#![allow(dead_code)]
use alloc::collections::{BTreeMap, VecDeque};
use alloc::vec::Vec;
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct StreamId(pub u64);
impl StreamId {
pub fn is_client_initiated(self) -> bool {
self.0 & 0x1 == 0
}
pub fn is_server_initiated(self) -> bool {
self.0 & 0x1 == 1
}
pub fn is_bidi(self) -> bool {
self.0 & 0x2 == 0
}
pub fn is_uni(self) -> bool {
self.0 & 0x2 != 0
}
pub fn value(self) -> u64 {
self.0
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(crate) enum StreamKind {
Bidi,
Uni,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(crate) enum SendState {
Ready,
Send,
DataSent,
DataRecvd,
ResetSent,
ResetRecvd,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(crate) enum RecvState {
Recv,
SizeKnown,
DataRecvd,
DataRead,
ResetRecvd,
ResetRead,
}
pub(crate) struct SendStream {
pub(crate) state: SendState,
pub(crate) write_buf: VecDeque<u8>,
pub(crate) write_off: u64,
pub(crate) sent_offset: u64,
pub(crate) acked_offset: u64,
pub(crate) fin_offset: Option<u64>,
pub(crate) fin_sent: bool,
pub(crate) peer_max_data: u64,
pub(crate) reset_code: Option<u64>,
pub(crate) blocked_at: Option<u64>,
pub(crate) reset_pending: bool,
pub(crate) sent_chunks: VecDeque<(u64, Vec<u8>, bool)>,
}
impl SendStream {
pub(crate) fn new(peer_max_data: u64) -> Self {
Self {
state: SendState::Ready,
write_buf: VecDeque::new(),
write_off: 0,
sent_offset: 0,
acked_offset: 0,
fin_offset: None,
fin_sent: false,
peer_max_data,
reset_code: None,
blocked_at: None,
reset_pending: false,
sent_chunks: VecDeque::new(),
}
}
pub(crate) fn has_outbound(&self) -> bool {
if !self.write_buf.is_empty() {
return true;
}
if self.fin_offset.is_some() && !self.fin_sent {
return true;
}
false
}
pub(crate) fn available_credit(&self) -> u64 {
self.peer_max_data.saturating_sub(self.sent_offset)
}
pub(crate) fn enqueue(&mut self, data: &[u8]) -> usize {
if self.state != SendState::Ready && self.state != SendState::Send {
return 0;
}
let cap = self.available_credit();
let already_buffered = self.write_buf.len() as u64;
let stream_room = cap.saturating_sub(already_buffered);
let take = core::cmp::min(stream_room as usize, data.len());
if take == 0 {
return 0;
}
self.write_buf.extend(data[..take].iter().copied());
if matches!(self.state, SendState::Ready) {
self.state = SendState::Send;
}
take
}
pub(crate) fn finish(&mut self) {
if self.fin_offset.is_some() {
return;
}
let fin_off = self.write_off + self.write_buf.len() as u64;
self.fin_offset = Some(fin_off);
}
pub(crate) fn carve(&mut self, cap: usize) -> Option<(u64, Vec<u8>, bool)> {
if !self.has_outbound() {
return None;
}
let offset = self.write_off;
let take = core::cmp::min(cap, self.write_buf.len());
let mut bytes: Vec<u8> = Vec::with_capacity(take);
for _ in 0..take {
bytes.push(self.write_buf.pop_front().expect("just-checked"));
}
self.write_off += take as u64;
if self.write_off > self.sent_offset {
self.sent_offset = self.write_off;
}
let fin = matches!(self.fin_offset, Some(fin) if self.write_off == fin && self.write_buf.is_empty())
&& !self.fin_sent;
if fin {
self.fin_sent = true;
self.state = SendState::DataSent;
} else if matches!(self.state, SendState::Ready) && !bytes.is_empty() {
self.state = SendState::Send;
}
self.sent_chunks.push_back((offset, bytes.clone(), fin));
Some((offset, bytes, fin))
}
pub(crate) fn requeue_all_sent(&mut self) {
let mut earliest_off = self.write_off;
let mut any_fin = false;
let chunks: alloc::vec::Vec<(u64, alloc::vec::Vec<u8>, bool)> =
self.sent_chunks.drain(..).collect();
for (off, _bytes, fin) in chunks.iter() {
if *off < earliest_off {
earliest_off = *off;
}
if *fin {
any_fin = true;
}
}
let mut sorted = chunks;
sorted.sort_by_key(|c| c.0);
let mut new_buf: VecDeque<u8> = VecDeque::new();
let mut cur_off = earliest_off;
for (off, bytes, _fin) in sorted.iter() {
if off + bytes.len() as u64 <= cur_off {
continue;
}
let skip = cur_off.saturating_sub(*off) as usize;
if skip < bytes.len() {
for &b in &bytes[skip..] {
new_buf.push_back(b);
}
cur_off = off + bytes.len() as u64;
}
}
while let Some(b) = self.write_buf.pop_front() {
new_buf.push_back(b);
}
self.write_buf = new_buf;
self.write_off = earliest_off;
if any_fin {
self.fin_sent = false;
}
}
pub(crate) fn has_unacked(&self) -> bool {
!self.sent_chunks.is_empty()
}
pub(crate) fn requeue(&mut self, offset: u64, bytes: &[u8], was_fin: bool) {
let mut new_buf: VecDeque<u8> = VecDeque::with_capacity(bytes.len() + self.write_buf.len());
for b in bytes.iter() {
new_buf.push_back(*b);
}
while let Some(b) = self.write_buf.pop_front() {
new_buf.push_back(b);
}
self.write_buf = new_buf;
self.write_off = offset;
if was_fin {
self.fin_sent = false;
}
}
pub(crate) fn enter_reset(&mut self, code: u64) {
self.write_buf.clear();
self.reset_code = Some(code);
self.reset_pending = true;
self.state = SendState::ResetSent;
}
}
pub(crate) const MAX_PENDING_FRAGMENTS: usize = 128;
pub(crate) struct RecvStream {
pub(crate) state: RecvState,
pub(crate) delivered: VecDeque<u8>,
pub(crate) next_offset: u64,
pub(crate) read_off: u64,
pub(crate) pending: BTreeMap<u64, Vec<u8>>,
pub(crate) fin_offset: Option<u64>,
pub(crate) max_data: u64,
pub(crate) max_data_announced: u64,
pub(crate) reset_code: Option<u64>,
pub(crate) stop_sending_sent: bool,
pub(crate) max_data_pending: bool,
}
impl RecvStream {
pub(crate) fn new(max_data: u64) -> Self {
Self {
state: RecvState::Recv,
delivered: VecDeque::new(),
next_offset: 0,
read_off: 0,
pending: BTreeMap::new(),
fin_offset: None,
max_data,
max_data_announced: max_data,
reset_code: None,
stop_sending_sent: false,
max_data_pending: false,
}
}
pub(crate) fn is_readable(&self) -> bool {
if !self.delivered.is_empty() {
return true;
}
if matches!(self.state, RecvState::DataRecvd | RecvState::ResetRecvd) {
return true;
}
false
}
pub(crate) fn on_data(
&mut self,
mut offset: u64,
mut data: &[u8],
fin: bool,
) -> Result<u64, crate::tls::Error> {
if matches!(self.state, RecvState::ResetRecvd | RecvState::ResetRead) {
return Ok(0);
}
let end = offset
.checked_add(data.len() as u64)
.ok_or(crate::tls::Error::Decode)?;
if end > self.max_data {
return Err(crate::tls::Error::Decode);
}
if let Some(prev_fin) = self.fin_offset {
if end > prev_fin {
return Err(crate::tls::Error::Decode);
}
if fin && end != prev_fin {
return Err(crate::tls::Error::Decode);
}
}
if fin {
let fin_off = end;
match self.fin_offset {
Some(prev) if prev != fin_off => return Err(crate::tls::Error::Decode),
_ => self.fin_offset = Some(fin_off),
}
}
if self.stop_sending_sent {
return Ok(0);
}
if offset < self.next_offset {
let skip = (self.next_offset - offset) as usize;
if skip >= data.len() {
if fin && self.next_offset == end && self.pending.is_empty() {
self.state = RecvState::DataRecvd;
}
return Ok(0);
}
data = &data[skip..];
offset = self.next_offset;
}
let mut newly_contig: u64 = 0;
if offset == self.next_offset {
self.delivered.extend(data.iter().copied());
newly_contig += data.len() as u64;
self.next_offset += data.len() as u64;
while let Some((&p_off, _)) = self.pending.iter().next()
&& p_off <= self.next_offset
{
let frag = self.pending.remove(&p_off).expect("just-peeked");
let p_end = p_off + frag.len() as u64;
if p_end <= self.next_offset {
continue; }
let skip = (self.next_offset - p_off) as usize;
let take = &frag[skip..];
self.delivered.extend(take.iter().copied());
newly_contig += take.len() as u64;
self.next_offset = p_end;
}
} else {
let new_end = offset + data.len() as u64;
let mut should_insert = true;
if let Some((&prev_off, prev_data)) = self.pending.range(..=offset).next_back() {
let prev_end = prev_off + prev_data.len() as u64;
if prev_end >= new_end {
should_insert = false;
}
}
if should_insert {
let existing = self.pending.get(&offset).map(|v| v.len()).unwrap_or(0);
if data.len() > existing {
if !self.pending.contains_key(&offset)
&& self.pending.len() >= MAX_PENDING_FRAGMENTS
{
return Err(crate::tls::Error::Decode);
}
self.pending.insert(offset, data.to_vec());
}
}
}
if self.fin_offset.is_some() {
if Some(self.next_offset) == self.fin_offset && self.pending.is_empty() {
self.state = RecvState::DataRecvd;
} else if !matches!(self.state, RecvState::DataRecvd) {
self.state = RecvState::SizeKnown;
}
}
Ok(newly_contig)
}
pub(crate) fn read(&mut self, into: &mut [u8]) -> (usize, bool) {
let mut copied = 0;
while copied < into.len() {
match self.delivered.pop_front() {
Some(b) => {
into[copied] = b;
copied += 1;
}
None => break,
}
}
self.read_off += copied as u64;
let fin_seen = matches!(self.fin_offset, Some(fin) if self.read_off == fin)
&& matches!(
self.state,
RecvState::DataRecvd | RecvState::SizeKnown | RecvState::DataRead
);
if fin_seen && self.delivered.is_empty() {
self.state = RecvState::DataRead;
}
(copied, fin_seen)
}
pub(crate) fn on_reset(&mut self, code: u64, final_size: u64) -> Result<(), crate::tls::Error> {
if final_size < self.next_offset {
return Err(crate::tls::Error::Decode);
}
for (&p_off, p_data) in self.pending.iter() {
let p_end = p_off + p_data.len() as u64;
if final_size < p_end {
return Err(crate::tls::Error::Decode);
}
}
if let Some(fin) = self.fin_offset
&& final_size != fin
{
return Err(crate::tls::Error::Decode);
}
if matches!(self.state, RecvState::ResetRecvd | RecvState::ResetRead) {
return Ok(()); }
self.pending.clear();
self.reset_code = Some(code);
self.fin_offset = Some(final_size);
self.state = RecvState::ResetRecvd;
Ok(())
}
pub(crate) fn ack_reset(&mut self) {
if matches!(self.state, RecvState::ResetRecvd) {
self.state = RecvState::ResetRead;
}
}
}
pub(crate) struct Stream {
pub(crate) id: StreamId,
pub(crate) send: Option<SendStream>,
pub(crate) recv: Option<RecvStream>,
}
impl Stream {
pub(crate) fn new_send(id: StreamId, peer_max_data: u64) -> Self {
Self {
id,
send: Some(SendStream::new(peer_max_data)),
recv: None,
}
}
pub(crate) fn new_recv(id: StreamId, max_data: u64) -> Self {
Self {
id,
send: None,
recv: Some(RecvStream::new(max_data)),
}
}
pub(crate) fn new_bidi(id: StreamId, peer_max_data: u64, self_max_data: u64) -> Self {
Self {
id,
send: Some(SendStream::new(peer_max_data)),
recv: Some(RecvStream::new(self_max_data)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn stream_id_helpers() {
let a = StreamId(0);
assert!(a.is_client_initiated());
assert!(!a.is_server_initiated());
assert!(a.is_bidi());
assert!(!a.is_uni());
let b = StreamId(1);
assert!(!b.is_client_initiated());
assert!(b.is_server_initiated());
assert!(b.is_bidi());
assert!(!b.is_uni());
let c = StreamId(2);
assert!(c.is_client_initiated());
assert!(!c.is_server_initiated());
assert!(!c.is_bidi());
assert!(c.is_uni());
let d = StreamId(3);
assert!(!d.is_client_initiated());
assert!(d.is_server_initiated());
assert!(!d.is_bidi());
assert!(d.is_uni());
let e = StreamId(0x4000_0000); assert!(e.is_client_initiated());
assert!(e.is_bidi());
}
#[test]
fn send_state_transitions() {
let mut s = SendStream::new(1024);
assert_eq!(s.state, SendState::Ready);
let n = s.enqueue(b"hello");
assert_eq!(n, 5);
assert_eq!(s.state, SendState::Send);
let (off, bytes, fin) = s.carve(100).expect("carve");
assert_eq!(off, 0);
assert_eq!(bytes, b"hello");
assert!(!fin);
assert_eq!(s.state, SendState::Send);
s.finish();
let (off, bytes, fin) = s.carve(100).expect("carve-fin");
assert_eq!(off, 5);
assert!(bytes.is_empty());
assert!(fin);
assert_eq!(s.state, SendState::DataSent);
let mut s2 = SendStream::new(1024);
let _ = s2.enqueue(b"abc");
s2.enter_reset(7);
assert_eq!(s2.state, SendState::ResetSent);
assert!(s2.write_buf.is_empty());
assert_eq!(s2.reset_code, Some(7));
}
#[test]
fn recv_state_transitions() {
let mut r = RecvStream::new(1024);
assert_eq!(r.state, RecvState::Recv);
let n = r.on_data(0, b"hello", false).unwrap();
assert_eq!(n, 5);
assert_eq!(r.state, RecvState::Recv);
let mut buf = [0u8; 16];
let (got, fin) = r.read(&mut buf);
assert_eq!(got, 5);
assert!(!fin);
assert_eq!(&buf[..got], b"hello");
let n = r.on_data(5, b"world", true).unwrap();
assert_eq!(n, 5);
assert_eq!(r.state, RecvState::DataRecvd);
let (got, fin) = r.read(&mut buf);
assert_eq!(got, 5);
assert!(fin);
assert_eq!(r.state, RecvState::DataRead);
let mut r2 = RecvStream::new(1024);
r2.on_reset(7, 0).unwrap();
assert_eq!(r2.state, RecvState::ResetRecvd);
r2.ack_reset();
assert_eq!(r2.state, RecvState::ResetRead);
}
#[test]
fn recv_out_of_order_reassembly() {
let mut r = RecvStream::new(1024);
let n = r.on_data(100, &[b'C'; 50], false).unwrap();
assert_eq!(n, 0);
assert!(r.delivered.is_empty());
let n = r.on_data(50, &[b'B'; 50], false).unwrap();
assert_eq!(n, 0);
let n = r.on_data(0, &[b'A'; 50], false).unwrap();
assert_eq!(n, 150);
assert_eq!(r.next_offset, 150);
assert!(r.pending.is_empty());
let n = r.on_data(150, &[b'D'; 10], true).unwrap();
assert_eq!(n, 10);
assert_eq!(r.state, RecvState::DataRecvd);
}
#[test]
fn recv_duplicate_dropped() {
let mut r = RecvStream::new(1024);
let n1 = r.on_data(0, b"hello", false).unwrap();
let n2 = r.on_data(0, b"hello", false).unwrap();
assert_eq!(n1, 5);
assert_eq!(n2, 0);
assert_eq!(r.next_offset, 5);
let n3 = r.on_data(1, b"ell", false).unwrap();
assert_eq!(n3, 0);
}
#[test]
fn fin_only_after_all_data() {
let mut r = RecvStream::new(1024);
let _ = r.on_data(50, &[b'B'; 50], true).unwrap();
assert_eq!(r.fin_offset, Some(100));
assert_eq!(r.state, RecvState::SizeKnown);
let n = r.on_data(0, &[b'A'; 50], false).unwrap();
assert_eq!(n, 100);
assert_eq!(r.state, RecvState::DataRecvd);
}
#[test]
fn send_enqueue_respects_credit() {
let mut s = SendStream::new(100);
let n = s.enqueue(&[0u8; 200]);
assert_eq!(n, 100);
assert_eq!(s.write_buf.len(), 100);
let n = s.enqueue(&[0u8; 50]);
assert_eq!(n, 0);
}
#[test]
fn carve_advances_offsets_and_marks_fin() {
let mut s = SendStream::new(1024);
let _ = s.enqueue(b"hello world");
s.finish();
let (off, bytes, fin) = s.carve(5).unwrap();
assert_eq!(off, 0);
assert_eq!(bytes, b"hello");
assert!(!fin);
assert_eq!(s.sent_offset, 5);
let (off, bytes, fin) = s.carve(100).unwrap();
assert_eq!(off, 5);
assert_eq!(bytes, b" world");
assert!(fin);
assert_eq!(s.state, SendState::DataSent);
}
#[test]
fn requeue_rewinds_write_off() {
let mut s = SendStream::new(1024);
let _ = s.enqueue(b"hello");
let (off, bytes, _fin) = s.carve(5).unwrap();
assert_eq!(off, 0);
let _ = s.enqueue(b" world");
s.requeue(off, &bytes, false);
let (off2, bytes2, _fin) = s.carve(100).unwrap();
assert_eq!(off2, 0);
assert_eq!(bytes2, b"hello world");
}
#[test]
fn recv_flow_control_overshoot_errors() {
let mut r = RecvStream::new(50);
let err = r.on_data(0, &[0u8; 51], false);
assert!(err.is_err());
let ok = r.on_data(0, &[0u8; 50], false);
assert!(ok.is_ok());
}
#[test]
fn reset_clears_pending_and_state() {
let mut r = RecvStream::new(1024);
let _ = r.on_data(100, &[b'X'; 10], false).unwrap();
r.on_reset(42, 200).unwrap();
assert!(r.pending.is_empty());
assert_eq!(r.reset_code, Some(42));
assert_eq!(r.state, RecvState::ResetRecvd);
}
#[test]
fn recv_data_past_fin_offset_errors() {
let mut r = RecvStream::new(1024);
let n = r.on_data(0, &[b'A'; 100], true).unwrap();
assert_eq!(n, 100);
assert_eq!(r.fin_offset, Some(100));
let err = r.on_data(150, &[b'B'; 10], false);
assert!(err.is_err(), "data past fin_offset must error");
let _ = r.on_data(99, &[b'A'; 1], true); }
#[test]
fn recv_contradictory_fin_errors() {
let mut r = RecvStream::new(1024);
let _ = r.on_data(0, &[b'A'; 100], true).unwrap();
let err = r.on_data(0, &[b'A'; 120], true);
assert!(err.is_err(), "contradictory FIN must error");
}
#[test]
fn reset_below_pending_end_errors() {
let mut r = RecvStream::new(1024);
let _ = r.on_data(100, &[b'C'; 50], false).unwrap();
let err = r.on_reset(0, 80);
assert!(err.is_err(), "reset below pending end must error");
let ok = r.on_reset(0, 200);
assert!(ok.is_ok());
}
#[test]
fn recv_pending_fragments_are_bounded() {
let mut r = RecvStream::new(1u64 << 30);
for i in 0..MAX_PENDING_FRAGMENTS {
let off = 1 + (i as u64) * 2;
r.on_data(off, &[0xABu8; 1], false)
.expect("fragment within cap");
}
assert_eq!(r.pending.len(), MAX_PENDING_FRAGMENTS);
let off = 1 + (MAX_PENDING_FRAGMENTS as u64) * 2;
let err = r.on_data(off, &[0xABu8; 1], false);
assert!(err.is_err(), "fragment beyond cap must be rejected");
assert_eq!(r.pending.len(), MAX_PENDING_FRAGMENTS);
}
#[test]
fn recv_pending_replacement_does_not_grow_map() {
let mut r = RecvStream::new(1u64 << 30);
for i in 0..MAX_PENDING_FRAGMENTS {
let off = 1 + (i as u64) * 4;
r.on_data(off, &[0xCDu8; 1], false).expect("fragment");
}
assert_eq!(r.pending.len(), MAX_PENDING_FRAGMENTS);
r.on_data(1, &[0xCDu8; 2], false).expect("replacement");
assert_eq!(r.pending.len(), MAX_PENDING_FRAGMENTS);
}
}