Skip to main content

atomr_remote/
chunking.rs

1//! Message chunking for payloads that exceed `maximum-frame-size`.
2//!
3//! Per-PDU length-prefix split. Senders that produce payloads larger than
4//! `chunk_size` use [`Chunker::split`] to fragment into ordered chunks;
5//! receivers feed each chunk to [`Reassembler::push`] until
6//! `Some(Vec<u8>)` comes back.
7//!
8//! The wire envelope around chunks is a tiny `(message_id, chunk_idx,
9//! total_chunks, payload)` tuple. The remote codec wraps chunks in
10//! `AkkaPdu::Payload` like any other message.
11
12use std::collections::HashMap;
13use thiserror::Error;
14
15#[derive(Debug, Error)]
16#[non_exhaustive]
17pub enum ChunkError {
18    #[error("invalid chunk: idx={idx} >= total={total}")]
19    InvalidIndex { idx: u32, total: u32 },
20    #[error("size mismatch for message {message_id}: previously {previous}, now {now}")]
21    SizeMismatch { message_id: u64, previous: u32, now: u32 },
22}
23
24/// One fragment produced by [`Chunker::split`].
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct Chunk {
27    pub message_id: u64,
28    pub chunk_idx: u32,
29    pub total_chunks: u32,
30    pub payload: Vec<u8>,
31}
32
33impl Chunk {
34    /// Convenience: serialize to a `(header, payload)` pair so the
35    /// remote codec can frame them on the wire. Header is 16 bytes:
36    /// `[message_id u64 le][chunk_idx u32 le][total u32 le]`.
37    pub fn to_wire(&self) -> Vec<u8> {
38        let mut buf = Vec::with_capacity(16 + self.payload.len());
39        buf.extend_from_slice(&self.message_id.to_le_bytes());
40        buf.extend_from_slice(&self.chunk_idx.to_le_bytes());
41        buf.extend_from_slice(&self.total_chunks.to_le_bytes());
42        buf.extend_from_slice(&self.payload);
43        buf
44    }
45
46    pub fn from_wire(bytes: &[u8]) -> Option<Self> {
47        if bytes.len() < 16 {
48            return None;
49        }
50        let mut id_bytes = [0u8; 8];
51        id_bytes.copy_from_slice(&bytes[..8]);
52        let mut idx_bytes = [0u8; 4];
53        idx_bytes.copy_from_slice(&bytes[8..12]);
54        let mut total_bytes = [0u8; 4];
55        total_bytes.copy_from_slice(&bytes[12..16]);
56        Some(Self {
57            message_id: u64::from_le_bytes(id_bytes),
58            chunk_idx: u32::from_le_bytes(idx_bytes),
59            total_chunks: u32::from_le_bytes(total_bytes),
60            payload: bytes[16..].to_vec(),
61        })
62    }
63}
64
65/// Sender-side: split large payloads into ordered fragments.
66pub struct Chunker {
67    pub chunk_size: usize,
68}
69
70impl Chunker {
71    pub fn new(chunk_size: usize) -> Self {
72        assert!(chunk_size >= 1, "chunk_size must be >= 1");
73        Self { chunk_size }
74    }
75
76    /// Split `payload` into fragments. Each fragment shares the same
77    /// `message_id`. Returns at least one chunk even for an empty
78    /// payload (`total_chunks = 1`, empty payload).
79    pub fn split(&self, message_id: u64, payload: &[u8]) -> Vec<Chunk> {
80        if payload.is_empty() {
81            return vec![Chunk { message_id, chunk_idx: 0, total_chunks: 1, payload: Vec::new() }];
82        }
83        let total = payload.len().div_ceil(self.chunk_size);
84        let mut out = Vec::with_capacity(total);
85        for (i, chunk_payload) in payload.chunks(self.chunk_size).enumerate() {
86            out.push(Chunk {
87                message_id,
88                chunk_idx: i as u32,
89                total_chunks: total as u32,
90                payload: chunk_payload.to_vec(),
91            });
92        }
93        out
94    }
95}
96
97/// Receiver-side: collect chunks for the same `message_id` until the
98/// full set arrives, then return the reassembled payload.
99#[derive(Default)]
100pub struct Reassembler {
101    pending: HashMap<u64, Pending>,
102}
103
104struct Pending {
105    total: u32,
106    chunks: Vec<Option<Vec<u8>>>,
107    received: u32,
108    started_at: std::time::Instant,
109}
110
111impl Reassembler {
112    pub fn new() -> Self {
113        Self::default()
114    }
115
116    /// Feed one chunk. Returns `Some(payload)` when the message is
117    /// complete, `None` while still waiting for siblings.
118    pub fn push(&mut self, chunk: Chunk) -> Result<Option<Vec<u8>>, ChunkError> {
119        if chunk.total_chunks == 0 || chunk.chunk_idx >= chunk.total_chunks {
120            return Err(ChunkError::InvalidIndex { idx: chunk.chunk_idx, total: chunk.total_chunks });
121        }
122        let entry = self.pending.entry(chunk.message_id).or_insert_with(|| Pending {
123            total: chunk.total_chunks,
124            chunks: (0..chunk.total_chunks).map(|_| None).collect(),
125            received: 0,
126            started_at: std::time::Instant::now(),
127        });
128        if entry.total != chunk.total_chunks {
129            return Err(ChunkError::SizeMismatch {
130                message_id: chunk.message_id,
131                previous: entry.total,
132                now: chunk.total_chunks,
133            });
134        }
135        let slot = &mut entry.chunks[chunk.chunk_idx as usize];
136        if slot.is_none() {
137            *slot = Some(chunk.payload);
138            entry.received += 1;
139        }
140        if entry.received < entry.total {
141            return Ok(None);
142        }
143        // All chunks present — concatenate in order.
144        let pending = self.pending.remove(&chunk.message_id).expect("just present");
145        let total_len: usize = pending.chunks.iter().filter_map(|c| c.as_ref()).map(|v| v.len()).sum();
146        let mut out = Vec::with_capacity(total_len);
147        for buf in pending.chunks.into_iter().flatten() {
148            out.extend_from_slice(&buf);
149        }
150        Ok(Some(out))
151    }
152
153    pub fn pending_message_count(&self) -> usize {
154        self.pending.len()
155    }
156
157    /// Discard partial reassemblies older than `older_than`. Returns
158    /// the count of entries swept. Call on a low-frequency tick to
159    /// keep the reassembler bounded against peers that fall silent
160    /// mid-message.
161    pub fn gc_expired(&mut self, older_than: std::time::Duration) -> usize {
162        let now = std::time::Instant::now();
163        let before = self.pending.len();
164        self.pending.retain(|_, p| now.duration_since(p.started_at) < older_than);
165        before - self.pending.len()
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    #[test]
174    fn split_packs_chunks_with_indices() {
175        let c = Chunker::new(3);
176        let chunks = c.split(42, b"abcdefgh");
177        assert_eq!(chunks.len(), 3);
178        assert_eq!(
179            chunks[0],
180            Chunk { message_id: 42, chunk_idx: 0, total_chunks: 3, payload: b"abc".to_vec() }
181        );
182        assert_eq!(
183            chunks[1],
184            Chunk { message_id: 42, chunk_idx: 1, total_chunks: 3, payload: b"def".to_vec() }
185        );
186        assert_eq!(
187            chunks[2],
188            Chunk { message_id: 42, chunk_idx: 2, total_chunks: 3, payload: b"gh".to_vec() }
189        );
190    }
191
192    #[test]
193    fn empty_payload_yields_single_chunk() {
194        let chunks = Chunker::new(8).split(7, b"");
195        assert_eq!(chunks.len(), 1);
196        assert!(chunks[0].payload.is_empty());
197        assert_eq!(chunks[0].total_chunks, 1);
198    }
199
200    #[test]
201    fn round_trip_split_then_reassemble() {
202        let c = Chunker::new(4);
203        let payload = b"hello world! this is a longer payload than 4 bytes.";
204        let chunks = c.split(1, payload);
205        let mut r = Reassembler::new();
206        let mut completed = None;
207        for chunk in chunks {
208            if let Some(full) = r.push(chunk).unwrap() {
209                completed = Some(full);
210            }
211        }
212        assert_eq!(completed.unwrap(), payload);
213        assert_eq!(r.pending_message_count(), 0);
214    }
215
216    #[test]
217    fn out_of_order_chunks_reassemble_correctly() {
218        let c = Chunker::new(2);
219        let mut chunks = c.split(7, b"abcdef");
220        chunks.reverse(); // hand them to the receiver in reverse order
221        let mut r = Reassembler::new();
222        let mut full = None;
223        for chunk in chunks {
224            if let Some(payload) = r.push(chunk).unwrap() {
225                full = Some(payload);
226            }
227        }
228        assert_eq!(full.unwrap(), b"abcdef");
229    }
230
231    #[test]
232    fn duplicate_chunks_are_idempotent() {
233        let c = Chunker::new(2);
234        let chunks = c.split(9, b"abcd");
235        let mut r = Reassembler::new();
236        let _ = r.push(chunks[0].clone()).unwrap();
237        // Re-push the same chunk; receiver count shouldn't double.
238        let _ = r.push(chunks[0].clone()).unwrap();
239        let full = r.push(chunks[1].clone()).unwrap();
240        assert_eq!(full.unwrap(), b"abcd");
241    }
242
243    #[test]
244    fn invalid_index_errors() {
245        let mut r = Reassembler::new();
246        let bad = Chunk { message_id: 1, chunk_idx: 5, total_chunks: 3, payload: vec![] };
247        let result = r.push(bad);
248        assert!(matches!(result, Err(ChunkError::InvalidIndex { .. })));
249    }
250
251    #[test]
252    fn size_mismatch_errors() {
253        let mut r = Reassembler::new();
254        let _ = r.push(Chunk { message_id: 1, chunk_idx: 0, total_chunks: 3, payload: vec![1] });
255        let conflicting = Chunk { message_id: 1, chunk_idx: 1, total_chunks: 4, payload: vec![2] };
256        assert!(matches!(r.push(conflicting), Err(ChunkError::SizeMismatch { .. })));
257    }
258
259    #[test]
260    fn wire_round_trip() {
261        let c = Chunk { message_id: 0xdeadbeef, chunk_idx: 3, total_chunks: 7, payload: b"hello".to_vec() };
262        let bytes = c.to_wire();
263        let parsed = Chunk::from_wire(&bytes).unwrap();
264        assert_eq!(parsed, c);
265    }
266
267    #[test]
268    fn gc_expired_evicts_old_partials() {
269        let mut r = Reassembler::new();
270        // Insert a partial that will be aged out.
271        let _ = r.push(Chunk { message_id: 99, chunk_idx: 0, total_chunks: 2, payload: vec![1] });
272        assert_eq!(r.pending_message_count(), 1);
273        // Sweep with a zero-age threshold: evicts immediately.
274        std::thread::sleep(std::time::Duration::from_millis(2));
275        let swept = r.gc_expired(std::time::Duration::from_millis(1));
276        assert_eq!(swept, 1);
277        assert_eq!(r.pending_message_count(), 0);
278    }
279
280    #[test]
281    fn gc_expired_keeps_fresh_partials() {
282        let mut r = Reassembler::new();
283        let _ = r.push(Chunk { message_id: 1, chunk_idx: 0, total_chunks: 2, payload: vec![1] });
284        let swept = r.gc_expired(std::time::Duration::from_secs(60));
285        assert_eq!(swept, 0);
286        assert_eq!(r.pending_message_count(), 1);
287    }
288}