use std::{
io::{Error, ErrorKind},
mem::swap,
};
use crate::{buf::GrowableCircleBuf, ReadStatus, Session, TlsSession, WriteStatus};
pub struct FramingSession<S, F> {
session: S,
framing_strategy: F,
write_buffer: GrowableCircleBuf,
read_buffer: Vec<u8>,
read_advance: usize,
}
impl<S, F> FramingSession<S, F>
where
S: Session<ReadData = [u8], WriteData = [u8]>,
F: FramingStrategy,
{
pub fn new(session: S, framing_strategy: F, write_buffer_capacity: usize) -> Self {
Self {
session,
framing_strategy,
write_buffer: GrowableCircleBuf::new(write_buffer_capacity),
read_buffer: Vec::new(),
read_advance: 0,
}
}
}
impl<S, F> Session for FramingSession<S, F>
where
S: Session<ReadData = [u8], WriteData = [u8]>,
F: FramingStrategy,
{
type ReadData = F::ReadFrame;
type WriteData = F::WriteFrame;
fn is_connected(&self) -> bool {
self.session.is_connected()
}
fn try_connect(&mut self) -> Result<bool, Error> {
self.session.try_connect()
}
fn drive(&mut self) -> Result<bool, std::io::Error> {
self.session.drive()?;
if self.write_buffer.is_empty() {
return Ok(false);
}
let write_buffer = self.write_buffer.peek_read();
let wrote_len = match self.session.write(write_buffer)? {
WriteStatus::Success => write_buffer.len(),
WriteStatus::Pending(pending) => write_buffer.len() - pending.len(),
};
self.write_buffer.advance_read(wrote_len);
Ok(wrote_len > 0)
}
fn write<'a>(
&mut self,
frame: &'a Self::WriteData,
) -> Result<WriteStatus<'a, Self::WriteData>, Error> {
let data = self.framing_strategy.serialize_frame(&frame)?;
if self.write_buffer.try_write(&data)? {
Ok(WriteStatus::Success)
} else {
Ok(WriteStatus::Pending(&frame))
}
}
fn read<'a>(&'a mut self) -> Result<ReadStatus<'a, Self::ReadData>, std::io::Error> {
if self.read_advance != 0 {
let mut new_buf = Vec::from(&self.read_buffer[self.read_advance..]);
self.read_advance = 0;
swap(&mut new_buf, &mut self.read_buffer);
}
if self
.framing_strategy
.check_deserialize_frame(&self.read_buffer, false)?
{
let de = self.framing_strategy.deserialize_frame(&self.read_buffer)?;
self.read_advance = de.size;
return Ok(ReadStatus::Data(de.frame));
}
let data = match self.session.read()? {
ReadStatus::Data(data) => data,
ReadStatus::Buffered => return Ok(ReadStatus::Buffered),
ReadStatus::None => return Ok(ReadStatus::None),
};
self.read_buffer.extend_from_slice(data);
Ok(ReadStatus::Buffered)
}
fn flush(&mut self) -> Result<(), std::io::Error> {
while !self.write_buffer.is_empty() {
self.drive()?;
}
self.session.flush()
}
}
impl<S, F> TlsSession for FramingSession<S, F>
where
S: TlsSession<ReadData = [u8], WriteData = [u8]>,
F: FramingStrategy,
{
fn to_tls(
&mut self,
domain: &str,
config: tcp_stream::TLSConfig<'_, '_, '_>,
) -> Result<(), std::io::Error> {
self.session.to_tls(domain, config)
}
fn is_handshake_complete(&self) -> Result<bool, Error> {
self.session.is_handshake_complete()
}
}
pub trait FramingStrategy {
type ReadFrame: ?Sized;
type WriteFrame: ?Sized;
fn check_deserialize_frame(&mut self, data: &[u8], eof: bool) -> Result<bool, Error>;
fn deserialize_frame<'a>(
&'a mut self,
data: &'a [u8],
) -> Result<DeserializedFrame<'a, Self::ReadFrame>, Error>;
fn serialize_frame<'a>(
&'a mut self,
data: &'a Self::WriteFrame,
) -> Result<Vec<&'a [u8]>, Error>;
}
pub struct DeserializedFrame<'a, T: ?Sized> {
pub frame: &'a T,
pub size: usize,
}
impl<'a, T: ?Sized> DeserializedFrame<'a, T> {
pub fn new(frame: &'a T, size: usize) -> Self {
Self { frame, size }
}
}
pub struct U64FramingStrategy {
header: [u8; 8],
}
impl U64FramingStrategy {
pub fn new() -> Self {
Self { header: [0; 8] }
}
}
impl FramingStrategy for U64FramingStrategy {
type ReadFrame = [u8];
type WriteFrame = [u8];
fn serialize_frame<'a>(
&'a mut self,
data: &'a Self::ReadFrame,
) -> Result<Vec<&'a [u8]>, Error> {
let len = u64::try_from(data.len())
.map_err(|_| Error::new(ErrorKind::InvalidData, "frame to serialize exceeds u64"))?;
self.header.copy_from_slice(&len.to_le_bytes());
let mut buffers = Vec::new();
buffers.push(self.header.as_slice());
buffers.push(data);
Ok(buffers)
}
fn check_deserialize_frame(&mut self, data: &[u8], _eof: bool) -> Result<bool, Error> {
if data.len() < 8 {
return Ok(false);
}
let len = u64::from_le_bytes(
data[..8]
.try_into()
.expect("expected 8 byte slice to be 8 bytes long"),
);
let ulen = usize::try_from(len).map_err(|_| {
Error::new(ErrorKind::InvalidData, "frame to deserialize exceeds usize")
})?;
Ok(data.len() - 8 >= ulen)
}
fn deserialize_frame<'a>(
&'a mut self,
data: &'a [u8],
) -> Result<DeserializedFrame<'a, Self::WriteFrame>, Error> {
if data.len() < 8 {
return Err(Error::new(
ErrorKind::InvalidData,
"cannot deserialize partial frame",
));
}
let len = u64::from_le_bytes(
data[..8]
.try_into()
.expect("expected 8 byte slice to be 8 bytes long"),
);
let ulen = usize::try_from(len).map_err(|_| {
Error::new(ErrorKind::InvalidData, "frame to deserialize exceeds usize")
})?;
if data.len() - 8 >= ulen {
Ok(DeserializedFrame::new(&data[8..][..ulen], 8 + ulen))
} else {
Err(Error::new(
ErrorKind::InvalidData,
"cannot deserialize partial frame",
))
}
}
}
#[cfg(test)]
mod test {
use crate::{
frame::{FramingSession, U64FramingStrategy},
tcp::{StreamingTcpSession, TcpServer},
ReadStatus, Session, WriteStatus,
};
#[test]
fn one_small_frame() {
let server = TcpServer::bind("127.0.0.1:34001").unwrap();
let client = StreamingTcpSession::connect("127.0.0.1:34001")
.unwrap()
.with_nonblocking(true)
.unwrap();
let session = server
.accept()
.unwrap()
.unwrap()
.0
.with_nonblocking(true)
.unwrap();
let mut client = FramingSession::new(client, U64FramingStrategy::new(), 1024);
let mut session = FramingSession::new(session, U64FramingStrategy::new(), 1024);
let mut read_payload = None;
let mut write_payload = Vec::new();
for i in 0..512 {
write_payload.push(i as u8)
}
let mut remaining = write_payload.as_slice();
while let WriteStatus::Pending(pw) = client.write(remaining).unwrap() {
remaining = pw;
client.drive().unwrap();
if let ReadStatus::Data(read) = session.read().unwrap() {
read_payload = Some(Vec::from(read));
}
}
while let None = read_payload {
client.drive().unwrap();
if let ReadStatus::Data(read) = session.read().unwrap() {
read_payload = Some(Vec::from(read));
}
}
let read_payload = read_payload.unwrap();
assert_eq!(read_payload.len(), write_payload.len());
assert_eq!(read_payload, write_payload);
}
#[test]
fn one_large_frame() {
let server = TcpServer::bind("127.0.0.1:34002").unwrap();
let client = StreamingTcpSession::connect("127.0.0.1:34002")
.unwrap()
.with_nonblocking(true)
.unwrap();
let session = server
.accept()
.unwrap()
.unwrap()
.0
.with_nonblocking(true)
.unwrap();
let mut client = FramingSession::new(client, U64FramingStrategy::new(), 1024);
let mut session = FramingSession::new(session, U64FramingStrategy::new(), 1024);
let mut read_payload = None;
let mut write_payload = Vec::new();
for i in 0..888888 {
write_payload.push(i as u8)
}
let mut remaining = write_payload.as_slice();
while let WriteStatus::Pending(pw) = client.write(remaining).unwrap() {
remaining = pw;
if let ReadStatus::Data(read) = session.read().unwrap() {
read_payload = Some(Vec::from(read));
}
}
while let None = read_payload {
client.drive().unwrap();
if let ReadStatus::Data(read) = session.read().unwrap() {
read_payload = Some(Vec::from(read));
}
}
let read_payload = read_payload.unwrap();
assert_eq!(read_payload.len(), write_payload.len());
assert_eq!(read_payload, write_payload);
}
#[test]
fn framing_slow_consumer() {
let server = TcpServer::bind("127.0.0.1:34003").unwrap();
let client = StreamingTcpSession::connect("127.0.0.1:34003")
.unwrap()
.with_nonblocking(true)
.unwrap();
let session = server
.accept()
.unwrap()
.unwrap()
.0
.with_nonblocking(true)
.unwrap();
let mut client = FramingSession::new(client, U64FramingStrategy::new(), 1024);
let mut session = FramingSession::new(session, U64FramingStrategy::new(), 1024);
let mut received = Vec::new();
let mut backpressure = false;
for i in 0..100000 {
let m = format!("test test test test hello world {i:06}!");
let mut remaining = m.as_bytes();
while let WriteStatus::Pending(pw) = client.write(remaining).unwrap() {
client.drive().unwrap();
remaining = pw;
backpressure = true;
for _ in 0..10 {
if let ReadStatus::Data(read) = session.read().unwrap() {
received.push(String::from_utf8_lossy(read).to_string());
}
}
}
client.drive().unwrap();
}
assert!(backpressure);
while received.len() < 100000 {
client.drive().unwrap();
if let ReadStatus::Data(read) = session.read().unwrap() {
received.push(String::from_utf8_lossy(read).to_string());
}
}
for i in 0..100000 {
assert_eq!(
received.get(i).expect(&format!("message idx {i}")),
&format!("test test test test hello world {i:06}!")
);
}
}
}