Skip to main content

ntex_amqp_codec/types/
variant.rs

1use std::hash::{Hash, Hasher};
2
3use chrono::{DateTime, Utc};
4use ntex_bytes::{ByteString, Bytes, BytesMut};
5use ordered_float::OrderedFloat;
6use uuid::Uuid;
7
8use crate::types::{Array, Descriptor, List, Str, Symbol};
9use crate::{AmqpParseError, Decode, Encode, HashMap, protocol::Annotations};
10
11/// Represents an AMQP type for use in polymorphic collections
12#[derive(Debug, Eq, PartialEq, Hash, Clone, From)]
13pub enum Variant {
14    /// Indicates an empty value.
15    Null,
16
17    /// Represents a true or false value.
18    Boolean(bool),
19
20    /// Integer in the range 0 to 2^8 - 1 inclusive.
21    Ubyte(u8),
22
23    /// Integer in the range 0 to 2^16 - 1 inclusive.
24    Ushort(u16),
25
26    /// Integer in the range 0 to 2^32 - 1 inclusive.
27    Uint(u32),
28
29    /// Integer in the range 0 to 2^64 - 1 inclusive.
30    Ulong(u64),
31
32    /// Integer in the range 0 to 2^7 - 1 inclusive.
33    Byte(i8),
34
35    /// Integer in the range 0 to 2^15 - 1 inclusive.
36    Short(i16),
37
38    /// Integer in the range 0 to 2^32 - 1 inclusive.
39    Int(i32),
40
41    /// Integer in the range 0 to 2^64 - 1 inclusive.
42    Long(i64),
43
44    /// 32-bit floating point number (IEEE 754-2008 binary32).
45    Float(OrderedFloat<f32>),
46
47    /// 64-bit floating point number (IEEE 754-2008 binary64).
48    Double(OrderedFloat<f64>),
49
50    /// 32-bit decimal number, represented per IEEE 754-2008 decimal32 specification.
51    Decimal32([u8; 4]),
52
53    /// 64-bit decimal number, represented per IEEE 754-2008 decimal64 specification.
54    Decimal64([u8; 8]),
55
56    /// 128-bit decimal number, represented per IEEE 754-2008 decimal128 specification.
57    Decimal128([u8; 16]),
58
59    /// A single Unicode character.
60    Char(char),
61
62    /// An absolute point in time.
63    /// Represents an approximate point in time using the Unix time encoding of
64    /// UTC with a precision of milliseconds. For example, 1311704463521
65    /// represents the moment 2011-07-26T18:21:03.521Z.
66    Timestamp(DateTime<Utc>),
67
68    /// A universally unique identifier as defined by RFC-4122 section 4.1.2
69    Uuid(Uuid),
70
71    /// A sequence of octets.
72    Binary(Bytes),
73
74    /// A sequence of Unicode characters
75    String(Str),
76
77    /// Symbolic values from a constrained domain.
78    Symbol(Symbol),
79
80    /// List
81    List(List),
82
83    /// Map
84    Map(VariantMap),
85
86    /// Array
87    Array(Array),
88
89    /// Described value of primitive type. See `Variant::DescribedCompound` for
90    Described((Descriptor, Box<Variant>)),
91
92    /// Described value of compound or array type. See `Variant::DescribedCompound` for details.
93    DescribedCompound(DescribedCompound),
94}
95
96/// Represents a compound value with a descriptor. The value contains data starting with format code for the underlying AMQP type
97/// (right after the descriptor).
98#[derive(Debug, Clone, PartialEq, Eq, Hash)]
99pub struct DescribedCompound {
100    descriptor: Descriptor,
101    pub(crate) data: Bytes,
102}
103
104impl DescribedCompound {
105    /// Creates a representation of described value of compound value type based on `T` type's AMQP encoding.
106    /// `T`'s implementation of `Encode` is expected to produce the binary representation of the T in an underlying AMQP type value, starting from the format code.
107    /// For instance, if the described value is to be represented as an AMQP list with 1 ubyte field with a value of 3:
108    /// ```text
109    /// 0x00 0xa3 0x07 "foo:bar" 0xc0 0x02 0x01 0x50 0x03
110    /// ```
111    /// The `T::encode` method is expected to produce the following output:
112    /// ```text
113    /// 0xc0 0x02 0x01 0x50 0x03
114    /// ```
115    pub fn create<T: Encode>(descriptor: Descriptor, value: T) -> Self {
116        let size = value.encoded_size();
117        let mut buf = BytesMut::with_capacity(size);
118        value.encode(&mut buf);
119        DescribedCompound {
120            descriptor,
121            data: buf.freeze(),
122        }
123    }
124
125    pub(crate) fn new(descriptor: Descriptor, data: Bytes) -> Self {
126        DescribedCompound { descriptor, data }
127    }
128
129    pub fn descriptor(&self) -> &Descriptor {
130        &self.descriptor
131    }
132
133    /// Attempts to decode the described value as `T`.
134    /// `T`'s implementation of `Decode` is expected to parse the underlying AMQP type starting from the format code.
135    /// For instance, if the value is of described type represented by AMQP list with 1 ubyte field with a value of 3:
136    /// ```text
137    /// 0x00 0xa3 0x07 "foo:bar" 0xc0 0x02 0x01 0x50 0x03
138    /// ```
139    /// The `T::decode` method will be called with the following input:
140    /// ```text
141    /// 0xc0 0x02 0x01 0x50 0x03
142    /// ```
143    pub fn decode<T: Decode>(&self) -> Result<T, AmqpParseError> {
144        let mut buf = self.data.clone();
145        let result = T::decode(&mut buf)?;
146        if buf.is_empty() {
147            Ok(result)
148        } else {
149            Err(AmqpParseError::InvalidSize)
150        }
151    }
152}
153
154impl Encode for DescribedCompound {
155    fn encoded_size(&self) -> usize {
156        self.descriptor.encoded_size() + self.data.len()
157    }
158
159    fn encode(&self, buf: &mut BytesMut) {
160        self.descriptor.encode(buf);
161        buf.extend_from_slice(&self.data);
162    }
163}
164
165impl From<HashMap<Variant, Variant>> for Variant {
166    fn from(data: HashMap<Variant, Variant>) -> Self {
167        Variant::Map(VariantMap { map: data })
168    }
169}
170
171impl From<ByteString> for Variant {
172    fn from(s: ByteString) -> Self {
173        Str::from(s).into()
174    }
175}
176
177impl From<String> for Variant {
178    fn from(s: String) -> Self {
179        Str::from(ByteString::from(s)).into()
180    }
181}
182
183impl From<&'static str> for Variant {
184    fn from(s: &'static str) -> Self {
185        Str::from(s).into()
186    }
187}
188
189impl PartialEq<str> for Variant {
190    fn eq(&self, other: &str) -> bool {
191        match self {
192            Variant::String(s) => s == other,
193            Variant::Symbol(s) => s == other,
194            _ => false,
195        }
196    }
197}
198
199impl Variant {
200    pub fn as_str(&self) -> Option<&str> {
201        match self {
202            Variant::String(s) => Some(s.as_str()),
203            Variant::Symbol(s) => Some(s.as_str()),
204            _ => None,
205        }
206    }
207
208    /// Expresses integer-typed variant values as i64 value when possible. Notably, does not include ulong.
209    /// Returns `None` for variants other than supported integers.
210    pub fn as_long(&self) -> Option<i64> {
211        match self {
212            Variant::Ubyte(v) => Some(*v as i64),
213            Variant::Ushort(v) => Some(*v as i64),
214            Variant::Uint(v) => Some(*v as i64),
215            Variant::Byte(v) => Some(*v as i64),
216            Variant::Short(v) => Some(*v as i64),
217            Variant::Int(v) => Some(*v as i64),
218            Variant::Long(v) => Some(*v),
219            _ => None,
220        }
221    }
222
223    /// Expresses unsigned integer-typed variant values as u64 value. Returns `None` for variants other than unsigned integers.
224    pub fn as_ulong(&self) -> Option<u64> {
225        match self {
226            Variant::Ubyte(v) => Some(*v as u64),
227            Variant::Ushort(v) => Some(*v as u64),
228            Variant::Uint(v) => Some(*v as u64),
229            Variant::Ulong(v) => Some(*v),
230            _ => None,
231        }
232    }
233
234    pub fn to_bytes_str(&self) -> Option<ByteString> {
235        match self {
236            Variant::String(s) => Some(s.to_bytes_str()),
237            Variant::Symbol(s) => Some(s.to_bytes_str()),
238            _ => None,
239        }
240    }
241}
242
243#[derive(PartialEq, Eq, Clone, Debug)]
244pub struct VariantMap {
245    pub map: HashMap<Variant, Variant>,
246}
247
248impl VariantMap {
249    pub fn new(map: HashMap<Variant, Variant>) -> VariantMap {
250        VariantMap { map }
251    }
252}
253
254#[allow(clippy::derived_hash_with_manual_eq)]
255impl Hash for VariantMap {
256    fn hash<H: Hasher>(&self, _state: &mut H) {
257        unimplemented!()
258    }
259}
260
261#[derive(PartialEq, Eq, Clone, Debug)]
262pub struct VecSymbolMap(pub Vec<(Symbol, Variant)>);
263
264impl Default for VecSymbolMap {
265    fn default() -> Self {
266        VecSymbolMap(Vec::with_capacity(8))
267    }
268}
269
270impl From<Annotations> for VecSymbolMap {
271    fn from(anns: Annotations) -> VecSymbolMap {
272        VecSymbolMap(anns.into_iter().collect())
273    }
274}
275
276impl From<Vec<(Symbol, Variant)>> for VecSymbolMap {
277    fn from(data: Vec<(Symbol, Variant)>) -> VecSymbolMap {
278        VecSymbolMap(data)
279    }
280}
281
282impl std::ops::Deref for VecSymbolMap {
283    type Target = Vec<(Symbol, Variant)>;
284
285    fn deref(&self) -> &Self::Target {
286        &self.0
287    }
288}
289
290impl std::ops::DerefMut for VecSymbolMap {
291    fn deref_mut(&mut self) -> &mut Self::Target {
292        &mut self.0
293    }
294}
295
296#[derive(PartialEq, Eq, Clone, Debug)]
297pub struct VecStringMap(pub Vec<(Str, Variant)>);
298
299impl Default for VecStringMap {
300    fn default() -> Self {
301        VecStringMap(Vec::with_capacity(8))
302    }
303}
304
305impl From<Vec<(Str, Variant)>> for VecStringMap {
306    fn from(data: Vec<(Str, Variant)>) -> VecStringMap {
307        VecStringMap(data)
308    }
309}
310
311impl From<HashMap<Str, Variant>> for VecStringMap {
312    fn from(map: HashMap<Str, Variant>) -> VecStringMap {
313        VecStringMap(map.into_iter().collect())
314    }
315}
316
317impl std::ops::Deref for VecStringMap {
318    type Target = Vec<(Str, Variant)>;
319
320    fn deref(&self) -> &Self::Target {
321        &self.0
322    }
323}
324
325impl std::ops::DerefMut for VecStringMap {
326    fn deref_mut(&mut self) -> &mut Self::Target {
327        &mut self.0
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use ntex_bytes::{Buf, BufMut};
334
335    use crate::{codec::ListHeader, format_codes};
336
337    use super::*;
338
339    #[test]
340    fn bytes_eq() {
341        let bytes1 = Variant::Binary(Bytes::from(&b"hello"[..]));
342        let bytes2 = Variant::Binary(Bytes::from(&b"hello"[..]));
343        let bytes3 = Variant::Binary(Bytes::from(&b"world"[..]));
344
345        assert_eq!(bytes1, bytes2);
346        assert!(bytes1 != bytes3);
347    }
348
349    #[test]
350    fn string_eq() {
351        let a = Variant::String(ByteString::from("hello").into());
352        let b = Variant::String(ByteString::from("world!").into());
353
354        assert_eq!(Variant::String(ByteString::from("hello").into()), a);
355        assert!(a != b);
356    }
357
358    #[test]
359    fn symbol_eq() {
360        let a = Variant::Symbol(Symbol::from("hello"));
361        let b = Variant::Symbol(Symbol::from("world!"));
362
363        assert_eq!(Variant::Symbol(Symbol::from("hello")), a);
364        assert!(a != b);
365    }
366
367    // <type name="mqtt-metadata" class="composite" source="list">
368    //   <descriptor name="contoso:test"/>
369    //   <field name="field1" type="string" mandatory="true"/>
370    //   <field name="field2" type="ubyte" mandatory="true"/>
371    //   <field name="field3" type="string"/>
372    // </type>
373    #[derive(Debug, PartialEq, Eq, Clone)]
374    struct CustomList {
375        field1: ByteString,
376        field2: u8,
377        field3: Option<ByteString>,
378    }
379
380    impl CustomList {
381        fn encoded_data_size(&self) -> usize {
382            let mut size = self.field1.encoded_size() + self.field2.encoded_size();
383            if let Some(ref field3) = self.field3 {
384                size += field3.encoded_size();
385            }
386            size
387        }
388    }
389
390    impl crate::DecodeFormatted for CustomList {
391        fn decode_with_format(input: &mut Bytes, fmt: u8) -> Result<Self, AmqpParseError> {
392            let header = ListHeader::decode_with_format(input, fmt)?;
393            if header.count < 2 {
394                return Err(AmqpParseError::RequiredFieldOmitted("field2"));
395            }
396            let field1 = ByteString::decode(input)?;
397            let field2 = u8::decode(input)?;
398            let field3 = if header.count == 3 {
399                Some(ByteString::decode(input)?)
400            } else {
401                None
402            };
403            if input.has_remaining() {
404                return Err(AmqpParseError::InvalidSize);
405            }
406            Ok(CustomList {
407                field1,
408                field2,
409                field3,
410            })
411        }
412    }
413
414    impl crate::Encode for CustomList {
415        fn encoded_size(&self) -> usize {
416            let size = self.encoded_data_size();
417            if size + 1 > u8::MAX as usize {
418                size + 9 // 1 for format code, 4 for size, 4 for count
419            } else {
420                size + 3 // 1 for format code, 1 for size, 1 for count
421            }
422        }
423
424        fn encode(&self, buf: &mut BytesMut) {
425            let count = if self.field3.is_some() { 3u8 } else { 2u8 };
426            let data_size = self.encoded_data_size();
427            if data_size + 1 > u8::MAX as usize {
428                buf.put_u8(format_codes::FORMATCODE_LIST32);
429                buf.put_u32((4 + data_size) as u32); // size. 4 for count
430                buf.put_u32(count as u32); // count
431            } else {
432                buf.put_u8(format_codes::FORMATCODE_LIST8);
433                buf.put_u8((1 + data_size) as u8); // size. 1 for count
434                buf.put_u8(count); // count
435            }
436            self.field1.encode(buf);
437            self.field2.encode(buf);
438            if let Some(ref field3) = self.field3 {
439                field3.encode(buf);
440            }
441        }
442    }
443
444    #[test]
445    fn described_custom_list_recoding() {
446        let custom_list = CustomList {
447            field1: ByteString::from("value1"),
448            field2: 115,
449            field3: Some(ByteString::from("value3")),
450        };
451        let value = Variant::DescribedCompound(DescribedCompound::create(
452            Descriptor::Symbol("contoso:test".into()),
453            custom_list.clone(),
454        ));
455        let mut buf = BytesMut::with_capacity(value.encoded_size());
456        value.encode(&mut buf);
457        let data = buf.freeze();
458        assert_eq!(
459            data.as_ref(),
460            &b"\x00\xa3\x0ccontoso:test\xc0\x13\x03\xa1\x06value1\x50\x73\xa1\x06value3"[..]
461        );
462        let mut input = data.clone();
463        let decoded = Variant::decode(&mut input).unwrap();
464        assert_eq!(decoded, value);
465        let decoded_list = match decoded {
466            Variant::DescribedCompound(desc) => desc.decode::<CustomList>().unwrap(),
467            _ => panic!("Expected a described compound"),
468        };
469        assert_eq!(decoded_list, custom_list);
470    }
471}