Skip to main content

ntex_amqp_codec/codec/
decode.rs

1use std::{char, collections, convert::TryFrom, hash::BuildHasher, hash::Hash};
2
3use byteorder::{BigEndian, ByteOrder};
4use chrono::{DateTime, TimeZone, Utc};
5use ntex_bytes::{Buf, ByteString, Bytes};
6use ordered_float::OrderedFloat;
7use uuid::Uuid;
8
9use crate::codec::{self, ArrayHeader, Composite, Decode, DecodeFormatted, ListHeader, MapHeader};
10use crate::error::AmqpParseError;
11use crate::framing::{self, AmqpFrame, HEADER_LEN, SaslFrame};
12use crate::types::{
13    Array, Constructor, DescribedCompound, Descriptor, List, ListDescribed, Multiple, Str, Symbol,
14    Variant, VariantMap, VecStringMap, VecSymbolMap,
15};
16use crate::{HashMap, protocol};
17
18macro_rules! be_read {
19    ($input:ident, $fn:ident, $size:expr) => {{
20        decode_check_len!($input, $size);
21        let result = BigEndian::$fn(&$input);
22        $input.advance($size);
23        Ok(result)
24    }};
25}
26
27fn read_u8(input: &mut Bytes) -> Result<u8, AmqpParseError> {
28    decode_check_len!(input, 1);
29    let code = input[0];
30    input.advance(1);
31    Ok(code)
32}
33
34fn read_i8(input: &mut Bytes) -> Result<i8, AmqpParseError> {
35    decode_check_len!(input, 1);
36    let code = input[0] as i8;
37    input.advance(1);
38    Ok(code)
39}
40
41fn read_bytes_u8(input: &mut Bytes) -> Result<Bytes, AmqpParseError> {
42    let len = read_u8(input)?;
43    let len = len as usize;
44    decode_check_len!(input, len);
45    Ok(input.split_to(len))
46}
47
48fn read_bytes_u32(input: &mut Bytes) -> Result<Bytes, AmqpParseError> {
49    let result: Result<u32, AmqpParseError> = be_read!(input, read_u32, 4);
50    let len = result?;
51    let len = len as usize;
52    decode_check_len!(input, len);
53    Ok(input.split_to(len))
54}
55
56#[macro_export]
57macro_rules! validate_code {
58    ($fmt:ident, $code:expr) => {
59        if $fmt != $code {
60            return Err(AmqpParseError::InvalidFormatCode($fmt));
61        }
62    };
63}
64
65impl DecodeFormatted for bool {
66    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
67        match fmt {
68            codec::FORMATCODE_BOOLEAN => read_u8(input).map(|o| o != 0),
69            codec::FORMATCODE_BOOLEAN_TRUE => Ok(true),
70            codec::FORMATCODE_BOOLEAN_FALSE => Ok(false),
71            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
72        }
73    }
74}
75
76impl DecodeFormatted for u8 {
77    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
78        validate_code!(fmt, codec::FORMATCODE_UBYTE);
79        read_u8(input)
80    }
81}
82
83impl DecodeFormatted for u16 {
84    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
85        validate_code!(fmt, codec::FORMATCODE_USHORT);
86        be_read!(input, read_u16, 2)
87    }
88}
89
90impl DecodeFormatted for u32 {
91    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
92        match fmt {
93            codec::FORMATCODE_UINT => be_read!(input, read_u32, 4),
94            codec::FORMATCODE_SMALLUINT => read_u8(input).map(u32::from),
95            codec::FORMATCODE_UINT_0 => Ok(0),
96            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
97        }
98    }
99}
100
101impl DecodeFormatted for u64 {
102    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
103        match fmt {
104            codec::FORMATCODE_ULONG => be_read!(input, read_u64, 8),
105            codec::FORMATCODE_SMALLULONG => read_u8(input).map(u64::from),
106            codec::FORMATCODE_ULONG_0 => Ok(0),
107            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
108        }
109    }
110}
111
112impl DecodeFormatted for i8 {
113    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
114        validate_code!(fmt, codec::FORMATCODE_BYTE);
115        read_i8(input)
116    }
117}
118
119impl DecodeFormatted for i16 {
120    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
121        validate_code!(fmt, codec::FORMATCODE_SHORT);
122        be_read!(input, read_i16, 2)
123    }
124}
125
126impl DecodeFormatted for i32 {
127    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
128        match fmt {
129            codec::FORMATCODE_INT => be_read!(input, read_i32, 4),
130            codec::FORMATCODE_SMALLINT => read_i8(input).map(i32::from),
131            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
132        }
133    }
134}
135
136impl DecodeFormatted for i64 {
137    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
138        match fmt {
139            codec::FORMATCODE_LONG => be_read!(input, read_i64, 8),
140            codec::FORMATCODE_SMALLLONG => read_i8(input).map(i64::from),
141            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
142        }
143    }
144}
145
146impl DecodeFormatted for f32 {
147    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
148        validate_code!(fmt, codec::FORMATCODE_FLOAT);
149        be_read!(input, read_f32, 4)
150    }
151}
152
153impl DecodeFormatted for f64 {
154    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
155        validate_code!(fmt, codec::FORMATCODE_DOUBLE);
156        be_read!(input, read_f64, 8)
157    }
158}
159
160impl DecodeFormatted for char {
161    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
162        validate_code!(fmt, codec::FORMATCODE_CHAR);
163        let result: Result<u32, AmqpParseError> = be_read!(input, read_u32, 4);
164        let o = result?;
165        if let Some(c) = char::from_u32(o) {
166            Ok(c)
167        } else {
168            Err(AmqpParseError::InvalidChar(o))
169        } // todo: replace with CharTryFromError once try_from is stabilized
170    }
171}
172
173impl DecodeFormatted for DateTime<Utc> {
174    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
175        validate_code!(fmt, codec::FORMATCODE_TIMESTAMP);
176        be_read!(input, read_i64, 8).and_then(datetime_from_millis)
177    }
178}
179
180impl DecodeFormatted for Uuid {
181    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
182        validate_code!(fmt, codec::FORMATCODE_UUID);
183        decode_check_len!(input, 16);
184        let uuid =
185            Uuid::from_slice(&input.split_to(16)).map_err(|_| AmqpParseError::UuidParseError)?;
186        Ok(uuid)
187    }
188}
189
190impl DecodeFormatted for Bytes {
191    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
192        match fmt {
193            codec::FORMATCODE_BINARY8 => read_bytes_u8(input),
194            codec::FORMATCODE_BINARY32 => read_bytes_u32(input),
195            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
196        }
197    }
198}
199
200impl DecodeFormatted for ByteString {
201    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
202        match fmt {
203            codec::FORMATCODE_STRING8 => {
204                let bytes = read_bytes_u8(input)?;
205                Ok(ByteString::try_from(bytes).map_err(|_| AmqpParseError::Utf8Error)?)
206            }
207            codec::FORMATCODE_STRING32 => {
208                let bytes = read_bytes_u32(input)?;
209                Ok(ByteString::try_from(bytes).map_err(|_| AmqpParseError::Utf8Error)?)
210            }
211            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
212        }
213    }
214}
215
216impl DecodeFormatted for Str {
217    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
218        Ok(Str::from(ByteString::decode_with_format(input, fmt)?))
219    }
220}
221
222impl DecodeFormatted for Symbol {
223    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
224        match fmt {
225            codec::FORMATCODE_SYMBOL8 => {
226                let bytes = read_bytes_u8(input)?;
227                Ok(Symbol(Str::from(
228                    ByteString::try_from(bytes).map_err(|_| AmqpParseError::Utf8Error)?,
229                )))
230            }
231            codec::FORMATCODE_SYMBOL32 => {
232                let bytes = read_bytes_u32(input)?;
233                Ok(Symbol(Str::from(
234                    ByteString::try_from(bytes).map_err(|_| AmqpParseError::Utf8Error)?,
235                )))
236            }
237            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
238        }
239    }
240}
241
242impl<K: Decode + Eq + Hash, V: Decode, S: BuildHasher + Default> DecodeFormatted
243    for collections::HashMap<K, V, S>
244{
245    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
246        let header = MapHeader::decode_with_format(input, fmt)?;
247        decode_check_len!(input, header.size as usize);
248        let mut map_input = input.split_to(header.size as usize);
249        let count = header.count / 2;
250        let mut map: collections::HashMap<K, V, S> =
251            collections::HashMap::with_capacity_and_hasher(count as usize, Default::default());
252        for _ in 0..count {
253            let key = K::decode(&mut map_input)?;
254            let value = V::decode(&mut map_input)?;
255            map.insert(key, value); // todo: ensure None returned?
256        }
257        // todo: validate map_input is empty
258        Ok(map)
259    }
260}
261
262impl<T: DecodeFormatted> DecodeFormatted for Vec<T> {
263    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
264        let header = ArrayHeader::decode_with_format(input, fmt)?;
265        decode_check_len!(input, header.size as usize);
266        let elem_ctor = Constructor::decode(input)?;
267        let elem_fmt = match elem_ctor {
268            Constructor::FormatCode(code) => code,
269            Constructor::Described { descriptor, .. } => {
270                // todo: mg: described types are not supported OOTB at this point
271                return Err(AmqpParseError::InvalidDescriptor(Box::new(descriptor)));
272            }
273        };
274        let mut result: Vec<T> = Vec::with_capacity(header.count as usize);
275        for _ in 0..header.count {
276            let decoded = T::decode_with_format(input, elem_fmt)?;
277            result.push(decoded);
278        }
279        // todo: ensure header.size bytes were read out from input
280        Ok(result)
281    }
282}
283
284impl DecodeFormatted for VecSymbolMap {
285    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
286        let header = MapHeader::decode_with_format(input, fmt)?;
287        decode_check_len!(input, header.size as usize);
288        let mut map_input = input.split_to(header.size as usize);
289        let count = header.count / 2;
290        let mut map = Vec::with_capacity(count as usize);
291        for _ in 0..count {
292            let key = Symbol::decode(&mut map_input)?;
293            let value = Variant::decode(&mut map_input)?;
294            map.push((key, value)); // todo: mg: ensure None is returned
295        }
296        // todo: ensure header.size bytes were read out from input after header
297        Ok(VecSymbolMap(map))
298    }
299}
300
301impl DecodeFormatted for VecStringMap {
302    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
303        let header = MapHeader::decode_with_format(input, fmt)?;
304        decode_check_len!(input, header.size as usize);
305        let mut map_input = input.split_to(header.size as usize);
306        let count = header.count / 2;
307        let mut map = Vec::with_capacity(count as usize);
308        for _ in 0..count {
309            let key = Str::decode(&mut map_input)?;
310            let value = Variant::decode(&mut map_input)?;
311            map.push((key, value)); // todo: ensure None returned?
312        }
313        // todo: validate map_input is empty
314        Ok(VecStringMap(map))
315    }
316}
317
318impl<T: DecodeFormatted> DecodeFormatted for Multiple<T> {
319    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
320        match fmt {
321            codec::FORMATCODE_ARRAY8 | codec::FORMATCODE_ARRAY32 => {
322                let items = Vec::<T>::decode_with_format(input, fmt)?;
323                Ok(Multiple(items))
324            }
325            codec::FORMATCODE_DESCRIBED => {
326                let descriptor = Descriptor::decode_with_format(input, fmt)?;
327                // todo: mg: described types are not supported OOTB at this point
328                Err(AmqpParseError::InvalidDescriptor(Box::new(descriptor)))
329            }
330            _ => {
331                let item = T::decode_with_format(input, fmt)?;
332                Ok(Multiple(vec![item]))
333            }
334        }
335    }
336}
337
338impl DecodeFormatted for List {
339    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
340        let header = ListHeader::decode_with_format(input, fmt)?;
341        let mut result: Vec<Variant> = Vec::with_capacity(header.count as usize);
342        for _ in 0..header.count {
343            let decoded = Variant::decode(input)?;
344            result.push(decoded);
345        }
346        Ok(List(result))
347    }
348}
349
350impl<T: Composite> DecodeFormatted for ListDescribed<T> {
351    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
352        let header = ListHeader::decode_with_format(input, fmt)?;
353
354        let descr = T::descriptor();
355        let mut result: Vec<T> = Vec::with_capacity(header.count as usize);
356        for _ in 0..header.count {
357            if let Variant::DescribedCompound(decoded) = Variant::decode(input)? {
358                if &descr != decoded.descriptor() {
359                    return Err(AmqpParseError::UnexpectedType("Unexpected descriptor"));
360                }
361                result.push(decoded.decode()?);
362            } else {
363                return Err(AmqpParseError::UnexpectedType("Expected compound type"));
364            }
365        }
366
367        Ok(ListDescribed(result))
368    }
369}
370
371impl DecodeFormatted for Variant {
372    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
373        match fmt {
374            codec::FORMATCODE_NULL => Ok(Variant::Null),
375            codec::FORMATCODE_BOOLEAN => bool::decode_with_format(input, fmt).map(Variant::Boolean),
376            codec::FORMATCODE_BOOLEAN_FALSE => Ok(Variant::Boolean(false)),
377            codec::FORMATCODE_BOOLEAN_TRUE => Ok(Variant::Boolean(true)),
378            codec::FORMATCODE_UINT_0 => Ok(Variant::Uint(0)),
379            codec::FORMATCODE_ULONG_0 => Ok(Variant::Ulong(0)),
380            codec::FORMATCODE_UBYTE => u8::decode_with_format(input, fmt).map(Variant::Ubyte),
381            codec::FORMATCODE_USHORT => u16::decode_with_format(input, fmt).map(Variant::Ushort),
382            codec::FORMATCODE_UINT => u32::decode_with_format(input, fmt).map(Variant::Uint),
383            codec::FORMATCODE_ULONG => u64::decode_with_format(input, fmt).map(Variant::Ulong),
384            codec::FORMATCODE_BYTE => i8::decode_with_format(input, fmt).map(Variant::Byte),
385            codec::FORMATCODE_SHORT => i16::decode_with_format(input, fmt).map(Variant::Short),
386            codec::FORMATCODE_INT => i32::decode_with_format(input, fmt).map(Variant::Int),
387            codec::FORMATCODE_LONG => i64::decode_with_format(input, fmt).map(Variant::Long),
388            codec::FORMATCODE_SMALLUINT => u32::decode_with_format(input, fmt).map(Variant::Uint),
389            codec::FORMATCODE_SMALLULONG => u64::decode_with_format(input, fmt).map(Variant::Ulong),
390            codec::FORMATCODE_SMALLINT => i32::decode_with_format(input, fmt).map(Variant::Int),
391            codec::FORMATCODE_SMALLLONG => i64::decode_with_format(input, fmt).map(Variant::Long),
392            codec::FORMATCODE_FLOAT => {
393                f32::decode_with_format(input, fmt).map(|o| Variant::Float(OrderedFloat(o)))
394            }
395            codec::FORMATCODE_DOUBLE => {
396                f64::decode_with_format(input, fmt).map(|o| Variant::Double(OrderedFloat(o)))
397            }
398            codec::FORMATCODE_DECIMAL32 => read_fixed_bytes(input).map(Variant::Decimal32),
399            codec::FORMATCODE_DECIMAL64 => read_fixed_bytes(input).map(Variant::Decimal64),
400            codec::FORMATCODE_DECIMAL128 => read_fixed_bytes(input).map(Variant::Decimal128),
401            codec::FORMATCODE_CHAR => char::decode_with_format(input, fmt).map(Variant::Char),
402            codec::FORMATCODE_TIMESTAMP => {
403                DateTime::<Utc>::decode_with_format(input, fmt).map(Variant::Timestamp)
404            }
405            codec::FORMATCODE_UUID => Uuid::decode_with_format(input, fmt).map(Variant::Uuid),
406            codec::FORMATCODE_BINARY8 | codec::FORMATCODE_BINARY32 => {
407                Bytes::decode_with_format(input, fmt).map(Variant::Binary)
408            }
409            codec::FORMATCODE_STRING8 | codec::FORMATCODE_STRING32 => {
410                ByteString::decode_with_format(input, fmt).map(|o| Variant::String(o.into()))
411            }
412            codec::FORMATCODE_SYMBOL8 | codec::FORMATCODE_SYMBOL32 => {
413                Symbol::decode_with_format(input, fmt).map(Variant::Symbol)
414            }
415            codec::FORMATCODE_LIST0 => Ok(Variant::List(List(vec![]))),
416            codec::FORMATCODE_LIST8 | codec::FORMATCODE_LIST32 => {
417                List::decode_with_format(input, fmt).map(Variant::List)
418            }
419            codec::FORMATCODE_ARRAY8 | codec::FORMATCODE_ARRAY32 => {
420                Array::decode_with_format(input, fmt).map(Variant::Array)
421            }
422            codec::FORMATCODE_MAP8 | codec::FORMATCODE_MAP32 => {
423                HashMap::<Variant, Variant>::decode_with_format(input, fmt)
424                    .map(|o| Variant::Map(VariantMap::new(o)))
425            }
426            codec::FORMATCODE_DESCRIBED => {
427                let descriptor = Descriptor::decode(input)?;
428                let format_code = {
429                    decode_check_len!(input, 1);
430                    let code = input[0];
431                    Ok(code)
432                }?;
433                match format_code {
434                    codec::FORMATCODE_LIST0 => {
435                        input.advance(1); // advance past format code
436                        Ok(Variant::DescribedCompound(DescribedCompound::new(
437                            descriptor,
438                            Bytes::from_static(&[codec::FORMATCODE_LIST0]),
439                        )))
440                    }
441                    codec::FORMATCODE_LIST8 | codec::FORMATCODE_MAP8 | codec::FORMATCODE_ARRAY8 => {
442                        decode_check_len!(input, 2);
443                        let size = input[1] as usize;
444                        decode_check_len!(input, 2 + size);
445                        let data = input.split_to(2 + size);
446                        Ok(Variant::DescribedCompound(DescribedCompound::new(
447                            descriptor, data,
448                        )))
449                    }
450                    codec::FORMATCODE_LIST32
451                    | codec::FORMATCODE_MAP32
452                    | codec::FORMATCODE_ARRAY32 => {
453                        decode_check_len!(input, 5);
454                        let size = u32::from_be_bytes(input[1..5].try_into().unwrap()) as usize;
455                        decode_check_len!(input, 5 + size);
456                        let data = input.split_to(5 + size);
457                        Ok(Variant::DescribedCompound(DescribedCompound::new(
458                            descriptor, data,
459                        )))
460                    }
461                    _ => {
462                        input.advance(1); // advance past format code
463                        let value = Variant::decode_with_format(input, format_code)?;
464                        Ok(Variant::Described((descriptor, Box::new(value))))
465                    }
466                }
467            }
468            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
469        }
470    }
471}
472
473impl<T: DecodeFormatted> DecodeFormatted for Option<T> {
474    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
475        match fmt {
476            codec::FORMATCODE_NULL => Ok(None),
477            _ => T::decode_with_format(input, fmt).map(Some),
478        }
479    }
480}
481
482impl DecodeFormatted for Descriptor {
483    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
484        match fmt {
485            codec::FORMATCODE_SMALLULONG => {
486                u64::decode_with_format(input, fmt).map(Descriptor::Ulong)
487            }
488            codec::FORMATCODE_ULONG => u64::decode_with_format(input, fmt).map(Descriptor::Ulong),
489            codec::FORMATCODE_SYMBOL8 => {
490                Symbol::decode_with_format(input, fmt).map(Descriptor::Symbol)
491            }
492            codec::FORMATCODE_SYMBOL32 => {
493                Symbol::decode_with_format(input, fmt).map(Descriptor::Symbol)
494            }
495            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
496        }
497    }
498}
499
500impl DecodeFormatted for Constructor {
501    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
502        match fmt {
503            codec::FORMATCODE_DESCRIBED => {
504                let descriptor = Descriptor::decode(input)?;
505                let format_code = codec::decode_format_code(input)?;
506                Ok(Constructor::Described {
507                    descriptor,
508                    format_code,
509                })
510            }
511            _ => Ok(Constructor::FormatCode(fmt)),
512        }
513    }
514}
515
516impl Decode for AmqpFrame {
517    fn decode(input: &mut Bytes) -> Result<Self, AmqpParseError> {
518        let channel_id = decode_frame_header(input, framing::FRAME_TYPE_AMQP)?;
519        let performative = protocol::Frame::decode(input)?;
520        Ok(AmqpFrame::new(channel_id, performative))
521    }
522}
523
524impl Decode for SaslFrame {
525    fn decode(input: &mut Bytes) -> Result<Self, AmqpParseError> {
526        let _ = decode_frame_header(input, framing::FRAME_TYPE_SASL)?;
527        let frame = protocol::SaslFrameBody::decode(input)?;
528        Ok(SaslFrame { body: frame })
529    }
530}
531
532impl DecodeFormatted for ListHeader {
533    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
534        match fmt {
535            codec::FORMATCODE_LIST0 => Ok(ListHeader { count: 0, size: 0 }),
536            codec::FORMATCODE_LIST8 => {
537                decode_compound8(input).map(|(size, count)| ListHeader { count, size })
538            }
539            codec::FORMATCODE_LIST32 => {
540                decode_compound32(input).map(|(size, count)| ListHeader { count, size })
541            }
542            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
543        }
544    }
545}
546
547impl DecodeFormatted for MapHeader {
548    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
549        match fmt {
550            codec::FORMATCODE_MAP8 => {
551                decode_compound8(input).map(|(size, count)| MapHeader { count, size })
552            }
553            codec::FORMATCODE_MAP32 => {
554                decode_compound32(input).map(|(size, count)| MapHeader { count, size })
555            }
556            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
557        }
558    }
559}
560
561impl DecodeFormatted for ArrayHeader {
562    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
563        match fmt {
564            codec::FORMATCODE_ARRAY8 => {
565                decode_compound8(input).map(|(size, count)| ArrayHeader { count, size })
566            }
567            codec::FORMATCODE_ARRAY32 => {
568                decode_compound32(input).map(|(size, count)| ArrayHeader { count, size })
569            }
570            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
571        }
572    }
573}
574
575fn decode_frame_header(input: &mut Bytes, expected_frame_type: u8) -> Result<u16, AmqpParseError> {
576    decode_check_len!(input, 4);
577    let doff = input[0];
578    let frame_type = input[1];
579    if frame_type != expected_frame_type {
580        return Err(AmqpParseError::UnexpectedFrameType(frame_type));
581    }
582
583    let channel_id = BigEndian::read_u16(&input[2..]);
584    let doff = doff as usize * 4;
585    if doff < HEADER_LEN {
586        return Err(AmqpParseError::InvalidSize);
587    }
588    // skipping remaining two header bytes and ext header
589    let ext_header_len = doff - HEADER_LEN + 4;
590    decode_check_len!(input, ext_header_len);
591    input.advance(ext_header_len);
592    Ok(channel_id)
593}
594
595fn decode_compound8(input: &mut Bytes) -> Result<(u32, u32), AmqpParseError> {
596    decode_check_len!(input, 2);
597    let size = input[0] - 1; // -1 for 1 byte count
598    let count = input[1];
599    input.advance(2);
600    Ok((u32::from(size), u32::from(count)))
601}
602
603fn decode_compound32(input: &mut Bytes) -> Result<(u32, u32), AmqpParseError> {
604    decode_check_len!(input, 8);
605    let size = BigEndian::read_u32(input) - 4; // -4 for 4 byte count
606    let count = BigEndian::read_u32(&input[4..]);
607    input.advance(8);
608    Ok((size, count))
609}
610
611fn datetime_from_millis(millis: i64) -> Result<DateTime<Utc>, AmqpParseError> {
612    let seconds = millis / 1000;
613    if seconds < 0 {
614        // In order to handle time before 1970 correctly, we need to subtract a second
615        // and use the nanoseconds field to add it back. This is a result of the nanoseconds
616        // parameter being u32
617        let nanoseconds = ((1000 + (millis - (seconds * 1000))) * 1_000_000).unsigned_abs();
618        Utc.timestamp_opt(seconds - 1, nanoseconds as u32)
619            .earliest()
620            .ok_or(AmqpParseError::DatetimeParseError)
621    } else {
622        let nanoseconds = ((millis - (seconds * 1000)) * 1_000_000).unsigned_abs();
623        Utc.timestamp_opt(seconds, nanoseconds as u32)
624            .earliest()
625            .ok_or(AmqpParseError::DatetimeParseError)
626    }
627}
628
629fn read_fixed_bytes<const N: usize>(input: &mut Bytes) -> Result<[u8; N], AmqpParseError> {
630    decode_check_len!(input, N);
631    let mut data = [0u8; N];
632    data.copy_from_slice(&input[..N]);
633    input.advance(N);
634    Ok(data)
635}
636
637#[cfg(test)]
638mod tests {
639    use chrono::TimeDelta;
640    use ntex_bytes::{BufMut, BytesMut};
641    use test_case::test_case;
642
643    use super::*;
644    use crate::codec::{Decode, Encode};
645
646    const LOREM: &str = include_str!("lorem.txt");
647
648    macro_rules! decode_tests {
649        ($($name:ident: $kind:ident, $test:expr, $expected:expr,)*) => {
650        $(
651            #[test]
652            fn $name() {
653                let mut b1 = BytesMut::with_capacity(($test).encoded_size());
654                ($test).encode(&mut b1);
655                assert_eq!($expected, <$kind as Decode>::decode(&mut b1.freeze()).unwrap());
656            }
657        )*
658        }
659    }
660
661    decode_tests! {
662        ubyte: u8, 255_u8, 255_u8,
663        ushort: u16, 350_u16, 350_u16,
664
665        uint_zero: u32, 0_u32, 0_u32,
666        uint_small: u32, 128_u32, 128_u32,
667        uint_big: u32, 2147483647_u32, 2147483647_u32,
668
669        ulong_zero: u64, 0_u64, 0_u64,
670        ulong_small: u64, 128_u64, 128_u64,
671        uulong_big: u64, 2147483649_u64, 2147483649_u64,
672
673        byte: i8, -128_i8, -128_i8,
674        short: i16, -255_i16, -255_i16,
675
676        int_zero: i32, 0_i32, 0_i32,
677        int_small: i32, -50000_i32, -50000_i32,
678        int_neg: i32, -128_i32, -128_i32,
679
680        long_zero: i64, 0_i64, 0_i64,
681        long_big: i64, -2147483647_i64, -2147483647_i64,
682        long_small: i64, -128_i64, -128_i64,
683
684        float: f32, 1.234_f32, 1.234_f32,
685        double: f64, 1.234_f64, 1.234_f64,
686
687        test_char: char, '💯', '💯',
688
689        uuid: Uuid, Uuid::from_slice(&[4, 54, 67, 12, 43, 2, 98, 76, 32, 50, 87, 5, 1, 33, 43, 87]).expect("parse error"),
690        Uuid::parse_str("0436430c2b02624c2032570501212b57").expect("parse error"),
691
692        binary_short: Bytes, Bytes::from(&[4u8, 5u8][..]), Bytes::from(&[4u8, 5u8][..]),
693        binary_long: Bytes, Bytes::from(&[4u8; 500][..]), Bytes::from(&[4u8; 500][..]),
694
695        string_short: ByteString, ByteString::from("Hello there"), ByteString::from("Hello there"),
696        string_long: ByteString, ByteString::from(LOREM), ByteString::from(LOREM),
697
698        // symbol_short: Symbol, Symbol::from("Hello there"), Symbol::from("Hello there"),
699        // symbol_long: Symbol, Symbol::from(LOREM), Symbol::from(LOREM),
700
701        variant_ubyte: Variant, Variant::Ubyte(255_u8), Variant::Ubyte(255_u8),
702        variant_ushort: Variant, Variant::Ushort(350_u16), Variant::Ushort(350_u16),
703
704        variant_uint_zero: Variant, Variant::Uint(0_u32), Variant::Uint(0_u32),
705        variant_uint_small: Variant, Variant::Uint(128_u32), Variant::Uint(128_u32),
706        variant_uint_big: Variant, Variant::Uint(2147483647_u32), Variant::Uint(2147483647_u32),
707
708        variant_ulong_zero: Variant, Variant::Ulong(0_u64), Variant::Ulong(0_u64),
709        variant_ulong_small: Variant, Variant::Ulong(128_u64), Variant::Ulong(128_u64),
710        variant_ulong_big: Variant, Variant::Ulong(2147483649_u64), Variant::Ulong(2147483649_u64),
711
712        variant_byte: Variant, Variant::Byte(-128_i8), Variant::Byte(-128_i8),
713        variant_short: Variant, Variant::Short(-255_i16), Variant::Short(-255_i16),
714
715        variant_int_zero: Variant, Variant::Int(0_i32), Variant::Int(0_i32),
716        variant_int_small: Variant, Variant::Int(-50000_i32), Variant::Int(-50000_i32),
717        variant_int_neg: Variant, Variant::Int(-128_i32), Variant::Int(-128_i32),
718
719        variant_long_zero: Variant, Variant::Long(0_i64), Variant::Long(0_i64),
720        variant_long_big: Variant, Variant::Long(-2147483647_i64), Variant::Long(-2147483647_i64),
721        variant_long_small: Variant, Variant::Long(-128_i64), Variant::Long(-128_i64),
722
723        variant_float: Variant, Variant::Float(OrderedFloat(1.234_f32)), Variant::Float(OrderedFloat(1.234_f32)),
724        variant_double: Variant, Variant::Double(OrderedFloat(1.234_f64)), Variant::Double(OrderedFloat(1.234_f64)),
725
726        variant_char: Variant, Variant::Char('💯'), Variant::Char('💯'),
727
728        variant_uuid: Variant, Variant::Uuid(Uuid::from_slice(&[4, 54, 67, 12, 43, 2, 98, 76, 32, 50, 87, 5, 1, 33, 43, 87]).expect("parse error")),
729        Variant::Uuid(Uuid::parse_str("0436430c2b02624c2032570501212b57").expect("parse error")),
730
731        variant_binary_short: Variant, Variant::Binary(Bytes::from(&[4u8, 5u8][..])), Variant::Binary(Bytes::from(&[4u8, 5u8][..])),
732        variant_binary_long: Variant, Variant::Binary(Bytes::from(&[4u8; 500][..])), Variant::Binary(Bytes::from(&[4u8; 500][..])),
733
734        variant_string_short: Variant, Variant::String(ByteString::from("Hello there").into()), Variant::String(ByteString::from("Hello there").into()),
735        variant_string_long: Variant, Variant::String(ByteString::from(LOREM).into()), Variant::String(ByteString::from(LOREM).into()),
736
737        // variant_symbol_short: Variant, Variant::Symbol(Symbol::from("Hello there")), Variant::Symbol(Symbol::from("Hello there")),
738        // variant_symbol_long: Variant, Variant::Symbol(Symbol::from(LOREM)), Variant::Symbol(Symbol::from(LOREM)),
739    }
740
741    fn unwrap_value<T>(res: Result<T, AmqpParseError>) -> T {
742        assert!(res.is_ok());
743        res.unwrap()
744    }
745
746    #[test]
747    fn test_bool_true() {
748        let mut b1 = BytesMut::with_capacity(0);
749        b1.put_u8(0x41);
750        assert!(unwrap_value(bool::decode(&mut b1.freeze())));
751
752        let mut b2 = BytesMut::with_capacity(0);
753        b2.put_u8(0x56);
754        b2.put_u8(0x01);
755        assert!(unwrap_value(bool::decode(&mut b2.freeze())));
756    }
757
758    #[test]
759    fn test_bool_false() {
760        let mut b1 = BytesMut::with_capacity(0);
761        b1.put_u8(0x42u8);
762        assert!(!unwrap_value(bool::decode(&mut b1.freeze())));
763
764        let mut b2 = BytesMut::with_capacity(0);
765        b2.put_u8(0x56);
766        b2.put_u8(0x00);
767        assert!(!unwrap_value(bool::decode(&mut b2.freeze())));
768    }
769
770    /// UTC with a precision of milliseconds. For example, 1311704463521
771    /// represents the moment 2011-07-26T18:21:03.521Z.
772    #[test]
773    fn test_timestamp() {
774        let mut b1 = BytesMut::with_capacity(0);
775        let datetime =
776            Utc.with_ymd_and_hms(2011, 7, 26, 18, 21, 3).unwrap() + TimeDelta::milliseconds(521);
777        datetime.encode(&mut b1);
778
779        let expected =
780            Utc.with_ymd_and_hms(2011, 7, 26, 18, 21, 3).unwrap() + TimeDelta::milliseconds(521);
781        assert_eq!(
782            expected,
783            unwrap_value(DateTime::<Utc>::decode(&mut b1.freeze()))
784        );
785    }
786
787    #[test]
788    fn test_timestamp_pre_unix() {
789        let mut b1 = BytesMut::with_capacity(0);
790        let datetime =
791            Utc.with_ymd_and_hms(1968, 7, 26, 18, 21, 3).unwrap() + TimeDelta::milliseconds(521);
792        datetime.encode(&mut b1);
793
794        let expected =
795            Utc.with_ymd_and_hms(1968, 7, 26, 18, 21, 3).unwrap() + TimeDelta::milliseconds(521);
796        assert_eq!(
797            expected,
798            unwrap_value(DateTime::<Utc>::decode(&mut b1.freeze()))
799        );
800    }
801
802    #[test]
803    fn variant_null() {
804        let mut b = BytesMut::with_capacity(0);
805        Variant::Null.encode(&mut b);
806        let t = unwrap_value(Variant::decode(&mut b.freeze()));
807        assert_eq!(Variant::Null, t);
808    }
809
810    #[test]
811    fn variant_bool_true() {
812        let mut b1 = BytesMut::with_capacity(0);
813        b1.put_u8(0x41);
814        assert_eq!(
815            Variant::Boolean(true),
816            unwrap_value(Variant::decode(&mut b1.freeze()))
817        );
818
819        let mut b2 = BytesMut::with_capacity(0);
820        b2.put_u8(0x56);
821        b2.put_u8(0x01);
822        assert_eq!(
823            Variant::Boolean(true),
824            unwrap_value(Variant::decode(&mut b2.freeze()))
825        );
826    }
827
828    #[test]
829    fn variant_bool_false() {
830        let mut b1 = BytesMut::with_capacity(0);
831        b1.put_u8(0x42u8);
832        assert_eq!(
833            Variant::Boolean(false),
834            unwrap_value(Variant::decode(&mut b1.freeze()))
835        );
836
837        let mut b2 = BytesMut::with_capacity(0);
838        b2.put_u8(0x56);
839        b2.put_u8(0x00);
840        assert_eq!(
841            Variant::Boolean(false),
842            unwrap_value(Variant::decode(&mut b2.freeze()))
843        );
844    }
845
846    /// UTC with a precision of milliseconds. For example, 1311704463521
847    /// represents the moment 2011-07-26T18:21:03.521Z.
848    #[test]
849    fn variant_timestamp() {
850        let mut b1 = BytesMut::with_capacity(0);
851        let datetime =
852            Utc.with_ymd_and_hms(2011, 7, 26, 18, 21, 3).unwrap() + TimeDelta::milliseconds(521);
853        Variant::Timestamp(datetime).encode(&mut b1);
854
855        let expected =
856            Utc.with_ymd_and_hms(2011, 7, 26, 18, 21, 3).unwrap() + TimeDelta::milliseconds(521);
857        assert_eq!(
858            Variant::Timestamp(expected),
859            unwrap_value(Variant::decode(&mut b1.freeze()))
860        );
861    }
862
863    #[test]
864    fn variant_timestamp_pre_unix() {
865        let mut b1 = BytesMut::with_capacity(0);
866        let datetime =
867            Utc.with_ymd_and_hms(1968, 7, 26, 18, 21, 3).unwrap() + TimeDelta::milliseconds(521);
868        Variant::Timestamp(datetime).encode(&mut b1);
869
870        let expected =
871            Utc.with_ymd_and_hms(1968, 7, 26, 18, 21, 3).unwrap() + TimeDelta::milliseconds(521);
872        assert_eq!(
873            Variant::Timestamp(expected),
874            unwrap_value(Variant::decode(&mut b1.freeze()))
875        );
876    }
877
878    #[test_case(
879        b"\x00\xa3\x07foo:bar\xc0\x03\x01\x50\x03",
880        Descriptor::Symbol("foo:bar".into()),
881        List(vec![Variant::Ubyte(3)]);
882        "described 'foo:bar', list8 w/one u8 field with value 3")]
883    #[test_case(
884        b"\x00\x80\x00\x00\x01\x37\x00\x00\x03\xe9\x45",
885        Descriptor::Ulong((311 << 32) + 1001), List(vec![]); "described 311:1001, list0")]
886    #[test_case(
887        b"\x00\x80\x00\x01\xd4\xc0\x00\x03\x82\x70\xd0\x00\x00\x00\x0c\x00\x00\x00\x03\x53\x6f\xa1\x03abc\x42",
888        Descriptor::Ulong((120_000 << 32) + 230_000),
889        List(vec![Variant::Ulong(111), Variant::String("abc".into()), Variant::Boolean(false)]);
890        "described 120000:230000, list32 w/3 fields: smallulong: 111, string8: 'abc', booleanfalse")]
891    fn decode_described_list(
892        input: &'static [u8],
893        expected_descriptor: Descriptor,
894        expected_list: List,
895    ) {
896        let mut buf = Bytes::from(input);
897        let variant = Variant::decode(&mut buf).unwrap();
898        assert!(buf.is_empty(), "Expected no remaining bytes after decoding");
899        let dc = match variant {
900            Variant::DescribedCompound(dc) => dc,
901            _ => panic!("Expected a DescribedCompound variant"),
902        };
903        assert_eq!(dc.descriptor(), &expected_descriptor);
904        println!("{:02x?}", dc.data.as_ref());
905        let decoded_list: List = dc.decode().expect("Failed to decode List");
906        assert_eq!(decoded_list, expected_list);
907    }
908
909    #[test_case(
910        b"\x00\xa3\x05a:b:c\xc1\x08\x04\x50\x03\x41\x50\xc8\x56\x00",
911        Descriptor::Symbol("a:b:c".into()),
912        vec![(Variant::Ubyte(3), Variant::Boolean(true)), (Variant::Ubyte(200), Variant::Boolean(false))];
913        "described 'a:b:c', map8 with 2 pairs: (ubyte(3), true), (ubyte(200), false)")]
914    #[test_case(
915        b"\x00\x80\x00\x01\xd4\xc0\x00\x03\x82\x70\xd1\x00\x00\x00\x0a\x00\x00\x00\x02\x73\x00\x00\x00z\x40",
916        Descriptor::Ulong((120_000 << 32) + 230_000),
917        vec![(Variant::Char('z'), Variant::Null)];
918        "described 120000:230000, map32 with 1 pair: char: 'z', null")]
919    fn decode_described_map(
920        input: &'static [u8],
921        expected_descriptor: Descriptor,
922        expected_map: Vec<(Variant, Variant)>,
923    ) {
924        let mut buf = Bytes::from(input);
925        let variant = Variant::decode(&mut buf).unwrap();
926        assert!(buf.is_empty(), "Expected no remaining bytes after decoding");
927        let dc = match variant {
928            Variant::DescribedCompound(dc) => dc,
929            _ => panic!("Expected a DescribedCompound variant"),
930        };
931        assert_eq!(dc.descriptor(), &expected_descriptor);
932        println!("{:02x?}", dc.data.as_ref());
933        let decoded_map: HashMap<Variant, Variant> = dc.decode().expect("Failed to decode List");
934        let expected_map: HashMap<Variant, Variant> = expected_map.into_iter().collect();
935        assert_eq!(decoded_map, expected_map);
936    }
937
938    #[test_case(
939        b"\x00\xa3\x07foo:bar\xe0\x05\x03\x50\x01\x02\x03",
940        Descriptor::Symbol("foo:bar".into()),
941        Constructor::FormatCode(codec::FORMATCODE_UBYTE),
942        vec![Variant::Ubyte(1), Variant::Ubyte(2), Variant::Ubyte(3)];
943        "described 'foo:bar', array8 w/3 u8 elements: 1, 2, 3")]
944    fn decode_described_array(
945        input: &'static [u8],
946        expected_descriptor: Descriptor,
947        expected_el_ctor: Constructor,
948        expected_array: Vec<Variant>,
949    ) {
950        // todo: mg: array decoding: add check that all bytes are read out according to size when done decoding array elements /
951        // list fields / map key-value pairs according to count
952        let mut buf = Bytes::from(input);
953        let variant = Variant::decode(&mut buf).unwrap();
954        assert!(buf.is_empty(), "Expected no remaining bytes after decoding");
955        let dc = match variant {
956            Variant::DescribedCompound(dc) => dc,
957            _ => panic!("Expected a DescribedCompound variant"),
958        };
959        assert_eq!(dc.descriptor(), &expected_descriptor);
960        println!("{:02x?}", dc.data.as_ref());
961        let decoded_array: Array = dc.decode().expect("Failed to decode Array");
962        assert_eq!(decoded_array.element_constructor(), &expected_el_ctor);
963        let array_items: Vec<Variant> = decoded_array
964            .decode()
965            .expect("Failed to decode Array items using Variant type");
966        assert_eq!(array_items, expected_array);
967    }
968
969    #[test]
970    fn option_i8() {
971        let mut b1 = BytesMut::with_capacity(0);
972        Some(42i8).encode(&mut b1);
973
974        assert_eq!(
975            Some(42),
976            unwrap_value(Option::<i8>::decode(&mut b1.freeze()))
977        );
978
979        let mut b2 = BytesMut::with_capacity(0);
980        let o1: Option<i8> = None;
981        o1.encode(&mut b2);
982
983        assert_eq!(None, unwrap_value(Option::<i8>::decode(&mut b2.freeze())));
984    }
985
986    #[test]
987    fn option_string() {
988        let mut b1 = BytesMut::with_capacity(0);
989        Some(ByteString::from("hello")).encode(&mut b1);
990
991        assert_eq!(
992            Some(ByteString::from("hello")),
993            unwrap_value(Option::<ByteString>::decode(&mut b1.freeze()))
994        );
995
996        let mut b2 = BytesMut::with_capacity(0);
997        let o1: Option<ByteString> = None;
998        o1.encode(&mut b2);
999
1000        assert_eq!(
1001            None,
1002            unwrap_value(Option::<ByteString>::decode(&mut b2.freeze()))
1003        );
1004    }
1005}