ntex_amqp_codec/
io.rs

1use std::{cell::Cell, marker::PhantomData};
2
3use byteorder::{BigEndian, ByteOrder};
4use ntex_bytes::{Buf, BufMut, BytesMut};
5use ntex_codec::{Decoder, Encoder};
6
7use super::error::{AmqpCodecError, ProtocolIdError};
8use super::framing::HEADER_LEN;
9use crate::codec::{Decode, Encode};
10use crate::protocol::ProtocolId;
11
12#[derive(Debug)]
13pub struct AmqpCodec<T: Decode + Encode> {
14    state: Cell<DecodeState>,
15    max_size: usize,
16    phantom: PhantomData<T>,
17}
18
19#[derive(Debug, Clone, Copy)]
20enum DecodeState {
21    FrameHeader,
22    Frame(usize),
23}
24
25impl<T: Decode + Encode> Default for AmqpCodec<T> {
26    fn default() -> Self {
27        Self::new()
28    }
29}
30
31impl<T: Decode + Encode> AmqpCodec<T> {
32    pub fn new() -> AmqpCodec<T> {
33        AmqpCodec {
34            state: Cell::new(DecodeState::FrameHeader),
35            max_size: 0,
36            phantom: PhantomData,
37        }
38    }
39
40    /// Set max inbound frame size.
41    ///
42    /// If max size is set to `0`, size is unlimited.
43    /// By default max size is set to `0`
44    pub fn max_size(mut self, size: usize) -> Self {
45        self.max_size = size;
46        self
47    }
48
49    /// Set max inbound frame size.
50    ///
51    /// If max size is set to `0`, size is unlimited.
52    /// By default max size is set to `0`
53    pub fn set_max_size(&mut self, size: usize) {
54        self.max_size = size;
55    }
56}
57
58impl<T: Decode + Encode> Decoder for AmqpCodec<T> {
59    type Item = T;
60    type Error = AmqpCodecError;
61
62    fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
63        loop {
64            match self.state.get() {
65                DecodeState::FrameHeader => {
66                    let len = src.len();
67                    if len < HEADER_LEN {
68                        return Ok(None);
69                    }
70
71                    // read frame size
72                    let size = BigEndian::read_u32(src.as_ref()) as usize;
73                    if self.max_size != 0 && size > self.max_size {
74                        return Err(AmqpCodecError::MaxSizeExceeded);
75                    }
76                    if size <= 4 {
77                        return Err(AmqpCodecError::InvalidFrameSize);
78                    }
79                    self.state.set(DecodeState::Frame(size - 4));
80                    src.advance(4);
81
82                    if len < size {
83                        return Ok(None);
84                    }
85                }
86                DecodeState::Frame(size) => {
87                    if src.len() < size {
88                        return Ok(None);
89                    }
90
91                    let mut frame_buf = src.split_to(size).freeze();
92                    let frame = T::decode(&mut frame_buf)?;
93                    if !frame_buf.is_empty() {
94                        // todo: could it really happen?
95                        return Err(AmqpCodecError::UnparsedBytesLeft);
96                    }
97                    self.state.set(DecodeState::FrameHeader);
98                    return Ok(Some(frame));
99                }
100            }
101        }
102    }
103}
104
105impl<T: Decode + Encode + ::std::fmt::Debug> Encoder for AmqpCodec<T> {
106    type Item = T;
107    type Error = AmqpCodecError;
108
109    fn encode(&self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
110        let size = item.encoded_size();
111        if dst.remaining_mut() < size {
112            dst.reserve(size);
113        }
114
115        let len = dst.len();
116        item.encode(dst);
117        debug_assert!(dst.len() - len == size);
118
119        Ok(())
120    }
121}
122
123const PROTOCOL_HEADER_LEN: usize = 8;
124const PROTOCOL_HEADER_PREFIX: &[u8] = b"AMQP";
125const PROTOCOL_VERSION: &[u8] = &[1, 0, 0];
126
127#[derive(Default, Debug)]
128pub struct ProtocolIdCodec;
129
130impl Decoder for ProtocolIdCodec {
131    type Item = ProtocolId;
132    type Error = ProtocolIdError;
133
134    fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
135        if src.len() < PROTOCOL_HEADER_LEN {
136            Ok(None)
137        } else {
138            let src = src.split_to(PROTOCOL_HEADER_LEN);
139            if &src[0..4] != PROTOCOL_HEADER_PREFIX {
140                Err(ProtocolIdError::InvalidHeader)
141            } else if &src[5..8] != PROTOCOL_VERSION {
142                Err(ProtocolIdError::Incompatible)
143            } else {
144                let protocol_id = src[4];
145                match protocol_id {
146                    0 => Ok(Some(ProtocolId::Amqp)),
147                    2 => Ok(Some(ProtocolId::AmqpTls)),
148                    3 => Ok(Some(ProtocolId::AmqpSasl)),
149                    _ => Err(ProtocolIdError::Unknown),
150                }
151            }
152        }
153    }
154}
155
156impl Encoder for ProtocolIdCodec {
157    type Item = ProtocolId;
158    type Error = ProtocolIdError;
159
160    fn encode(&self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
161        dst.reserve(PROTOCOL_HEADER_LEN);
162        dst.put_slice(PROTOCOL_HEADER_PREFIX);
163        dst.put_u8(item as u8);
164        dst.put_slice(PROTOCOL_VERSION);
165        Ok(())
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    use crate::AmqpFrame;
174
175    #[test]
176    fn test_decode() -> Result<(), AmqpCodecError> {
177        let mut data = BytesMut::from(b"\0\0\0\0\0\0\0\0\0\x06AC@A\0S$\xc0\x01\0B".as_ref());
178
179        let codec = AmqpCodec::<AmqpFrame>::new();
180        let res = codec.decode(&mut data);
181        assert!(matches!(res, Err(AmqpCodecError::InvalidFrameSize)));
182
183        Ok(())
184    }
185}