1use std::convert::TryInto;
2
3use bytes::{Buf, Bytes, BytesMut};
4use tokio_util::codec::Decoder;
5
6pub(crate) const AMP_KEY_LIMIT: usize = 0xff;
7const LENGTH_SIZE: usize = std::mem::size_of::<u16>();
8
9#[derive(Debug, Default, PartialEq)]
10pub struct Dec<D = Vec<(Bytes, Bytes)>> {
11 state: State,
12 key: Vec<u8>,
13 frame: D,
14}
15
16#[derive(Debug, PartialEq)]
17enum State {
18 Key,
19 Value,
20}
21
22impl Default for State {
23 fn default() -> Self {
24 State::Key
25 }
26}
27
28impl<D> Dec<D>
29where
30 D: Default,
31{
32 pub fn new() -> Self {
33 Default::default()
34 }
35
36 fn read_key(length: usize, buf: &mut BytesMut) -> Result<Option<Bytes>, CodecError> {
37 if length > AMP_KEY_LIMIT {
38 return Err(CodecError::KeyTooLong);
39 }
40
41 Ok(Self::read_delimited(length, buf))
42 }
43
44 fn read_delimited(length: usize, buf: &mut BytesMut) -> Option<Bytes> {
45 if buf.len() >= length + LENGTH_SIZE {
46 buf.advance(LENGTH_SIZE);
47 Some(buf.split_to(length).freeze())
48 } else {
49 None
50 }
51 }
52}
53
54impl<D> Decoder for Dec<D>
55where
56 D: Default + Extend<(Vec<u8>, Bytes)>,
57{
58 type Error = CodecError;
59 type Item = D;
60
61 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
62 loop {
63 if buf.len() < LENGTH_SIZE {
64 return Ok(None);
65 }
66
67 let (length_bytes, _) = buf.split_at(LENGTH_SIZE);
68 let length = usize::from(u16::from_be_bytes(length_bytes.try_into().unwrap()));
69
70 match self.state {
71 State::Key => {
72 if length == 0 {
73 buf.advance(LENGTH_SIZE);
74 return Ok(Some(std::mem::take(&mut self.frame)));
75 } else {
76 match Self::read_key(length, buf)? {
77 Some(key) => {
78 self.key = key.to_vec();
79 self.state = State::Value;
80 }
81 None => {
82 return Ok(None);
83 }
84 }
85 }
86 }
87 State::Value => match Self::read_delimited(length, buf) {
88 Some(value) => {
89 let key = std::mem::take(&mut self.key);
90 self.frame.extend(std::iter::once((key, value)));
91 self.state = State::Key;
92 }
93 None => {
94 return Ok(None);
95 }
96 },
97 }
98 }
99 }
100}
101
102#[derive(Debug)]
103pub enum CodecError {
104 IO(std::io::Error),
105 KeyTooLong,
106 EmptyKey,
107 ValueTooLong,
108 Serde(String),
109 Unsupported,
110}
111
112impl From<std::io::Error> for CodecError {
113 fn from(err: std::io::Error) -> Self {
114 Self::IO(err)
115 }
116}
117
118impl std::fmt::Display for CodecError {
119 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
120 write!(fmt, "{:?}", self)
121 }
122}
123
124impl std::error::Error for CodecError {}
125
126#[cfg(test)]
127mod test {
128 use amp_serde::Request;
129 use bytes::BytesMut;
130 use serde::Serialize;
131 use tokio_util::codec::Decoder as _;
132
133 use crate::*;
134
135 const WWW_EXAMPLE: &[u8] = &[
136 0x00, 0x04, 0x5F, 0x61, 0x73, 0x6B, 0x00, 0x02, 0x32, 0x33, 0x00, 0x08, 0x5F, 0x63, 0x6F,
137 0x6D, 0x6D, 0x61, 0x6E, 0x64, 0x00, 0x03, 0x53, 0x75, 0x6D, 0x00, 0x01, 0x61, 0x00, 0x02,
138 0x31, 0x33, 0x00, 0x01, 0x62, 0x00, 0x02, 0x38, 0x31, 0x00, 0x00,
139 ];
140 const WWW_EXAMPLE_DEC: &[(&[u8], &[u8])] = &[
141 (b"_ask", b"23"),
142 (b"_command", b"Sum"),
143 (b"a", b"13"),
144 (b"b", b"81"),
145 ];
146
147 #[test]
148 fn decode_example() {
149 let mut dec = Decoder::<Vec<_>>::new();
150 let mut buf = BytesMut::new();
151 buf.extend(WWW_EXAMPLE);
152
153 let frame = dec.decode(&mut buf).unwrap().unwrap();
154
155 assert_eq!(
156 frame
157 .iter()
158 .map(|(k, v)| (k.as_ref(), v.as_ref()))
159 .collect::<Vec<_>>(),
160 WWW_EXAMPLE_DEC
161 );
162 assert_eq!(buf.len(), 0);
163 assert_eq!(dec, Decoder::<Vec<_>>::new());
164 }
165
166 #[test]
167 fn encode_example() {
168 #[derive(Serialize)]
169 struct Sum {
170 a: u32,
171 b: u32,
172 }
173 let fields = Sum { a: 13, b: 81 };
174
175 let buf = amp_serde::to_bytes(Request {
176 command: "Sum".into(),
177 tag: Some(b"23".as_ref().into()),
178 fields,
179 })
180 .unwrap();
181
182 assert_eq!(buf, WWW_EXAMPLE);
183 }
184}