use std::{
num::NonZeroUsize,
sync::atomic::{AtomicU64, Ordering},
};
use super::{FragmentHeader, FragmentIndex, FragmentationError, MessageId};
use crate::message::Message;
#[derive(Debug)]
pub struct Fragmenter {
max_fragment_size: NonZeroUsize,
next_message_id: AtomicU64,
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct FragmentCursor {
offset: usize,
index: FragmentIndex,
}
impl FragmentCursor {
pub(crate) const fn new(offset: usize, index: FragmentIndex) -> Self { Self { offset, index } }
}
impl Fragmenter {
#[must_use]
pub const fn new(max_fragment_size: NonZeroUsize) -> Self {
Self::with_starting_id(max_fragment_size, MessageId::new(0))
}
#[must_use]
pub const fn with_starting_id(max_fragment_size: NonZeroUsize, start_at: MessageId) -> Self {
Self {
max_fragment_size,
next_message_id: AtomicU64::new(start_at.get()),
}
}
#[must_use]
pub const fn max_fragment_size(&self) -> NonZeroUsize { self.max_fragment_size }
#[must_use]
pub fn next_message_id(&self) -> MessageId {
match self
.next_message_id
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
current.checked_add(1)
}) {
Ok(previous) => MessageId::new(previous),
Err(current) => panic!("message id counter exhausted at {current}"),
}
}
pub fn fragment_message<M: Message>(
&self,
message: &M,
) -> Result<FragmentBatch, FragmentationError> {
let bytes = message.to_bytes()?;
self.fragment_bytes(bytes)
}
pub fn fragment_bytes(
&self,
payload: impl AsRef<[u8]>,
) -> Result<FragmentBatch, FragmentationError> {
let message_id = self.next_message_id();
self.fragment_with_id(message_id, payload.as_ref())
}
pub fn fragment_with_id(
&self,
message_id: MessageId,
payload: impl AsRef<[u8]>,
) -> Result<FragmentBatch, FragmentationError> {
let fragments = self.build_fragments(message_id, payload.as_ref())?;
Ok(FragmentBatch::new(message_id, fragments))
}
fn build_fragments(
&self,
message_id: MessageId,
payload: &[u8],
) -> Result<Vec<FragmentFrame>, FragmentationError> {
self.build_fragments_from(
message_id,
payload,
FragmentCursor::new(0, FragmentIndex::zero()),
)
}
fn build_fragments_from(
&self,
message_id: MessageId,
payload: &[u8],
mut cursor: FragmentCursor,
) -> Result<Vec<FragmentFrame>, FragmentationError> {
let max = self.max_fragment_size.get();
if payload.is_empty() {
let header = FragmentHeader::new(message_id, FragmentIndex::zero(), true);
return Ok(vec![FragmentFrame::new(header, Vec::new())]);
}
let total = payload.len();
if cursor.offset > total {
return Err(FragmentationError::SliceBounds {
offset: cursor.offset,
end: cursor.offset,
total,
});
}
let mut fragments = Vec::with_capacity(div_ceil(total, max));
while cursor.offset < total {
let end = (cursor.offset + max).min(total);
let is_last = end == total;
let chunk = if let Some(slice) = payload.get(cursor.offset..end) {
slice.to_vec()
} else {
return Err(FragmentationError::SliceBounds {
offset: cursor.offset,
end,
total,
});
};
fragments.push(FragmentFrame::new(
FragmentHeader::new(message_id, cursor.index, is_last),
chunk,
));
if is_last {
break;
}
cursor.offset = end;
cursor.index = cursor
.index
.checked_increment()
.ok_or(FragmentationError::IndexOverflow { last: cursor.index })?;
}
Ok(fragments)
}
}
#[cfg(test)]
impl Fragmenter {
pub(crate) fn build_fragments_from_for_tests(
&self,
message_id: MessageId,
payload: &[u8],
cursor: FragmentCursor,
) -> Result<Vec<FragmentFrame>, FragmentationError> {
self.build_fragments_from(message_id, payload, cursor)
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct FragmentFrame {
header: FragmentHeader,
payload: Vec<u8>,
}
impl FragmentFrame {
#[must_use]
pub fn new(header: FragmentHeader, payload: Vec<u8>) -> Self { Self { header, payload } }
#[must_use]
pub fn header(&self) -> &FragmentHeader { &self.header }
#[must_use]
pub fn payload(&self) -> &[u8] { self.payload.as_slice() }
#[must_use]
pub fn into_parts(self) -> (FragmentHeader, Vec<u8>) { (self.header, self.payload) }
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct FragmentBatch {
message_id: MessageId,
fragments: Vec<FragmentFrame>,
}
impl FragmentBatch {
fn new(message_id: MessageId, fragments: Vec<FragmentFrame>) -> Self {
debug_assert!(!fragments.is_empty(), "fragment batches must not be empty");
Self {
message_id,
fragments,
}
}
#[must_use]
pub const fn message_id(&self) -> MessageId { self.message_id }
#[must_use]
pub fn fragments(&self) -> &[FragmentFrame] { self.fragments.as_slice() }
#[expect(
clippy::len_without_is_empty,
reason = "batches are guaranteed non-empty"
)]
#[must_use]
pub fn len(&self) -> usize { self.fragments.len() }
#[must_use]
pub fn is_fragmented(&self) -> bool { self.len() > 1 }
#[must_use]
pub fn into_fragments(self) -> Vec<FragmentFrame> { self.fragments }
}
impl IntoIterator for FragmentBatch {
type Item = FragmentFrame;
type IntoIter = std::vec::IntoIter<FragmentFrame>;
fn into_iter(self) -> Self::IntoIter { self.fragments.into_iter() }
}
fn div_ceil(numerator: usize, denominator: usize) -> usize { numerator.div_ceil(denominator) }