ntex_amqp_codec/types/
variant.rs

1use std::hash::{Hash, Hasher};
2
3use chrono::{DateTime, Utc};
4use ntex_bytes::{ByteString, Bytes, BytesMut};
5use ordered_float::OrderedFloat;
6use uuid::Uuid;
7
8use crate::types::{Array, Descriptor, List, Str, Symbol};
9use crate::{protocol::Annotations, HashMap};
10use crate::{AmqpParseError, Decode, Encode};
11
12/// Represents an AMQP type for use in polymorphic collections
13#[derive(Debug, Eq, PartialEq, Hash, Clone, From)]
14pub enum Variant {
15    /// Indicates an empty value.
16    Null,
17
18    /// Represents a true or false value.
19    Boolean(bool),
20
21    /// Integer in the range 0 to 2^8 - 1 inclusive.
22    Ubyte(u8),
23
24    /// Integer in the range 0 to 2^16 - 1 inclusive.
25    Ushort(u16),
26
27    /// Integer in the range 0 to 2^32 - 1 inclusive.
28    Uint(u32),
29
30    /// Integer in the range 0 to 2^64 - 1 inclusive.
31    Ulong(u64),
32
33    /// Integer in the range 0 to 2^7 - 1 inclusive.
34    Byte(i8),
35
36    /// Integer in the range 0 to 2^15 - 1 inclusive.
37    Short(i16),
38
39    /// Integer in the range 0 to 2^32 - 1 inclusive.
40    Int(i32),
41
42    /// Integer in the range 0 to 2^64 - 1 inclusive.
43    Long(i64),
44
45    /// 32-bit floating point number (IEEE 754-2008 binary32).
46    Float(OrderedFloat<f32>),
47
48    /// 64-bit floating point number (IEEE 754-2008 binary64).
49    Double(OrderedFloat<f64>),
50
51    /// 32-bit decimal number, represented per IEEE 754-2008 decimal32 specification.
52    Decimal32([u8; 4]),
53
54    /// 64-bit decimal number, represented per IEEE 754-2008 decimal64 specification.
55    Decimal64([u8; 8]),
56
57    /// 128-bit decimal number, represented per IEEE 754-2008 decimal128 specification.
58    Decimal128([u8; 16]),
59
60    /// A single Unicode character.
61    Char(char),
62
63    /// An absolute point in time.
64    /// Represents an approximate point in time using the Unix time encoding of
65    /// UTC with a precision of milliseconds. For example, 1311704463521
66    /// represents the moment 2011-07-26T18:21:03.521Z.
67    Timestamp(DateTime<Utc>),
68
69    /// A universally unique identifier as defined by RFC-4122 section 4.1.2
70    Uuid(Uuid),
71
72    /// A sequence of octets.
73    Binary(Bytes),
74
75    /// A sequence of Unicode characters
76    String(Str),
77
78    /// Symbolic values from a constrained domain.
79    Symbol(Symbol),
80
81    /// List
82    List(List),
83
84    /// Map
85    Map(VariantMap),
86
87    /// Array
88    Array(Array),
89
90    /// Described value of primitive type. See `Variant::DescribedCompound` for
91    Described((Descriptor, Box<Variant>)),
92
93    /// Described value of compound or array type. See `Variant::DescribedCompound` for details.
94    DescribedCompound(DescribedCompound),
95}
96
97/// Represents a compound value with a descriptor. The value contains data starting with format code for the underlying AMQP type
98/// (right after the descriptor).
99#[derive(Debug, Clone, PartialEq, Eq, Hash)]
100pub struct DescribedCompound {
101    descriptor: Descriptor,
102    pub(crate) data: Bytes,
103}
104
105impl DescribedCompound {
106    /// Creates a representation of described value of compound value type based on `T` type's AMQP encoding.
107    /// `T`'s implementation of `Encode` is expected to produce the binary representation of the T in an underlying AMQP type value, starting from the format code.
108    /// For instance, if the described value is to be represented as an AMQP list with 1 ubyte field with a value of 3:
109    /// ```text
110    /// 0x00 0xa3 0x07 "foo:bar" 0xc0 0x02 0x01 0x50 0x03
111    /// ```
112    /// The `T::encode` method is expected to produce the following output:
113    /// ```text
114    /// 0xc0 0x02 0x01 0x50 0x03
115    /// ```
116    pub fn create<T: Encode>(descriptor: Descriptor, value: T) -> Self {
117        let size = value.encoded_size();
118        let mut buf = BytesMut::with_capacity(size);
119        value.encode(&mut buf);
120        DescribedCompound {
121            descriptor,
122            data: buf.freeze(),
123        }
124    }
125
126    pub(crate) fn new(descriptor: Descriptor, data: Bytes) -> Self {
127        DescribedCompound { descriptor, data }
128    }
129
130    pub fn descriptor(&self) -> &Descriptor {
131        &self.descriptor
132    }
133
134    /// Attempts to decode the described value as `T`.
135    /// `T`'s implementation of `Decode` is expected to parse the underlying AMQP type starting from the format code.
136    /// For instance, if the value is of described type represented by AMQP list with 1 ubyte field with a value of 3:
137    /// ```text
138    /// 0x00 0xa3 0x07 "foo:bar" 0xc0 0x02 0x01 0x50 0x03
139    /// ```
140    /// The `T::decode` method will be called with the following input:
141    /// ```text
142    /// 0xc0 0x02 0x01 0x50 0x03
143    /// ```
144    pub fn decode<T: Decode>(self) -> Result<T, AmqpParseError> {
145        let mut buf = self.data.clone();
146        let result = T::decode(&mut buf)?;
147        if buf.is_empty() {
148            Ok(result)
149        } else {
150            Err(AmqpParseError::InvalidSize)
151        }
152    }
153}
154
155impl Encode for DescribedCompound {
156    fn encoded_size(&self) -> usize {
157        self.descriptor.encoded_size() + self.data.len()
158    }
159
160    fn encode(&self, buf: &mut BytesMut) {
161        self.descriptor.encode(buf);
162        buf.extend_from_slice(&self.data);
163    }
164}
165
166impl From<HashMap<Variant, Variant>> for Variant {
167    fn from(data: HashMap<Variant, Variant>) -> Self {
168        Variant::Map(VariantMap { map: data })
169    }
170}
171
172impl From<ByteString> for Variant {
173    fn from(s: ByteString) -> Self {
174        Str::from(s).into()
175    }
176}
177
178impl From<String> for Variant {
179    fn from(s: String) -> Self {
180        Str::from(ByteString::from(s)).into()
181    }
182}
183
184impl From<&'static str> for Variant {
185    fn from(s: &'static str) -> Self {
186        Str::from(s).into()
187    }
188}
189
190impl PartialEq<str> for Variant {
191    fn eq(&self, other: &str) -> bool {
192        match self {
193            Variant::String(s) => s == other,
194            Variant::Symbol(s) => s == other,
195            _ => false,
196        }
197    }
198}
199
200impl Variant {
201    pub fn as_str(&self) -> Option<&str> {
202        match self {
203            Variant::String(s) => Some(s.as_str()),
204            Variant::Symbol(s) => Some(s.as_str()),
205            _ => None,
206        }
207    }
208
209    /// Expresses integer-typed variant values as i64 value when possible. Notably, does not include ulong.
210    /// Returns `None` for variants other than supported integers.
211    pub fn as_long(&self) -> Option<i64> {
212        match self {
213            Variant::Ubyte(v) => Some(*v as i64),
214            Variant::Ushort(v) => Some(*v as i64),
215            Variant::Uint(v) => Some(*v as i64),
216            Variant::Byte(v) => Some(*v as i64),
217            Variant::Short(v) => Some(*v as i64),
218            Variant::Int(v) => Some(*v as i64),
219            Variant::Long(v) => Some(*v),
220            _ => None,
221        }
222    }
223
224    /// Expresses unsigned integer-typed variant values as u64 value. Returns `None` for variants other than unsigned integers.
225    pub fn as_ulong(&self) -> Option<u64> {
226        match self {
227            Variant::Ubyte(v) => Some(*v as u64),
228            Variant::Ushort(v) => Some(*v as u64),
229            Variant::Uint(v) => Some(*v as u64),
230            Variant::Ulong(v) => Some(*v),
231            _ => None,
232        }
233    }
234
235    pub fn to_bytes_str(&self) -> Option<ByteString> {
236        match self {
237            Variant::String(s) => Some(s.to_bytes_str()),
238            Variant::Symbol(s) => Some(s.to_bytes_str()),
239            _ => None,
240        }
241    }
242}
243
244#[derive(PartialEq, Eq, Clone, Debug)]
245pub struct VariantMap {
246    pub map: HashMap<Variant, Variant>,
247}
248
249impl VariantMap {
250    pub fn new(map: HashMap<Variant, Variant>) -> VariantMap {
251        VariantMap { map }
252    }
253}
254
255#[allow(clippy::derived_hash_with_manual_eq)]
256impl Hash for VariantMap {
257    fn hash<H: Hasher>(&self, _state: &mut H) {
258        unimplemented!()
259    }
260}
261
262#[derive(PartialEq, Eq, Clone, Debug)]
263pub struct VecSymbolMap(pub Vec<(Symbol, Variant)>);
264
265impl Default for VecSymbolMap {
266    fn default() -> Self {
267        VecSymbolMap(Vec::with_capacity(8))
268    }
269}
270
271impl From<Annotations> for VecSymbolMap {
272    fn from(anns: Annotations) -> VecSymbolMap {
273        VecSymbolMap(anns.into_iter().collect())
274    }
275}
276
277impl From<Vec<(Symbol, Variant)>> for VecSymbolMap {
278    fn from(data: Vec<(Symbol, Variant)>) -> VecSymbolMap {
279        VecSymbolMap(data)
280    }
281}
282
283impl std::ops::Deref for VecSymbolMap {
284    type Target = Vec<(Symbol, Variant)>;
285
286    fn deref(&self) -> &Self::Target {
287        &self.0
288    }
289}
290
291impl std::ops::DerefMut for VecSymbolMap {
292    fn deref_mut(&mut self) -> &mut Self::Target {
293        &mut self.0
294    }
295}
296
297#[derive(PartialEq, Eq, Clone, Debug)]
298pub struct VecStringMap(pub Vec<(Str, Variant)>);
299
300impl Default for VecStringMap {
301    fn default() -> Self {
302        VecStringMap(Vec::with_capacity(8))
303    }
304}
305
306impl From<Vec<(Str, Variant)>> for VecStringMap {
307    fn from(data: Vec<(Str, Variant)>) -> VecStringMap {
308        VecStringMap(data)
309    }
310}
311
312impl From<HashMap<Str, Variant>> for VecStringMap {
313    fn from(map: HashMap<Str, Variant>) -> VecStringMap {
314        VecStringMap(map.into_iter().collect())
315    }
316}
317
318impl std::ops::Deref for VecStringMap {
319    type Target = Vec<(Str, Variant)>;
320
321    fn deref(&self) -> &Self::Target {
322        &self.0
323    }
324}
325
326impl std::ops::DerefMut for VecStringMap {
327    fn deref_mut(&mut self) -> &mut Self::Target {
328        &mut self.0
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use ntex_bytes::{Buf, BufMut};
335
336    use crate::{codec::ListHeader, format_codes};
337
338    use super::*;
339
340    #[test]
341    fn bytes_eq() {
342        let bytes1 = Variant::Binary(Bytes::from(&b"hello"[..]));
343        let bytes2 = Variant::Binary(Bytes::from(&b"hello"[..]));
344        let bytes3 = Variant::Binary(Bytes::from(&b"world"[..]));
345
346        assert_eq!(bytes1, bytes2);
347        assert!(bytes1 != bytes3);
348    }
349
350    #[test]
351    fn string_eq() {
352        let a = Variant::String(ByteString::from("hello").into());
353        let b = Variant::String(ByteString::from("world!").into());
354
355        assert_eq!(Variant::String(ByteString::from("hello").into()), a);
356        assert!(a != b);
357    }
358
359    #[test]
360    fn symbol_eq() {
361        let a = Variant::Symbol(Symbol::from("hello"));
362        let b = Variant::Symbol(Symbol::from("world!"));
363
364        assert_eq!(Variant::Symbol(Symbol::from("hello")), a);
365        assert!(a != b);
366    }
367
368    // <type name="mqtt-metadata" class="composite" source="list">
369    //   <descriptor name="contoso:test"/>
370    //   <field name="field1" type="string" mandatory="true"/>
371    //   <field name="field2" type="ubyte" mandatory="true"/>
372    //   <field name="field3" type="string"/>
373    // </type>
374    #[derive(Debug, PartialEq, Eq, Clone)]
375    struct CustomList {
376        field1: ByteString,
377        field2: u8,
378        field3: Option<ByteString>,
379    }
380
381    impl CustomList {
382        fn encoded_data_size(&self) -> usize {
383            let mut size = self.field1.encoded_size() + self.field2.encoded_size();
384            if let Some(ref field3) = self.field3 {
385                size += field3.encoded_size();
386            }
387            size
388        }
389    }
390
391    impl crate::DecodeFormatted for CustomList {
392        fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
393            let header = ListHeader::decode_with_format(input, fmt)?;
394            if header.count < 2 {
395                return Err(AmqpParseError::RequiredFieldOmitted("field2"));
396            }
397            let field1 = ByteString::decode(input)?;
398            let field2 = u8::decode(input)?;
399            let field3 = if header.count == 3 {
400                Some(ByteString::decode(input)?)
401            } else {
402                None
403            };
404            if input.has_remaining() {
405                return Err(AmqpParseError::InvalidSize);
406            }
407            Ok(CustomList {
408                field1,
409                field2,
410                field3,
411            })
412        }
413    }
414
415    impl crate::Encode for CustomList {
416        fn encoded_size(&self) -> usize {
417            let size = self.encoded_data_size();
418            if size + 1 > u8::MAX as usize {
419                size + 9 // 1 for format code, 4 for size, 4 for count
420            } else {
421                size + 3 // 1 for format code, 1 for size, 1 for count
422            }
423        }
424
425        fn encode(&self, buf: &mut BytesMut) {
426            let count = if self.field3.is_some() { 3u8 } else { 2u8 };
427            let data_size = self.encoded_data_size();
428            if data_size + 1 > u8::MAX as usize {
429                buf.put_u8(format_codes::FORMATCODE_LIST32);
430                buf.put_u32((4 + data_size) as u32); // size. 4 for count
431                buf.put_u32(count as u32); // count
432            } else {
433                buf.put_u8(format_codes::FORMATCODE_LIST8);
434                buf.put_u8((1 + data_size) as u8); // size. 1 for count
435                buf.put_u8(count); // count
436            }
437            self.field1.encode(buf);
438            self.field2.encode(buf);
439            if let Some(ref field3) = self.field3 {
440                field3.encode(buf);
441            }
442        }
443    }
444
445    #[test]
446    fn described_custom_list_recoding() {
447        let custom_list = CustomList {
448            field1: ByteString::from("value1"),
449            field2: 115,
450            field3: Some(ByteString::from("value3")),
451        };
452        let value = Variant::DescribedCompound(DescribedCompound::create(
453            Descriptor::Symbol("contoso:test".into()),
454            custom_list.clone(),
455        ));
456        let mut buf = BytesMut::with_capacity(value.encoded_size());
457        value.encode(&mut buf);
458        let data = buf.freeze();
459        assert_eq!(
460            data.as_ref(),
461            &b"\x00\xa3\x0ccontoso:test\xc0\x13\x03\xa1\x06value1\x50\x73\xa1\x06value3"[..]
462        );
463        let mut input = data.clone();
464        let decoded = Variant::decode(&mut input).unwrap();
465        assert_eq!(decoded, value);
466        let decoded_list = match decoded {
467            Variant::DescribedCompound(desc) => desc.decode::<CustomList>().unwrap(),
468            _ => panic!("Expected a described compound"),
469        };
470        assert_eq!(decoded_list, custom_list);
471    }
472}