1use bytes::{Buf, Bytes, BytesMut};
4use futures::stream::{BoxStream, StreamExt};
5use thiserror::Error;
6
7use crate::flow::Flow;
8
9#[derive(Debug, Error)]
10#[non_exhaustive]
11pub enum FramingError {
12 #[error("frame exceeds {0} bytes")]
13 FrameTooLarge(usize),
14 #[error("truncated frame at end of stream")]
15 Truncated,
16}
17
18pub struct Framing;
19
20struct FrameState<S> {
21 stream: S,
22 buf: BytesMut,
23 done: bool,
24}
25
26impl Framing {
27 pub fn delimiter(delimiter: u8, max_frame_length: usize) -> Flow<Bytes, Result<Bytes, FramingError>> {
30 Flow {
31 transform: Box::new(move |stream: BoxStream<'static, Bytes>| {
32 futures::stream::unfold(
33 FrameState { stream, buf: BytesMut::new(), done: false },
34 move |mut st| async move {
35 if st.done {
36 return None;
37 }
38 loop {
39 if let Some(pos) = st.buf.iter().position(|b| *b == delimiter) {
40 let frame = st.buf.split_to(pos).freeze();
41 st.buf.advance(1);
42 if frame.len() > max_frame_length {
43 st.done = true;
44 return Some((Err(FramingError::FrameTooLarge(max_frame_length)), st));
45 }
46 return Some((Ok(frame), st));
47 }
48 match st.stream.next().await {
49 Some(chunk) => {
50 st.buf.extend_from_slice(&chunk);
51 if st.buf.len() > max_frame_length {
52 st.done = true;
53 return Some((
54 Err(FramingError::FrameTooLarge(max_frame_length)),
55 st,
56 ));
57 }
58 }
59 None => {
60 if st.buf.is_empty() {
61 return None;
62 }
63 st.done = true;
64 return Some((Err(FramingError::Truncated), st));
65 }
66 }
67 }
68 },
69 )
70 .boxed()
71 }),
72 }
73 }
74
75 pub fn length_field(max_frame_length: usize) -> Flow<Bytes, Result<Bytes, FramingError>> {
78 Flow {
79 transform: Box::new(move |stream: BoxStream<'static, Bytes>| {
80 futures::stream::unfold(
81 FrameState { stream, buf: BytesMut::new(), done: false },
82 move |mut st| async move {
83 if st.done {
84 return None;
85 }
86 loop {
87 if st.buf.len() >= 4 {
88 let len = u32::from_le_bytes(st.buf[..4].try_into().unwrap()) as usize;
89 if len > max_frame_length {
90 st.done = true;
91 return Some((Err(FramingError::FrameTooLarge(max_frame_length)), st));
92 }
93 if st.buf.len() >= 4 + len {
94 st.buf.advance(4);
95 let frame = st.buf.split_to(len).freeze();
96 return Some((Ok(frame), st));
97 }
98 }
99 match st.stream.next().await {
100 Some(chunk) => st.buf.extend_from_slice(&chunk),
101 None => {
102 if st.buf.is_empty() {
103 return None;
104 }
105 st.done = true;
106 return Some((Err(FramingError::Truncated), st));
107 }
108 }
109 }
110 },
111 )
112 .boxed()
113 }),
114 }
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use crate::sink::Sink;
122 use crate::source::Source;
123
124 #[tokio::test]
125 async fn delimiter_framing_splits_chunks() {
126 let source =
127 Source::from_iter(vec![Bytes::from_static(b"hello\nwo"), Bytes::from_static(b"rld\nfoo\n")]);
128 let framed = source.via(Framing::delimiter(b'\n', 1024));
129 let out: Vec<_> = Sink::collect(framed).await;
130 let ok: Vec<_> = out.into_iter().map(|r| r.unwrap()).collect();
131 assert_eq!(
132 ok,
133 vec![Bytes::from_static(b"hello"), Bytes::from_static(b"world"), Bytes::from_static(b"foo"),]
134 );
135 }
136
137 #[tokio::test]
138 async fn length_field_framing_handles_splits() {
139 let mut buf = Vec::new();
140 let msgs: [&[u8]; 2] = [b"abc", b"hello"];
141 for m in msgs {
142 buf.extend_from_slice(&(m.len() as u32).to_le_bytes());
143 buf.extend_from_slice(m);
144 }
145 let source =
146 Source::from_iter(vec![Bytes::copy_from_slice(&buf[..5]), Bytes::copy_from_slice(&buf[5..])]);
147 let framed = source.via(Framing::length_field(1024));
148 let out: Vec<_> = Sink::collect(framed).await;
149 let ok: Vec<_> = out.into_iter().map(|r| r.unwrap()).collect();
150 assert_eq!(ok, vec![Bytes::from_static(b"abc"), Bytes::from_static(b"hello")]);
151 }
152}