use super::{StreamError, StreamId};
use crate::bytes::Bytes;
use crate::types::outcome::Outcome;
use std::collections::BTreeMap;
#[derive(Debug, Clone)]
pub struct DataSegment {
pub offset: u64,
pub data: Bytes,
pub is_final: bool,
}
impl DataSegment {
pub fn new(offset: u64, data: Bytes, is_final: bool) -> Self {
Self {
offset,
data,
is_final,
}
}
pub fn end_offset(&self) -> u64 {
self.offset + self.data.len() as u64
}
pub fn overlaps_with(&self, other: &DataSegment) -> bool {
self.offset < other.end_offset() && other.offset < self.end_offset()
}
pub fn is_adjacent_to(&self, other: &DataSegment) -> bool {
self.end_offset() == other.offset || other.end_offset() == self.offset
}
}
#[derive(Debug)]
pub struct ReassemblyBuffer {
segments: BTreeMap<u64, DataSegment>,
next_offset: u64,
final_size: Option<u64>,
received_final: bool,
max_buffered_data: u64,
buffered_data_size: u64,
}
impl ReassemblyBuffer {
pub fn new(max_buffered_data: u64) -> Self {
Self {
segments: BTreeMap::new(),
next_offset: 0,
final_size: None,
received_final: false,
max_buffered_data,
buffered_data_size: 0,
}
}
pub fn insert_segment(&mut self, mut segment: DataSegment) -> Outcome<Vec<Bytes>, StreamError> {
if segment.offset < self.next_offset {
if segment.end_offset() <= self.next_offset {
return Outcome::ok(Vec::new());
}
let duplicate_len = (self.next_offset - segment.offset) as usize;
segment.data = segment.data.slice(duplicate_len..);
segment.offset = self.next_offset;
}
if segment.is_final {
let segment_final_size = segment.end_offset();
if let Some(existing_final_size) = self.final_size {
if segment_final_size != existing_final_size {
return Outcome::err(StreamError::FinalSizeMismatch {
stream_id: StreamId::new(0), expected: existing_final_size,
actual: segment_final_size,
});
}
} else {
self.final_size = Some(segment_final_size);
}
self.received_final = true;
}
let uncovered_segments = match self.uncovered_segments(segment) {
Ok(segments) => segments,
Err(err) => return Outcome::err(err),
};
let new_data_size = uncovered_segments
.iter()
.fold(0_u64, |sum, segment| sum + segment.data.len() as u64);
if self.buffered_data_size + new_data_size > self.max_buffered_data {
return Outcome::err(StreamError::ConnectionError {
reason: "Reassembly buffer limit exceeded".to_string(),
});
}
for uncovered in uncovered_segments {
let offset = uncovered.offset;
self.buffered_data_size += uncovered.data.len() as u64;
self.segments.insert(offset, uncovered);
}
let deliverable = self.extract_deliverable_data();
Outcome::ok(deliverable)
}
fn uncovered_segments(&self, segment: DataSegment) -> Result<Vec<DataSegment>, StreamError> {
let mut ranges = vec![(0usize, segment.data.len())];
for existing in self.segments.values() {
if !segment.overlaps_with(existing) {
continue;
}
let overlap_start = segment.offset.max(existing.offset);
let overlap_end = segment.end_offset().min(existing.end_offset());
let segment_start = (overlap_start - segment.offset) as usize;
let segment_end = (overlap_end - segment.offset) as usize;
let existing_start = (overlap_start - existing.offset) as usize;
let existing_end = (overlap_end - existing.offset) as usize;
if segment.data.slice(segment_start..segment_end)
!= existing.data.slice(existing_start..existing_end)
{
return Err(StreamError::InvalidState {
stream_id: StreamId::new(0),
state: format!(
"Conflicting overlapping segment at offset {}",
segment.offset
),
});
}
let mut next_ranges = Vec::with_capacity(ranges.len() + 1);
for (start, end) in ranges {
if segment_end <= start || segment_start >= end {
next_ranges.push((start, end));
continue;
}
if start < segment_start {
next_ranges.push((start, segment_start));
}
if segment_end < end {
next_ranges.push((segment_end, end));
}
}
ranges = next_ranges;
if ranges.is_empty() {
break;
}
}
Ok(ranges
.into_iter()
.map(|(start, end)| DataSegment {
offset: segment.offset + start as u64,
data: segment.data.slice(start..end),
is_final: segment.is_final && end == segment.data.len(),
})
.collect())
}
fn extract_deliverable_data(&mut self) -> Vec<Bytes> {
let mut deliverable = Vec::new();
while let Some((&offset, _)) = self.segments.iter().next() {
if offset != self.next_offset {
break;
}
if let Some(segment) = self.segments.remove(&offset) {
self.next_offset = segment.end_offset();
self.buffered_data_size -= segment.data.len() as u64;
deliverable.push(segment.data);
}
}
deliverable
}
pub fn is_complete(&self) -> bool {
self.received_final
&& self.segments.is_empty()
&& self.final_size.is_some_and(|size| self.next_offset >= size)
}
pub fn next_expected_offset(&self) -> u64 {
self.next_offset
}
pub fn final_size(&self) -> Option<u64> {
self.final_size
}
pub fn received_final_segment(&self) -> bool {
self.received_final
}
pub fn buffered_segments(&self) -> usize {
self.segments.len()
}
pub fn buffered_data_size(&self) -> u64 {
self.buffered_data_size
}
pub fn statistics(&self) -> ReassemblyStats {
let gaps = self.count_gaps();
ReassemblyStats {
next_offset: self.next_offset,
final_size: self.final_size,
buffered_segments: self.segments.len(),
buffered_data_size: self.buffered_data_size,
max_buffered_data: self.max_buffered_data,
gaps: gaps,
is_complete: self.is_complete(),
}
}
fn count_gaps(&self) -> usize {
let mut gaps = 0;
let mut expected_offset = self.next_offset;
for (&offset, segment) in &self.segments {
if offset > expected_offset {
gaps += 1;
}
expected_offset = segment.end_offset();
}
gaps
}
pub fn reset(&mut self) {
self.segments.clear();
self.next_offset = 0;
self.final_size = None;
self.received_final = false;
self.buffered_data_size = 0;
}
pub fn has_gaps(&self) -> bool {
self.count_gaps() > 0
}
pub fn earliest_gap_offset(&self) -> Option<u64> {
if self.segments.is_empty() {
return None;
}
let mut expected_offset = self.next_offset;
for (&offset, segment) in &self.segments {
if offset > expected_offset {
return Some(expected_offset);
}
expected_offset = segment.end_offset();
}
None
}
}
#[derive(Debug, Clone)]
pub struct ReassemblyStats {
pub next_offset: u64,
pub final_size: Option<u64>,
pub buffered_segments: usize,
pub buffered_data_size: u64,
pub max_buffered_data: u64,
pub gaps: usize,
pub is_complete: bool,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bytes::Bytes;
#[test]
fn test_reassembly_in_order() {
let mut buffer = ReassemblyBuffer::new(10000);
let segment1 = DataSegment::new(0, Bytes::from("hello"), false);
let segment2 = DataSegment::new(5, Bytes::from("world"), true);
let result1 = buffer.insert_segment(segment1).unwrap(); assert_eq!(result1.len(), 1);
assert_eq!(&result1[0][..], b"hello");
let result2 = buffer.insert_segment(segment2).unwrap(); assert_eq!(result2.len(), 1);
assert_eq!(&result2[0][..], b"world");
assert!(buffer.is_complete());
assert_eq!(buffer.final_size(), Some(10));
}
#[test]
fn test_reassembly_out_of_order() {
let mut buffer = ReassemblyBuffer::new(10000);
let segment2 = DataSegment::new(5, Bytes::from("world"), true);
let segment1 = DataSegment::new(0, Bytes::from("hello"), false);
let result1 = buffer.insert_segment(segment2).unwrap(); assert_eq!(result1.len(), 0);
let result2 = buffer.insert_segment(segment1).unwrap(); assert_eq!(result2.len(), 2);
assert_eq!(&result2[0][..], b"hello");
assert_eq!(&result2[1][..], b"world");
assert!(buffer.is_complete());
}
#[test]
fn test_final_size_mismatch() {
let mut buffer = ReassemblyBuffer::new(10000);
let segment1 = DataSegment::new(0, Bytes::from("hello"), true);
let segment2 = DataSegment::new(5, Bytes::from("world"), true);
buffer.insert_segment(segment1).unwrap();
let result = buffer.insert_segment(segment2);
assert!(result.is_err());
}
#[test]
fn test_overlapping_segments() {
let mut buffer = ReassemblyBuffer::new(10000);
let segment1 = DataSegment::new(5, Bytes::from("world"), false);
let duplicate_overlap = DataSegment::new(7, Bytes::from("rld"), false);
let conflicting_overlap = DataSegment::new(6, Bytes::from("XX"), false);
buffer.insert_segment(segment1).unwrap();
let duplicate = buffer.insert_segment(duplicate_overlap).unwrap(); assert!(duplicate.is_empty());
let result = buffer.insert_segment(conflicting_overlap);
assert!(result.is_err());
}
#[test]
fn test_buffer_limit() {
let mut buffer = ReassemblyBuffer::new(10);
let large_segment = DataSegment::new(0, Bytes::from("this is too large"), false);
let result = buffer.insert_segment(large_segment);
assert!(result.is_err());
}
}