mod stream_producer;
use stream_producer::*;
use stream_multiplexer::*;
use bytes::{Bytes, BytesMut};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::sync::mpsc as channel;
use std::io::Result as IoResult;
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct SharedIdGen<T: IdGen = IncrementIdGen> {
inner: Arc<Mutex<T>>,
id_rx: Arc<Mutex<channel::UnboundedReceiver<usize>>>,
id_tx: channel::UnboundedSender<usize>,
}
impl<T: IdGen> IdGen for SharedIdGen<T> {
fn next(&mut self) -> usize {
let id = self.inner.lock().unwrap().next();
self.id_tx
.send(id)
.expect("should be able to send updated id");
id
}
fn id(&self) -> usize {
self.inner.lock().unwrap().id()
}
fn seed(&mut self, seed: usize) {
self.inner.lock().unwrap().seed(seed)
}
}
impl<T: IdGen> Default for SharedIdGen<T> {
fn default() -> Self {
let (id_tx, id_rx) = channel::unbounded_channel();
Self {
inner: Default::default(),
id_tx,
id_rx: Arc::new(Mutex::new(id_rx)),
}
}
}
impl<T: IdGen> SharedIdGen<T> {
pub async fn wait_for_next_id(&self) -> usize {
self.id_rx.lock().unwrap().recv().await.unwrap()
}
}
#[allow(dead_code)]
pub(crate) fn init_logging() {
use tracing_subscriber::FmtSubscriber;
let subscriber = FmtSubscriber::builder()
.with_max_level(tracing::Level::TRACE)
.finish();
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
}
async fn bind() -> IoResult<TcpListener> {
tracing::info!("Starting");
let addrs = "127.0.0.1:0".to_string();
tracing::info!("Binding {:?}", &addrs);
TcpListener::bind(&addrs).await
}
#[tokio::test(basic_scheduler)]
async fn shutdown() {
let stream_halves = TcpStreamProducer::new(bind().await.unwrap());
let (control_write, control_read) = channel::unbounded_channel();
let (_data_write, data_read) = channel::unbounded_channel();
let (in_data_tx, _in_data_rx) = channel::channel(10);
let multiplexer = Multiplexer::new(8, data_read, vec![in_data_tx]);
let shutdown_status = tokio::task::spawn(multiplexer.run(stream_halves, control_read));
control_write.send(ControlMessage::Shutdown).unwrap();
assert!(shutdown_status.await.is_ok());
}
#[tokio::test(basic_scheduler)]
async fn socket_shutdown() {
let socket = bind().await.unwrap();
let local_addr = socket.local_addr().unwrap();
let socket = TcpStreamProducer::new(socket);
let client = tokio::net::TcpStream::connect(local_addr).await.unwrap();
client.shutdown(std::net::Shutdown::Both).unwrap();
let (control_write, control_read) = channel::unbounded_channel();
let (_data_write, data_read) = channel::unbounded_channel();
let (in_data_tx, _in_data_rx) = channel::channel(10);
let tcp_streams = Multiplexer::new(8, data_read, vec![in_data_tx]);
let shutdown_status = tokio::task::spawn(tcp_streams.run(socket, control_read));
control_write.send(ControlMessage::Shutdown).unwrap();
assert!(shutdown_status.await.is_ok());
}
#[tokio::test(basic_scheduler)]
async fn write_packets() {
let socket = bind().await.unwrap();
let local_addr = socket.local_addr().unwrap();
let socket = TcpStreamProducer::new(socket);
let (control_write, control_read) = channel::unbounded_channel();
let (data_write, data_read) = channel::unbounded_channel();
let id_gen: SharedIdGen<IncrementIdGen> = SharedIdGen::default();
let (in_data_tx, _in_data_rx) = channel::channel(10);
let multiplexer = Multiplexer::with_id_gen(8, id_gen.clone(), data_read, vec![in_data_tx]);
let shutdown_status = tokio::task::spawn(multiplexer.run(socket, control_read));
let mut client1 = tokio::net::TcpStream::connect(local_addr).await.unwrap();
let client1_id = id_gen.wait_for_next_id().await;
let client2 = tokio::net::TcpStream::connect(local_addr).await.unwrap();
let _client2_id = id_gen.wait_for_next_id().await;
let data = Bytes::from("a message");
data_write
.send(OutgoingMessage::new(vec![client1_id], vec![data.clone()]).into())
.unwrap();
let mut read_data = BytesMut::new();
let _read_res = client1.read_buf(&mut read_data).await.unwrap();
assert_eq!(read_data, "\0\ta message".as_bytes());
client1.shutdown(std::net::Shutdown::Both).unwrap();
client2.shutdown(std::net::Shutdown::Both).unwrap();
control_write.send(ControlMessage::Shutdown).unwrap();
assert!(shutdown_status.await.is_ok());
}
#[tokio::test(basic_scheduler)]
async fn read_packets() {
let socket = bind().await.unwrap();
let local_addr = socket.local_addr().unwrap();
let socket = TcpStreamProducer::new(socket);
let (control_write, control_read) = channel::unbounded_channel();
let (_data_write, data_read) = channel::unbounded_channel();
let id_gen: SharedIdGen<IncrementIdGen> = SharedIdGen::default();
let (in_data_tx, mut in_data_rx) = channel::channel(10);
let multiplexer = Multiplexer::with_id_gen(8, id_gen.clone(), data_read, vec![in_data_tx]);
let shutdown_status = tokio::task::spawn(multiplexer.run(socket, control_read));
let mut client1 = tokio::net::TcpStream::connect(local_addr).await.unwrap();
let client1_id = id_gen.wait_for_next_id().await;
assert_eq!(1, client1_id);
let message = in_data_rx.recv().await.expect("Should have connected.");
matches::assert_matches!(message, IncomingPacket::StreamConnected(_));
for _ in 0_u8..2 {
let mut data = Bytes::from("\0\ta message");
client1.write_buf(&mut data).await.unwrap();
let incoming_packet = in_data_rx.recv().await.unwrap();
assert_eq!(incoming_packet.id(), client1_id);
assert_eq!(
incoming_packet
.value()
.expect("should have a value")
.as_ref()
.unwrap(),
&Bytes::from("a message")
);
}
client1.shutdown(std::net::Shutdown::Both).unwrap();
let message = in_data_rx.recv().await.expect("Should have connected.");
matches::assert_matches!(
message,
IncomingPacket::StreamDisconnected(_, DisconnectReason::Graceful)
);
control_write.send(ControlMessage::Shutdown).unwrap();
assert!(shutdown_status.await.is_ok());
}
#[tokio::test(basic_scheduler)]
async fn change_channel() {
let socket = bind().await.unwrap();
let local_addr = socket.local_addr().unwrap();
let socket = TcpStreamProducer::new(socket);
let (control_write, control_read) = channel::unbounded_channel();
let (data_write, data_read) = channel::unbounded_channel();
let mut id_gen: SharedIdGen<IncrementIdGen> = SharedIdGen::default();
id_gen.seed(100);
let (in_data_tx0, mut in_data_rx0) = channel::channel(10);
let (in_data_tx1, mut in_data_rx1) = channel::channel(10);
let multiplexer =
Multiplexer::with_id_gen(8, id_gen.clone(), data_read, vec![in_data_tx0, in_data_tx1]);
let shutdown_status = tokio::task::spawn(multiplexer.run(socket, control_read));
let mut client1 = tokio::net::TcpStream::connect(local_addr).await.unwrap();
let client1_id = id_gen.wait_for_next_id().await;
assert_eq!(client1_id, 101);
let message = in_data_rx0.recv().await.expect("Should have connected.");
matches::assert_matches!(message, IncomingPacket::StreamConnected(_));
let mut data = Bytes::from("\0\ta message");
client1.write_buf(&mut data).await.unwrap();
let incoming_packet = in_data_rx0.recv().await.unwrap();
assert_eq!(incoming_packet.id(), client1_id);
assert_eq!(
incoming_packet
.value()
.expect("should have a value")
.as_ref()
.unwrap(),
&Bytes::from("a message")
);
let change_channel = OutgoingPacket::ChangeChannel(vec![client1_id], 1);
data_write.send(change_channel).unwrap();
let message = in_data_rx0.recv().await.expect("Should have disconnected.");
matches::assert_matches!(
message,
IncomingPacket::StreamDisconnected(_, DisconnectReason::ChannelChange(1))
);
let message = in_data_rx1.recv().await.expect("Should have connected.");
matches::assert_matches!(message, IncomingPacket::StreamConnected(_));
let data = Bytes::from("a message from the server");
data_write
.send(OutgoingMessage::new(vec![client1_id], vec![data.clone()]).into())
.unwrap();
let mut read_data = BytesMut::new();
let _read_res = client1.read_buf(&mut read_data).await.unwrap();
assert_eq!(read_data, "\0\x19a message from the server".as_bytes());
let mut data = Bytes::from("\0\x10a second message");
client1.write_buf(&mut data).await.unwrap();
let incoming_packet = in_data_rx1.recv().await.unwrap();
assert_eq!(incoming_packet.id(), client1_id);
assert_eq!(
incoming_packet
.value()
.expect("Should have a value")
.as_ref()
.unwrap(),
&Bytes::from("a second message")
);
client1.shutdown(std::net::Shutdown::Both).unwrap();
control_write.send(ControlMessage::Shutdown).unwrap();
assert!(shutdown_status.await.is_ok());
}
#[tokio::test(basic_scheduler)]
async fn linkdead() {
let socket = bind().await.unwrap();
let local_addr = socket.local_addr().unwrap();
let socket = TcpStreamProducer::new(socket);
let (control_write, control_read) = channel::unbounded_channel();
let (_data_write, data_read) = channel::unbounded_channel();
let id_gen: SharedIdGen<IncrementIdGen> = SharedIdGen::default();
let (in_data_tx, mut in_data_rx) = channel::channel(10);
let multiplexer = Multiplexer::with_id_gen(8, id_gen.clone(), data_read, vec![in_data_tx]);
let shutdown_status = tokio::task::spawn(multiplexer.run(socket, control_read));
let client1 = tokio::net::TcpStream::connect(local_addr).await.unwrap();
let client1_id = id_gen.wait_for_next_id().await;
assert_eq!(1, client1_id);
let message = in_data_rx.recv().await.expect("Should have connected.");
matches::assert_matches!(message, IncomingPacket::StreamConnected(_));
client1.shutdown(std::net::Shutdown::Both).unwrap();
let message = in_data_rx.recv().await.expect("should have gone linkdead");
assert_eq!(client1_id, message.id());
matches::assert_matches!(
message,
IncomingPacket::StreamDisconnected(_, DisconnectReason::Graceful)
);
control_write.send(ControlMessage::Shutdown).unwrap();
assert!(shutdown_status.await.is_ok());
}