use anyhow::Result;
use bincode_next::{config::standard, encode_to_vec};
use bon::Builder;
use tokio::{io::AsyncWriteExt as _, net::tcp::OwnedWriteHalf};
use crate::Frame;
#[derive(Builder, Debug)]
pub struct ConnectionWriter {
writer: OwnedWriteHalf,
}
impl ConnectionWriter {
pub async fn write_frame(&mut self, frame: &Frame) -> Result<()> {
let id = frame.id();
let encoded = encode_to_vec(frame, standard())?;
let len = encoded.len();
self.writer.write_u8(id).await?;
self.writer.write_all(len.to_be_bytes().as_slice()).await?;
self.writer.write_all(&encoded).await?;
self.writer.flush().await.map_err(Into::into)
}
pub async fn write_bytes(&mut self, bytes: &[u8]) -> Result<()> {
self.writer.write_all(bytes).await?;
self.writer.flush().await.map_err(Into::into)
}
}
#[cfg(test)]
mod tests {
use tokio::net::{TcpListener, TcpStream};
use super::*;
#[tokio::test]
async fn write_frame_succeeds() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (server, client) = tokio::join!(
async { listener.accept().await.map(|(s, _)| s).unwrap() },
TcpStream::connect(addr),
);
let (_server_r, _server_w) = server.into_split();
let (_, client_w) = client.unwrap().into_split();
let mut writer = ConnectionWriter::builder().writer(client_w).build();
writer.write_frame(&Frame::KexFailure).await.unwrap();
}
#[tokio::test]
async fn write_bytes_succeeds() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (server, client) = tokio::join!(
async { listener.accept().await.map(|(s, _)| s).unwrap() },
TcpStream::connect(addr),
);
let (_server_r, _server_w) = server.into_split();
let (_, client_w) = client.unwrap().into_split();
let mut writer = ConnectionWriter::builder().writer(client_w).build();
writer.write_bytes(b"hello").await.unwrap();
}
}