Skip to main content

pushwire_core/
fragments.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5use uuid::Uuid;
6
7/// Header: [fragment_id:16][total:2][index:2]
8pub const FRAGMENT_HEADER_LEN: usize = 16 + 2 + 2;
9
10#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
11pub struct FragmentShard {
12    pub fragment_id: Uuid,
13    pub total: u16,
14    pub index: u16,
15    pub payload: Vec<u8>,
16}
17
18#[derive(Debug, Error, PartialEq, Eq)]
19pub enum FragmentError {
20    #[error("invalid fragment size")]
21    InvalidSize,
22    #[error("duplicate fragment {index}/{total}")]
23    Duplicate { index: u16, total: u16 },
24    #[error("unexpected total mismatch")]
25    TotalMismatch,
26}
27
28/// Assembles fragments keyed by fragment_id until all parts arrive.
29#[derive(Default, Debug)]
30pub struct FragmentAssembler {
31    inflight: HashMap<Uuid, Inflight>,
32    order: Vec<Uuid>,
33}
34
35/// Retention/cleanup policy for fragments.
36#[derive(Debug, Clone, Copy)]
37pub struct FragmentRetention {
38    pub max_inflight: usize,
39}
40
41impl Default for FragmentRetention {
42    fn default() -> Self {
43        Self { max_inflight: 1024 }
44    }
45}
46
47#[derive(Debug)]
48struct Inflight {
49    total: u16,
50    received: Vec<Option<Vec<u8>>>,
51    seen: u16,
52}
53
54impl FragmentAssembler {
55    pub fn new() -> Self {
56        Self {
57            inflight: HashMap::new(),
58            order: Vec::new(),
59        }
60    }
61
62    /// Push a new fragment; returns Some(full_payload) when completed.
63    pub fn push(&mut self, shard: FragmentShard) -> Result<Option<Vec<u8>>, FragmentError> {
64        if shard.total == 0 {
65            return Err(FragmentError::InvalidSize);
66        }
67
68        let entry = self
69            .inflight
70            .entry(shard.fragment_id)
71            .or_insert_with(|| Inflight::new(shard.total));
72        if entry.seen == 0 {
73            self.order.push(shard.fragment_id);
74        }
75
76        if entry.total != shard.total {
77            return Err(FragmentError::TotalMismatch);
78        }
79
80        if shard.index as usize >= entry.received.len() {
81            return Err(FragmentError::InvalidSize);
82        }
83
84        if entry.received[shard.index as usize].is_some() {
85            return Err(FragmentError::Duplicate {
86                index: shard.index,
87                total: shard.total,
88            });
89        }
90
91        entry.received[shard.index as usize] = Some(shard.payload);
92        entry.seen += 1;
93
94        if entry.seen == shard.total {
95            let mut buf = Vec::new();
96            for part in entry.received.iter_mut() {
97                if let Some(mut chunk) = part.take() {
98                    buf.append(&mut chunk);
99                }
100            }
101            self.inflight.remove(&shard.fragment_id);
102            self.order.retain(|id| id != &shard.fragment_id);
103            Ok(Some(buf))
104        } else {
105            Ok(None)
106        }
107    }
108
109    pub fn inflight(&self) -> usize {
110        self.inflight.len()
111    }
112
113    /// Apply retention: evict oldest if exceeding max inflight.
114    pub fn enforce_retention(&mut self, retention: FragmentRetention) {
115        if self.inflight.len() <= retention.max_inflight {
116            return;
117        }
118        while self.inflight.len() > retention.max_inflight {
119            if let Some(oldest) = self.order.first().cloned() {
120                self.inflight.remove(&oldest);
121                self.order.remove(0);
122            } else {
123                break;
124            }
125        }
126    }
127}
128
129impl Inflight {
130    fn new(total: u16) -> Self {
131        Self {
132            total,
133            received: vec![None; total as usize],
134            seen: 0,
135        }
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142
143    #[test]
144    fn assembles_in_order() {
145        let id = Uuid::new_v4();
146        let mut assembler = FragmentAssembler::new();
147        let total = 3;
148
149        let shard1 = FragmentShard {
150            fragment_id: id,
151            total,
152            index: 0,
153            payload: b"hel".to_vec(),
154        };
155        assert!(assembler.push(shard1).unwrap().is_none());
156
157        let shard2 = FragmentShard {
158            fragment_id: id,
159            total,
160            index: 1,
161            payload: b"lo ".to_vec(),
162        };
163        assert!(assembler.push(shard2).unwrap().is_none());
164
165        let shard3 = FragmentShard {
166            fragment_id: id,
167            total,
168            index: 2,
169            payload: b"world".to_vec(),
170        };
171        let assembled = assembler.push(shard3).unwrap().expect("should assemble");
172        assert_eq!(assembled, b"hello world");
173        assert_eq!(assembler.inflight(), 0);
174    }
175
176    #[test]
177    fn assembles_out_of_order() {
178        let id = Uuid::new_v4();
179        let mut assembler = FragmentAssembler::new();
180        let total = 2;
181
182        let shard2 = FragmentShard {
183            fragment_id: id,
184            total,
185            index: 1,
186            payload: b"world".to_vec(),
187        };
188        assert!(assembler.push(shard2).unwrap().is_none());
189
190        let shard1 = FragmentShard {
191            fragment_id: id,
192            total,
193            index: 0,
194            payload: b"hello ".to_vec(),
195        };
196        let assembled = assembler.push(shard1).unwrap().expect("should assemble");
197        assert_eq!(assembled, b"hello world");
198    }
199
200    #[test]
201    fn rejects_duplicates() {
202        let id = Uuid::new_v4();
203        let mut assembler = FragmentAssembler::new();
204        let total = 2;
205        let shard0 = FragmentShard {
206            fragment_id: id,
207            total,
208            index: 0,
209            payload: b"abc".to_vec(),
210        };
211        assert!(assembler.push(shard0.clone()).unwrap().is_none());
212        let err = assembler.push(shard0).unwrap_err();
213        assert_eq!(err, FragmentError::Duplicate { index: 0, total: 2 });
214
215        let shard1 = FragmentShard {
216            fragment_id: id,
217            total,
218            index: 1,
219            payload: b"xyz".to_vec(),
220        };
221        let assembled = assembler.push(shard1).unwrap().expect("should assemble");
222        assert_eq!(assembled, b"abcxyz");
223    }
224
225    #[test]
226    fn mismatched_total_is_error() {
227        let id = Uuid::new_v4();
228        let mut assembler = FragmentAssembler::new();
229        let shard1 = FragmentShard {
230            fragment_id: id,
231            total: 2,
232            index: 0,
233            payload: vec![],
234        };
235        assert!(assembler.push(shard1).is_ok());
236
237        let shard2 = FragmentShard {
238            fragment_id: id,
239            total: 3,
240            index: 1,
241            payload: vec![],
242        };
243        let err = assembler.push(shard2).unwrap_err();
244        assert_eq!(err, FragmentError::TotalMismatch);
245    }
246
247    #[test]
248    fn retention_evicts_oldest() {
249        let mut assembler = FragmentAssembler::new();
250        let retention = FragmentRetention { max_inflight: 2 };
251
252        let ids: Vec<_> = (0..3).map(|_| Uuid::new_v4()).collect();
253        for id in &ids {
254            let shard = FragmentShard {
255                fragment_id: *id,
256                total: 2,
257                index: 0,
258                payload: b"x".to_vec(),
259            };
260            let _ = assembler.push(shard).unwrap();
261        }
262
263        assembler.enforce_retention(retention);
264        assert!(assembler.inflight.contains_key(&ids[1]));
265        assert!(assembler.inflight.contains_key(&ids[2]));
266        assert!(!assembler.inflight.contains_key(&ids[0]));
267    }
268}