ntex_grpc/
types.rs

1use std::{collections::HashMap, convert::TryFrom, fmt, hash::BuildHasher, hash::Hash, mem};
2
3use ntex_bytes::{Buf, BufMut, ByteString, Bytes, BytesMut};
4
5pub use crate::encoding::WireType;
6use crate::encoding::{self, DecodeError};
7
8/// Protobuf struct read/write operations
9pub trait Message: Default + Sized + fmt::Debug {
10    /// Decodes an instance of the message from a buffer
11    fn read(src: &mut Bytes) -> Result<Self, DecodeError>;
12
13    /// Encodes and writes the message to a buffer
14    fn write(&self, dst: &mut BytesMut);
15
16    /// Returns the encoded length of the message with a length delimiter
17    fn encoded_len(&self) -> usize;
18}
19
20/// Default type value
21pub enum DefaultValue<T> {
22    Unknown,
23    Default,
24    Value(T),
25}
26
27/// Protobuf type serializer
28pub trait NativeType: PartialEq + Default + Sized + fmt::Debug {
29    const TYPE: WireType;
30
31    #[inline]
32    /// Returns the encoded length of the message without a length delimiter.
33    fn value_len(&self) -> usize {
34        0
35    }
36
37    /// Deserialize from the input
38    fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError>;
39
40    /// Check if value is default
41    fn is_default(&self) -> bool {
42        false
43    }
44
45    /// Encode field value
46    fn encode_value(&self, dst: &mut BytesMut);
47
48    #[inline]
49    /// Encode field tag and length
50    fn encode_type(&self, tag: u32, dst: &mut BytesMut) {
51        encoding::encode_key(tag, Self::TYPE, dst);
52        if !matches!(Self::TYPE, WireType::Varint | WireType::SixtyFourBit) {
53            encoding::encode_varint(self.value_len() as u64, dst);
54        }
55    }
56
57    #[inline]
58    /// Protobuf field length
59    fn encoded_len(&self, tag: u32) -> usize {
60        let value_len = self.value_len();
61        encoding::key_len(tag) + encoding::encoded_len_varint(value_len as u64) + value_len
62    }
63
64    #[inline]
65    /// Serialize protobuf field
66    fn serialize(&self, tag: u32, default: DefaultValue<&Self>, dst: &mut BytesMut) {
67        let default = match default {
68            DefaultValue::Unknown => false,
69            DefaultValue::Default => self.is_default(),
70            DefaultValue::Value(d) => self == d,
71        };
72
73        if !default {
74            self.encode_type(tag, dst);
75            self.encode_value(dst);
76        }
77    }
78
79    #[inline]
80    /// Protobuf field length
81    fn serialized_len(&self, tag: u32, default: DefaultValue<&Self>) -> usize {
82        let default = match default {
83            DefaultValue::Unknown => false,
84            DefaultValue::Default => self.is_default(),
85            DefaultValue::Value(d) => self == d,
86        };
87
88        if default {
89            0
90        } else {
91            self.encoded_len(tag)
92        }
93    }
94
95    #[inline]
96    /// Deserialize protobuf field
97    fn deserialize(
98        &mut self,
99        _: u32,
100        wtype: WireType,
101        src: &mut Bytes,
102    ) -> Result<(), DecodeError> {
103        encoding::check_wire_type(Self::TYPE, wtype)?;
104
105        if matches!(Self::TYPE, WireType::Varint | WireType::SixtyFourBit) {
106            self.merge(src)
107        } else {
108            let len = encoding::decode_varint(src)? as usize;
109            let mut buf = src.split_to_checked(len).ok_or_else(|| {
110                DecodeError::new(format!(
111                    "Not enough data, message size {} buffer size {}",
112                    len,
113                    src.len()
114                ))
115            })?;
116            self.merge(&mut buf)
117        }
118    }
119
120    #[inline]
121    /// Deserialize protobuf field to default value
122    fn deserialize_default(
123        tag: u32,
124        wtype: WireType,
125        src: &mut Bytes,
126    ) -> Result<Self, DecodeError> {
127        let mut value = Self::default();
128        value.deserialize(tag, wtype, src)?;
129        Ok(value)
130    }
131}
132
133/// Protobuf struct read/write operations
134impl Message for () {
135    fn encoded_len(&self) -> usize {
136        0
137    }
138    fn read(_: &mut Bytes) -> Result<Self, DecodeError> {
139        Ok(())
140    }
141    fn write(&self, _: &mut BytesMut) {}
142}
143
144impl<T: Message + PartialEq> NativeType for T {
145    const TYPE: WireType = WireType::LengthDelimited;
146
147    fn value_len(&self) -> usize {
148        Message::encoded_len(self)
149    }
150
151    #[inline]
152    /// Encode message to the buffer
153    fn encode_value(&self, dst: &mut BytesMut) {
154        self.write(dst)
155    }
156
157    /// Deserialize from the input
158    fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
159        *self = Message::read(src)?;
160        Ok(())
161    }
162}
163
164impl NativeType for Bytes {
165    const TYPE: WireType = WireType::LengthDelimited;
166
167    #[inline]
168    fn value_len(&self) -> usize {
169        self.len()
170    }
171
172    #[inline]
173    /// Serialize field value
174    fn encode_value(&self, dst: &mut BytesMut) {
175        dst.extend_from_slice(self);
176    }
177
178    #[inline]
179    /// Deserialize from the input
180    fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
181        *self = mem::take(src);
182        Ok(())
183    }
184
185    #[inline]
186    fn is_default(&self) -> bool {
187        self.is_empty()
188    }
189}
190
191impl NativeType for String {
192    const TYPE: WireType = WireType::LengthDelimited;
193
194    #[inline]
195    fn value_len(&self) -> usize {
196        self.len()
197    }
198
199    #[inline]
200    fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
201        if let Ok(s) = ByteString::try_from(mem::take(src)) {
202            *self = s.as_str().to_string();
203            Ok(())
204        } else {
205            Err(DecodeError::new(
206                "invalid string value: data is not UTF-8 encoded",
207            ))
208        }
209    }
210
211    #[inline]
212    fn encode_value(&self, dst: &mut BytesMut) {
213        dst.extend_from_slice(self.as_bytes());
214    }
215
216    #[inline]
217    fn is_default(&self) -> bool {
218        self.is_empty()
219    }
220}
221
222impl NativeType for ByteString {
223    const TYPE: WireType = WireType::LengthDelimited;
224
225    #[inline]
226    fn value_len(&self) -> usize {
227        self.as_slice().len()
228    }
229
230    #[inline]
231    fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
232        if let Ok(s) = ByteString::try_from(mem::take(src)) {
233            *self = s;
234            Ok(())
235        } else {
236            Err(DecodeError::new(
237                "invalid string value: data is not UTF-8 encoded",
238            ))
239        }
240    }
241
242    #[inline]
243    fn encode_value(&self, dst: &mut BytesMut) {
244        dst.extend_from_slice(self.as_bytes());
245    }
246
247    #[inline]
248    fn is_default(&self) -> bool {
249        self.is_empty()
250    }
251}
252
253impl<T: NativeType> NativeType for Option<T> {
254    const TYPE: WireType = WireType::LengthDelimited;
255
256    #[inline]
257    fn is_default(&self) -> bool {
258        self.is_none()
259    }
260
261    #[inline]
262    /// Serialize field value
263    fn encode_value(&self, _: &mut BytesMut) {}
264
265    #[inline]
266    /// Deserialize from the input
267    fn merge(&mut self, _: &mut Bytes) -> Result<(), DecodeError> {
268        Err(DecodeError::new(
269            "Cannot directly call deserialize for Option<T>",
270        ))
271    }
272
273    #[inline]
274    /// Deserialize protobuf field
275    fn deserialize(
276        &mut self,
277        tag: u32,
278        wtype: WireType,
279        src: &mut Bytes,
280    ) -> Result<(), DecodeError> {
281        let mut value: T = Default::default();
282        value.deserialize(tag, wtype, src)?;
283        *self = Some(value);
284        Ok(())
285    }
286
287    #[inline]
288    /// Serialize protobuf field
289    fn serialize(&self, tag: u32, _: DefaultValue<&Self>, dst: &mut BytesMut) {
290        if let Some(ref value) = self {
291            value.serialize(tag, DefaultValue::Unknown, dst);
292        }
293    }
294
295    #[inline]
296    /// Protobuf field length
297    fn serialized_len(&self, tag: u32, _: DefaultValue<&Self>) -> usize {
298        if let Some(ref value) = self {
299            value.serialized_len(tag, DefaultValue::Unknown)
300        } else {
301            0
302        }
303    }
304
305    #[inline]
306    /// Protobuf field length
307    fn encoded_len(&self, tag: u32) -> usize {
308        self.as_ref()
309            .map(|value| value.encoded_len(tag))
310            .unwrap_or(0)
311    }
312}
313
314impl NativeType for Vec<u8> {
315    const TYPE: WireType = WireType::LengthDelimited;
316
317    #[inline]
318    fn value_len(&self) -> usize {
319        self.len()
320    }
321
322    #[inline]
323    /// Serialize field value
324    fn encode_value(&self, dst: &mut BytesMut) {
325        dst.extend_from_slice(self.as_slice());
326    }
327
328    #[inline]
329    /// Deserialize from the input
330    fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
331        *self = Vec::from(&src[..]);
332        Ok(())
333    }
334
335    #[inline]
336    fn is_default(&self) -> bool {
337        self.is_empty()
338    }
339}
340
341impl<T: NativeType> NativeType for Vec<T> {
342    const TYPE: WireType = WireType::LengthDelimited;
343
344    #[inline]
345    /// Serialize field value
346    fn encode_value(&self, _: &mut BytesMut) {}
347
348    #[inline]
349    /// Deserialize from the input
350    fn merge(&mut self, _: &mut Bytes) -> Result<(), DecodeError> {
351        Err(DecodeError::new("Cannot directly call merge for Vec<T>"))
352    }
353
354    /// Deserialize protobuf field
355    fn deserialize(
356        &mut self,
357        tag: u32,
358        wtype: WireType,
359        src: &mut Bytes,
360    ) -> Result<(), DecodeError> {
361        if T::TYPE == WireType::Varint {
362            let len = encoding::decode_varint(src)? as usize;
363            let mut buf = src
364                .split_to_checked(len)
365                .ok_or_else(DecodeError::incomplete)?;
366            while !buf.is_empty() {
367                let mut value: T = Default::default();
368                value.merge(&mut buf)?;
369                self.push(value);
370            }
371        } else {
372            let mut value: T = Default::default();
373            value.deserialize(tag, wtype, src)?;
374            self.push(value);
375        }
376        Ok(())
377    }
378
379    /// Serialize protobuf field
380    fn serialize(&self, tag: u32, _: DefaultValue<&Self>, dst: &mut BytesMut) {
381        if T::TYPE == WireType::Varint {
382            encoding::encode_key(tag, WireType::LengthDelimited, dst);
383            encoding::encode_varint(
384                self.iter().map(|v| v.value_len()).sum::<usize>() as u64,
385                dst,
386            );
387            for item in self.iter() {
388                item.encode_value(dst);
389            }
390        } else {
391            for item in self.iter() {
392                item.serialize(tag, DefaultValue::Unknown, dst);
393            }
394        }
395    }
396
397    #[inline]
398    fn is_default(&self) -> bool {
399        self.is_empty()
400    }
401
402    /// Protobuf field length
403    fn encoded_len(&self, tag: u32) -> usize {
404        if T::TYPE == WireType::Varint {
405            let len = self.iter().map(|value| value.value_len()).sum::<usize>();
406            self.iter().map(|value| value.value_len()).sum::<usize>()
407                + encoding::key_len(tag)
408                + encoding::encoded_len_varint(len as u64)
409        } else {
410            self.iter().map(|value| value.encoded_len(tag)).sum()
411        }
412    }
413}
414
415impl<K: NativeType + Eq + Hash, V: NativeType, S: BuildHasher + Default> NativeType
416    for HashMap<K, V, S>
417{
418    const TYPE: WireType = WireType::LengthDelimited;
419
420    #[inline]
421    /// Deserialize from the input
422    fn merge(&mut self, _: &mut Bytes) -> Result<(), DecodeError> {
423        Err(DecodeError::new("Cannot directly call merge for Map<K, V>"))
424    }
425
426    #[inline]
427    /// Serialize field value
428    fn encode_value(&self, _: &mut BytesMut) {}
429
430    #[inline]
431    fn is_default(&self) -> bool {
432        self.is_empty()
433    }
434
435    /// Deserialize protobuf field
436    fn deserialize(
437        &mut self,
438        _: u32,
439        wtype: WireType,
440        src: &mut Bytes,
441    ) -> Result<(), DecodeError> {
442        encoding::check_wire_type(Self::TYPE, wtype)?;
443
444        let len = encoding::decode_varint(src)? as usize;
445        let mut buf = src.split_to_checked(len).ok_or_else(|| {
446            DecodeError::new(format!(
447                "Not enough data for HashMap, message size {}, buf size {}",
448                len,
449                src.len()
450            ))
451        })?;
452        let mut key = Default::default();
453        let mut val = Default::default();
454
455        while !buf.is_empty() {
456            let (tag, wire_type) = encoding::decode_key(&mut buf)?;
457            match tag {
458                1 => NativeType::deserialize(&mut key, 1, wire_type, &mut buf)?,
459                2 => NativeType::deserialize(&mut val, 2, wire_type, &mut buf)?,
460                _ => return Err(DecodeError::new("Map deserialization error")),
461            }
462        }
463        self.insert(key, val);
464        Ok(())
465    }
466
467    /// Serialize protobuf field
468    fn serialize(&self, tag: u32, _: DefaultValue<&Self>, dst: &mut BytesMut) {
469        let key_default = K::default();
470        let val_default = V::default();
471
472        for item in self.iter() {
473            let skip_key = item.0 == &key_default;
474            let skip_val = item.1 == &val_default;
475
476            let len = (if skip_key { 0 } else { item.0.encoded_len(1) })
477                + (if skip_val { 0 } else { item.1.encoded_len(2) });
478
479            encoding::encode_key(tag, WireType::LengthDelimited, dst);
480            encoding::encode_varint(len as u64, dst);
481            if !skip_key {
482                item.0.serialize(1, DefaultValue::Default, dst);
483            }
484            if !skip_val {
485                item.1.serialize(2, DefaultValue::Default, dst);
486            }
487        }
488    }
489
490    /// Generic protobuf map encode function with an overridden value default.
491    fn encoded_len(&self, tag: u32) -> usize {
492        let key_default = K::default();
493        let val_default = V::default();
494
495        self.iter()
496            .map(|(key, val)| {
497                let len = (if key == &key_default {
498                    0
499                } else {
500                    key.encoded_len(1)
501                }) + (if val == &val_default {
502                    0
503                } else {
504                    val.encoded_len(2)
505                });
506
507                encoding::key_len(tag) + encoding::encoded_len_varint(len as u64) + len
508            })
509            .sum::<usize>()
510    }
511}
512
513/// Macro which emits a module containing a set of encoding functions for a
514/// variable width numeric type.
515macro_rules! varint {
516    ($ty:ident, $default:expr) => (
517        varint!($ty, $default, to_uint64(self) { *self as u64 }, from_uint64(v) { v as $ty });
518    );
519
520    ($ty:ty, $default:expr, to_uint64($slf:ident) $to_uint64:expr, from_uint64($val:ident) $from_uint64:expr) => (
521
522        impl NativeType for $ty {
523            const TYPE: WireType = WireType::Varint;
524
525            #[inline]
526            fn is_default(&self) -> bool {
527                *self == $default
528            }
529
530            #[inline]
531            fn encode_value(&$slf, dst: &mut BytesMut) {
532                encoding::encode_varint($to_uint64, dst);
533            }
534
535            #[inline]
536            fn encoded_len(&$slf, tag: u32) -> usize {
537                encoding::key_len(tag) + encoding::encoded_len_varint($to_uint64)
538            }
539
540            #[inline]
541            fn value_len(&$slf) -> usize {
542                encoding::encoded_len_varint($to_uint64)
543            }
544
545            #[inline]
546            fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
547                *self = encoding::decode_varint(src).map(|$val| $from_uint64)?;
548                Ok(())
549            }
550        }
551    );
552}
553
554varint!(bool, false,
555        to_uint64(self) u64::from(*self),
556        from_uint64(value) value != 0);
557varint!(i32, 0i32);
558varint!(i64, 0i64);
559varint!(u32, 0u32);
560varint!(u64, 0u64);
561
562/// Macro which emits a module containing a set of encoding functions for a
563/// fixed width numeric type.
564macro_rules! fixed_width {
565    ($ty:ty,
566     $width:expr,
567     $wire_type:expr,
568     $default:expr,
569     $put:expr,
570     $get:expr) => {
571        impl NativeType for $ty {
572            const TYPE: WireType = $wire_type;
573
574            #[inline]
575            fn is_default(&self) -> bool {
576                *self == $default
577            }
578
579            #[inline]
580            fn encode_value(&self, dst: &mut BytesMut) {
581                $put(dst, *self);
582            }
583
584            #[inline]
585            fn encoded_len(&self, tag: u32) -> usize {
586                encoding::key_len(tag) + $width
587            }
588
589            #[inline]
590            fn value_len(&self) -> usize {
591                $width
592            }
593
594            #[inline]
595            fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
596                if src.len() < $width {
597                    return Err(DecodeError::new("Buffer underflow"));
598                }
599                *self = $get(src);
600                Ok(())
601            }
602        }
603    };
604}
605
606fixed_width!(
607    f32,
608    4,
609    WireType::ThirtyTwoBit,
610    0f32,
611    BufMut::put_f32_le,
612    Buf::get_f32_le
613);
614fixed_width!(
615    f64,
616    8,
617    WireType::SixtyFourBit,
618    0f64,
619    BufMut::put_f64_le,
620    Buf::get_f64_le
621);
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626
627    #[derive(Clone, PartialEq, Debug, Default)]
628    pub struct TestMessage {
629        f: f64,
630        props: HashMap<String, u32>,
631        b: bool,
632        opt: Option<String>,
633    }
634
635    impl Message for TestMessage {
636        fn write(&self, dst: &mut BytesMut) {
637            NativeType::serialize(&self.f, 1, DefaultValue::Default, dst);
638            NativeType::serialize(&self.props, 2, DefaultValue::Default, dst);
639            NativeType::serialize(&self.b, 3, DefaultValue::Default, dst);
640            NativeType::serialize(&self.opt, 4, DefaultValue::Default, dst);
641        }
642
643        #[inline]
644        fn read(src: &mut Bytes) -> Result<Self, DecodeError> {
645            let mut msg = Self::default();
646            while !src.is_empty() {
647                let (tag, wire_type) = encoding::decode_key(src)?;
648                match tag {
649                    1 => NativeType::deserialize(&mut msg.f, tag, wire_type, src)?,
650                    2 => NativeType::deserialize(&mut msg.props, tag, wire_type, src)?,
651                    3 => NativeType::deserialize(&mut msg.b, tag, wire_type, src)?,
652                    4 => NativeType::deserialize(&mut msg.opt, tag, wire_type, src)?,
653                    _ => encoding::skip_field(wire_type, tag, src)?,
654                }
655            }
656            Ok(msg)
657        }
658
659        #[inline]
660        fn encoded_len(&self) -> usize {
661            0 + NativeType::serialized_len(&self.f, 1, DefaultValue::Default)
662                + NativeType::serialized_len(&self.props, 2, DefaultValue::Default)
663                + NativeType::serialized_len(&self.b, 3, DefaultValue::Default)
664                + NativeType::serialized_len(&self.opt, 4, DefaultValue::Default)
665        }
666    }
667
668    #[test]
669    fn test_hashmap_default_values() {
670        let mut msg = TestMessage::default();
671
672        msg.f = 382.8263;
673        msg.b = true;
674        msg.props.insert("test1".to_string(), 1);
675        msg.props.insert("test2".to_string(), 0);
676        msg.props.insert("".to_string(), 0);
677
678        let mut buf = BytesMut::new();
679        msg.write(&mut buf);
680        assert_eq!(Message::encoded_len(&msg), 33);
681        assert_eq!(buf.len(), 33);
682
683        let mut buf2 = BytesMut::new();
684        msg.serialize(1, DefaultValue::Default, &mut buf2);
685        assert_eq!(NativeType::encoded_len(&msg, 1), 35);
686        assert_eq!(buf2.len(), 35);
687
688        let msg2 = TestMessage::read(&mut buf.freeze()).unwrap();
689        assert_eq!(Message::encoded_len(&msg2), 33);
690        assert_eq!(msg, msg2);
691
692        let mut buf2 = buf2.freeze();
693        let mut msg3 = TestMessage::default();
694        let (tag, wire_type) = encoding::decode_key(&mut buf2).unwrap();
695        msg3.deserialize(tag, wire_type, &mut buf2).unwrap();
696        assert_eq!(msg, msg3);
697    }
698}