Skip to main content

aetheris_protocol/
reassembler.rs

1//! Logic for reassembling fragmented network messages.
2
3use std::collections::{HashMap, hash_map};
4use std::time::Duration;
5use web_time::Instant;
6
7use crate::events::FragmentedEvent;
8use crate::types::ClientId;
9
10/// Buffers fragments for a single larger message from a specific client.
11#[derive(Debug, Clone)]
12struct FragmentBuffer {
13    /// When the first fragment of this message was received.
14    start_time: Instant,
15    /// Total number of fragments expected.
16    total_fragments: u16,
17    /// Fragments received so far.
18    fragments: Vec<Option<Vec<u8>>>,
19    /// Number of fragments currently present in the buffer.
20    count: u16,
21}
22
23impl FragmentBuffer {
24    fn new(total_fragments: u16) -> Option<Self> {
25        if total_fragments == 0 || total_fragments > crate::MAX_TOTAL_FRAGMENTS {
26            return None;
27        }
28
29        Some(Self {
30            start_time: Instant::now(),
31            total_fragments,
32            fragments: vec![None; total_fragments as usize],
33            count: 0,
34        })
35    }
36
37    fn add(&mut self, index: u16, payload: Vec<u8>) -> Option<Vec<u8>> {
38        let idx = index as usize;
39        if idx >= self.fragments.len() {
40            return None;
41        }
42
43        if self.fragments[idx].is_none() {
44            self.fragments[idx] = Some(payload);
45            self.count += 1;
46        }
47
48        if self.count == self.total_fragments {
49            let mut full_payload = Vec::new();
50            for frag in self.fragments.drain(..) {
51                full_payload.extend(frag.unwrap());
52            }
53            Some(full_payload)
54        } else {
55            None
56        }
57    }
58
59    fn is_stale(&self, timeout: Duration) -> bool {
60        self.start_time.elapsed() > timeout
61    }
62}
63
64/// A stateful reassembler that tracks fragmented messages from multiple clients.
65#[derive(Debug, Default, Clone)]
66pub struct Reassembler {
67    /// `message_id` -> buffer
68    buffers: HashMap<(ClientId, u32), FragmentBuffer>,
69    /// How long to keep fragments before discarding.
70    timeout: Duration,
71}
72
73impl Reassembler {
74    /// Creates a new reassembler with a default timeout of 5 seconds.
75    #[must_use]
76    pub fn new() -> Self {
77        Self {
78            buffers: HashMap::new(),
79            timeout: Duration::from_secs(5),
80        }
81    }
82
83    /// Sets a custom timeout for message reassembly.
84    #[must_use]
85    pub fn with_timeout(mut self, timeout: Duration) -> Self {
86        self.timeout = timeout;
87        self
88    }
89
90    /// Ingests a fragment into the reassembler.
91    ///
92    /// Returns the full reassembled message if this was the last fragment,
93    /// otherwise returns `None`.
94    pub fn ingest(&mut self, client_id: ClientId, event: FragmentedEvent) -> Option<Vec<u8>> {
95        // Security check: ensure total_fragments is valid from untrusted input
96        if event.total_fragments == 0 || event.total_fragments > crate::MAX_TOTAL_FRAGMENTS {
97            tracing::warn!(
98                "Rejecting fragment with invalid total_fragments: {}",
99                event.total_fragments
100            );
101            return None;
102        }
103
104        let key = (client_id, event.message_id);
105
106        let buffer = match self.buffers.entry(key) {
107            hash_map::Entry::Occupied(e) => e.into_mut(),
108            hash_map::Entry::Vacant(e) => match FragmentBuffer::new(event.total_fragments) {
109                Some(buf) => e.insert(buf),
110                None => return None,
111            },
112        };
113
114        // Safety check: ensure total_fragments matches what we original expected for this message_id
115        if buffer.total_fragments != event.total_fragments {
116            tracing::warn!(
117                "Fragment mismatch for message_id {}: expected {}, got {}",
118                event.message_id,
119                buffer.total_fragments,
120                event.total_fragments
121            );
122            return None;
123        }
124
125        let result = buffer.add(event.fragment_index, event.payload);
126
127        if result.is_some() {
128            self.buffers.remove(&key);
129        }
130
131        result
132    }
133
134    /// Discards messages that have haven't been completed within the timeout.
135    pub fn prune(&mut self) {
136        self.buffers
137            .retain(|_, buffer| !buffer.is_stale(self.timeout));
138    }
139
140    /// **DEPRECATED**: Use `ingest()` instead.
141    #[deprecated(since = "0.2.4", note = "Renamed to ingest() for consistency")]
142    pub fn add(&mut self, client_id: ClientId, event: FragmentedEvent) -> Option<Vec<u8>> {
143        self.ingest(client_id, event)
144    }
145
146    /// **DEPRECATED**: Use `prune()` instead.
147    #[deprecated(since = "0.2.4", note = "Renamed to prune() for consistency")]
148    pub fn cleanup(&mut self) {
149        self.prune();
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    #[test]
158    fn test_reassembly_ordered() {
159        let mut reassembler = Reassembler::new();
160        let cid = ClientId(1);
161        let mid = 100;
162
163        let f1 = FragmentedEvent {
164            message_id: mid,
165            fragment_index: 0,
166            total_fragments: 2,
167            payload: vec![1, 2],
168        };
169        let f2 = FragmentedEvent {
170            message_id: mid,
171            fragment_index: 1,
172            total_fragments: 2,
173            payload: vec![3, 4],
174        };
175
176        assert!(reassembler.ingest(cid, f1).is_none());
177        let result = reassembler.ingest(cid, f2).unwrap();
178        assert_eq!(result, vec![1, 2, 3, 4]);
179    }
180
181    #[test]
182    fn test_reassembly_out_of_order() {
183        let mut reassembler = Reassembler::new();
184        let cid = ClientId(1);
185        let mid = 101;
186
187        let f1 = FragmentedEvent {
188            message_id: mid,
189            fragment_index: 0,
190            total_fragments: 3,
191            payload: vec![1],
192        };
193        let f2 = FragmentedEvent {
194            message_id: mid,
195            fragment_index: 1,
196            total_fragments: 3,
197            payload: vec![2],
198        };
199        let f3 = FragmentedEvent {
200            message_id: mid,
201            fragment_index: 2,
202            total_fragments: 3,
203            payload: vec![3],
204        };
205
206        assert!(reassembler.ingest(cid, f3).is_none());
207        assert!(reassembler.ingest(cid, f1).is_none());
208        let result = reassembler.ingest(cid, f2).unwrap();
209        assert_eq!(result, vec![1, 2, 3]);
210    }
211
212    #[test]
213    fn test_cleanup() {
214        let mut reassembler = Reassembler::new().with_timeout(Duration::from_millis(10));
215        let cid = ClientId(1);
216        let mid = 102;
217
218        reassembler.ingest(
219            cid,
220            FragmentedEvent {
221                message_id: mid,
222                fragment_index: 0,
223                total_fragments: 2,
224                payload: vec![1],
225            },
226        );
227
228        std::thread::sleep(Duration::from_millis(20));
229        reassembler.prune();
230        assert!(reassembler.buffers.is_empty());
231    }
232}