Skip to main content

atomr_streams/
framing.rs

1//! Framing utilities — split a byte stream into messages.
2
3use bytes::{Buf, Bytes, BytesMut};
4use futures::stream::{BoxStream, StreamExt};
5use thiserror::Error;
6
7use crate::flow::Flow;
8
9#[derive(Debug, Error)]
10#[non_exhaustive]
11pub enum FramingError {
12    #[error("frame exceeds {0} bytes")]
13    FrameTooLarge(usize),
14    #[error("truncated frame at end of stream")]
15    Truncated,
16}
17
18pub struct Framing;
19
20struct FrameState<S> {
21    stream: S,
22    buf: BytesMut,
23    done: bool,
24}
25
26impl Framing {
27    /// Split an incoming `Bytes` stream using a single-byte delimiter, dropping
28    /// the delimiter from each produced frame.
29    pub fn delimiter(delimiter: u8, max_frame_length: usize) -> Flow<Bytes, Result<Bytes, FramingError>> {
30        Flow {
31            transform: Box::new(move |stream: BoxStream<'static, Bytes>| {
32                futures::stream::unfold(
33                    FrameState { stream, buf: BytesMut::new(), done: false },
34                    move |mut st| async move {
35                        if st.done {
36                            return None;
37                        }
38                        loop {
39                            if let Some(pos) = st.buf.iter().position(|b| *b == delimiter) {
40                                let frame = st.buf.split_to(pos).freeze();
41                                st.buf.advance(1);
42                                if frame.len() > max_frame_length {
43                                    st.done = true;
44                                    return Some((Err(FramingError::FrameTooLarge(max_frame_length)), st));
45                                }
46                                return Some((Ok(frame), st));
47                            }
48                            match st.stream.next().await {
49                                Some(chunk) => {
50                                    st.buf.extend_from_slice(&chunk);
51                                    if st.buf.len() > max_frame_length {
52                                        st.done = true;
53                                        return Some((
54                                            Err(FramingError::FrameTooLarge(max_frame_length)),
55                                            st,
56                                        ));
57                                    }
58                                }
59                                None => {
60                                    if st.buf.is_empty() {
61                                        return None;
62                                    }
63                                    st.done = true;
64                                    return Some((Err(FramingError::Truncated), st));
65                                }
66                            }
67                        }
68                    },
69                )
70                .boxed()
71            }),
72        }
73    }
74
75    /// Split by length-prefixed frames. The prefix is a little-endian u32
76    /// giving the size of the payload.
77    pub fn length_field(max_frame_length: usize) -> Flow<Bytes, Result<Bytes, FramingError>> {
78        Flow {
79            transform: Box::new(move |stream: BoxStream<'static, Bytes>| {
80                futures::stream::unfold(
81                    FrameState { stream, buf: BytesMut::new(), done: false },
82                    move |mut st| async move {
83                        if st.done {
84                            return None;
85                        }
86                        loop {
87                            if st.buf.len() >= 4 {
88                                let len = u32::from_le_bytes(st.buf[..4].try_into().unwrap()) as usize;
89                                if len > max_frame_length {
90                                    st.done = true;
91                                    return Some((Err(FramingError::FrameTooLarge(max_frame_length)), st));
92                                }
93                                if st.buf.len() >= 4 + len {
94                                    st.buf.advance(4);
95                                    let frame = st.buf.split_to(len).freeze();
96                                    return Some((Ok(frame), st));
97                                }
98                            }
99                            match st.stream.next().await {
100                                Some(chunk) => st.buf.extend_from_slice(&chunk),
101                                None => {
102                                    if st.buf.is_empty() {
103                                        return None;
104                                    }
105                                    st.done = true;
106                                    return Some((Err(FramingError::Truncated), st));
107                                }
108                            }
109                        }
110                    },
111                )
112                .boxed()
113            }),
114        }
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121    use crate::sink::Sink;
122    use crate::source::Source;
123
124    #[tokio::test]
125    async fn delimiter_framing_splits_chunks() {
126        let source =
127            Source::from_iter(vec![Bytes::from_static(b"hello\nwo"), Bytes::from_static(b"rld\nfoo\n")]);
128        let framed = source.via(Framing::delimiter(b'\n', 1024));
129        let out: Vec<_> = Sink::collect(framed).await;
130        let ok: Vec<_> = out.into_iter().map(|r| r.unwrap()).collect();
131        assert_eq!(
132            ok,
133            vec![Bytes::from_static(b"hello"), Bytes::from_static(b"world"), Bytes::from_static(b"foo"),]
134        );
135    }
136
137    #[tokio::test]
138    async fn length_field_framing_handles_splits() {
139        let mut buf = Vec::new();
140        let msgs: [&[u8]; 2] = [b"abc", b"hello"];
141        for m in msgs {
142            buf.extend_from_slice(&(m.len() as u32).to_le_bytes());
143            buf.extend_from_slice(m);
144        }
145        let source =
146            Source::from_iter(vec![Bytes::copy_from_slice(&buf[..5]), Bytes::copy_from_slice(&buf[5..])]);
147        let framed = source.via(Framing::length_field(1024));
148        let out: Vec<_> = Sink::collect(framed).await;
149        let ok: Vec<_> = out.into_iter().map(|r| r.unwrap()).collect();
150        assert_eq!(ok, vec![Bytes::from_static(b"abc"), Bytes::from_static(b"hello")]);
151    }
152}