1use crate::wire_format::{MAX_VARINT_BYTES, VARINT_CONTINUATION_BIT, VARINT_PAYLOAD_MASK};
21use crate::{ProtobufError, Result};
22use ::std::io::Read;
23use ::std::iter::Iterator;
24
25use super::Varint;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
35pub struct DecodeState {
36 decoded_value: u64,
37 shift: u32,
38 bytes_consumed: usize,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum DecodeOutcome {
47 Complete(Varint),
49 Incomplete(DecodeState),
54 Empty,
56}
57
58impl DecodeState {
59 pub(crate) fn new() -> Self {
60 Self::default()
61 }
62
63 #[allow(dead_code)] pub(crate) fn bytes_consumed(&self) -> usize {
65 self.bytes_consumed
66 }
67
68 pub(crate) fn feed<I, E>(mut self, iter: I) -> Result<DecodeOutcome>
69 where
70 I: Iterator<Item = ::std::result::Result<u8, E>>,
71 E: Into<ProtobufError>,
72 {
73 for byte_result in iter.take(MAX_VARINT_BYTES - self.bytes_consumed) {
74 let byte = byte_result.map_err(Into::into)?;
75 self.bytes_consumed += 1;
76
77 let value = (byte & VARINT_PAYLOAD_MASK) as u64;
78 self.decoded_value |= value << self.shift;
79
80 if byte & VARINT_CONTINUATION_BIT == 0 {
81 let result_bytes = self.decoded_value.to_le_bytes();
82 return Ok(DecodeOutcome::Complete(Varint::new(result_bytes)));
83 }
84 self.shift += 7;
85 }
86
87 if self.bytes_consumed == 0 {
88 return Ok(DecodeOutcome::Empty);
89 }
90
91 if self.bytes_consumed >= MAX_VARINT_BYTES {
92 return Err(ProtobufError::VarintTooLong);
93 }
94
95 Ok(DecodeOutcome::Incomplete(self))
96 }
97}
98
99pub struct VarintIterator<I: Iterator> {
104 bytes: I,
105}
106
107pub struct ToResultIterator<I> {
111 inner: I,
112}
113
114impl<I> Iterator for ToResultIterator<I>
115where
116 I: Iterator<Item = u8>,
117{
118 type Item = ::std::result::Result<u8, ::std::convert::Infallible>;
119
120 fn next(&mut self) -> Option<Self::Item> {
121 self.inner.next().map(Ok)
122 }
123}
124
125impl<I, E> VarintIterator<I>
126where
127 I: Iterator<Item = ::std::result::Result<u8, E>>,
128 E: Into<ProtobufError>,
129{
130 fn new(bytes: I) -> Self {
131 Self { bytes }
132 }
133}
134
135impl<I, E> Iterator for VarintIterator<I>
136where
137 I: Iterator<Item = ::std::result::Result<u8, E>>,
138 E: Into<ProtobufError>,
139{
140 type Item = Result<Varint>;
141
142 fn next(&mut self) -> Option<Self::Item> {
143 match DecodeState::new().feed(&mut self.bytes) {
144 Ok(DecodeOutcome::Complete(v)) => Some(Ok(v)),
145 Ok(DecodeOutcome::Empty) => None,
146 Ok(DecodeOutcome::Incomplete(_)) => Some(Err(ProtobufError::UnexpectedEof)),
147 Err(e) => Some(Err(e)),
148 }
149 }
150}
151
152pub trait IteratorExtVarint {
174 fn read_varint(self) -> Result<Option<Varint>>;
179
180 fn read_varint_partial(self) -> Result<DecodeOutcome>;
182
183 fn read_varint_resume(self, state: DecodeState) -> Result<DecodeOutcome>;
188
189 fn read_varints(self) -> VarintIterator<ToResultIterator<Self>>
206 where
207 Self: Sized + Iterator<Item = u8>;
208}
209
210impl<I> IteratorExtVarint for I
211where
212 I: Iterator<Item = u8>,
213{
214 fn read_varint(self) -> Result<Option<Varint>> {
215 match DecodeState::new().feed(self.map(Ok::<u8, ::std::convert::Infallible>)) {
216 Ok(DecodeOutcome::Complete(v)) => Ok(Some(v)),
217 Ok(DecodeOutcome::Empty) => Ok(None),
218 Ok(DecodeOutcome::Incomplete(_)) => Err(crate::ProtobufError::UnexpectedEof),
219 Err(e) => Err(e),
220 }
221 }
222
223 fn read_varint_partial(self) -> Result<DecodeOutcome> {
224 DecodeState::new().feed(self.map(Ok::<u8, ::std::convert::Infallible>))
225 }
226
227 fn read_varint_resume(self, state: DecodeState) -> Result<DecodeOutcome> {
228 state.feed(self.map(Ok::<u8, ::std::convert::Infallible>))
229 }
230
231 fn read_varints(self) -> VarintIterator<ToResultIterator<Self>>
232 where
233 Self: Sized,
234 {
235 VarintIterator::new(ToResultIterator { inner: self })
236 }
237}
238
239pub trait TryIteratorExtVarint {
265 fn read_varint(self) -> Result<Option<Varint>>;
270
271 fn read_varint_partial(self) -> Result<DecodeOutcome>;
273
274 fn read_varint_resume(self, state: DecodeState) -> Result<DecodeOutcome>;
276
277 fn read_varints(self) -> VarintIterator<Self>
296 where
297 Self: Sized + Iterator;
298}
299
300impl<I, E> TryIteratorExtVarint for I
301where
302 I: Iterator<Item = ::std::result::Result<u8, E>>,
303 E: Into<ProtobufError>,
304{
305 fn read_varint(self) -> Result<Option<Varint>> {
306 match DecodeState::new().feed(self) {
307 Ok(DecodeOutcome::Complete(v)) => Ok(Some(v)),
308 Ok(DecodeOutcome::Empty) => Ok(None),
309 Ok(DecodeOutcome::Incomplete(_)) => Err(crate::ProtobufError::UnexpectedEof),
310 Err(e) => Err(e),
311 }
312 }
313
314 fn read_varint_partial(self) -> Result<DecodeOutcome> {
315 DecodeState::new().feed(self)
316 }
317
318 fn read_varint_resume(self, state: DecodeState) -> Result<DecodeOutcome> {
319 state.feed(self)
320 }
321
322 fn read_varints(self) -> VarintIterator<Self>
323 where
324 Self: Sized,
325 {
326 VarintIterator::new(self)
327 }
328}
329
330pub trait ReadExtVarint {
374 fn read_varint(&mut self) -> Result<Option<Varint>> {
379 match self.read_varint_partial()? {
380 DecodeOutcome::Complete(v) => Ok(Some(v)),
381 DecodeOutcome::Empty => Ok(None),
382 DecodeOutcome::Incomplete(_) => Err(crate::ProtobufError::UnexpectedEof),
383 }
384 }
385
386 fn read_varint_partial(&mut self) -> Result<DecodeOutcome>;
388
389 fn read_varint_resume(&mut self, state: DecodeState) -> Result<DecodeOutcome>;
391
392 fn read_varints(&mut self) -> VarintIterator<::std::io::Bytes<&mut Self>>
410 where
411 Self: ::std::io::Read;
412}
413
414impl<R> ReadExtVarint for R
415where
416 R: Read,
417{
418 #[allow(clippy::unbuffered_bytes)] fn read_varint_partial(&mut self) -> Result<DecodeOutcome> {
420 DecodeState::new().feed(self.bytes())
421 }
422
423 #[allow(clippy::unbuffered_bytes)]
424 fn read_varint_resume(&mut self, state: DecodeState) -> Result<DecodeOutcome> {
425 state.feed(self.bytes())
426 }
427
428 #[allow(clippy::unbuffered_bytes)] fn read_varints(&mut self) -> VarintIterator<::std::io::Bytes<&mut Self>> {
430 VarintIterator::new(self.bytes())
431 }
432}
433
434#[cfg(test)]
439mod tests {
440 use super::{DecodeOutcome, DecodeState, IteratorExtVarint, ReadExtVarint, TryIteratorExtVarint};
441 use crate::ProtobufError;
442 use crate::varint::Varint;
443
444 #[test]
445 fn test_read_varint_from_iterator() {
446 let input = [0x96, 0x01];
447 let iter = input.iter().copied();
448 let outcome = iter.read_varint_partial().unwrap();
449 let varint = match outcome {
450 DecodeOutcome::Complete(v) => v,
451 _ => panic!("expected Complete"),
452 };
453 assert_eq!(varint.to_uint64(), 150);
454 }
455
456 #[test]
457 fn test_iterator_ext_varint_trait() {
458 let bytes = vec![0x96, 0x01];
459 let iter = bytes.into_iter();
460 let outcome = iter.read_varint_partial().unwrap();
461 let varint = match outcome {
462 DecodeOutcome::Complete(v) => v,
463 _ => panic!("expected Complete"),
464 };
465 assert_eq!(varint.to_uint64(), 150);
466 }
467
468 #[test]
469 fn test_iterator_ext_varint_empty() {
470 let outcome = IteratorExtVarint::read_varint_partial(::std::iter::empty()).unwrap();
471 assert_eq!(outcome, DecodeOutcome::Empty);
472 }
473
474 #[test]
475 fn test_iterator_ext_varint_read_varints() {
476 let bytes = vec![0x96, 0x01, 0x7F, 0x01];
477 let iter = bytes.into_iter();
478 let varints: Vec<Varint> = iter.read_varints().collect::<Result<Vec<_>, _>>().unwrap();
479 assert_eq!(varints.len(), 3);
480 assert_eq!(varints[0].to_uint64(), 150);
481 assert_eq!(varints[1].to_uint64(), 127);
482 assert_eq!(varints[2].to_uint64(), 1);
483 }
484
485 #[test]
486 fn test_iterator_ext_varint_incomplete() {
487 let bytes = vec![0x80u8];
488 let result = IteratorExtVarint::read_varint_partial(bytes.into_iter());
489 assert!(result.is_ok());
490 assert!(matches!(
491 result,
492 Ok(DecodeOutcome::Incomplete(state)) if state.bytes_consumed() == 1
493 ));
494 }
495
496 #[test]
497 fn test_read_ext_varint_from_slice() {
498 let mut slice = &[0x96u8, 0x01][..];
499 let outcome = slice.read_varint_partial().unwrap();
500 let varint = match outcome {
501 DecodeOutcome::Complete(v) => v,
502 _ => panic!("expected Complete"),
503 };
504 assert_eq!(varint.to_uint64(), 150);
505 }
506
507 #[test]
508 fn test_read_ext_varint_from_empty_slice() {
509 let mut slice: &[u8] = &[];
510 let outcome = slice.read_varint_partial().unwrap();
511 assert_eq!(outcome, DecodeOutcome::Empty);
512 }
513
514 #[test]
515 fn test_read_ext_varint_from_slice_incomplete() {
516 let mut slice = &[0x80u8][..];
517 let result = slice.read_varint_partial();
518 assert!(result.is_ok());
519 assert!(matches!(result, Ok(DecodeOutcome::Incomplete(_))));
520 }
521
522 #[test]
523 fn test_decode_state_feed_incomplete() {
524 use ::std::convert::Infallible;
525
526 let slice = &[0x80u8][..];
527 let result = DecodeState::new().feed(slice.iter().copied().map(Ok::<u8, Infallible>));
528 assert!(result.is_ok());
529 assert!(matches!(result, Ok(DecodeOutcome::Incomplete(_))));
530 }
531
532 #[test]
533 fn test_decode_state_feed_resume_complete() {
534 use ::std::convert::Infallible;
535
536 let buf1 = &[0x80u8][..];
537 let Ok(DecodeOutcome::Incomplete(state)) =
538 DecodeState::new().feed(buf1.iter().copied().map(Ok::<u8, Infallible>))
539 else {
540 panic!("Expected Incomplete");
541 };
542
543 let buf2 = &[0x01u8][..];
544 let Ok(DecodeOutcome::Complete(varint)) =
545 state.feed(buf2.iter().copied().map(Ok::<u8, Infallible>))
546 else {
547 panic!("Expected Complete");
548 };
549 assert_eq!(varint.to_uint64(), 128);
550 }
551
552 #[test]
553 fn test_decode_state_feed_resume_incomplete() {
554 use ::std::convert::Infallible;
555
556 let buf1 = &[0x80u8][..];
557 let Ok(DecodeOutcome::Incomplete(state)) =
558 DecodeState::new().feed(buf1.iter().copied().map(Ok::<u8, Infallible>))
559 else {
560 panic!("Expected Incomplete");
561 };
562
563 let buf2 = &[0x80u8][..];
564 let result = state.feed(buf2.iter().copied().map(Ok::<u8, Infallible>));
565 assert!(result.is_ok());
566 assert!(matches!(result, Ok(DecodeOutcome::Incomplete(_))));
567 }
568
569 #[test]
570 fn test_try_iterator_ext_varint() {
571 use ::std::io::{Cursor, Read};
572
573 let data = vec![0x96, 0x01];
574 let reader = Cursor::new(data);
575 let iter = reader.bytes();
576 let outcome = TryIteratorExtVarint::read_varint_partial(iter).unwrap();
577 let varint = match outcome {
578 DecodeOutcome::Complete(v) => v,
579 _ => panic!("expected Complete"),
580 };
581 assert_eq!(varint.to_uint64(), 150);
582 }
583
584 #[test]
585 fn test_try_iterator_ext_varint_empty() {
586 use ::std::io::{Cursor, Read};
587
588 let data = vec![];
589 let reader = Cursor::new(data);
590 let iter = reader.bytes();
591 let outcome = TryIteratorExtVarint::read_varint_partial(iter).unwrap();
592 assert_eq!(outcome, DecodeOutcome::Empty);
593 }
594
595 #[test]
596 fn test_try_iterator_ext_varint_error() {
597 use ::std::io::ErrorKind;
598
599 let error = ::std::io::Error::new(ErrorKind::UnexpectedEof, "test error");
600 let iter = ::std::iter::once(Err(error));
601 let result = TryIteratorExtVarint::read_varint_partial(iter);
602
603 assert!(result.is_err());
604 if let Err(ProtobufError::IoError(io_err)) = result {
605 assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof);
606 } else {
607 panic!("Expected IoError");
608 }
609 }
610
611 #[test]
612 fn test_try_iterator_ext_varint_read_varints() {
613 use ::std::io::{Cursor, Read};
614
615 let data = vec![0x96, 0x01, 0x7F, 0x01];
616 let reader = Cursor::new(data);
617 let iter = reader.bytes();
618 let varints: Vec<Varint> = iter.read_varints().collect::<Result<Vec<_>, _>>().unwrap();
619 assert_eq!(varints.len(), 3);
620 assert_eq!(varints[0].to_uint64(), 150);
621 assert_eq!(varints[1].to_uint64(), 127);
622 assert_eq!(varints[2].to_uint64(), 1);
623 }
624
625 #[test]
626 fn test_read_ext_varint_trait() {
627 use ::std::io::Cursor;
628
629 let input = [0x96, 0x01];
630 let mut reader = Cursor::new(input);
631 let outcome = reader.read_varint_partial().unwrap();
632 let varint = match outcome {
633 DecodeOutcome::Complete(v) => v,
634 _ => panic!("expected Complete"),
635 };
636 assert_eq!(varint.to_uint64(), 150);
637 }
638
639 #[test]
640 fn test_read_ext_varint_read_varints() {
641 use ::std::io::Cursor;
642
643 let data = vec![0x96, 0x01, 0x7F, 0x01];
644 let mut reader = Cursor::new(data);
645 let varints: Vec<Varint> = reader
646 .read_varints()
647 .collect::<Result<Vec<_>, _>>()
648 .unwrap();
649 assert_eq!(varints.len(), 3);
650 assert_eq!(varints[0].to_uint64(), 150);
651 assert_eq!(varints[1].to_uint64(), 127);
652 assert_eq!(varints[2].to_uint64(), 1);
653 }
654
655 #[test]
656 fn test_write_varint_roundtrip() {
657 use crate::varint::WriteExtVarint;
658
659 let test_values = vec![0, 1, 127, 128, 150, 255, 256, 65535, 0x7FFFFFFF];
660
661 for &value in &test_values {
662 let varint = Varint::from_uint64(value);
663
664 let mut buffer = Vec::new();
665 buffer.write_varint(&varint).unwrap();
666
667 let iter = buffer.iter().copied();
668 let outcome = iter.read_varint_partial().unwrap();
669 let decoded_varint = match outcome {
670 DecodeOutcome::Complete(v) => v,
671 _ => panic!("expected Complete"),
672 };
673 let decoded_value = decoded_varint.to_uint64();
674
675 assert_eq!(decoded_value, value, "Roundtrip failed for value {}", value);
676 }
677 }
678}