use crate::{
libp2p::read_write::ReadWrite,
util::{leb128, protobuf},
};
use alloc::{borrow::ToOwned as _, vec::Vec};
use core::{cmp, fmt, mem, ops};
pub struct WebRtcFraming {
inner_stream_expected_incoming_bytes: Option<usize>,
receive_buffer: Vec<u8>,
remote_write_state: RemoteWriteState,
local_write_state: LocalWriteState,
}
enum LocalWriteState {
Open,
FinBuffered,
FinAcked,
}
enum RemoteWriteState {
Open,
Closed,
ClosedAckBuffered,
}
const RECEIVE_BUFFER_CAPACITY: usize = 2048;
const PROTOBUF_FRAME_MIN_LEN: usize = 2;
const PROTOBUF_FRAME_MAX_LEN: usize = 8; const MAX_PROTOBUF_MESSAGE_LEN: usize = 16384;
impl WebRtcFraming {
pub fn new() -> Self {
WebRtcFraming {
inner_stream_expected_incoming_bytes: None,
receive_buffer: Vec::with_capacity(RECEIVE_BUFFER_CAPACITY),
remote_write_state: RemoteWriteState::Open,
local_write_state: LocalWriteState::Open,
}
}
pub fn read_write<'a, TNow: Clone>(
&'a mut self,
outer_read_write: &'a mut ReadWrite<TNow>,
) -> Result<InnerReadWrite<'a, TNow>, Error> {
loop {
if self
.inner_stream_expected_incoming_bytes
.map_or(true, |rq_bytes| rq_bytes <= self.receive_buffer.len())
{
break;
}
let bytes_to_discard = {
let mut parser = nom::combinator::map_parser::<_, _, nom::error::Error<&[u8]>, _, _>(
nom::multi::length_data(crate::util::leb128::nom_leb128_usize),
protobuf::message_decode! {
#[optional] flags = 1 => protobuf::enum_tag_decode,
#[optional] message = 2 => protobuf::bytes_tag_decode,
},
);
match nom::Parser::parse(&mut parser, &outer_read_write.incoming_buffer) {
Ok((rest, framed_message)) => {
if framed_message.flags.map_or(false, |f| f == 2) {
return Err(Error::RemoteResetDesired);
}
if framed_message.message.map_or(false, |msg| !msg.is_empty())
&& !matches!(self.remote_write_state, RemoteWriteState::Open)
{
return Err(Error::DataAfterFin);
}
if framed_message.flags.map_or(false, |f| f == 3) {
if matches!(self.local_write_state, LocalWriteState::Open) {
return Err(Error::FinAckWithoutFin);
}
self.local_write_state = LocalWriteState::FinAcked;
}
if matches!(self.remote_write_state, RemoteWriteState::Open)
&& framed_message.flags.map_or(false, |f| f == 0)
{
self.remote_write_state = RemoteWriteState::Closed;
}
if let Some(message) = framed_message.message {
self.receive_buffer.extend_from_slice(message);
}
outer_read_write.incoming_buffer.len() - rest.len()
}
Err(nom::Err::Incomplete(needed)) => {
let Some(expected_incoming_bytes) =
&mut outer_read_write.expected_incoming_bytes
else {
return Err(Error::EofIncompleteFrame);
};
*expected_incoming_bytes = outer_read_write.incoming_buffer.len()
+ match needed {
nom::Needed::Size(s) => s.get(),
nom::Needed::Unknown => 1,
};
break;
}
Err(_) => {
return Err(Error::InvalidFrame);
}
}
};
let _extract_result = outer_read_write.incoming_bytes_take(bytes_to_discard);
debug_assert!(matches!(_extract_result, Ok(Some(_))));
}
Ok(InnerReadWrite {
inner_read_write: ReadWrite {
now: outer_read_write.now.clone(),
incoming_buffer: mem::take(&mut self.receive_buffer),
read_bytes: 0,
expected_incoming_bytes: if matches!(
self.remote_write_state,
RemoteWriteState::Open
) {
Some(0)
} else {
None
},
write_buffers: Vec::new(),
write_bytes_queued: 0,
write_bytes_queueable: if matches!(self.local_write_state, LocalWriteState::Open) {
outer_read_write
.write_bytes_queueable
.map(|outer_writable| {
cmp::min(
outer_writable.saturating_sub(PROTOBUF_FRAME_MAX_LEN),
MAX_PROTOBUF_MESSAGE_LEN - PROTOBUF_FRAME_MAX_LEN,
)
})
} else {
None
},
wake_up_after: outer_read_write.wake_up_after.clone(),
},
framing: self,
outer_read_write,
})
}
}
impl fmt::Debug for WebRtcFraming {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("WebRtcFraming").finish()
}
}
pub struct InnerReadWrite<'a, TNow: Clone> {
framing: &'a mut WebRtcFraming,
outer_read_write: &'a mut ReadWrite<TNow>,
inner_read_write: ReadWrite<TNow>,
}
impl<'a, TNow: Clone> ops::Deref for InnerReadWrite<'a, TNow> {
type Target = ReadWrite<TNow>;
fn deref(&self) -> &Self::Target {
&self.inner_read_write
}
}
impl<'a, TNow: Clone> ops::DerefMut for InnerReadWrite<'a, TNow> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner_read_write
}
}
impl<'a, TNow: Clone> Drop for InnerReadWrite<'a, TNow> {
fn drop(&mut self) {
if self.framing.inner_stream_expected_incoming_bytes.is_none()
|| self.inner_read_write.read_bytes != 0
{
self.outer_read_write.wake_up_asap();
}
self.outer_read_write.wake_up_after = self.inner_read_write.wake_up_after.clone();
self.framing.receive_buffer = mem::take(&mut self.inner_read_write.incoming_buffer);
self.framing.inner_stream_expected_incoming_bytes =
Some(self.inner_read_write.expected_incoming_bytes.unwrap_or(0));
if let Some(expected_incoming_bytes) = &mut self.outer_read_write.expected_incoming_bytes {
*expected_incoming_bytes = cmp::max(
*expected_incoming_bytes,
self.inner_read_write.expected_incoming_bytes.unwrap_or(0) + PROTOBUF_FRAME_MIN_LEN,
);
}
let flag_to_send_out: Option<u32> =
if matches!(self.framing.local_write_state, LocalWriteState::Open)
&& self.inner_read_write.write_bytes_queueable.is_none()
{
self.framing.local_write_state = LocalWriteState::FinBuffered;
Some(0)
} else if matches!(self.framing.remote_write_state, RemoteWriteState::Closed) {
self.framing.remote_write_state = RemoteWriteState::ClosedAckBuffered;
Some(3)
} else {
None
};
if flag_to_send_out.is_some() || self.inner_read_write.write_bytes_queued != 0 {
let message_length_prefix_index = self.outer_read_write.write_buffers.len();
self.outer_read_write
.write_buffers
.push(Vec::with_capacity(4));
let mut length_prefix_value = 0;
if let Some(flag_to_send_out) = flag_to_send_out {
for buffer in protobuf::uint32_tag_encode(1, flag_to_send_out) {
let buffer = buffer.as_ref();
length_prefix_value += buffer.len();
self.outer_read_write.write_buffers.push(buffer.to_owned());
}
}
let data_protobuf_tag = protobuf::tag_encode(2, 2).collect::<Vec<_>>();
length_prefix_value += data_protobuf_tag.len();
self.outer_read_write.write_buffers.push(data_protobuf_tag);
let data_len =
leb128::encode_usize(self.inner_read_write.write_bytes_queued).collect::<Vec<_>>();
length_prefix_value += data_len.len();
self.outer_read_write.write_buffers.push(data_len);
length_prefix_value += self.inner_read_write.write_bytes_queued;
self.outer_read_write
.write_buffers
.extend(mem::take(&mut self.inner_read_write.write_buffers));
let length_prefix = leb128::encode_usize(length_prefix_value).collect::<Vec<_>>();
let total_length = length_prefix_value + length_prefix.len();
self.outer_read_write.write_buffers[message_length_prefix_index] = length_prefix;
self.outer_read_write.write_bytes_queued += total_length;
*self
.outer_read_write
.write_bytes_queueable
.as_mut()
.unwrap() -= total_length;
}
}
}
#[derive(Debug, derive_more::Display, derive_more::Error)]
pub enum Error {
RemoteResetDesired,
InvalidFrame,
DataAfterFin,
EofIncompleteFrame,
FinAckWithoutFin,
}