amqp_codec/
io.rs

1use std::marker::PhantomData;
2
3use actix_codec::{Decoder, Encoder};
4use byteorder::{BigEndian, ByteOrder};
5use bytes::{BufMut, BytesMut};
6
7use super::errors::{AmqpCodecError, ProtocolIdError};
8use super::framing::HEADER_LEN;
9use crate::codec::{Decode, Encode};
10use crate::protocol::ProtocolId;
11
12const SIZE_LOW_WM: usize = 4096;
13const SIZE_HIGH_WM: usize = 32768;
14
15#[derive(Debug)]
16pub struct AmqpCodec<T: Decode + Encode> {
17    state: DecodeState,
18    max_size: usize,
19    phantom: PhantomData<T>,
20}
21
22#[derive(Debug, Clone, Copy)]
23enum DecodeState {
24    FrameHeader,
25    Frame(usize),
26}
27
28impl<T: Decode + Encode> Default for AmqpCodec<T> {
29    fn default() -> Self {
30        Self::new()
31    }
32}
33
34impl<T: Decode + Encode> AmqpCodec<T> {
35    pub fn new() -> AmqpCodec<T> {
36        AmqpCodec {
37            state: DecodeState::FrameHeader,
38            max_size: 0,
39            phantom: PhantomData,
40        }
41    }
42
43    /// Set max inbound frame size.
44    ///
45    /// If max size is set to `0`, size is unlimited.
46    /// By default max size is set to `0`
47    pub fn max_size(&mut self, size: usize) {
48        self.max_size = size;
49    }
50}
51
52impl<T: Decode + Encode> Decoder for AmqpCodec<T> {
53    type Item = T;
54    type Error = AmqpCodecError;
55
56    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
57        loop {
58            match self.state {
59                DecodeState::FrameHeader => {
60                    let len = src.len();
61                    if len < HEADER_LEN {
62                        return Ok(None);
63                    }
64
65                    // read frame size
66                    let size = BigEndian::read_u32(src.as_ref()) as usize;
67                    if self.max_size != 0 && size > self.max_size {
68                        return Err(AmqpCodecError::MaxSizeExceeded);
69                    }
70                    self.state = DecodeState::Frame(size - 4);
71                    src.split_to(4);
72
73                    if len < size {
74                        // extend receiving buffer to fit the whole frame
75                        if src.remaining_mut() < std::cmp::max(SIZE_LOW_WM, size + HEADER_LEN) {
76                            src.reserve(SIZE_HIGH_WM);
77                        }
78                        return Ok(None);
79                    }
80                }
81                DecodeState::Frame(size) => {
82                    if src.len() < size {
83                        return Ok(None);
84                    }
85
86                    let frame_buf = src.split_to(size);
87                    let (remainder, frame) = T::decode(frame_buf.as_ref())?;
88                    if !remainder.is_empty() {
89                        // todo: could it really happen?
90                        return Err(AmqpCodecError::UnparsedBytesLeft);
91                    }
92                    self.state = DecodeState::FrameHeader;
93                    return Ok(Some(frame));
94                }
95            }
96        }
97    }
98}
99
100impl<T: Decode + Encode + ::std::fmt::Debug> Encoder for AmqpCodec<T> {
101    type Item = T;
102    type Error = AmqpCodecError;
103
104    fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
105        let size = item.encoded_size();
106        let need = std::cmp::max(SIZE_LOW_WM, size);
107        if dst.remaining_mut() < need {
108            dst.reserve(std::cmp::max(need, SIZE_HIGH_WM));
109        }
110
111        item.encode(dst);
112        Ok(())
113    }
114}
115
116const PROTOCOL_HEADER_LEN: usize = 8;
117const PROTOCOL_HEADER_PREFIX: &[u8] = b"AMQP";
118const PROTOCOL_VERSION: &[u8] = &[1, 0, 0];
119
120#[derive(Default, Debug)]
121pub struct ProtocolIdCodec;
122
123impl Decoder for ProtocolIdCodec {
124    type Item = ProtocolId;
125    type Error = ProtocolIdError;
126
127    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
128        if src.len() < PROTOCOL_HEADER_LEN {
129            Ok(None)
130        } else {
131            let src = src.split_to(8);
132            if &src[0..4] != PROTOCOL_HEADER_PREFIX {
133                Err(ProtocolIdError::InvalidHeader)
134            } else if &src[5..8] != PROTOCOL_VERSION {
135                Err(ProtocolIdError::Incompatible)
136            } else {
137                let protocol_id = src[4];
138                match protocol_id {
139                    0 => Ok(Some(ProtocolId::Amqp)),
140                    2 => Ok(Some(ProtocolId::AmqpTls)),
141                    3 => Ok(Some(ProtocolId::AmqpSasl)),
142                    _ => Err(ProtocolIdError::Unknown),
143                }
144            }
145        }
146    }
147}
148
149impl Encoder for ProtocolIdCodec {
150    type Item = ProtocolId;
151    type Error = ProtocolIdError;
152
153    fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
154        dst.reserve(PROTOCOL_HEADER_LEN);
155        dst.put_slice(PROTOCOL_HEADER_PREFIX);
156        dst.put_u8(item as u8);
157        dst.put_slice(PROTOCOL_VERSION);
158        Ok(())
159    }
160}