use std::{
collections::{HashMap, hash_map::Entry},
num::NonZeroUsize,
time::{Duration, Instant},
};
use super::{
ContinuationFrameHeader,
MessageKey,
budget::{AggregateBudgets, check_aggregate_budgets, check_size_limit},
error::{MessageAssemblyError, MessageSeriesError, MessageSeriesStatus},
series::MessageSeries,
types::{AssembledMessage, EnvelopeRouting, FirstFrameInput},
};
#[derive(Debug)]
struct PartialAssembly {
series: MessageSeries,
routing: EnvelopeRouting,
metadata: Vec<u8>,
body_buffer: Vec<u8>,
started_at: Instant,
}
impl PartialAssembly {
fn new(series: MessageSeries, routing: EnvelopeRouting, started_at: Instant) -> Self {
Self {
series,
routing,
metadata: Vec::new(),
body_buffer: Vec::new(),
started_at,
}
}
fn push_body(&mut self, data: &[u8]) { self.body_buffer.extend_from_slice(data); }
fn set_metadata(&mut self, data: Vec<u8>) { self.metadata = data; }
fn accumulated_len(&self) -> usize { self.body_buffer.len() }
fn buffered_bytes(&self) -> usize { self.body_buffer.len().saturating_add(self.metadata.len()) }
}
#[derive(Debug)]
pub struct MessageAssemblyState {
max_message_size: NonZeroUsize,
timeout: Duration,
assemblies: HashMap<MessageKey, PartialAssembly>,
budgets: AggregateBudgets,
}
impl MessageAssemblyState {
#[must_use]
pub fn new(max_message_size: NonZeroUsize, timeout: Duration) -> Self {
Self::with_budgets(max_message_size, timeout, None, None)
}
#[must_use]
pub fn with_budgets(
max_message_size: NonZeroUsize,
timeout: Duration,
connection_budget: Option<NonZeroUsize>,
in_flight_budget: Option<NonZeroUsize>,
) -> Self {
Self {
max_message_size,
timeout,
assemblies: HashMap::new(),
budgets: AggregateBudgets {
connection: connection_budget,
in_flight: in_flight_budget,
},
}
}
pub fn accept_first_frame(
&mut self,
input: FirstFrameInput<'_>,
) -> Result<Option<AssembledMessage>, MessageAssemblyError> {
self.accept_first_frame_at(input, Instant::now())
}
pub fn accept_first_frame_at(
&mut self,
input: FirstFrameInput<'_>,
now: Instant,
) -> Result<Option<AssembledMessage>, MessageAssemblyError> {
self.purge_expired_at(now);
let key = input.header.message_key;
if self.assemblies.contains_key(&key) {
return Err(MessageAssemblyError::DuplicateFirstFrame { key });
}
let declared_body_len = input.header.total_body_len.unwrap_or(input.body.len());
let total_message_size = declared_body_len.saturating_add(input.metadata.len());
if total_message_size > self.max_message_size.get() {
return Err(MessageAssemblyError::MessageTooLarge {
key,
attempted: total_message_size,
limit: self.max_message_size,
});
}
let series = MessageSeries::from_first_frame(input.header);
if input.header.is_last {
return Ok(Some(AssembledMessage::new(
key,
input.routing,
input.metadata,
input.body.to_vec(),
)));
}
if self.budgets.is_enabled() {
let incoming_bytes = input.body.len().saturating_add(input.metadata.len());
check_aggregate_budgets(
key,
self.total_buffered_bytes(),
incoming_bytes,
&self.budgets,
)?;
}
let mut partial = PartialAssembly::new(series, input.routing, now);
partial.set_metadata(input.metadata);
partial.push_body(input.body);
self.assemblies.insert(key, partial);
Ok(None)
}
pub fn accept_continuation_frame(
&mut self,
header: &ContinuationFrameHeader,
body: &[u8],
) -> Result<Option<AssembledMessage>, MessageAssemblyError> {
self.accept_continuation_frame_at(header, body, Instant::now())
}
const fn is_unrecoverable_continuity_error(error: &MessageSeriesError) -> bool {
matches!(
error,
MessageSeriesError::KeyMismatch { .. }
| MessageSeriesError::SequenceOverflow { .. }
| MessageSeriesError::MissingFirstFrame { .. }
| MessageSeriesError::MissingSequence { .. }
| MessageSeriesError::ContinuationBodyLengthMismatch { .. }
)
}
pub fn accept_continuation_frame_at(
&mut self,
header: &ContinuationFrameHeader,
body: &[u8],
now: Instant,
) -> Result<Option<AssembledMessage>, MessageAssemblyError> {
self.purge_expired_at(now);
let key = header.message_key;
if header.body_len != body.len() {
return Err(MessageAssemblyError::Series(
MessageSeriesError::ContinuationBodyLengthMismatch {
key,
header_len: header.body_len,
actual_len: body.len(),
},
));
}
let max_message_size = self.max_message_size;
let budgets = self.budgets;
let buffered_total = if budgets.is_enabled() {
self.total_buffered_bytes()
} else {
0
};
let Entry::Occupied(mut entry) = self.assemblies.entry(key) else {
return Err(MessageAssemblyError::Series(
MessageSeriesError::MissingFirstFrame { key },
));
};
let status = match entry.get_mut().series.accept_continuation(header) {
Ok(s) => s,
Err(e) => {
if Self::is_unrecoverable_continuity_error(&e) {
entry.remove();
}
return Err(MessageAssemblyError::Series(e));
}
};
let accumulated = entry.get().accumulated_len();
if let Err(e) = check_size_limit(max_message_size, key, accumulated, body.len()) {
entry.remove();
return Err(e);
}
if let Err(e) = check_aggregate_budgets(key, buffered_total, body.len(), &budgets) {
entry.remove();
return Err(e);
}
entry.get_mut().push_body(body);
match status {
MessageSeriesStatus::Incomplete => Ok(None),
MessageSeriesStatus::Complete => {
let partial = entry.remove();
Ok(Some(AssembledMessage::new(
key,
partial.routing,
partial.metadata,
partial.body_buffer,
)))
}
}
}
pub fn purge_expired(&mut self) -> Vec<MessageKey> { self.purge_expired_at(Instant::now()) }
pub fn purge_expired_at(&mut self, now: Instant) -> Vec<MessageKey> {
let mut evicted = Vec::new();
let timeout = self.timeout;
self.assemblies.retain(|key, partial| {
let expired = now.saturating_duration_since(partial.started_at) >= timeout;
if expired {
evicted.push(*key);
}
!expired
});
evicted
}
#[must_use]
pub fn total_buffered_bytes(&self) -> usize {
self.assemblies
.values()
.map(PartialAssembly::buffered_bytes)
.sum()
}
#[must_use]
pub fn buffered_count(&self) -> usize { self.assemblies.len() }
}