use crate::message::{self, reply::Reply, Message};
use std::{io, time::Duration};
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufStream},
sync::{mpsc, oneshot},
time,
};
use tracing::trace;
enum Request {
WriteMsg {
msg: Message,
response: oneshot::Sender<io::Result<()>>,
},
ReadMsg {
response: oneshot::Sender<io::Result<Message>>,
},
}
struct StreamHandler<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
{
stream: BufStream<S>,
}
impl<S> StreamHandler<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
{
fn new(stream: S) -> Self {
Self {
stream: BufStream::new(stream),
}
}
async fn handle_requests(mut self, mut conn: mpsc::Receiver<Request>, timeout: Duration) {
while let Some(req) = conn.recv().await {
match req {
Request::ReadMsg { response } => {
let f = message::read(&mut self.stream);
let result = match time::timeout(timeout, f).await {
Ok(r) => r,
Err(e) => Err(e.into()),
};
let _ = response.send(result);
}
Request::WriteMsg { msg, response } => {
let f = message::write(&mut self.stream, msg);
let result = match time::timeout(timeout, f).await {
Ok(r) => r,
Err(e) => Err(e.into()),
};
let _ = response.send(result);
}
}
}
let _ = self.stream.shutdown().await;
}
}
#[derive(Clone)]
pub struct Connection {
conn: mpsc::Sender<Request>,
}
impl Connection {
pub fn new<S>(stream: S, timeout: Duration) -> Self
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let (messages_tx, messages_rx) = mpsc::channel(1);
let handler = StreamHandler::new(stream);
tokio::spawn(handler.handle_requests(messages_rx, timeout));
Self { conn: messages_tx }
}
pub async fn read_message(&self) -> io::Result<Message> {
let (response_tx, response) = oneshot::channel();
let request = Request::ReadMsg {
response: response_tx,
};
let result = self.do_request(request, response).await;
if let Ok(msg) = &result {
trace!(?msg, "message read");
}
result
}
pub async fn write_reply(&self, reply: Reply) -> io::Result<()> {
let msg = reply.into_message();
self.write_message(msg).await
}
pub async fn write_message(&self, msg: Message) -> io::Result<()> {
let (response_tx, response) = oneshot::channel();
trace!(?msg, "writing message");
let request = Request::WriteMsg {
msg,
response: response_tx,
};
self.do_request(request, response).await
}
async fn do_request<T>(&self, request: Request, response: oneshot::Receiver<T>) -> T {
self.conn
.send(request)
.await
.unwrap_or_else(|_| panic!("connection stream closed"));
response.await.expect("connection stream exited")
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::ErrorKind;
use tokio::{io::AsyncReadExt, join};
#[tokio::test]
async fn multiple_connections() {
let (mut client, stream) = tokio::io::duplex(100);
let conn1 = Connection::new(stream, Duration::from_secs(30));
let conn2 = conn1.clone();
client.write_all(b"\0\0\0\x03xyz").await.unwrap();
let msg = conn1.read_message().await.unwrap();
assert_eq!(msg, Message::new(b'x', "yz"));
conn1.write_message(msg).await.unwrap();
let mut buffer = vec![0; 7];
client.read_exact(&mut buffer).await.unwrap();
assert_eq!(buffer, b"\0\0\0\x03xyz");
drop(conn1);
let msg = Message::new(b'x', "abc");
conn2.write_message(msg).await.unwrap();
let mut buffer = vec![0; 8];
client.read_exact(&mut buffer).await.unwrap();
assert_eq!(buffer, b"\0\0\0\x04xabc");
drop(conn2);
let e = client.read_u8().await.unwrap_err();
assert_eq!(e.kind(), ErrorKind::UnexpectedEof);
}
#[tokio::test]
async fn connection_timeout() {
let timeout = Duration::from_secs(30);
let (mut client, stream) = tokio::io::duplex(100);
let conn = Connection::new(stream, timeout);
time::pause();
let (stream_result, client_result) = join!(
async move { conn.read_message().await },
async move {
client.write_all(b"\0\0\0\x05").await.unwrap();
time::sleep(timeout + Duration::from_secs(5)).await;
client.write_all(b"Xyzabc").await
},
);
time::resume();
let e = stream_result.unwrap_err();
assert_eq!(e.kind(), ErrorKind::TimedOut);
let e = client_result.unwrap_err();
assert_eq!(e.kind(), ErrorKind::BrokenPipe);
}
}