Skip to main content

phantom_protocol/transport/
fragmentation.rs

1use borsh::{BorshDeserialize, BorshSerialize};
2use std::collections::HashMap;
3use std::time::{Duration, Instant};
4
5const MAX_UDP_PAYLOAD: usize = 1200; // Leave room for IP/UDP headers and protocol overhead
6
7/// Largest logical packet the assembler will reassemble. Bounds the memory a
8/// single `(session_id, packet_id)` assembly can pin: at most
9/// `MAX_TOTAL_CHUNKS` chunks of `MAX_UDP_PAYLOAD` bytes each.
10pub const MAX_REASSEMBLED_LEN: usize = 256 * 1024;
11
12/// Maximum fragments per logical packet, derived from the reassembled-size cap.
13/// A frame declaring more than this (up to the `u16::MAX` the wire allows) is
14/// dropped, so an attacker cannot force a 65 535-entry chunk map.
15pub const MAX_TOTAL_CHUNKS: u16 = (MAX_REASSEMBLED_LEN / MAX_UDP_PAYLOAD + 1) as u16;
16
17/// Maximum number of in-flight (incomplete) assemblies tracked at once. Caps
18/// the memory an attacker can pin by spraying chunks across many distinct
19/// `(session_id, packet_id)` keys without ever completing a packet. The
20/// worst-case resident memory is therefore bounded by
21/// `MAX_CONCURRENT_ASSEMBLIES * MAX_REASSEMBLED_LEN` (≈ 64 MiB).
22pub const MAX_CONCURRENT_ASSEMBLIES: usize = 256;
23
24/// Represents a single chunk of a fragmented logical packet
25#[derive(BorshSerialize, BorshDeserialize, Debug, Clone)]
26pub struct CryptoFrame {
27    pub session_id: [u8; 16], // Derived from IP + Client ID hash or explicit cookie
28    pub packet_id: u32,
29    pub chunk_index: u16,
30    pub total_chunks: u16,
31    pub payload: Vec<u8>,
32}
33
34pub struct FragmentAssembler {
35    // Map of (SessionId, PacketId) -> (Received Chunks, Total Chunks, Last Update Time)
36    assemblies: HashMap<([u8; 16], u32), AssemblyState>,
37}
38
39struct AssemblyState {
40    chunks: HashMap<u16, Vec<u8>>,
41    total_chunks: u16,
42    last_update: Instant,
43}
44
45impl Default for FragmentAssembler {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51impl FragmentAssembler {
52    pub fn new() -> Self {
53        Self {
54            assemblies: HashMap::new(),
55        }
56    }
57
58    /// Process a new CryptoFrame chunk.
59    /// Returns Some(reassembled_packet) if this chunk completes the packet.
60    pub fn process_chunk(&mut self, frame: CryptoFrame) -> Option<Vec<u8>> {
61        // Reject malformed or abusive fragments up front — for a UDP reassembler
62        // a bad fragment is simply dropped. Without these bounds a peer could
63        // pin unbounded memory: a huge `total_chunks` (up to 65 535) inflates
64        // the per-assembly chunk map, an out-of-range `chunk_index` parks bytes
65        // in a slot completion never reaches, and an oversized `payload`
66        // (borsh-decoded, so not implicitly capped at the datagram MTU)
67        // amplifies each chunk.
68        if frame.total_chunks == 0
69            || frame.total_chunks > MAX_TOTAL_CHUNKS
70            || frame.chunk_index >= frame.total_chunks
71            || frame.payload.len() > MAX_UDP_PAYLOAD
72        {
73            return None;
74        }
75
76        let key = (frame.session_id, frame.packet_id);
77
78        // Bound the number of concurrent assemblies. If this frame would open a
79        // NEW assembly while the table is full, evict the stalest one first —
80        // dropping the most-abandoned partial (typically an attacker's spray or
81        // a dead transfer) rather than letting the table grow without limit or
82        // permanently locking out fresh packets.
83        if !self.assemblies.contains_key(&key) && self.assemblies.len() >= MAX_CONCURRENT_ASSEMBLIES
84        {
85            self.evict_stalest();
86        }
87
88        let is_complete = {
89            let state = self.assemblies.entry(key).or_insert_with(|| AssemblyState {
90                chunks: HashMap::new(),
91                total_chunks: frame.total_chunks,
92                last_update: Instant::now(),
93            });
94
95            state.last_update = Instant::now();
96            state.chunks.insert(frame.chunk_index, frame.payload);
97
98            state.chunks.len() == state.total_chunks as usize
99        };
100
101        if is_complete {
102            // PANIC-SAFETY: the `is_complete` branch above just inserted the
103            // entry under `key` via `entry(key).or_insert_with(...)` and we
104            // hold `&mut self` — nothing else can have removed it.
105            #[allow(clippy::unwrap_used, clippy::disallowed_methods)]
106            let state = self.assemblies.remove(&key).unwrap();
107            let mut total_size = 0;
108            for i in 0..state.total_chunks {
109                if let Some(chunk) = state.chunks.get(&i) {
110                    total_size += chunk.len();
111                } else {
112                    return None;
113                }
114            }
115
116            let mut packet = Vec::with_capacity(total_size);
117            for i in 0..state.total_chunks {
118                // PANIC-SAFETY: the preceding loop returned early if any
119                // chunk `i` was missing; reaching this loop proves every
120                // index in `0..total_chunks` is present.
121                #[allow(clippy::unwrap_used, clippy::disallowed_methods)]
122                packet.extend_from_slice(state.chunks.get(&i).unwrap());
123            }
124
125            return Some(packet);
126        }
127
128        None
129    }
130
131    /// Evict the single least-recently-updated assembly. Used to keep the table
132    /// at or below [`MAX_CONCURRENT_ASSEMBLIES`] when a new assembly arrives at
133    /// capacity (the periodic `get_nacks_and_evict` sweep only reclaims dead
134    /// entries on a timer, which is too slow under a deliberate spray).
135    fn evict_stalest(&mut self) {
136        if let Some((&stalest_key, _)) = self
137            .assemblies
138            .iter()
139            .min_by_key(|(_, state)| state.last_update)
140        {
141            self.assemblies.remove(&stalest_key);
142        }
143    }
144
145    /// Number of in-flight (incomplete) assemblies currently tracked.
146    pub fn len(&self) -> usize {
147        self.assemblies.len()
148    }
149
150    /// Whether there are no in-flight assemblies.
151    pub fn is_empty(&self) -> bool {
152        self.assemblies.is_empty()
153    }
154
155    /// Check for timed out assemblies and return a list of missing chunks (NACK)
156    /// Also evicts purely dead assemblies (> 5000ms)
157    pub fn get_nacks_and_evict(&mut self) -> Vec<([u8; 16], u32, Vec<u16>)> {
158        let now = Instant::now();
159        let mut nacks = Vec::new();
160        let mut to_remove = Vec::new();
161
162        for (key, state) in self.assemblies.iter() {
163            let elapsed = now.duration_since(state.last_update);
164
165            if elapsed > Duration::from_millis(5000) {
166                // Dead
167                to_remove.push(*key);
168            } else if elapsed > Duration::from_millis(50) {
169                // NACK condition
170                let mut missing = Vec::new();
171                for i in 0..state.total_chunks {
172                    if !state.chunks.contains_key(&i) {
173                        missing.push(i);
174                    }
175                }
176                if !missing.is_empty() {
177                    nacks.push((key.0, key.1, missing));
178                }
179            }
180        }
181
182        for k in to_remove {
183            self.assemblies.remove(&k);
184        }
185
186        nacks
187    }
188}
189
190/// Split a large payload into CryptoFrame chunks
191pub fn fragment_payload(session_id: [u8; 16], packet_id: u32, payload: &[u8]) -> Vec<CryptoFrame> {
192    let mut frames = Vec::new();
193    let chunks = payload.chunks(MAX_UDP_PAYLOAD);
194    let total_chunks = chunks.len() as u16;
195
196    for (i, chunk) in chunks.enumerate() {
197        frames.push(CryptoFrame {
198            session_id,
199            packet_id,
200            chunk_index: i as u16,
201            total_chunks,
202            payload: chunk.to_vec(),
203        });
204    }
205
206    frames
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    fn frame(packet_id: u32, idx: u16, total: u16, payload_len: usize) -> CryptoFrame {
214        CryptoFrame {
215            session_id: [0u8; 16],
216            packet_id,
217            chunk_index: idx,
218            total_chunks: total,
219            payload: vec![0xABu8; payload_len],
220        }
221    }
222
223    #[test]
224    fn fragment_reassemble_round_trip() {
225        let payload: Vec<u8> = (0..3000u32).map(|i| i as u8).collect();
226        let frames = fragment_payload([1u8; 16], 42, &payload);
227        assert!(frames.len() > 1, "3000 bytes must fragment");
228        let mut asm = FragmentAssembler::new();
229        let mut out = None;
230        for f in frames {
231            if let Some(p) = asm.process_chunk(f) {
232                out = Some(p);
233            }
234        }
235        assert_eq!(out.as_deref(), Some(payload.as_slice()));
236        assert!(asm.is_empty(), "completed assembly is removed");
237    }
238
239    #[test]
240    fn rejects_zero_total_chunks() {
241        let mut asm = FragmentAssembler::new();
242        assert!(asm.process_chunk(frame(1, 0, 0, 10)).is_none());
243        assert!(asm.is_empty(), "malformed frame must not open an assembly");
244    }
245
246    #[test]
247    fn rejects_out_of_range_chunk_index() {
248        let mut asm = FragmentAssembler::new();
249        // chunk_index == total_chunks is out of the valid 0..total range.
250        assert!(asm.process_chunk(frame(1, 2, 2, 10)).is_none());
251        assert!(asm.is_empty());
252    }
253
254    #[test]
255    fn rejects_excessive_total_chunks() {
256        let mut asm = FragmentAssembler::new();
257        assert!(asm
258            .process_chunk(frame(1, 0, MAX_TOTAL_CHUNKS.saturating_add(1), 10))
259            .is_none());
260        assert!(asm.is_empty());
261    }
262
263    #[test]
264    fn rejects_oversized_fragment_payload() {
265        let mut asm = FragmentAssembler::new();
266        assert!(asm
267            .process_chunk(frame(1, 0, 4, MAX_UDP_PAYLOAD + 1))
268            .is_none());
269        assert!(asm.is_empty());
270    }
271
272    #[test]
273    fn caps_concurrent_assemblies() {
274        let mut asm = FragmentAssembler::new();
275        // Open far more distinct (never-completed, total_chunks=4) assemblies
276        // than the cap; the table must never exceed MAX_CONCURRENT_ASSEMBLIES.
277        for packet_id in 0..(MAX_CONCURRENT_ASSEMBLIES as u32 * 4) {
278            assert!(asm.process_chunk(frame(packet_id, 0, 4, 10)).is_none());
279            assert!(
280                asm.len() <= MAX_CONCURRENT_ASSEMBLIES,
281                "assembly table exceeded its cap: {}",
282                asm.len()
283            );
284        }
285        assert_eq!(asm.len(), MAX_CONCURRENT_ASSEMBLIES);
286    }
287}