1use std::collections::HashMap;
13use thiserror::Error;
14
15#[derive(Debug, Error)]
16#[non_exhaustive]
17pub enum ChunkError {
18 #[error("invalid chunk: idx={idx} >= total={total}")]
19 InvalidIndex { idx: u32, total: u32 },
20 #[error("size mismatch for message {message_id}: previously {previous}, now {now}")]
21 SizeMismatch { message_id: u64, previous: u32, now: u32 },
22}
23
24#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct Chunk {
27 pub message_id: u64,
28 pub chunk_idx: u32,
29 pub total_chunks: u32,
30 pub payload: Vec<u8>,
31}
32
33impl Chunk {
34 pub fn to_wire(&self) -> Vec<u8> {
38 let mut buf = Vec::with_capacity(16 + self.payload.len());
39 buf.extend_from_slice(&self.message_id.to_le_bytes());
40 buf.extend_from_slice(&self.chunk_idx.to_le_bytes());
41 buf.extend_from_slice(&self.total_chunks.to_le_bytes());
42 buf.extend_from_slice(&self.payload);
43 buf
44 }
45
46 pub fn from_wire(bytes: &[u8]) -> Option<Self> {
47 if bytes.len() < 16 {
48 return None;
49 }
50 let mut id_bytes = [0u8; 8];
51 id_bytes.copy_from_slice(&bytes[..8]);
52 let mut idx_bytes = [0u8; 4];
53 idx_bytes.copy_from_slice(&bytes[8..12]);
54 let mut total_bytes = [0u8; 4];
55 total_bytes.copy_from_slice(&bytes[12..16]);
56 Some(Self {
57 message_id: u64::from_le_bytes(id_bytes),
58 chunk_idx: u32::from_le_bytes(idx_bytes),
59 total_chunks: u32::from_le_bytes(total_bytes),
60 payload: bytes[16..].to_vec(),
61 })
62 }
63}
64
65pub struct Chunker {
67 pub chunk_size: usize,
68}
69
70impl Chunker {
71 pub fn new(chunk_size: usize) -> Self {
72 assert!(chunk_size >= 1, "chunk_size must be >= 1");
73 Self { chunk_size }
74 }
75
76 pub fn split(&self, message_id: u64, payload: &[u8]) -> Vec<Chunk> {
80 if payload.is_empty() {
81 return vec![Chunk { message_id, chunk_idx: 0, total_chunks: 1, payload: Vec::new() }];
82 }
83 let total = payload.len().div_ceil(self.chunk_size);
84 let mut out = Vec::with_capacity(total);
85 for (i, chunk_payload) in payload.chunks(self.chunk_size).enumerate() {
86 out.push(Chunk {
87 message_id,
88 chunk_idx: i as u32,
89 total_chunks: total as u32,
90 payload: chunk_payload.to_vec(),
91 });
92 }
93 out
94 }
95}
96
97#[derive(Default)]
100pub struct Reassembler {
101 pending: HashMap<u64, Pending>,
102}
103
104struct Pending {
105 total: u32,
106 chunks: Vec<Option<Vec<u8>>>,
107 received: u32,
108 started_at: std::time::Instant,
109}
110
111impl Reassembler {
112 pub fn new() -> Self {
113 Self::default()
114 }
115
116 pub fn push(&mut self, chunk: Chunk) -> Result<Option<Vec<u8>>, ChunkError> {
119 if chunk.total_chunks == 0 || chunk.chunk_idx >= chunk.total_chunks {
120 return Err(ChunkError::InvalidIndex { idx: chunk.chunk_idx, total: chunk.total_chunks });
121 }
122 let entry = self.pending.entry(chunk.message_id).or_insert_with(|| Pending {
123 total: chunk.total_chunks,
124 chunks: (0..chunk.total_chunks).map(|_| None).collect(),
125 received: 0,
126 started_at: std::time::Instant::now(),
127 });
128 if entry.total != chunk.total_chunks {
129 return Err(ChunkError::SizeMismatch {
130 message_id: chunk.message_id,
131 previous: entry.total,
132 now: chunk.total_chunks,
133 });
134 }
135 let slot = &mut entry.chunks[chunk.chunk_idx as usize];
136 if slot.is_none() {
137 *slot = Some(chunk.payload);
138 entry.received += 1;
139 }
140 if entry.received < entry.total {
141 return Ok(None);
142 }
143 let pending = self.pending.remove(&chunk.message_id).expect("just present");
145 let total_len: usize = pending.chunks.iter().filter_map(|c| c.as_ref()).map(|v| v.len()).sum();
146 let mut out = Vec::with_capacity(total_len);
147 for buf in pending.chunks.into_iter().flatten() {
148 out.extend_from_slice(&buf);
149 }
150 Ok(Some(out))
151 }
152
153 pub fn pending_message_count(&self) -> usize {
154 self.pending.len()
155 }
156
157 pub fn gc_expired(&mut self, older_than: std::time::Duration) -> usize {
162 let now = std::time::Instant::now();
163 let before = self.pending.len();
164 self.pending.retain(|_, p| now.duration_since(p.started_at) < older_than);
165 before - self.pending.len()
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 #[test]
174 fn split_packs_chunks_with_indices() {
175 let c = Chunker::new(3);
176 let chunks = c.split(42, b"abcdefgh");
177 assert_eq!(chunks.len(), 3);
178 assert_eq!(
179 chunks[0],
180 Chunk { message_id: 42, chunk_idx: 0, total_chunks: 3, payload: b"abc".to_vec() }
181 );
182 assert_eq!(
183 chunks[1],
184 Chunk { message_id: 42, chunk_idx: 1, total_chunks: 3, payload: b"def".to_vec() }
185 );
186 assert_eq!(
187 chunks[2],
188 Chunk { message_id: 42, chunk_idx: 2, total_chunks: 3, payload: b"gh".to_vec() }
189 );
190 }
191
192 #[test]
193 fn empty_payload_yields_single_chunk() {
194 let chunks = Chunker::new(8).split(7, b"");
195 assert_eq!(chunks.len(), 1);
196 assert!(chunks[0].payload.is_empty());
197 assert_eq!(chunks[0].total_chunks, 1);
198 }
199
200 #[test]
201 fn round_trip_split_then_reassemble() {
202 let c = Chunker::new(4);
203 let payload = b"hello world! this is a longer payload than 4 bytes.";
204 let chunks = c.split(1, payload);
205 let mut r = Reassembler::new();
206 let mut completed = None;
207 for chunk in chunks {
208 if let Some(full) = r.push(chunk).unwrap() {
209 completed = Some(full);
210 }
211 }
212 assert_eq!(completed.unwrap(), payload);
213 assert_eq!(r.pending_message_count(), 0);
214 }
215
216 #[test]
217 fn out_of_order_chunks_reassemble_correctly() {
218 let c = Chunker::new(2);
219 let mut chunks = c.split(7, b"abcdef");
220 chunks.reverse(); let mut r = Reassembler::new();
222 let mut full = None;
223 for chunk in chunks {
224 if let Some(payload) = r.push(chunk).unwrap() {
225 full = Some(payload);
226 }
227 }
228 assert_eq!(full.unwrap(), b"abcdef");
229 }
230
231 #[test]
232 fn duplicate_chunks_are_idempotent() {
233 let c = Chunker::new(2);
234 let chunks = c.split(9, b"abcd");
235 let mut r = Reassembler::new();
236 let _ = r.push(chunks[0].clone()).unwrap();
237 let _ = r.push(chunks[0].clone()).unwrap();
239 let full = r.push(chunks[1].clone()).unwrap();
240 assert_eq!(full.unwrap(), b"abcd");
241 }
242
243 #[test]
244 fn invalid_index_errors() {
245 let mut r = Reassembler::new();
246 let bad = Chunk { message_id: 1, chunk_idx: 5, total_chunks: 3, payload: vec![] };
247 let result = r.push(bad);
248 assert!(matches!(result, Err(ChunkError::InvalidIndex { .. })));
249 }
250
251 #[test]
252 fn size_mismatch_errors() {
253 let mut r = Reassembler::new();
254 let _ = r.push(Chunk { message_id: 1, chunk_idx: 0, total_chunks: 3, payload: vec![1] });
255 let conflicting = Chunk { message_id: 1, chunk_idx: 1, total_chunks: 4, payload: vec![2] };
256 assert!(matches!(r.push(conflicting), Err(ChunkError::SizeMismatch { .. })));
257 }
258
259 #[test]
260 fn wire_round_trip() {
261 let c = Chunk { message_id: 0xdeadbeef, chunk_idx: 3, total_chunks: 7, payload: b"hello".to_vec() };
262 let bytes = c.to_wire();
263 let parsed = Chunk::from_wire(&bytes).unwrap();
264 assert_eq!(parsed, c);
265 }
266
267 #[test]
268 fn gc_expired_evicts_old_partials() {
269 let mut r = Reassembler::new();
270 let _ = r.push(Chunk { message_id: 99, chunk_idx: 0, total_chunks: 2, payload: vec![1] });
272 assert_eq!(r.pending_message_count(), 1);
273 std::thread::sleep(std::time::Duration::from_millis(2));
275 let swept = r.gc_expired(std::time::Duration::from_millis(1));
276 assert_eq!(swept, 1);
277 assert_eq!(r.pending_message_count(), 0);
278 }
279
280 #[test]
281 fn gc_expired_keeps_fresh_partials() {
282 let mut r = Reassembler::new();
283 let _ = r.push(Chunk { message_id: 1, chunk_idx: 0, total_chunks: 2, payload: vec![1] });
284 let swept = r.gc_expired(std::time::Duration::from_secs(60));
285 assert_eq!(swept, 0);
286 assert_eq!(r.pending_message_count(), 1);
287 }
288}