use std::collections::HashMap;
const UINT32_LEN: usize = 4;
const HEADER_LENGTH: usize = UINT32_LEN * 3;
fn decode_header(header: &[u8]) -> anyhow::Result<(u32, u32, u32)> {
if header.len() != UINT32_LEN * 3 {
return Err(anyhow::anyhow!(
"invalid chunk header length: expected {}, got {}",
HEADER_LENGTH,
header.len()
));
}
let mut values = [0; 3];
header.chunks_exact(4).enumerate().for_each(|(idx, chunk)| {
values[idx] = u32::from_le_bytes(chunk.try_into().unwrap());
});
Ok((values[0], values[1], values[2]))
}
fn encode_header(id: u32, offset: u32, total_size: u32) -> [u8; HEADER_LENGTH] {
let mut bytes = [0; HEADER_LENGTH];
let values = [id, offset, total_size];
bytes
.chunks_exact_mut(4)
.enumerate()
.for_each(|(i, chunk)| chunk.copy_from_slice(&values[i].to_le_bytes()));
bytes
}
#[derive(Default)]
struct PendingMessage {
received_size: usize,
content: Vec<u8>,
received_ranges: Vec<(usize, usize)>,
}
impl PendingMessage {
fn add_received_range(&mut self, start: usize, end: usize) {
if start >= end {
return;
}
let mut added = end - start;
let mut merged_start = start;
let mut merged_end = end;
let mut idx = 0;
while idx < self.received_ranges.len() && self.received_ranges[idx].1 < merged_start {
idx += 1;
}
let merge_start = idx;
while idx < self.received_ranges.len() && self.received_ranges[idx].0 <= merged_end {
let (existing_start, existing_end) = self.received_ranges[idx];
let overlap_start = start.max(existing_start);
let overlap_end = end.min(existing_end);
if overlap_start < overlap_end {
added -= overlap_end - overlap_start;
}
merged_start = merged_start.min(existing_start);
merged_end = merged_end.max(existing_end);
idx += 1;
}
self.received_ranges
.splice(merge_start..idx, std::iter::once((merged_start, merged_end)));
self.received_size += added;
}
}
pub(crate) struct UnChunker {
max_message_size: usize,
pending_message: HashMap<u32, PendingMessage>,
}
impl UnChunker {
pub(crate) fn new() -> Self {
Self {
max_message_size: 512,
pending_message: HashMap::default(),
}
}
pub(crate) fn set_max_message_size(&mut self, size: usize) {
self.max_message_size = size;
}
pub(crate) fn process_chunk(&mut self, chunk: &[u8]) -> anyhow::Result<Option<Vec<u8>>> {
if chunk.len() < HEADER_LENGTH {
return Err(anyhow::anyhow!(
"chunk shorter than header: {} < {}",
chunk.len(),
HEADER_LENGTH
));
}
let (header, chunk_content) = chunk.split_at(HEADER_LENGTH);
let (id, offset, total_size) = decode_header(header)?;
let offset = offset as usize;
let total_size = total_size as usize;
if total_size > self.max_message_size {
return Err(anyhow::anyhow!(format!(
r"Total size for message {} exceeds the allocation limit allowed.
Maximum size = {},
Received size = {}.",
id, self.max_message_size, total_size
)));
}
if offset > total_size {
return Err(anyhow::anyhow!(
"chunk offset out of bounds for message {}: offset={}, total_size={}",
id,
offset,
total_size
));
}
let end = offset
.checked_add(chunk_content.len())
.ok_or_else(|| anyhow::anyhow!("chunk size overflow for message {}", id))?;
if end > total_size {
return Err(anyhow::anyhow!(
"chunk exceeds message bounds for message {}: offset={}, chunk_size={}, total_size={}",
id,
offset,
chunk_content.len(),
total_size
));
}
let mut pending_message = self.pending_message.get_mut(&id);
if pending_message.is_none() {
let mut content = Vec::new();
content.resize(total_size, 0);
self.pending_message.insert(
id,
PendingMessage {
received_size: 0,
content,
received_ranges: Vec::new(),
},
);
pending_message = self.pending_message.get_mut(&id);
}
if let Some(pending_message) = pending_message {
if total_size != pending_message.content.len() {
self.pending_message.remove(&id);
return Err(anyhow::anyhow!(
"Total size in chunk header for message {} does not match total size declared by previous chunk.",
id
));
}
pending_message.content[offset..end].copy_from_slice(chunk_content);
pending_message.add_received_range(offset, end);
let value = if pending_message.received_size == total_size {
self.pending_message.remove(&id).map(|msg| msg.content)
} else {
None
};
return Ok(value);
}
Err(anyhow::anyhow!("panic"))
}
}
pub(crate) fn generate_chunks(message: &[u8], max_size: usize) -> Vec<Vec<u8>> {
use rand;
let total_len = message.len();
let total_len_u32 = u32::try_from(total_len).expect("message length does not fit in chunk header");
let max_content_size = max_size.saturating_sub(HEADER_LENGTH).max(1);
let id = rand::random();
if total_len == 0 {
return vec![encode_header(id, 0, 0).to_vec()];
}
let mut chunks = Vec::new();
let mut offset = 0;
while offset < total_len {
let end = (offset + max_content_size).min(total_len);
let header = encode_header(id, offset as u32, total_len_u32);
let mut bytes = Vec::with_capacity(HEADER_LENGTH + (end - offset));
bytes.extend_from_slice(&header);
bytes.extend_from_slice(&message[offset..end]);
offset = end;
chunks.push(bytes);
}
chunks
}