use std::{future::ready, time::Duration};
use futures::{StreamExt, future::pending};
use mock::MockEavesdropper;
use rand::{SeedableRng, rngs::StdRng};
use tokio::{join, select, time::sleep};
use zerocopy::IntoBytes;
use crate::noise::index_table::IndexTable;
pub mod mock;
#[tokio::test]
#[test_log::test]
async fn number_of_packets() {
test_device_pair(async |eve| {
let expected_count = packet_count() + 2 + 1;
let ipv4_count = eve.ipv4().count().await;
assert_eq!(ipv4_count, expected_count);
})
.await
}
#[tokio::test]
#[test_log::test]
async fn ipv6_isnt_used() {
test_device_pair(async |eve| {
let ipv6_count = eve.ipv6().count().await;
assert_eq!(dbg!(ipv6_count), 0);
})
.await
}
#[tokio::test]
#[test_log::test]
async fn one_handshake() {
test_device_pair(async |eve| {
let handshake_inits = async {
assert_eq!(eve.wg_handshake_init().count().await, 1);
};
let handshake_resps = async {
assert_eq!(eve.wg_handshake_resp().count().await, 1);
};
join! { handshake_inits, handshake_resps };
})
.await
}
#[tokio::test]
#[test_log::test]
async fn one_keepalive() {
test_device_pair(async |eve| {
let keepalive_count = eve
.wg_data()
.filter(|wg_data| ready(wg_data.is_keepalive()))
.count()
.await;
assert_eq!(keepalive_count, 1);
})
.await
}
#[tokio::test]
#[test_log::test]
async fn wg_data_length_is_x16() {
test_device_pair(async |eve| {
let wg_data_count = eve
.wg_data()
.map(|wg| {
let payload_len = wg.encrypted_encapsulated_packet().len();
assert!(
payload_len.is_multiple_of(16),
"wireguard data length must be a multiple of 16, but was {payload_len}"
);
})
.count()
.await;
assert!(dbg!(wg_data_count) >= packet_count());
})
.await
}
#[tokio::test]
#[test_log::test]
async fn test_indices() {
let expected_alice_idx =
IndexTable::next_id(&mut StdRng::seed_from_u64(mock::ALICE_INDEX_SEED));
let expected_bob_idx = IndexTable::next_id(&mut StdRng::seed_from_u64(mock::BOB_INDEX_SEED));
test_device_pair(async |eve| {
let check_init = eve.wg_handshake_init().for_each(async |p| {
assert_eq!(p.sender_idx.get(), expected_alice_idx);
});
let check_alice_data = eve.wg_data().for_each(async |p| {
assert_eq!(p.header.receiver_idx, expected_bob_idx);
});
let check_resp = eve.wg_handshake_resp().for_each(async |p| {
assert_eq!(p.sender_idx.get(), expected_bob_idx);
});
join!(check_init, check_resp, check_alice_data);
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_endpoint_roaming() {
let (mut alice, mut bob, eve) = mock::device_pair().await;
let packet = mock::packet(b"Hello!");
let mut ping_pong = async |alice_ip| {
*alice.source_ipv4_override.lock().await = Some(alice_ip);
alice.app_tx.send(packet.clone()).await;
assert_eq!(bob.app_rx.recv().await.as_bytes(), packet.as_bytes());
let peers = bob.device.peers().await;
assert_eq!(peers.len(), 1);
let stats = &peers[0];
assert_eq!(
stats.peer.endpoint.map(|addr| addr.ip()),
Some(alice_ip.into()),
);
let ip_stream = eve.ip();
tokio::pin!(ip_stream);
let next_packet = async {
tokio::time::timeout(Duration::from_secs(5), ip_stream.next())
.await
.expect("did not see sent packet")
};
let (_, sniffed_packet) = join! {
bob.app_tx.send(packet.clone()),
next_packet,
};
alice.app_rx.recv().await;
assert_eq!(
sniffed_packet.and_then(|ip| ip.destination()),
Some(alice_ip.into())
);
};
ping_pong("1.2.3.4".parse().unwrap()).await;
ping_pong("1.3.3.7".parse().unwrap()).await;
ping_pong("1.2.3.4".parse().unwrap()).await;
}
fn packet_count() -> usize {
mock::packets_of_every_size().len()
}
async fn test_device_pair(eavesdrop: impl AsyncFnOnce(MockEavesdropper) + Send) {
let (mut alice, mut bob, eve) = mock::device_pair().await;
let eavesdrop = async {
select! {
_ = eavesdrop(eve) => {}
_ = sleep(Duration::from_secs(1)) => panic!("eavesdrop timeout"),
}
};
let drive_connection = async move {
let packets_to_send = mock::packets_of_every_size();
let packets_to_recv = packets_to_send.clone();
let send_packets = async {
for packet in packets_to_send {
alice.app_tx.send(packet).await;
}
pending().await
};
let wait_for_packets = async {
for expected_packet in packets_to_recv {
let p = bob.app_rx.recv().await;
assert_eq!(p.as_bytes(), expected_packet.as_bytes());
}
};
select! {
_ = wait_for_packets => {},
_ = send_packets => unreachable!(),
_ = alice.app_rx.recv() => panic!("no data is sent from bob to alice"),
_ = sleep(Duration::from_secs(1)) => panic!("timeout"),
}
drop((alice, bob));
};
join! {
drive_connection,
eavesdrop
};
}