1#![allow(
2 clippy::cast_lossless,
3 clippy::cast_sign_loss,
4 clippy::cast_possible_wrap
5)]
6use std::{collections::HashMap, convert::TryFrom, fmt, hash::BuildHasher, hash::Hash, mem};
7
8use ntex_bytes::{Buf, BufMut, ByteString, Bytes, BytesMut};
9
10pub use crate::encoding::WireType;
11use crate::encoding::{self, DecodeError};
12
13pub trait Message: Default + Sized + fmt::Debug {
15 fn read(src: &mut Bytes) -> Result<Self, DecodeError>;
17
18 fn write(&self, dst: &mut BytesMut);
20
21 fn encoded_len(&self) -> usize;
23}
24
25pub enum DefaultValue<T> {
27 Unknown,
28 Default,
29 Value(T),
30}
31
32pub trait NativeType: PartialEq + Default + Sized + fmt::Debug {
34 const TYPE: WireType;
35
36 #[inline]
37 fn value_len(&self) -> usize {
39 0
40 }
41
42 fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError>;
44
45 fn is_default(&self) -> bool {
47 false
48 }
49
50 fn encode_value(&self, dst: &mut BytesMut);
52
53 #[inline]
54 fn encode_type(&self, tag: u32, dst: &mut BytesMut) {
56 encoding::encode_key(tag, Self::TYPE, dst);
57 if !matches!(Self::TYPE, WireType::Varint | WireType::SixtyFourBit) {
58 encoding::encode_varint(self.value_len() as u64, dst);
59 }
60 }
61
62 #[inline]
63 fn encoded_len(&self, tag: u32) -> usize {
65 let value_len = self.value_len();
66 encoding::key_len(tag) + encoding::encoded_len_varint(value_len as u64) + value_len
67 }
68
69 #[inline]
70 fn serialize(&self, tag: u32, default: DefaultValue<&Self>, dst: &mut BytesMut) {
72 let default = match default {
73 DefaultValue::Unknown => false,
74 DefaultValue::Default => self.is_default(),
75 DefaultValue::Value(d) => self == d,
76 };
77
78 if !default {
79 self.encode_type(tag, dst);
80 self.encode_value(dst);
81 }
82 }
83
84 #[inline]
85 fn serialized_len(&self, tag: u32, default: DefaultValue<&Self>) -> usize {
87 let default = match default {
88 DefaultValue::Unknown => false,
89 DefaultValue::Default => self.is_default(),
90 DefaultValue::Value(d) => self == d,
91 };
92
93 if default { 0 } else { self.encoded_len(tag) }
94 }
95
96 #[inline]
97 fn deserialize(
99 &mut self,
100 _: u32,
101 wtype: WireType,
102 src: &mut Bytes,
103 ) -> Result<(), DecodeError> {
104 encoding::check_wire_type(Self::TYPE, wtype)?;
105
106 if matches!(Self::TYPE, WireType::Varint | WireType::SixtyFourBit) {
107 self.merge(src)
108 } else {
109 let len = encoding::decode_varint(src)? as usize;
110 let mut buf = src.split_to_checked(len).ok_or_else(|| {
111 DecodeError::new(format!(
112 "Not enough data, message size {} buffer size {}",
113 len,
114 src.len()
115 ))
116 })?;
117 self.merge(&mut buf)
118 }
119 }
120
121 #[inline]
122 fn deserialize_default(
124 tag: u32,
125 wtype: WireType,
126 src: &mut Bytes,
127 ) -> Result<Self, DecodeError> {
128 let mut value = Self::default();
129 value.deserialize(tag, wtype, src)?;
130 Ok(value)
131 }
132}
133
134impl Message for () {
136 fn encoded_len(&self) -> usize {
137 0
138 }
139 fn read(_: &mut Bytes) -> Result<Self, DecodeError> {
140 Ok(())
141 }
142 fn write(&self, _: &mut BytesMut) {}
143}
144
145impl<T: Message + PartialEq> NativeType for T {
146 const TYPE: WireType = WireType::LengthDelimited;
147
148 fn value_len(&self) -> usize {
149 Message::encoded_len(self)
150 }
151
152 #[inline]
153 fn encode_value(&self, dst: &mut BytesMut) {
155 self.write(dst);
156 }
157
158 fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
160 *self = Message::read(src)?;
161 Ok(())
162 }
163}
164
165impl NativeType for Bytes {
166 const TYPE: WireType = WireType::LengthDelimited;
167
168 #[inline]
169 fn value_len(&self) -> usize {
170 self.len()
171 }
172
173 #[inline]
174 fn encode_value(&self, dst: &mut BytesMut) {
176 dst.extend_from_slice(self);
177 }
178
179 #[inline]
180 fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
182 *self = mem::take(src);
183 Ok(())
184 }
185
186 #[inline]
187 fn is_default(&self) -> bool {
188 self.is_empty()
189 }
190}
191
192impl NativeType for String {
193 const TYPE: WireType = WireType::LengthDelimited;
194
195 #[inline]
196 fn value_len(&self) -> usize {
197 self.len()
198 }
199
200 #[inline]
201 fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
202 if let Ok(s) = ByteString::try_from(mem::take(src)) {
203 *self = s.as_str().to_string();
204 Ok(())
205 } else {
206 Err(DecodeError::new(
207 "invalid string value: data is not UTF-8 encoded",
208 ))
209 }
210 }
211
212 #[inline]
213 fn encode_value(&self, dst: &mut BytesMut) {
214 dst.extend_from_slice(self.as_bytes());
215 }
216
217 #[inline]
218 fn is_default(&self) -> bool {
219 self.is_empty()
220 }
221}
222
223impl NativeType for ByteString {
224 const TYPE: WireType = WireType::LengthDelimited;
225
226 #[inline]
227 fn value_len(&self) -> usize {
228 self.as_slice().len()
229 }
230
231 #[inline]
232 fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
233 if let Ok(s) = ByteString::try_from(mem::take(src)) {
234 *self = s;
235 Ok(())
236 } else {
237 Err(DecodeError::new(
238 "invalid string value: data is not UTF-8 encoded",
239 ))
240 }
241 }
242
243 #[inline]
244 fn encode_value(&self, dst: &mut BytesMut) {
245 dst.extend_from_slice(self.as_bytes());
246 }
247
248 #[inline]
249 fn is_default(&self) -> bool {
250 self.is_empty()
251 }
252}
253
254impl<T: NativeType> NativeType for Option<T> {
255 const TYPE: WireType = WireType::LengthDelimited;
256
257 #[inline]
258 fn is_default(&self) -> bool {
259 self.is_none()
260 }
261
262 #[inline]
263 fn encode_value(&self, _: &mut BytesMut) {}
265
266 #[inline]
267 fn merge(&mut self, _: &mut Bytes) -> Result<(), DecodeError> {
269 Err(DecodeError::new(
270 "Cannot directly call deserialize for Option<T>",
271 ))
272 }
273
274 #[inline]
275 fn deserialize(
277 &mut self,
278 tag: u32,
279 wtype: WireType,
280 src: &mut Bytes,
281 ) -> Result<(), DecodeError> {
282 let mut value: T = Default::default();
283 value.deserialize(tag, wtype, src)?;
284 *self = Some(value);
285 Ok(())
286 }
287
288 #[inline]
289 fn serialize(&self, tag: u32, _: DefaultValue<&Self>, dst: &mut BytesMut) {
291 if let Some(value) = self {
292 value.serialize(tag, DefaultValue::Unknown, dst);
293 }
294 }
295
296 #[inline]
297 fn serialized_len(&self, tag: u32, _: DefaultValue<&Self>) -> usize {
299 if let Some(value) = self {
300 value.serialized_len(tag, DefaultValue::Unknown)
301 } else {
302 0
303 }
304 }
305
306 #[inline]
307 fn encoded_len(&self, tag: u32) -> usize {
309 self.as_ref().map_or(0, |value| value.encoded_len(tag))
310 }
311}
312
313impl NativeType for Vec<u8> {
314 const TYPE: WireType = WireType::LengthDelimited;
315
316 #[inline]
317 fn value_len(&self) -> usize {
318 self.len()
319 }
320
321 #[inline]
322 fn encode_value(&self, dst: &mut BytesMut) {
324 dst.extend_from_slice(self.as_slice());
325 }
326
327 #[inline]
328 fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
330 *self = Vec::from(&src[..]);
331 Ok(())
332 }
333
334 #[inline]
335 fn is_default(&self) -> bool {
336 self.is_empty()
337 }
338}
339
340impl<T: NativeType> NativeType for Vec<T> {
341 const TYPE: WireType = WireType::LengthDelimited;
342
343 #[inline]
344 fn encode_value(&self, _: &mut BytesMut) {}
346
347 #[inline]
348 fn merge(&mut self, _: &mut Bytes) -> Result<(), DecodeError> {
350 Err(DecodeError::new("Cannot directly call merge for Vec<T>"))
351 }
352
353 fn deserialize(
355 &mut self,
356 tag: u32,
357 wtype: WireType,
358 src: &mut Bytes,
359 ) -> Result<(), DecodeError> {
360 if T::TYPE == WireType::Varint {
361 let len = encoding::decode_varint(src)? as usize;
362 let mut buf = src
363 .split_to_checked(len)
364 .ok_or_else(DecodeError::incomplete)?;
365 while !buf.is_empty() {
366 let mut value: T = Default::default();
367 value.merge(&mut buf)?;
368 self.push(value);
369 }
370 } else {
371 let mut value: T = Default::default();
372 value.deserialize(tag, wtype, src)?;
373 self.push(value);
374 }
375 Ok(())
376 }
377
378 fn serialize(&self, tag: u32, _: DefaultValue<&Self>, dst: &mut BytesMut) {
380 if self.is_empty() {
381 return;
382 }
383 if T::TYPE == WireType::Varint {
384 encoding::encode_key(tag, WireType::LengthDelimited, dst);
385 encoding::encode_varint(
386 self.iter().map(NativeType::value_len).sum::<usize>() as u64,
387 dst,
388 );
389 for item in self {
390 item.encode_value(dst);
391 }
392 } else {
393 for item in self {
394 item.serialize(tag, DefaultValue::Unknown, dst);
395 }
396 }
397 }
398
399 #[inline]
400 fn is_default(&self) -> bool {
401 self.is_empty()
402 }
403
404 fn encoded_len(&self, tag: u32) -> usize {
406 if T::TYPE == WireType::Varint {
407 let len = self.iter().map(NativeType::value_len).sum::<usize>();
408 self.iter().map(NativeType::value_len).sum::<usize>()
409 + encoding::key_len(tag)
410 + encoding::encoded_len_varint(len as u64)
411 } else {
412 self.iter().map(|value| value.encoded_len(tag)).sum()
413 }
414 }
415}
416
417impl<K: NativeType + Eq + Hash, V: NativeType, S: BuildHasher + Default> NativeType
418 for HashMap<K, V, S>
419{
420 const TYPE: WireType = WireType::LengthDelimited;
421
422 #[inline]
423 fn merge(&mut self, _: &mut Bytes) -> Result<(), DecodeError> {
425 Err(DecodeError::new("Cannot directly call merge for Map<K, V>"))
426 }
427
428 #[inline]
429 fn encode_value(&self, _: &mut BytesMut) {}
431
432 #[inline]
433 fn is_default(&self) -> bool {
434 self.is_empty()
435 }
436
437 fn deserialize(
439 &mut self,
440 _: u32,
441 wtype: WireType,
442 src: &mut Bytes,
443 ) -> Result<(), DecodeError> {
444 encoding::check_wire_type(Self::TYPE, wtype)?;
445
446 let len = encoding::decode_varint(src)? as usize;
447 let mut buf = src.split_to_checked(len).ok_or_else(|| {
448 DecodeError::new(format!(
449 "Not enough data for HashMap, message size {}, buf size {}",
450 len,
451 src.len()
452 ))
453 })?;
454 let mut key = Default::default();
455 let mut val = Default::default();
456
457 while !buf.is_empty() {
458 let (tag, wire_type) = encoding::decode_key(&mut buf)?;
459 match tag {
460 1 => NativeType::deserialize(&mut key, 1, wire_type, &mut buf)?,
461 2 => NativeType::deserialize(&mut val, 2, wire_type, &mut buf)?,
462 _ => return Err(DecodeError::new("Map deserialization error")),
463 }
464 }
465 self.insert(key, val);
466 Ok(())
467 }
468
469 fn serialize(&self, tag: u32, _: DefaultValue<&Self>, dst: &mut BytesMut) {
471 let key_default = K::default();
472 let val_default = V::default();
473
474 for item in self {
475 let skip_key = item.0 == &key_default;
476 let skip_val = item.1 == &val_default;
477
478 let len = (if skip_key { 0 } else { item.0.encoded_len(1) })
479 + (if skip_val { 0 } else { item.1.encoded_len(2) });
480
481 encoding::encode_key(tag, WireType::LengthDelimited, dst);
482 encoding::encode_varint(len as u64, dst);
483 if !skip_key {
484 item.0.serialize(1, DefaultValue::Default, dst);
485 }
486 if !skip_val {
487 item.1.serialize(2, DefaultValue::Default, dst);
488 }
489 }
490 }
491
492 fn encoded_len(&self, tag: u32) -> usize {
494 let key_default = K::default();
495 let val_default = V::default();
496
497 self.iter()
498 .map(|(key, val)| {
499 let len = (if key == &key_default {
500 0
501 } else {
502 key.encoded_len(1)
503 }) + (if val == &val_default {
504 0
505 } else {
506 val.encoded_len(2)
507 });
508
509 encoding::key_len(tag) + encoding::encoded_len_varint(len as u64) + len
510 })
511 .sum::<usize>()
512 }
513}
514
515macro_rules! varint {
518 ($ty:ident, $default:expr) => (
519 varint!($ty, $default, to_uint64(self) { *self as u64 }, from_uint64(v) { v as $ty });
520 );
521
522 ($ty:ty, $default:expr, to_uint64($slf:ident) $to_uint64:expr, from_uint64($val:ident) $from_uint64:expr) => (
523
524 impl NativeType for $ty {
525 const TYPE: WireType = WireType::Varint;
526
527 #[inline]
528 fn is_default(&self) -> bool {
529 *self == $default
530 }
531
532 #[inline]
533 fn encode_value(&$slf, dst: &mut BytesMut) {
534 encoding::encode_varint($to_uint64, dst);
535 }
536
537 #[inline]
538 fn encoded_len(&$slf, tag: u32) -> usize {
539 encoding::key_len(tag) + encoding::encoded_len_varint($to_uint64)
540 }
541
542 #[inline]
543 fn value_len(&$slf) -> usize {
544 encoding::encoded_len_varint($to_uint64)
545 }
546
547 #[inline]
548 fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
549 *self = encoding::decode_varint(src).map(|$val| $from_uint64)?;
550 Ok(())
551 }
552 }
553 );
554}
555
556varint!(bool, false,
557 to_uint64(self) u64::from(*self),
558 from_uint64(value) value != 0);
559varint!(i32, 0i32);
560varint!(i64, 0i64);
561varint!(u32, 0u32);
562varint!(u64, 0u64);
563
564macro_rules! fixed_width {
567 ($ty:ty,
568 $width:expr,
569 $wire_type:expr,
570 $default:expr,
571 $put:expr,
572 $get:expr) => {
573 impl NativeType for $ty {
574 const TYPE: WireType = $wire_type;
575
576 #[inline]
577 fn is_default(&self) -> bool {
578 *self == $default
579 }
580
581 #[inline]
582 fn encode_value(&self, dst: &mut BytesMut) {
583 $put(dst, *self);
584 }
585
586 #[inline]
587 fn encoded_len(&self, tag: u32) -> usize {
588 encoding::key_len(tag) + $width
589 }
590
591 #[inline]
592 fn value_len(&self) -> usize {
593 $width
594 }
595
596 #[inline]
597 fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
598 if src.len() < $width {
599 return Err(DecodeError::new("Buffer underflow"));
600 }
601 *self = $get(src);
602 Ok(())
603 }
604 }
605 };
606}
607
608fixed_width!(
609 f32,
610 4,
611 WireType::ThirtyTwoBit,
612 0f32,
613 BufMut::put_f32_le,
614 Buf::get_f32_le
615);
616fixed_width!(
617 f64,
618 8,
619 WireType::SixtyFourBit,
620 0f64,
621 BufMut::put_f64_le,
622 Buf::get_f64_le
623);
624
625#[cfg(test)]
626mod tests {
627 use super::*;
628
629 #[derive(Clone, PartialEq, Debug, Default)]
630 pub struct TestMessage {
631 f: f64,
632 props: HashMap<String, u32>,
633 b: bool,
634 opt: Option<String>,
635 }
636
637 impl Message for TestMessage {
638 fn write(&self, dst: &mut BytesMut) {
639 NativeType::serialize(&self.f, 1, DefaultValue::Default, dst);
640 NativeType::serialize(&self.props, 2, DefaultValue::Default, dst);
641 NativeType::serialize(&self.b, 3, DefaultValue::Default, dst);
642 NativeType::serialize(&self.opt, 4, DefaultValue::Default, dst);
643 }
644
645 #[inline]
646 fn read(src: &mut Bytes) -> Result<Self, DecodeError> {
647 let mut msg = Self::default();
648 while !src.is_empty() {
649 let (tag, wire_type) = encoding::decode_key(src)?;
650 match tag {
651 1 => NativeType::deserialize(&mut msg.f, tag, wire_type, src)?,
652 2 => NativeType::deserialize(&mut msg.props, tag, wire_type, src)?,
653 3 => NativeType::deserialize(&mut msg.b, tag, wire_type, src)?,
654 4 => NativeType::deserialize(&mut msg.opt, tag, wire_type, src)?,
655 _ => encoding::skip_field(wire_type, tag, src)?,
656 }
657 }
658 Ok(msg)
659 }
660
661 #[allow(clippy::identity_op)]
662 #[inline]
663 fn encoded_len(&self) -> usize {
664 0 + NativeType::serialized_len(&self.f, 1, DefaultValue::Default)
665 + NativeType::serialized_len(&self.props, 2, DefaultValue::Default)
666 + NativeType::serialized_len(&self.b, 3, DefaultValue::Default)
667 + NativeType::serialized_len(&self.opt, 4, DefaultValue::Default)
668 }
669 }
670
671 #[allow(clippy::field_reassign_with_default)]
672 #[test]
673 fn test_hashmap_default_values() {
674 let mut msg = TestMessage::default();
675
676 msg.f = 382.8263;
677 msg.b = true;
678 msg.props.insert("test1".to_string(), 1);
679 msg.props.insert("test2".to_string(), 0);
680 msg.props.insert("".to_string(), 0);
681
682 let mut buf = BytesMut::new();
683 msg.write(&mut buf);
684 assert_eq!(Message::encoded_len(&msg), 33);
685 assert_eq!(buf.len(), 33);
686
687 let mut buf2 = BytesMut::new();
688 msg.serialize(1, DefaultValue::Default, &mut buf2);
689 assert_eq!(NativeType::encoded_len(&msg, 1), 35);
690 assert_eq!(buf2.len(), 35);
691
692 let msg2 = TestMessage::read(&mut buf.freeze()).unwrap();
693 assert_eq!(Message::encoded_len(&msg2), 33);
694 assert_eq!(msg, msg2);
695
696 let mut buf2 = buf2.freeze();
697 let mut msg3 = TestMessage::default();
698 let (tag, wire_type) = encoding::decode_key(&mut buf2).unwrap();
699 msg3.deserialize(tag, wire_type, &mut buf2).unwrap();
700 assert_eq!(msg, msg3);
701 }
702}