Skip to main content

aetheris_protocol/
reassembler.rs

1//! Logic for reassembling fragmented network messages.
2
3use std::collections::{HashMap, hash_map};
4use std::time::{Duration, Instant};
5
6use crate::events::FragmentedEvent;
7use crate::types::ClientId;
8
9/// Buffers fragments for a single larger message from a specific client.
10#[derive(Debug, Clone)]
11struct FragmentBuffer {
12    /// When the first fragment of this message was received.
13    start_time: Instant,
14    /// Total number of fragments expected.
15    total_fragments: u16,
16    /// Fragments received so far.
17    fragments: Vec<Option<Vec<u8>>>,
18    /// Number of fragments currently present in the buffer.
19    count: u16,
20}
21
22impl FragmentBuffer {
23    fn new(total_fragments: u16) -> Option<Self> {
24        if total_fragments == 0 || total_fragments > crate::MAX_TOTAL_FRAGMENTS {
25            return None;
26        }
27
28        Some(Self {
29            start_time: Instant::now(),
30            total_fragments,
31            fragments: vec![None; total_fragments as usize],
32            count: 0,
33        })
34    }
35
36    fn add(&mut self, index: u16, payload: Vec<u8>) -> Option<Vec<u8>> {
37        let idx = index as usize;
38        if idx >= self.fragments.len() {
39            return None;
40        }
41
42        if self.fragments[idx].is_none() {
43            self.fragments[idx] = Some(payload);
44            self.count += 1;
45        }
46
47        if self.count == self.total_fragments {
48            let mut full_payload = Vec::new();
49            for frag in self.fragments.drain(..) {
50                full_payload.extend(frag.unwrap());
51            }
52            Some(full_payload)
53        } else {
54            None
55        }
56    }
57
58    fn is_stale(&self, timeout: Duration) -> bool {
59        self.start_time.elapsed() > timeout
60    }
61}
62
63/// A stateful reassembler that tracks fragmented messages from multiple clients.
64#[derive(Debug, Default, Clone)]
65pub struct Reassembler {
66    /// `message_id` -> buffer
67    buffers: HashMap<(ClientId, u32), FragmentBuffer>,
68    /// How long to keep fragments before discarding.
69    timeout: Duration,
70}
71
72impl Reassembler {
73    /// Creates a new reassembler with a default timeout of 5 seconds.
74    #[must_use]
75    pub fn new() -> Self {
76        Self {
77            buffers: HashMap::new(),
78            timeout: Duration::from_secs(5),
79        }
80    }
81
82    /// Sets a custom timeout for message reassembly.
83    #[must_use]
84    pub fn with_timeout(mut self, timeout: Duration) -> Self {
85        self.timeout = timeout;
86        self
87    }
88
89    /// Adds a fragment to the reassembler.
90    ///
91    /// Returns the full reassembled message if this was the last fragment,
92    /// otherwise returns `None`.
93    pub fn add(&mut self, client_id: ClientId, event: FragmentedEvent) -> Option<Vec<u8>> {
94        // Security check: ensure total_fragments is valid from untrusted input
95        if event.total_fragments == 0 || event.total_fragments > crate::MAX_TOTAL_FRAGMENTS {
96            tracing::warn!(
97                "Rejecting fragment with invalid total_fragments: {}",
98                event.total_fragments
99            );
100            return None;
101        }
102
103        let key = (client_id, event.message_id);
104
105        let buffer = match self.buffers.entry(key) {
106            hash_map::Entry::Occupied(e) => e.into_mut(),
107            hash_map::Entry::Vacant(e) => match FragmentBuffer::new(event.total_fragments) {
108                Some(buf) => e.insert(buf),
109                None => return None,
110            },
111        };
112
113        // Safety check: ensure total_fragments matches what we original expected for this message_id
114        if buffer.total_fragments != event.total_fragments {
115            tracing::warn!(
116                "Fragment mismatch for message_id {}: expected {}, got {}",
117                event.message_id,
118                buffer.total_fragments,
119                event.total_fragments
120            );
121            return None;
122        }
123
124        let result = buffer.add(event.fragment_index, event.payload);
125
126        if result.is_some() {
127            self.buffers.remove(&key);
128        }
129
130        result
131    }
132
133    /// Discards messages that have haven't been completed within the timeout.
134    pub fn cleanup(&mut self) {
135        self.buffers
136            .retain(|_, buffer| !buffer.is_stale(self.timeout));
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[test]
145    fn test_reassembly_ordered() {
146        let mut reassembler = Reassembler::new();
147        let cid = ClientId(1);
148        let mid = 100;
149
150        let f1 = FragmentedEvent {
151            message_id: mid,
152            fragment_index: 0,
153            total_fragments: 2,
154            payload: vec![1, 2],
155        };
156        let f2 = FragmentedEvent {
157            message_id: mid,
158            fragment_index: 1,
159            total_fragments: 2,
160            payload: vec![3, 4],
161        };
162
163        assert!(reassembler.add(cid, f1).is_none());
164        let result = reassembler.add(cid, f2).unwrap();
165        assert_eq!(result, vec![1, 2, 3, 4]);
166    }
167
168    #[test]
169    fn test_reassembly_out_of_order() {
170        let mut reassembler = Reassembler::new();
171        let cid = ClientId(1);
172        let mid = 101;
173
174        let f1 = FragmentedEvent {
175            message_id: mid,
176            fragment_index: 0,
177            total_fragments: 3,
178            payload: vec![1],
179        };
180        let f2 = FragmentedEvent {
181            message_id: mid,
182            fragment_index: 1,
183            total_fragments: 3,
184            payload: vec![2],
185        };
186        let f3 = FragmentedEvent {
187            message_id: mid,
188            fragment_index: 2,
189            total_fragments: 3,
190            payload: vec![3],
191        };
192
193        assert!(reassembler.add(cid, f3).is_none());
194        assert!(reassembler.add(cid, f1).is_none());
195        let result = reassembler.add(cid, f2).unwrap();
196        assert_eq!(result, vec![1, 2, 3]);
197    }
198
199    #[test]
200    fn test_cleanup() {
201        let mut reassembler = Reassembler::new().with_timeout(Duration::from_millis(10));
202        let cid = ClientId(1);
203        let mid = 102;
204
205        reassembler.add(
206            cid,
207            FragmentedEvent {
208                message_id: mid,
209                fragment_index: 0,
210                total_fragments: 2,
211                payload: vec![1],
212            },
213        );
214
215        std::thread::sleep(Duration::from_millis(20));
216        reassembler.cleanup();
217        assert!(reassembler.buffers.is_empty());
218    }
219}