Skip to main content

atomr_streams/
framing.rs

1//! Framing utilities — split a byte stream into messages.
2//! akka.net: `Dsl/Framing.cs`, `Dsl/JsonFraming.cs`.
3
4use bytes::{Buf, Bytes, BytesMut};
5use futures::stream::{BoxStream, StreamExt};
6use thiserror::Error;
7
8use crate::flow::Flow;
9
10#[derive(Debug, Error)]
11#[non_exhaustive]
12pub enum FramingError {
13    #[error("frame exceeds {0} bytes")]
14    FrameTooLarge(usize),
15    #[error("truncated frame at end of stream")]
16    Truncated,
17}
18
19pub struct Framing;
20
21struct FrameState<S> {
22    stream: S,
23    buf: BytesMut,
24    done: bool,
25}
26
27impl Framing {
28    /// Split an incoming `Bytes` stream using a single-byte delimiter, dropping
29    /// the delimiter from each produced frame. akka.net: `Framing.Delimiter`.
30    pub fn delimiter(delimiter: u8, max_frame_length: usize) -> Flow<Bytes, Result<Bytes, FramingError>> {
31        Flow {
32            transform: Box::new(move |stream: BoxStream<'static, Bytes>| {
33                futures::stream::unfold(
34                    FrameState { stream, buf: BytesMut::new(), done: false },
35                    move |mut st| async move {
36                        if st.done {
37                            return None;
38                        }
39                        loop {
40                            if let Some(pos) = st.buf.iter().position(|b| *b == delimiter) {
41                                let frame = st.buf.split_to(pos).freeze();
42                                st.buf.advance(1);
43                                if frame.len() > max_frame_length {
44                                    st.done = true;
45                                    return Some((Err(FramingError::FrameTooLarge(max_frame_length)), st));
46                                }
47                                return Some((Ok(frame), st));
48                            }
49                            match st.stream.next().await {
50                                Some(chunk) => {
51                                    st.buf.extend_from_slice(&chunk);
52                                    if st.buf.len() > max_frame_length {
53                                        st.done = true;
54                                        return Some((
55                                            Err(FramingError::FrameTooLarge(max_frame_length)),
56                                            st,
57                                        ));
58                                    }
59                                }
60                                None => {
61                                    if st.buf.is_empty() {
62                                        return None;
63                                    }
64                                    st.done = true;
65                                    return Some((Err(FramingError::Truncated), st));
66                                }
67                            }
68                        }
69                    },
70                )
71                .boxed()
72            }),
73        }
74    }
75
76    /// Split by length-prefixed frames. The prefix is a little-endian u32
77    /// giving the size of the payload. akka.net: `Framing.LengthField`.
78    pub fn length_field(max_frame_length: usize) -> Flow<Bytes, Result<Bytes, FramingError>> {
79        Flow {
80            transform: Box::new(move |stream: BoxStream<'static, Bytes>| {
81                futures::stream::unfold(
82                    FrameState { stream, buf: BytesMut::new(), done: false },
83                    move |mut st| async move {
84                        if st.done {
85                            return None;
86                        }
87                        loop {
88                            if st.buf.len() >= 4 {
89                                let len = u32::from_le_bytes(st.buf[..4].try_into().unwrap()) as usize;
90                                if len > max_frame_length {
91                                    st.done = true;
92                                    return Some((Err(FramingError::FrameTooLarge(max_frame_length)), st));
93                                }
94                                if st.buf.len() >= 4 + len {
95                                    st.buf.advance(4);
96                                    let frame = st.buf.split_to(len).freeze();
97                                    return Some((Ok(frame), st));
98                                }
99                            }
100                            match st.stream.next().await {
101                                Some(chunk) => st.buf.extend_from_slice(&chunk),
102                                None => {
103                                    if st.buf.is_empty() {
104                                        return None;
105                                    }
106                                    st.done = true;
107                                    return Some((Err(FramingError::Truncated), st));
108                                }
109                            }
110                        }
111                    },
112                )
113                .boxed()
114            }),
115        }
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use crate::sink::Sink;
123    use crate::source::Source;
124
125    #[tokio::test]
126    async fn delimiter_framing_splits_chunks() {
127        let source =
128            Source::from_iter(vec![Bytes::from_static(b"hello\nwo"), Bytes::from_static(b"rld\nfoo\n")]);
129        let framed = source.via(Framing::delimiter(b'\n', 1024));
130        let out: Vec<_> = Sink::collect(framed).await;
131        let ok: Vec<_> = out.into_iter().map(|r| r.unwrap()).collect();
132        assert_eq!(
133            ok,
134            vec![Bytes::from_static(b"hello"), Bytes::from_static(b"world"), Bytes::from_static(b"foo"),]
135        );
136    }
137
138    #[tokio::test]
139    async fn length_field_framing_handles_splits() {
140        let mut buf = Vec::new();
141        let msgs: [&[u8]; 2] = [b"abc", b"hello"];
142        for m in msgs {
143            buf.extend_from_slice(&(m.len() as u32).to_le_bytes());
144            buf.extend_from_slice(m);
145        }
146        let source =
147            Source::from_iter(vec![Bytes::copy_from_slice(&buf[..5]), Bytes::copy_from_slice(&buf[5..])]);
148        let framed = source.via(Framing::length_field(1024));
149        let out: Vec<_> = Sink::collect(framed).await;
150        let ok: Vec<_> = out.into_iter().map(|r| r.unwrap()).collect();
151        assert_eq!(ok, vec![Bytes::from_static(b"abc"), Bytes::from_static(b"hello")]);
152    }
153}