use super::*;
use std::{
net::{SocketAddr, TcpListener, TcpStream},
thread::sleep,
time::Duration,
};
#[derive(Debug, Clone, PartialEq)]
struct Msg {
data: Vec<u8>,
}
impl TryFrom<&[u8]> for Msg {
type Error = ();
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
Ok(Self { data: value.into() })
}
}
impl TryFrom<Msg> for Vec<u8> {
type Error = ();
fn try_from(value: Msg) -> Result<Self, Self::Error> {
Ok(value.data)
}
}
const TEST_SIZE: usize = 10_000_000;
#[test]
fn test_big_payload_from_server() {
let (mut h, mut c) = build_stream_pair(2000);
let m = Msg {
data: vec![1; TEST_SIZE],
};
let m: Vec<u8> = m.try_into().unwrap();
let mut sent = h.total_sent();
let mut read = c.total_read();
h.write(m.clone()).unwrap();
let msg_rec = wait_msg(&mut c).unwrap();
sent = h.total_sent() - sent;
read = c.total_read() - read;
assert_eq!(m, *msg_rec);
assert_eq!(sent, read);
assert_eq!(sent, TEST_SIZE);
}
#[test]
fn test_big_payload_from_client() {
let (mut h, mut c) = build_stream_pair(2001);
let m = Msg {
data: vec![1; TEST_SIZE],
};
let m: Vec<u8> = m.try_into().unwrap();
let mut sent = c.total_sent();
let mut read = h.total_read();
c.write(m.clone()).unwrap();
let msg_rec = wait_msg(&mut h).unwrap();
sent = c.total_sent() - sent;
read = h.total_read() - read;
assert_eq!(m, *msg_rec);
assert_eq!(sent, read);
assert_eq!(sent, TEST_SIZE);
}
#[test]
fn test_multichannels() {
let (mut h1, mut h2, mut c1, mut c2) = build_stream_triple(2002);
for _ in 0..10 {
let m = Msg { data: vec![1; 100] };
let m: Vec<u8> = m.try_into().unwrap();
h1.write(m.clone()).unwrap();
h2.write(m.clone()).unwrap();
let msg_rec1 = wait_msg(&mut c1).unwrap();
let msg_rec2 = wait_msg(&mut c2).unwrap();
assert_eq!(m, *msg_rec1);
assert_eq!(m, *msg_rec2);
}
}
#[test]
#[cfg(all(feature = "compression", feature = "encryption"))]
fn test_compresssion_and_encryption() {
let (mut h1, mut h2, mut c1, mut c2) = build_stream_triple_comp_enc(2003);
for _ in 0..10 {
let m = Msg { data: vec![1; 100] };
let m: Vec<u8> = m.try_into().unwrap();
h1.write(m.clone()).unwrap();
h2.write(m.clone()).unwrap();
let msg_rec1 = wait_msg(&mut c1).unwrap();
let msg_rec2 = wait_msg(&mut c2).unwrap();
assert_eq!(m, *msg_rec1);
assert_eq!(m, *msg_rec2);
}
}
fn wait_msg(c: &mut NonBlockStream) -> Option<Vec<u8>> {
let mut count = 0;
sleep(Duration::from_millis(100));
let mut msg_rec = c.read().unwrap();
while msg_rec.is_none() && count < 100 {
sleep(Duration::from_millis(100));
count += 1;
msg_rec = c.read().unwrap();
}
msg_rec
}
fn build_stream_pair(p: u16) -> (NonBlockStream, NonBlockStream) {
let s = SocketAddr::from(([127, 0, 0, 1], p));
let l = TcpListener::bind(s).unwrap();
let c = TcpStream::connect(s).unwrap();
let (h, _) = l.accept().unwrap();
(h.into(), c.into())
}
fn build_stream_triple(
p: u16,
) -> (
NonBlockStream,
NonBlockStream,
NonBlockStream,
NonBlockStream,
) {
let (c1, h_to_c1, c2, h_to_c2) = create_connections(p);
(h_to_c1.into(), h_to_c2.into(), c1.into(), c2.into())
}
fn create_connections(p: u16) -> (TcpStream, TcpStream, TcpStream, TcpStream) {
let s = SocketAddr::from(([127, 0, 0, 1], p));
let l = TcpListener::bind(s).unwrap();
let c1 = TcpStream::connect(s).unwrap();
let (h_to_c1, _) = l.accept().unwrap();
let c2 = TcpStream::connect(s).unwrap();
let (h_to_c2, _) = l.accept().unwrap();
(c1, h_to_c1, c2, h_to_c2)
}
#[cfg(all(feature = "compression", feature = "encryption"))]
fn build_stream_triple_comp_enc(
p: u16,
) -> (
NonBlockStream,
NonBlockStream,
NonBlockStream,
NonBlockStream,
) {
let (c1, h_to_c1, c2, h_to_c2) = create_connections(p);
let key = [0u8; 32];
let nb_to_c1 = NonBlockStream::from_version_packs(
Default::default(),
Packs::default().compress().encrypt(&key),
h_to_c1,
);
let nb_to_c2 = NonBlockStream::from_version_packs(
Default::default(),
Packs::default().compress().encrypt(&key),
h_to_c2,
);
let nbc1 = NonBlockStream::from_version_packs(
Default::default(),
Packs::default().compress().encrypt(&key),
c1,
);
let nbc2 = NonBlockStream::from_version_packs(
Default::default(),
Packs::default().compress().encrypt(&key),
c2,
);
(nb_to_c1, nb_to_c2, nbc1, nbc2)
}