1use std::net::{Ipv4Addr, Ipv6Addr};
11
12use bytes::{Buf, BufMut};
13use thiserror::Error;
14
15use crate::{VarInt, VarIntBoundsExceeded};
16
17#[derive(Error, Debug, Copy, Clone, Eq, PartialEq)]
19#[error("unexpected end of buffer")]
20pub struct UnexpectedEnd;
21
22pub type Result<T> = ::std::result::Result<T, UnexpectedEnd>;
24
25pub trait Codec: Sized {
27 fn decode<B: Buf>(buf: &mut B) -> Result<Self>;
29 fn encode<B: BufMut>(&self, buf: &mut B);
31}
32
33impl Codec for u8 {
34 fn decode<B: Buf>(buf: &mut B) -> Result<Self> {
35 if buf.remaining() < 1 {
36 return Err(UnexpectedEnd);
37 }
38 Ok(buf.get_u8())
39 }
40 fn encode<B: BufMut>(&self, buf: &mut B) {
41 buf.put_u8(*self);
42 }
43}
44
45impl Codec for u16 {
46 fn decode<B: Buf>(buf: &mut B) -> Result<Self> {
47 if buf.remaining() < 2 {
48 return Err(UnexpectedEnd);
49 }
50 Ok(buf.get_u16())
51 }
52 fn encode<B: BufMut>(&self, buf: &mut B) {
53 buf.put_u16(*self);
54 }
55}
56
57impl Codec for u32 {
58 fn decode<B: Buf>(buf: &mut B) -> Result<Self> {
59 if buf.remaining() < 4 {
60 return Err(UnexpectedEnd);
61 }
62 Ok(buf.get_u32())
63 }
64 fn encode<B: BufMut>(&self, buf: &mut B) {
65 buf.put_u32(*self);
66 }
67}
68
69impl Codec for u64 {
70 fn decode<B: Buf>(buf: &mut B) -> Result<Self> {
71 if buf.remaining() < 8 {
72 return Err(UnexpectedEnd);
73 }
74 Ok(buf.get_u64())
75 }
76 fn encode<B: BufMut>(&self, buf: &mut B) {
77 buf.put_u64(*self);
78 }
79}
80
81impl Codec for Ipv4Addr {
82 fn decode<B: Buf>(buf: &mut B) -> Result<Self> {
83 if buf.remaining() < 4 {
84 return Err(UnexpectedEnd);
85 }
86 let mut octets = [0; 4];
87 buf.copy_to_slice(&mut octets);
88 Ok(octets.into())
89 }
90 fn encode<B: BufMut>(&self, buf: &mut B) {
91 buf.put_slice(&self.octets());
92 }
93}
94
95impl Codec for Ipv6Addr {
96 fn decode<B: Buf>(buf: &mut B) -> Result<Self> {
97 if buf.remaining() < 16 {
98 return Err(UnexpectedEnd);
99 }
100 let mut octets = [0; 16];
101 buf.copy_to_slice(&mut octets);
102 Ok(octets.into())
103 }
104 fn encode<B: BufMut>(&self, buf: &mut B) {
105 buf.put_slice(&self.octets());
106 }
107}
108
109pub trait BufExt {
111 fn get<T: Codec>(&mut self) -> Result<T>;
113 fn get_var(&mut self) -> Result<u64>;
115}
116
117impl<T: Buf> BufExt for T {
118 fn get<U: Codec>(&mut self) -> Result<U> {
119 U::decode(self)
120 }
121
122 fn get_var(&mut self) -> Result<u64> {
123 Ok(VarInt::decode(self)?.into_inner())
124 }
125}
126
127pub trait BufMutExt {
129 fn write<T: Codec>(&mut self, x: T);
131 fn write_var(&mut self, x: u64) -> std::result::Result<(), VarIntBoundsExceeded>;
133 fn write_var_or_debug_assert(&mut self, x: u64) {
135 if self.write_var(x).is_err() {
136 tracing::error!("VarInt overflow: {} exceeds maximum", x);
137 debug_assert!(false, "VarInt overflow: {}", x);
138 }
139 }
140}
141
142impl<T: BufMut> BufMutExt for T {
143 fn write<U: Codec>(&mut self, x: U) {
144 x.encode(self);
145 }
146
147 fn write_var(&mut self, x: u64) -> std::result::Result<(), VarIntBoundsExceeded> {
148 VarInt::encode_checked(x, self)
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use bytes::BytesMut;
156
157 #[test]
160 fn u8_roundtrip() {
161 let mut buf = BytesMut::new();
162 let v: u8 = 0xAB;
163 buf.write(v);
164 let mut read = buf.freeze();
165 let decoded: u8 = Codec::decode(&mut read).unwrap();
166 assert_eq!(decoded, v);
167 }
168
169 #[test]
170 fn u8_roundtrip_zero() {
171 let mut buf = BytesMut::new();
172 let v: u8 = 0;
173 buf.write(v);
174 let mut read = buf.freeze();
175 let decoded: u8 = Codec::decode(&mut read).unwrap();
176 assert_eq!(decoded, v);
177 }
178
179 #[test]
180 fn u8_roundtrip_max() {
181 let mut buf = BytesMut::new();
182 let v: u8 = u8::MAX;
183 buf.write(v);
184 let mut read = buf.freeze();
185 let decoded: u8 = Codec::decode(&mut read).unwrap();
186 assert_eq!(decoded, v);
187 }
188
189 #[test]
190 fn u16_roundtrip() {
191 let mut buf = BytesMut::new();
192 let v: u16 = 0xABCD;
193 buf.write(v);
194 let mut read = buf.freeze();
195 let decoded: u16 = Codec::decode(&mut read).unwrap();
196 assert_eq!(decoded, v);
197 }
198
199 #[test]
200 fn u16_roundtrip_zero() {
201 let mut buf = BytesMut::new();
202 let v: u16 = 0;
203 buf.write(v);
204 let mut read = buf.freeze();
205 let decoded: u16 = Codec::decode(&mut read).unwrap();
206 assert_eq!(decoded, v);
207 }
208
209 #[test]
210 fn u16_roundtrip_max() {
211 let mut buf = BytesMut::new();
212 let v: u16 = u16::MAX;
213 buf.write(v);
214 let mut read = buf.freeze();
215 let decoded: u16 = Codec::decode(&mut read).unwrap();
216 assert_eq!(decoded, v);
217 }
218
219 #[test]
220 fn u32_roundtrip() {
221 let mut buf = BytesMut::new();
222 let v: u32 = 0xDEAD_BEEF;
223 buf.write(v);
224 let mut read = buf.freeze();
225 let decoded: u32 = Codec::decode(&mut read).unwrap();
226 assert_eq!(decoded, v);
227 }
228
229 #[test]
230 fn u32_roundtrip_zero() {
231 let mut buf = BytesMut::new();
232 let v: u32 = 0;
233 buf.write(v);
234 let mut read = buf.freeze();
235 let decoded: u32 = Codec::decode(&mut read).unwrap();
236 assert_eq!(decoded, v);
237 }
238
239 #[test]
240 fn u32_roundtrip_max() {
241 let mut buf = BytesMut::new();
242 let v: u32 = u32::MAX;
243 buf.write(v);
244 let mut read = buf.freeze();
245 let decoded: u32 = Codec::decode(&mut read).unwrap();
246 assert_eq!(decoded, v);
247 }
248
249 #[test]
250 fn u64_roundtrip() {
251 let mut buf = BytesMut::new();
252 let v: u64 = 0x0123_4567_89AB_CDEF;
253 buf.write(v);
254 let mut read = buf.freeze();
255 let decoded: u64 = Codec::decode(&mut read).unwrap();
256 assert_eq!(decoded, v);
257 }
258
259 #[test]
260 fn u64_roundtrip_zero() {
261 let mut buf = BytesMut::new();
262 let v: u64 = 0;
263 buf.write(v);
264 let mut read = buf.freeze();
265 let decoded: u64 = Codec::decode(&mut read).unwrap();
266 assert_eq!(decoded, v);
267 }
268
269 #[test]
270 fn u64_roundtrip_max() {
271 let mut buf = BytesMut::new();
272 let v: u64 = u64::MAX;
273 buf.write(v);
274 let mut read = buf.freeze();
275 let decoded: u64 = Codec::decode(&mut read).unwrap();
276 assert_eq!(decoded, v);
277 }
278
279 #[test]
280 fn ipv4_roundtrip() {
281 let mut buf = BytesMut::new();
282 let v: Ipv4Addr = "192.168.1.1".parse().unwrap();
283 buf.write(v);
284 let mut read = buf.freeze();
285 let decoded: Ipv4Addr = Codec::decode(&mut read).unwrap();
286 assert_eq!(decoded, v);
287 }
288
289 #[test]
290 fn ipv4_zero() {
291 let mut buf = BytesMut::new();
292 let v: Ipv4Addr = Ipv4Addr::UNSPECIFIED;
293 buf.write(v);
294 let mut read = buf.freeze();
295 let decoded: Ipv4Addr = Codec::decode(&mut read).unwrap();
296 assert_eq!(decoded, v);
297 }
298
299 #[test]
300 fn ipv4_broadcast() {
301 let mut buf = BytesMut::new();
302 let v: Ipv4Addr = Ipv4Addr::BROADCAST;
303 buf.write(v);
304 let mut read = buf.freeze();
305 let decoded: Ipv4Addr = Codec::decode(&mut read).unwrap();
306 assert_eq!(decoded, v);
307 }
308
309 #[test]
310 fn ipv6_roundtrip() {
311 let mut buf = BytesMut::new();
312 let v: Ipv6Addr = "2001:db8::1".parse().unwrap();
313 buf.write(v);
314 let mut read = buf.freeze();
315 let decoded: Ipv6Addr = Codec::decode(&mut read).unwrap();
316 assert_eq!(decoded, v);
317 }
318
319 #[test]
320 fn ipv6_loopback() {
321 let mut buf = BytesMut::new();
322 let v: Ipv6Addr = "::1".parse().unwrap();
323 buf.write(v);
324 let mut read = buf.freeze();
325 let decoded: Ipv6Addr = Codec::decode(&mut read).unwrap();
326 assert_eq!(decoded, v);
327 }
328
329 #[test]
330 fn ipv6_unspecified() {
331 let mut buf = BytesMut::new();
332 let v: Ipv6Addr = "::".parse().unwrap();
333 buf.write(v);
334 let mut read = buf.freeze();
335 let decoded: Ipv6Addr = Codec::decode(&mut read).unwrap();
336 assert_eq!(decoded, v);
337 }
338
339 #[test]
342 fn u8_decode_empty_fails() {
343 let buf = BytesMut::new();
344 let mut read = buf.freeze();
345 let result: Result<u8> = Codec::decode(&mut read);
346 assert!(result.is_err());
347 assert_eq!(result.unwrap_err(), UnexpectedEnd);
348 }
349
350 #[test]
351 fn u16_decode_insufficient_fails() {
352 let mut buf = BytesMut::new();
353 buf.put_u8(0xAB);
354 let mut read = buf.freeze();
355 let result: Result<u16> = Codec::decode(&mut read);
356 assert!(result.is_err());
357 }
358
359 #[test]
360 fn u32_decode_insufficient_fails() {
361 let mut buf = BytesMut::new();
362 buf.put_slice(&[0; 3]);
363 let mut read = buf.freeze();
364 let result: Result<u32> = Codec::decode(&mut read);
365 assert!(result.is_err());
366 }
367
368 #[test]
369 fn u64_decode_insufficient_fails() {
370 let mut buf = BytesMut::new();
371 buf.put_slice(&[0; 7]);
372 let mut read = buf.freeze();
373 let result: Result<u64> = Codec::decode(&mut read);
374 assert!(result.is_err());
375 }
376
377 #[test]
378 fn ipv4_decode_insufficient_fails() {
379 let mut buf = BytesMut::new();
380 buf.put_slice(&[0; 3]);
381 let mut read = buf.freeze();
382 let result: Result<Ipv4Addr> = Codec::decode(&mut read);
383 assert!(result.is_err());
384 }
385
386 #[test]
387 fn ipv6_decode_insufficient_fails() {
388 let mut buf = BytesMut::new();
389 buf.put_slice(&[0; 15]);
390 let mut read = buf.freeze();
391 let result: Result<Ipv6Addr> = Codec::decode(&mut read);
392 assert!(result.is_err());
393 }
394
395 #[test]
398 fn buf_ext_get_u32() {
399 let mut buf = BytesMut::new();
400 buf.put_u32(0xAABB_CCDD);
401 let mut read = buf.freeze();
402 let val: u32 = read.get().unwrap();
403 assert_eq!(val, 0xAABB_CCDD);
404 }
405
406 #[test]
407 fn buf_ext_get_var() {
408 let mut buf = BytesMut::new();
409 VarInt::from_u32(16383).encode(&mut buf);
410 let mut read = buf.freeze();
411 let val: u64 = read.get_var().unwrap();
412 assert_eq!(val, 16383);
413 }
414
415 #[test]
416 fn buf_ext_get_var_zero() {
417 let mut buf = BytesMut::new();
418 VarInt::from_u32(0).encode(&mut buf);
419 let mut read = buf.freeze();
420 let val: u64 = read.get_var().unwrap();
421 assert_eq!(val, 0);
422 }
423
424 #[test]
425 fn buf_ext_get_var_large() {
426 let mut buf = BytesMut::new();
427 let v = VarInt::MAX;
428 v.encode(&mut buf);
429 let mut read = buf.freeze();
430 let val: u64 = read.get_var().unwrap();
431 assert_eq!(val, v.into_inner());
432 }
433
434 #[test]
437 fn buf_mut_ext_write_u16() {
438 let mut buf = BytesMut::new();
439 let v: u16 = 0x1234;
440 buf.write(v);
441 let mut read = buf.freeze();
442 let decoded: u16 = Codec::decode(&mut read).unwrap();
443 assert_eq!(decoded, v);
444 }
445
446 #[test]
447 fn buf_mut_ext_write_var_small() {
448 let mut buf = BytesMut::new();
449 buf.write_var(42u64).unwrap();
450 let mut read = buf.freeze();
451 let decoded = VarInt::decode(&mut read).unwrap();
452 assert_eq!(decoded.into_inner(), 42);
453 }
454
455 #[test]
456 fn buf_mut_ext_write_var_medium() {
457 let mut buf = BytesMut::new();
458 buf.write_var(16383u64).unwrap();
459 let mut read = buf.freeze();
460 let decoded = VarInt::decode(&mut read).unwrap();
461 assert_eq!(decoded.into_inner(), 16383);
462 }
463
464 #[test]
465 fn buf_mut_ext_write_var_large() {
466 let mut buf = BytesMut::new();
467 buf.write_var(1_073_741_823u64).unwrap();
468 let mut read = buf.freeze();
469 let decoded = VarInt::decode(&mut read).unwrap();
470 assert_eq!(decoded.into_inner(), 1_073_741_823);
471 }
472
473 #[test]
474 fn buf_mut_ext_write_var_max() {
475 let mut buf = BytesMut::new();
476 let v = VarInt::MAX.into_inner();
477 buf.write_var(v).unwrap();
478 let mut read = buf.freeze();
479 let decoded = VarInt::decode(&mut read).unwrap();
480 assert_eq!(decoded.into_inner(), v);
481 }
482
483 #[test]
484 fn buf_mut_ext_write_var_overflow() {
485 let mut buf = BytesMut::new();
486 let result = buf.write_var(1u64 << 62);
487 assert!(result.is_err());
488 }
489
490 #[test]
491 fn buf_mut_ext_write_var_or_debug_assert_valid() {
492 let mut buf = BytesMut::new();
493 buf.write_var_or_debug_assert(42u64);
494 let mut read = buf.freeze();
495 let val = VarInt::decode(&mut read).unwrap();
496 assert_eq!(val.into_inner(), 42);
497 }
498
499 #[test]
500 fn write_var_or_debug_assert_overflow_logs_error() {
501 let mut buf = BytesMut::new();
503 buf.write_var_or_debug_assert(42u64);
504 let mut read = buf.freeze();
505 let decoded = VarInt::decode(&mut read).unwrap();
506 assert_eq!(decoded.into_inner(), 42);
507 }
508
509 #[test]
512 fn ext_traits_roundtrip_u32() {
513 let mut buf = BytesMut::new();
514 let v: u32 = 42;
515 buf.write(v);
516 let mut read = buf.freeze();
517 let decoded: u32 = read.get().unwrap();
518 assert_eq!(decoded, v);
519 }
520
521 #[test]
522 fn ext_traits_roundtrip_mixed_types() {
523 let mut buf = BytesMut::new();
524 buf.write(0xABu8);
525 buf.write(0x1234u16);
526 buf.write(0xDEAD_BEEFu32);
527 buf.write(0x0123_4567_89AB_CDEFu64);
528
529 let mut read = buf.freeze();
530 assert_eq!(read.get::<u8>().unwrap(), 0xAB);
531 assert_eq!(read.get::<u16>().unwrap(), 0x1234);
532 assert_eq!(read.get::<u32>().unwrap(), 0xDEAD_BEEF);
533 assert_eq!(read.get::<u64>().unwrap(), 0x0123_4567_89AB_CDEF);
534 }
535
536 #[test]
537 fn ext_traits_roundtrip_varint_mixed() {
538 let mut buf = BytesMut::new();
539 buf.write_var(0u64).unwrap();
540 buf.write_var(63u64).unwrap();
541 buf.write_var(64u64).unwrap();
542 buf.write_var(16383u64).unwrap();
543 buf.write_var(16384u64).unwrap();
544 buf.write_var(1_073_741_823u64).unwrap();
545 buf.write_var(1_073_741_824u64).unwrap();
546 buf.write_var(VarInt::MAX.into_inner()).unwrap();
547
548 let mut read = buf.freeze();
549 assert_eq!(read.get_var().unwrap(), 0);
550 assert_eq!(read.get_var().unwrap(), 63);
551 assert_eq!(read.get_var().unwrap(), 64);
552 assert_eq!(read.get_var().unwrap(), 16383);
553 assert_eq!(read.get_var().unwrap(), 16384);
554 assert_eq!(read.get_var().unwrap(), 1_073_741_823);
555 assert_eq!(read.get_var().unwrap(), 1_073_741_824);
556 assert_eq!(read.get_var().unwrap(), VarInt::MAX.into_inner());
557
558 assert!(!read.has_remaining());
560 }
561
562 #[test]
565 fn varint_decode_empty_fails() {
566 let buf = BytesMut::new();
567 let mut read = buf.freeze();
568 let result: Result<VarInt> = VarInt::decode(&mut read);
569 assert_eq!(result.unwrap_err(), UnexpectedEnd);
570 }
571
572 #[test]
573 fn varint_decode_partial_2byte_tag() {
574 let mut buf = BytesMut::new();
575 buf.put_u8(0b01_000000 | 42); let mut read = buf.freeze();
578 let result: Result<VarInt> = VarInt::decode(&mut read);
579 assert_eq!(result.unwrap_err(), UnexpectedEnd);
580 }
581
582 #[test]
583 fn varint_decode_partial_4byte_tag() {
584 let mut buf = BytesMut::new();
585 buf.put_u8(0b10_000000 | 42);
587 let mut read = buf.freeze();
588 let result: Result<VarInt> = VarInt::decode(&mut read);
589 assert_eq!(result.unwrap_err(), UnexpectedEnd);
590 }
591
592 #[test]
593 fn varint_decode_partial_8byte_tag() {
594 let mut buf = BytesMut::new();
595 buf.put_u8(0b11_000000);
597 let mut read = buf.freeze();
598 let result: Result<VarInt> = VarInt::decode(&mut read);
599 assert_eq!(result.unwrap_err(), UnexpectedEnd);
600 }
601
602 #[test]
605 fn varint_size_1_byte() {
606 assert_eq!(VarInt::from_u32(0).size(), 1);
607 assert_eq!(VarInt::from_u32(63).size(), 1);
608 }
609
610 #[test]
611 fn varint_size_2_bytes() {
612 assert_eq!(VarInt::from_u32(64).size(), 2);
613 assert_eq!(VarInt::from_u32(16383).size(), 2);
614 }
615
616 #[test]
617 fn varint_size_4_bytes() {
618 assert_eq!(VarInt::from_u32(16384).size(), 4);
619 assert_eq!(VarInt::from_u32(1_073_741_823).size(), 4);
620 }
621
622 #[test]
623 fn varint_size_8_bytes() {
624 assert_eq!(VarInt::from_u64(1_073_741_824).unwrap().size(), 8);
625 assert_eq!(VarInt::MAX.size(), 8);
626 }
627
628 #[test]
631 fn varint_from_u64_valid() {
632 let v = VarInt::from_u64(0).unwrap();
633 assert_eq!(v.into_inner(), 0);
634
635 let v = VarInt::from_u64(VarInt::MAX.into_inner()).unwrap();
636 assert_eq!(v.into_inner(), VarInt::MAX.into_inner());
637 }
638
639 #[test]
640 fn varint_from_u64_invalid() {
641 let result = VarInt::from_u64(1u64 << 62);
642 assert!(result.is_err());
643 }
644
645 #[test]
646 fn varint_try_from_u64_valid() {
647 use std::convert::TryFrom;
648 let v = VarInt::try_from(42u64).unwrap();
649 assert_eq!(v.into_inner(), 42);
650 }
651
652 #[test]
653 fn varint_try_from_u64_invalid() {
654 use std::convert::TryFrom;
655 let result = VarInt::try_from(1u64 << 62);
656 assert!(result.is_err());
657 }
658
659 #[test]
660 fn varint_try_from_u128_valid() {
661 use std::convert::TryFrom;
662 let v = VarInt::try_from(42u128).unwrap();
663 assert_eq!(v.into_inner(), 42);
664 }
665
666 #[test]
667 fn varint_try_from_u128_invalid() {
668 use std::convert::TryFrom;
669 let result = VarInt::try_from((1u128 << 62) + 1);
670 assert!(result.is_err());
671 }
672
673 #[test]
674 fn varint_try_from_usize_valid() {
675 use std::convert::TryFrom;
676 let v = VarInt::try_from(42usize).unwrap();
677 assert_eq!(v.into_inner(), 42);
678 }
679
680 #[test]
681 fn varint_into_u64() {
682 let v = VarInt::from_u32(42);
683 let val: u64 = v.into();
684 assert_eq!(val, 42);
685 }
686
687 #[test]
688 fn varint_from_u8() {
689 let v: VarInt = 42u8.into();
690 assert_eq!(v.into_inner(), 42);
691 }
692
693 #[test]
694 fn varint_from_u16() {
695 let v: VarInt = 16383u16.into();
696 assert_eq!(v.into_inner(), 16383);
697 }
698
699 #[test]
700 fn varint_from_u32() {
701 let v: VarInt = 42u32.into();
702 assert_eq!(v.into_inner(), 42);
703 }
704
705 #[test]
706 fn varint_display() {
707 let v = VarInt::from_u32(42);
708 assert_eq!(format!("{v}"), "42");
709 }
710
711 #[test]
712 fn varint_debug() {
713 let v = VarInt::from_u32(42);
714 assert_eq!(format!("{v:?}"), "42");
715 }
716
717 #[test]
718 fn varint_ord() {
719 let small = VarInt::from_u32(10);
720 let large = VarInt::from_u32(20);
721 assert!(small < large);
722 assert!(large > small);
723 assert_eq!(small.min(large), small);
724 assert_eq!(small.max(large), large);
725 }
726
727 #[test]
728 fn varint_hash() {
729 use std::collections::hash_map::DefaultHasher;
730 use std::hash::{Hash, Hasher};
731
732 let v1 = VarInt::from_u32(42);
733 let v2 = VarInt::from_u32(42);
734 let mut hasher1 = DefaultHasher::new();
735 let mut hasher2 = DefaultHasher::new();
736 v1.hash(&mut hasher1);
737 v2.hash(&mut hasher2);
738 assert_eq!(hasher1.finish(), hasher2.finish());
739 }
740}