use anyhow::Result;
use bon::Builder;
use tokio::sync::mpsc::UnboundedReceiver;
use crate::{ConnectionWriter, Frame};
#[derive(Builder, Debug)]
pub struct KexSender {
writer: ConnectionWriter,
rx: UnboundedReceiver<Frame>,
}
impl KexSender {
pub async fn handle_send_frames(&mut self) -> Result<()> {
while let Some(frame) = self.rx.recv().await {
self.writer.write_frame(&frame).await?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use tokio::{
net::{TcpListener, TcpStream},
sync::mpsc::unbounded_channel,
};
use super::{ConnectionWriter, Frame, KexSender};
use crate::ConnectionReader;
use anyhow::Result;
async fn make_loopback() -> Result<(ConnectionReader, ConnectionWriter)> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let (server, client) = tokio::join!(
async { listener.accept().await.map(|(s, _)| s) },
TcpStream::connect(addr),
);
let (server_r, _) = server?.into_split();
let (_, client_w) = client?.into_split();
let reader = ConnectionReader::builder().reader(server_r).build();
let writer = ConnectionWriter::builder().writer(client_w).build();
Ok((reader, writer))
}
#[tokio::test]
async fn kex_sender_relays_frame() -> Result<()> {
let (mut reader, writer) = make_loopback().await?;
let (tx, rx) = unbounded_channel();
let mut sender = KexSender::builder().writer(writer).rx(rx).build();
tx.send(Frame::KexFailure).expect("test channel send");
drop(tx);
sender.handle_send_frames().await?;
let frame = reader.read_frame().await?;
assert_eq!(frame, Some(Frame::KexFailure));
Ok(())
}
#[tokio::test]
async fn kex_sender_stops_on_channel_close() -> Result<()> {
let (_, writer) = make_loopback().await?;
let (tx, rx) = unbounded_channel::<Frame>();
drop(tx);
let mut sender = KexSender::builder().writer(writer).rx(rx).build();
sender.handle_send_frames().await?;
Ok(())
}
}