1use std::{cell::Cell, marker::PhantomData};
2
3use byteorder::{BigEndian, ByteOrder};
4use ntex_bytes::{Buf, BufMut, BytesMut};
5use ntex_codec::{Decoder, Encoder};
6
7use super::error::{AmqpCodecError, ProtocolIdError};
8use super::framing::HEADER_LEN;
9use crate::codec::{Decode, Encode};
10use crate::protocol::ProtocolId;
11
12#[derive(Debug)]
13pub struct AmqpCodec<T: Decode + Encode> {
14 state: Cell<DecodeState>,
15 max_size: usize,
16 phantom: PhantomData<T>,
17}
18
19#[derive(Debug, Clone, Copy)]
20enum DecodeState {
21 FrameHeader,
22 Frame(usize),
23}
24
25impl<T: Decode + Encode> Default for AmqpCodec<T> {
26 fn default() -> Self {
27 Self::new()
28 }
29}
30
31impl<T: Decode + Encode> AmqpCodec<T> {
32 pub fn new() -> AmqpCodec<T> {
33 AmqpCodec {
34 state: Cell::new(DecodeState::FrameHeader),
35 max_size: 0,
36 phantom: PhantomData,
37 }
38 }
39
40 pub fn max_size(mut self, size: usize) -> Self {
45 self.max_size = size;
46 self
47 }
48
49 pub fn set_max_size(&mut self, size: usize) {
54 self.max_size = size;
55 }
56}
57
58impl<T: Decode + Encode> Decoder for AmqpCodec<T> {
59 type Item = T;
60 type Error = AmqpCodecError;
61
62 fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
63 loop {
64 match self.state.get() {
65 DecodeState::FrameHeader => {
66 let len = src.len();
67 if len < HEADER_LEN {
68 return Ok(None);
69 }
70
71 let size = BigEndian::read_u32(src.as_ref()) as usize;
73 if self.max_size != 0 && size > self.max_size {
74 return Err(AmqpCodecError::MaxSizeExceeded);
75 }
76 if size <= 4 {
77 return Err(AmqpCodecError::InvalidFrameSize);
78 }
79 self.state.set(DecodeState::Frame(size - 4));
80 src.advance(4);
81
82 if len < size {
83 return Ok(None);
84 }
85 }
86 DecodeState::Frame(size) => {
87 if src.len() < size {
88 return Ok(None);
89 }
90
91 let mut frame_buf = src.split_to(size).freeze();
92 let frame = T::decode(&mut frame_buf)?;
93 if !frame_buf.is_empty() {
94 return Err(AmqpCodecError::UnparsedBytesLeft);
96 }
97 self.state.set(DecodeState::FrameHeader);
98 return Ok(Some(frame));
99 }
100 }
101 }
102 }
103}
104
105impl<T: Decode + Encode + ::std::fmt::Debug> Encoder for AmqpCodec<T> {
106 type Item = T;
107 type Error = AmqpCodecError;
108
109 fn encode(&self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
110 let size = item.encoded_size();
111 if dst.remaining_mut() < size {
112 dst.reserve(size);
113 }
114
115 let len = dst.len();
116 item.encode(dst);
117 debug_assert!(dst.len() - len == size);
118
119 Ok(())
120 }
121}
122
123const PROTOCOL_HEADER_LEN: usize = 8;
124const PROTOCOL_HEADER_PREFIX: &[u8] = b"AMQP";
125const PROTOCOL_VERSION: &[u8] = &[1, 0, 0];
126
127#[derive(Default, Debug)]
128pub struct ProtocolIdCodec;
129
130impl Decoder for ProtocolIdCodec {
131 type Item = ProtocolId;
132 type Error = ProtocolIdError;
133
134 fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
135 if src.len() < PROTOCOL_HEADER_LEN {
136 Ok(None)
137 } else {
138 let src = src.split_to(PROTOCOL_HEADER_LEN);
139 if &src[0..4] != PROTOCOL_HEADER_PREFIX {
140 Err(ProtocolIdError::InvalidHeader)
141 } else if &src[5..8] != PROTOCOL_VERSION {
142 Err(ProtocolIdError::Incompatible)
143 } else {
144 let protocol_id = src[4];
145 match protocol_id {
146 0 => Ok(Some(ProtocolId::Amqp)),
147 2 => Ok(Some(ProtocolId::AmqpTls)),
148 3 => Ok(Some(ProtocolId::AmqpSasl)),
149 _ => Err(ProtocolIdError::Unknown),
150 }
151 }
152 }
153 }
154}
155
156impl Encoder for ProtocolIdCodec {
157 type Item = ProtocolId;
158 type Error = ProtocolIdError;
159
160 fn encode(&self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
161 dst.reserve(PROTOCOL_HEADER_LEN);
162 dst.put_slice(PROTOCOL_HEADER_PREFIX);
163 dst.put_u8(item as u8);
164 dst.put_slice(PROTOCOL_VERSION);
165 Ok(())
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 use crate::AmqpFrame;
174
175 #[test]
176 fn test_decode() -> Result<(), AmqpCodecError> {
177 let mut data = BytesMut::from(b"\0\0\0\0\0\0\0\0\0\x06AC@A\0S$\xc0\x01\0B".as_ref());
178
179 let codec = AmqpCodec::<AmqpFrame>::new();
180 let res = codec.decode(&mut data);
181 assert!(matches!(res, Err(AmqpCodecError::InvalidFrameSize)));
182
183 Ok(())
184 }
185}