miden_serde_utils/
byte_reader.rs

1// Copyright (c) Facebook, Inc. and its affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6#[cfg(feature = "std")]
7use alloc::string::ToString;
8use alloc::{format, string::String, vec::Vec};
9#[cfg(feature = "std")]
10use core::cell::{Ref, RefCell};
11#[cfg(feature = "std")]
12use std::io::BufRead;
13
14use crate::{Deserializable, DeserializationError};
15
16// BYTE READER TRAIT
17// ================================================================================================
18
19/// Defines how primitive values are to be read from `Self`.
20///
21/// Whenever data is read from the reader using any of the `read_*` functions, the reader advances
22/// to the next unread byte. If the error occurs, the reader is not rolled back to the state prior
23/// to calling any of the function.
24pub trait ByteReader {
25    // REQUIRED METHODS
26    // --------------------------------------------------------------------------------------------
27
28    /// Returns a single byte read from `self`.
29    ///
30    /// # Errors
31    /// Returns a [DeserializationError] error the reader is at EOF.
32    fn read_u8(&mut self) -> Result<u8, DeserializationError>;
33
34    /// Returns the next byte to be read from `self` without advancing the reader to the next byte.
35    ///
36    /// # Errors
37    /// Returns a [DeserializationError] error the reader is at EOF.
38    fn peek_u8(&self) -> Result<u8, DeserializationError>;
39
40    /// Returns a slice of bytes of the specified length read from `self`.
41    ///
42    /// # Errors
43    /// Returns a [DeserializationError] if a slice of the specified length could not be read
44    /// from `self`.
45    fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError>;
46
47    /// Returns a byte array of length `N` read from `self`.
48    ///
49    /// # Errors
50    /// Returns a [DeserializationError] if an array of the specified length could not be read
51    /// from `self`.
52    fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError>;
53
54    /// Checks if it is possible to read at least `num_bytes` bytes from this ByteReader
55    ///
56    /// # Errors
57    /// Returns an error if, when reading the requested number of bytes, we go beyond the
58    /// the data available in the reader.
59    fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError>;
60
61    /// Returns true if there are more bytes left to be read from `self`.
62    fn has_more_bytes(&self) -> bool;
63
64    /// Returns the maximum number of elements that can be safely allocated, given each
65    /// element occupies `element_size` bytes when serialized.
66    ///
67    /// This can be used by callers to pre-validate collection lengths before iterating,
68    /// preventing denial-of-service attacks from malicious length prefixes that claim
69    /// billions of elements.
70    ///
71    /// The default implementation returns `usize::MAX`, meaning no limit is enforced.
72    /// [`BudgetedReader`] overrides this to return `remaining_budget / element_size`,
73    /// providing tight, adaptive limits based on the caller's budget.
74    ///
75    /// # Arguments
76    /// * `element_size` - The serialized size of one element, from
77    ///   [`Deserializable::min_serialized_size`]. Defaults to `size_of::<D>()` but can be
78    ///   overridden for types where serialized size differs from in-memory size.
79    fn max_alloc(&self, _element_size: usize) -> usize {
80        usize::MAX
81    }
82
83    // PROVIDED METHODS
84    // --------------------------------------------------------------------------------------------
85
86    /// Returns a boolean value read from `self` consuming 1 byte from the reader.
87    ///
88    /// # Errors
89    /// Returns a [DeserializationError] if a u16 value could not be read from `self`.
90    fn read_bool(&mut self) -> Result<bool, DeserializationError> {
91        let byte = self.read_u8()?;
92        match byte {
93            0 => Ok(false),
94            1 => Ok(true),
95            _ => Err(DeserializationError::InvalidValue(format!("{byte} is not a boolean value"))),
96        }
97    }
98
99    /// Returns a u16 value read from `self` in little-endian byte order.
100    ///
101    /// # Errors
102    /// Returns a [DeserializationError] if a u16 value could not be read from `self`.
103    fn read_u16(&mut self) -> Result<u16, DeserializationError> {
104        let bytes = self.read_array::<2>()?;
105        Ok(u16::from_le_bytes(bytes))
106    }
107
108    /// Returns a u32 value read from `self` in little-endian byte order.
109    ///
110    /// # Errors
111    /// Returns a [DeserializationError] if a u32 value could not be read from `self`.
112    fn read_u32(&mut self) -> Result<u32, DeserializationError> {
113        let bytes = self.read_array::<4>()?;
114        Ok(u32::from_le_bytes(bytes))
115    }
116
117    /// Returns a u64 value read from `self` in little-endian byte order.
118    ///
119    /// # Errors
120    /// Returns a [DeserializationError] if a u64 value could not be read from `self`.
121    fn read_u64(&mut self) -> Result<u64, DeserializationError> {
122        let bytes = self.read_array::<8>()?;
123        Ok(u64::from_le_bytes(bytes))
124    }
125
126    /// Returns a u128 value read from `self` in little-endian byte order.
127    ///
128    /// # Errors
129    /// Returns a [DeserializationError] if a u128 value could not be read from `self`.
130    fn read_u128(&mut self) -> Result<u128, DeserializationError> {
131        let bytes = self.read_array::<16>()?;
132        Ok(u128::from_le_bytes(bytes))
133    }
134
135    /// Returns a usize value read from `self` in [vint64](https://docs.rs/vint64/latest/vint64/)
136    /// format.
137    ///
138    /// # Errors
139    /// Returns a [DeserializationError] if:
140    /// * usize value could not be read from `self`.
141    /// * encoded value is greater than `usize` maximum value on a given platform.
142    fn read_usize(&mut self) -> Result<usize, DeserializationError> {
143        let first_byte = self.peek_u8()?;
144        let length = first_byte.trailing_zeros() as usize + 1;
145
146        let result = if length == 9 {
147            // 9-byte special case
148            self.read_u8()?;
149            let value = self.read_array::<8>()?;
150            u64::from_le_bytes(value)
151        } else {
152            let mut encoded = [0u8; 8];
153            let value = self.read_slice(length)?;
154            encoded[..length].copy_from_slice(value);
155            u64::from_le_bytes(encoded) >> length
156        };
157
158        // check if the result value is within acceptable bounds for `usize` on a given platform
159        if result > usize::MAX as u64 {
160            return Err(DeserializationError::InvalidValue(format!(
161                "Encoded value must be less than {}, but {} was provided",
162                usize::MAX,
163                result
164            )));
165        }
166
167        Ok(result as usize)
168    }
169
170    /// Returns a byte vector of the specified length read from `self`.
171    ///
172    /// # Errors
173    /// Returns a [DeserializationError] if a vector of the specified length could not be read
174    /// from `self`.
175    fn read_vec(&mut self, len: usize) -> Result<Vec<u8>, DeserializationError> {
176        let data = self.read_slice(len)?;
177        Ok(data.to_vec())
178    }
179
180    /// Returns a String of the specified length read from `self`.
181    ///
182    /// # Errors
183    /// Returns a [DeserializationError] if a String of the specified length could not be read
184    /// from `self`.
185    fn read_string(&mut self, num_bytes: usize) -> Result<String, DeserializationError> {
186        let data = self.read_vec(num_bytes)?;
187        String::from_utf8(data).map_err(|err| DeserializationError::InvalidValue(format!("{err}")))
188    }
189
190    /// Reads a deserializable value from `self`.
191    ///
192    /// # Errors
193    /// Returns a [DeserializationError] if the specified value could not be read from `self`.
194    fn read<D>(&mut self) -> Result<D, DeserializationError>
195    where
196        Self: Sized,
197        D: Deserializable,
198    {
199        D::read_from(self)
200    }
201
202    /// Returns an iterator that deserializes `num_elements` instances of `D` from this reader.
203    ///
204    /// This method validates the requested count against the reader's capacity before returning
205    /// the iterator, rejecting implausible lengths early. Each element is then deserialized
206    /// lazily as the iterator is consumed.
207    ///
208    /// # Errors
209    ///
210    /// Returns an error if `num_elements` exceeds `self.max_alloc(D::min_serialized_size())`,
211    /// indicating the reader cannot allocate that many elements.
212    ///
213    /// # Example
214    ///
215    /// ```ignore
216    /// // Collect into a Vec
217    /// let items: Vec<u64> = reader
218    ///     .read_many_iter::<u64>(count)?
219    ///     .collect::<Result<_, _>>()?;
220    ///
221    /// // Collect directly into a BTreeMap (no intermediate Vec)
222    /// let map: BTreeMap<K, V> = reader
223    ///     .read_many_iter::<(K, V)>(count)?
224    ///     .collect::<Result<_, _>>()?;
225    /// ```
226    fn read_many_iter<D>(
227        &mut self,
228        num_elements: usize,
229    ) -> Result<ReadManyIter<'_, Self, D>, DeserializationError>
230    where
231        Self: Sized,
232        D: Deserializable,
233    {
234        let max_elements = self.max_alloc(D::min_serialized_size());
235        if num_elements > max_elements {
236            return Err(DeserializationError::InvalidValue(format!(
237                "requested {num_elements} elements but reader can provide at most {max_elements}"
238            )));
239        }
240        Ok(ReadManyIter {
241            reader: self,
242            remaining: num_elements,
243            _item: core::marker::PhantomData,
244        })
245    }
246}
247
248// READ MANY ITERATOR
249// ================================================================================================
250
251/// Iterator that lazily deserializes elements from a [`ByteReader`].
252///
253/// Created by [`ByteReader::read_many_iter`]. Each call to `next()` deserializes one element.
254/// This avoids upfront allocation and naturally integrates with [`BudgetedReader`] for
255/// protection against malicious inputs.
256pub struct ReadManyIter<'reader, R: ByteReader, D: Deserializable> {
257    reader: &'reader mut R,
258    remaining: usize,
259    _item: core::marker::PhantomData<D>,
260}
261
262impl<'reader, R: ByteReader, D: Deserializable> Iterator for ReadManyIter<'reader, R, D> {
263    type Item = Result<D, DeserializationError>;
264
265    fn next(&mut self) -> Option<Self::Item> {
266        if self.remaining > 0 {
267            self.remaining -= 1;
268            Some(D::read_from(self.reader))
269        } else {
270            None
271        }
272    }
273
274    fn size_hint(&self) -> (usize, Option<usize>) {
275        (self.remaining, Some(self.remaining))
276    }
277}
278
279impl<'reader, R: ByteReader, D: Deserializable> ExactSizeIterator for ReadManyIter<'reader, R, D> {}
280
281// STANDARD LIBRARY ADAPTER
282// ================================================================================================
283
284/// An adapter of [ByteReader] to any type that implements [std::io::Read]
285///
286/// In particular, this covers things like [std::fs::File], standard input, etc.
287#[cfg(feature = "std")]
288pub struct ReadAdapter<'a> {
289    // NOTE: The [ByteReader] trait does not currently support reader implementations that require
290    // mutation during `peek_u8`, `has_more_bytes`, and `check_eor`. These (or equivalent)
291    // operations on the standard library [std::io::BufRead] trait require a mutable reference, as
292    // it may be necessary to read from the underlying input to implement them.
293    //
294    // To handle this, we wrap the underlying reader in an [RefCell], this allows us to mutate the
295    // reader if necessary during a call to one of the above-mentioned trait methods, without
296    // sacrificing safety - at the cost of enforcing Rust's borrowing semantics dynamically.
297    //
298    // This should not be a problem in practice, except in the case where `read_slice` is called,
299    // and the reference returned is from `reader` directly, rather than `buf`. If a call to one
300    // of the above-mentioned methods is made while that reference is live, and we attempt to read
301    // from `reader`, a panic will occur.
302    //
303    // Ultimately, this should be addressed by making the [ByteReader] trait align with the
304    // standard library I/O traits, so this is a temporary solution.
305    reader: RefCell<std::io::BufReader<&'a mut dyn std::io::Read>>,
306    // A temporary buffer to store chunks read from `reader` that are larger than what is required
307    // for the higher-level [ByteReader] APIs.
308    //
309    // By default we attempt to satisfy reads from `reader` directly, but that is not always
310    // possible.
311    buf: alloc::vec::Vec<u8>,
312    // The position in `buf` at which we should start reading the next byte, when `buf` is
313    // non-empty.
314    pos: usize,
315    // This is set when we attempt to read from `reader` and get an empty buffer. This indicates
316    // that once we exhaust `buf`, we have truly reached end-of-file.
317    //
318    // We will use this to more accurately handle functions like `has_more_bytes` when this is set.
319    guaranteed_eof: bool,
320}
321
322#[cfg(feature = "std")]
323impl<'a> ReadAdapter<'a> {
324    /// Create a new [ByteReader] adapter for the given implementation of [std::io::Read]
325    pub fn new(reader: &'a mut dyn std::io::Read) -> Self {
326        Self {
327            reader: RefCell::new(std::io::BufReader::with_capacity(256, reader)),
328            buf: Default::default(),
329            pos: 0,
330            guaranteed_eof: false,
331        }
332    }
333
334    /// Get the internal adapter buffer as a (possibly empty) slice of bytes
335    #[inline(always)]
336    fn buffer(&self) -> &[u8] {
337        self.buf.get(self.pos..).unwrap_or(&[])
338    }
339
340    /// Get the internal adapter buffer as a slice of bytes, or `None` if the buffer is empty
341    #[inline(always)]
342    fn non_empty_buffer(&self) -> Option<&[u8]> {
343        self.buf.get(self.pos..).filter(|b| !b.is_empty())
344    }
345
346    /// Return the current reader buffer as a (possibly empty) slice of bytes.
347    ///
348    /// This buffer being empty _does not_ mean we're at EOF, you must call
349    /// [non_empty_reader_buffer_mut] first.
350    #[inline(always)]
351    fn reader_buffer(&self) -> Ref<'_, [u8]> {
352        Ref::map(self.reader.borrow(), |r| r.buffer())
353    }
354
355    /// Return the current reader buffer, reading from the underlying reader
356    /// if the buffer is empty.
357    ///
358    /// Returns `Ok` only if the buffer is non-empty, and no errors occurred
359    /// while filling it (if filling was needed).
360    fn non_empty_reader_buffer_mut(&mut self) -> Result<&[u8], DeserializationError> {
361        use std::io::ErrorKind;
362        let buf = self.reader.get_mut().fill_buf().map_err(|e| match e.kind() {
363            ErrorKind::UnexpectedEof => DeserializationError::UnexpectedEOF,
364            e => DeserializationError::UnknownError(e.to_string()),
365        })?;
366        if buf.is_empty() {
367            self.guaranteed_eof = true;
368            Err(DeserializationError::UnexpectedEOF)
369        } else {
370            Ok(buf)
371        }
372    }
373
374    /// Same as [non_empty_reader_buffer_mut], but with dynamically-enforced
375    /// borrow check rules so that it can be called in functions like `peek_u8`.
376    ///
377    /// This comes with overhead for the dynamic checks, so you should prefer
378    /// to call [non_empty_reader_buffer_mut] if you already have a mutable
379    /// reference to `self`
380    fn non_empty_reader_buffer(&self) -> Result<Ref<'_, [u8]>, DeserializationError> {
381        use std::io::ErrorKind;
382        let mut reader = self.reader.borrow_mut();
383        let buf = reader.fill_buf().map_err(|e| match e.kind() {
384            ErrorKind::UnexpectedEof => DeserializationError::UnexpectedEOF,
385            e => DeserializationError::UnknownError(e.to_string()),
386        })?;
387        if buf.is_empty() {
388            Err(DeserializationError::UnexpectedEOF)
389        } else {
390            // Re-borrow immutably
391            drop(reader);
392            Ok(self.reader_buffer())
393        }
394    }
395
396    /// Returns true if there is sufficient capacity remaining in `buf` to hold `n` bytes
397    #[inline]
398    fn has_remaining_capacity(&self, n: usize) -> bool {
399        let remaining = self.buf.capacity() - self.buffer().len();
400        remaining >= n
401    }
402
403    /// Takes the next byte from the input, returning an error if the operation fails
404    fn pop(&mut self) -> Result<u8, DeserializationError> {
405        if let Some(byte) = self.non_empty_buffer().map(|b| b[0]) {
406            self.pos += 1;
407            return Ok(byte);
408        }
409        let result = self.non_empty_reader_buffer_mut().map(|b| b[0]);
410        if result.is_ok() {
411            self.reader.get_mut().consume(1);
412        } else {
413            self.guaranteed_eof = true;
414        }
415        result
416    }
417
418    /// Takes the next `N` bytes from the input as an array, returning an error if the operation
419    /// fails
420    fn read_exact<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
421        let buf = self.buffer();
422        let mut output = [0; N];
423        match buf.len() {
424            0 => {
425                let buf = self.non_empty_reader_buffer_mut()?;
426                if buf.len() < N {
427                    return Err(DeserializationError::UnexpectedEOF);
428                }
429                // SAFETY: This copy is guaranteed to be safe, as we have validated above
430                // that `buf` has at least N bytes, and `output` is defined to be exactly
431                // N bytes.
432                unsafe {
433                    core::ptr::copy_nonoverlapping(buf.as_ptr(), output.as_mut_ptr(), N);
434                }
435                self.reader.get_mut().consume(N);
436            },
437            n if n >= N => {
438                // SAFETY: This copy is guaranteed to be safe, as we have validated above
439                // that `buf` has at least N bytes, and `output` is defined to be exactly
440                // N bytes.
441                unsafe {
442                    core::ptr::copy_nonoverlapping(buf.as_ptr(), output.as_mut_ptr(), N);
443                }
444                self.pos += N;
445            },
446            n => {
447                // We have to fill from both the local and reader buffers
448                self.non_empty_reader_buffer_mut()?;
449                let reader_buf = self.reader_buffer();
450                match reader_buf.len() {
451                    #[cfg(debug_assertions)]
452                    0 => unreachable!("expected reader buffer to be non-empty to reach here"),
453                    #[cfg(not(debug_assertions))]
454                    // SAFETY: The call to `non_empty_reader_buffer_mut` will return an error
455                    // if `reader_buffer` is non-empty, as a result is is impossible to reach
456                    // here with a length of 0.
457                    0 => unsafe { core::hint::unreachable_unchecked() },
458                    // We got enough in one request
459                    m if m + n >= N => {
460                        let needed = N - n;
461                        let dst = output.as_mut_ptr();
462                        // SAFETY: Both copies are guaranteed to be in-bounds:
463                        //
464                        // * `output` is defined to be exactly N bytes
465                        // * `buf` is guaranteed to be < N bytes
466                        // * `reader_buf` is guaranteed to have the remaining bytes needed,
467                        // and we only copy exactly that many bytes
468                        unsafe {
469                            core::ptr::copy_nonoverlapping(self.buffer().as_ptr(), dst, n);
470                            core::ptr::copy_nonoverlapping(reader_buf.as_ptr(), dst.add(n), needed);
471                            drop(reader_buf);
472                        }
473                        self.pos += n;
474                        self.reader.get_mut().consume(needed);
475                    },
476                    // We didn't get enough, but haven't necessarily reached eof yet, so fall back
477                    // to filling `self.buf`
478                    m => {
479                        let needed = N - (m + n);
480                        drop(reader_buf);
481                        self.buffer_at_least(needed)?;
482                        debug_assert!(
483                            self.buffer().len() >= N,
484                            "expected buffer to be at least {N} bytes after call to buffer_at_least"
485                        );
486                        // SAFETY: This is guaranteed to be an in-bounds copy
487                        unsafe {
488                            core::ptr::copy_nonoverlapping(
489                                self.buffer().as_ptr(),
490                                output.as_mut_ptr(),
491                                N,
492                            );
493                        }
494                        self.pos += N;
495                        return Ok(output);
496                    },
497                }
498            },
499        }
500
501        // Check if we should reset our internal buffer
502        if self.buffer().is_empty() && self.pos > 0 {
503            unsafe {
504                self.buf.set_len(0);
505            }
506        }
507
508        Ok(output)
509    }
510
511    /// Fill `self.buf` with `count` bytes
512    ///
513    /// This should only be called when we can't read from the reader directly
514    fn buffer_at_least(&mut self, mut count: usize) -> Result<(), DeserializationError> {
515        // Read until we have at least `count` bytes, or until we reach end-of-file,
516        // which ever comes first.
517        loop {
518            // If we have successfully read `count` bytes, we're done
519            if count == 0 || self.buffer().len() >= count {
520                break Ok(());
521            }
522
523            // This operation will return an error if the underlying reader hits EOF
524            self.non_empty_reader_buffer_mut()?;
525
526            // Extend `self.buf` with the bytes read from the underlying reader.
527            //
528            // NOTE: We have to re-borrow the reader buffer here, since we can't get a mutable
529            // reference to `self.buf` while holding an immutable reference to the reader buffer.
530            let reader = self.reader.get_mut();
531            let buf = reader.buffer();
532            let consumed = buf.len();
533            self.buf.extend_from_slice(buf);
534            reader.consume(consumed);
535            count = count.saturating_sub(consumed);
536        }
537    }
538}
539
540#[cfg(feature = "std")]
541impl ByteReader for ReadAdapter<'_> {
542    #[inline(always)]
543    fn read_u8(&mut self) -> Result<u8, DeserializationError> {
544        self.pop()
545    }
546
547    /// NOTE: If we happen to not have any bytes buffered yet when this is called, then we will be
548    /// forced to try and read from the underlying reader. This requires a mutable reference, which
549    /// is obtained dynamically via [RefCell].
550    ///
551    /// <div class="warning">
552    /// Callers must ensure that they do not hold any immutable references to the buffer of this
553    /// reader when calling this function so as to avoid a situation in which the dynamic borrow
554    /// check fails. Specifically, you must not be holding a reference to the result of
555    /// [Self::read_slice] when this function is called.
556    /// </div>
557    fn peek_u8(&self) -> Result<u8, DeserializationError> {
558        if let Some(byte) = self.buffer().first() {
559            return Ok(*byte);
560        }
561        self.non_empty_reader_buffer().map(|b| b[0])
562    }
563
564    fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError> {
565        // Edge case
566        if len == 0 {
567            return Ok(&[]);
568        }
569
570        // If we have unused buffer, and the consumed portion is
571        // large enough, we will move the unused portion of the buffer
572        // to the start, freeing up bytes at the end for more reads
573        // before forcing a reallocation
574        let should_optimize_storage = self.pos >= 16 && !self.has_remaining_capacity(len);
575        if should_optimize_storage {
576            // We're going to optimize storage first
577            let buf = self.buffer();
578            let src = buf.as_ptr();
579            let count = buf.len();
580            let dst = self.buf.as_mut_ptr();
581            unsafe {
582                core::ptr::copy(src, dst, count);
583                self.buf.set_len(count);
584                self.pos = 0;
585            }
586        }
587
588        // Fill the buffer so we have at least `len` bytes available,
589        // this will return an error if we hit EOF first
590        self.buffer_at_least(len)?;
591
592        let slice = &self.buf[self.pos..(self.pos + len)];
593        self.pos += len;
594        Ok(slice)
595    }
596
597    #[inline]
598    fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
599        if N == 0 {
600            return Ok([0; N]);
601        }
602        self.read_exact()
603    }
604
605    fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError> {
606        // Do we have sufficient data in the local buffer?
607        let buffer_len = self.buffer().len();
608        if buffer_len >= num_bytes {
609            return Ok(());
610        }
611
612        // What about if we include what is in the local buffer and the reader's buffer?
613        let reader_buffer_len = self.non_empty_reader_buffer().map(|b| b.len())?;
614        let buffer_len = buffer_len + reader_buffer_len;
615        if buffer_len >= num_bytes {
616            return Ok(());
617        }
618
619        // We have no more input, thus can't fulfill a request of `num_bytes`
620        if self.guaranteed_eof {
621            return Err(DeserializationError::UnexpectedEOF);
622        }
623
624        // Because this function is read-only, we must optimistically assume we can read `num_bytes`
625        // from the input, and fail later if that does not hold. We know we're not at EOF yet, but
626        // that's all we can say without buffering more from the reader. We could make use of
627        // `buffer_at_least`, which would guarantee a correct result, but it would also impose
628        // additional restrictions on the use of this function, e.g. not using it while holding a
629        // reference returned from `read_slice`. Since it is not a memory safety violation to return
630        // an optimistic result here, it makes for a better tradeoff.
631        Ok(())
632    }
633
634    #[inline]
635    fn has_more_bytes(&self) -> bool {
636        !self.buffer().is_empty() || self.non_empty_reader_buffer().is_ok()
637    }
638}
639
640// CURSOR
641// ================================================================================================
642
643#[cfg(feature = "std")]
644macro_rules! cursor_remaining_buf {
645    ($cursor:ident) => {{
646        let buf = $cursor.get_ref().as_ref();
647        let start = $cursor.position().min(buf.len() as u64) as usize;
648        &buf[start..]
649    }};
650}
651
652#[cfg(feature = "std")]
653impl<T: AsRef<[u8]>> ByteReader for std::io::Cursor<T> {
654    fn read_u8(&mut self) -> Result<u8, DeserializationError> {
655        let buf = cursor_remaining_buf!(self);
656        if buf.is_empty() {
657            Err(DeserializationError::UnexpectedEOF)
658        } else {
659            let byte = buf[0];
660            self.set_position(self.position() + 1);
661            Ok(byte)
662        }
663    }
664
665    fn peek_u8(&self) -> Result<u8, DeserializationError> {
666        cursor_remaining_buf!(self)
667            .first()
668            .copied()
669            .ok_or(DeserializationError::UnexpectedEOF)
670    }
671
672    fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError> {
673        let pos = self.position();
674        let size = self.get_ref().as_ref().len() as u64;
675        if size.saturating_sub(pos) < len as u64 {
676            Err(DeserializationError::UnexpectedEOF)
677        } else {
678            self.set_position(pos + len as u64);
679            let start = pos.min(size) as usize;
680            Ok(&self.get_ref().as_ref()[start..(start + len)])
681        }
682    }
683
684    fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
685        self.read_slice(N).map(|bytes| {
686            let mut result = [0u8; N];
687            result.copy_from_slice(bytes);
688            result
689        })
690    }
691
692    fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError> {
693        if cursor_remaining_buf!(self).len() >= num_bytes {
694            Ok(())
695        } else {
696            Err(DeserializationError::UnexpectedEOF)
697        }
698    }
699
700    #[inline]
701    fn has_more_bytes(&self) -> bool {
702        let pos = self.position();
703        let size = self.get_ref().as_ref().len() as u64;
704        pos < size
705    }
706}
707
708// SLICE READER
709// ================================================================================================
710
711/// Implements [ByteReader] trait for a slice of bytes.
712///
713/// NOTE: If you are building with the `std` feature, you should probably prefer [std::io::Cursor]
714/// instead. However, [SliceReader] is still useful in no-std environments until stabilization of
715/// the `core_io_borrowed_buf` feature.
716pub struct SliceReader<'a> {
717    source: &'a [u8],
718    pos: usize,
719}
720
721impl<'a> SliceReader<'a> {
722    /// Creates a new slice reader from the specified slice.
723    pub fn new(source: &'a [u8]) -> Self {
724        SliceReader { source, pos: 0 }
725    }
726}
727
728impl ByteReader for SliceReader<'_> {
729    fn read_u8(&mut self) -> Result<u8, DeserializationError> {
730        self.check_eor(1)?;
731        let result = self.source[self.pos];
732        self.pos += 1;
733        Ok(result)
734    }
735
736    fn peek_u8(&self) -> Result<u8, DeserializationError> {
737        self.check_eor(1)?;
738        Ok(self.source[self.pos])
739    }
740
741    fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError> {
742        self.check_eor(len)?;
743        let result = &self.source[self.pos..self.pos + len];
744        self.pos += len;
745        Ok(result)
746    }
747
748    fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
749        self.check_eor(N)?;
750        let mut result = [0_u8; N];
751        result.copy_from_slice(&self.source[self.pos..self.pos + N]);
752        self.pos += N;
753        Ok(result)
754    }
755
756    fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError> {
757        if self.pos + num_bytes > self.source.len() {
758            return Err(DeserializationError::UnexpectedEOF);
759        }
760        Ok(())
761    }
762
763    fn has_more_bytes(&self) -> bool {
764        self.pos < self.source.len()
765    }
766}
767
768// BUDGETED READER
769// ================================================================================================
770
771/// A reader wrapper that enforces a byte budget during deserialization.
772///
773/// # Threat Model
774///
775/// Malicious input can attack deserialization in two ways:
776///
777/// 1. **Fake length prefix**: Input claims `len = 2^60` elements, causing allocation of a huge
778///    `Vec` before any data is read.
779/// 2. **Oversized input**: Attacker sends gigabytes of valid-looking data to exhaust memory over
780///    time.
781///
782/// # Defense Strategy
783///
784/// Use `BudgetedReader` to limit total bytes consumed. Its [`max_alloc`](ByteReader::max_alloc)
785/// method derives a bound from the remaining budget, which
786/// [`read_many_iter`](ByteReader::read_many_iter) checks before iterating.
787///
788/// ## Problem: SliceReader alone doesn't bound allocations
789///
790/// ```
791/// use miden_serde_utils::{ByteReader, Deserializable, SliceReader};
792///
793/// // Malicious input: length prefix says 1 billion u64s, but only 16 bytes of data
794/// let mut data = Vec::new();
795/// data.push(0u8); // vint64 9-byte marker
796/// data.extend_from_slice(&1_000_000_000u64.to_le_bytes());
797/// data.extend_from_slice(&[0u8; 16]);
798///
799/// // SliceReader returns usize::MAX from max_alloc, so read_many_iter accepts
800/// // any length. This would try to iterate 1 billion times (slow, not OOM,
801/// // but still a DoS vector).
802/// let reader = SliceReader::new(&data);
803/// assert_eq!(reader.max_alloc(8), usize::MAX);
804/// ```
805///
806/// ## Solution: BudgetedReader bounds allocations via max_alloc
807///
808/// ```
809/// use miden_serde_utils::{BudgetedReader, ByteReader, Deserializable, SliceReader};
810///
811/// // Same malicious input
812/// let mut data = Vec::new();
813/// data.push(0u8);
814/// data.extend_from_slice(&1_000_000_000u64.to_le_bytes());
815/// data.extend_from_slice(&[0u8; 16]);
816///
817/// // BudgetedReader with 64-byte budget: max_alloc(8) = 64/8 = 8 elements
818/// let inner = SliceReader::new(&data);
819/// let reader = BudgetedReader::new(inner, 64);
820/// assert_eq!(reader.max_alloc(8), 8);
821///
822/// // read_many_iter rejects the 1B length since 1B > 8
823/// let result = Vec::<u64>::read_from_bytes_with_budget(&data, 64);
824/// assert!(result.is_err());
825/// ```
826///
827/// ## Best practice: Set budget to expected input size
828///
829/// ```
830/// use miden_serde_utils::{ByteWriter, Deserializable, Serializable};
831///
832/// // Legitimate input: 3 u64s, properly serialized
833/// let original = vec![1u64, 2, 3];
834/// let mut data = Vec::new();
835/// original.write_into(&mut data);
836///
837/// // Budget = data.len() bounds both fake lengths and total consumption
838/// let result = Vec::<u64>::read_from_bytes_with_budget(&data, data.len());
839/// assert_eq!(result.unwrap(), vec![1, 2, 3]);
840/// ```
841pub struct BudgetedReader<R> {
842    inner: R,
843    remaining: usize,
844}
845
846impl<R> BudgetedReader<R> {
847    /// Wraps a reader with the specified byte budget.
848    pub fn new(inner: R, budget: usize) -> Self {
849        Self { inner, remaining: budget }
850    }
851
852    /// Returns remaining budget in bytes.
853    pub fn remaining(&self) -> usize {
854        self.remaining
855    }
856
857    /// Consumes budget, returning an error if insufficient.
858    fn consume_budget(&mut self, n: usize) -> Result<(), DeserializationError> {
859        if n > self.remaining {
860            return Err(DeserializationError::InvalidValue(format!(
861                "budget exhausted: requested {n} bytes, {} remaining",
862                self.remaining
863            )));
864        }
865        self.remaining -= n;
866        Ok(())
867    }
868}
869
870impl<R: ByteReader> ByteReader for BudgetedReader<R> {
871    fn read_u8(&mut self) -> Result<u8, DeserializationError> {
872        self.consume_budget(1)?;
873        self.inner.read_u8()
874    }
875
876    fn peek_u8(&self) -> Result<u8, DeserializationError> {
877        // peek doesn't consume budget since it doesn't advance the reader
878        self.inner.peek_u8()
879    }
880
881    fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError> {
882        self.consume_budget(len)?;
883        self.inner.read_slice(len)
884    }
885
886    fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
887        self.consume_budget(N)?;
888        self.inner.read_array()
889    }
890
891    fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError> {
892        // check budget first, then delegate
893        if num_bytes > self.remaining {
894            return Err(DeserializationError::InvalidValue(format!(
895                "budget exhausted: requested {num_bytes} bytes, {} remaining",
896                self.remaining
897            )));
898        }
899        self.inner.check_eor(num_bytes)
900    }
901
902    fn has_more_bytes(&self) -> bool {
903        self.remaining > 0 && self.inner.has_more_bytes()
904    }
905
906    fn max_alloc(&self, element_size: usize) -> usize {
907        if element_size == 0 {
908            return usize::MAX; // ZSTs don't consume budget
909        }
910        self.remaining / element_size
911    }
912}
913
914#[cfg(all(test, feature = "std"))]
915mod tests {
916    use std::io::Cursor;
917
918    use super::*;
919    use crate::ByteWriter;
920
921    #[test]
922    fn read_adapter_empty() -> Result<(), DeserializationError> {
923        let mut reader = std::io::empty();
924        let mut adapter = ReadAdapter::new(&mut reader);
925        assert!(!adapter.has_more_bytes());
926        assert_eq!(adapter.check_eor(8), Err(DeserializationError::UnexpectedEOF));
927        assert_eq!(adapter.peek_u8(), Err(DeserializationError::UnexpectedEOF));
928        assert_eq!(adapter.read_u8(), Err(DeserializationError::UnexpectedEOF));
929        assert_eq!(adapter.read_slice(0), Ok([].as_slice()));
930        assert_eq!(adapter.read_slice(1), Err(DeserializationError::UnexpectedEOF));
931        assert_eq!(adapter.read_array(), Ok([]));
932        assert_eq!(adapter.read_array::<1>(), Err(DeserializationError::UnexpectedEOF));
933        Ok(())
934    }
935
936    #[test]
937    fn read_adapter_passthrough() -> Result<(), DeserializationError> {
938        let mut reader = std::io::repeat(0b101);
939        let mut adapter = ReadAdapter::new(&mut reader);
940        assert!(adapter.has_more_bytes());
941        assert_eq!(adapter.check_eor(8), Ok(()));
942        assert_eq!(adapter.peek_u8(), Ok(0b101));
943        assert_eq!(adapter.read_u8(), Ok(0b101));
944        assert_eq!(adapter.read_slice(0), Ok([].as_slice()));
945        assert_eq!(adapter.read_slice(4), Ok([0b101, 0b101, 0b101, 0b101].as_slice()));
946        assert_eq!(adapter.read_array(), Ok([]));
947        assert_eq!(adapter.read_array(), Ok([0b101, 0b101]));
948        Ok(())
949    }
950
951    #[test]
952    fn read_adapter_exact() {
953        const VALUE: usize = 2048;
954        let mut reader = Cursor::new(VALUE.to_le_bytes());
955        let mut adapter = ReadAdapter::new(&mut reader);
956        assert_eq!(usize::from_le_bytes(adapter.read_array().unwrap()), VALUE);
957        assert!(!adapter.has_more_bytes());
958        assert_eq!(adapter.peek_u8(), Err(DeserializationError::UnexpectedEOF));
959        assert_eq!(adapter.read_u8(), Err(DeserializationError::UnexpectedEOF));
960    }
961
962    #[test]
963    fn read_adapter_roundtrip() {
964        const VALUE: usize = 2048;
965
966        // Write VALUE to storage
967        let mut cursor = Cursor::new([0; core::mem::size_of::<usize>()]);
968        cursor.write_usize(VALUE);
969
970        // Read VALUE from storage
971        cursor.set_position(0);
972        let mut adapter = ReadAdapter::new(&mut cursor);
973
974        assert_eq!(adapter.read_usize(), Ok(VALUE));
975    }
976
977    #[test]
978    fn read_adapter_for_file() {
979        use std::fs::File;
980
981        use crate::ByteWriter;
982
983        let path = std::env::temp_dir().join("read_adapter_for_file.bin");
984
985        // Encode some data to a buffer, then write that buffer to a file
986        {
987            let mut buf = Vec::<u8>::with_capacity(256);
988            buf.write_bytes(b"MAGIC\0");
989            buf.write_bool(true);
990            buf.write_u32(0xbeef);
991            buf.write_usize(0xfeed);
992            buf.write_u16(0x5);
993
994            std::fs::write(&path, &buf).unwrap();
995        }
996
997        // Open the file, and try to decode the encoded items
998        let mut file = File::open(&path).unwrap();
999        let mut reader = ReadAdapter::new(&mut file);
1000        assert_eq!(reader.peek_u8().unwrap(), b'M');
1001        assert_eq!(reader.read_slice(6).unwrap(), b"MAGIC\0");
1002        assert!(reader.read_bool().unwrap());
1003        assert_eq!(reader.read_u32().unwrap(), 0xbeef);
1004        assert_eq!(reader.read_usize().unwrap(), 0xfeed);
1005        assert_eq!(reader.read_u16().unwrap(), 0x5);
1006        assert!(!reader.has_more_bytes(), "expected there to be no more data in the input");
1007    }
1008
1009    #[test]
1010    fn read_adapter_issue_383() {
1011        const STR_BYTES: &[u8] = b"just a string";
1012
1013        use std::fs::File;
1014
1015        use crate::ByteWriter;
1016
1017        let path = std::env::temp_dir().join("issue_383.bin");
1018
1019        // Encode some data to a buffer, then write that buffer to a file
1020        {
1021            let mut buf = vec![0u8; 1024];
1022            unsafe {
1023                buf.set_len(0);
1024            }
1025            buf.write_u128(2 * u64::MAX as u128);
1026            unsafe {
1027                buf.set_len(512);
1028            }
1029            buf.write_bytes(STR_BYTES);
1030            buf.write_u32(0xbeef);
1031
1032            std::fs::write(&path, &buf).unwrap();
1033        }
1034
1035        // Open the file, and try to decode the encoded items
1036        let mut file = File::open(&path).unwrap();
1037        let mut reader = ReadAdapter::new(&mut file);
1038        assert_eq!(reader.read_u128().unwrap(), 2 * u64::MAX as u128);
1039        assert_eq!(reader.buf.len(), 0);
1040        assert_eq!(reader.pos, 0);
1041        // Read to offset 512 (we're 16 bytes into the underlying file, i.e. offset of 496)
1042        reader.read_slice(496).unwrap();
1043        assert_eq!(reader.buf.len(), 496);
1044        assert_eq!(reader.pos, 496);
1045        // The byte string is 13 bytes, followed by 4 bytes containing the trailing u32 value.
1046        // We expect that the underlying reader will buffer the remaining bytes of the file when
1047        // reading STR_BYTES, so the total size of our adapter's buffer should be
1048        // 496 + STR_BYTES.len() + size_of::<u32>();
1049        assert_eq!(reader.read_slice(STR_BYTES.len()).unwrap(), STR_BYTES);
1050        assert_eq!(reader.buf.len(), 496 + STR_BYTES.len() + core::mem::size_of::<u32>());
1051        // We haven't read the u32 yet
1052        assert_eq!(reader.pos, 509);
1053        assert_eq!(reader.read_u32().unwrap(), 0xbeef);
1054        // Now we have
1055        assert_eq!(reader.pos, 513);
1056        assert!(!reader.has_more_bytes(), "expected there to be no more data in the input");
1057    }
1058
1059    #[test]
1060    fn budgeted_reader_basic() {
1061        let data = [1u8, 2, 3, 4, 5, 6, 7, 8];
1062        let inner = SliceReader::new(&data);
1063        let mut reader = BudgetedReader::new(inner, 4);
1064
1065        assert_eq!(reader.remaining(), 4);
1066        assert!(reader.has_more_bytes());
1067
1068        // read 4 bytes (within budget)
1069        assert_eq!(reader.read_u32().unwrap(), 0x04030201);
1070        assert_eq!(reader.remaining(), 0);
1071
1072        // budget exhausted
1073        assert!(!reader.has_more_bytes());
1074        assert!(reader.read_u8().is_err());
1075    }
1076
1077    #[test]
1078    fn budgeted_reader_peek_does_not_consume() {
1079        let data = [42u8];
1080        let inner = SliceReader::new(&data);
1081        let mut reader = BudgetedReader::new(inner, 1);
1082
1083        // peek multiple times, budget unchanged
1084        assert_eq!(reader.peek_u8().unwrap(), 42);
1085        assert_eq!(reader.peek_u8().unwrap(), 42);
1086        assert_eq!(reader.remaining(), 1);
1087
1088        // actual read consumes budget
1089        assert_eq!(reader.read_u8().unwrap(), 42);
1090        assert_eq!(reader.remaining(), 0);
1091    }
1092
1093    #[test]
1094    fn budgeted_reader_check_eor_respects_budget() {
1095        let data = [0u8; 100];
1096        let inner = SliceReader::new(&data);
1097        let reader = BudgetedReader::new(inner, 10);
1098
1099        // within budget
1100        assert!(reader.check_eor(10).is_ok());
1101
1102        // exceeds budget (even though inner has enough bytes)
1103        assert!(reader.check_eor(11).is_err());
1104    }
1105
1106    #[test]
1107    fn budgeted_reader_read_slice() {
1108        let data = [1u8, 2, 3, 4, 5];
1109        let inner = SliceReader::new(&data);
1110        let mut reader = BudgetedReader::new(inner, 3);
1111
1112        // read 3 bytes (exactly budget)
1113        assert_eq!(reader.read_slice(3).unwrap(), &[1, 2, 3]);
1114        assert_eq!(reader.remaining(), 0);
1115
1116        // can't read more
1117        assert!(reader.read_slice(1).is_err());
1118    }
1119
1120    #[test]
1121    fn budgeted_reader_read_array() {
1122        let data = [0xaau8, 0xbb, 0xcc, 0xdd];
1123        let inner = SliceReader::new(&data);
1124        let mut reader = BudgetedReader::new(inner, 2);
1125
1126        // read 2-byte array
1127        assert_eq!(reader.read_array::<2>().unwrap(), [0xaa, 0xbb]);
1128        assert_eq!(reader.remaining(), 0);
1129
1130        // budget exhausted
1131        assert!(reader.read_array::<2>().is_err());
1132    }
1133
1134    #[test]
1135    fn budgeted_reader_zero_budget() {
1136        let data = [1u8];
1137        let inner = SliceReader::new(&data);
1138        let mut reader = BudgetedReader::new(inner, 0);
1139
1140        assert!(!reader.has_more_bytes());
1141        assert!(reader.read_u8().is_err());
1142        // peek still works (doesn't consume budget)
1143        assert_eq!(reader.peek_u8().unwrap(), 1);
1144    }
1145
1146    #[test]
1147    fn budgeted_reader_max_alloc() {
1148        let data = [0u8; 100];
1149        let inner = SliceReader::new(&data);
1150        let reader = BudgetedReader::new(inner, 64);
1151
1152        // 64 bytes budget / 8 bytes per u64 = 8 elements max
1153        assert_eq!(reader.max_alloc(8), 8);
1154
1155        // 64 bytes budget / 1 byte per u8 = 64 elements max
1156        assert_eq!(reader.max_alloc(1), 64);
1157
1158        // 64 bytes budget / 16 bytes per u128 = 4 elements max
1159        assert_eq!(reader.max_alloc(16), 4);
1160
1161        // ZSTs (0 bytes) return usize::MAX
1162        assert_eq!(reader.max_alloc(0), usize::MAX);
1163    }
1164
1165    #[test]
1166    fn unbounded_reader_max_alloc_returns_max() {
1167        let data = [0u8; 100];
1168        let reader = SliceReader::new(&data);
1169
1170        // Unbounded readers return usize::MAX
1171        assert_eq!(reader.max_alloc(1), usize::MAX);
1172        assert_eq!(reader.max_alloc(8), usize::MAX);
1173    }
1174
1175    // ============================================================================================
1176    // The following tests document the threat model and defense layers.
1177    // ============================================================================================
1178
1179    /// SliceReader alone does NOT reject fake length prefixes.
1180    ///
1181    /// A malicious input claiming 1000 elements will be accepted by read_many_iter
1182    /// because SliceReader.max_alloc() returns usize::MAX. The deserialization will
1183    /// eventually fail with UnexpectedEOF, but only after attempting to iterate
1184    /// (which could be slow for huge counts, though not OOM since we don't pre-allocate).
1185    #[test]
1186    fn slice_reader_accepts_fake_length_prefix() {
1187        let mut data = Vec::new();
1188        // Write length = 1000 (vint64 encoding: 0x07D0 << 2 | 0b10 = 0x1F42)
1189        // For simplicity, use the 9-byte form
1190        data.push(0); // 9-byte marker
1191        data.extend_from_slice(&1000u64.to_le_bytes());
1192        // Only 8 bytes of actual u64 data (1 element, not 1000)
1193        data.extend_from_slice(&42u64.to_le_bytes());
1194
1195        // read_many_iter passes the max_alloc check (usize::MAX >= 1000)
1196        let mut reader = SliceReader::new(&data);
1197        let _len = reader.read_usize().unwrap();
1198        let iter_result = reader.read_many_iter::<u64>(1000);
1199
1200        // The iterator is created successfully
1201        assert!(iter_result.is_ok());
1202
1203        // But collecting fails on the 2nd element (EOF)
1204        let collect_result: Result<Vec<u64>, _> = iter_result.unwrap().collect();
1205        assert!(collect_result.is_err());
1206        assert!(matches!(collect_result.unwrap_err(), DeserializationError::UnexpectedEOF));
1207    }
1208
1209    /// BudgetedReader rejects fake length prefixes BEFORE iteration begins.
1210    ///
1211    /// With a 64-byte budget, max_alloc(8) = 8, so a claim of 1000 elements
1212    /// is rejected immediately by read_many_iter.
1213    #[test]
1214    fn budgeted_reader_rejects_fake_length_upfront() {
1215        let mut data = Vec::new();
1216        data.push(0); // 9-byte vint64 marker
1217        data.extend_from_slice(&1000u64.to_le_bytes());
1218        data.extend_from_slice(&42u64.to_le_bytes());
1219
1220        let inner = SliceReader::new(&data);
1221        let mut reader = BudgetedReader::new(inner, 64);
1222
1223        let _len = reader.read_usize().unwrap(); // consumes 9 bytes, 55 remaining
1224        // 55 / 8 = 6 elements max
1225        let iter_result = reader.read_many_iter::<u64>(1000);
1226
1227        // Rejected immediately: 1000 > 6
1228        match iter_result {
1229            Err(DeserializationError::InvalidValue(_)) => {}, // expected
1230            other => panic!("expected InvalidValue error, got {:?}", other.map(|_| "Ok")),
1231        }
1232    }
1233
1234    /// Best practice: budget = input length provides both protections.
1235    ///
1236    /// 1. Fake length prefixes are bounded by max_alloc (remaining_bytes / element_size)
1237    /// 2. Total consumption is bounded by the budget
1238    #[test]
1239    fn budget_equals_input_length_is_safe() {
1240        // Valid input: 2 u64s
1241        let original = vec![100u64, 200];
1242        let mut data = Vec::new();
1243        crate::Serializable::write_into(&original, &mut data);
1244
1245        // Budget = exact input size
1246        let result = Vec::<u64>::read_from_bytes_with_budget(&data, data.len());
1247        assert_eq!(result.unwrap(), vec![100, 200]);
1248
1249        // Malicious input claiming 1000 elements (same serialized prefix manipulation)
1250        let mut evil_data = Vec::new();
1251        evil_data.push(0); // 9-byte vint64
1252        evil_data.extend_from_slice(&1000u64.to_le_bytes());
1253        evil_data.extend_from_slice(&42u64.to_le_bytes()); // only 1 actual element
1254
1255        // Budget = input length (17 bytes). After reading length (9 bytes), 8 remain.
1256        // max_alloc(8) = 8/8 = 1, so 1000 > 1 fails.
1257        let result = Vec::<u64>::read_from_bytes_with_budget(&evil_data, evil_data.len());
1258        assert!(result.is_err());
1259    }
1260
1261    // ============================================================================================
1262    // Tests documenting min_serialized_size()-based allocation bounds (defaults to size_of)
1263    // ============================================================================================
1264
1265    /// The max_alloc check uses D::min_serialized_size() to bound memory allocation.
1266    /// By default, min_serialized_size() returns size_of::<D>().
1267    ///
1268    /// For flat collections like Vec<u64>, this works well: we check that
1269    /// budget / min_serialized_size() >= requested_count before allocating.
1270    #[test]
1271    fn min_serialized_size_bounds_flat_collections() {
1272        let mut data = Vec::new();
1273        data.push(0); // 9-byte vint64 marker
1274        data.extend_from_slice(&1000u64.to_le_bytes()); // claim 1000 u64s
1275        data.extend_from_slice(&[0u8; 16]); // only 2 u64s of actual data
1276
1277        let inner = SliceReader::new(&data);
1278        // Budget of 80 bytes: after reading 9-byte length, 71 remain.
1279        // max_alloc(u64::min_serialized_size()) = 71 / 8 = 8 elements max
1280        let mut reader = BudgetedReader::new(inner, 80);
1281
1282        let _len = reader.read_usize().unwrap();
1283        let result = reader.read_many_iter::<u64>(1000);
1284
1285        // Rejected: 1000 > 8
1286        assert!(result.is_err());
1287    }
1288
1289    /// For nested collections like Vec<Vec<u64>>, min_serialized_size() returns 1 (the minimum
1290    /// vint length prefix), not size_of. This is more permissive but accurate: a
1291    /// serialized Vec can be as small as 1 byte (empty vec).
1292    ///
1293    /// The early-abort check uses this minimum, and budget enforcement during actual
1294    /// reads provides the real protection against malicious input.
1295    #[test]
1296    fn min_serialized_size_override_for_nested_collections() {
1297        // Vec<u64>::min_serialized_size() returns 1 (minimum vint prefix), not size_of
1298        assert_eq!(<Vec<u64>>::min_serialized_size(), 1);
1299
1300        let mut data = Vec::new();
1301        data.push(0); // 9-byte vint64 marker
1302        data.extend_from_slice(&100u64.to_le_bytes()); // claim 100 inner Vecs
1303        // Only provide enough data for 1 empty inner Vec
1304        data.push(0b10); // vint64 for 0 (empty inner vec)
1305
1306        let inner = SliceReader::new(&data);
1307        // With min_serialized_size() = 1, we need budget >= 100 to pass the early check.
1308        // After reading 9-byte length, 101 - 9 = 92 remaining, 92 / 1 = 92 < 100.
1309        // So with budget = 110, we get 110 - 9 = 101 remaining, 101 >= 100.
1310        let mut reader = BudgetedReader::new(inner, 110);
1311
1312        let _len = reader.read_usize().unwrap();
1313        let result = reader.read_many_iter::<Vec<u64>>(100);
1314
1315        // The early check passes (100 <= 101)
1316        assert!(result.is_ok());
1317
1318        // But deserialization fails when we try to read 100 inner Vecs with only 1
1319        let collect_result: Result<Vec<Vec<u64>>, _> = result.unwrap().collect();
1320        assert!(collect_result.is_err());
1321    }
1322
1323    /// Demonstrates that min_serialized_size() approach still provides security for nested
1324    /// collections, just with later detection. The budget is enforced during reads.
1325    #[test]
1326    fn nested_collections_still_protected_by_budget() {
1327        // With Vec::min_serialized_size() = 1, the early check is permissive.
1328        // Security comes from budget enforcement during actual reads.
1329        let mut data = Vec::new();
1330        data.push(0); // 9-byte vint64 marker
1331        data.extend_from_slice(&10u64.to_le_bytes()); // claim 10 inner Vecs
1332        // Each inner vec claims 1000 u64s but provides none
1333        for _ in 0..10 {
1334            data.push(0); // 9-byte vint64 marker
1335            data.extend_from_slice(&1000u64.to_le_bytes());
1336        }
1337
1338        let inner = SliceReader::new(&data);
1339        // Small budget: will run out during inner deserialization
1340        let mut reader = BudgetedReader::new(inner, 100);
1341
1342        // Outer length read succeeds (consumes 9 bytes, 91 remaining)
1343        let _len = reader.read_usize().unwrap();
1344
1345        // With Vec::min_serialized_size() = 1, early check passes: 91 / 1 = 91 >= 10
1346        let result = reader.read_many_iter::<Vec<u64>>(10);
1347        assert!(result.is_ok());
1348
1349        // But collecting fails because the inner vecs claim 1000 u64s each,
1350        // exhausting the budget during inner deserialization
1351        let collect_result: Result<Vec<Vec<u64>>, _> = result.unwrap().collect();
1352        assert!(collect_result.is_err());
1353    }
1354}