use crate::constants::{FRAME_HEADER_SIZE, FRAME_LENGTH_FIELD_SIZE};
use crate::frame::{DecodedFrame, FrameCodec, FrameDecodeError, FrameKind};
use std::collections::{BTreeMap, HashMap, VecDeque};
pub struct FrameMuxStreamDecoder {
buffer: Vec<u8>, streams: HashMap<u32, StreamReassembly>, }
struct StreamReassembly {
next_expected: u32, buffer: BTreeMap<u32, DecodedFrame>, is_canceled: bool,
is_ended: bool,
}
pub struct FrameDecoderIterator {
queue: VecDeque<Result<DecodedFrame, FrameDecodeError>>,
}
impl Iterator for FrameDecoderIterator {
type Item = Result<DecodedFrame, FrameDecodeError>;
fn next(&mut self) -> Option<Self::Item> {
self.queue.pop_front()
}
}
impl Default for FrameMuxStreamDecoder {
fn default() -> Self {
Self::new()
}
}
impl FrameMuxStreamDecoder {
pub fn new() -> Self {
Self {
buffer: Vec::new(),
streams: HashMap::new(),
}
}
pub fn read_bytes(&mut self, data: &[u8]) -> FrameDecoderIterator {
self.buffer.extend_from_slice(data);
let mut queue = VecDeque::new();
while self.buffer.len() >= FRAME_LENGTH_FIELD_SIZE {
let len = match self
.buffer
.get(..FRAME_LENGTH_FIELD_SIZE)
.and_then(|bytes| bytes.try_into().ok())
.map(u32::from_le_bytes)
{
Some(n) => n as usize,
None => {
queue.push_back(Err(FrameDecodeError::IncompleteHeader));
break;
}
};
let total = FRAME_HEADER_SIZE + len;
if self.buffer.len() < total {
break;
}
match FrameCodec::decode(&self.buffer[..total]) {
Ok(mut frame) => {
let stream_id = frame.inner.stream_id;
let frame_kind = frame.inner.kind;
self.buffer.drain(..total);
if let Some(stream) = self.streams.get(&stream_id)
&& stream.is_canceled
{
frame.decode_error = Some(FrameDecodeError::ReadAfterCancel);
queue.push_back(Ok(frame));
continue;
}
if frame_kind == FrameKind::Cancel {
if let Some(stream) = self.streams.get_mut(&stream_id) {
stream.is_canceled = true;
}
frame.decode_error = Some(FrameDecodeError::ReadAfterCancel);
queue.push_back(Ok(frame));
self.streams.remove(&stream_id);
continue;
}
let stream =
self.streams
.entry(stream_id)
.or_insert_with(|| StreamReassembly {
next_expected: 0,
buffer: BTreeMap::new(),
is_canceled: false,
is_ended: false,
});
if frame_kind == FrameKind::End {
stream.is_ended = true;
}
stream.buffer.insert(frame.inner.seq_id, frame);
while let Some(buffered_frame) = stream.buffer.remove(&stream.next_expected) {
stream.next_expected += 1;
queue.push_back(Ok(buffered_frame));
}
if stream.is_ended && stream.buffer.is_empty() {
self.streams.remove(&stream_id);
}
}
Err(e) => {
self.buffer.drain(..total);
queue.push_back(Err(e));
continue;
}
}
}
FrameDecoderIterator { queue }
}
}