1use bytes::{Buf, Bytes, BytesMut};
5use futures::stream::{BoxStream, StreamExt};
6use thiserror::Error;
7
8use crate::flow::Flow;
9
10#[derive(Debug, Error)]
11#[non_exhaustive]
12pub enum FramingError {
13 #[error("frame exceeds {0} bytes")]
14 FrameTooLarge(usize),
15 #[error("truncated frame at end of stream")]
16 Truncated,
17}
18
19pub struct Framing;
20
21struct FrameState<S> {
22 stream: S,
23 buf: BytesMut,
24 done: bool,
25}
26
27impl Framing {
28 pub fn delimiter(delimiter: u8, max_frame_length: usize) -> Flow<Bytes, Result<Bytes, FramingError>> {
31 Flow {
32 transform: Box::new(move |stream: BoxStream<'static, Bytes>| {
33 futures::stream::unfold(
34 FrameState { stream, buf: BytesMut::new(), done: false },
35 move |mut st| async move {
36 if st.done {
37 return None;
38 }
39 loop {
40 if let Some(pos) = st.buf.iter().position(|b| *b == delimiter) {
41 let frame = st.buf.split_to(pos).freeze();
42 st.buf.advance(1);
43 if frame.len() > max_frame_length {
44 st.done = true;
45 return Some((Err(FramingError::FrameTooLarge(max_frame_length)), st));
46 }
47 return Some((Ok(frame), st));
48 }
49 match st.stream.next().await {
50 Some(chunk) => {
51 st.buf.extend_from_slice(&chunk);
52 if st.buf.len() > max_frame_length {
53 st.done = true;
54 return Some((
55 Err(FramingError::FrameTooLarge(max_frame_length)),
56 st,
57 ));
58 }
59 }
60 None => {
61 if st.buf.is_empty() {
62 return None;
63 }
64 st.done = true;
65 return Some((Err(FramingError::Truncated), st));
66 }
67 }
68 }
69 },
70 )
71 .boxed()
72 }),
73 }
74 }
75
76 pub fn length_field(max_frame_length: usize) -> Flow<Bytes, Result<Bytes, FramingError>> {
79 Flow {
80 transform: Box::new(move |stream: BoxStream<'static, Bytes>| {
81 futures::stream::unfold(
82 FrameState { stream, buf: BytesMut::new(), done: false },
83 move |mut st| async move {
84 if st.done {
85 return None;
86 }
87 loop {
88 if st.buf.len() >= 4 {
89 let len = u32::from_le_bytes(st.buf[..4].try_into().unwrap()) as usize;
90 if len > max_frame_length {
91 st.done = true;
92 return Some((Err(FramingError::FrameTooLarge(max_frame_length)), st));
93 }
94 if st.buf.len() >= 4 + len {
95 st.buf.advance(4);
96 let frame = st.buf.split_to(len).freeze();
97 return Some((Ok(frame), st));
98 }
99 }
100 match st.stream.next().await {
101 Some(chunk) => st.buf.extend_from_slice(&chunk),
102 None => {
103 if st.buf.is_empty() {
104 return None;
105 }
106 st.done = true;
107 return Some((Err(FramingError::Truncated), st));
108 }
109 }
110 }
111 },
112 )
113 .boxed()
114 }),
115 }
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use crate::sink::Sink;
123 use crate::source::Source;
124
125 #[tokio::test]
126 async fn delimiter_framing_splits_chunks() {
127 let source =
128 Source::from_iter(vec![Bytes::from_static(b"hello\nwo"), Bytes::from_static(b"rld\nfoo\n")]);
129 let framed = source.via(Framing::delimiter(b'\n', 1024));
130 let out: Vec<_> = Sink::collect(framed).await;
131 let ok: Vec<_> = out.into_iter().map(|r| r.unwrap()).collect();
132 assert_eq!(
133 ok,
134 vec![Bytes::from_static(b"hello"), Bytes::from_static(b"world"), Bytes::from_static(b"foo"),]
135 );
136 }
137
138 #[tokio::test]
139 async fn length_field_framing_handles_splits() {
140 let mut buf = Vec::new();
141 let msgs: [&[u8]; 2] = [b"abc", b"hello"];
142 for m in msgs {
143 buf.extend_from_slice(&(m.len() as u32).to_le_bytes());
144 buf.extend_from_slice(m);
145 }
146 let source =
147 Source::from_iter(vec![Bytes::copy_from_slice(&buf[..5]), Bytes::copy_from_slice(&buf[5..])]);
148 let framed = source.via(Framing::length_field(1024));
149 let out: Vec<_> = Sink::collect(framed).await;
150 let ok: Vec<_> = out.into_iter().map(|r| r.unwrap()).collect();
151 assert_eq!(ok, vec![Bytes::from_static(b"abc"), Bytes::from_static(b"hello")]);
152 }
153}