use std::{
fmt::Debug,
io::{Error, ErrorKind},
mem::swap,
};
use crate::{
DriveOutcome, Flush, Publish, PublishOutcome, Receive, ReceiveOutcome, Session, SessionStatus,
buffer::GrowableCircleBuf,
};
pub struct FrameDuplex<S, DF, SF> {
session: S,
deserialize_frame: DF,
serialize_frame: SF,
write_buffer: GrowableCircleBuf,
read_buffer: Vec<u8>,
read_advance: usize,
}
impl<S, DF, SF> FrameDuplex<S, DF, SF>
where
S: for<'a> Publish<PublishPayload<'a> = &'a [u8]>
+ for<'a> Receive<ReceivePayload<'a> = &'a [u8]>
+ 'static,
DF: DeserializeFrame + 'static,
SF: SerializeFrame + 'static,
{
pub fn new(
session: S,
deserialize_frame: DF,
serialize_frame: SF,
write_buffer_capacity: usize,
) -> Self {
Self {
session,
deserialize_frame,
serialize_frame,
write_buffer: GrowableCircleBuf::new(write_buffer_capacity)
.unwrap_or_else(|_| GrowableCircleBuf::new(usize::MAX / 2).unwrap()),
read_buffer: Vec::new(),
read_advance: 0,
}
}
pub(crate) fn read_buffer_mut<'a>(&'a mut self) -> &'a mut Vec<u8> {
&mut self.read_buffer
}
}
impl<S, DF, SF> Session for FrameDuplex<S, DF, SF>
where
S: for<'a> Publish<PublishPayload<'a> = &'a [u8]> + 'static,
DF: DeserializeFrame + 'static,
SF: SerializeFrame + 'static,
{
fn status(&self) -> crate::SessionStatus {
self.session.status()
}
fn drive(&mut self) -> Result<DriveOutcome, std::io::Error> {
let mut outcome = self.session.drive()?;
if self.write_buffer.is_empty() {
return Ok(outcome);
}
let write_buffer = self.write_buffer.peek_read();
let wrote_len = match self.session.publish(write_buffer)? {
PublishOutcome::Published => write_buffer.len(),
PublishOutcome::Incomplete(pending) => write_buffer.len() - pending.len(),
};
self.write_buffer.advance_read(wrote_len)?;
if wrote_len > 0 {
outcome = DriveOutcome::Active;
}
Ok(outcome)
}
}
impl<S, DF, SF> Publish for FrameDuplex<S, DF, SF>
where
S: for<'a> Publish<PublishPayload<'a> = &'a [u8]> + 'static,
DF: DeserializeFrame + 'static,
SF: SerializeFrame + 'static,
{
type PublishPayload<'a> = SF::SerializedFrame<'a>;
fn publish<'a>(
&mut self,
frame: Self::PublishPayload<'a>,
) -> Result<PublishOutcome<Self::PublishPayload<'a>>, Error> {
if self.session.status() != SessionStatus::Established {
return Err(Error::new(
ErrorKind::NotConnected,
"underlying session is not established",
));
}
let outcome = self
.serialize_frame
.serialize_frame(frame, &mut self.write_buffer)?;
self.drive()?;
Ok(outcome)
}
}
impl<S, DF, SF> Flush for FrameDuplex<S, DF, SF>
where
S: for<'a> Publish<PublishPayload<'a> = &'a [u8]> + Flush + 'static,
DF: DeserializeFrame + 'static,
SF: SerializeFrame + 'static,
{
fn flush(&mut self) -> Result<(), std::io::Error> {
while !self.write_buffer.is_empty() {
self.drive()?;
}
self.session.flush()
}
}
impl<S, DF, SF> Receive for FrameDuplex<S, DF, SF>
where
S: for<'a> Publish<PublishPayload<'a> = &'a [u8]>
+ for<'a> Receive<ReceivePayload<'a> = &'a [u8]>
+ 'static,
DF: DeserializeFrame + 'static,
SF: SerializeFrame + 'static,
{
type ReceivePayload<'a> = DF::DeserializedFrame<'a>;
fn receive<'a>(
&'a mut self,
) -> Result<ReceiveOutcome<Self::ReceivePayload<'a>>, std::io::Error> {
if self.read_advance != 0 {
let mut new_buf = Vec::from(&self.read_buffer[self.read_advance..]);
self.read_advance = 0;
swap(&mut new_buf, &mut self.read_buffer);
}
if self
.deserialize_frame
.check_deserialize_frame(&self.read_buffer, false)?
{
let de = self
.deserialize_frame
.deserialize_frame(&self.read_buffer)?;
self.read_advance = de.size;
return Ok(ReceiveOutcome::Payload(de.frame));
}
let data = match self.session.receive() {
Ok(ReceiveOutcome::Payload(data)) => data,
Ok(ReceiveOutcome::Active) => return Ok(ReceiveOutcome::Active),
Ok(ReceiveOutcome::Idle) => return Ok(ReceiveOutcome::Idle),
Err(err) => {
if let ErrorKind::UnexpectedEof = err.kind() {
if self
.deserialize_frame
.check_deserialize_frame(&self.read_buffer, true)?
{
let de = self
.deserialize_frame
.deserialize_frame(&self.read_buffer)?;
self.read_advance = de.size;
return Ok(ReceiveOutcome::Payload(de.frame));
}
}
return Err(err);
}
};
self.read_buffer.extend_from_slice(data);
Ok(ReceiveOutcome::Active)
}
}
impl<S, DF, SF> Debug for FrameDuplex<S, DF, SF>
where
S: for<'a> Session + 'static,
DF: DeserializeFrame + 'static,
SF: SerializeFrame + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FrameDuplex")
.field("session", &self.session)
.finish()
}
}
pub struct FramePublisher<S, F> {
session: S,
framing_strategy: F,
write_buffer: GrowableCircleBuf,
}
impl<S, F> FramePublisher<S, F>
where
S: for<'a> Publish<PublishPayload<'a> = &'a [u8]> + 'static,
F: SerializeFrame + 'static,
{
pub fn new(session: S, framing_strategy: F, write_buffer_capacity: usize) -> Self {
Self {
session,
framing_strategy,
write_buffer: GrowableCircleBuf::new(write_buffer_capacity)
.unwrap_or_else(|_| GrowableCircleBuf::new(usize::MAX / 2).unwrap()),
}
}
}
impl<S, F> Session for FramePublisher<S, F>
where
S: for<'a> Publish<PublishPayload<'a> = &'a [u8]> + 'static,
F: SerializeFrame + 'static,
{
fn status(&self) -> crate::SessionStatus {
self.session.status()
}
fn drive(&mut self) -> Result<DriveOutcome, std::io::Error> {
let mut outcome = self.session.drive()?;
if self.write_buffer.is_empty() {
return Ok(outcome);
}
let write_buffer = self.write_buffer.peek_read();
let wrote_len = match self.session.publish(write_buffer)? {
PublishOutcome::Published => write_buffer.len(),
PublishOutcome::Incomplete(pending) => write_buffer.len() - pending.len(),
};
self.write_buffer.advance_read(wrote_len)?;
if wrote_len > 0 {
outcome = DriveOutcome::Active;
}
Ok(outcome)
}
}
impl<S, F> Publish for FramePublisher<S, F>
where
S: for<'a> Publish<PublishPayload<'a> = &'a [u8]> + 'static,
F: SerializeFrame + 'static,
{
type PublishPayload<'a> = F::SerializedFrame<'a>;
fn publish<'a>(
&mut self,
frame: Self::PublishPayload<'a>,
) -> Result<PublishOutcome<Self::PublishPayload<'a>>, Error> {
if self.session.status() != SessionStatus::Established {
return Err(Error::new(
ErrorKind::NotConnected,
"underlying session is not established",
));
}
let outcome = self
.framing_strategy
.serialize_frame(frame, &mut self.write_buffer)?;
self.drive()?;
Ok(outcome)
}
}
impl<S, F> Flush for FramePublisher<S, F>
where
S: for<'a> Publish<PublishPayload<'a> = &'a [u8]> + Flush + 'static,
F: SerializeFrame + 'static,
{
fn flush(&mut self) -> Result<(), std::io::Error> {
while !self.write_buffer.is_empty() {
self.drive()?;
}
self.session.flush()
}
}
impl<S, F> Debug for FramePublisher<S, F>
where
S: for<'a> Session + 'static,
F: SerializeFrame + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FramePublisher")
.field("session", &self.session)
.finish()
}
}
pub struct FrameReceiver<S, F> {
session: S,
deserialize_frame: F,
read_buffer: Vec<u8>,
read_advance: usize,
}
impl<S, F> FrameReceiver<S, F>
where
S: Session + 'static,
F: DeserializeFrame + 'static,
{
pub fn new(session: S, deserialize_frame: F) -> Self {
Self {
session,
deserialize_frame,
read_buffer: Vec::new(),
read_advance: 0,
}
}
}
impl<S, F> Session for FrameReceiver<S, F>
where
S: for<'a> Session + 'static,
F: DeserializeFrame + 'static,
{
fn status(&self) -> crate::SessionStatus {
self.session.status()
}
fn drive(&mut self) -> Result<DriveOutcome, std::io::Error> {
self.session.drive()
}
}
impl<S, F> Receive for FrameReceiver<S, F>
where
S: for<'a> Receive<ReceivePayload<'a> = &'a [u8]> + 'static,
F: DeserializeFrame + 'static,
{
type ReceivePayload<'a> = F::DeserializedFrame<'a>;
fn receive<'a>(
&'a mut self,
) -> Result<ReceiveOutcome<Self::ReceivePayload<'a>>, std::io::Error> {
if self.read_advance != 0 {
let mut new_buf = Vec::from(&self.read_buffer[self.read_advance..]);
self.read_advance = 0;
swap(&mut new_buf, &mut self.read_buffer);
}
if self
.deserialize_frame
.check_deserialize_frame(&self.read_buffer, false)?
{
let de = self
.deserialize_frame
.deserialize_frame(&self.read_buffer)?;
self.read_advance = de.size;
return Ok(ReceiveOutcome::Payload(de.frame));
}
let data = match self.session.receive() {
Ok(ReceiveOutcome::Payload(data)) => data,
Ok(ReceiveOutcome::Active) => return Ok(ReceiveOutcome::Active),
Ok(ReceiveOutcome::Idle) => return Ok(ReceiveOutcome::Idle),
Err(err) => {
if let ErrorKind::UnexpectedEof = err.kind() {
if self
.deserialize_frame
.check_deserialize_frame(&self.read_buffer, true)?
{
let de = self
.deserialize_frame
.deserialize_frame(&self.read_buffer)?;
self.read_advance = de.size;
return Ok(ReceiveOutcome::Payload(de.frame));
}
}
return Err(err);
}
};
self.read_buffer.extend_from_slice(data);
Ok(ReceiveOutcome::Active)
}
}
impl<S, F> Debug for FrameReceiver<S, F>
where
S: for<'a> Session + 'static,
F: DeserializeFrame + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FrameReceiver")
.field("session", &self.session)
.finish()
}
}
pub trait DeserializeFrame {
type DeserializedFrame<'a>
where
Self: 'a;
fn check_deserialize_frame(&mut self, data: &[u8], eof: bool) -> Result<bool, Error>;
fn deserialize_frame<'a>(
&'a mut self,
data: &'a [u8],
) -> Result<SizedFrame<Self::DeserializedFrame<'a>>, Error>;
}
pub trait SerializeFrame {
type SerializedFrame<'a>
where
Self: 'a;
fn serialize_frame<'a>(
&mut self,
frame: Self::SerializedFrame<'a>,
buffer: &mut GrowableCircleBuf,
) -> Result<PublishOutcome<Self::SerializedFrame<'a>>, Error>;
}
pub struct SizedFrame<T> {
pub frame: T,
pub size: usize,
}
impl<T> SizedFrame<T> {
pub fn new(frame: T, size: usize) -> Self {
Self { frame, size }
}
}
pub struct U64FrameSerializer {
header: [u8; 8],
}
impl U64FrameSerializer {
pub fn new() -> Self {
Self { header: [0; 8] }
}
}
impl SerializeFrame for U64FrameSerializer {
type SerializedFrame<'a> = &'a [u8];
fn serialize_frame<'a>(
&mut self,
data: Self::SerializedFrame<'a>,
write_buffer: &mut GrowableCircleBuf,
) -> Result<PublishOutcome<Self::SerializedFrame<'a>>, Error> {
let len = u64::try_from(data.len())
.map_err(|_| Error::new(ErrorKind::InvalidData, "frame to serialize exceeds u64"))?;
self.header.copy_from_slice(&len.to_le_bytes());
if write_buffer.try_write(&vec![self.header.as_slice(), data])? {
Ok(PublishOutcome::Published)
} else {
Ok(PublishOutcome::Incomplete(data))
}
}
}
pub struct U64FrameDeserializer {}
impl U64FrameDeserializer {
pub fn new() -> Self {
Self {}
}
}
impl DeserializeFrame for U64FrameDeserializer {
type DeserializedFrame<'a> = &'a [u8];
fn check_deserialize_frame(&mut self, data: &[u8], _eof: bool) -> Result<bool, Error> {
if data.len() < 8 {
return Ok(false);
}
let len = u64::from_le_bytes(
data[..8]
.try_into()
.expect("expected 8 byte slice to be 8 bytes long"),
);
let ulen = usize::try_from(len).map_err(|_| {
Error::new(ErrorKind::InvalidData, "frame to deserialize exceeds usize")
})?;
Ok(data.len() - 8 >= ulen)
}
fn deserialize_frame<'a>(
&'a mut self,
data: &'a [u8],
) -> Result<SizedFrame<Self::DeserializedFrame<'a>>, Error> {
if data.len() < 8 {
return Err(Error::new(
ErrorKind::InvalidData,
"cannot deserialize partial frame",
));
}
let len = u64::from_le_bytes(
data[..8]
.try_into()
.expect("expected 8 byte slice to be 8 bytes long"),
);
let ulen = usize::try_from(len).map_err(|_| {
Error::new(ErrorKind::InvalidData, "frame to deserialize exceeds usize")
})?;
if data.len() - 8 >= ulen {
Ok(SizedFrame::new(&data[8..][..ulen], 8 + ulen))
} else {
Err(Error::new(
ErrorKind::InvalidData,
"cannot deserialize partial frame",
))
}
}
}