amp_async/
codecs.rs

1use std::convert::TryInto;
2
3use bytes::{Buf, Bytes, BytesMut};
4use tokio_util::codec::Decoder;
5
6pub(crate) const AMP_KEY_LIMIT: usize = 0xff;
7const LENGTH_SIZE: usize = std::mem::size_of::<u16>();
8
9#[derive(Debug, Default, PartialEq)]
10pub struct Dec<D = Vec<(Bytes, Bytes)>> {
11    state: State,
12    key: Vec<u8>,
13    frame: D,
14}
15
16#[derive(Debug, PartialEq)]
17enum State {
18    Key,
19    Value,
20}
21
22impl Default for State {
23    fn default() -> Self {
24        State::Key
25    }
26}
27
28impl<D> Dec<D>
29where
30    D: Default,
31{
32    pub fn new() -> Self {
33        Default::default()
34    }
35
36    fn read_key(length: usize, buf: &mut BytesMut) -> Result<Option<Bytes>, CodecError> {
37        if length > AMP_KEY_LIMIT {
38            return Err(CodecError::KeyTooLong);
39        }
40
41        Ok(Self::read_delimited(length, buf))
42    }
43
44    fn read_delimited(length: usize, buf: &mut BytesMut) -> Option<Bytes> {
45        if buf.len() >= length + LENGTH_SIZE {
46            buf.advance(LENGTH_SIZE);
47            Some(buf.split_to(length).freeze())
48        } else {
49            None
50        }
51    }
52}
53
54impl<D> Decoder for Dec<D>
55where
56    D: Default + Extend<(Vec<u8>, Bytes)>,
57{
58    type Error = CodecError;
59    type Item = D;
60
61    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
62        loop {
63            if buf.len() < LENGTH_SIZE {
64                return Ok(None);
65            }
66
67            let (length_bytes, _) = buf.split_at(LENGTH_SIZE);
68            let length = usize::from(u16::from_be_bytes(length_bytes.try_into().unwrap()));
69
70            match self.state {
71                State::Key => {
72                    if length == 0 {
73                        buf.advance(LENGTH_SIZE);
74                        return Ok(Some(std::mem::take(&mut self.frame)));
75                    } else {
76                        match Self::read_key(length, buf)? {
77                            Some(key) => {
78                                self.key = key.to_vec();
79                                self.state = State::Value;
80                            }
81                            None => {
82                                return Ok(None);
83                            }
84                        }
85                    }
86                }
87                State::Value => match Self::read_delimited(length, buf) {
88                    Some(value) => {
89                        let key = std::mem::take(&mut self.key);
90                        self.frame.extend(std::iter::once((key, value)));
91                        self.state = State::Key;
92                    }
93                    None => {
94                        return Ok(None);
95                    }
96                },
97            }
98        }
99    }
100}
101
102#[derive(Debug)]
103pub enum CodecError {
104    IO(std::io::Error),
105    KeyTooLong,
106    EmptyKey,
107    ValueTooLong,
108    Serde(String),
109    Unsupported,
110}
111
112impl From<std::io::Error> for CodecError {
113    fn from(err: std::io::Error) -> Self {
114        Self::IO(err)
115    }
116}
117
118impl std::fmt::Display for CodecError {
119    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
120        write!(fmt, "{:?}", self)
121    }
122}
123
124impl std::error::Error for CodecError {}
125
126#[cfg(test)]
127mod test {
128    use amp_serde::Request;
129    use bytes::BytesMut;
130    use serde::Serialize;
131    use tokio_util::codec::Decoder as _;
132
133    use crate::*;
134
135    const WWW_EXAMPLE: &[u8] = &[
136        0x00, 0x04, 0x5F, 0x61, 0x73, 0x6B, 0x00, 0x02, 0x32, 0x33, 0x00, 0x08, 0x5F, 0x63, 0x6F,
137        0x6D, 0x6D, 0x61, 0x6E, 0x64, 0x00, 0x03, 0x53, 0x75, 0x6D, 0x00, 0x01, 0x61, 0x00, 0x02,
138        0x31, 0x33, 0x00, 0x01, 0x62, 0x00, 0x02, 0x38, 0x31, 0x00, 0x00,
139    ];
140    const WWW_EXAMPLE_DEC: &[(&[u8], &[u8])] = &[
141        (b"_ask", b"23"),
142        (b"_command", b"Sum"),
143        (b"a", b"13"),
144        (b"b", b"81"),
145    ];
146
147    #[test]
148    fn decode_example() {
149        let mut dec = Decoder::<Vec<_>>::new();
150        let mut buf = BytesMut::new();
151        buf.extend(WWW_EXAMPLE);
152
153        let frame = dec.decode(&mut buf).unwrap().unwrap();
154
155        assert_eq!(
156            frame
157                .iter()
158                .map(|(k, v)| (k.as_ref(), v.as_ref()))
159                .collect::<Vec<_>>(),
160            WWW_EXAMPLE_DEC
161        );
162        assert_eq!(buf.len(), 0);
163        assert_eq!(dec, Decoder::<Vec<_>>::new());
164    }
165
166    #[test]
167    fn encode_example() {
168        #[derive(Serialize)]
169        struct Sum {
170            a: u32,
171            b: u32,
172        }
173        let fields = Sum { a: 13, b: 81 };
174
175        let buf = amp_serde::to_bytes(Request {
176            command: "Sum".into(),
177            tag: Some(b"23".as_ref().into()),
178            fields,
179        })
180        .unwrap();
181
182        assert_eq!(buf, WWW_EXAMPLE);
183    }
184}