use std::io::Cursor;
use anyhow::Result;
use bon::Builder;
use bytes::{Buf as _, BytesMut};
use tokio::{io::AsyncReadExt as _, net::tcp::OwnedReadHalf};
use crate::{Frame, error::Error};
#[derive(Builder, Debug)]
pub struct ConnectionReader {
reader: OwnedReadHalf,
#[builder(default = BytesMut::with_capacity(4096))]
buffer: BytesMut,
}
impl ConnectionReader {
pub async fn read_frame(&mut self) -> Result<Option<Frame>> {
loop {
if let Some(frame) = self.parse_frame()? {
return Ok(Some(frame));
}
if 0 == self.reader.read_buf(&mut self.buffer).await? {
if self.buffer.is_empty() {
return Ok(None);
}
return Err(Error::ConnectionResetByPeer.into());
}
}
}
fn parse_frame(&mut self) -> Result<Option<Frame>> {
let mut buf = Cursor::new(&self.buffer[..]);
buf.set_position(0);
match Frame::parse(&mut buf) {
Ok(Some(frame)) => {
let len = usize::try_from(buf.position())?;
self.buffer.advance(len);
Ok(Some(frame))
}
Ok(None) => {
Ok(None)
}
Err(err) => Err(err),
}
}
}
#[cfg(test)]
mod tests {
use tokio::net::{TcpListener, TcpStream};
use super::*;
use crate::ConnectionWriter;
async fn make_loopback() -> (ConnectionReader, ConnectionWriter) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (server, client) = tokio::join!(
async { listener.accept().await.map(|(s, _)| s).unwrap() },
TcpStream::connect(addr),
);
let (server_r, _) = server.into_split();
let (_, client_w) = client.unwrap().into_split();
let reader = ConnectionReader::builder().reader(server_r).build();
let writer = ConnectionWriter::builder().writer(client_w).build();
(reader, writer)
}
#[tokio::test]
async fn read_frame_round_trip() {
let (mut reader, mut writer) = make_loopback().await;
writer.write_frame(&Frame::KexFailure).await.unwrap();
drop(writer);
let frame = reader.read_frame().await.unwrap();
assert_eq!(frame, Some(Frame::KexFailure));
}
#[tokio::test]
async fn read_frame_eof_returns_none() {
let (mut reader, writer) = make_loopback().await;
drop(writer);
let frame = reader.read_frame().await.unwrap();
assert_eq!(frame, None);
}
}