1use std::marker::PhantomData;
2
3use actix_codec::{Decoder, Encoder};
4use byteorder::{BigEndian, ByteOrder};
5use bytes::{BufMut, BytesMut};
6
7use super::errors::{AmqpCodecError, ProtocolIdError};
8use super::framing::HEADER_LEN;
9use crate::codec::{Decode, Encode};
10use crate::protocol::ProtocolId;
11
12const SIZE_LOW_WM: usize = 4096;
13const SIZE_HIGH_WM: usize = 32768;
14
15#[derive(Debug)]
16pub struct AmqpCodec<T: Decode + Encode> {
17 state: DecodeState,
18 max_size: usize,
19 phantom: PhantomData<T>,
20}
21
22#[derive(Debug, Clone, Copy)]
23enum DecodeState {
24 FrameHeader,
25 Frame(usize),
26}
27
28impl<T: Decode + Encode> Default for AmqpCodec<T> {
29 fn default() -> Self {
30 Self::new()
31 }
32}
33
34impl<T: Decode + Encode> AmqpCodec<T> {
35 pub fn new() -> AmqpCodec<T> {
36 AmqpCodec {
37 state: DecodeState::FrameHeader,
38 max_size: 0,
39 phantom: PhantomData,
40 }
41 }
42
43 pub fn max_size(&mut self, size: usize) {
48 self.max_size = size;
49 }
50}
51
52impl<T: Decode + Encode> Decoder for AmqpCodec<T> {
53 type Item = T;
54 type Error = AmqpCodecError;
55
56 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
57 loop {
58 match self.state {
59 DecodeState::FrameHeader => {
60 let len = src.len();
61 if len < HEADER_LEN {
62 return Ok(None);
63 }
64
65 let size = BigEndian::read_u32(src.as_ref()) as usize;
67 if self.max_size != 0 && size > self.max_size {
68 return Err(AmqpCodecError::MaxSizeExceeded);
69 }
70 self.state = DecodeState::Frame(size - 4);
71 src.split_to(4);
72
73 if len < size {
74 if src.remaining_mut() < std::cmp::max(SIZE_LOW_WM, size + HEADER_LEN) {
76 src.reserve(SIZE_HIGH_WM);
77 }
78 return Ok(None);
79 }
80 }
81 DecodeState::Frame(size) => {
82 if src.len() < size {
83 return Ok(None);
84 }
85
86 let frame_buf = src.split_to(size);
87 let (remainder, frame) = T::decode(frame_buf.as_ref())?;
88 if !remainder.is_empty() {
89 return Err(AmqpCodecError::UnparsedBytesLeft);
91 }
92 self.state = DecodeState::FrameHeader;
93 return Ok(Some(frame));
94 }
95 }
96 }
97 }
98}
99
100impl<T: Decode + Encode + ::std::fmt::Debug> Encoder for AmqpCodec<T> {
101 type Item = T;
102 type Error = AmqpCodecError;
103
104 fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
105 let size = item.encoded_size();
106 let need = std::cmp::max(SIZE_LOW_WM, size);
107 if dst.remaining_mut() < need {
108 dst.reserve(std::cmp::max(need, SIZE_HIGH_WM));
109 }
110
111 item.encode(dst);
112 Ok(())
113 }
114}
115
116const PROTOCOL_HEADER_LEN: usize = 8;
117const PROTOCOL_HEADER_PREFIX: &[u8] = b"AMQP";
118const PROTOCOL_VERSION: &[u8] = &[1, 0, 0];
119
120#[derive(Default, Debug)]
121pub struct ProtocolIdCodec;
122
123impl Decoder for ProtocolIdCodec {
124 type Item = ProtocolId;
125 type Error = ProtocolIdError;
126
127 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
128 if src.len() < PROTOCOL_HEADER_LEN {
129 Ok(None)
130 } else {
131 let src = src.split_to(8);
132 if &src[0..4] != PROTOCOL_HEADER_PREFIX {
133 Err(ProtocolIdError::InvalidHeader)
134 } else if &src[5..8] != PROTOCOL_VERSION {
135 Err(ProtocolIdError::Incompatible)
136 } else {
137 let protocol_id = src[4];
138 match protocol_id {
139 0 => Ok(Some(ProtocolId::Amqp)),
140 2 => Ok(Some(ProtocolId::AmqpTls)),
141 3 => Ok(Some(ProtocolId::AmqpSasl)),
142 _ => Err(ProtocolIdError::Unknown),
143 }
144 }
145 }
146 }
147}
148
149impl Encoder for ProtocolIdCodec {
150 type Item = ProtocolId;
151 type Error = ProtocolIdError;
152
153 fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
154 dst.reserve(PROTOCOL_HEADER_LEN);
155 dst.put_slice(PROTOCOL_HEADER_PREFIX);
156 dst.put_u8(item as u8);
157 dst.put_slice(PROTOCOL_VERSION);
158 Ok(())
159 }
160}