aetheris_protocol/
reassembler.rs1use std::collections::{HashMap, hash_map};
4use std::time::{Duration, Instant};
5
6use crate::events::FragmentedEvent;
7use crate::types::ClientId;
8
9#[derive(Debug, Clone)]
11struct FragmentBuffer {
12 start_time: Instant,
14 total_fragments: u16,
16 fragments: Vec<Option<Vec<u8>>>,
18 count: u16,
20}
21
22impl FragmentBuffer {
23 fn new(total_fragments: u16) -> Option<Self> {
24 if total_fragments == 0 || total_fragments > crate::MAX_TOTAL_FRAGMENTS {
25 return None;
26 }
27
28 Some(Self {
29 start_time: Instant::now(),
30 total_fragments,
31 fragments: vec![None; total_fragments as usize],
32 count: 0,
33 })
34 }
35
36 fn add(&mut self, index: u16, payload: Vec<u8>) -> Option<Vec<u8>> {
37 let idx = index as usize;
38 if idx >= self.fragments.len() {
39 return None;
40 }
41
42 if self.fragments[idx].is_none() {
43 self.fragments[idx] = Some(payload);
44 self.count += 1;
45 }
46
47 if self.count == self.total_fragments {
48 let mut full_payload = Vec::new();
49 for frag in self.fragments.drain(..) {
50 full_payload.extend(frag.unwrap());
51 }
52 Some(full_payload)
53 } else {
54 None
55 }
56 }
57
58 fn is_stale(&self, timeout: Duration) -> bool {
59 self.start_time.elapsed() > timeout
60 }
61}
62
63#[derive(Debug, Default, Clone)]
65pub struct Reassembler {
66 buffers: HashMap<(ClientId, u32), FragmentBuffer>,
68 timeout: Duration,
70}
71
72impl Reassembler {
73 #[must_use]
75 pub fn new() -> Self {
76 Self {
77 buffers: HashMap::new(),
78 timeout: Duration::from_secs(5),
79 }
80 }
81
82 #[must_use]
84 pub fn with_timeout(mut self, timeout: Duration) -> Self {
85 self.timeout = timeout;
86 self
87 }
88
89 pub fn ingest(&mut self, client_id: ClientId, event: FragmentedEvent) -> Option<Vec<u8>> {
94 if event.total_fragments == 0 || event.total_fragments > crate::MAX_TOTAL_FRAGMENTS {
96 tracing::warn!(
97 "Rejecting fragment with invalid total_fragments: {}",
98 event.total_fragments
99 );
100 return None;
101 }
102
103 let key = (client_id, event.message_id);
104
105 let buffer = match self.buffers.entry(key) {
106 hash_map::Entry::Occupied(e) => e.into_mut(),
107 hash_map::Entry::Vacant(e) => match FragmentBuffer::new(event.total_fragments) {
108 Some(buf) => e.insert(buf),
109 None => return None,
110 },
111 };
112
113 if buffer.total_fragments != event.total_fragments {
115 tracing::warn!(
116 "Fragment mismatch for message_id {}: expected {}, got {}",
117 event.message_id,
118 buffer.total_fragments,
119 event.total_fragments
120 );
121 return None;
122 }
123
124 let result = buffer.add(event.fragment_index, event.payload);
125
126 if result.is_some() {
127 self.buffers.remove(&key);
128 }
129
130 result
131 }
132
133 pub fn prune(&mut self) {
135 self.buffers
136 .retain(|_, buffer| !buffer.is_stale(self.timeout));
137 }
138
139 #[deprecated(since = "0.2.4", note = "Renamed to ingest() for consistency")]
141 pub fn add(&mut self, client_id: ClientId, event: FragmentedEvent) -> Option<Vec<u8>> {
142 self.ingest(client_id, event)
143 }
144
145 #[deprecated(since = "0.2.4", note = "Renamed to prune() for consistency")]
147 pub fn cleanup(&mut self) {
148 self.prune();
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn test_reassembly_ordered() {
158 let mut reassembler = Reassembler::new();
159 let cid = ClientId(1);
160 let mid = 100;
161
162 let f1 = FragmentedEvent {
163 message_id: mid,
164 fragment_index: 0,
165 total_fragments: 2,
166 payload: vec![1, 2],
167 };
168 let f2 = FragmentedEvent {
169 message_id: mid,
170 fragment_index: 1,
171 total_fragments: 2,
172 payload: vec![3, 4],
173 };
174
175 assert!(reassembler.ingest(cid, f1).is_none());
176 let result = reassembler.ingest(cid, f2).unwrap();
177 assert_eq!(result, vec![1, 2, 3, 4]);
178 }
179
180 #[test]
181 fn test_reassembly_out_of_order() {
182 let mut reassembler = Reassembler::new();
183 let cid = ClientId(1);
184 let mid = 101;
185
186 let f1 = FragmentedEvent {
187 message_id: mid,
188 fragment_index: 0,
189 total_fragments: 3,
190 payload: vec![1],
191 };
192 let f2 = FragmentedEvent {
193 message_id: mid,
194 fragment_index: 1,
195 total_fragments: 3,
196 payload: vec![2],
197 };
198 let f3 = FragmentedEvent {
199 message_id: mid,
200 fragment_index: 2,
201 total_fragments: 3,
202 payload: vec![3],
203 };
204
205 assert!(reassembler.ingest(cid, f3).is_none());
206 assert!(reassembler.ingest(cid, f1).is_none());
207 let result = reassembler.ingest(cid, f2).unwrap();
208 assert_eq!(result, vec![1, 2, 3]);
209 }
210
211 #[test]
212 fn test_cleanup() {
213 let mut reassembler = Reassembler::new().with_timeout(Duration::from_millis(10));
214 let cid = ClientId(1);
215 let mid = 102;
216
217 reassembler.ingest(
218 cid,
219 FragmentedEvent {
220 message_id: mid,
221 fragment_index: 0,
222 total_fragments: 2,
223 payload: vec![1],
224 },
225 );
226
227 std::thread::sleep(Duration::from_millis(20));
228 reassembler.prune();
229 assert!(reassembler.buffers.is_empty());
230 }
231}