aetheris_protocol/
reassembler.rs1use std::collections::{HashMap, hash_map};
4use std::time::Duration;
5use web_time::Instant;
6
7use crate::events::FragmentedEvent;
8use crate::types::ClientId;
9
10#[derive(Debug, Clone)]
12struct FragmentBuffer {
13 start_time: Instant,
15 total_fragments: u16,
17 fragments: Vec<Option<Vec<u8>>>,
19 count: u16,
21}
22
23impl FragmentBuffer {
24 fn new(total_fragments: u16) -> Option<Self> {
25 if total_fragments == 0 || total_fragments > crate::MAX_TOTAL_FRAGMENTS {
26 return None;
27 }
28
29 Some(Self {
30 start_time: Instant::now(),
31 total_fragments,
32 fragments: vec![None; total_fragments as usize],
33 count: 0,
34 })
35 }
36
37 fn add(&mut self, index: u16, payload: Vec<u8>) -> Option<Vec<u8>> {
38 let idx = index as usize;
39 if idx >= self.fragments.len() {
40 return None;
41 }
42
43 if self.fragments[idx].is_none() {
44 self.fragments[idx] = Some(payload);
45 self.count += 1;
46 }
47
48 if self.count == self.total_fragments {
49 let mut full_payload = Vec::new();
50 for frag in self.fragments.drain(..) {
51 full_payload.extend(frag.unwrap());
52 }
53 Some(full_payload)
54 } else {
55 None
56 }
57 }
58
59 fn is_stale(&self, timeout: Duration) -> bool {
60 self.start_time.elapsed() > timeout
61 }
62}
63
64#[derive(Debug, Default, Clone)]
66pub struct Reassembler {
67 buffers: HashMap<(ClientId, u32), FragmentBuffer>,
69 timeout: Duration,
71}
72
73impl Reassembler {
74 #[must_use]
76 pub fn new() -> Self {
77 Self {
78 buffers: HashMap::new(),
79 timeout: Duration::from_secs(5),
80 }
81 }
82
83 #[must_use]
85 pub fn with_timeout(mut self, timeout: Duration) -> Self {
86 self.timeout = timeout;
87 self
88 }
89
90 pub fn ingest(&mut self, client_id: ClientId, event: FragmentedEvent) -> Option<Vec<u8>> {
95 if event.total_fragments == 0 || event.total_fragments > crate::MAX_TOTAL_FRAGMENTS {
97 tracing::warn!(
98 "Rejecting fragment with invalid total_fragments: {}",
99 event.total_fragments
100 );
101 return None;
102 }
103
104 let key = (client_id, event.message_id);
105
106 let buffer = match self.buffers.entry(key) {
107 hash_map::Entry::Occupied(e) => e.into_mut(),
108 hash_map::Entry::Vacant(e) => match FragmentBuffer::new(event.total_fragments) {
109 Some(buf) => e.insert(buf),
110 None => return None,
111 },
112 };
113
114 if buffer.total_fragments != event.total_fragments {
116 tracing::warn!(
117 "Fragment mismatch for message_id {}: expected {}, got {}",
118 event.message_id,
119 buffer.total_fragments,
120 event.total_fragments
121 );
122 return None;
123 }
124
125 let result = buffer.add(event.fragment_index, event.payload);
126
127 if result.is_some() {
128 self.buffers.remove(&key);
129 }
130
131 result
132 }
133
134 pub fn prune(&mut self) {
136 self.buffers
137 .retain(|_, buffer| !buffer.is_stale(self.timeout));
138 }
139
140 #[deprecated(since = "0.2.4", note = "Renamed to ingest() for consistency")]
142 pub fn add(&mut self, client_id: ClientId, event: FragmentedEvent) -> Option<Vec<u8>> {
143 self.ingest(client_id, event)
144 }
145
146 #[deprecated(since = "0.2.4", note = "Renamed to prune() for consistency")]
148 pub fn cleanup(&mut self) {
149 self.prune();
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn test_reassembly_ordered() {
159 let mut reassembler = Reassembler::new();
160 let cid = ClientId(1);
161 let mid = 100;
162
163 let f1 = FragmentedEvent {
164 message_id: mid,
165 fragment_index: 0,
166 total_fragments: 2,
167 payload: vec![1, 2],
168 };
169 let f2 = FragmentedEvent {
170 message_id: mid,
171 fragment_index: 1,
172 total_fragments: 2,
173 payload: vec![3, 4],
174 };
175
176 assert!(reassembler.ingest(cid, f1).is_none());
177 let result = reassembler.ingest(cid, f2).unwrap();
178 assert_eq!(result, vec![1, 2, 3, 4]);
179 }
180
181 #[test]
182 fn test_reassembly_out_of_order() {
183 let mut reassembler = Reassembler::new();
184 let cid = ClientId(1);
185 let mid = 101;
186
187 let f1 = FragmentedEvent {
188 message_id: mid,
189 fragment_index: 0,
190 total_fragments: 3,
191 payload: vec![1],
192 };
193 let f2 = FragmentedEvent {
194 message_id: mid,
195 fragment_index: 1,
196 total_fragments: 3,
197 payload: vec![2],
198 };
199 let f3 = FragmentedEvent {
200 message_id: mid,
201 fragment_index: 2,
202 total_fragments: 3,
203 payload: vec![3],
204 };
205
206 assert!(reassembler.ingest(cid, f3).is_none());
207 assert!(reassembler.ingest(cid, f1).is_none());
208 let result = reassembler.ingest(cid, f2).unwrap();
209 assert_eq!(result, vec![1, 2, 3]);
210 }
211
212 #[test]
213 fn test_cleanup() {
214 let mut reassembler = Reassembler::new().with_timeout(Duration::from_millis(10));
215 let cid = ClientId(1);
216 let mid = 102;
217
218 reassembler.ingest(
219 cid,
220 FragmentedEvent {
221 message_id: mid,
222 fragment_index: 0,
223 total_fragments: 2,
224 payload: vec![1],
225 },
226 );
227
228 std::thread::sleep(Duration::from_millis(20));
229 reassembler.prune();
230 assert!(reassembler.buffers.is_empty());
231 }
232}