use std::num::NonZeroUsize;
use super::{MessageKey, error::MessageAssemblyError};
#[derive(Clone, Copy, Debug)]
pub(super) struct AggregateBudgets {
pub(super) connection: Option<NonZeroUsize>,
pub(super) in_flight: Option<NonZeroUsize>,
}
impl AggregateBudgets {
pub(super) const fn is_enabled(&self) -> bool {
self.connection.is_some() || self.in_flight.is_some()
}
}
pub(super) fn check_aggregate_budgets(
key: MessageKey,
current_total: usize,
additional_bytes: usize,
budgets: &AggregateBudgets,
) -> Result<(), MessageAssemblyError> {
let new_total = current_total.saturating_add(additional_bytes);
if let Some(limit) = budgets.connection
&& new_total > limit.get()
{
return Err(MessageAssemblyError::ConnectionBudgetExceeded {
key,
attempted: new_total,
limit,
});
}
if let Some(limit) = budgets.in_flight
&& new_total > limit.get()
{
return Err(MessageAssemblyError::InFlightBudgetExceeded {
key,
attempted: new_total,
limit,
});
}
Ok(())
}
pub(super) fn check_size_limit(
max_message_size: NonZeroUsize,
key: MessageKey,
accumulated: usize,
body_len: usize,
) -> Result<usize, MessageAssemblyError> {
let Some(new_len) = accumulated.checked_add(body_len) else {
return Err(MessageAssemblyError::MessageTooLarge {
key,
attempted: usize::MAX,
limit: max_message_size,
});
};
if new_len > max_message_size.get() {
return Err(MessageAssemblyError::MessageTooLarge {
key,
attempted: new_len,
limit: max_message_size,
});
}
Ok(new_len)
}