Skip to main content

theater_server/
fragmenting_codec.rs

1//! Fixed FragmentingCodec that works properly with Framed::split()
2
3use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
4use bytes::{Bytes, BytesMut};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::io;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, Instant};
11use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec};
12use tracing::{debug, warn};
13
14/// Maximum size for a single fragment data (12MB)
15/// This leaves room for JSON serialization overhead while staying well under the 32MB frame limit
16const MAX_FRAGMENT_DATA_SIZE: usize = 8 * 1024 * 1024; // Reduced to 8MB to account for base64 + JSON overhead
17
18/// How long to keep partial messages before timing out (30 seconds)
19const FRAGMENT_TIMEOUT: Duration = Duration::from_secs(30);
20
21/// A single fragment of a larger message
22#[derive(Debug, Clone, Serialize, Deserialize)]
23struct Fragment {
24    /// Unique identifier for the complete message
25    message_id: u64,
26    /// Index of this fragment (0-based)
27    fragment_index: u32,
28    /// Total number of fragments for this message
29    total_fragments: u32,
30    /// The actual data chunk (base64 encoded for efficient JSON serialization)
31    data: String,
32}
33
34/// Internal wrapper to distinguish between complete messages and fragments
35#[derive(Debug, Clone, Serialize, Deserialize)]
36enum FrameType {
37    /// A complete message that doesn't need fragmentation
38    Complete(Vec<u8>),
39    /// A fragment of a larger message
40    Fragment(Fragment),
41}
42
43/// Partial message being reassembled
44#[derive(Debug)]
45struct PartialMessage {
46    /// When this partial message was first created
47    created_at: Instant,
48    /// Total number of fragments expected
49    total_fragments: u32,
50    /// Fragments received so far, indexed by fragment_index
51    fragments: HashMap<u32, Vec<u8>>,
52}
53
54impl PartialMessage {
55    fn new(total_fragments: u32) -> Self {
56        Self {
57            created_at: Instant::now(),
58            total_fragments,
59            fragments: HashMap::new(),
60        }
61    }
62
63    fn add_fragment(&mut self, index: u32, data: Vec<u8>) {
64        self.fragments.insert(index, data);
65    }
66
67    fn is_complete(&self) -> bool {
68        self.fragments.len() == self.total_fragments as usize
69    }
70
71    fn is_expired(&self) -> bool {
72        self.created_at.elapsed() > FRAGMENT_TIMEOUT
73    }
74
75    fn reassemble(self) -> io::Result<Vec<u8>> {
76        if !self.is_complete() {
77            return Err(io::Error::new(
78                io::ErrorKind::InvalidData,
79                "Cannot reassemble incomplete message",
80            ));
81        }
82
83        let mut result = Vec::new();
84
85        // Reassemble fragments in order
86        for i in 0..self.total_fragments {
87            if let Some(fragment_data) = self.fragments.get(&i) {
88                result.extend_from_slice(fragment_data);
89            } else {
90                return Err(io::Error::new(
91                    io::ErrorKind::InvalidData,
92                    format!("Missing fragment {}", i),
93                ));
94            }
95        }
96
97        Ok(result)
98    }
99}
100
101/// Shared state between encoder and decoder
102#[derive(Debug)]
103struct SharedState {
104    /// Counter for generating unique message IDs
105    next_message_id: AtomicU64,
106    /// Partial messages being reassembled (keyed by message_id)
107    partial_messages: Mutex<HashMap<u64, PartialMessage>>,
108    /// Last time we cleaned up expired partial messages
109    last_cleanup: Mutex<Instant>,
110}
111
112impl SharedState {
113    fn new() -> Self {
114        Self {
115            next_message_id: AtomicU64::new(1),
116            partial_messages: Mutex::new(HashMap::new()),
117            last_cleanup: Mutex::new(Instant::now()),
118        }
119    }
120
121    fn next_message_id(&self) -> u64 {
122        self.next_message_id.fetch_add(1, Ordering::Relaxed)
123    }
124
125    fn cleanup_expired(&self) {
126        // Only cleanup periodically to avoid overhead
127        {
128            let last_cleanup = self.last_cleanup.lock().unwrap();
129            if last_cleanup.elapsed() < Duration::from_secs(10) {
130                return;
131            }
132        }
133
134        let mut partial_messages = self.partial_messages.lock().unwrap();
135        let before_count = partial_messages.len();
136
137        partial_messages.retain(|message_id, partial| {
138            if partial.is_expired() {
139                warn!("Cleaning up expired partial message {}", message_id);
140                false
141            } else {
142                true
143            }
144        });
145
146        let cleaned = before_count - partial_messages.len();
147        if cleaned > 0 {
148            debug!("Cleaned up {} expired partial messages", cleaned);
149        }
150
151        *self.last_cleanup.lock().unwrap() = Instant::now();
152    }
153}
154
155/// A codec that transparently handles message fragmentation
156#[derive(Debug)]
157pub struct FragmentingCodec {
158    /// Underlying length-delimited codec
159    inner: LengthDelimitedCodec,
160    /// Shared state between encoder and decoder
161    shared_state: Arc<SharedState>,
162}
163
164impl FragmentingCodec {
165    /// Create a new fragmenting codec with the same configuration as Theater's current setup
166    pub fn new() -> Self {
167        let mut inner = LengthDelimitedCodec::new();
168        inner.set_max_frame_length(32 * 1024 * 1024); // 32MB max frame
169
170        Self {
171            inner,
172            shared_state: Arc::new(SharedState::new()),
173        }
174    }
175
176    /// Fragment a large message into chunks
177    fn fragment_message(&self, data: &[u8]) -> Vec<Fragment> {
178        let message_id = self.shared_state.next_message_id();
179        let total_size = data.len();
180
181        // Use the defined chunk size constant
182        let chunk_size = MAX_FRAGMENT_DATA_SIZE;
183
184        // Calculate how many fragments we need
185        let total_fragments = total_size.div_ceil(chunk_size);
186
187        debug!(
188            "Fragmenting message {} into {} fragments (total size: {} bytes, chunk size: {} bytes)",
189            message_id, total_fragments, total_size, chunk_size
190        );
191
192        let mut fragments = Vec::new();
193
194        for (i, chunk) in data.chunks(chunk_size).enumerate() {
195            let fragment = Fragment {
196                message_id,
197                fragment_index: i as u32,
198                total_fragments: total_fragments as u32,
199                data: BASE64.encode(chunk),
200            };
201
202            // Debug: check serialized size to ensure it's under the frame limit
203            if let Ok(serialized) = serde_json::to_vec(&FrameType::Fragment(fragment.clone())) {
204                debug!("Fragment {} serialized size: {} bytes", i, serialized.len());
205                if serialized.len() > 31 * 1024 * 1024 {
206                    // Close to 32MB limit
207                    warn!("Fragment {} serialized size ({} bytes) is dangerously close to frame limit", i, serialized.len());
208                }
209            }
210
211            fragments.push(fragment);
212        }
213
214        fragments
215    }
216}
217
218impl Default for FragmentingCodec {
219    fn default() -> Self {
220        Self::new()
221    }
222}
223
224impl Clone for FragmentingCodec {
225    fn clone(&self) -> Self {
226        let mut inner = LengthDelimitedCodec::new();
227        inner.set_max_frame_length(32 * 1024 * 1024); // 32MB max frame - CRITICAL!
228
229        Self {
230            inner,
231            shared_state: Arc::clone(&self.shared_state),
232        }
233    }
234}
235
236impl Encoder<Bytes> for FragmentingCodec {
237    type Error = io::Error;
238
239    fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
240        let data = item.to_vec();
241
242        // Check if we need to fragment this message
243        if data.len() <= MAX_FRAGMENT_DATA_SIZE {
244            // Small message - send as complete
245            let frame = FrameType::Complete(data);
246            let serialized = serde_json::to_vec(&frame).map_err(|e| {
247                io::Error::new(
248                    io::ErrorKind::InvalidData,
249                    format!("Failed to serialize frame: {}", e),
250                )
251            })?;
252
253            self.inner.encode(Bytes::from(serialized), dst)
254        } else {
255            // Large message - fragment it
256            let fragments = self.fragment_message(&data);
257
258            // Encode each fragment into the destination buffer
259            for fragment in fragments {
260                let frame = FrameType::Fragment(fragment);
261                let serialized = serde_json::to_vec(&frame).map_err(|e| {
262                    io::Error::new(
263                        io::ErrorKind::InvalidData,
264                        format!("Failed to serialize fragment: {}", e),
265                    )
266                })?;
267
268                // Create a temporary buffer for this fragment
269                let mut fragment_buf = BytesMut::new();
270                self.inner
271                    .encode(Bytes::from(serialized), &mut fragment_buf)?;
272
273                // Append to the main destination buffer
274                dst.extend_from_slice(&fragment_buf);
275            }
276
277            Ok(())
278        }
279    }
280}
281
282impl Decoder for FragmentingCodec {
283    type Item = Bytes;
284    type Error = io::Error;
285
286    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
287        // Clean up expired messages periodically
288        self.shared_state.cleanup_expired();
289
290        // Try to decode a frame from the underlying codec
291        if let Some(frame_bytes) = self.inner.decode(src)? {
292            // Deserialize the frame
293            let frame: FrameType = serde_json::from_slice(&frame_bytes).map_err(|e| {
294                io::Error::new(
295                    io::ErrorKind::InvalidData,
296                    format!("Failed to deserialize frame: {}", e),
297                )
298            })?;
299
300            match frame {
301                FrameType::Complete(data) => {
302                    // Complete message - return immediately
303                    Ok(Some(Bytes::from(data)))
304                }
305                FrameType::Fragment(fragment) => {
306                    // Fragment - add to partial message
307                    let message_id = fragment.message_id;
308                    let fragment_index = fragment.fragment_index;
309                    let total_fragments = fragment.total_fragments;
310
311                    debug!(
312                        "Received fragment {}/{} for message {}",
313                        fragment_index + 1,
314                        total_fragments,
315                        message_id
316                    );
317
318                    // Decode the base64 data
319                    let fragment_data = BASE64.decode(&fragment.data).map_err(|e| {
320                        io::Error::new(
321                            io::ErrorKind::InvalidData,
322                            format!("Failed to decode fragment data: {}", e),
323                        )
324                    })?;
325
326                    // Get or create partial message
327                    let mut partial_messages = self.shared_state.partial_messages.lock().unwrap();
328                    let partial = partial_messages
329                        .entry(message_id)
330                        .or_insert_with(|| PartialMessage::new(total_fragments));
331
332                    // Add this fragment
333                    partial.add_fragment(fragment_index, fragment_data);
334
335                    // Check if message is complete
336                    if partial.is_complete() {
337                        debug!("Message {} is complete, reassembling", message_id);
338
339                        // Remove from partial messages and reassemble
340                        let partial = partial_messages.remove(&message_id).unwrap();
341                        drop(partial_messages); // Release the lock
342
343                        let complete_data = partial.reassemble()?;
344                        Ok(Some(Bytes::from(complete_data)))
345                    } else {
346                        // Still waiting for more fragments
347                        debug!(
348                            "Message {} still incomplete ({}/{} fragments)",
349                            message_id,
350                            partial.fragments.len(),
351                            total_fragments
352                        );
353                        Ok(None)
354                    }
355                }
356            }
357        } else {
358            // No complete frame available yet
359            Ok(None)
360        }
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367    use futures::{SinkExt, StreamExt};
368    use tokio::io::duplex;
369    use tokio_util::codec::{FramedRead, FramedWrite};
370
371    #[tokio::test]
372    async fn test_small_message_no_fragmentation() {
373        let (client, server) = duplex(1024);
374
375        let codec_write = FragmentingCodec::new();
376        let codec_read = FragmentingCodec::new();
377
378        let mut writer = FramedWrite::new(client, codec_write);
379        let mut reader = FramedRead::new(server, codec_read);
380
381        let test_data = b"Hello, World!";
382
383        // Send small message
384        writer.send(Bytes::from(&test_data[..])).await.unwrap();
385        drop(writer); // Close writer
386
387        // Receive should get the same data
388        let received = reader.next().await.unwrap().unwrap();
389        assert_eq!(received.as_ref(), test_data);
390    }
391
392    #[tokio::test]
393    async fn test_large_message_fragmentation() {
394        let (client, server) = duplex(64 * 1024 * 1024); // Large buffer
395
396        let codec_write = FragmentingCodec::new();
397        let codec_read = FragmentingCodec::new();
398
399        let mut writer = FramedWrite::new(client, codec_write);
400        let mut reader = FramedRead::new(server, codec_read);
401
402        // Create a message larger than MAX_FRAGMENT_DATA_SIZE
403        let test_data = vec![0xAB; MAX_FRAGMENT_DATA_SIZE + 1000];
404
405        // Send large message
406        match writer.send(Bytes::from(test_data.clone())).await {
407            Ok(_) => println!("Successfully sent large message"),
408            Err(e) => {
409                println!("Error sending: {:?}", e);
410                panic!("Failed to send: {}", e);
411            }
412        }
413        drop(writer); // Close writer
414
415        // Receive should get the same data
416        let received = reader.next().await.unwrap().unwrap();
417        assert_eq!(received.as_ref(), &test_data[..]);
418    }
419
420    #[test]
421    fn test_fragment_message() {
422        let codec = FragmentingCodec::new();
423        let data = vec![0x42; MAX_FRAGMENT_DATA_SIZE + 500];
424
425        let fragments = codec.fragment_message(&data);
426
427        assert_eq!(fragments.len(), 2);
428        assert_eq!(fragments[0].fragment_index, 0);
429        assert_eq!(fragments[1].fragment_index, 1);
430        assert_eq!(fragments[0].total_fragments, 2);
431        assert_eq!(fragments[1].total_fragments, 2);
432        assert_eq!(fragments[0].message_id, fragments[1].message_id);
433
434        // Check data integrity
435        let mut reassembled = Vec::new();
436        let decoded_0 = BASE64.decode(&fragments[0].data).unwrap();
437        let decoded_1 = BASE64.decode(&fragments[1].data).unwrap();
438        reassembled.extend_from_slice(&decoded_0);
439        reassembled.extend_from_slice(&decoded_1);
440        assert_eq!(reassembled, data);
441    }
442
443    #[test]
444    fn test_partial_message_assembly() {
445        let mut partial = PartialMessage::new(3);
446
447        assert!(!partial.is_complete());
448
449        partial.add_fragment(0, vec![1, 2, 3]);
450        partial.add_fragment(2, vec![7, 8, 9]);
451        assert!(!partial.is_complete());
452
453        partial.add_fragment(1, vec![4, 5, 6]);
454        assert!(partial.is_complete());
455
456        let reassembled = partial.reassemble().unwrap();
457        assert_eq!(reassembled, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
458    }
459}