Skip to main content

aetheris_protocol/
reassembler.rs

1//! Logic for reassembling fragmented network messages.
2
3use std::collections::{HashMap, hash_map};
4use std::time::{Duration, Instant};
5
6use crate::events::FragmentedEvent;
7use crate::types::ClientId;
8
9/// Buffers fragments for a single larger message from a specific client.
10#[derive(Debug, Clone)]
11struct FragmentBuffer {
12    /// When the first fragment of this message was received.
13    start_time: Instant,
14    /// Total number of fragments expected.
15    total_fragments: u16,
16    /// Fragments received so far.
17    fragments: Vec<Option<Vec<u8>>>,
18    /// Number of fragments currently present in the buffer.
19    count: u16,
20}
21
22impl FragmentBuffer {
23    fn new(total_fragments: u16) -> Option<Self> {
24        if total_fragments == 0 || total_fragments > crate::MAX_TOTAL_FRAGMENTS {
25            return None;
26        }
27
28        Some(Self {
29            start_time: Instant::now(),
30            total_fragments,
31            fragments: vec![None; total_fragments as usize],
32            count: 0,
33        })
34    }
35
36    fn add(&mut self, index: u16, payload: Vec<u8>) -> Option<Vec<u8>> {
37        let idx = index as usize;
38        if idx >= self.fragments.len() {
39            return None;
40        }
41
42        if self.fragments[idx].is_none() {
43            self.fragments[idx] = Some(payload);
44            self.count += 1;
45        }
46
47        if self.count == self.total_fragments {
48            let mut full_payload = Vec::new();
49            for frag in self.fragments.drain(..) {
50                full_payload.extend(frag.unwrap());
51            }
52            Some(full_payload)
53        } else {
54            None
55        }
56    }
57
58    fn is_stale(&self, timeout: Duration) -> bool {
59        self.start_time.elapsed() > timeout
60    }
61}
62
63/// A stateful reassembler that tracks fragmented messages from multiple clients.
64#[derive(Debug, Default, Clone)]
65pub struct Reassembler {
66    /// `message_id` -> buffer
67    buffers: HashMap<(ClientId, u32), FragmentBuffer>,
68    /// How long to keep fragments before discarding.
69    timeout: Duration,
70}
71
72impl Reassembler {
73    /// Creates a new reassembler with a default timeout of 5 seconds.
74    #[must_use]
75    pub fn new() -> Self {
76        Self {
77            buffers: HashMap::new(),
78            timeout: Duration::from_secs(5),
79        }
80    }
81
82    /// Sets a custom timeout for message reassembly.
83    #[must_use]
84    pub fn with_timeout(mut self, timeout: Duration) -> Self {
85        self.timeout = timeout;
86        self
87    }
88
89    /// Ingests a fragment into the reassembler.
90    ///
91    /// Returns the full reassembled message if this was the last fragment,
92    /// otherwise returns `None`.
93    pub fn ingest(&mut self, client_id: ClientId, event: FragmentedEvent) -> Option<Vec<u8>> {
94        // Security check: ensure total_fragments is valid from untrusted input
95        if event.total_fragments == 0 || event.total_fragments > crate::MAX_TOTAL_FRAGMENTS {
96            tracing::warn!(
97                "Rejecting fragment with invalid total_fragments: {}",
98                event.total_fragments
99            );
100            return None;
101        }
102
103        let key = (client_id, event.message_id);
104
105        let buffer = match self.buffers.entry(key) {
106            hash_map::Entry::Occupied(e) => e.into_mut(),
107            hash_map::Entry::Vacant(e) => match FragmentBuffer::new(event.total_fragments) {
108                Some(buf) => e.insert(buf),
109                None => return None,
110            },
111        };
112
113        // Safety check: ensure total_fragments matches what we original expected for this message_id
114        if buffer.total_fragments != event.total_fragments {
115            tracing::warn!(
116                "Fragment mismatch for message_id {}: expected {}, got {}",
117                event.message_id,
118                buffer.total_fragments,
119                event.total_fragments
120            );
121            return None;
122        }
123
124        let result = buffer.add(event.fragment_index, event.payload);
125
126        if result.is_some() {
127            self.buffers.remove(&key);
128        }
129
130        result
131    }
132
133    /// Discards messages that have haven't been completed within the timeout.
134    pub fn prune(&mut self) {
135        self.buffers
136            .retain(|_, buffer| !buffer.is_stale(self.timeout));
137    }
138
139    /// **DEPRECATED**: Use `ingest()` instead.
140    #[deprecated(since = "0.2.4", note = "Renamed to ingest() for consistency")]
141    pub fn add(&mut self, client_id: ClientId, event: FragmentedEvent) -> Option<Vec<u8>> {
142        self.ingest(client_id, event)
143    }
144
145    /// **DEPRECATED**: Use `prune()` instead.
146    #[deprecated(since = "0.2.4", note = "Renamed to prune() for consistency")]
147    pub fn cleanup(&mut self) {
148        self.prune();
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn test_reassembly_ordered() {
158        let mut reassembler = Reassembler::new();
159        let cid = ClientId(1);
160        let mid = 100;
161
162        let f1 = FragmentedEvent {
163            message_id: mid,
164            fragment_index: 0,
165            total_fragments: 2,
166            payload: vec![1, 2],
167        };
168        let f2 = FragmentedEvent {
169            message_id: mid,
170            fragment_index: 1,
171            total_fragments: 2,
172            payload: vec![3, 4],
173        };
174
175        assert!(reassembler.ingest(cid, f1).is_none());
176        let result = reassembler.ingest(cid, f2).unwrap();
177        assert_eq!(result, vec![1, 2, 3, 4]);
178    }
179
180    #[test]
181    fn test_reassembly_out_of_order() {
182        let mut reassembler = Reassembler::new();
183        let cid = ClientId(1);
184        let mid = 101;
185
186        let f1 = FragmentedEvent {
187            message_id: mid,
188            fragment_index: 0,
189            total_fragments: 3,
190            payload: vec![1],
191        };
192        let f2 = FragmentedEvent {
193            message_id: mid,
194            fragment_index: 1,
195            total_fragments: 3,
196            payload: vec![2],
197        };
198        let f3 = FragmentedEvent {
199            message_id: mid,
200            fragment_index: 2,
201            total_fragments: 3,
202            payload: vec![3],
203        };
204
205        assert!(reassembler.ingest(cid, f3).is_none());
206        assert!(reassembler.ingest(cid, f1).is_none());
207        let result = reassembler.ingest(cid, f2).unwrap();
208        assert_eq!(result, vec![1, 2, 3]);
209    }
210
211    #[test]
212    fn test_cleanup() {
213        let mut reassembler = Reassembler::new().with_timeout(Duration::from_millis(10));
214        let cid = ClientId(1);
215        let mid = 102;
216
217        reassembler.ingest(
218            cid,
219            FragmentedEvent {
220                message_id: mid,
221                fragment_index: 0,
222                total_fragments: 2,
223                payload: vec![1],
224            },
225        );
226
227        std::thread::sleep(Duration::from_millis(20));
228        reassembler.prune();
229        assert!(reassembler.buffers.is_empty());
230    }
231}