pushwire_core/
fragments.rs1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5use uuid::Uuid;
6
7pub const FRAGMENT_HEADER_LEN: usize = 16 + 2 + 2;
9
10#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
11pub struct FragmentShard {
12 pub fragment_id: Uuid,
13 pub total: u16,
14 pub index: u16,
15 pub payload: Vec<u8>,
16}
17
18#[derive(Debug, Error, PartialEq, Eq)]
19pub enum FragmentError {
20 #[error("invalid fragment size")]
21 InvalidSize,
22 #[error("duplicate fragment {index}/{total}")]
23 Duplicate { index: u16, total: u16 },
24 #[error("unexpected total mismatch")]
25 TotalMismatch,
26}
27
28#[derive(Default, Debug)]
30pub struct FragmentAssembler {
31 inflight: HashMap<Uuid, Inflight>,
32 order: Vec<Uuid>,
33}
34
35#[derive(Debug, Clone, Copy)]
37pub struct FragmentRetention {
38 pub max_inflight: usize,
39}
40
41impl Default for FragmentRetention {
42 fn default() -> Self {
43 Self { max_inflight: 1024 }
44 }
45}
46
47#[derive(Debug)]
48struct Inflight {
49 total: u16,
50 received: Vec<Option<Vec<u8>>>,
51 seen: u16,
52}
53
54impl FragmentAssembler {
55 pub fn new() -> Self {
56 Self {
57 inflight: HashMap::new(),
58 order: Vec::new(),
59 }
60 }
61
62 pub fn push(&mut self, shard: FragmentShard) -> Result<Option<Vec<u8>>, FragmentError> {
64 if shard.total == 0 {
65 return Err(FragmentError::InvalidSize);
66 }
67
68 let entry = self
69 .inflight
70 .entry(shard.fragment_id)
71 .or_insert_with(|| Inflight::new(shard.total));
72 if entry.seen == 0 {
73 self.order.push(shard.fragment_id);
74 }
75
76 if entry.total != shard.total {
77 return Err(FragmentError::TotalMismatch);
78 }
79
80 if shard.index as usize >= entry.received.len() {
81 return Err(FragmentError::InvalidSize);
82 }
83
84 if entry.received[shard.index as usize].is_some() {
85 return Err(FragmentError::Duplicate {
86 index: shard.index,
87 total: shard.total,
88 });
89 }
90
91 entry.received[shard.index as usize] = Some(shard.payload);
92 entry.seen += 1;
93
94 if entry.seen == shard.total {
95 let mut buf = Vec::new();
96 for part in entry.received.iter_mut() {
97 if let Some(mut chunk) = part.take() {
98 buf.append(&mut chunk);
99 }
100 }
101 self.inflight.remove(&shard.fragment_id);
102 self.order.retain(|id| id != &shard.fragment_id);
103 Ok(Some(buf))
104 } else {
105 Ok(None)
106 }
107 }
108
109 pub fn inflight(&self) -> usize {
110 self.inflight.len()
111 }
112
113 pub fn enforce_retention(&mut self, retention: FragmentRetention) {
115 if self.inflight.len() <= retention.max_inflight {
116 return;
117 }
118 while self.inflight.len() > retention.max_inflight {
119 if let Some(oldest) = self.order.first().cloned() {
120 self.inflight.remove(&oldest);
121 self.order.remove(0);
122 } else {
123 break;
124 }
125 }
126 }
127}
128
129impl Inflight {
130 fn new(total: u16) -> Self {
131 Self {
132 total,
133 received: vec![None; total as usize],
134 seen: 0,
135 }
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142
143 #[test]
144 fn assembles_in_order() {
145 let id = Uuid::new_v4();
146 let mut assembler = FragmentAssembler::new();
147 let total = 3;
148
149 let shard1 = FragmentShard {
150 fragment_id: id,
151 total,
152 index: 0,
153 payload: b"hel".to_vec(),
154 };
155 assert!(assembler.push(shard1).unwrap().is_none());
156
157 let shard2 = FragmentShard {
158 fragment_id: id,
159 total,
160 index: 1,
161 payload: b"lo ".to_vec(),
162 };
163 assert!(assembler.push(shard2).unwrap().is_none());
164
165 let shard3 = FragmentShard {
166 fragment_id: id,
167 total,
168 index: 2,
169 payload: b"world".to_vec(),
170 };
171 let assembled = assembler.push(shard3).unwrap().expect("should assemble");
172 assert_eq!(assembled, b"hello world");
173 assert_eq!(assembler.inflight(), 0);
174 }
175
176 #[test]
177 fn assembles_out_of_order() {
178 let id = Uuid::new_v4();
179 let mut assembler = FragmentAssembler::new();
180 let total = 2;
181
182 let shard2 = FragmentShard {
183 fragment_id: id,
184 total,
185 index: 1,
186 payload: b"world".to_vec(),
187 };
188 assert!(assembler.push(shard2).unwrap().is_none());
189
190 let shard1 = FragmentShard {
191 fragment_id: id,
192 total,
193 index: 0,
194 payload: b"hello ".to_vec(),
195 };
196 let assembled = assembler.push(shard1).unwrap().expect("should assemble");
197 assert_eq!(assembled, b"hello world");
198 }
199
200 #[test]
201 fn rejects_duplicates() {
202 let id = Uuid::new_v4();
203 let mut assembler = FragmentAssembler::new();
204 let total = 2;
205 let shard0 = FragmentShard {
206 fragment_id: id,
207 total,
208 index: 0,
209 payload: b"abc".to_vec(),
210 };
211 assert!(assembler.push(shard0.clone()).unwrap().is_none());
212 let err = assembler.push(shard0).unwrap_err();
213 assert_eq!(err, FragmentError::Duplicate { index: 0, total: 2 });
214
215 let shard1 = FragmentShard {
216 fragment_id: id,
217 total,
218 index: 1,
219 payload: b"xyz".to_vec(),
220 };
221 let assembled = assembler.push(shard1).unwrap().expect("should assemble");
222 assert_eq!(assembled, b"abcxyz");
223 }
224
225 #[test]
226 fn mismatched_total_is_error() {
227 let id = Uuid::new_v4();
228 let mut assembler = FragmentAssembler::new();
229 let shard1 = FragmentShard {
230 fragment_id: id,
231 total: 2,
232 index: 0,
233 payload: vec![],
234 };
235 assert!(assembler.push(shard1).is_ok());
236
237 let shard2 = FragmentShard {
238 fragment_id: id,
239 total: 3,
240 index: 1,
241 payload: vec![],
242 };
243 let err = assembler.push(shard2).unwrap_err();
244 assert_eq!(err, FragmentError::TotalMismatch);
245 }
246
247 #[test]
248 fn retention_evicts_oldest() {
249 let mut assembler = FragmentAssembler::new();
250 let retention = FragmentRetention { max_inflight: 2 };
251
252 let ids: Vec<_> = (0..3).map(|_| Uuid::new_v4()).collect();
253 for id in &ids {
254 let shard = FragmentShard {
255 fragment_id: *id,
256 total: 2,
257 index: 0,
258 payload: b"x".to_vec(),
259 };
260 let _ = assembler.push(shard).unwrap();
261 }
262
263 assembler.enforce_retention(retention);
264 assert!(assembler.inflight.contains_key(&ids[1]));
265 assert!(assembler.inflight.contains_key(&ids[2]));
266 assert!(!assembler.inflight.contains_key(&ids[0]));
267 }
268}