Skip to main content

atomr_remote/
chunking.rs

1//! Message chunking for payloads that exceed `maximum-frame-size`.
2//!
3//! Phase 5.F of `docs/full-port-plan.md`. Akka.NET parity:
4//! `Akka.Remote.Configuration.Maximum-Frame-Size` + per-PDU
5//! length-prefix split. Senders that produce payloads larger than
6//! `chunk_size` use [`Chunker::split`] to fragment into ordered
7//! chunks; receivers feed each chunk to [`Reassembler::push`] until
8//! `Some(Vec<u8>)` comes back.
9//!
10//! The wire envelope around chunks is a tiny `(message_id, chunk_idx,
11//! total_chunks, payload)` tuple. The remote codec wraps chunks in
12//! `AkkaPdu::Payload` like any other message.
13
14use std::collections::HashMap;
15use thiserror::Error;
16
17#[derive(Debug, Error)]
18#[non_exhaustive]
19pub enum ChunkError {
20    #[error("invalid chunk: idx={idx} >= total={total}")]
21    InvalidIndex { idx: u32, total: u32 },
22    #[error("size mismatch for message {message_id}: previously {previous}, now {now}")]
23    SizeMismatch { message_id: u64, previous: u32, now: u32 },
24}
25
26/// One fragment produced by [`Chunker::split`].
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct Chunk {
29    pub message_id: u64,
30    pub chunk_idx: u32,
31    pub total_chunks: u32,
32    pub payload: Vec<u8>,
33}
34
35impl Chunk {
36    /// Convenience: serialize to a `(header, payload)` pair so the
37    /// remote codec can frame them on the wire. Header is 16 bytes:
38    /// `[message_id u64 le][chunk_idx u32 le][total u32 le]`.
39    pub fn to_wire(&self) -> Vec<u8> {
40        let mut buf = Vec::with_capacity(16 + self.payload.len());
41        buf.extend_from_slice(&self.message_id.to_le_bytes());
42        buf.extend_from_slice(&self.chunk_idx.to_le_bytes());
43        buf.extend_from_slice(&self.total_chunks.to_le_bytes());
44        buf.extend_from_slice(&self.payload);
45        buf
46    }
47
48    pub fn from_wire(bytes: &[u8]) -> Option<Self> {
49        if bytes.len() < 16 {
50            return None;
51        }
52        let mut id_bytes = [0u8; 8];
53        id_bytes.copy_from_slice(&bytes[..8]);
54        let mut idx_bytes = [0u8; 4];
55        idx_bytes.copy_from_slice(&bytes[8..12]);
56        let mut total_bytes = [0u8; 4];
57        total_bytes.copy_from_slice(&bytes[12..16]);
58        Some(Self {
59            message_id: u64::from_le_bytes(id_bytes),
60            chunk_idx: u32::from_le_bytes(idx_bytes),
61            total_chunks: u32::from_le_bytes(total_bytes),
62            payload: bytes[16..].to_vec(),
63        })
64    }
65}
66
67/// Sender-side: split large payloads into ordered fragments.
68pub struct Chunker {
69    pub chunk_size: usize,
70}
71
72impl Chunker {
73    pub fn new(chunk_size: usize) -> Self {
74        assert!(chunk_size >= 1, "chunk_size must be >= 1");
75        Self { chunk_size }
76    }
77
78    /// Split `payload` into fragments. Each fragment shares the same
79    /// `message_id`. Returns at least one chunk even for an empty
80    /// payload (`total_chunks = 1`, empty payload).
81    pub fn split(&self, message_id: u64, payload: &[u8]) -> Vec<Chunk> {
82        if payload.is_empty() {
83            return vec![Chunk { message_id, chunk_idx: 0, total_chunks: 1, payload: Vec::new() }];
84        }
85        let total = payload.len().div_ceil(self.chunk_size);
86        let mut out = Vec::with_capacity(total);
87        for (i, chunk_payload) in payload.chunks(self.chunk_size).enumerate() {
88            out.push(Chunk {
89                message_id,
90                chunk_idx: i as u32,
91                total_chunks: total as u32,
92                payload: chunk_payload.to_vec(),
93            });
94        }
95        out
96    }
97}
98
99/// Receiver-side: collect chunks for the same `message_id` until the
100/// full set arrives, then return the reassembled payload.
101#[derive(Default)]
102pub struct Reassembler {
103    pending: HashMap<u64, Pending>,
104}
105
106struct Pending {
107    total: u32,
108    chunks: Vec<Option<Vec<u8>>>,
109    received: u32,
110}
111
112impl Reassembler {
113    pub fn new() -> Self {
114        Self::default()
115    }
116
117    /// Feed one chunk. Returns `Some(payload)` when the message is
118    /// complete, `None` while still waiting for siblings.
119    pub fn push(&mut self, chunk: Chunk) -> Result<Option<Vec<u8>>, ChunkError> {
120        if chunk.total_chunks == 0 || chunk.chunk_idx >= chunk.total_chunks {
121            return Err(ChunkError::InvalidIndex { idx: chunk.chunk_idx, total: chunk.total_chunks });
122        }
123        let entry = self.pending.entry(chunk.message_id).or_insert_with(|| Pending {
124            total: chunk.total_chunks,
125            chunks: (0..chunk.total_chunks).map(|_| None).collect(),
126            received: 0,
127        });
128        if entry.total != chunk.total_chunks {
129            return Err(ChunkError::SizeMismatch {
130                message_id: chunk.message_id,
131                previous: entry.total,
132                now: chunk.total_chunks,
133            });
134        }
135        let slot = &mut entry.chunks[chunk.chunk_idx as usize];
136        if slot.is_none() {
137            *slot = Some(chunk.payload);
138            entry.received += 1;
139        }
140        if entry.received < entry.total {
141            return Ok(None);
142        }
143        // All chunks present — concatenate in order.
144        let pending = self.pending.remove(&chunk.message_id).expect("just present");
145        let total_len: usize = pending.chunks.iter().filter_map(|c| c.as_ref()).map(|v| v.len()).sum();
146        let mut out = Vec::with_capacity(total_len);
147        for buf in pending.chunks.into_iter().flatten() {
148            out.extend_from_slice(&buf);
149        }
150        Ok(Some(out))
151    }
152
153    pub fn pending_message_count(&self) -> usize {
154        self.pending.len()
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[test]
163    fn split_packs_chunks_with_indices() {
164        let c = Chunker::new(3);
165        let chunks = c.split(42, b"abcdefgh");
166        assert_eq!(chunks.len(), 3);
167        assert_eq!(
168            chunks[0],
169            Chunk { message_id: 42, chunk_idx: 0, total_chunks: 3, payload: b"abc".to_vec() }
170        );
171        assert_eq!(
172            chunks[1],
173            Chunk { message_id: 42, chunk_idx: 1, total_chunks: 3, payload: b"def".to_vec() }
174        );
175        assert_eq!(
176            chunks[2],
177            Chunk { message_id: 42, chunk_idx: 2, total_chunks: 3, payload: b"gh".to_vec() }
178        );
179    }
180
181    #[test]
182    fn empty_payload_yields_single_chunk() {
183        let chunks = Chunker::new(8).split(7, b"");
184        assert_eq!(chunks.len(), 1);
185        assert!(chunks[0].payload.is_empty());
186        assert_eq!(chunks[0].total_chunks, 1);
187    }
188
189    #[test]
190    fn round_trip_split_then_reassemble() {
191        let c = Chunker::new(4);
192        let payload = b"hello world! this is a longer payload than 4 bytes.";
193        let chunks = c.split(1, payload);
194        let mut r = Reassembler::new();
195        let mut completed = None;
196        for chunk in chunks {
197            if let Some(full) = r.push(chunk).unwrap() {
198                completed = Some(full);
199            }
200        }
201        assert_eq!(completed.unwrap(), payload);
202        assert_eq!(r.pending_message_count(), 0);
203    }
204
205    #[test]
206    fn out_of_order_chunks_reassemble_correctly() {
207        let c = Chunker::new(2);
208        let mut chunks = c.split(7, b"abcdef");
209        chunks.reverse(); // hand them to the receiver in reverse order
210        let mut r = Reassembler::new();
211        let mut full = None;
212        for chunk in chunks {
213            if let Some(payload) = r.push(chunk).unwrap() {
214                full = Some(payload);
215            }
216        }
217        assert_eq!(full.unwrap(), b"abcdef");
218    }
219
220    #[test]
221    fn duplicate_chunks_are_idempotent() {
222        let c = Chunker::new(2);
223        let chunks = c.split(9, b"abcd");
224        let mut r = Reassembler::new();
225        let _ = r.push(chunks[0].clone()).unwrap();
226        // Re-push the same chunk; receiver count shouldn't double.
227        let _ = r.push(chunks[0].clone()).unwrap();
228        let full = r.push(chunks[1].clone()).unwrap();
229        assert_eq!(full.unwrap(), b"abcd");
230    }
231
232    #[test]
233    fn invalid_index_errors() {
234        let mut r = Reassembler::new();
235        let bad = Chunk { message_id: 1, chunk_idx: 5, total_chunks: 3, payload: vec![] };
236        let result = r.push(bad);
237        assert!(matches!(result, Err(ChunkError::InvalidIndex { .. })));
238    }
239
240    #[test]
241    fn size_mismatch_errors() {
242        let mut r = Reassembler::new();
243        let _ = r.push(Chunk { message_id: 1, chunk_idx: 0, total_chunks: 3, payload: vec![1] });
244        let conflicting = Chunk { message_id: 1, chunk_idx: 1, total_chunks: 4, payload: vec![2] };
245        assert!(matches!(r.push(conflicting), Err(ChunkError::SizeMismatch { .. })));
246    }
247
248    #[test]
249    fn wire_round_trip() {
250        let c = Chunk { message_id: 0xdeadbeef, chunk_idx: 3, total_chunks: 7, payload: b"hello".to_vec() };
251        let bytes = c.to_wire();
252        let parsed = Chunk::from_wire(&bytes).unwrap();
253        assert_eq!(parsed, c);
254    }
255}