1use std::{collections::HashMap, convert::TryFrom, fmt, hash::BuildHasher, hash::Hash, mem};
2
3use ntex_bytes::{Buf, BufMut, ByteString, Bytes, BytesMut};
4
5pub use crate::encoding::WireType;
6use crate::encoding::{self, DecodeError};
7
8pub trait Message: Default + Sized + fmt::Debug {
10 fn read(src: &mut Bytes) -> Result<Self, DecodeError>;
12
13 fn write(&self, dst: &mut BytesMut);
15
16 fn encoded_len(&self) -> usize;
18}
19
20pub enum DefaultValue<T> {
22 Unknown,
23 Default,
24 Value(T),
25}
26
27pub trait NativeType: PartialEq + Default + Sized + fmt::Debug {
29 const TYPE: WireType;
30
31 #[inline]
32 fn value_len(&self) -> usize {
34 0
35 }
36
37 fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError>;
39
40 fn is_default(&self) -> bool {
42 false
43 }
44
45 fn encode_value(&self, dst: &mut BytesMut);
47
48 #[inline]
49 fn encode_type(&self, tag: u32, dst: &mut BytesMut) {
51 encoding::encode_key(tag, Self::TYPE, dst);
52 if !matches!(Self::TYPE, WireType::Varint | WireType::SixtyFourBit) {
53 encoding::encode_varint(self.value_len() as u64, dst);
54 }
55 }
56
57 #[inline]
58 fn encoded_len(&self, tag: u32) -> usize {
60 let value_len = self.value_len();
61 encoding::key_len(tag) + encoding::encoded_len_varint(value_len as u64) + value_len
62 }
63
64 #[inline]
65 fn serialize(&self, tag: u32, default: DefaultValue<&Self>, dst: &mut BytesMut) {
67 let default = match default {
68 DefaultValue::Unknown => false,
69 DefaultValue::Default => self.is_default(),
70 DefaultValue::Value(d) => self == d,
71 };
72
73 if !default {
74 self.encode_type(tag, dst);
75 self.encode_value(dst);
76 }
77 }
78
79 #[inline]
80 fn serialized_len(&self, tag: u32, default: DefaultValue<&Self>) -> usize {
82 let default = match default {
83 DefaultValue::Unknown => false,
84 DefaultValue::Default => self.is_default(),
85 DefaultValue::Value(d) => self == d,
86 };
87
88 if default {
89 0
90 } else {
91 self.encoded_len(tag)
92 }
93 }
94
95 #[inline]
96 fn deserialize(
98 &mut self,
99 _: u32,
100 wtype: WireType,
101 src: &mut Bytes,
102 ) -> Result<(), DecodeError> {
103 encoding::check_wire_type(Self::TYPE, wtype)?;
104
105 if matches!(Self::TYPE, WireType::Varint | WireType::SixtyFourBit) {
106 self.merge(src)
107 } else {
108 let len = encoding::decode_varint(src)? as usize;
109 let mut buf = src.split_to_checked(len).ok_or_else(|| {
110 DecodeError::new(format!(
111 "Not enough data, message size {} buffer size {}",
112 len,
113 src.len()
114 ))
115 })?;
116 self.merge(&mut buf)
117 }
118 }
119
120 #[inline]
121 fn deserialize_default(
123 tag: u32,
124 wtype: WireType,
125 src: &mut Bytes,
126 ) -> Result<Self, DecodeError> {
127 let mut value = Self::default();
128 value.deserialize(tag, wtype, src)?;
129 Ok(value)
130 }
131}
132
133impl Message for () {
135 fn encoded_len(&self) -> usize {
136 0
137 }
138 fn read(_: &mut Bytes) -> Result<Self, DecodeError> {
139 Ok(())
140 }
141 fn write(&self, _: &mut BytesMut) {}
142}
143
144impl<T: Message + PartialEq> NativeType for T {
145 const TYPE: WireType = WireType::LengthDelimited;
146
147 fn value_len(&self) -> usize {
148 Message::encoded_len(self)
149 }
150
151 #[inline]
152 fn encode_value(&self, dst: &mut BytesMut) {
154 self.write(dst)
155 }
156
157 fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
159 *self = Message::read(src)?;
160 Ok(())
161 }
162}
163
164impl NativeType for Bytes {
165 const TYPE: WireType = WireType::LengthDelimited;
166
167 #[inline]
168 fn value_len(&self) -> usize {
169 self.len()
170 }
171
172 #[inline]
173 fn encode_value(&self, dst: &mut BytesMut) {
175 dst.extend_from_slice(self);
176 }
177
178 #[inline]
179 fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
181 *self = mem::take(src);
182 Ok(())
183 }
184
185 #[inline]
186 fn is_default(&self) -> bool {
187 self.is_empty()
188 }
189}
190
191impl NativeType for String {
192 const TYPE: WireType = WireType::LengthDelimited;
193
194 #[inline]
195 fn value_len(&self) -> usize {
196 self.len()
197 }
198
199 #[inline]
200 fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
201 if let Ok(s) = ByteString::try_from(mem::take(src)) {
202 *self = s.as_str().to_string();
203 Ok(())
204 } else {
205 Err(DecodeError::new(
206 "invalid string value: data is not UTF-8 encoded",
207 ))
208 }
209 }
210
211 #[inline]
212 fn encode_value(&self, dst: &mut BytesMut) {
213 dst.extend_from_slice(self.as_bytes());
214 }
215
216 #[inline]
217 fn is_default(&self) -> bool {
218 self.is_empty()
219 }
220}
221
222impl NativeType for ByteString {
223 const TYPE: WireType = WireType::LengthDelimited;
224
225 #[inline]
226 fn value_len(&self) -> usize {
227 self.as_slice().len()
228 }
229
230 #[inline]
231 fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
232 if let Ok(s) = ByteString::try_from(mem::take(src)) {
233 *self = s;
234 Ok(())
235 } else {
236 Err(DecodeError::new(
237 "invalid string value: data is not UTF-8 encoded",
238 ))
239 }
240 }
241
242 #[inline]
243 fn encode_value(&self, dst: &mut BytesMut) {
244 dst.extend_from_slice(self.as_bytes());
245 }
246
247 #[inline]
248 fn is_default(&self) -> bool {
249 self.is_empty()
250 }
251}
252
253impl<T: NativeType> NativeType for Option<T> {
254 const TYPE: WireType = WireType::LengthDelimited;
255
256 #[inline]
257 fn is_default(&self) -> bool {
258 self.is_none()
259 }
260
261 #[inline]
262 fn encode_value(&self, _: &mut BytesMut) {}
264
265 #[inline]
266 fn merge(&mut self, _: &mut Bytes) -> Result<(), DecodeError> {
268 Err(DecodeError::new(
269 "Cannot directly call deserialize for Option<T>",
270 ))
271 }
272
273 #[inline]
274 fn deserialize(
276 &mut self,
277 tag: u32,
278 wtype: WireType,
279 src: &mut Bytes,
280 ) -> Result<(), DecodeError> {
281 let mut value: T = Default::default();
282 value.deserialize(tag, wtype, src)?;
283 *self = Some(value);
284 Ok(())
285 }
286
287 #[inline]
288 fn serialize(&self, tag: u32, _: DefaultValue<&Self>, dst: &mut BytesMut) {
290 if let Some(ref value) = self {
291 value.serialize(tag, DefaultValue::Unknown, dst);
292 }
293 }
294
295 #[inline]
296 fn serialized_len(&self, tag: u32, _: DefaultValue<&Self>) -> usize {
298 if let Some(ref value) = self {
299 value.serialized_len(tag, DefaultValue::Unknown)
300 } else {
301 0
302 }
303 }
304
305 #[inline]
306 fn encoded_len(&self, tag: u32) -> usize {
308 self.as_ref()
309 .map(|value| value.encoded_len(tag))
310 .unwrap_or(0)
311 }
312}
313
314impl NativeType for Vec<u8> {
315 const TYPE: WireType = WireType::LengthDelimited;
316
317 #[inline]
318 fn value_len(&self) -> usize {
319 self.len()
320 }
321
322 #[inline]
323 fn encode_value(&self, dst: &mut BytesMut) {
325 dst.extend_from_slice(self.as_slice());
326 }
327
328 #[inline]
329 fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
331 *self = Vec::from(&src[..]);
332 Ok(())
333 }
334
335 #[inline]
336 fn is_default(&self) -> bool {
337 self.is_empty()
338 }
339}
340
341impl<T: NativeType> NativeType for Vec<T> {
342 const TYPE: WireType = WireType::LengthDelimited;
343
344 #[inline]
345 fn encode_value(&self, _: &mut BytesMut) {}
347
348 #[inline]
349 fn merge(&mut self, _: &mut Bytes) -> Result<(), DecodeError> {
351 Err(DecodeError::new("Cannot directly call merge for Vec<T>"))
352 }
353
354 fn deserialize(
356 &mut self,
357 tag: u32,
358 wtype: WireType,
359 src: &mut Bytes,
360 ) -> Result<(), DecodeError> {
361 if T::TYPE == WireType::Varint {
362 let len = encoding::decode_varint(src)? as usize;
363 let mut buf = src
364 .split_to_checked(len)
365 .ok_or_else(DecodeError::incomplete)?;
366 while !buf.is_empty() {
367 let mut value: T = Default::default();
368 value.merge(&mut buf)?;
369 self.push(value);
370 }
371 } else {
372 let mut value: T = Default::default();
373 value.deserialize(tag, wtype, src)?;
374 self.push(value);
375 }
376 Ok(())
377 }
378
379 fn serialize(&self, tag: u32, _: DefaultValue<&Self>, dst: &mut BytesMut) {
381 if T::TYPE == WireType::Varint {
382 encoding::encode_key(tag, WireType::LengthDelimited, dst);
383 encoding::encode_varint(
384 self.iter().map(|v| v.value_len()).sum::<usize>() as u64,
385 dst,
386 );
387 for item in self.iter() {
388 item.encode_value(dst);
389 }
390 } else {
391 for item in self.iter() {
392 item.serialize(tag, DefaultValue::Unknown, dst);
393 }
394 }
395 }
396
397 #[inline]
398 fn is_default(&self) -> bool {
399 self.is_empty()
400 }
401
402 fn encoded_len(&self, tag: u32) -> usize {
404 if T::TYPE == WireType::Varint {
405 let len = self.iter().map(|value| value.value_len()).sum::<usize>();
406 self.iter().map(|value| value.value_len()).sum::<usize>()
407 + encoding::key_len(tag)
408 + encoding::encoded_len_varint(len as u64)
409 } else {
410 self.iter().map(|value| value.encoded_len(tag)).sum()
411 }
412 }
413}
414
415impl<K: NativeType + Eq + Hash, V: NativeType, S: BuildHasher + Default> NativeType
416 for HashMap<K, V, S>
417{
418 const TYPE: WireType = WireType::LengthDelimited;
419
420 #[inline]
421 fn merge(&mut self, _: &mut Bytes) -> Result<(), DecodeError> {
423 Err(DecodeError::new("Cannot directly call merge for Map<K, V>"))
424 }
425
426 #[inline]
427 fn encode_value(&self, _: &mut BytesMut) {}
429
430 #[inline]
431 fn is_default(&self) -> bool {
432 self.is_empty()
433 }
434
435 fn deserialize(
437 &mut self,
438 _: u32,
439 wtype: WireType,
440 src: &mut Bytes,
441 ) -> Result<(), DecodeError> {
442 encoding::check_wire_type(Self::TYPE, wtype)?;
443
444 let len = encoding::decode_varint(src)? as usize;
445 let mut buf = src.split_to_checked(len).ok_or_else(|| {
446 DecodeError::new(format!(
447 "Not enough data for HashMap, message size {}, buf size {}",
448 len,
449 src.len()
450 ))
451 })?;
452 let mut key = Default::default();
453 let mut val = Default::default();
454
455 while !buf.is_empty() {
456 let (tag, wire_type) = encoding::decode_key(&mut buf)?;
457 match tag {
458 1 => NativeType::deserialize(&mut key, 1, wire_type, &mut buf)?,
459 2 => NativeType::deserialize(&mut val, 2, wire_type, &mut buf)?,
460 _ => return Err(DecodeError::new("Map deserialization error")),
461 }
462 }
463 self.insert(key, val);
464 Ok(())
465 }
466
467 fn serialize(&self, tag: u32, _: DefaultValue<&Self>, dst: &mut BytesMut) {
469 let key_default = K::default();
470 let val_default = V::default();
471
472 for item in self.iter() {
473 let skip_key = item.0 == &key_default;
474 let skip_val = item.1 == &val_default;
475
476 let len = (if skip_key { 0 } else { item.0.encoded_len(1) })
477 + (if skip_val { 0 } else { item.1.encoded_len(2) });
478
479 encoding::encode_key(tag, WireType::LengthDelimited, dst);
480 encoding::encode_varint(len as u64, dst);
481 if !skip_key {
482 item.0.serialize(1, DefaultValue::Default, dst);
483 }
484 if !skip_val {
485 item.1.serialize(2, DefaultValue::Default, dst);
486 }
487 }
488 }
489
490 fn encoded_len(&self, tag: u32) -> usize {
492 let key_default = K::default();
493 let val_default = V::default();
494
495 self.iter()
496 .map(|(key, val)| {
497 let len = (if key == &key_default {
498 0
499 } else {
500 key.encoded_len(1)
501 }) + (if val == &val_default {
502 0
503 } else {
504 val.encoded_len(2)
505 });
506
507 encoding::key_len(tag) + encoding::encoded_len_varint(len as u64) + len
508 })
509 .sum::<usize>()
510 }
511}
512
513macro_rules! varint {
516 ($ty:ident, $default:expr) => (
517 varint!($ty, $default, to_uint64(self) { *self as u64 }, from_uint64(v) { v as $ty });
518 );
519
520 ($ty:ty, $default:expr, to_uint64($slf:ident) $to_uint64:expr, from_uint64($val:ident) $from_uint64:expr) => (
521
522 impl NativeType for $ty {
523 const TYPE: WireType = WireType::Varint;
524
525 #[inline]
526 fn is_default(&self) -> bool {
527 *self == $default
528 }
529
530 #[inline]
531 fn encode_value(&$slf, dst: &mut BytesMut) {
532 encoding::encode_varint($to_uint64, dst);
533 }
534
535 #[inline]
536 fn encoded_len(&$slf, tag: u32) -> usize {
537 encoding::key_len(tag) + encoding::encoded_len_varint($to_uint64)
538 }
539
540 #[inline]
541 fn value_len(&$slf) -> usize {
542 encoding::encoded_len_varint($to_uint64)
543 }
544
545 #[inline]
546 fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
547 *self = encoding::decode_varint(src).map(|$val| $from_uint64)?;
548 Ok(())
549 }
550 }
551 );
552}
553
554varint!(bool, false,
555 to_uint64(self) u64::from(*self),
556 from_uint64(value) value != 0);
557varint!(i32, 0i32);
558varint!(i64, 0i64);
559varint!(u32, 0u32);
560varint!(u64, 0u64);
561
562macro_rules! fixed_width {
565 ($ty:ty,
566 $width:expr,
567 $wire_type:expr,
568 $default:expr,
569 $put:expr,
570 $get:expr) => {
571 impl NativeType for $ty {
572 const TYPE: WireType = $wire_type;
573
574 #[inline]
575 fn is_default(&self) -> bool {
576 *self == $default
577 }
578
579 #[inline]
580 fn encode_value(&self, dst: &mut BytesMut) {
581 $put(dst, *self);
582 }
583
584 #[inline]
585 fn encoded_len(&self, tag: u32) -> usize {
586 encoding::key_len(tag) + $width
587 }
588
589 #[inline]
590 fn value_len(&self) -> usize {
591 $width
592 }
593
594 #[inline]
595 fn merge(&mut self, src: &mut Bytes) -> Result<(), DecodeError> {
596 if src.len() < $width {
597 return Err(DecodeError::new("Buffer underflow"));
598 }
599 *self = $get(src);
600 Ok(())
601 }
602 }
603 };
604}
605
606fixed_width!(
607 f32,
608 4,
609 WireType::ThirtyTwoBit,
610 0f32,
611 BufMut::put_f32_le,
612 Buf::get_f32_le
613);
614fixed_width!(
615 f64,
616 8,
617 WireType::SixtyFourBit,
618 0f64,
619 BufMut::put_f64_le,
620 Buf::get_f64_le
621);
622
623#[cfg(test)]
624mod tests {
625 use super::*;
626
627 #[derive(Clone, PartialEq, Debug, Default)]
628 pub struct TestMessage {
629 f: f64,
630 props: HashMap<String, u32>,
631 b: bool,
632 opt: Option<String>,
633 }
634
635 impl Message for TestMessage {
636 fn write(&self, dst: &mut BytesMut) {
637 NativeType::serialize(&self.f, 1, DefaultValue::Default, dst);
638 NativeType::serialize(&self.props, 2, DefaultValue::Default, dst);
639 NativeType::serialize(&self.b, 3, DefaultValue::Default, dst);
640 NativeType::serialize(&self.opt, 4, DefaultValue::Default, dst);
641 }
642
643 #[inline]
644 fn read(src: &mut Bytes) -> Result<Self, DecodeError> {
645 let mut msg = Self::default();
646 while !src.is_empty() {
647 let (tag, wire_type) = encoding::decode_key(src)?;
648 match tag {
649 1 => NativeType::deserialize(&mut msg.f, tag, wire_type, src)?,
650 2 => NativeType::deserialize(&mut msg.props, tag, wire_type, src)?,
651 3 => NativeType::deserialize(&mut msg.b, tag, wire_type, src)?,
652 4 => NativeType::deserialize(&mut msg.opt, tag, wire_type, src)?,
653 _ => encoding::skip_field(wire_type, tag, src)?,
654 }
655 }
656 Ok(msg)
657 }
658
659 #[inline]
660 fn encoded_len(&self) -> usize {
661 0 + NativeType::serialized_len(&self.f, 1, DefaultValue::Default)
662 + NativeType::serialized_len(&self.props, 2, DefaultValue::Default)
663 + NativeType::serialized_len(&self.b, 3, DefaultValue::Default)
664 + NativeType::serialized_len(&self.opt, 4, DefaultValue::Default)
665 }
666 }
667
668 #[test]
669 fn test_hashmap_default_values() {
670 let mut msg = TestMessage::default();
671
672 msg.f = 382.8263;
673 msg.b = true;
674 msg.props.insert("test1".to_string(), 1);
675 msg.props.insert("test2".to_string(), 0);
676 msg.props.insert("".to_string(), 0);
677
678 let mut buf = BytesMut::new();
679 msg.write(&mut buf);
680 assert_eq!(Message::encoded_len(&msg), 33);
681 assert_eq!(buf.len(), 33);
682
683 let mut buf2 = BytesMut::new();
684 msg.serialize(1, DefaultValue::Default, &mut buf2);
685 assert_eq!(NativeType::encoded_len(&msg, 1), 35);
686 assert_eq!(buf2.len(), 35);
687
688 let msg2 = TestMessage::read(&mut buf.freeze()).unwrap();
689 assert_eq!(Message::encoded_len(&msg2), 33);
690 assert_eq!(msg, msg2);
691
692 let mut buf2 = buf2.freeze();
693 let mut msg3 = TestMessage::default();
694 let (tag, wire_type) = encoding::decode_key(&mut buf2).unwrap();
695 msg3.deserialize(tag, wire_type, &mut buf2).unwrap();
696 assert_eq!(msg, msg3);
697 }
698}