use std::{fmt::Debug, io::ErrorKind};
use serde::{de::DeserializeOwned, Serialize};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter},
sync::mpsc,
};
use crate::{do_spawn, MaybeSend};
pub(crate) async fn framed_transport<T, U, R, W>(
read_stream: R,
write_stream: W,
mut req_rx: mpsc::Receiver<T>,
res_tx: mpsc::Sender<U>,
) where
R: AsyncRead + Unpin + MaybeSend + 'static,
W: AsyncWrite + Unpin + MaybeSend + 'static,
T: Debug + Serialize + Send + 'static,
U: Debug + DeserializeOwned + Send + 'static,
{
do_spawn(async move {
let mut br = BufReader::new(read_stream);
loop {
let size = match br.read_u64().await {
Ok(s) => s as usize,
Err(e) if e.kind() == ErrorKind::UnexpectedEof => break,
Err(e) => panic!("Got unexpected error: {:#?}", e),
};
let mut data = vec![0u8; size];
br.read_exact(&mut data).await.unwrap();
let response: U = bincode::deserialize(&data).unwrap();
res_tx.send(response).await.unwrap();
}
});
do_spawn(async move {
let mut bw = BufWriter::new(write_stream);
loop {
let item = match req_rx.try_recv() {
Ok(item) => item,
Err(mpsc::error::TryRecvError::Empty) => {
bw.flush().await.unwrap();
match req_rx.recv().await {
Some(item) => item,
None => break,
}
}
Err(mpsc::error::TryRecvError::Disconnected) => {
break;
}
};
let data = bincode::serialize(&item).unwrap();
bw.write_u64(data.len() as _).await.unwrap();
bw.write_all(&data).await.unwrap();
}
});
}
pub(crate) async fn frame<T, U, R, W>(
read_stream: R,
write_stream: W,
) -> (mpsc::Sender<T>, mpsc::Receiver<U>)
where
R: AsyncRead + Unpin + MaybeSend + 'static,
W: AsyncWrite + Unpin + MaybeSend + 'static,
T: Debug + Serialize + Send + 'static,
U: Debug + DeserializeOwned + Send + 'static,
{
let (send, req_rx) = mpsc::channel(32);
let (res_tx, recv) = mpsc::channel(32);
framed_transport(read_stream, write_stream, req_rx, res_tx).await;
(send, recv)
}