Skip to main content

ntex_grpc/
types.rs

1#![allow(
2    clippy::cast_lossless,
3    clippy::cast_sign_loss,
4    clippy::cast_possible_wrap
5)]
6use std::{collections::HashMap, convert::TryFrom, fmt, hash::BuildHasher, hash::Hash, mem};
7
8use ntex_bytes::{Buf, BufMut, ByteString, Bytes, BytesMut};
9
10pub use crate::encoding::WireType;
11use crate::encoding::{self, DecodeError};
12
13/// Protobuf struct read/write operations
14pub trait Message: Default + Sized + fmt::Debug {
15    /// Decodes an instance of the message from a buffer
16    fn read(src: &mut Bytes) -> Result<Self, DecodeError>;
17
18    /// Encodes and writes the message to a buffer
19    fn write(&self, dst: &mut BytesMut);
20
21    /// Returns the encoded length of the message with a length delimiter
22    fn encoded_len(&self) -> usize;
23}
24
25/// Default type value
26pub enum DefaultValue<T> {
27    Unknown,
28    Default,
29    Value(T),
30}
31
32/// Protobuf type serializer
33pub trait NativeType: PartialEq + Default + Sized + fmt::Debug {
34    const TYPE: WireType;
35
36    #[inline]
37    /// Returns the encoded length of the message without a length delimiter.
38    fn value_len(&self) -> usize {
39        0
40    }
41
42    /// Deserialize from the input
43    fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError>;
44
45    /// Check if value is default
46    fn is_default(&self) -> bool {
47        false
48    }
49
50    /// Encode field value
51    fn encode_value(&self, dst: &mut BytesMut);
52
53    #[inline]
54    /// Encode field tag and length
55    fn encode_type(&self, tag: u32, dst: &mut BytesMut) {
56        encoding::encode_key(tag, Self::TYPE, dst);
57        if !matches!(Self::TYPE, WireType::Varint | WireType::SixtyFourBit) {
58            encoding::encode_varint(self.value_len() as u64, dst);
59        }
60    }
61
62    #[inline]
63    /// Protobuf field length
64    fn encoded_len(&self, tag: u32) -> usize {
65        let value_len = self.value_len();
66        encoding::key_len(tag) + encoding::encoded_len_varint(value_len as u64) + value_len
67    }
68
69    #[inline]
70    /// Serialize protobuf field
71    fn serialize(&self, tag: u32, default: DefaultValue<&Self>, dst: &mut BytesMut) {
72        let default = match default {
73            DefaultValue::Unknown => false,
74            DefaultValue::Default => self.is_default(),
75            DefaultValue::Value(d) => self == d,
76        };
77
78        if !default {
79            self.encode_type(tag, dst);
80            self.encode_value(dst);
81        }
82    }
83
84    #[inline]
85    /// Protobuf field length
86    fn serialized_len(&self, tag: u32, default: DefaultValue<&Self>) -> usize {
87        let default = match default {
88            DefaultValue::Unknown => false,
89            DefaultValue::Default => self.is_default(),
90            DefaultValue::Value(d) => self == d,
91        };
92
93        if default { 0 } else { self.encoded_len(tag) }
94    }
95
96    #[inline]
97    /// Deserialize protobuf field
98    fn deserialize(
99        &mut self,
100        _: u32,
101        wtype: WireType,
102        src: &mut Bytes,
103    ) -> Result<(), DecodeError> {
104        encoding::check_wire_type(Self::TYPE, wtype)?;
105
106        if matches!(Self::TYPE, WireType::Varint | WireType::SixtyFourBit) {
107            self.merge(src)
108        } else {
109            let len = encoding::decode_varint(src)? as usize;
110            let mut buf = src.split_to_checked(len).ok_or_else(|| {
111                DecodeError::new(format!(
112                    "Not enough data, message size {} buffer size {}",
113                    len,
114                    src.len()
115                ))
116            })?;
117            self.merge(&mut buf)
118        }
119    }
120
121    #[inline]
122    /// Deserialize protobuf field to default value
123    fn deserialize_default(
124        tag: u32,
125        wtype: WireType,
126        src: &mut Bytes,
127    ) -> Result<Self, DecodeError> {
128        let mut value = Self::default();
129        value.deserialize(tag, wtype, src)?;
130        Ok(value)
131    }
132}
133
134/// Protobuf struct read/write operations
135impl Message for () {
136    fn encoded_len(&self) -> usize {
137        0
138    }
139    fn read(_: &mut Bytes) -> Result<Self, DecodeError> {
140        Ok(())
141    }
142    fn write(&self, _: &mut BytesMut) {}
143}
144
145impl<T: Message + PartialEq> NativeType for T {
146    const TYPE: WireType = WireType::LengthDelimited;
147
148    fn value_len(&self) -> usize {
149        Message::encoded_len(self)
150    }
151
152    #[inline]
153    /// Encode message to the buffer
154    fn encode_value(&self, dst: &mut BytesMut) {
155        self.write(dst);
156    }
157
158    /// Deserialize from the input
159    fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
160        *self = Message::read(src)?;
161        Ok(())
162    }
163}
164
165impl NativeType for Bytes {
166    const TYPE: WireType = WireType::LengthDelimited;
167
168    #[inline]
169    fn value_len(&self) -> usize {
170        self.len()
171    }
172
173    #[inline]
174    /// Serialize field value
175    fn encode_value(&self, dst: &mut BytesMut) {
176        dst.extend_from_slice(self);
177    }
178
179    #[inline]
180    /// Deserialize from the input
181    fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
182        *self = mem::take(src);
183        Ok(())
184    }
185
186    #[inline]
187    fn is_default(&self) -> bool {
188        self.is_empty()
189    }
190}
191
192impl NativeType for String {
193    const TYPE: WireType = WireType::LengthDelimited;
194
195    #[inline]
196    fn value_len(&self) -> usize {
197        self.len()
198    }
199
200    #[inline]
201    fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
202        if let Ok(s) = ByteString::try_from(mem::take(src)) {
203            *self = s.as_str().to_string();
204            Ok(())
205        } else {
206            Err(DecodeError::new(
207                "invalid string value: data is not UTF-8 encoded",
208            ))
209        }
210    }
211
212    #[inline]
213    fn encode_value(&self, dst: &mut BytesMut) {
214        dst.extend_from_slice(self.as_bytes());
215    }
216
217    #[inline]
218    fn is_default(&self) -> bool {
219        self.is_empty()
220    }
221}
222
223impl NativeType for ByteString {
224    const TYPE: WireType = WireType::LengthDelimited;
225
226    #[inline]
227    fn value_len(&self) -> usize {
228        self.as_slice().len()
229    }
230
231    #[inline]
232    fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
233        if let Ok(s) = ByteString::try_from(mem::take(src)) {
234            *self = s;
235            Ok(())
236        } else {
237            Err(DecodeError::new(
238                "invalid string value: data is not UTF-8 encoded",
239            ))
240        }
241    }
242
243    #[inline]
244    fn encode_value(&self, dst: &mut BytesMut) {
245        dst.extend_from_slice(self.as_bytes());
246    }
247
248    #[inline]
249    fn is_default(&self) -> bool {
250        self.is_empty()
251    }
252}
253
254impl<T: NativeType> NativeType for Option<T> {
255    const TYPE: WireType = WireType::LengthDelimited;
256
257    #[inline]
258    fn is_default(&self) -> bool {
259        self.is_none()
260    }
261
262    #[inline]
263    /// Serialize field value
264    fn encode_value(&self, _: &mut BytesMut) {}
265
266    #[inline]
267    /// Deserialize from the input
268    fn merge(&mut self, _: &mut Bytes) -> Result<(), DecodeError> {
269        Err(DecodeError::new(
270            "Cannot directly call deserialize for Option<T>",
271        ))
272    }
273
274    #[inline]
275    /// Deserialize protobuf field
276    fn deserialize(
277        &mut self,
278        tag: u32,
279        wtype: WireType,
280        src: &mut Bytes,
281    ) -> Result<(), DecodeError> {
282        let mut value: T = Default::default();
283        value.deserialize(tag, wtype, src)?;
284        *self = Some(value);
285        Ok(())
286    }
287
288    #[inline]
289    /// Serialize protobuf field
290    fn serialize(&self, tag: u32, _: DefaultValue<&Self>, dst: &mut BytesMut) {
291        if let Some(value) = self {
292            value.serialize(tag, DefaultValue::Unknown, dst);
293        }
294    }
295
296    #[inline]
297    /// Protobuf field length
298    fn serialized_len(&self, tag: u32, _: DefaultValue<&Self>) -> usize {
299        if let Some(value) = self {
300            value.serialized_len(tag, DefaultValue::Unknown)
301        } else {
302            0
303        }
304    }
305
306    #[inline]
307    /// Protobuf field length
308    fn encoded_len(&self, tag: u32) -> usize {
309        self.as_ref().map_or(0, |value| value.encoded_len(tag))
310    }
311}
312
313impl NativeType for Vec<u8> {
314    const TYPE: WireType = WireType::LengthDelimited;
315
316    #[inline]
317    fn value_len(&self) -> usize {
318        self.len()
319    }
320
321    #[inline]
322    /// Serialize field value
323    fn encode_value(&self, dst: &mut BytesMut) {
324        dst.extend_from_slice(self.as_slice());
325    }
326
327    #[inline]
328    /// Deserialize from the input
329    fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
330        *self = Vec::from(&src[..]);
331        Ok(())
332    }
333
334    #[inline]
335    fn is_default(&self) -> bool {
336        self.is_empty()
337    }
338}
339
340impl<T: NativeType> NativeType for Vec<T> {
341    const TYPE: WireType = WireType::LengthDelimited;
342
343    #[inline]
344    /// Serialize field value
345    fn encode_value(&self, _: &mut BytesMut) {}
346
347    #[inline]
348    /// Deserialize from the input
349    fn merge(&mut self, _: &mut Bytes) -> Result<(), DecodeError> {
350        Err(DecodeError::new("Cannot directly call merge for Vec<T>"))
351    }
352
353    /// Deserialize protobuf field
354    fn deserialize(
355        &mut self,
356        tag: u32,
357        wtype: WireType,
358        src: &mut Bytes,
359    ) -> Result<(), DecodeError> {
360        if T::TYPE == WireType::Varint {
361            let len = encoding::decode_varint(src)? as usize;
362            let mut buf = src
363                .split_to_checked(len)
364                .ok_or_else(DecodeError::incomplete)?;
365            while !buf.is_empty() {
366                let mut value: T = Default::default();
367                value.merge(&mut buf)?;
368                self.push(value);
369            }
370        } else {
371            let mut value: T = Default::default();
372            value.deserialize(tag, wtype, src)?;
373            self.push(value);
374        }
375        Ok(())
376    }
377
378    /// Serialize protobuf field
379    fn serialize(&self, tag: u32, _: DefaultValue<&Self>, dst: &mut BytesMut) {
380        if self.is_empty() {
381            return;
382        }
383        if T::TYPE == WireType::Varint {
384            encoding::encode_key(tag, WireType::LengthDelimited, dst);
385            encoding::encode_varint(
386                self.iter().map(NativeType::value_len).sum::<usize>() as u64,
387                dst,
388            );
389            for item in self {
390                item.encode_value(dst);
391            }
392        } else {
393            for item in self {
394                item.serialize(tag, DefaultValue::Unknown, dst);
395            }
396        }
397    }
398
399    #[inline]
400    fn is_default(&self) -> bool {
401        self.is_empty()
402    }
403
404    /// Protobuf field length
405    fn encoded_len(&self, tag: u32) -> usize {
406        if T::TYPE == WireType::Varint {
407            let len = self.iter().map(NativeType::value_len).sum::<usize>();
408            self.iter().map(NativeType::value_len).sum::<usize>()
409                + encoding::key_len(tag)
410                + encoding::encoded_len_varint(len as u64)
411        } else {
412            self.iter().map(|value| value.encoded_len(tag)).sum()
413        }
414    }
415}
416
417impl<K: NativeType + Eq + Hash, V: NativeType, S: BuildHasher + Default> NativeType
418    for HashMap<K, V, S>
419{
420    const TYPE: WireType = WireType::LengthDelimited;
421
422    #[inline]
423    /// Deserialize from the input
424    fn merge(&mut self, _: &mut Bytes) -> Result<(), DecodeError> {
425        Err(DecodeError::new("Cannot directly call merge for Map<K, V>"))
426    }
427
428    #[inline]
429    /// Serialize field value
430    fn encode_value(&self, _: &mut BytesMut) {}
431
432    #[inline]
433    fn is_default(&self) -> bool {
434        self.is_empty()
435    }
436
437    /// Deserialize protobuf field
438    fn deserialize(
439        &mut self,
440        _: u32,
441        wtype: WireType,
442        src: &mut Bytes,
443    ) -> Result<(), DecodeError> {
444        encoding::check_wire_type(Self::TYPE, wtype)?;
445
446        let len = encoding::decode_varint(src)? as usize;
447        let mut buf = src.split_to_checked(len).ok_or_else(|| {
448            DecodeError::new(format!(
449                "Not enough data for HashMap, message size {}, buf size {}",
450                len,
451                src.len()
452            ))
453        })?;
454        let mut key = Default::default();
455        let mut val = Default::default();
456
457        while !buf.is_empty() {
458            let (tag, wire_type) = encoding::decode_key(&mut buf)?;
459            match tag {
460                1 => NativeType::deserialize(&mut key, 1, wire_type, &mut buf)?,
461                2 => NativeType::deserialize(&mut val, 2, wire_type, &mut buf)?,
462                _ => return Err(DecodeError::new("Map deserialization error")),
463            }
464        }
465        self.insert(key, val);
466        Ok(())
467    }
468
469    /// Serialize protobuf field
470    fn serialize(&self, tag: u32, _: DefaultValue<&Self>, dst: &mut BytesMut) {
471        let key_default = K::default();
472        let val_default = V::default();
473
474        for item in self {
475            let skip_key = item.0 == &key_default;
476            let skip_val = item.1 == &val_default;
477
478            let len = (if skip_key { 0 } else { item.0.encoded_len(1) })
479                + (if skip_val { 0 } else { item.1.encoded_len(2) });
480
481            encoding::encode_key(tag, WireType::LengthDelimited, dst);
482            encoding::encode_varint(len as u64, dst);
483            if !skip_key {
484                item.0.serialize(1, DefaultValue::Default, dst);
485            }
486            if !skip_val {
487                item.1.serialize(2, DefaultValue::Default, dst);
488            }
489        }
490    }
491
492    /// Generic protobuf map encode function with an overridden value default.
493    fn encoded_len(&self, tag: u32) -> usize {
494        let key_default = K::default();
495        let val_default = V::default();
496
497        self.iter()
498            .map(|(key, val)| {
499                let len = (if key == &key_default {
500                    0
501                } else {
502                    key.encoded_len(1)
503                }) + (if val == &val_default {
504                    0
505                } else {
506                    val.encoded_len(2)
507                });
508
509                encoding::key_len(tag) + encoding::encoded_len_varint(len as u64) + len
510            })
511            .sum::<usize>()
512    }
513}
514
515/// Macro which emits a module containing a set of encoding functions for a
516/// variable width numeric type.
517macro_rules! varint {
518    ($ty:ident, $default:expr) => (
519        varint!($ty, $default, to_uint64(self) { *self as u64 }, from_uint64(v) { v as $ty });
520    );
521
522    ($ty:ty, $default:expr, to_uint64($slf:ident) $to_uint64:expr, from_uint64($val:ident) $from_uint64:expr) => (
523
524        impl NativeType for $ty {
525            const TYPE: WireType = WireType::Varint;
526
527            #[inline]
528            fn is_default(&self) -> bool {
529                *self == $default
530            }
531
532            #[inline]
533            fn encode_value(&$slf, dst: &mut BytesMut) {
534                encoding::encode_varint($to_uint64, dst);
535            }
536
537            #[inline]
538            fn encoded_len(&$slf, tag: u32) -> usize {
539                encoding::key_len(tag) + encoding::encoded_len_varint($to_uint64)
540            }
541
542            #[inline]
543            fn value_len(&$slf) -> usize {
544                encoding::encoded_len_varint($to_uint64)
545            }
546
547            #[inline]
548            fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
549                *self = encoding::decode_varint(src).map(|$val| $from_uint64)?;
550                Ok(())
551            }
552        }
553    );
554}
555
556varint!(bool, false,
557        to_uint64(self) u64::from(*self),
558        from_uint64(value) value != 0);
559varint!(i32, 0i32);
560varint!(i64, 0i64);
561varint!(u32, 0u32);
562varint!(u64, 0u64);
563
564/// Macro which emits a module containing a set of encoding functions for a
565/// fixed width numeric type.
566macro_rules! fixed_width {
567    ($ty:ty,
568     $width:expr,
569     $wire_type:expr,
570     $default:expr,
571     $put:expr,
572     $get:expr) => {
573        impl NativeType for $ty {
574            const TYPE: WireType = $wire_type;
575
576            #[inline]
577            fn is_default(&self) -> bool {
578                *self == $default
579            }
580
581            #[inline]
582            fn encode_value(&self, dst: &mut BytesMut) {
583                $put(dst, *self);
584            }
585
586            #[inline]
587            fn encoded_len(&self, tag: u32) -> usize {
588                encoding::key_len(tag) + $width
589            }
590
591            #[inline]
592            fn value_len(&self) -> usize {
593                $width
594            }
595
596            #[inline]
597            fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
598                if src.len() < $width {
599                    return Err(DecodeError::new("Buffer underflow"));
600                }
601                *self = $get(src);
602                Ok(())
603            }
604        }
605    };
606}
607
608fixed_width!(
609    f32,
610    4,
611    WireType::ThirtyTwoBit,
612    0f32,
613    BufMut::put_f32_le,
614    Buf::get_f32_le
615);
616fixed_width!(
617    f64,
618    8,
619    WireType::SixtyFourBit,
620    0f64,
621    BufMut::put_f64_le,
622    Buf::get_f64_le
623);
624
625#[cfg(test)]
626mod tests {
627    use super::*;
628
629    #[derive(Clone, PartialEq, Debug, Default)]
630    pub struct TestMessage {
631        f: f64,
632        props: HashMap<String, u32>,
633        b: bool,
634        opt: Option<String>,
635    }
636
637    impl Message for TestMessage {
638        fn write(&self, dst: &mut BytesMut) {
639            NativeType::serialize(&self.f, 1, DefaultValue::Default, dst);
640            NativeType::serialize(&self.props, 2, DefaultValue::Default, dst);
641            NativeType::serialize(&self.b, 3, DefaultValue::Default, dst);
642            NativeType::serialize(&self.opt, 4, DefaultValue::Default, dst);
643        }
644
645        #[inline]
646        fn read(src: &mut Bytes) -> Result<Self, DecodeError> {
647            let mut msg = Self::default();
648            while !src.is_empty() {
649                let (tag, wire_type) = encoding::decode_key(src)?;
650                match tag {
651                    1 => NativeType::deserialize(&mut msg.f, tag, wire_type, src)?,
652                    2 => NativeType::deserialize(&mut msg.props, tag, wire_type, src)?,
653                    3 => NativeType::deserialize(&mut msg.b, tag, wire_type, src)?,
654                    4 => NativeType::deserialize(&mut msg.opt, tag, wire_type, src)?,
655                    _ => encoding::skip_field(wire_type, tag, src)?,
656                }
657            }
658            Ok(msg)
659        }
660
661        #[allow(clippy::identity_op)]
662        #[inline]
663        fn encoded_len(&self) -> usize {
664            0 + NativeType::serialized_len(&self.f, 1, DefaultValue::Default)
665                + NativeType::serialized_len(&self.props, 2, DefaultValue::Default)
666                + NativeType::serialized_len(&self.b, 3, DefaultValue::Default)
667                + NativeType::serialized_len(&self.opt, 4, DefaultValue::Default)
668        }
669    }
670
671    #[allow(clippy::field_reassign_with_default)]
672    #[test]
673    fn test_hashmap_default_values() {
674        let mut msg = TestMessage::default();
675
676        msg.f = 382.8263;
677        msg.b = true;
678        msg.props.insert("test1".to_string(), 1);
679        msg.props.insert("test2".to_string(), 0);
680        msg.props.insert("".to_string(), 0);
681
682        let mut buf = BytesMut::new();
683        msg.write(&mut buf);
684        assert_eq!(Message::encoded_len(&msg), 33);
685        assert_eq!(buf.len(), 33);
686
687        let mut buf2 = BytesMut::new();
688        msg.serialize(1, DefaultValue::Default, &mut buf2);
689        assert_eq!(NativeType::encoded_len(&msg, 1), 35);
690        assert_eq!(buf2.len(), 35);
691
692        let msg2 = TestMessage::read(&mut buf.freeze()).unwrap();
693        assert_eq!(Message::encoded_len(&msg2), 33);
694        assert_eq!(msg, msg2);
695
696        let mut buf2 = buf2.freeze();
697        let mut msg3 = TestMessage::default();
698        let (tag, wire_type) = encoding::decode_key(&mut buf2).unwrap();
699        msg3.deserialize(tag, wire_type, &mut buf2).unwrap();
700        assert_eq!(msg, msg3);
701    }
702}