use std::{
collections::{
HashMap,
hash_map::{Entry, OccupiedEntry},
},
num::NonZeroUsize,
time::{Duration, Instant},
};
use bincode::error::DecodeError;
use super::{FragmentHeader, FragmentSeries, FragmentStatus, MessageId, ReassemblyError};
use crate::message::Message;
#[derive(Debug)]
struct PartialMessage {
series: FragmentSeries,
buffer: Vec<u8>,
started_at: Instant,
}
struct FirstFragment<'a> {
header: FragmentHeader,
payload: &'a [u8],
now: Instant,
}
impl PartialMessage {
fn new(series: FragmentSeries, payload: &[u8], started_at: Instant) -> Self {
Self {
series,
buffer: payload.to_vec(),
started_at,
}
}
fn push(&mut self, payload: &[u8]) { self.buffer.extend_from_slice(payload); }
fn len(&self) -> usize { self.buffer.len() }
fn started_at(&self) -> Instant { self.started_at }
fn into_buffer(self) -> Vec<u8> { self.buffer }
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ReassembledMessage {
message_id: MessageId,
payload: Vec<u8>,
}
impl ReassembledMessage {
#[must_use]
pub fn new(message_id: MessageId, payload: Vec<u8>) -> Self {
Self {
message_id,
payload,
}
}
#[must_use]
pub const fn message_id(&self) -> MessageId { self.message_id }
#[must_use]
pub fn payload(&self) -> &[u8] { self.payload.as_slice() }
#[must_use]
pub fn into_payload(self) -> Vec<u8> { self.payload }
pub fn decode<M: Message>(&self) -> Result<M, DecodeError> {
let (message, _) = M::from_bytes(self.payload())?;
Ok(message)
}
}
#[derive(Debug)]
pub struct Reassembler {
max_message_size: NonZeroUsize,
timeout: Duration,
buffers: HashMap<MessageId, PartialMessage>,
}
impl Reassembler {
#[must_use]
pub fn new(max_message_size: NonZeroUsize, timeout: Duration) -> Self {
Self {
max_message_size,
timeout,
buffers: HashMap::new(),
}
}
pub fn push(
&mut self,
header: FragmentHeader,
payload: impl AsRef<[u8]>,
) -> Result<Option<ReassembledMessage>, ReassemblyError> {
self.push_at(header, payload, Instant::now())
}
pub fn push_at(
&mut self,
header: FragmentHeader,
payload: impl AsRef<[u8]>,
now: Instant,
) -> Result<Option<ReassembledMessage>, ReassemblyError> {
self.purge_expired_at(now);
let payload = payload.as_ref();
match self.buffers.entry(header.message_id()) {
Entry::Occupied(occupied) => {
Self::push_existing_fragment(self.max_message_size, occupied, header, payload)
}
Entry::Vacant(vacant) => Self::push_first_fragment(
self.max_message_size,
vacant,
&FirstFragment {
header,
payload,
now,
},
),
}
}
pub fn purge_expired(&mut self) -> Vec<MessageId> { self.purge_expired_at(Instant::now()) }
pub fn purge_expired_at(&mut self, now: Instant) -> Vec<MessageId> {
let mut evicted = Vec::new();
let timeout = self.timeout;
self.buffers.retain(|message_id, partial| {
let expired = now.saturating_duration_since(partial.started_at()) >= timeout;
if expired {
evicted.push(*message_id);
}
!expired
});
evicted
}
#[must_use]
pub fn buffered_len(&self) -> usize { self.buffers.len() }
fn assert_within_limit(
limit: NonZeroUsize,
message_id: MessageId,
attempted: usize,
) -> Result<(), ReassemblyError> {
if attempted > limit.get() {
return Err(ReassemblyError::MessageTooLarge {
message_id,
attempted,
limit,
});
}
Ok(())
}
fn append_and_maybe_complete(
limit: NonZeroUsize,
mut occupied: OccupiedEntry<'_, MessageId, PartialMessage>,
payload: &[u8],
completes: bool,
) -> Result<Option<ReassembledMessage>, ReassemblyError> {
let message_id = *occupied.key();
let Some(attempted) = occupied.get().len().checked_add(payload.len()) else {
occupied.remove();
return Err(ReassemblyError::MessageTooLarge {
message_id,
attempted: usize::MAX,
limit,
});
};
if let Err(err) = Self::assert_within_limit(limit, message_id, attempted) {
occupied.remove();
return Err(err);
}
occupied.get_mut().push(payload);
if completes {
let buffer = occupied.remove().into_buffer();
Ok(Some(ReassembledMessage::new(message_id, buffer)))
} else {
Ok(None)
}
}
fn push_existing_fragment(
limit: NonZeroUsize,
mut occupied: OccupiedEntry<'_, MessageId, PartialMessage>,
header: FragmentHeader,
payload: &[u8],
) -> Result<Option<ReassembledMessage>, ReassemblyError> {
let status = occupied
.get_mut()
.series
.accept(header)
.map_err(ReassemblyError::from);
match status {
Ok(FragmentStatus::Incomplete) => {
Self::append_and_maybe_complete(limit, occupied, payload, false)
}
Ok(FragmentStatus::Duplicate) => Ok(None),
Ok(FragmentStatus::Complete) => {
Self::append_and_maybe_complete(limit, occupied, payload, true)
}
Err(err) => {
occupied.remove();
Err(err)
}
}
}
fn push_first_fragment(
limit: NonZeroUsize,
vacant: std::collections::hash_map::VacantEntry<'_, MessageId, PartialMessage>,
fragment: &FirstFragment<'_>,
) -> Result<Option<ReassembledMessage>, ReassemblyError> {
let message_id = fragment.header.message_id();
let mut series = FragmentSeries::new(message_id);
let status = series
.accept(fragment.header)
.map_err(ReassemblyError::from)?;
Self::assert_within_limit(limit, message_id, fragment.payload.len())?;
match status {
FragmentStatus::Incomplete => {
vacant.insert(PartialMessage::new(series, fragment.payload, fragment.now));
Ok(None)
}
#[expect(
clippy::unreachable,
reason = "The first accepted fragment for a new series cannot be duplicate"
)]
FragmentStatus::Duplicate => {
unreachable!(
"newly created FragmentSeries starts at index 0; a first fragment cannot be \
duplicate"
);
}
FragmentStatus::Complete => Ok(Some(ReassembledMessage::new(
message_id,
fragment.payload.to_vec(),
))),
}
}
}