use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::AtomicU32;
use bytes::Bytes;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::mpsc;
use tracing::debug;
use irontide_core::Id20;
use irontide_storage::Bitfield;
use irontide_wire::{ExtHandshake, Handshake, Message};
use crate::peer_codec::{PeerReader, PeerWriter};
use crate::peer_connection::PeerConnection;
use crate::torrent_peer_handler::PeerMessageHandler;
use crate::types::{PeerCommand, PeerEvent};
const HANDSHAKE_SIZE: usize = 68;
#[allow(dead_code, clippy::too_many_arguments)]
pub(crate) async fn run_peer(
addr: SocketAddr,
stream: impl AsyncRead + AsyncWrite + Unpin + Send + 'static,
info_hash: Id20,
our_peer_id: Id20,
our_bitfield: Bitfield,
num_pieces: u32,
event_tx: mpsc::Sender<PeerEvent>,
cmd_rx: mpsc::Receiver<PeerCommand>,
enable_dht: bool,
enable_fast: bool,
encryption_mode: irontide_wire::mse::EncryptionMode,
outbound: bool,
anonymous_mode: bool,
info_bytes: Option<Bytes>,
plugins: std::sync::Arc<Vec<Box<dyn crate::extension::ExtensionPlugin>>>,
enable_holepunch: bool,
max_message_size: usize,
have_broadcast_rx: tokio::sync::broadcast::Receiver<u32>,
in_flight: Arc<AtomicU32>,
target_depth: Arc<AtomicU32>,
read_timeout_secs: u64,
write_timeout_secs: u64,
data_contribution_timeout_secs: u64,
) -> crate::Result<()> {
use irontide_wire::mse::{self, EncryptionMode, MseStream};
let mut stream = if encryption_mode != EncryptionMode::Disabled {
let crypto_provide = match encryption_mode {
EncryptionMode::Forced => irontide_wire::mse::CRYPTO_RC4,
EncryptionMode::PreferPlaintext if outbound => irontide_wire::mse::CRYPTO_PLAINTEXT,
_ => irontide_wire::mse::CRYPTO_PLAINTEXT | irontide_wire::mse::CRYPTO_RC4,
};
let result = if outbound {
match tokio::time::timeout(
std::time::Duration::from_secs(5),
mse::handshake::negotiate_outbound(stream, &info_hash, crypto_provide),
)
.await
{
Ok(r) => r,
Err(_) => Err(irontide_wire::Error::EncryptionHandshakeFailed(
"MSE handshake timed out".into(),
)),
}
} else {
mse::handshake::negotiate_inbound(
stream,
&info_hash,
encryption_mode == EncryptionMode::Forced,
)
.await
};
match result {
Ok(r) => r.stream,
Err(e) => return Err(crate::Error::Wire(e)),
}
} else {
MseStream::plaintext(stream)
};
let hs_timeout = if read_timeout_secs > 0 {
Some(std::time::Duration::from_secs(read_timeout_secs))
} else {
None };
let mut our_hs = Handshake::new(info_hash, our_peer_id);
if enable_dht {
our_hs = our_hs.with_dht();
}
if enable_fast {
our_hs = our_hs.with_fast();
}
let hs_bytes = our_hs.to_bytes();
if let Some(timeout) = hs_timeout {
tokio::time::timeout(timeout, stream.write_all(&hs_bytes))
.await
.map_err(|_| crate::Error::Connection("BT handshake write timed out".into()))??;
} else {
stream.write_all(&hs_bytes).await?;
}
stream.flush().await?;
let mut hs_buf = [0u8; HANDSHAKE_SIZE];
if let Some(timeout) = hs_timeout {
tokio::time::timeout(timeout, stream.read_exact(&mut hs_buf))
.await
.map_err(|_| crate::Error::Connection("BT handshake read timed out".into()))??;
} else {
stream.read_exact(&mut hs_buf).await?;
}
let their_hs = Handshake::from_bytes(&hs_buf)?;
if their_hs.info_hash != info_hash {
return Err(crate::Error::Connection("info_hash mismatch".into()));
}
let _ = event_tx
.send(PeerEvent::HandshakeComplete { peer_addr: addr })
.await;
let (reader, writer) = tokio::io::split(stream);
let reader = PeerReader::new(crate::vectored_io::VectoredCompat(reader), max_message_size);
let mut writer = PeerWriter::new(writer);
let plugin_names: Vec<&str> = plugins.iter().map(|p| p.name()).collect();
let peer_supports_extensions = their_hs.supports_extensions();
if peer_supports_extensions {
let mut ext_hs = ExtHandshake::new_with_plugins(&plugin_names);
if !enable_holepunch {
ext_hs.m.remove("ut_holepunch");
}
if anonymous_mode {
ext_hs.v = None;
ext_hs.p = None;
ext_hs.reqq = None;
ext_hs.upload_only = None;
}
let payload = ext_hs.to_bytes().map_err(crate::Error::Wire)?;
writer
.send(&Message::Extended { ext_id: 0, payload })
.await?;
}
let both_support_fast = enable_fast && their_hs.supports_fast();
if both_support_fast {
if num_pieces > 0 && our_bitfield.count_ones() == num_pieces {
writer.send(&Message::HaveAll).await?;
} else if our_bitfield.count_ones() == 0 {
writer.send(&Message::HaveNone).await?;
} else {
writer
.send(&Message::Bitfield(Bytes::copy_from_slice(
our_bitfield.as_bytes(),
)))
.await?;
}
} else if our_bitfield.count_ones() > 0 {
writer
.send(&Message::Bitfield(Bytes::copy_from_slice(
our_bitfield.as_bytes(),
)))
.await?;
}
if both_support_fast && num_pieces > 0 {
let fast_set =
irontide_wire::allowed_fast_set_for_ip(&info_hash, addr.ip(), num_pieces, 10);
for index in fast_set {
writer.send(&Message::AllowedFast(index)).await?;
}
}
writer.send(&Message::Unchoke).await?;
let our_ext = ExtHandshake::new_with_plugins(&plugin_names);
let our_ut_metadata: Option<u8> = our_ext.ext_id("ut_metadata");
let our_ut_pex: Option<u8> = our_ext.ext_id("ut_pex");
let our_lt_trackers: Option<u8> = our_ext.ext_id("lt_trackers");
let our_ut_holepunch: Option<u8> = if enable_holepunch {
our_ext.ext_id("ut_holepunch")
} else {
None
};
let our_lt_donthave: Option<u8> = our_ext.ext_id("lt_donthave");
for plugin in plugins.iter() {
plugin.on_peer_connected(&info_hash, addr);
}
debug!(%addr, num_pieces, "entering main loop");
let handler = PeerMessageHandler::new(
addr,
num_pieces,
event_tx.clone(),
both_support_fast,
info_hash,
info_bytes,
plugins,
our_ut_metadata,
our_ut_pex,
our_lt_trackers,
our_ut_holepunch,
our_lt_donthave,
);
let connection = PeerConnection::new(
handler,
reader,
writer,
cmd_rx,
event_tx,
have_broadcast_rx,
in_flight,
target_depth,
read_timeout_secs,
write_timeout_secs,
data_contribution_timeout_secs,
);
connection.run().await
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use irontide_wire::{ExtHandshake, MetadataMessage, MetadataMessageType};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::mpsc;
fn dummy_broadcast_rx() -> tokio::sync::broadcast::Receiver<u32> {
let (tx, rx) = tokio::sync::broadcast::channel(16);
drop(tx);
rx
}
fn test_addr() -> SocketAddr {
"127.0.0.1:6881".parse().unwrap()
}
fn test_info_hash() -> Id20 {
Id20::from_hex("aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d").unwrap()
}
fn test_peer_id() -> Id20 {
Id20::from_hex("0102030405060708091011121314151617181920").unwrap()
}
fn remote_peer_id() -> Id20 {
Id20::from_hex("2122232425262728293031323334353637383940").unwrap()
}
async fn do_remote_handshake(
stream: &mut (impl AsyncRead + AsyncWrite + Unpin),
info_hash: Id20,
remote_id: Id20,
) -> Handshake {
let mut buf = [0u8; HANDSHAKE_SIZE];
stream.read_exact(&mut buf).await.unwrap();
let our_hs = Handshake::from_bytes(&buf).unwrap();
let remote_hs = Handshake::new(info_hash, remote_id);
stream.write_all(&remote_hs.to_bytes()).await.unwrap();
stream.flush().await.unwrap();
our_hs
}
async fn read_framed_message(stream: &mut (impl AsyncRead + Unpin)) -> Message {
let mut len_buf = [0u8; 4];
stream.read_exact(&mut len_buf).await.unwrap();
let len = u32::from_be_bytes(len_buf) as usize;
let mut payload = vec![0u8; len];
if len > 0 {
stream.read_exact(&mut payload).await.unwrap();
}
Message::from_payload(Bytes::from(payload)).unwrap()
}
async fn write_framed_message(stream: &mut (impl AsyncWrite + Unpin), msg: &Message) {
let bytes = msg.to_bytes();
stream.write_all(&bytes).await.unwrap();
stream.flush().await.unwrap();
}
async fn expect_handshake_complete(event_rx: &mut mpsc::Receiver<PeerEvent>) {
let evt = event_rx.recv().await.unwrap();
assert!(
matches!(evt, PeerEvent::HandshakeComplete { .. }),
"expected HandshakeComplete event first, got: {evt:?}"
);
}
async fn do_remote_ext_handshake(
stream: &mut (impl AsyncRead + AsyncWrite + Unpin),
) -> ExtHandshake {
let msg = read_framed_message(stream).await;
let our_ext_hs = match msg {
Message::Extended { ext_id: 0, payload } => ExtHandshake::from_bytes(&payload).unwrap(),
other => panic!("expected ext handshake, got: {other:?}"),
};
let mut remote_ext = ExtHandshake::default();
remote_ext.m.insert("ut_metadata".into(), 3);
remote_ext.m.insert("ut_pex".into(), 4);
remote_ext.v = Some("TestPeer 1.0".into());
let payload = remote_ext.to_bytes().unwrap();
write_framed_message(stream, &Message::Extended { ext_id: 0, payload }).await;
our_ext_hs
}
#[tokio::test]
async fn handshake_exchange() {
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, mut event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let our_id = test_peer_id();
let remote_id = remote_peer_id();
let bitfield = Bitfield::new(10);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
our_id,
bitfield,
10,
event_tx,
cmd_rx,
false,
false,
irontide_wire::mse::EncryptionMode::Disabled,
false, false, None, std::sync::Arc::new(Vec::new()), true, 16 * 1024 * 1024, dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
let our_hs = do_remote_handshake(&mut server_stream, info_hash, remote_id).await;
assert_eq!(our_hs.info_hash, info_hash);
assert_eq!(our_hs.peer_id, our_id);
assert!(our_hs.supports_extensions());
expect_handshake_complete(&mut event_rx).await;
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let result = handle.await.unwrap();
assert!(result.is_ok());
let evt = event_rx.recv().await.unwrap();
assert!(
matches!(evt, PeerEvent::Disconnected { .. }),
"expected Disconnected event, got: {evt:?}"
);
}
#[tokio::test]
async fn handshake_info_hash_mismatch() {
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, _event_rx) = mpsc::channel(32);
let (_cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let wrong_hash = Id20::from_hex("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb").unwrap();
let our_id = test_peer_id();
let remote_id = remote_peer_id();
let bitfield = Bitfield::new(10);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
our_id,
bitfield,
10,
event_tx,
cmd_rx,
false,
false,
irontide_wire::mse::EncryptionMode::Disabled,
false, false, None, std::sync::Arc::new(Vec::new()), true, 16 * 1024 * 1024, dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
let mut buf = [0u8; HANDSHAKE_SIZE];
server_stream.read_exact(&mut buf).await.unwrap();
let bad_hs = Handshake::new(wrong_hash, remote_id);
server_stream.write_all(&bad_hs.to_bytes()).await.unwrap();
server_stream.flush().await.unwrap();
let result = handle.await.unwrap();
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("info_hash mismatch"),
"expected info_hash mismatch error, got: {err_msg}"
);
}
#[tokio::test]
async fn extension_negotiation() {
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, _event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let bitfield = Bitfield::new(10);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
10,
event_tx,
cmd_rx,
false,
false,
irontide_wire::mse::EncryptionMode::Disabled,
false, false, None, std::sync::Arc::new(Vec::new()), true, 16 * 1024 * 1024, dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
do_remote_handshake(&mut server_stream, info_hash, remote_peer_id()).await;
let msg = read_framed_message(&mut server_stream).await;
match msg {
Message::Extended { ext_id, payload } => {
assert_eq!(ext_id, 0, "ext handshake should be ext_id=0");
let ext_hs = ExtHandshake::from_bytes(&payload).unwrap();
assert!(
ext_hs.ext_id("ut_metadata").is_some(),
"should advertise ut_metadata"
);
assert!(ext_hs.ext_id("ut_pex").is_some(), "should advertise ut_pex");
}
other => panic!("expected Extended handshake, got: {other:?}"),
}
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let _ = handle.await;
}
#[tokio::test]
async fn bitfield_exchange() {
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, _event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let mut bitfield = Bitfield::new(16);
bitfield.set(0);
bitfield.set(5);
bitfield.set(15);
let expected_bytes = bitfield.as_bytes().to_vec();
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
16,
event_tx,
cmd_rx,
false,
false,
irontide_wire::mse::EncryptionMode::Disabled,
false, false, None, std::sync::Arc::new(Vec::new()), true, 16 * 1024 * 1024, dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
do_remote_handshake(&mut server_stream, info_hash, remote_peer_id()).await;
let _ext_hs_msg = read_framed_message(&mut server_stream).await;
let msg = read_framed_message(&mut server_stream).await;
match msg {
Message::Bitfield(data) => {
assert_eq!(data.as_ref(), expected_bytes.as_slice());
}
other => panic!("expected Bitfield, got: {other:?}"),
}
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let _ = handle.await;
}
#[tokio::test]
async fn choke_unchoke_state() {
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, mut event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let bitfield = Bitfield::new(10);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
10,
event_tx,
cmd_rx,
false,
false,
irontide_wire::mse::EncryptionMode::Disabled,
false, false, None, std::sync::Arc::new(Vec::new()), true, 16 * 1024 * 1024, dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
do_remote_handshake(&mut server_stream, info_hash, remote_peer_id()).await;
expect_handshake_complete(&mut event_rx).await;
let _ext_hs_msg = read_framed_message(&mut server_stream).await;
write_framed_message(&mut server_stream, &Message::Choke).await;
let evt = event_rx.recv().await.unwrap();
match evt {
PeerEvent::PeerChoking { choking, .. } => assert!(choking),
other => panic!("expected PeerChoking, got: {other:?}"),
}
write_framed_message(&mut server_stream, &Message::Unchoke).await;
let evt = event_rx.recv().await.unwrap();
match evt {
PeerEvent::PeerChoking { choking, .. } => assert!(!choking),
other => panic!("expected PeerChoking, got: {other:?}"),
}
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let _ = handle.await;
}
#[tokio::test]
async fn interested_signaling() {
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, _event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let bitfield = Bitfield::new(10);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
10,
event_tx,
cmd_rx,
false,
false,
irontide_wire::mse::EncryptionMode::Disabled,
false, false, None, std::sync::Arc::new(Vec::new()), true, 16 * 1024 * 1024, dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
do_remote_handshake(&mut server_stream, info_hash, remote_peer_id()).await;
let _ext_hs_msg = read_framed_message(&mut server_stream).await;
let _unchoke = read_framed_message(&mut server_stream).await;
cmd_tx.send(PeerCommand::SetInterested(true)).await.unwrap();
let msg = read_framed_message(&mut server_stream).await;
assert_eq!(msg, Message::Interested);
cmd_tx
.send(PeerCommand::SetInterested(false))
.await
.unwrap();
let msg = read_framed_message(&mut server_stream).await;
assert_eq!(msg, Message::NotInterested);
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let _ = handle.await;
}
#[tokio::test]
async fn have_forwarding_via_broadcast() {
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, _event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let (have_tx, have_rx) = tokio::sync::broadcast::channel(16);
let info_hash = test_info_hash();
let bitfield = Bitfield::new(10);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
10,
event_tx,
cmd_rx,
false,
false,
irontide_wire::mse::EncryptionMode::Disabled,
false, false, None, std::sync::Arc::new(Vec::new()), true, 16 * 1024 * 1024, have_rx,
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
do_remote_handshake(&mut server_stream, info_hash, remote_peer_id()).await;
let _ext_hs_msg = read_framed_message(&mut server_stream).await;
let _unchoke = read_framed_message(&mut server_stream).await;
have_tx.send(5).unwrap();
let msg = read_framed_message(&mut server_stream).await;
assert_eq!(msg, Message::Have { index: 5 });
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let _ = handle.await;
}
#[tokio::test]
async fn request_piece_flow() {
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, mut event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let bitfield = Bitfield::new(10);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
10,
event_tx,
cmd_rx,
false,
false,
irontide_wire::mse::EncryptionMode::Disabled,
false, false, None, std::sync::Arc::new(Vec::new()), true, 16 * 1024 * 1024, dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
do_remote_handshake(&mut server_stream, info_hash, remote_peer_id()).await;
expect_handshake_complete(&mut event_rx).await;
let _ext_hs_msg = read_framed_message(&mut server_stream).await;
let _unchoke = read_framed_message(&mut server_stream).await;
cmd_tx
.send(PeerCommand::Request {
index: 0,
begin: 0,
length: 16384,
})
.await
.unwrap();
let msg = read_framed_message(&mut server_stream).await;
assert_eq!(
msg,
Message::Request {
index: 0,
begin: 0,
length: 16384,
}
);
let piece_data = Bytes::from_static(b"test piece data!");
write_framed_message(
&mut server_stream,
&Message::Piece {
index: 0,
begin: 0,
data_0: piece_data.clone(),
data_1: Bytes::new(),
},
)
.await;
let evt = event_rx.recv().await.unwrap();
match evt {
PeerEvent::PieceData {
index, begin, data, ..
} => {
assert_eq!(index, 0);
assert_eq!(begin, 0);
assert_eq!(data, piece_data);
}
other => panic!("expected PieceData, got: {other:?}"),
}
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let _ = handle.await;
}
#[tokio::test]
async fn keepalive() {
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, _event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let bitfield = Bitfield::new(10);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
10,
event_tx,
cmd_rx,
false,
false,
irontide_wire::mse::EncryptionMode::Disabled,
false, false, None, std::sync::Arc::new(Vec::new()), true, 16 * 1024 * 1024, dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
do_remote_handshake(&mut server_stream, info_hash, remote_peer_id()).await;
let _ext_hs_msg = read_framed_message(&mut server_stream).await;
let _unchoke = read_framed_message(&mut server_stream).await;
write_framed_message(&mut server_stream, &Message::KeepAlive).await;
cmd_tx.send(PeerCommand::SetInterested(true)).await.unwrap();
let msg = read_framed_message(&mut server_stream).await;
assert_eq!(msg, Message::Interested);
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let _ = handle.await;
}
#[tokio::test]
async fn graceful_disconnect() {
let (client_stream, server_stream) = tokio::io::duplex(8192);
let (event_tx, mut event_rx) = mpsc::channel(32);
let (_cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let bitfield = Bitfield::new(10);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
10,
event_tx,
cmd_rx,
false,
false,
irontide_wire::mse::EncryptionMode::Disabled,
false, false, None, std::sync::Arc::new(Vec::new()), true, 16 * 1024 * 1024, dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
let mut server = server_stream;
do_remote_handshake(&mut server, info_hash, remote_peer_id()).await;
expect_handshake_complete(&mut event_rx).await;
let _ext_hs_msg = read_framed_message(&mut server).await;
drop(server);
let evt = event_rx.recv().await.unwrap();
assert!(
matches!(evt, PeerEvent::Disconnected { .. }),
"expected Disconnected, got: {evt:?}"
);
let _ = handle.await;
}
#[tokio::test]
async fn metadata_request_response() {
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, mut event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let bitfield = Bitfield::new(10);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
10,
event_tx,
cmd_rx,
false,
false,
irontide_wire::mse::EncryptionMode::Disabled,
false, false, None, std::sync::Arc::new(Vec::new()), true, 16 * 1024 * 1024, dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
do_remote_handshake(&mut server_stream, info_hash, remote_peer_id()).await;
expect_handshake_complete(&mut event_rx).await;
let _our_ext = do_remote_ext_handshake(&mut server_stream).await;
let _unchoke = read_framed_message(&mut server_stream).await;
let evt = event_rx.recv().await.unwrap();
assert!(matches!(evt, PeerEvent::ExtHandshake { .. }));
cmd_tx
.send(PeerCommand::RequestMetadata { piece: 0 })
.await
.unwrap();
let msg = read_framed_message(&mut server_stream).await;
match msg {
Message::Extended { ext_id, payload } => {
assert_eq!(ext_id, 3, "should use remote's ut_metadata ID");
let meta = MetadataMessage::from_bytes(&payload).unwrap();
assert_eq!(meta.msg_type, MetadataMessageType::Request);
assert_eq!(meta.piece, 0);
}
other => panic!("expected Extended metadata request, got: {other:?}"),
}
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let _ = handle.await;
}
#[tokio::test]
async fn pex_message_handling() {
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, mut event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let bitfield = Bitfield::new(10);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
10,
event_tx,
cmd_rx,
false,
false,
irontide_wire::mse::EncryptionMode::Disabled,
false, false, None, std::sync::Arc::new(Vec::new()), true, 16 * 1024 * 1024, dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
do_remote_handshake(&mut server_stream, info_hash, remote_peer_id()).await;
expect_handshake_complete(&mut event_rx).await;
let _our_ext = do_remote_ext_handshake(&mut server_stream).await;
let evt = event_rx.recv().await.unwrap();
assert!(matches!(evt, PeerEvent::ExtHandshake { .. }));
let pex = crate::pex::PexMessage {
added: vec![10, 0, 0, 1, 0x1F, 0x90],
added_flags: vec![0x00],
..Default::default()
};
let pex_payload = pex.to_bytes().unwrap();
write_framed_message(
&mut server_stream,
&Message::Extended {
ext_id: 2,
payload: pex_payload,
},
)
.await;
let evt = event_rx.recv().await.unwrap();
match evt {
PeerEvent::PexPeers { new_peers } => {
assert_eq!(new_peers.len(), 1);
assert_eq!(new_peers[0].to_string(), "10.0.0.1:8080");
}
other => panic!("expected PexPeers, got: {other:?}"),
}
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let _ = handle.await;
}
async fn do_remote_handshake_fast(
stream: &mut (impl AsyncRead + AsyncWrite + Unpin),
info_hash: Id20,
remote_id: Id20,
) -> Handshake {
let mut buf = [0u8; HANDSHAKE_SIZE];
stream.read_exact(&mut buf).await.unwrap();
let our_hs = Handshake::from_bytes(&buf).unwrap();
let remote_hs = Handshake::new(info_hash, remote_id).with_fast();
stream.write_all(&remote_hs.to_bytes()).await.unwrap();
stream.flush().await.unwrap();
our_hs
}
#[tokio::test]
async fn fast_extension_have_all_sent_when_complete() {
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, _event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let num_pieces = 10u32;
let mut bitfield = Bitfield::new(num_pieces);
for i in 0..num_pieces {
bitfield.set(i);
}
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
num_pieces,
event_tx,
cmd_rx,
false,
true, irontide_wire::mse::EncryptionMode::Disabled,
false, false, None, std::sync::Arc::new(Vec::new()), true, 16 * 1024 * 1024, dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
let our_hs =
do_remote_handshake_fast(&mut server_stream, info_hash, remote_peer_id()).await;
assert!(
our_hs.supports_fast(),
"our handshake should advertise fast"
);
let _ext_hs_msg = read_framed_message(&mut server_stream).await;
let msg = read_framed_message(&mut server_stream).await;
assert_eq!(msg, Message::HaveAll);
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let _ = handle.await;
}
#[tokio::test]
async fn fast_extension_have_none_sent_when_empty() {
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, _event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let num_pieces = 10u32;
let bitfield = Bitfield::new(num_pieces);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
num_pieces,
event_tx,
cmd_rx,
false,
true, irontide_wire::mse::EncryptionMode::Disabled,
false, false, None, std::sync::Arc::new(Vec::new()), true, 16 * 1024 * 1024, dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
let our_hs =
do_remote_handshake_fast(&mut server_stream, info_hash, remote_peer_id()).await;
assert!(our_hs.supports_fast());
let _ext_hs_msg = read_framed_message(&mut server_stream).await;
let msg = read_framed_message(&mut server_stream).await;
assert_eq!(msg, Message::HaveNone);
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let _ = handle.await;
}
#[tokio::test]
async fn fast_extension_sends_allowed_fast_on_connect() {
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, _event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let num_pieces = 100u32;
let bitfield = Bitfield::new(num_pieces);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
num_pieces,
event_tx,
cmd_rx,
false,
true, irontide_wire::mse::EncryptionMode::Disabled,
false, false, None, std::sync::Arc::new(Vec::new()), true, 16 * 1024 * 1024, dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
let _our_hs =
do_remote_handshake_fast(&mut server_stream, info_hash, remote_peer_id()).await;
let _ext_hs_msg = read_framed_message(&mut server_stream).await;
let msg = read_framed_message(&mut server_stream).await;
assert_eq!(msg, Message::HaveNone);
let mut fast_indices = std::collections::HashSet::new();
for _ in 0..10 {
let msg = read_framed_message(&mut server_stream).await;
match msg {
Message::AllowedFast(index) => {
assert!(index < num_pieces, "AllowedFast index out of range");
fast_indices.insert(index);
}
other => panic!("expected AllowedFast, got: {other:?}"),
}
}
assert_eq!(
fast_indices.len(),
10,
"should have 10 unique AllowedFast indices"
);
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let _ = handle.await;
}
#[tokio::test]
async fn no_allowed_fast_without_fast_extension() {
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, _event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let num_pieces = 100u32;
let bitfield = Bitfield::new(num_pieces);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
num_pieces,
event_tx,
cmd_rx,
false,
false, irontide_wire::mse::EncryptionMode::Disabled,
false, false, None, std::sync::Arc::new(Vec::new()), true, 16 * 1024 * 1024, dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
do_remote_handshake(&mut server_stream, info_hash, remote_peer_id()).await;
let _ext_hs_msg = read_framed_message(&mut server_stream).await;
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let result = handle.await.unwrap();
assert!(result.is_ok());
}
#[tokio::test]
async fn send_piece_command_sends_piece_message() {
let (client_stream, mut server_stream) = tokio::io::duplex(65536);
let (event_tx, _event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let bitfield = Bitfield::new(10);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
10,
event_tx,
cmd_rx,
false,
false,
irontide_wire::mse::EncryptionMode::Disabled,
false, false, None, std::sync::Arc::new(Vec::new()), true, 16 * 1024 * 1024, dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
do_remote_handshake(&mut server_stream, info_hash, remote_peer_id()).await;
let _ext_hs_msg = read_framed_message(&mut server_stream).await;
let _unchoke = read_framed_message(&mut server_stream).await;
let piece_data = Bytes::from(vec![0xAB; 16384]);
cmd_tx
.send(PeerCommand::SendPiece {
index: 0,
begin: 0,
data: piece_data.clone(),
})
.await
.unwrap();
let msg = read_framed_message(&mut server_stream).await;
match msg {
Message::Piece {
index,
begin,
data_0,
data_1,
} => {
assert_eq!(index, 0);
assert_eq!(begin, 0);
assert!(data_1.is_empty());
assert_eq!(data_0, piece_data);
}
other => panic!("expected Piece, got: {other:?}"),
}
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let _ = handle.await;
}
#[tokio::test]
async fn serves_metadata_request() {
let (client_stream, mut server_stream) = tokio::io::duplex(65536);
let (event_tx, _event_rx) = mpsc::channel(16);
let (cmd_tx, cmd_rx) = mpsc::channel(16);
let info_hash = test_info_hash();
let bitfield = Bitfield::new(10);
let info_raw = b"d4:name4:test12:piece lengthi16384e6:pieces20:AAAAAAAAAAAAAAAAAAAAe";
let info_bytes = Bytes::from_static(info_raw);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
10,
event_tx,
cmd_rx,
false,
false,
irontide_wire::mse::EncryptionMode::Disabled,
false, false, Some(info_bytes),
std::sync::Arc::new(Vec::new()), true, 16 * 1024 * 1024, dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
do_remote_handshake(&mut server_stream, info_hash, remote_peer_id()).await;
let ext_hs_msg = read_framed_message(&mut server_stream).await;
let our_ut_metadata_id = match &ext_hs_msg {
Message::Extended { ext_id: 0, payload } => {
let hs = ExtHandshake::from_bytes(payload).unwrap();
hs.ext_id("ut_metadata").unwrap()
}
other => panic!("expected ext handshake, got: {other:?}"),
};
let _unchoke = read_framed_message(&mut server_stream).await;
let mut remote_ext = ExtHandshake::new();
remote_ext.m.insert("ut_metadata".into(), 5);
let ext_payload = remote_ext.to_bytes().unwrap();
write_framed_message(
&mut server_stream,
&Message::Extended {
ext_id: 0,
payload: ext_payload,
},
)
.await;
let req = MetadataMessage::request(0);
let req_payload = req.to_bytes().unwrap();
write_framed_message(
&mut server_stream,
&Message::Extended {
ext_id: our_ut_metadata_id,
payload: req_payload,
},
)
.await;
let response = read_framed_message(&mut server_stream).await;
match response {
Message::Extended { ext_id, payload } => {
assert_eq!(ext_id, 5, "should use remote's ut_metadata id");
let meta_msg = MetadataMessage::from_bytes(&payload).unwrap();
assert_eq!(meta_msg.msg_type, MetadataMessageType::Data);
assert_eq!(meta_msg.piece, 0);
assert_eq!(meta_msg.total_size, Some(info_raw.len() as u64));
assert_eq!(meta_msg.data.as_deref(), Some(info_raw.as_ref()));
}
other => panic!("expected Extended data response, got: {other:?}"),
}
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let _ = handle.await;
}
struct TestEchoPlugin;
impl crate::extension::ExtensionPlugin for TestEchoPlugin {
fn name(&self) -> &str {
"ut_echo"
}
fn on_message(
&self,
_info_hash: &Id20,
_peer_addr: SocketAddr,
payload: &[u8],
) -> Option<Vec<u8>> {
Some(payload.to_vec())
}
}
#[tokio::test]
async fn plugin_advertised_in_ext_handshake() {
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, _event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let bitfield = Bitfield::new(10);
let plugins: std::sync::Arc<Vec<Box<dyn crate::extension::ExtensionPlugin>>> =
std::sync::Arc::new(vec![Box::new(TestEchoPlugin)]);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
10,
event_tx,
cmd_rx,
false,
false,
irontide_wire::mse::EncryptionMode::Disabled,
false, false, None, plugins,
true, 16 * 1024 * 1024, dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
do_remote_handshake(&mut server_stream, info_hash, remote_peer_id()).await;
let msg = read_framed_message(&mut server_stream).await;
match msg {
Message::Extended { ext_id: 0, payload } => {
let ext_hs = ExtHandshake::from_bytes(&payload).unwrap();
assert_eq!(
ext_hs.ext_id("ut_echo"),
Some(10),
"plugin should be at ID 10"
);
assert_eq!(ext_hs.ext_id("ut_metadata"), Some(1), "built-in unchanged");
}
other => panic!("expected ext handshake, got: {other:?}"),
}
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let _ = handle.await;
}
#[tokio::test]
async fn plugin_message_echo_dispatch() {
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, _event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let bitfield = Bitfield::new(10);
let plugins: std::sync::Arc<Vec<Box<dyn crate::extension::ExtensionPlugin>>> =
std::sync::Arc::new(vec![Box::new(TestEchoPlugin)]);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
10,
event_tx,
cmd_rx,
false,
false,
irontide_wire::mse::EncryptionMode::Disabled,
false, false, None, plugins,
true, 16 * 1024 * 1024, dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
do_remote_handshake(&mut server_stream, info_hash, remote_peer_id()).await;
let msg = read_framed_message(&mut server_stream).await;
let our_ext_hs = match msg {
Message::Extended { ext_id: 0, payload } => ExtHandshake::from_bytes(&payload).unwrap(),
other => panic!("expected ext handshake, got: {other:?}"),
};
let our_ut_echo_id = our_ext_hs.ext_id("ut_echo").unwrap();
assert_eq!(our_ut_echo_id, 10);
let _unchoke = read_framed_message(&mut server_stream).await;
let mut remote_ext = ExtHandshake::default();
remote_ext.m.insert("ut_metadata".into(), 3);
remote_ext.m.insert("ut_pex".into(), 4);
remote_ext.m.insert("ut_echo".into(), 42);
let ext_payload = remote_ext.to_bytes().unwrap();
write_framed_message(
&mut server_stream,
&Message::Extended {
ext_id: 0,
payload: ext_payload,
},
)
.await;
write_framed_message(
&mut server_stream,
&Message::Extended {
ext_id: our_ut_echo_id,
payload: Bytes::from_static(b"hello plugin"),
},
)
.await;
let response = read_framed_message(&mut server_stream).await;
match response {
Message::Extended { ext_id, payload } => {
assert_eq!(ext_id, 42, "response should use peer's ut_echo id");
assert_eq!(payload.as_ref(), b"hello plugin");
}
other => panic!("expected echo response, got: {other:?}"),
}
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let _ = handle.await;
}
#[tokio::test]
async fn holepunch_rendezvous_event() {
use irontide_wire::HolepunchMessage;
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, mut event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let bitfield = Bitfield::new(10);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
10,
event_tx,
cmd_rx,
false,
false,
irontide_wire::mse::EncryptionMode::Disabled,
false,
false,
None,
std::sync::Arc::new(Vec::new()),
true,
16 * 1024 * 1024,
dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
do_remote_handshake(&mut server_stream, info_hash, remote_peer_id()).await;
expect_handshake_complete(&mut event_rx).await;
let ext_hs_msg = read_framed_message(&mut server_stream).await;
let our_ut_holepunch_id = match &ext_hs_msg {
Message::Extended { ext_id: 0, payload } => {
let hs = ExtHandshake::from_bytes(payload).unwrap();
hs.ext_id("ut_holepunch").unwrap()
}
other => panic!("expected ext handshake, got: {other:?}"),
};
let mut remote_ext = ExtHandshake::default();
remote_ext.m.insert("ut_metadata".into(), 3);
let ext_payload = remote_ext.to_bytes().unwrap();
write_framed_message(
&mut server_stream,
&Message::Extended {
ext_id: 0,
payload: ext_payload,
},
)
.await;
let evt = event_rx.recv().await.unwrap();
assert!(matches!(evt, PeerEvent::ExtHandshake { .. }));
let target: SocketAddr = "10.0.0.1:8080".parse().unwrap();
let hp_msg = HolepunchMessage::rendezvous(target);
write_framed_message(
&mut server_stream,
&Message::Extended {
ext_id: our_ut_holepunch_id,
payload: hp_msg.to_bytes(),
},
)
.await;
let evt = event_rx.recv().await.unwrap();
match evt {
PeerEvent::HolepunchRendezvous {
peer_addr,
target: t,
} => {
assert_eq!(peer_addr, test_addr());
assert_eq!(t, target);
}
other => panic!("expected HolepunchRendezvous, got: {other:?}"),
}
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let _ = handle.await;
}
#[tokio::test]
async fn holepunch_connect_event() {
use irontide_wire::HolepunchMessage;
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, mut event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let bitfield = Bitfield::new(10);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
10,
event_tx,
cmd_rx,
false,
false,
irontide_wire::mse::EncryptionMode::Disabled,
false,
false,
None,
std::sync::Arc::new(Vec::new()),
true,
16 * 1024 * 1024,
dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
do_remote_handshake(&mut server_stream, info_hash, remote_peer_id()).await;
expect_handshake_complete(&mut event_rx).await;
let ext_hs_msg = read_framed_message(&mut server_stream).await;
let our_ut_holepunch_id = match &ext_hs_msg {
Message::Extended { ext_id: 0, payload } => ExtHandshake::from_bytes(payload)
.unwrap()
.ext_id("ut_holepunch")
.unwrap(),
other => panic!("expected ext handshake, got: {other:?}"),
};
let mut remote_ext = ExtHandshake::default();
remote_ext.m.insert("ut_metadata".into(), 3);
let ext_payload = remote_ext.to_bytes().unwrap();
write_framed_message(
&mut server_stream,
&Message::Extended {
ext_id: 0,
payload: ext_payload,
},
)
.await;
let evt = event_rx.recv().await.unwrap();
assert!(matches!(evt, PeerEvent::ExtHandshake { .. }));
let target: SocketAddr = "192.168.1.100:6881".parse().unwrap();
let hp_msg = HolepunchMessage::connect(target);
write_framed_message(
&mut server_stream,
&Message::Extended {
ext_id: our_ut_holepunch_id,
payload: hp_msg.to_bytes(),
},
)
.await;
let evt = event_rx.recv().await.unwrap();
match evt {
PeerEvent::HolepunchConnect {
peer_addr,
target: t,
} => {
assert_eq!(peer_addr, test_addr());
assert_eq!(t, target);
}
other => panic!("expected HolepunchConnect, got: {other:?}"),
}
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let _ = handle.await;
}
#[tokio::test]
async fn holepunch_error_event() {
use irontide_wire::{HolepunchError, HolepunchMessage};
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, mut event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let bitfield = Bitfield::new(10);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
10,
event_tx,
cmd_rx,
false,
false,
irontide_wire::mse::EncryptionMode::Disabled,
false,
false,
None,
std::sync::Arc::new(Vec::new()),
true,
16 * 1024 * 1024,
dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
do_remote_handshake(&mut server_stream, info_hash, remote_peer_id()).await;
expect_handshake_complete(&mut event_rx).await;
let ext_hs_msg = read_framed_message(&mut server_stream).await;
let our_ut_holepunch_id = match &ext_hs_msg {
Message::Extended { ext_id: 0, payload } => ExtHandshake::from_bytes(payload)
.unwrap()
.ext_id("ut_holepunch")
.unwrap(),
other => panic!("expected ext handshake, got: {other:?}"),
};
let mut remote_ext = ExtHandshake::default();
remote_ext.m.insert("ut_metadata".into(), 3);
let ext_payload = remote_ext.to_bytes().unwrap();
write_framed_message(
&mut server_stream,
&Message::Extended {
ext_id: 0,
payload: ext_payload,
},
)
.await;
let evt = event_rx.recv().await.unwrap();
assert!(matches!(evt, PeerEvent::ExtHandshake { .. }));
let target: SocketAddr = "172.16.0.5:51413".parse().unwrap();
let hp_msg = HolepunchMessage::error(target, HolepunchError::NotConnected);
write_framed_message(
&mut server_stream,
&Message::Extended {
ext_id: our_ut_holepunch_id,
payload: hp_msg.to_bytes(),
},
)
.await;
let evt = event_rx.recv().await.unwrap();
match evt {
PeerEvent::HolepunchError {
peer_addr,
target: t,
error_code,
} => {
assert_eq!(peer_addr, test_addr());
assert_eq!(t, target);
assert_eq!(error_code, 2); }
other => panic!("expected HolepunchError, got: {other:?}"),
}
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let _ = handle.await;
}
#[tokio::test]
async fn holepunch_disabled_removes_from_ext_handshake() {
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, _event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let bitfield = Bitfield::new(10);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
10,
event_tx,
cmd_rx,
false,
false,
irontide_wire::mse::EncryptionMode::Disabled,
false,
false,
None,
std::sync::Arc::new(Vec::new()),
false, 16 * 1024 * 1024,
dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
do_remote_handshake(&mut server_stream, info_hash, remote_peer_id()).await;
let msg = read_framed_message(&mut server_stream).await;
match msg {
Message::Extended { ext_id: 0, payload } => {
let ext_hs = ExtHandshake::from_bytes(&payload).unwrap();
assert!(
ext_hs.ext_id("ut_holepunch").is_none(),
"ut_holepunch should not be advertised when disabled"
);
assert!(ext_hs.ext_id("ut_metadata").is_some());
assert!(ext_hs.ext_id("ut_pex").is_some());
}
other => panic!("expected ext handshake, got: {other:?}"),
}
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let _ = handle.await;
}
#[tokio::test]
async fn holepunch_enabled_advertised_in_ext_handshake() {
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, _event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let bitfield = Bitfield::new(10);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
10,
event_tx,
cmd_rx,
false,
false,
irontide_wire::mse::EncryptionMode::Disabled,
false,
false,
None,
std::sync::Arc::new(Vec::new()),
true, 16 * 1024 * 1024,
dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
do_remote_handshake(&mut server_stream, info_hash, remote_peer_id()).await;
let msg = read_framed_message(&mut server_stream).await;
match msg {
Message::Extended { ext_id: 0, payload } => {
let ext_hs = ExtHandshake::from_bytes(&payload).unwrap();
assert_eq!(
ext_hs.ext_id("ut_holepunch"),
Some(4),
"ut_holepunch should be advertised at ID 4"
);
}
other => panic!("expected ext handshake, got: {other:?}"),
}
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let _ = handle.await;
}
#[tokio::test]
async fn send_holepunch_command_sends_extended_message() {
use irontide_wire::HolepunchMessage;
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, mut event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let bitfield = Bitfield::new(10);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
10,
event_tx,
cmd_rx,
false,
false,
irontide_wire::mse::EncryptionMode::Disabled,
false,
false,
None,
std::sync::Arc::new(Vec::new()),
true,
16 * 1024 * 1024,
dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
do_remote_handshake(&mut server_stream, info_hash, remote_peer_id()).await;
expect_handshake_complete(&mut event_rx).await;
let ext_hs_msg = read_framed_message(&mut server_stream).await;
let _our_ext_hs = match &ext_hs_msg {
Message::Extended { ext_id: 0, payload } => ExtHandshake::from_bytes(payload).unwrap(),
other => panic!("expected ext handshake, got: {other:?}"),
};
let _unchoke = read_framed_message(&mut server_stream).await;
let mut remote_ext = ExtHandshake::default();
remote_ext.m.insert("ut_metadata".into(), 3);
remote_ext.m.insert("ut_holepunch".into(), 7);
let ext_payload = remote_ext.to_bytes().unwrap();
write_framed_message(
&mut server_stream,
&Message::Extended {
ext_id: 0,
payload: ext_payload,
},
)
.await;
let evt = event_rx.recv().await.unwrap();
assert!(matches!(evt, PeerEvent::ExtHandshake { .. }));
let target: SocketAddr = "10.0.0.1:8080".parse().unwrap();
let hp_msg = HolepunchMessage::connect(target);
cmd_tx
.send(PeerCommand::SendHolepunch(hp_msg.clone()))
.await
.unwrap();
let msg = read_framed_message(&mut server_stream).await;
match msg {
Message::Extended { ext_id, payload } => {
assert_eq!(ext_id, 7, "should use remote's ut_holepunch ID");
let parsed = HolepunchMessage::from_bytes(&payload).unwrap();
assert_eq!(parsed, hp_msg);
}
other => panic!("expected Extended holepunch message, got: {other:?}"),
}
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let _ = handle.await;
}
#[tokio::test]
async fn incoming_holepunch_routed_to_event() {
use irontide_wire::HolepunchMessage;
let (client_stream, mut server_stream) = tokio::io::duplex(8192);
let (event_tx, mut event_rx) = mpsc::channel(32);
let (cmd_tx, cmd_rx) = mpsc::channel(32);
let info_hash = test_info_hash();
let bitfield = Bitfield::new(10);
let handle = tokio::spawn(async move {
run_peer(
test_addr(),
client_stream,
info_hash,
test_peer_id(),
bitfield,
10,
event_tx,
cmd_rx,
false,
false,
irontide_wire::mse::EncryptionMode::Disabled,
false,
false,
None,
std::sync::Arc::new(Vec::new()),
true,
16 * 1024 * 1024,
dummy_broadcast_rx(),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(128)),
0, 0, 0, )
.await
});
do_remote_handshake(&mut server_stream, info_hash, remote_peer_id()).await;
expect_handshake_complete(&mut event_rx).await;
let ext_hs_msg = read_framed_message(&mut server_stream).await;
let our_ut_holepunch_id = match &ext_hs_msg {
Message::Extended { ext_id: 0, payload } => {
let hs = ExtHandshake::from_bytes(payload).unwrap();
hs.ext_id("ut_holepunch").unwrap()
}
other => panic!("expected ext handshake, got: {other:?}"),
};
assert_eq!(our_ut_holepunch_id, 4);
let mut remote_ext = ExtHandshake::default();
remote_ext.m.insert("ut_metadata".into(), 3);
let ext_payload = remote_ext.to_bytes().unwrap();
write_framed_message(
&mut server_stream,
&Message::Extended {
ext_id: 0,
payload: ext_payload,
},
)
.await;
let evt = event_rx.recv().await.unwrap();
assert!(matches!(evt, PeerEvent::ExtHandshake { .. }));
let target: SocketAddr = "10.0.0.1:8080".parse().unwrap();
let hp_msg = HolepunchMessage::rendezvous(target);
write_framed_message(
&mut server_stream,
&Message::Extended {
ext_id: our_ut_holepunch_id,
payload: hp_msg.to_bytes(),
},
)
.await;
let evt = event_rx.recv().await.unwrap();
match evt {
PeerEvent::HolepunchRendezvous {
peer_addr,
target: t,
} => {
assert_eq!(peer_addr, test_addr());
assert_eq!(t, target);
}
other => panic!("expected HolepunchRendezvous, got: {other:?}"),
}
cmd_tx.send(PeerCommand::Shutdown).await.unwrap();
let _ = handle.await;
}
#[test]
fn anonymous_mode_suppresses_ext_handshake_fields() {
let mut ext_hs = ExtHandshake::new();
ext_hs.p = Some(6881);
ext_hs.upload_only = Some(1);
ext_hs.reqq = Some(250);
ext_hs.v = None;
ext_hs.p = None;
ext_hs.reqq = None;
ext_hs.upload_only = None;
assert!(ext_hs.v.is_none());
assert!(ext_hs.p.is_none());
assert!(ext_hs.reqq.is_none());
assert!(ext_hs.upload_only.is_none());
assert!(!ext_hs.m.is_empty());
let encoded = ext_hs.to_bytes().unwrap();
let decoded = ExtHandshake::from_bytes(&encoded).unwrap();
assert!(decoded.v.is_none());
assert!(decoded.p.is_none());
assert!(decoded.reqq.is_none());
assert!(decoded.upload_only.is_none());
assert!(!decoded.m.is_empty());
}
}