Skip to main content

procwire_client/protocol/
frame_buffer.rs

1//! Frame buffer for accumulating partial reads.
2//!
3//! Uses `bytes::BytesMut` for zero-copy buffer management.
4//! Implements a state machine for handling fragmented frames:
5//! - `WaitingForHeader`: Need at least 11 bytes
6//! - `WaitingForPayload`: Header parsed, need N more payload bytes
7//!
8//! # Example
9//!
10//! ```ignore
11//! use procwire_client::protocol::FrameBuffer;
12//!
13//! let mut buffer = FrameBuffer::new();
14//!
15//! // Data arrives in chunks from socket
16//! let chunk = vec![0u8; 100];
17//! let frames = buffer.push(&chunk).unwrap();
18//!
19//! for frame in frames {
20//!     println!("Got frame with method_id: {}", frame.method_id());
21//! }
22//! ```
23
24use bytes::{Bytes, BytesMut};
25
26use super::wire_format::{Header, DEFAULT_MAX_PAYLOAD_SIZE, HEADER_SIZE};
27use super::Frame;
28use crate::error::{ProcwireError, Result};
29
30/// State machine for frame parsing.
31#[derive(Debug, Clone)]
32enum State {
33    /// Waiting for complete header (need 11 bytes).
34    WaitingForHeader,
35    /// Header parsed, waiting for payload bytes.
36    WaitingForPayload { header: Header, remaining: u32 },
37}
38
39/// Buffer for accumulating incoming bytes and extracting complete frames.
40///
41/// Uses a state machine to handle partial reads efficiently.
42/// All data is stored in a single `BytesMut` buffer to minimize allocations.
43pub struct FrameBuffer {
44    /// Accumulated bytes from socket reads.
45    buffer: BytesMut,
46    /// Current parsing state.
47    state: State,
48    /// Maximum allowed payload size.
49    max_payload_size: u32,
50}
51
52impl FrameBuffer {
53    /// Create a new frame buffer with default settings.
54    ///
55    /// Default capacity: 64KB, max payload: 1GB.
56    pub fn new() -> Self {
57        Self {
58            buffer: BytesMut::with_capacity(64 * 1024),
59            state: State::WaitingForHeader,
60            max_payload_size: DEFAULT_MAX_PAYLOAD_SIZE,
61        }
62    }
63
64    /// Create a new frame buffer with custom max payload size.
65    pub fn with_max_payload(max_payload_size: u32) -> Self {
66        Self {
67            buffer: BytesMut::with_capacity(64 * 1024),
68            state: State::WaitingForHeader,
69            max_payload_size,
70        }
71    }
72
73    /// Create a new frame buffer with custom capacity and max payload.
74    pub fn with_capacity_and_max_payload(capacity: usize, max_payload_size: u32) -> Self {
75        Self {
76            buffer: BytesMut::with_capacity(capacity),
77            state: State::WaitingForHeader,
78            max_payload_size,
79        }
80    }
81
82    /// Push data into the buffer and extract all complete frames.
83    ///
84    /// This is the main API for processing incoming data from the socket.
85    /// Returns a vector of complete frames. If data is fragmented,
86    /// partial data is buffered internally for the next push.
87    ///
88    /// # Arguments
89    ///
90    /// * `data` - Raw bytes from socket read
91    ///
92    /// # Returns
93    ///
94    /// Vector of complete frames (may be empty if still waiting for data).
95    ///
96    /// # Errors
97    ///
98    /// Returns error if payload exceeds max_payload_size.
99    pub fn push(&mut self, data: &[u8]) -> Result<Vec<Frame>> {
100        // Single allocation to add data to buffer
101        self.buffer.extend_from_slice(data);
102
103        let mut frames = Vec::new();
104
105        // Process as many complete frames as possible
106        while let Some(frame) = self.try_extract_one()? {
107            frames.push(frame);
108        }
109
110        Ok(frames)
111    }
112
113    /// Try to extract a single frame from the buffer.
114    ///
115    /// Returns:
116    /// - `Ok(Some(frame))` if a complete frame was extracted
117    /// - `Ok(None)` if more data is needed
118    /// - `Err(...)` if protocol violation (e.g., payload too large)
119    fn try_extract_one(&mut self) -> Result<Option<Frame>> {
120        match &self.state {
121            State::WaitingForHeader => {
122                if self.buffer.len() < HEADER_SIZE {
123                    return Ok(None);
124                }
125
126                // Parse header (peek, don't consume yet)
127                let header =
128                    Header::decode(&self.buffer[..HEADER_SIZE]).expect("Buffer has enough bytes");
129
130                // Validate payload size
131                if header.payload_length > self.max_payload_size {
132                    return Err(ProcwireError::Protocol(format!(
133                        "Payload size {} exceeds maximum {}",
134                        header.payload_length, self.max_payload_size
135                    )));
136                }
137
138                // Consume header bytes
139                let _ = self.buffer.split_to(HEADER_SIZE);
140
141                if header.payload_length == 0 {
142                    // Empty payload, frame is complete
143                    return Ok(Some(Frame::new(header, Bytes::new())));
144                }
145
146                // Transition to waiting for payload
147                self.state = State::WaitingForPayload {
148                    header,
149                    remaining: header.payload_length,
150                };
151
152                // Try to get payload immediately
153                self.try_extract_one()
154            }
155
156            State::WaitingForPayload { header, remaining } => {
157                let remaining = *remaining as usize;
158
159                if self.buffer.len() < remaining {
160                    return Ok(None);
161                }
162
163                // Extract payload (zero-copy freeze)
164                let payload = self.buffer.split_to(remaining).freeze();
165                let header = *header;
166
167                // Reset state for next frame
168                self.state = State::WaitingForHeader;
169
170                Ok(Some(Frame::new(header, payload)))
171            }
172        }
173    }
174
175    /// Legacy method - try to extract a single frame.
176    ///
177    /// Prefer using `push()` which handles multiple frames efficiently.
178    #[deprecated(note = "Use push() instead for proper multi-frame handling")]
179    pub fn try_extract(&mut self) -> Option<Frame> {
180        self.try_extract_one().ok().flatten()
181    }
182
183    /// Append data to the buffer without extracting frames.
184    ///
185    /// Prefer using `push()` which does extend + extract in one call.
186    pub fn extend(&mut self, data: &[u8]) {
187        self.buffer.extend_from_slice(data);
188    }
189
190    /// Get the number of buffered bytes.
191    pub fn len(&self) -> usize {
192        self.buffer.len()
193    }
194
195    /// Check if the buffer is empty.
196    pub fn is_empty(&self) -> bool {
197        self.buffer.is_empty()
198    }
199
200    /// Clear the buffer and reset state.
201    pub fn clear(&mut self) {
202        self.buffer.clear();
203        self.state = State::WaitingForHeader;
204    }
205
206    /// Get the current state for debugging.
207    #[cfg(test)]
208    fn state_name(&self) -> &'static str {
209        match &self.state {
210            State::WaitingForHeader => "WaitingForHeader",
211            State::WaitingForPayload { .. } => "WaitingForPayload",
212        }
213    }
214}
215
216impl Default for FrameBuffer {
217    fn default() -> Self {
218        Self::new()
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use crate::protocol::flags;
226
227    /// Helper to create a valid frame as bytes.
228    fn make_frame_bytes(method_id: u16, flags: u8, request_id: u32, payload: &[u8]) -> Vec<u8> {
229        let header = Header::new(method_id, flags, request_id, payload.len() as u32);
230        let mut bytes = header.encode().to_vec();
231        bytes.extend_from_slice(payload);
232        bytes
233    }
234
235    #[test]
236    fn test_single_complete_frame() {
237        let mut buffer = FrameBuffer::new();
238        let frame_bytes = make_frame_bytes(1, flags::RESPONSE, 42, b"hello");
239
240        let frames = buffer.push(&frame_bytes).unwrap();
241
242        assert_eq!(frames.len(), 1);
243        assert_eq!(frames[0].method_id(), 1);
244        assert_eq!(frames[0].request_id(), 42);
245        assert_eq!(&frames[0].payload[..], b"hello");
246        assert!(buffer.is_empty());
247    }
248
249    #[test]
250    fn test_multiple_frames_in_one_push() {
251        let mut buffer = FrameBuffer::new();
252
253        let frame1 = make_frame_bytes(1, 0, 1, b"first");
254        let frame2 = make_frame_bytes(2, 0, 2, b"second");
255        let frame3 = make_frame_bytes(3, 0, 3, b"third");
256
257        let mut combined = Vec::new();
258        combined.extend_from_slice(&frame1);
259        combined.extend_from_slice(&frame2);
260        combined.extend_from_slice(&frame3);
261
262        let frames = buffer.push(&combined).unwrap();
263
264        assert_eq!(frames.len(), 3);
265        assert_eq!(frames[0].method_id(), 1);
266        assert_eq!(frames[1].method_id(), 2);
267        assert_eq!(frames[2].method_id(), 3);
268        assert!(buffer.is_empty());
269    }
270
271    #[test]
272    fn test_fragmented_header() {
273        let mut buffer = FrameBuffer::new();
274        let frame_bytes = make_frame_bytes(1, 0, 42, b"test");
275
276        // Push first 5 bytes of header
277        let frames = buffer.push(&frame_bytes[..5]).unwrap();
278        assert!(frames.is_empty());
279        assert_eq!(buffer.state_name(), "WaitingForHeader");
280
281        // Push rest of header and payload
282        let frames = buffer.push(&frame_bytes[5..]).unwrap();
283        assert_eq!(frames.len(), 1);
284        assert_eq!(frames[0].method_id(), 1);
285        assert!(buffer.is_empty());
286    }
287
288    #[test]
289    fn test_fragmented_payload() {
290        let mut buffer = FrameBuffer::new();
291        let payload = b"this is a longer payload that will be fragmented";
292        let frame_bytes = make_frame_bytes(1, 0, 42, payload);
293
294        // Push header + partial payload
295        let partial_len = HEADER_SIZE + 10;
296        let frames = buffer.push(&frame_bytes[..partial_len]).unwrap();
297        assert!(frames.is_empty());
298        assert_eq!(buffer.state_name(), "WaitingForPayload");
299
300        // Push rest of payload
301        let frames = buffer.push(&frame_bytes[partial_len..]).unwrap();
302        assert_eq!(frames.len(), 1);
303        assert_eq!(&frames[0].payload[..], payload);
304        assert!(buffer.is_empty());
305    }
306
307    #[test]
308    fn test_empty_payload() {
309        let mut buffer = FrameBuffer::new();
310        let frame_bytes = make_frame_bytes(1, 0, 42, b"");
311
312        let frames = buffer.push(&frame_bytes).unwrap();
313
314        assert_eq!(frames.len(), 1);
315        assert!(frames[0].payload.is_empty());
316        assert_eq!(frames[0].header.payload_length, 0);
317    }
318
319    #[test]
320    fn test_large_payload() {
321        let mut buffer = FrameBuffer::new();
322        let payload = vec![0xAB; 1024 * 1024]; // 1MB
323        let frame_bytes = make_frame_bytes(1, 0, 42, &payload);
324
325        let frames = buffer.push(&frame_bytes).unwrap();
326
327        assert_eq!(frames.len(), 1);
328        assert_eq!(frames[0].payload.len(), 1024 * 1024);
329        assert!(frames[0].payload.iter().all(|&b| b == 0xAB));
330    }
331
332    #[test]
333    fn test_max_payload_validation() {
334        let mut buffer = FrameBuffer::with_max_payload(100);
335
336        // Create header claiming 1000 byte payload
337        let header = Header::new(1, 0, 42, 1000);
338        let header_bytes = header.encode();
339
340        let result = buffer.push(&header_bytes);
341
342        assert!(result.is_err());
343        assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
344    }
345
346    #[test]
347    fn test_frame_with_all_header_fields() {
348        let mut buffer = FrameBuffer::new();
349        let frame_bytes = make_frame_bytes(0x1234, flags::STREAM_END_RESPONSE, 0xDEADBEEF, b"data");
350
351        let frames = buffer.push(&frame_bytes).unwrap();
352
353        assert_eq!(frames.len(), 1);
354        let frame = &frames[0];
355        assert_eq!(frame.method_id(), 0x1234);
356        assert_eq!(frame.header.flags, flags::STREAM_END_RESPONSE);
357        assert_eq!(frame.request_id(), 0xDEADBEEF);
358        assert!(frame.is_stream());
359        assert!(frame.is_stream_end());
360    }
361
362    #[test]
363    fn test_clear_resets_state() {
364        let mut buffer = FrameBuffer::new();
365
366        // Push partial header (not complete)
367        let frame_bytes = make_frame_bytes(1, 0, 42, b"test");
368        buffer.push(&frame_bytes[..5]).unwrap(); // Only 5 bytes of header
369
370        assert_eq!(buffer.state_name(), "WaitingForHeader");
371        assert!(!buffer.is_empty());
372        assert_eq!(buffer.len(), 5);
373
374        // Push rest of header to transition to WaitingForPayload
375        buffer.push(&frame_bytes[5..HEADER_SIZE]).unwrap();
376        assert_eq!(buffer.state_name(), "WaitingForPayload");
377
378        buffer.clear();
379
380        assert_eq!(buffer.state_name(), "WaitingForHeader");
381        assert!(buffer.is_empty());
382    }
383
384    #[test]
385    fn test_mixed_complete_and_partial() {
386        let mut buffer = FrameBuffer::new();
387
388        let frame1 = make_frame_bytes(1, 0, 1, b"first");
389        let frame2 = make_frame_bytes(2, 0, 2, b"second");
390
391        // Push first complete frame + partial second
392        let mut data = frame1.clone();
393        data.extend_from_slice(&frame2[..5]);
394
395        let frames = buffer.push(&data).unwrap();
396        assert_eq!(frames.len(), 1);
397        assert_eq!(frames[0].method_id(), 1);
398        assert_eq!(buffer.state_name(), "WaitingForHeader");
399
400        // Complete second frame
401        let frames = buffer.push(&frame2[5..]).unwrap();
402        assert_eq!(frames.len(), 1);
403        assert_eq!(frames[0].method_id(), 2);
404    }
405
406    #[test]
407    fn test_byte_at_a_time() {
408        let mut buffer = FrameBuffer::new();
409        let frame_bytes = make_frame_bytes(1, 0, 42, b"hi");
410
411        let mut all_frames = Vec::new();
412
413        for byte in &frame_bytes {
414            let frames = buffer.push(&[*byte]).unwrap();
415            all_frames.extend(frames);
416        }
417
418        assert_eq!(all_frames.len(), 1);
419        assert_eq!(all_frames[0].method_id(), 1);
420        assert_eq!(&all_frames[0].payload[..], b"hi");
421    }
422}