phantom_protocol/transport/
fragmentation.rs1use borsh::{BorshDeserialize, BorshSerialize};
2use std::collections::HashMap;
3use std::time::{Duration, Instant};
4
5const MAX_UDP_PAYLOAD: usize = 1200; pub const MAX_REASSEMBLED_LEN: usize = 256 * 1024;
11
12pub const MAX_TOTAL_CHUNKS: u16 = (MAX_REASSEMBLED_LEN / MAX_UDP_PAYLOAD + 1) as u16;
16
17pub const MAX_CONCURRENT_ASSEMBLIES: usize = 256;
23
24#[derive(BorshSerialize, BorshDeserialize, Debug, Clone)]
26pub struct CryptoFrame {
27 pub session_id: [u8; 16], pub packet_id: u32,
29 pub chunk_index: u16,
30 pub total_chunks: u16,
31 pub payload: Vec<u8>,
32}
33
34pub struct FragmentAssembler {
35 assemblies: HashMap<([u8; 16], u32), AssemblyState>,
37}
38
39struct AssemblyState {
40 chunks: HashMap<u16, Vec<u8>>,
41 total_chunks: u16,
42 last_update: Instant,
43}
44
45impl Default for FragmentAssembler {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51impl FragmentAssembler {
52 pub fn new() -> Self {
53 Self {
54 assemblies: HashMap::new(),
55 }
56 }
57
58 pub fn process_chunk(&mut self, frame: CryptoFrame) -> Option<Vec<u8>> {
61 if frame.total_chunks == 0
69 || frame.total_chunks > MAX_TOTAL_CHUNKS
70 || frame.chunk_index >= frame.total_chunks
71 || frame.payload.len() > MAX_UDP_PAYLOAD
72 {
73 return None;
74 }
75
76 let key = (frame.session_id, frame.packet_id);
77
78 if !self.assemblies.contains_key(&key) && self.assemblies.len() >= MAX_CONCURRENT_ASSEMBLIES
84 {
85 self.evict_stalest();
86 }
87
88 let is_complete = {
89 let state = self.assemblies.entry(key).or_insert_with(|| AssemblyState {
90 chunks: HashMap::new(),
91 total_chunks: frame.total_chunks,
92 last_update: Instant::now(),
93 });
94
95 state.last_update = Instant::now();
96 state.chunks.insert(frame.chunk_index, frame.payload);
97
98 state.chunks.len() == state.total_chunks as usize
99 };
100
101 if is_complete {
102 #[allow(clippy::unwrap_used, clippy::disallowed_methods)]
106 let state = self.assemblies.remove(&key).unwrap();
107 let mut total_size = 0;
108 for i in 0..state.total_chunks {
109 if let Some(chunk) = state.chunks.get(&i) {
110 total_size += chunk.len();
111 } else {
112 return None;
113 }
114 }
115
116 let mut packet = Vec::with_capacity(total_size);
117 for i in 0..state.total_chunks {
118 #[allow(clippy::unwrap_used, clippy::disallowed_methods)]
122 packet.extend_from_slice(state.chunks.get(&i).unwrap());
123 }
124
125 return Some(packet);
126 }
127
128 None
129 }
130
131 fn evict_stalest(&mut self) {
136 if let Some((&stalest_key, _)) = self
137 .assemblies
138 .iter()
139 .min_by_key(|(_, state)| state.last_update)
140 {
141 self.assemblies.remove(&stalest_key);
142 }
143 }
144
145 pub fn len(&self) -> usize {
147 self.assemblies.len()
148 }
149
150 pub fn is_empty(&self) -> bool {
152 self.assemblies.is_empty()
153 }
154
155 pub fn get_nacks_and_evict(&mut self) -> Vec<([u8; 16], u32, Vec<u16>)> {
158 let now = Instant::now();
159 let mut nacks = Vec::new();
160 let mut to_remove = Vec::new();
161
162 for (key, state) in self.assemblies.iter() {
163 let elapsed = now.duration_since(state.last_update);
164
165 if elapsed > Duration::from_millis(5000) {
166 to_remove.push(*key);
168 } else if elapsed > Duration::from_millis(50) {
169 let mut missing = Vec::new();
171 for i in 0..state.total_chunks {
172 if !state.chunks.contains_key(&i) {
173 missing.push(i);
174 }
175 }
176 if !missing.is_empty() {
177 nacks.push((key.0, key.1, missing));
178 }
179 }
180 }
181
182 for k in to_remove {
183 self.assemblies.remove(&k);
184 }
185
186 nacks
187 }
188}
189
190pub fn fragment_payload(session_id: [u8; 16], packet_id: u32, payload: &[u8]) -> Vec<CryptoFrame> {
192 let mut frames = Vec::new();
193 let chunks = payload.chunks(MAX_UDP_PAYLOAD);
194 let total_chunks = chunks.len() as u16;
195
196 for (i, chunk) in chunks.enumerate() {
197 frames.push(CryptoFrame {
198 session_id,
199 packet_id,
200 chunk_index: i as u16,
201 total_chunks,
202 payload: chunk.to_vec(),
203 });
204 }
205
206 frames
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 fn frame(packet_id: u32, idx: u16, total: u16, payload_len: usize) -> CryptoFrame {
214 CryptoFrame {
215 session_id: [0u8; 16],
216 packet_id,
217 chunk_index: idx,
218 total_chunks: total,
219 payload: vec![0xABu8; payload_len],
220 }
221 }
222
223 #[test]
224 fn fragment_reassemble_round_trip() {
225 let payload: Vec<u8> = (0..3000u32).map(|i| i as u8).collect();
226 let frames = fragment_payload([1u8; 16], 42, &payload);
227 assert!(frames.len() > 1, "3000 bytes must fragment");
228 let mut asm = FragmentAssembler::new();
229 let mut out = None;
230 for f in frames {
231 if let Some(p) = asm.process_chunk(f) {
232 out = Some(p);
233 }
234 }
235 assert_eq!(out.as_deref(), Some(payload.as_slice()));
236 assert!(asm.is_empty(), "completed assembly is removed");
237 }
238
239 #[test]
240 fn rejects_zero_total_chunks() {
241 let mut asm = FragmentAssembler::new();
242 assert!(asm.process_chunk(frame(1, 0, 0, 10)).is_none());
243 assert!(asm.is_empty(), "malformed frame must not open an assembly");
244 }
245
246 #[test]
247 fn rejects_out_of_range_chunk_index() {
248 let mut asm = FragmentAssembler::new();
249 assert!(asm.process_chunk(frame(1, 2, 2, 10)).is_none());
251 assert!(asm.is_empty());
252 }
253
254 #[test]
255 fn rejects_excessive_total_chunks() {
256 let mut asm = FragmentAssembler::new();
257 assert!(asm
258 .process_chunk(frame(1, 0, MAX_TOTAL_CHUNKS.saturating_add(1), 10))
259 .is_none());
260 assert!(asm.is_empty());
261 }
262
263 #[test]
264 fn rejects_oversized_fragment_payload() {
265 let mut asm = FragmentAssembler::new();
266 assert!(asm
267 .process_chunk(frame(1, 0, 4, MAX_UDP_PAYLOAD + 1))
268 .is_none());
269 assert!(asm.is_empty());
270 }
271
272 #[test]
273 fn caps_concurrent_assemblies() {
274 let mut asm = FragmentAssembler::new();
275 for packet_id in 0..(MAX_CONCURRENT_ASSEMBLIES as u32 * 4) {
278 assert!(asm.process_chunk(frame(packet_id, 0, 4, 10)).is_none());
279 assert!(
280 asm.len() <= MAX_CONCURRENT_ASSEMBLIES,
281 "assembly table exceeded its cap: {}",
282 asm.len()
283 );
284 }
285 assert_eq!(asm.len(), MAX_CONCURRENT_ASSEMBLIES);
286 }
287}