ntex_amqp_codec/codec/
decode.rs

1use std::{char, collections, convert::TryFrom, hash::BuildHasher, hash::Hash};
2
3use byteorder::{BigEndian, ByteOrder};
4use chrono::{DateTime, TimeZone, Utc};
5use ntex_bytes::{Buf, ByteString, Bytes};
6use ordered_float::OrderedFloat;
7use uuid::Uuid;
8
9use crate::codec::{self, ArrayHeader, Decode, DecodeFormatted, ListHeader, MapHeader};
10use crate::error::AmqpParseError;
11use crate::framing::{self, AmqpFrame, SaslFrame, HEADER_LEN};
12use crate::protocol;
13use crate::types::{
14    Array, Constructor, DescribedCompound, Descriptor, List, Multiple, Str, Symbol, Variant,
15    VariantMap, VecStringMap, VecSymbolMap,
16};
17use crate::HashMap;
18
19macro_rules! be_read {
20    ($input:ident, $fn:ident, $size:expr) => {{
21        decode_check_len!($input, $size);
22        let result = BigEndian::$fn(&$input);
23        $input.advance($size);
24        Ok(result)
25    }};
26}
27
28fn read_u8(input: &mut Bytes) -> Result<u8, AmqpParseError> {
29    decode_check_len!(input, 1);
30    let code = input[0];
31    input.advance(1);
32    Ok(code)
33}
34
35fn read_i8(input: &mut Bytes) -> Result<i8, AmqpParseError> {
36    decode_check_len!(input, 1);
37    let code = input[0] as i8;
38    input.advance(1);
39    Ok(code)
40}
41
42fn read_bytes_u8(input: &mut Bytes) -> Result<Bytes, AmqpParseError> {
43    let len = read_u8(input)?;
44    let len = len as usize;
45    decode_check_len!(input, len);
46    Ok(input.split_to(len))
47}
48
49fn read_bytes_u32(input: &mut Bytes) -> Result<Bytes, AmqpParseError> {
50    let result: Result<u32, AmqpParseError> = be_read!(input, read_u32, 4);
51    let len = result?;
52    let len = len as usize;
53    decode_check_len!(input, len);
54    Ok(input.split_to(len))
55}
56
57#[macro_export]
58macro_rules! validate_code {
59    ($fmt:ident, $code:expr) => {
60        if $fmt != $code {
61            return Err(AmqpParseError::InvalidFormatCode($fmt));
62        }
63    };
64}
65
66impl DecodeFormatted for bool {
67    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
68        match fmt {
69            codec::FORMATCODE_BOOLEAN => read_u8(input).map(|o| o != 0),
70            codec::FORMATCODE_BOOLEAN_TRUE => Ok(true),
71            codec::FORMATCODE_BOOLEAN_FALSE => Ok(false),
72            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
73        }
74    }
75}
76
77impl DecodeFormatted for u8 {
78    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
79        validate_code!(fmt, codec::FORMATCODE_UBYTE);
80        read_u8(input)
81    }
82}
83
84impl DecodeFormatted for u16 {
85    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
86        validate_code!(fmt, codec::FORMATCODE_USHORT);
87        be_read!(input, read_u16, 2)
88    }
89}
90
91impl DecodeFormatted for u32 {
92    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
93        match fmt {
94            codec::FORMATCODE_UINT => be_read!(input, read_u32, 4),
95            codec::FORMATCODE_SMALLUINT => read_u8(input).map(u32::from),
96            codec::FORMATCODE_UINT_0 => Ok(0),
97            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
98        }
99    }
100}
101
102impl DecodeFormatted for u64 {
103    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
104        match fmt {
105            codec::FORMATCODE_ULONG => be_read!(input, read_u64, 8),
106            codec::FORMATCODE_SMALLULONG => read_u8(input).map(u64::from),
107            codec::FORMATCODE_ULONG_0 => Ok(0),
108            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
109        }
110    }
111}
112
113impl DecodeFormatted for i8 {
114    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
115        validate_code!(fmt, codec::FORMATCODE_BYTE);
116        read_i8(input)
117    }
118}
119
120impl DecodeFormatted for i16 {
121    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
122        validate_code!(fmt, codec::FORMATCODE_SHORT);
123        be_read!(input, read_i16, 2)
124    }
125}
126
127impl DecodeFormatted for i32 {
128    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
129        match fmt {
130            codec::FORMATCODE_INT => be_read!(input, read_i32, 4),
131            codec::FORMATCODE_SMALLINT => read_i8(input).map(i32::from),
132            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
133        }
134    }
135}
136
137impl DecodeFormatted for i64 {
138    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
139        match fmt {
140            codec::FORMATCODE_LONG => be_read!(input, read_i64, 8),
141            codec::FORMATCODE_SMALLLONG => read_i8(input).map(i64::from),
142            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
143        }
144    }
145}
146
147impl DecodeFormatted for f32 {
148    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
149        validate_code!(fmt, codec::FORMATCODE_FLOAT);
150        be_read!(input, read_f32, 4)
151    }
152}
153
154impl DecodeFormatted for f64 {
155    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
156        validate_code!(fmt, codec::FORMATCODE_DOUBLE);
157        be_read!(input, read_f64, 8)
158    }
159}
160
161impl DecodeFormatted for char {
162    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
163        validate_code!(fmt, codec::FORMATCODE_CHAR);
164        let result: Result<u32, AmqpParseError> = be_read!(input, read_u32, 4);
165        let o = result?;
166        if let Some(c) = char::from_u32(o) {
167            Ok(c)
168        } else {
169            Err(AmqpParseError::InvalidChar(o))
170        } // todo: replace with CharTryFromError once try_from is stabilized
171    }
172}
173
174impl DecodeFormatted for DateTime<Utc> {
175    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
176        validate_code!(fmt, codec::FORMATCODE_TIMESTAMP);
177        be_read!(input, read_i64, 8).and_then(datetime_from_millis)
178    }
179}
180
181impl DecodeFormatted for Uuid {
182    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
183        validate_code!(fmt, codec::FORMATCODE_UUID);
184        decode_check_len!(input, 16);
185        let uuid =
186            Uuid::from_slice(&input.split_to(16)).map_err(|_| AmqpParseError::UuidParseError)?;
187        Ok(uuid)
188    }
189}
190
191impl DecodeFormatted for Bytes {
192    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
193        match fmt {
194            codec::FORMATCODE_BINARY8 => read_bytes_u8(input),
195            codec::FORMATCODE_BINARY32 => read_bytes_u32(input),
196            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
197        }
198    }
199}
200
201impl DecodeFormatted for ByteString {
202    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
203        match fmt {
204            codec::FORMATCODE_STRING8 => {
205                let bytes = read_bytes_u8(input)?;
206                Ok(ByteString::try_from(bytes).map_err(|_| AmqpParseError::Utf8Error)?)
207            }
208            codec::FORMATCODE_STRING32 => {
209                let bytes = read_bytes_u32(input)?;
210                Ok(ByteString::try_from(bytes).map_err(|_| AmqpParseError::Utf8Error)?)
211            }
212            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
213        }
214    }
215}
216
217impl DecodeFormatted for Str {
218    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
219        Ok(Str::from(ByteString::decode_with_format(input, fmt)?))
220    }
221}
222
223impl DecodeFormatted for Symbol {
224    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
225        match fmt {
226            codec::FORMATCODE_SYMBOL8 => {
227                let bytes = read_bytes_u8(input)?;
228                Ok(Symbol(Str::from(
229                    ByteString::try_from(bytes).map_err(|_| AmqpParseError::Utf8Error)?,
230                )))
231            }
232            codec::FORMATCODE_SYMBOL32 => {
233                let bytes = read_bytes_u32(input)?;
234                Ok(Symbol(Str::from(
235                    ByteString::try_from(bytes).map_err(|_| AmqpParseError::Utf8Error)?,
236                )))
237            }
238            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
239        }
240    }
241}
242
243impl<K: Decode + Eq + Hash, V: Decode, S: BuildHasher + Default> DecodeFormatted
244    for collections::HashMap<K, V, S>
245{
246    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
247        let header = MapHeader::decode_with_format(input, fmt)?;
248        decode_check_len!(input, header.size as usize);
249        let mut map_input = input.split_to(header.size as usize);
250        let count = header.count / 2;
251        let mut map: collections::HashMap<K, V, S> =
252            collections::HashMap::with_capacity_and_hasher(count as usize, Default::default());
253        for _ in 0..count {
254            let key = K::decode(&mut map_input)?;
255            let value = V::decode(&mut map_input)?;
256            map.insert(key, value); // todo: ensure None returned?
257        }
258        // todo: validate map_input is empty
259        Ok(map)
260    }
261}
262
263impl<T: DecodeFormatted> DecodeFormatted for Vec<T> {
264    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
265        let header = ArrayHeader::decode_with_format(input, fmt)?;
266        decode_check_len!(input, header.size as usize);
267        let elem_ctor = Constructor::decode(input)?;
268        let elem_fmt = match elem_ctor {
269            Constructor::FormatCode(code) => code,
270            Constructor::Described { descriptor, .. } => {
271                // todo: mg: described types are not supported OOTB at this point
272                return Err(AmqpParseError::InvalidDescriptor(Box::new(descriptor)));
273            }
274        };
275        let mut result: Vec<T> = Vec::with_capacity(header.count as usize);
276        for _ in 0..header.count {
277            let decoded = T::decode_with_format(input, elem_fmt)?;
278            result.push(decoded);
279        }
280        // todo: ensure header.size bytes were read out from input
281        Ok(result)
282    }
283}
284
285impl DecodeFormatted for VecSymbolMap {
286    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
287        let header = MapHeader::decode_with_format(input, fmt)?;
288        decode_check_len!(input, header.size as usize);
289        let mut map_input = input.split_to(header.size as usize);
290        let count = header.count / 2;
291        let mut map = Vec::with_capacity(count as usize);
292        for _ in 0..count {
293            let key = Symbol::decode(&mut map_input)?;
294            let value = Variant::decode(&mut map_input)?;
295            map.push((key, value)); // todo: mg: ensure None is returned
296        }
297        // todo: ensure header.size bytes were read out from input after header
298        Ok(VecSymbolMap(map))
299    }
300}
301
302impl DecodeFormatted for VecStringMap {
303    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
304        let header = MapHeader::decode_with_format(input, fmt)?;
305        decode_check_len!(input, header.size as usize);
306        let mut map_input = input.split_to(header.size as usize);
307        let count = header.count / 2;
308        let mut map = Vec::with_capacity(count as usize);
309        for _ in 0..count {
310            let key = Str::decode(&mut map_input)?;
311            let value = Variant::decode(&mut map_input)?;
312            map.push((key, value)); // todo: ensure None returned?
313        }
314        // todo: validate map_input is empty
315        Ok(VecStringMap(map))
316    }
317}
318
319impl<T: DecodeFormatted> DecodeFormatted for Multiple<T> {
320    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
321        match fmt {
322            codec::FORMATCODE_ARRAY8 | codec::FORMATCODE_ARRAY32 => {
323                let items = Vec::<T>::decode_with_format(input, fmt)?;
324                Ok(Multiple(items))
325            }
326            codec::FORMATCODE_DESCRIBED => {
327                let descriptor = Descriptor::decode_with_format(input, fmt)?;
328                // todo: mg: described types are not supported OOTB at this point
329                Err(AmqpParseError::InvalidDescriptor(Box::new(descriptor)))
330            }
331            _ => {
332                let item = T::decode_with_format(input, fmt)?;
333                Ok(Multiple(vec![item]))
334            }
335        }
336    }
337}
338
339impl DecodeFormatted for List {
340    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
341        let header = ListHeader::decode_with_format(input, fmt)?;
342        let mut result: Vec<Variant> = Vec::with_capacity(header.count as usize);
343        for _ in 0..header.count {
344            let decoded = Variant::decode(input)?;
345            result.push(decoded);
346        }
347        Ok(List(result))
348    }
349}
350
351impl DecodeFormatted for Variant {
352    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
353        match fmt {
354            codec::FORMATCODE_NULL => Ok(Variant::Null),
355            codec::FORMATCODE_BOOLEAN => bool::decode_with_format(input, fmt).map(Variant::Boolean),
356            codec::FORMATCODE_BOOLEAN_FALSE => Ok(Variant::Boolean(false)),
357            codec::FORMATCODE_BOOLEAN_TRUE => Ok(Variant::Boolean(true)),
358            codec::FORMATCODE_UINT_0 => Ok(Variant::Uint(0)),
359            codec::FORMATCODE_ULONG_0 => Ok(Variant::Ulong(0)),
360            codec::FORMATCODE_UBYTE => u8::decode_with_format(input, fmt).map(Variant::Ubyte),
361            codec::FORMATCODE_USHORT => u16::decode_with_format(input, fmt).map(Variant::Ushort),
362            codec::FORMATCODE_UINT => u32::decode_with_format(input, fmt).map(Variant::Uint),
363            codec::FORMATCODE_ULONG => u64::decode_with_format(input, fmt).map(Variant::Ulong),
364            codec::FORMATCODE_BYTE => i8::decode_with_format(input, fmt).map(Variant::Byte),
365            codec::FORMATCODE_SHORT => i16::decode_with_format(input, fmt).map(Variant::Short),
366            codec::FORMATCODE_INT => i32::decode_with_format(input, fmt).map(Variant::Int),
367            codec::FORMATCODE_LONG => i64::decode_with_format(input, fmt).map(Variant::Long),
368            codec::FORMATCODE_SMALLUINT => u32::decode_with_format(input, fmt).map(Variant::Uint),
369            codec::FORMATCODE_SMALLULONG => u64::decode_with_format(input, fmt).map(Variant::Ulong),
370            codec::FORMATCODE_SMALLINT => i32::decode_with_format(input, fmt).map(Variant::Int),
371            codec::FORMATCODE_SMALLLONG => i64::decode_with_format(input, fmt).map(Variant::Long),
372            codec::FORMATCODE_FLOAT => {
373                f32::decode_with_format(input, fmt).map(|o| Variant::Float(OrderedFloat(o)))
374            }
375            codec::FORMATCODE_DOUBLE => {
376                f64::decode_with_format(input, fmt).map(|o| Variant::Double(OrderedFloat(o)))
377            }
378            codec::FORMATCODE_DECIMAL32 => read_fixed_bytes(input).map(Variant::Decimal32),
379            codec::FORMATCODE_DECIMAL64 => read_fixed_bytes(input).map(Variant::Decimal64),
380            codec::FORMATCODE_DECIMAL128 => read_fixed_bytes(input).map(Variant::Decimal128),
381            codec::FORMATCODE_CHAR => char::decode_with_format(input, fmt).map(Variant::Char),
382            codec::FORMATCODE_TIMESTAMP => {
383                DateTime::<Utc>::decode_with_format(input, fmt).map(Variant::Timestamp)
384            }
385            codec::FORMATCODE_UUID => Uuid::decode_with_format(input, fmt).map(Variant::Uuid),
386            codec::FORMATCODE_BINARY8 | codec::FORMATCODE_BINARY32 => {
387                Bytes::decode_with_format(input, fmt).map(Variant::Binary)
388            }
389            codec::FORMATCODE_STRING8 | codec::FORMATCODE_STRING32 => {
390                ByteString::decode_with_format(input, fmt).map(|o| Variant::String(o.into()))
391            }
392            codec::FORMATCODE_SYMBOL8 | codec::FORMATCODE_SYMBOL32 => {
393                Symbol::decode_with_format(input, fmt).map(Variant::Symbol)
394            }
395            codec::FORMATCODE_LIST0 => Ok(Variant::List(List(vec![]))),
396            codec::FORMATCODE_LIST8 | codec::FORMATCODE_LIST32 => {
397                List::decode_with_format(input, fmt).map(Variant::List)
398            }
399            codec::FORMATCODE_ARRAY8 | codec::FORMATCODE_ARRAY32 => {
400                Array::decode_with_format(input, fmt).map(Variant::Array)
401            }
402            codec::FORMATCODE_MAP8 | codec::FORMATCODE_MAP32 => {
403                HashMap::<Variant, Variant>::decode_with_format(input, fmt)
404                    .map(|o| Variant::Map(VariantMap::new(o)))
405            }
406            codec::FORMATCODE_DESCRIBED => {
407                let descriptor = Descriptor::decode(input)?;
408                let format_code = {
409                    decode_check_len!(input, 1);
410                    let code = input[0];
411                    Ok(code)
412                }?;
413                match format_code {
414                    codec::FORMATCODE_LIST0 => {
415                        input.advance(1); // advance past format code
416                        Ok(Variant::DescribedCompound(DescribedCompound::new(
417                            descriptor,
418                            Bytes::from_static(&[codec::FORMATCODE_LIST0]),
419                        )))
420                    }
421                    codec::FORMATCODE_LIST8 | codec::FORMATCODE_MAP8 | codec::FORMATCODE_ARRAY8 => {
422                        decode_check_len!(input, 2);
423                        let size = input[1] as usize;
424                        decode_check_len!(input, 2 + size);
425                        let data = input.split_to(2 + size);
426                        Ok(Variant::DescribedCompound(DescribedCompound::new(
427                            descriptor, data,
428                        )))
429                    }
430                    codec::FORMATCODE_LIST32
431                    | codec::FORMATCODE_MAP32
432                    | codec::FORMATCODE_ARRAY32 => {
433                        decode_check_len!(input, 5);
434                        let size = u32::from_be_bytes(input[1..5].try_into().unwrap()) as usize;
435                        decode_check_len!(input, 5 + size);
436                        let data = input.split_to(5 + size);
437                        Ok(Variant::DescribedCompound(DescribedCompound::new(
438                            descriptor, data,
439                        )))
440                    }
441                    _ => {
442                        input.advance(1); // advance past format code
443                        let value = Variant::decode_with_format(input, format_code)?;
444                        Ok(Variant::Described((descriptor, Box::new(value))))
445                    }
446                }
447            }
448            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
449        }
450    }
451}
452
453impl<T: DecodeFormatted> DecodeFormatted for Option<T> {
454    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
455        match fmt {
456            codec::FORMATCODE_NULL => Ok(None),
457            _ => T::decode_with_format(input, fmt).map(Some),
458        }
459    }
460}
461
462impl DecodeFormatted for Descriptor {
463    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
464        match fmt {
465            codec::FORMATCODE_SMALLULONG => {
466                u64::decode_with_format(input, fmt).map(Descriptor::Ulong)
467            }
468            codec::FORMATCODE_ULONG => u64::decode_with_format(input, fmt).map(Descriptor::Ulong),
469            codec::FORMATCODE_SYMBOL8 => {
470                Symbol::decode_with_format(input, fmt).map(Descriptor::Symbol)
471            }
472            codec::FORMATCODE_SYMBOL32 => {
473                Symbol::decode_with_format(input, fmt).map(Descriptor::Symbol)
474            }
475            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
476        }
477    }
478}
479
480impl DecodeFormatted for Constructor {
481    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
482        match fmt {
483            codec::FORMATCODE_DESCRIBED => {
484                let descriptor = Descriptor::decode(input)?;
485                let format_code = codec::decode_format_code(input)?;
486                Ok(Constructor::Described {
487                    descriptor,
488                    format_code,
489                })
490            }
491            _ => Ok(Constructor::FormatCode(fmt)),
492        }
493    }
494}
495
496impl Decode for AmqpFrame {
497    fn decode(input: &mut Bytes) -> Result<Self, AmqpParseError> {
498        let channel_id = decode_frame_header(input, framing::FRAME_TYPE_AMQP)?;
499        let performative = protocol::Frame::decode(input)?;
500        Ok(AmqpFrame::new(channel_id, performative))
501    }
502}
503
504impl Decode for SaslFrame {
505    fn decode(input: &mut Bytes) -> Result<Self, AmqpParseError> {
506        let _ = decode_frame_header(input, framing::FRAME_TYPE_SASL)?;
507        let frame = protocol::SaslFrameBody::decode(input)?;
508        Ok(SaslFrame { body: frame })
509    }
510}
511
512impl DecodeFormatted for ListHeader {
513    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
514        match fmt {
515            codec::FORMATCODE_LIST0 => Ok(ListHeader { count: 0, size: 0 }),
516            codec::FORMATCODE_LIST8 => {
517                decode_compound8(input).map(|(size, count)| ListHeader { count, size })
518            }
519            codec::FORMATCODE_LIST32 => {
520                decode_compound32(input).map(|(size, count)| ListHeader { count, size })
521            }
522            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
523        }
524    }
525}
526
527impl DecodeFormatted for MapHeader {
528    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
529        match fmt {
530            codec::FORMATCODE_MAP8 => {
531                decode_compound8(input).map(|(size, count)| MapHeader { count, size })
532            }
533            codec::FORMATCODE_MAP32 => {
534                decode_compound32(input).map(|(size, count)| MapHeader { count, size })
535            }
536            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
537        }
538    }
539}
540
541impl DecodeFormatted for ArrayHeader {
542    fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
543        match fmt {
544            codec::FORMATCODE_ARRAY8 => {
545                decode_compound8(input).map(|(size, count)| ArrayHeader { count, size })
546            }
547            codec::FORMATCODE_ARRAY32 => {
548                decode_compound32(input).map(|(size, count)| ArrayHeader { count, size })
549            }
550            _ => Err(AmqpParseError::InvalidFormatCode(fmt)),
551        }
552    }
553}
554
555fn decode_frame_header(input: &mut Bytes, expected_frame_type: u8) -> Result<u16, AmqpParseError> {
556    decode_check_len!(input, 4);
557    let doff = input[0];
558    let frame_type = input[1];
559    if frame_type != expected_frame_type {
560        return Err(AmqpParseError::UnexpectedFrameType(frame_type));
561    }
562
563    let channel_id = BigEndian::read_u16(&input[2..]);
564    let doff = doff as usize * 4;
565    if doff < HEADER_LEN {
566        return Err(AmqpParseError::InvalidSize);
567    }
568    // skipping remaining two header bytes and ext header
569    let ext_header_len = doff - HEADER_LEN + 4;
570    decode_check_len!(input, ext_header_len);
571    input.advance(ext_header_len);
572    Ok(channel_id)
573}
574
575fn decode_compound8(input: &mut Bytes) -> Result<(u32, u32), AmqpParseError> {
576    decode_check_len!(input, 2);
577    let size = input[0] - 1; // -1 for 1 byte count
578    let count = input[1];
579    input.advance(2);
580    Ok((u32::from(size), u32::from(count)))
581}
582
583fn decode_compound32(input: &mut Bytes) -> Result<(u32, u32), AmqpParseError> {
584    decode_check_len!(input, 8);
585    let size = BigEndian::read_u32(input) - 4; // -4 for 4 byte count
586    let count = BigEndian::read_u32(&input[4..]);
587    input.advance(8);
588    Ok((size, count))
589}
590
591fn datetime_from_millis(millis: i64) -> Result<DateTime<Utc>, AmqpParseError> {
592    let seconds = millis / 1000;
593    if seconds < 0 {
594        // In order to handle time before 1970 correctly, we need to subtract a second
595        // and use the nanoseconds field to add it back. This is a result of the nanoseconds
596        // parameter being u32
597        let nanoseconds = ((1000 + (millis - (seconds * 1000))) * 1_000_000).unsigned_abs();
598        Utc.timestamp_opt(seconds - 1, nanoseconds as u32)
599            .earliest()
600            .ok_or(AmqpParseError::DatetimeParseError)
601    } else {
602        let nanoseconds = ((millis - (seconds * 1000)) * 1_000_000).unsigned_abs();
603        Utc.timestamp_opt(seconds, nanoseconds as u32)
604            .earliest()
605            .ok_or(AmqpParseError::DatetimeParseError)
606    }
607}
608
609fn read_fixed_bytes<const N: usize>(input: &mut Bytes) -> Result<[u8; N], AmqpParseError> {
610    decode_check_len!(input, N);
611    let mut data = [0u8; N];
612    data.copy_from_slice(&input[..N]);
613    input.advance(N);
614    Ok(data)
615}
616
617#[cfg(test)]
618mod tests {
619    use chrono::TimeDelta;
620    use ntex_bytes::{BufMut, BytesMut};
621    use test_case::test_case;
622
623    use super::*;
624    use crate::codec::{Decode, Encode};
625
626    const LOREM: &str = include_str!("lorem.txt");
627
628    macro_rules! decode_tests {
629        ($($name:ident: $kind:ident, $test:expr, $expected:expr,)*) => {
630        $(
631            #[test]
632            fn $name() {
633                let mut b1 = BytesMut::with_capacity(($test).encoded_size());
634                ($test).encode(&mut b1);
635                assert_eq!($expected, <$kind as Decode>::decode(&mut b1.freeze()).unwrap());
636            }
637        )*
638        }
639    }
640
641    decode_tests! {
642        ubyte: u8, 255_u8, 255_u8,
643        ushort: u16, 350_u16, 350_u16,
644
645        uint_zero: u32, 0_u32, 0_u32,
646        uint_small: u32, 128_u32, 128_u32,
647        uint_big: u32, 2147483647_u32, 2147483647_u32,
648
649        ulong_zero: u64, 0_u64, 0_u64,
650        ulong_small: u64, 128_u64, 128_u64,
651        uulong_big: u64, 2147483649_u64, 2147483649_u64,
652
653        byte: i8, -128_i8, -128_i8,
654        short: i16, -255_i16, -255_i16,
655
656        int_zero: i32, 0_i32, 0_i32,
657        int_small: i32, -50000_i32, -50000_i32,
658        int_neg: i32, -128_i32, -128_i32,
659
660        long_zero: i64, 0_i64, 0_i64,
661        long_big: i64, -2147483647_i64, -2147483647_i64,
662        long_small: i64, -128_i64, -128_i64,
663
664        float: f32, 1.234_f32, 1.234_f32,
665        double: f64, 1.234_f64, 1.234_f64,
666
667        test_char: char, '💯', '💯',
668
669        uuid: Uuid, Uuid::from_slice(&[4, 54, 67, 12, 43, 2, 98, 76, 32, 50, 87, 5, 1, 33, 43, 87]).expect("parse error"),
670        Uuid::parse_str("0436430c2b02624c2032570501212b57").expect("parse error"),
671
672        binary_short: Bytes, Bytes::from(&[4u8, 5u8][..]), Bytes::from(&[4u8, 5u8][..]),
673        binary_long: Bytes, Bytes::from(&[4u8; 500][..]), Bytes::from(&[4u8; 500][..]),
674
675        string_short: ByteString, ByteString::from("Hello there"), ByteString::from("Hello there"),
676        string_long: ByteString, ByteString::from(LOREM), ByteString::from(LOREM),
677
678        // symbol_short: Symbol, Symbol::from("Hello there"), Symbol::from("Hello there"),
679        // symbol_long: Symbol, Symbol::from(LOREM), Symbol::from(LOREM),
680
681        variant_ubyte: Variant, Variant::Ubyte(255_u8), Variant::Ubyte(255_u8),
682        variant_ushort: Variant, Variant::Ushort(350_u16), Variant::Ushort(350_u16),
683
684        variant_uint_zero: Variant, Variant::Uint(0_u32), Variant::Uint(0_u32),
685        variant_uint_small: Variant, Variant::Uint(128_u32), Variant::Uint(128_u32),
686        variant_uint_big: Variant, Variant::Uint(2147483647_u32), Variant::Uint(2147483647_u32),
687
688        variant_ulong_zero: Variant, Variant::Ulong(0_u64), Variant::Ulong(0_u64),
689        variant_ulong_small: Variant, Variant::Ulong(128_u64), Variant::Ulong(128_u64),
690        variant_ulong_big: Variant, Variant::Ulong(2147483649_u64), Variant::Ulong(2147483649_u64),
691
692        variant_byte: Variant, Variant::Byte(-128_i8), Variant::Byte(-128_i8),
693        variant_short: Variant, Variant::Short(-255_i16), Variant::Short(-255_i16),
694
695        variant_int_zero: Variant, Variant::Int(0_i32), Variant::Int(0_i32),
696        variant_int_small: Variant, Variant::Int(-50000_i32), Variant::Int(-50000_i32),
697        variant_int_neg: Variant, Variant::Int(-128_i32), Variant::Int(-128_i32),
698
699        variant_long_zero: Variant, Variant::Long(0_i64), Variant::Long(0_i64),
700        variant_long_big: Variant, Variant::Long(-2147483647_i64), Variant::Long(-2147483647_i64),
701        variant_long_small: Variant, Variant::Long(-128_i64), Variant::Long(-128_i64),
702
703        variant_float: Variant, Variant::Float(OrderedFloat(1.234_f32)), Variant::Float(OrderedFloat(1.234_f32)),
704        variant_double: Variant, Variant::Double(OrderedFloat(1.234_f64)), Variant::Double(OrderedFloat(1.234_f64)),
705
706        variant_char: Variant, Variant::Char('💯'), Variant::Char('💯'),
707
708        variant_uuid: Variant, Variant::Uuid(Uuid::from_slice(&[4, 54, 67, 12, 43, 2, 98, 76, 32, 50, 87, 5, 1, 33, 43, 87]).expect("parse error")),
709        Variant::Uuid(Uuid::parse_str("0436430c2b02624c2032570501212b57").expect("parse error")),
710
711        variant_binary_short: Variant, Variant::Binary(Bytes::from(&[4u8, 5u8][..])), Variant::Binary(Bytes::from(&[4u8, 5u8][..])),
712        variant_binary_long: Variant, Variant::Binary(Bytes::from(&[4u8; 500][..])), Variant::Binary(Bytes::from(&[4u8; 500][..])),
713
714        variant_string_short: Variant, Variant::String(ByteString::from("Hello there").into()), Variant::String(ByteString::from("Hello there").into()),
715        variant_string_long: Variant, Variant::String(ByteString::from(LOREM).into()), Variant::String(ByteString::from(LOREM).into()),
716
717        // variant_symbol_short: Variant, Variant::Symbol(Symbol::from("Hello there")), Variant::Symbol(Symbol::from("Hello there")),
718        // variant_symbol_long: Variant, Variant::Symbol(Symbol::from(LOREM)), Variant::Symbol(Symbol::from(LOREM)),
719    }
720
721    fn unwrap_value<T>(res: Result<T, AmqpParseError>) -> T {
722        assert!(res.is_ok());
723        res.unwrap()
724    }
725
726    #[test]
727    fn test_bool_true() {
728        let mut b1 = BytesMut::with_capacity(0);
729        b1.put_u8(0x41);
730        assert!(unwrap_value(bool::decode(&mut b1.freeze())));
731
732        let mut b2 = BytesMut::with_capacity(0);
733        b2.put_u8(0x56);
734        b2.put_u8(0x01);
735        assert!(unwrap_value(bool::decode(&mut b2.freeze())));
736    }
737
738    #[test]
739    fn test_bool_false() {
740        let mut b1 = BytesMut::with_capacity(0);
741        b1.put_u8(0x42u8);
742        assert!(!unwrap_value(bool::decode(&mut b1.freeze())));
743
744        let mut b2 = BytesMut::with_capacity(0);
745        b2.put_u8(0x56);
746        b2.put_u8(0x00);
747        assert!(!unwrap_value(bool::decode(&mut b2.freeze())));
748    }
749
750    /// UTC with a precision of milliseconds. For example, 1311704463521
751    /// represents the moment 2011-07-26T18:21:03.521Z.
752    #[test]
753    fn test_timestamp() {
754        let mut b1 = BytesMut::with_capacity(0);
755        let datetime =
756            Utc.with_ymd_and_hms(2011, 7, 26, 18, 21, 3).unwrap() + TimeDelta::milliseconds(521);
757        datetime.encode(&mut b1);
758
759        let expected =
760            Utc.with_ymd_and_hms(2011, 7, 26, 18, 21, 3).unwrap() + TimeDelta::milliseconds(521);
761        assert_eq!(
762            expected,
763            unwrap_value(DateTime::<Utc>::decode(&mut b1.freeze()))
764        );
765    }
766
767    #[test]
768    fn test_timestamp_pre_unix() {
769        let mut b1 = BytesMut::with_capacity(0);
770        let datetime =
771            Utc.with_ymd_and_hms(1968, 7, 26, 18, 21, 3).unwrap() + TimeDelta::milliseconds(521);
772        datetime.encode(&mut b1);
773
774        let expected =
775            Utc.with_ymd_and_hms(1968, 7, 26, 18, 21, 3).unwrap() + TimeDelta::milliseconds(521);
776        assert_eq!(
777            expected,
778            unwrap_value(DateTime::<Utc>::decode(&mut b1.freeze()))
779        );
780    }
781
782    #[test]
783    fn variant_null() {
784        let mut b = BytesMut::with_capacity(0);
785        Variant::Null.encode(&mut b);
786        let t = unwrap_value(Variant::decode(&mut b.freeze()));
787        assert_eq!(Variant::Null, t);
788    }
789
790    #[test]
791    fn variant_bool_true() {
792        let mut b1 = BytesMut::with_capacity(0);
793        b1.put_u8(0x41);
794        assert_eq!(
795            Variant::Boolean(true),
796            unwrap_value(Variant::decode(&mut b1.freeze()))
797        );
798
799        let mut b2 = BytesMut::with_capacity(0);
800        b2.put_u8(0x56);
801        b2.put_u8(0x01);
802        assert_eq!(
803            Variant::Boolean(true),
804            unwrap_value(Variant::decode(&mut b2.freeze()))
805        );
806    }
807
808    #[test]
809    fn variant_bool_false() {
810        let mut b1 = BytesMut::with_capacity(0);
811        b1.put_u8(0x42u8);
812        assert_eq!(
813            Variant::Boolean(false),
814            unwrap_value(Variant::decode(&mut b1.freeze()))
815        );
816
817        let mut b2 = BytesMut::with_capacity(0);
818        b2.put_u8(0x56);
819        b2.put_u8(0x00);
820        assert_eq!(
821            Variant::Boolean(false),
822            unwrap_value(Variant::decode(&mut b2.freeze()))
823        );
824    }
825
826    /// UTC with a precision of milliseconds. For example, 1311704463521
827    /// represents the moment 2011-07-26T18:21:03.521Z.
828    #[test]
829    fn variant_timestamp() {
830        let mut b1 = BytesMut::with_capacity(0);
831        let datetime =
832            Utc.with_ymd_and_hms(2011, 7, 26, 18, 21, 3).unwrap() + TimeDelta::milliseconds(521);
833        Variant::Timestamp(datetime).encode(&mut b1);
834
835        let expected =
836            Utc.with_ymd_and_hms(2011, 7, 26, 18, 21, 3).unwrap() + TimeDelta::milliseconds(521);
837        assert_eq!(
838            Variant::Timestamp(expected),
839            unwrap_value(Variant::decode(&mut b1.freeze()))
840        );
841    }
842
843    #[test]
844    fn variant_timestamp_pre_unix() {
845        let mut b1 = BytesMut::with_capacity(0);
846        let datetime =
847            Utc.with_ymd_and_hms(1968, 7, 26, 18, 21, 3).unwrap() + TimeDelta::milliseconds(521);
848        Variant::Timestamp(datetime).encode(&mut b1);
849
850        let expected =
851            Utc.with_ymd_and_hms(1968, 7, 26, 18, 21, 3).unwrap() + TimeDelta::milliseconds(521);
852        assert_eq!(
853            Variant::Timestamp(expected),
854            unwrap_value(Variant::decode(&mut b1.freeze()))
855        );
856    }
857
858    #[test_case(
859        b"\x00\xa3\x07foo:bar\xc0\x03\x01\x50\x03",
860        Descriptor::Symbol("foo:bar".into()),
861        List(vec![Variant::Ubyte(3)]);
862        "described 'foo:bar', list8 w/one u8 field with value 3")]
863    #[test_case(
864        b"\x00\x80\x00\x00\x01\x37\x00\x00\x03\xe9\x45",
865        Descriptor::Ulong((311 << 32) + 1001), List(vec![]); "described 311:1001, list0")]
866    #[test_case(
867        b"\x00\x80\x00\x01\xd4\xc0\x00\x03\x82\x70\xd0\x00\x00\x00\x0c\x00\x00\x00\x03\x53\x6f\xa1\x03abc\x42",
868        Descriptor::Ulong((120_000 << 32) + 230_000),
869        List(vec![Variant::Ulong(111), Variant::String("abc".into()), Variant::Boolean(false)]);
870        "described 120000:230000, list32 w/3 fields: smallulong: 111, string8: 'abc', booleanfalse")]
871    fn decode_described_list(
872        input: &'static [u8],
873        expected_descriptor: Descriptor,
874        expected_list: List,
875    ) {
876        let mut buf = Bytes::from(input);
877        let variant = Variant::decode(&mut buf).unwrap();
878        assert!(buf.is_empty(), "Expected no remaining bytes after decoding");
879        let dc = match variant {
880            Variant::DescribedCompound(dc) => dc,
881            _ => panic!("Expected a DescribedCompound variant"),
882        };
883        assert_eq!(dc.descriptor(), &expected_descriptor);
884        println!("{:02x?}", dc.data.as_ref());
885        let decoded_list: List = dc.decode().expect("Failed to decode List");
886        assert_eq!(decoded_list, expected_list);
887    }
888
889    #[test_case(
890        b"\x00\xa3\x05a:b:c\xc1\x08\x04\x50\x03\x41\x50\xc8\x56\x00",
891        Descriptor::Symbol("a:b:c".into()),
892        vec![(Variant::Ubyte(3), Variant::Boolean(true)), (Variant::Ubyte(200), Variant::Boolean(false))];
893        "described 'a:b:c', map8 with 2 pairs: (ubyte(3), true), (ubyte(200), false)")]
894    #[test_case(
895        b"\x00\x80\x00\x01\xd4\xc0\x00\x03\x82\x70\xd1\x00\x00\x00\x0a\x00\x00\x00\x02\x73\x00\x00\x00z\x40",
896        Descriptor::Ulong((120_000 << 32) + 230_000),
897        vec![(Variant::Char('z'), Variant::Null)];
898        "described 120000:230000, map32 with 1 pair: char: 'z', null")]
899    fn decode_described_map(
900        input: &'static [u8],
901        expected_descriptor: Descriptor,
902        expected_map: Vec<(Variant, Variant)>,
903    ) {
904        let mut buf = Bytes::from(input);
905        let variant = Variant::decode(&mut buf).unwrap();
906        assert!(buf.is_empty(), "Expected no remaining bytes after decoding");
907        let dc = match variant {
908            Variant::DescribedCompound(dc) => dc,
909            _ => panic!("Expected a DescribedCompound variant"),
910        };
911        assert_eq!(dc.descriptor(), &expected_descriptor);
912        println!("{:02x?}", dc.data.as_ref());
913        let decoded_map: HashMap<Variant, Variant> = dc.decode().expect("Failed to decode List");
914        let expected_map: HashMap<Variant, Variant> = expected_map.into_iter().collect();
915        assert_eq!(decoded_map, expected_map);
916    }
917
918    #[test_case(
919        b"\x00\xa3\x07foo:bar\xe0\x05\x03\x50\x01\x02\x03",
920        Descriptor::Symbol("foo:bar".into()),
921        Constructor::FormatCode(codec::FORMATCODE_UBYTE),
922        vec![Variant::Ubyte(1), Variant::Ubyte(2), Variant::Ubyte(3)];
923        "described 'foo:bar', array8 w/3 u8 elements: 1, 2, 3")]
924    fn decode_described_array(
925        input: &'static [u8],
926        expected_descriptor: Descriptor,
927        expected_el_ctor: Constructor,
928        expected_array: Vec<Variant>,
929    ) {
930        // todo: mg: array decoding: add check that all bytes are read out according to size when done decoding array elements /
931        // list fields / map key-value pairs according to count
932        let mut buf = Bytes::from(input);
933        let variant = Variant::decode(&mut buf).unwrap();
934        assert!(buf.is_empty(), "Expected no remaining bytes after decoding");
935        let dc = match variant {
936            Variant::DescribedCompound(dc) => dc,
937            _ => panic!("Expected a DescribedCompound variant"),
938        };
939        assert_eq!(dc.descriptor(), &expected_descriptor);
940        println!("{:02x?}", dc.data.as_ref());
941        let decoded_array: Array = dc.decode().expect("Failed to decode Array");
942        assert_eq!(decoded_array.element_constructor(), &expected_el_ctor);
943        let array_items: Vec<Variant> = decoded_array
944            .decode()
945            .expect("Failed to decode Array items using Variant type");
946        assert_eq!(array_items, expected_array);
947    }
948
949    #[test]
950    fn option_i8() {
951        let mut b1 = BytesMut::with_capacity(0);
952        Some(42i8).encode(&mut b1);
953
954        assert_eq!(
955            Some(42),
956            unwrap_value(Option::<i8>::decode(&mut b1.freeze()))
957        );
958
959        let mut b2 = BytesMut::with_capacity(0);
960        let o1: Option<i8> = None;
961        o1.encode(&mut b2);
962
963        assert_eq!(None, unwrap_value(Option::<i8>::decode(&mut b2.freeze())));
964    }
965
966    #[test]
967    fn option_string() {
968        let mut b1 = BytesMut::with_capacity(0);
969        Some(ByteString::from("hello")).encode(&mut b1);
970
971        assert_eq!(
972            Some(ByteString::from("hello")),
973            unwrap_value(Option::<ByteString>::decode(&mut b1.freeze()))
974        );
975
976        let mut b2 = BytesMut::with_capacity(0);
977        let o1: Option<ByteString> = None;
978        o1.encode(&mut b2);
979
980        assert_eq!(
981            None,
982            unwrap_value(Option::<ByteString>::decode(&mut b2.freeze()))
983        );
984    }
985}