use gbp::CodecError;
use serde::{Deserialize, Serialize};
use serde_bytes::ByteBuf;
use sha2::{Digest, Sha256};
use std::collections::HashMap;
pub const DEFAULT_CHUNK_SIZE: usize = 64 * 1024;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AttachmentManifest {
#[serde(rename = "aid")]
pub attachment_id: u64,
#[serde(rename = "name")]
pub filename: String,
#[serde(rename = "mime")]
pub mime_type: String,
#[serde(rename = "size")]
pub total_size: u64,
#[serde(rename = "nc")]
pub chunk_count: u32,
#[serde(rename = "hash")]
pub sha256: ByteBuf,
}
impl AttachmentManifest {
pub fn to_cbor(&self) -> Vec<u8> {
let mut buf = Vec::new();
ciborium::into_writer(self, &mut buf).expect("cbor encode");
buf
}
pub fn from_cbor(data: &[u8]) -> Result<Self, CodecError> {
ciborium::from_reader(data).map_err(|e| CodecError::Decode(e.to_string()))
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AttachmentChunk {
#[serde(rename = "aid")]
pub attachment_id: u64,
#[serde(rename = "idx")]
pub chunk_index: u32,
#[serde(rename = "nc")]
pub chunk_count: u32,
#[serde(rename = "data")]
pub data: ByteBuf,
}
impl AttachmentChunk {
pub fn to_cbor(&self) -> Vec<u8> {
let mut buf = Vec::new();
ciborium::into_writer(self, &mut buf).expect("cbor encode");
buf
}
pub fn from_cbor(data: &[u8]) -> Result<Self, CodecError> {
ciborium::from_reader(data).map_err(|e| CodecError::Decode(e.to_string()))
}
}
#[derive(Debug, thiserror::Error)]
pub enum AttachmentError {
#[error("decode: {0}")]
Decode(#[from] CodecError),
#[error("chunk index {idx} out of range (count={count})")]
ChunkOutOfRange {
idx: u32,
count: u32,
},
#[error("integrity check failed: hash mismatch")]
HashMismatch,
#[error("incomplete: {received}/{total} chunks received")]
Incomplete {
received: u32,
total: u32,
},
}
pub struct AttachmentSender {
pub manifest: AttachmentManifest,
pub chunks: Vec<Vec<u8>>,
}
impl AttachmentSender {
pub fn new(
attachment_id: u64,
filename: impl Into<String>,
mime_type: impl Into<String>,
data: &[u8],
chunk_size: usize,
) -> Self {
let hash = Sha256::digest(data);
let chunk_size = chunk_size.max(1);
let raw_chunks: Vec<&[u8]> = data.chunks(chunk_size).collect();
let chunk_count = raw_chunks.len() as u32;
let manifest = AttachmentManifest {
attachment_id,
filename: filename.into(),
mime_type: mime_type.into(),
total_size: data.len() as u64,
chunk_count,
sha256: ByteBuf::from(hash.as_slice().to_vec()),
};
let chunks = raw_chunks
.into_iter()
.enumerate()
.map(|(i, slice)| {
AttachmentChunk {
attachment_id,
chunk_index: i as u32,
chunk_count,
data: ByteBuf::from(slice.to_vec()),
}
.to_cbor()
})
.collect();
Self { manifest, chunks }
}
}
pub struct AttachmentAssembler {
manifest: AttachmentManifest,
received: HashMap<u32, Vec<u8>>,
}
impl AttachmentAssembler {
pub fn new(manifest: AttachmentManifest) -> Self {
Self {
manifest,
received: HashMap::new(),
}
}
pub fn manifest(&self) -> &AttachmentManifest {
&self.manifest
}
pub fn received_count(&self) -> u32 {
self.received.len() as u32
}
pub fn is_complete(&self) -> bool {
self.received.len() as u32 == self.manifest.chunk_count
}
pub fn push(&mut self, chunk: AttachmentChunk) -> Result<(), AttachmentError> {
if chunk.chunk_index >= self.manifest.chunk_count {
return Err(AttachmentError::ChunkOutOfRange {
idx: chunk.chunk_index,
count: self.manifest.chunk_count,
});
}
self.received
.entry(chunk.chunk_index)
.or_insert_with(|| chunk.data.into_vec());
Ok(())
}
pub fn assemble(self) -> Result<Vec<u8>, AttachmentError> {
let total = self.manifest.chunk_count;
let received = self.received.len() as u32;
if received < total {
return Err(AttachmentError::Incomplete { received, total });
}
let mut payload = Vec::with_capacity(self.manifest.total_size as usize);
for i in 0..total {
payload.extend_from_slice(self.received.get(&i).unwrap());
}
let hash = Sha256::digest(&payload);
if hash.as_slice() != self.manifest.sha256.as_ref() {
return Err(AttachmentError::HashMismatch);
}
Ok(payload)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_data(n: usize) -> Vec<u8> {
(0..n).map(|i| (i % 251) as u8).collect()
}
#[test]
fn round_trip_small_payload() {
let data = sample_data(100);
let sender = AttachmentSender::new(
1,
"file.bin",
"application/octet-stream",
&data,
DEFAULT_CHUNK_SIZE,
);
assert_eq!(sender.manifest.chunk_count, 1);
let mut asm = AttachmentAssembler::new(sender.manifest);
for cbor in &sender.chunks {
let chunk = AttachmentChunk::from_cbor(cbor).unwrap();
asm.push(chunk).unwrap();
}
let result = asm.assemble().unwrap();
assert_eq!(result, data);
}
#[test]
fn round_trip_multi_chunk() {
let data = sample_data(300);
let sender = AttachmentSender::new(2, "multi.bin", "application/octet-stream", &data, 100);
assert_eq!(sender.manifest.chunk_count, 3);
let mut asm = AttachmentAssembler::new(sender.manifest);
for cbor in &sender.chunks {
let chunk = AttachmentChunk::from_cbor(cbor).unwrap();
asm.push(chunk).unwrap();
}
assert!(asm.is_complete());
let result = asm.assemble().unwrap();
assert_eq!(result, data);
}
#[test]
fn out_of_order_chunks_reassemble_correctly() {
let data = sample_data(250);
let sender = AttachmentSender::new(3, "ooo.bin", "application/octet-stream", &data, 100);
let mut asm = AttachmentAssembler::new(sender.manifest);
for cbor in sender.chunks.iter().rev() {
let chunk = AttachmentChunk::from_cbor(cbor).unwrap();
asm.push(chunk).unwrap();
}
let result = asm.assemble().unwrap();
assert_eq!(result, data);
}
#[test]
fn duplicate_chunk_ignored() {
let data = sample_data(100);
let sender = AttachmentSender::new(
4,
"dup.bin",
"application/octet-stream",
&data,
DEFAULT_CHUNK_SIZE,
);
let mut asm = AttachmentAssembler::new(sender.manifest);
let chunk = AttachmentChunk::from_cbor(&sender.chunks[0]).unwrap();
asm.push(chunk.clone()).unwrap();
asm.push(chunk).unwrap(); let result = asm.assemble().unwrap();
assert_eq!(result, data);
}
#[test]
fn hash_mismatch_detected() {
let data = sample_data(100);
let sender = AttachmentSender::new(
5,
"bad.bin",
"application/octet-stream",
&data,
DEFAULT_CHUNK_SIZE,
);
let mut manifest = sender.manifest;
manifest.sha256[0] ^= 0xFF;
let mut asm = AttachmentAssembler::new(manifest);
let chunk = AttachmentChunk::from_cbor(&sender.chunks[0]).unwrap();
asm.push(chunk).unwrap();
assert!(matches!(asm.assemble(), Err(AttachmentError::HashMismatch)));
}
#[test]
fn incomplete_returns_error() {
let data = sample_data(300);
let sender = AttachmentSender::new(6, "inc.bin", "application/octet-stream", &data, 100);
let mut asm = AttachmentAssembler::new(sender.manifest);
let chunk = AttachmentChunk::from_cbor(&sender.chunks[0]).unwrap();
asm.push(chunk).unwrap();
assert!(matches!(
asm.assemble(),
Err(AttachmentError::Incomplete { .. })
));
}
#[test]
fn chunk_out_of_range_rejected() {
let data = sample_data(100);
let sender = AttachmentSender::new(
7,
"oor.bin",
"application/octet-stream",
&data,
DEFAULT_CHUNK_SIZE,
);
let mut asm = AttachmentAssembler::new(sender.manifest);
let bad_chunk = AttachmentChunk {
attachment_id: 7,
chunk_index: 99,
chunk_count: 1,
data: ByteBuf::new(),
};
assert!(matches!(
asm.push(bad_chunk),
Err(AttachmentError::ChunkOutOfRange { .. })
));
}
#[test]
fn manifest_cbor_round_trip() {
let data = sample_data(50);
let sender = AttachmentSender::new(8, "rt.bin", "text/plain", &data, DEFAULT_CHUNK_SIZE);
let encoded = sender.manifest.to_cbor();
let decoded = AttachmentManifest::from_cbor(&encoded).unwrap();
assert_eq!(decoded.attachment_id, 8);
assert_eq!(decoded.filename, "rt.bin");
assert_eq!(decoded.chunk_count, 1);
}
}