1use std::collections::HashMap;
15use thiserror::Error;
16
17#[derive(Debug, Error)]
18#[non_exhaustive]
19pub enum ChunkError {
20 #[error("invalid chunk: idx={idx} >= total={total}")]
21 InvalidIndex { idx: u32, total: u32 },
22 #[error("size mismatch for message {message_id}: previously {previous}, now {now}")]
23 SizeMismatch { message_id: u64, previous: u32, now: u32 },
24}
25
26#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct Chunk {
29 pub message_id: u64,
30 pub chunk_idx: u32,
31 pub total_chunks: u32,
32 pub payload: Vec<u8>,
33}
34
35impl Chunk {
36 pub fn to_wire(&self) -> Vec<u8> {
40 let mut buf = Vec::with_capacity(16 + self.payload.len());
41 buf.extend_from_slice(&self.message_id.to_le_bytes());
42 buf.extend_from_slice(&self.chunk_idx.to_le_bytes());
43 buf.extend_from_slice(&self.total_chunks.to_le_bytes());
44 buf.extend_from_slice(&self.payload);
45 buf
46 }
47
48 pub fn from_wire(bytes: &[u8]) -> Option<Self> {
49 if bytes.len() < 16 {
50 return None;
51 }
52 let mut id_bytes = [0u8; 8];
53 id_bytes.copy_from_slice(&bytes[..8]);
54 let mut idx_bytes = [0u8; 4];
55 idx_bytes.copy_from_slice(&bytes[8..12]);
56 let mut total_bytes = [0u8; 4];
57 total_bytes.copy_from_slice(&bytes[12..16]);
58 Some(Self {
59 message_id: u64::from_le_bytes(id_bytes),
60 chunk_idx: u32::from_le_bytes(idx_bytes),
61 total_chunks: u32::from_le_bytes(total_bytes),
62 payload: bytes[16..].to_vec(),
63 })
64 }
65}
66
67pub struct Chunker {
69 pub chunk_size: usize,
70}
71
72impl Chunker {
73 pub fn new(chunk_size: usize) -> Self {
74 assert!(chunk_size >= 1, "chunk_size must be >= 1");
75 Self { chunk_size }
76 }
77
78 pub fn split(&self, message_id: u64, payload: &[u8]) -> Vec<Chunk> {
82 if payload.is_empty() {
83 return vec![Chunk { message_id, chunk_idx: 0, total_chunks: 1, payload: Vec::new() }];
84 }
85 let total = payload.len().div_ceil(self.chunk_size);
86 let mut out = Vec::with_capacity(total);
87 for (i, chunk_payload) in payload.chunks(self.chunk_size).enumerate() {
88 out.push(Chunk {
89 message_id,
90 chunk_idx: i as u32,
91 total_chunks: total as u32,
92 payload: chunk_payload.to_vec(),
93 });
94 }
95 out
96 }
97}
98
99#[derive(Default)]
102pub struct Reassembler {
103 pending: HashMap<u64, Pending>,
104}
105
106struct Pending {
107 total: u32,
108 chunks: Vec<Option<Vec<u8>>>,
109 received: u32,
110}
111
112impl Reassembler {
113 pub fn new() -> Self {
114 Self::default()
115 }
116
117 pub fn push(&mut self, chunk: Chunk) -> Result<Option<Vec<u8>>, ChunkError> {
120 if chunk.total_chunks == 0 || chunk.chunk_idx >= chunk.total_chunks {
121 return Err(ChunkError::InvalidIndex { idx: chunk.chunk_idx, total: chunk.total_chunks });
122 }
123 let entry = self.pending.entry(chunk.message_id).or_insert_with(|| Pending {
124 total: chunk.total_chunks,
125 chunks: (0..chunk.total_chunks).map(|_| None).collect(),
126 received: 0,
127 });
128 if entry.total != chunk.total_chunks {
129 return Err(ChunkError::SizeMismatch {
130 message_id: chunk.message_id,
131 previous: entry.total,
132 now: chunk.total_chunks,
133 });
134 }
135 let slot = &mut entry.chunks[chunk.chunk_idx as usize];
136 if slot.is_none() {
137 *slot = Some(chunk.payload);
138 entry.received += 1;
139 }
140 if entry.received < entry.total {
141 return Ok(None);
142 }
143 let pending = self.pending.remove(&chunk.message_id).expect("just present");
145 let total_len: usize = pending.chunks.iter().filter_map(|c| c.as_ref()).map(|v| v.len()).sum();
146 let mut out = Vec::with_capacity(total_len);
147 for buf in pending.chunks.into_iter().flatten() {
148 out.extend_from_slice(&buf);
149 }
150 Ok(Some(out))
151 }
152
153 pub fn pending_message_count(&self) -> usize {
154 self.pending.len()
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161
162 #[test]
163 fn split_packs_chunks_with_indices() {
164 let c = Chunker::new(3);
165 let chunks = c.split(42, b"abcdefgh");
166 assert_eq!(chunks.len(), 3);
167 assert_eq!(
168 chunks[0],
169 Chunk { message_id: 42, chunk_idx: 0, total_chunks: 3, payload: b"abc".to_vec() }
170 );
171 assert_eq!(
172 chunks[1],
173 Chunk { message_id: 42, chunk_idx: 1, total_chunks: 3, payload: b"def".to_vec() }
174 );
175 assert_eq!(
176 chunks[2],
177 Chunk { message_id: 42, chunk_idx: 2, total_chunks: 3, payload: b"gh".to_vec() }
178 );
179 }
180
181 #[test]
182 fn empty_payload_yields_single_chunk() {
183 let chunks = Chunker::new(8).split(7, b"");
184 assert_eq!(chunks.len(), 1);
185 assert!(chunks[0].payload.is_empty());
186 assert_eq!(chunks[0].total_chunks, 1);
187 }
188
189 #[test]
190 fn round_trip_split_then_reassemble() {
191 let c = Chunker::new(4);
192 let payload = b"hello world! this is a longer payload than 4 bytes.";
193 let chunks = c.split(1, payload);
194 let mut r = Reassembler::new();
195 let mut completed = None;
196 for chunk in chunks {
197 if let Some(full) = r.push(chunk).unwrap() {
198 completed = Some(full);
199 }
200 }
201 assert_eq!(completed.unwrap(), payload);
202 assert_eq!(r.pending_message_count(), 0);
203 }
204
205 #[test]
206 fn out_of_order_chunks_reassemble_correctly() {
207 let c = Chunker::new(2);
208 let mut chunks = c.split(7, b"abcdef");
209 chunks.reverse(); let mut r = Reassembler::new();
211 let mut full = None;
212 for chunk in chunks {
213 if let Some(payload) = r.push(chunk).unwrap() {
214 full = Some(payload);
215 }
216 }
217 assert_eq!(full.unwrap(), b"abcdef");
218 }
219
220 #[test]
221 fn duplicate_chunks_are_idempotent() {
222 let c = Chunker::new(2);
223 let chunks = c.split(9, b"abcd");
224 let mut r = Reassembler::new();
225 let _ = r.push(chunks[0].clone()).unwrap();
226 let _ = r.push(chunks[0].clone()).unwrap();
228 let full = r.push(chunks[1].clone()).unwrap();
229 assert_eq!(full.unwrap(), b"abcd");
230 }
231
232 #[test]
233 fn invalid_index_errors() {
234 let mut r = Reassembler::new();
235 let bad = Chunk { message_id: 1, chunk_idx: 5, total_chunks: 3, payload: vec![] };
236 let result = r.push(bad);
237 assert!(matches!(result, Err(ChunkError::InvalidIndex { .. })));
238 }
239
240 #[test]
241 fn size_mismatch_errors() {
242 let mut r = Reassembler::new();
243 let _ = r.push(Chunk { message_id: 1, chunk_idx: 0, total_chunks: 3, payload: vec![1] });
244 let conflicting = Chunk { message_id: 1, chunk_idx: 1, total_chunks: 4, payload: vec![2] };
245 assert!(matches!(r.push(conflicting), Err(ChunkError::SizeMismatch { .. })));
246 }
247
248 #[test]
249 fn wire_round_trip() {
250 let c = Chunk { message_id: 0xdeadbeef, chunk_idx: 3, total_chunks: 7, payload: b"hello".to_vec() };
251 let bytes = c.to_wire();
252 let parsed = Chunk::from_wire(&bytes).unwrap();
253 assert_eq!(parsed, c);
254 }
255}