procwire_client/protocol/
frame_buffer.rs1use bytes::{Bytes, BytesMut};
25
26use super::wire_format::{Header, DEFAULT_MAX_PAYLOAD_SIZE, HEADER_SIZE};
27use super::Frame;
28use crate::error::{ProcwireError, Result};
29
30#[derive(Debug, Clone)]
32enum State {
33 WaitingForHeader,
35 WaitingForPayload { header: Header, remaining: u32 },
37}
38
39pub struct FrameBuffer {
44 buffer: BytesMut,
46 state: State,
48 max_payload_size: u32,
50}
51
52impl FrameBuffer {
53 pub fn new() -> Self {
57 Self {
58 buffer: BytesMut::with_capacity(64 * 1024),
59 state: State::WaitingForHeader,
60 max_payload_size: DEFAULT_MAX_PAYLOAD_SIZE,
61 }
62 }
63
64 pub fn with_max_payload(max_payload_size: u32) -> Self {
66 Self {
67 buffer: BytesMut::with_capacity(64 * 1024),
68 state: State::WaitingForHeader,
69 max_payload_size,
70 }
71 }
72
73 pub fn with_capacity_and_max_payload(capacity: usize, max_payload_size: u32) -> Self {
75 Self {
76 buffer: BytesMut::with_capacity(capacity),
77 state: State::WaitingForHeader,
78 max_payload_size,
79 }
80 }
81
82 pub fn push(&mut self, data: &[u8]) -> Result<Vec<Frame>> {
100 self.buffer.extend_from_slice(data);
102
103 let mut frames = Vec::new();
104
105 while let Some(frame) = self.try_extract_one()? {
107 frames.push(frame);
108 }
109
110 Ok(frames)
111 }
112
113 fn try_extract_one(&mut self) -> Result<Option<Frame>> {
120 match &self.state {
121 State::WaitingForHeader => {
122 if self.buffer.len() < HEADER_SIZE {
123 return Ok(None);
124 }
125
126 let header =
128 Header::decode(&self.buffer[..HEADER_SIZE]).expect("Buffer has enough bytes");
129
130 if header.payload_length > self.max_payload_size {
132 return Err(ProcwireError::Protocol(format!(
133 "Payload size {} exceeds maximum {}",
134 header.payload_length, self.max_payload_size
135 )));
136 }
137
138 let _ = self.buffer.split_to(HEADER_SIZE);
140
141 if header.payload_length == 0 {
142 return Ok(Some(Frame::new(header, Bytes::new())));
144 }
145
146 self.state = State::WaitingForPayload {
148 header,
149 remaining: header.payload_length,
150 };
151
152 self.try_extract_one()
154 }
155
156 State::WaitingForPayload { header, remaining } => {
157 let remaining = *remaining as usize;
158
159 if self.buffer.len() < remaining {
160 return Ok(None);
161 }
162
163 let payload = self.buffer.split_to(remaining).freeze();
165 let header = *header;
166
167 self.state = State::WaitingForHeader;
169
170 Ok(Some(Frame::new(header, payload)))
171 }
172 }
173 }
174
175 #[deprecated(note = "Use push() instead for proper multi-frame handling")]
179 pub fn try_extract(&mut self) -> Option<Frame> {
180 self.try_extract_one().ok().flatten()
181 }
182
183 pub fn extend(&mut self, data: &[u8]) {
187 self.buffer.extend_from_slice(data);
188 }
189
190 pub fn len(&self) -> usize {
192 self.buffer.len()
193 }
194
195 pub fn is_empty(&self) -> bool {
197 self.buffer.is_empty()
198 }
199
200 pub fn clear(&mut self) {
202 self.buffer.clear();
203 self.state = State::WaitingForHeader;
204 }
205
206 #[cfg(test)]
208 fn state_name(&self) -> &'static str {
209 match &self.state {
210 State::WaitingForHeader => "WaitingForHeader",
211 State::WaitingForPayload { .. } => "WaitingForPayload",
212 }
213 }
214}
215
216impl Default for FrameBuffer {
217 fn default() -> Self {
218 Self::new()
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use crate::protocol::flags;
226
227 fn make_frame_bytes(method_id: u16, flags: u8, request_id: u32, payload: &[u8]) -> Vec<u8> {
229 let header = Header::new(method_id, flags, request_id, payload.len() as u32);
230 let mut bytes = header.encode().to_vec();
231 bytes.extend_from_slice(payload);
232 bytes
233 }
234
235 #[test]
236 fn test_single_complete_frame() {
237 let mut buffer = FrameBuffer::new();
238 let frame_bytes = make_frame_bytes(1, flags::RESPONSE, 42, b"hello");
239
240 let frames = buffer.push(&frame_bytes).unwrap();
241
242 assert_eq!(frames.len(), 1);
243 assert_eq!(frames[0].method_id(), 1);
244 assert_eq!(frames[0].request_id(), 42);
245 assert_eq!(&frames[0].payload[..], b"hello");
246 assert!(buffer.is_empty());
247 }
248
249 #[test]
250 fn test_multiple_frames_in_one_push() {
251 let mut buffer = FrameBuffer::new();
252
253 let frame1 = make_frame_bytes(1, 0, 1, b"first");
254 let frame2 = make_frame_bytes(2, 0, 2, b"second");
255 let frame3 = make_frame_bytes(3, 0, 3, b"third");
256
257 let mut combined = Vec::new();
258 combined.extend_from_slice(&frame1);
259 combined.extend_from_slice(&frame2);
260 combined.extend_from_slice(&frame3);
261
262 let frames = buffer.push(&combined).unwrap();
263
264 assert_eq!(frames.len(), 3);
265 assert_eq!(frames[0].method_id(), 1);
266 assert_eq!(frames[1].method_id(), 2);
267 assert_eq!(frames[2].method_id(), 3);
268 assert!(buffer.is_empty());
269 }
270
271 #[test]
272 fn test_fragmented_header() {
273 let mut buffer = FrameBuffer::new();
274 let frame_bytes = make_frame_bytes(1, 0, 42, b"test");
275
276 let frames = buffer.push(&frame_bytes[..5]).unwrap();
278 assert!(frames.is_empty());
279 assert_eq!(buffer.state_name(), "WaitingForHeader");
280
281 let frames = buffer.push(&frame_bytes[5..]).unwrap();
283 assert_eq!(frames.len(), 1);
284 assert_eq!(frames[0].method_id(), 1);
285 assert!(buffer.is_empty());
286 }
287
288 #[test]
289 fn test_fragmented_payload() {
290 let mut buffer = FrameBuffer::new();
291 let payload = b"this is a longer payload that will be fragmented";
292 let frame_bytes = make_frame_bytes(1, 0, 42, payload);
293
294 let partial_len = HEADER_SIZE + 10;
296 let frames = buffer.push(&frame_bytes[..partial_len]).unwrap();
297 assert!(frames.is_empty());
298 assert_eq!(buffer.state_name(), "WaitingForPayload");
299
300 let frames = buffer.push(&frame_bytes[partial_len..]).unwrap();
302 assert_eq!(frames.len(), 1);
303 assert_eq!(&frames[0].payload[..], payload);
304 assert!(buffer.is_empty());
305 }
306
307 #[test]
308 fn test_empty_payload() {
309 let mut buffer = FrameBuffer::new();
310 let frame_bytes = make_frame_bytes(1, 0, 42, b"");
311
312 let frames = buffer.push(&frame_bytes).unwrap();
313
314 assert_eq!(frames.len(), 1);
315 assert!(frames[0].payload.is_empty());
316 assert_eq!(frames[0].header.payload_length, 0);
317 }
318
319 #[test]
320 fn test_large_payload() {
321 let mut buffer = FrameBuffer::new();
322 let payload = vec![0xAB; 1024 * 1024]; let frame_bytes = make_frame_bytes(1, 0, 42, &payload);
324
325 let frames = buffer.push(&frame_bytes).unwrap();
326
327 assert_eq!(frames.len(), 1);
328 assert_eq!(frames[0].payload.len(), 1024 * 1024);
329 assert!(frames[0].payload.iter().all(|&b| b == 0xAB));
330 }
331
332 #[test]
333 fn test_max_payload_validation() {
334 let mut buffer = FrameBuffer::with_max_payload(100);
335
336 let header = Header::new(1, 0, 42, 1000);
338 let header_bytes = header.encode();
339
340 let result = buffer.push(&header_bytes);
341
342 assert!(result.is_err());
343 assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
344 }
345
346 #[test]
347 fn test_frame_with_all_header_fields() {
348 let mut buffer = FrameBuffer::new();
349 let frame_bytes = make_frame_bytes(0x1234, flags::STREAM_END_RESPONSE, 0xDEADBEEF, b"data");
350
351 let frames = buffer.push(&frame_bytes).unwrap();
352
353 assert_eq!(frames.len(), 1);
354 let frame = &frames[0];
355 assert_eq!(frame.method_id(), 0x1234);
356 assert_eq!(frame.header.flags, flags::STREAM_END_RESPONSE);
357 assert_eq!(frame.request_id(), 0xDEADBEEF);
358 assert!(frame.is_stream());
359 assert!(frame.is_stream_end());
360 }
361
362 #[test]
363 fn test_clear_resets_state() {
364 let mut buffer = FrameBuffer::new();
365
366 let frame_bytes = make_frame_bytes(1, 0, 42, b"test");
368 buffer.push(&frame_bytes[..5]).unwrap(); assert_eq!(buffer.state_name(), "WaitingForHeader");
371 assert!(!buffer.is_empty());
372 assert_eq!(buffer.len(), 5);
373
374 buffer.push(&frame_bytes[5..HEADER_SIZE]).unwrap();
376 assert_eq!(buffer.state_name(), "WaitingForPayload");
377
378 buffer.clear();
379
380 assert_eq!(buffer.state_name(), "WaitingForHeader");
381 assert!(buffer.is_empty());
382 }
383
384 #[test]
385 fn test_mixed_complete_and_partial() {
386 let mut buffer = FrameBuffer::new();
387
388 let frame1 = make_frame_bytes(1, 0, 1, b"first");
389 let frame2 = make_frame_bytes(2, 0, 2, b"second");
390
391 let mut data = frame1.clone();
393 data.extend_from_slice(&frame2[..5]);
394
395 let frames = buffer.push(&data).unwrap();
396 assert_eq!(frames.len(), 1);
397 assert_eq!(frames[0].method_id(), 1);
398 assert_eq!(buffer.state_name(), "WaitingForHeader");
399
400 let frames = buffer.push(&frame2[5..]).unwrap();
402 assert_eq!(frames.len(), 1);
403 assert_eq!(frames[0].method_id(), 2);
404 }
405
406 #[test]
407 fn test_byte_at_a_time() {
408 let mut buffer = FrameBuffer::new();
409 let frame_bytes = make_frame_bytes(1, 0, 42, b"hi");
410
411 let mut all_frames = Vec::new();
412
413 for byte in &frame_bytes {
414 let frames = buffer.push(&[*byte]).unwrap();
415 all_frames.extend(frames);
416 }
417
418 assert_eq!(all_frames.len(), 1);
419 assert_eq!(all_frames[0].method_id(), 1);
420 assert_eq!(&all_frames[0].payload[..], b"hi");
421 }
422}