#![allow(dead_code)]
use alloc::collections::BTreeMap;
use alloc::vec::Vec;
use crate::tls::Error;
pub(crate) const MAX_PENDING_CRYPTO_BYTES: usize = 64 * 1024;
pub(crate) const MAX_PENDING_FRAGMENTS: usize = 32;
#[derive(Default)]
pub(crate) struct CryptoBuf {
next_offset: u64,
pending: BTreeMap<u64, Vec<u8>>,
outbound: Vec<u8>,
outbound_offset: u64,
last_sent: Option<(u64, Vec<u8>)>,
sent_history: BTreeMap<u64, Vec<u8>>,
}
impl CryptoBuf {
pub(crate) fn new() -> Self {
Self::default()
}
pub(crate) fn next_offset(&self) -> u64 {
self.next_offset
}
pub(crate) fn is_pending_empty(&self) -> bool {
self.pending.is_empty()
}
pub(crate) fn on_crypto(&mut self, mut offset: u64, mut data: &[u8]) -> Result<Vec<u8>, Error> {
if offset < self.next_offset {
let skip = (self.next_offset - offset) as usize;
if skip >= data.len() {
return Ok(Vec::new());
}
data = &data[skip..];
offset = self.next_offset;
}
if data.is_empty() {
return Ok(Vec::new());
}
if offset == self.next_offset {
let mut released = data.to_vec();
self.next_offset += data.len() as u64;
while let Some((&p_off, _)) = self.pending.iter().next()
&& p_off <= self.next_offset
{
let frag = self.pending.remove(&p_off).expect("just peeked");
let end = p_off + frag.len() as u64;
if end <= self.next_offset {
continue;
}
let skip = (self.next_offset - p_off) as usize;
let new_bytes = &frag[skip..];
released.extend_from_slice(new_bytes);
self.next_offset = end;
}
return Ok(released);
}
let end = offset + data.len() as u64;
if let Some((&prev_off, prev_data)) = self.pending.range(..offset).next_back() {
let prev_end = prev_off + prev_data.len() as u64;
if prev_end >= end {
return Ok(Vec::new());
}
}
let existing_len = self.pending.get(&offset).map(|v| v.len()).unwrap_or(0);
if data.len() <= existing_len {
return Ok(Vec::new());
}
let new_total_bytes = self
.total_pending_bytes()
.saturating_add(data.len() - existing_len);
if new_total_bytes > MAX_PENDING_CRYPTO_BYTES {
return Err(Error::Decode);
}
if existing_len == 0 && self.pending.len() >= MAX_PENDING_FRAGMENTS {
return Err(Error::Decode);
}
self.pending.insert(offset, data.to_vec());
Ok(Vec::new())
}
fn total_pending_bytes(&self) -> usize {
self.pending.values().map(Vec::len).sum()
}
pub(crate) fn enqueue_outbound(&mut self, data: &[u8]) {
self.outbound.extend_from_slice(data);
}
pub(crate) fn outbound_pending(&self) -> bool {
!self.outbound.is_empty()
}
pub(crate) fn outbound_len(&self) -> usize {
self.outbound.len()
}
pub(crate) fn outbound_offset_for_test(&self) -> u64 {
self.outbound_offset
}
pub(crate) fn carve(&mut self, cap: usize) -> Option<(u64, Vec<u8>)> {
if self.outbound.is_empty() {
return None;
}
let take = core::cmp::min(cap, self.outbound.len());
let chunk = self.outbound.drain(..take).collect::<Vec<u8>>();
let offset = self.outbound_offset;
self.outbound_offset += chunk.len() as u64;
self.last_sent = Some((offset, chunk.clone()));
self.sent_history.insert(offset, chunk.clone());
Some((offset, chunk))
}
pub(crate) fn requeue_range(&mut self, offset: u64, length: u64) -> bool {
if length == 0 {
return false;
}
let mut bytes_to_requeue: Vec<u8> = Vec::new();
let mut cursor = offset;
let end = offset.saturating_add(length);
while cursor < end {
let entry = self.sent_history.range(..=cursor).next_back();
let (entry_off, entry_bytes) = match entry {
Some((k, v)) => (*k, v.clone()),
None => return false,
};
let entry_end = entry_off + entry_bytes.len() as u64;
if entry_end <= cursor {
return false;
}
let local_skip = (cursor - entry_off) as usize;
let local_take =
core::cmp::min((end - cursor) as usize, entry_bytes.len() - local_skip);
bytes_to_requeue.extend_from_slice(&entry_bytes[local_skip..local_skip + local_take]);
cursor += local_take as u64;
}
let mut new_buf = bytes_to_requeue;
new_buf.append(&mut self.outbound);
self.outbound = new_buf;
self.outbound_offset = offset;
true
}
pub(crate) fn schedule_last_chunk_retransmit(&mut self) -> bool {
if let Some((off, bytes)) = self.last_sent.as_ref() {
let mut new_buf = bytes.clone();
new_buf.append(&mut self.outbound);
self.outbound = new_buf;
self.outbound_offset = *off;
true
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn in_order_pass_through() {
let mut b = CryptoBuf::new();
let out = b.on_crypto(0, b"hello").expect("ok");
assert_eq!(out, b"hello");
assert_eq!(b.next_offset(), 5);
assert!(b.is_pending_empty());
let out = b.on_crypto(5, b" world").expect("ok");
assert_eq!(out, b" world");
assert_eq!(b.next_offset(), 11);
}
#[test]
fn out_of_order_then_in_order_merges() {
let mut b = CryptoBuf::new();
let out = b.on_crypto(100, b"second-block").expect("ok");
assert!(out.is_empty());
assert_eq!(b.next_offset(), 0);
assert!(!b.is_pending_empty());
let filler = alloc::vec![b'a'; 100];
let out = b.on_crypto(0, &filler).expect("ok");
assert_eq!(out.len(), 100 + 12);
assert_eq!(&out[..100], &filler[..]);
assert_eq!(&out[100..], b"second-block");
assert_eq!(b.next_offset(), 112);
assert!(b.is_pending_empty());
}
#[test]
fn duplicate_fragment_is_swallowed() {
let mut b = CryptoBuf::new();
let _ = b.on_crypto(0, b"hello world").expect("ok");
let out = b.on_crypto(0, b"hello world").expect("ok");
assert!(out.is_empty());
let out = b.on_crypto(2, b"llo").expect("ok");
assert!(out.is_empty());
assert_eq!(b.next_offset(), 11);
}
#[test]
fn fragment_straddles_boundary() {
let mut b = CryptoBuf::new();
let _ = b.on_crypto(0, b"hello").expect("ok");
let out = b.on_crypto(3, b"lo world").expect("ok");
assert_eq!(out, b" world");
assert_eq!(b.next_offset(), 11);
}
#[test]
fn outbound_enqueue_and_carve() {
let mut b = CryptoBuf::new();
assert!(!b.outbound_pending());
b.enqueue_outbound(b"hello world");
assert!(b.outbound_pending());
assert_eq!(b.outbound_len(), 11);
let (off, chunk) = b.carve(5).expect("carve");
assert_eq!(off, 0);
assert_eq!(chunk, b"hello");
let (off, chunk) = b.carve(100).expect("carve rest");
assert_eq!(off, 5);
assert_eq!(chunk, b" world");
assert!(b.carve(10).is_none());
assert!(!b.outbound_pending());
}
#[test]
fn schedule_retransmit_replays_last_chunk() {
let mut b = CryptoBuf::new();
b.enqueue_outbound(b"AAABBB");
let (off1, c1) = b.carve(3).expect("carve");
assert_eq!(off1, 0);
assert_eq!(c1, b"AAA");
let (off2, c2) = b.carve(3).expect("carve");
assert_eq!(off2, 3);
assert_eq!(c2, b"BBB");
assert!(!b.outbound_pending());
let scheduled = b.schedule_last_chunk_retransmit();
assert!(scheduled);
let (off3, c3) = b.carve(3).expect("retransmit");
assert_eq!(off3, 3);
assert_eq!(c3, b"BBB");
}
#[test]
fn schedule_retransmit_with_pending_after() {
let mut b = CryptoBuf::new();
b.enqueue_outbound(b"AAA");
let (off1, c1) = b.carve(3).expect("carve");
assert_eq!(off1, 0);
assert_eq!(c1, b"AAA");
b.enqueue_outbound(b"CCC");
let _ = b.schedule_last_chunk_retransmit();
let (off, chunk) = b.carve(100).expect("carve all");
assert_eq!(off, 0);
assert_eq!(chunk, b"AAACCC");
}
#[test]
fn pending_smaller_fragment_does_not_overwrite_larger() {
let mut b = CryptoBuf::new();
let _ = b.on_crypto(10, b"longer-fragment").expect("ok");
let _ = b.on_crypto(10, b"long").expect("ok");
let filler = alloc::vec![b'X'; 10];
let out = b.on_crypto(0, &filler).expect("ok");
assert_eq!(&out[10..], b"longer-fragment");
}
#[test]
fn crypto_buf_rejects_oversize_pending() {
let mut b = CryptoBuf::new();
let chunk_size = 8 * 1024;
let chunk = alloc::vec![b'A'; chunk_size];
let base = 1u64 << 20;
let max_chunks = MAX_PENDING_CRYPTO_BYTES / chunk_size;
for i in 0..max_chunks {
let off = base + (i as u64) * (chunk_size as u64);
b.on_crypto(off, &chunk).expect("under cap");
}
let extra_off = base + (max_chunks as u64) * (chunk_size as u64);
let res = b.on_crypto(extra_off, b"X");
assert!(matches!(res, Err(Error::Decode)));
assert_eq!(b.total_pending_bytes(), MAX_PENDING_CRYPTO_BYTES);
}
#[test]
fn crypto_buf_rejects_too_many_fragments() {
let mut b = CryptoBuf::new();
let base = 1u64 << 20;
for i in 0..MAX_PENDING_FRAGMENTS {
let off = base + (i as u64) * 1024; b.on_crypto(off, b"X").expect("under cap");
}
let extra_off = base + (MAX_PENDING_FRAGMENTS as u64) * 1024;
let res = b.on_crypto(extra_off, b"X");
assert!(matches!(res, Err(Error::Decode)));
assert_eq!(b.pending.len(), MAX_PENDING_FRAGMENTS);
}
#[test]
fn crypto_buf_capacity_relaxes_after_delivery() {
let mut b = CryptoBuf::new();
let half = MAX_PENDING_CRYPTO_BYTES / 2; let block_a = alloc::vec![b'A'; half];
b.on_crypto(half as u64, &block_a).expect("under cap");
assert_eq!(b.total_pending_bytes(), half);
let filler = alloc::vec![b'F'; half];
let out = b.on_crypto(0, &filler).expect("delivery");
assert_eq!(out.len(), half + half); assert_eq!(b.next_offset(), (half * 2) as u64);
assert_eq!(b.total_pending_bytes(), 0);
let bridge = alloc::vec![b'B'; half];
let _ = b.on_crypto((half * 2) as u64, &bridge).expect("in-order");
let block_c = alloc::vec![b'C'; half];
let _ = b
.on_crypto((half * 3) as u64, &block_c)
.expect("still in-order");
assert_eq!(b.next_offset(), (half * 4) as u64);
assert_eq!(b.total_pending_bytes(), 0);
}
#[test]
fn crypto_buf_replace_at_offset_counts_delta_only() {
let mut b = CryptoBuf::new();
let base = 1u64 << 20;
let chunk_size = 8 * 1024;
let chunk_short = alloc::vec![b'S'; chunk_size - 128];
let chunk_full = alloc::vec![b'F'; chunk_size];
for i in 0..7 {
let off = base + (i as u64) * (chunk_size as u64);
b.on_crypto(off, &chunk_short).expect("under cap");
}
let last_off = base + 7 * (chunk_size as u64);
b.on_crypto(last_off, &chunk_full).expect("under cap");
let replaced = b.on_crypto(base, &chunk_full);
assert!(replaced.is_ok());
}
}