miden_serde_utils/byte_reader.rs
1// Copyright (c) Facebook, Inc. and its affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6#[cfg(feature = "std")]
7use alloc::string::ToString;
8use alloc::{format, string::String, vec::Vec};
9#[cfg(feature = "std")]
10use core::cell::{Ref, RefCell};
11#[cfg(feature = "std")]
12use std::io::BufRead;
13
14use crate::{Deserializable, DeserializationError};
15
16// BYTE READER TRAIT
17// ================================================================================================
18
19/// Defines how primitive values are to be read from `Self`.
20///
21/// Whenever data is read from the reader using any of the `read_*` functions, the reader advances
22/// to the next unread byte. If the error occurs, the reader is not rolled back to the state prior
23/// to calling any of the function.
24pub trait ByteReader {
25 // REQUIRED METHODS
26 // --------------------------------------------------------------------------------------------
27
28 /// Returns a single byte read from `self`.
29 ///
30 /// # Errors
31 /// Returns a [DeserializationError] error the reader is at EOF.
32 fn read_u8(&mut self) -> Result<u8, DeserializationError>;
33
34 /// Returns the next byte to be read from `self` without advancing the reader to the next byte.
35 ///
36 /// # Errors
37 /// Returns a [DeserializationError] error the reader is at EOF.
38 fn peek_u8(&self) -> Result<u8, DeserializationError>;
39
40 /// Returns a slice of bytes of the specified length read from `self`.
41 ///
42 /// # Errors
43 /// Returns a [DeserializationError] if a slice of the specified length could not be read
44 /// from `self`.
45 fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError>;
46
47 /// Returns a byte array of length `N` read from `self`.
48 ///
49 /// # Errors
50 /// Returns a [DeserializationError] if an array of the specified length could not be read
51 /// from `self`.
52 fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError>;
53
54 /// Checks if it is possible to read at least `num_bytes` bytes from this ByteReader
55 ///
56 /// # Errors
57 /// Returns an error if, when reading the requested number of bytes, we go beyond the
58 /// the data available in the reader.
59 fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError>;
60
61 /// Returns true if there are more bytes left to be read from `self`.
62 fn has_more_bytes(&self) -> bool;
63
64 /// Returns the maximum number of elements that can be safely allocated, given each
65 /// element occupies `element_size` bytes when serialized.
66 ///
67 /// This can be used by callers to pre-validate collection lengths before iterating,
68 /// preventing denial-of-service attacks from malicious length prefixes that claim
69 /// billions of elements.
70 ///
71 /// The default implementation returns `usize::MAX`, meaning no limit is enforced.
72 /// [`BudgetedReader`] overrides this to return `remaining_budget / element_size`,
73 /// providing tight, adaptive limits based on the caller's budget.
74 ///
75 /// # Arguments
76 /// * `element_size` - The serialized size of one element, from
77 /// [`Deserializable::min_serialized_size`]. Defaults to `size_of::<D>()` but can be
78 /// overridden for types where serialized size differs from in-memory size.
79 fn max_alloc(&self, _element_size: usize) -> usize {
80 usize::MAX
81 }
82
83 // PROVIDED METHODS
84 // --------------------------------------------------------------------------------------------
85
86 /// Returns a boolean value read from `self` consuming 1 byte from the reader.
87 ///
88 /// # Errors
89 /// Returns a [DeserializationError] if a u16 value could not be read from `self`.
90 fn read_bool(&mut self) -> Result<bool, DeserializationError> {
91 let byte = self.read_u8()?;
92 match byte {
93 0 => Ok(false),
94 1 => Ok(true),
95 _ => Err(DeserializationError::InvalidValue(format!("{byte} is not a boolean value"))),
96 }
97 }
98
99 /// Returns a u16 value read from `self` in little-endian byte order.
100 ///
101 /// # Errors
102 /// Returns a [DeserializationError] if a u16 value could not be read from `self`.
103 fn read_u16(&mut self) -> Result<u16, DeserializationError> {
104 let bytes = self.read_array::<2>()?;
105 Ok(u16::from_le_bytes(bytes))
106 }
107
108 /// Returns a u32 value read from `self` in little-endian byte order.
109 ///
110 /// # Errors
111 /// Returns a [DeserializationError] if a u32 value could not be read from `self`.
112 fn read_u32(&mut self) -> Result<u32, DeserializationError> {
113 let bytes = self.read_array::<4>()?;
114 Ok(u32::from_le_bytes(bytes))
115 }
116
117 /// Returns a u64 value read from `self` in little-endian byte order.
118 ///
119 /// # Errors
120 /// Returns a [DeserializationError] if a u64 value could not be read from `self`.
121 fn read_u64(&mut self) -> Result<u64, DeserializationError> {
122 let bytes = self.read_array::<8>()?;
123 Ok(u64::from_le_bytes(bytes))
124 }
125
126 /// Returns a u128 value read from `self` in little-endian byte order.
127 ///
128 /// # Errors
129 /// Returns a [DeserializationError] if a u128 value could not be read from `self`.
130 fn read_u128(&mut self) -> Result<u128, DeserializationError> {
131 let bytes = self.read_array::<16>()?;
132 Ok(u128::from_le_bytes(bytes))
133 }
134
135 /// Returns a usize value read from `self` in [vint64](https://docs.rs/vint64/latest/vint64/)
136 /// format.
137 ///
138 /// # Errors
139 /// Returns a [DeserializationError] if:
140 /// * usize value could not be read from `self`.
141 /// * encoded value is greater than `usize` maximum value on a given platform.
142 fn read_usize(&mut self) -> Result<usize, DeserializationError> {
143 let first_byte = self.peek_u8()?;
144 let length = first_byte.trailing_zeros() as usize + 1;
145
146 let result = if length == 9 {
147 // 9-byte special case
148 self.read_u8()?;
149 let value = self.read_array::<8>()?;
150 u64::from_le_bytes(value)
151 } else {
152 let mut encoded = [0u8; 8];
153 let value = self.read_slice(length)?;
154 encoded[..length].copy_from_slice(value);
155 u64::from_le_bytes(encoded) >> length
156 };
157
158 // check if the result value is within acceptable bounds for `usize` on a given platform
159 if result > usize::MAX as u64 {
160 return Err(DeserializationError::InvalidValue(format!(
161 "Encoded value must be less than {}, but {} was provided",
162 usize::MAX,
163 result
164 )));
165 }
166
167 Ok(result as usize)
168 }
169
170 /// Returns a byte vector of the specified length read from `self`.
171 ///
172 /// # Errors
173 /// Returns a [DeserializationError] if a vector of the specified length could not be read
174 /// from `self`.
175 fn read_vec(&mut self, len: usize) -> Result<Vec<u8>, DeserializationError> {
176 let data = self.read_slice(len)?;
177 Ok(data.to_vec())
178 }
179
180 /// Returns a String of the specified length read from `self`.
181 ///
182 /// # Errors
183 /// Returns a [DeserializationError] if a String of the specified length could not be read
184 /// from `self`.
185 fn read_string(&mut self, num_bytes: usize) -> Result<String, DeserializationError> {
186 let data = self.read_vec(num_bytes)?;
187 String::from_utf8(data).map_err(|err| DeserializationError::InvalidValue(format!("{err}")))
188 }
189
190 /// Reads a deserializable value from `self`.
191 ///
192 /// # Errors
193 /// Returns a [DeserializationError] if the specified value could not be read from `self`.
194 fn read<D>(&mut self) -> Result<D, DeserializationError>
195 where
196 Self: Sized,
197 D: Deserializable,
198 {
199 D::read_from(self)
200 }
201
202 /// Returns an iterator that deserializes `num_elements` instances of `D` from this reader.
203 ///
204 /// This method validates the requested count against the reader's capacity before returning
205 /// the iterator, rejecting implausible lengths early. Each element is then deserialized
206 /// lazily as the iterator is consumed.
207 ///
208 /// # Errors
209 ///
210 /// Returns an error if `num_elements` exceeds `self.max_alloc(D::min_serialized_size())`,
211 /// indicating the reader cannot allocate that many elements.
212 ///
213 /// # Example
214 ///
215 /// ```ignore
216 /// // Collect into a Vec
217 /// let items: Vec<u64> = reader
218 /// .read_many_iter::<u64>(count)?
219 /// .collect::<Result<_, _>>()?;
220 ///
221 /// // Collect directly into a BTreeMap (no intermediate Vec)
222 /// let map: BTreeMap<K, V> = reader
223 /// .read_many_iter::<(K, V)>(count)?
224 /// .collect::<Result<_, _>>()?;
225 /// ```
226 fn read_many_iter<D>(
227 &mut self,
228 num_elements: usize,
229 ) -> Result<ReadManyIter<'_, Self, D>, DeserializationError>
230 where
231 Self: Sized,
232 D: Deserializable,
233 {
234 let max_elements = self.max_alloc(D::min_serialized_size());
235 if num_elements > max_elements {
236 return Err(DeserializationError::InvalidValue(format!(
237 "requested {num_elements} elements but reader can provide at most {max_elements}"
238 )));
239 }
240 Ok(ReadManyIter {
241 reader: self,
242 remaining: num_elements,
243 _item: core::marker::PhantomData,
244 })
245 }
246}
247
248// READ MANY ITERATOR
249// ================================================================================================
250
251/// Iterator that lazily deserializes elements from a [`ByteReader`].
252///
253/// Created by [`ByteReader::read_many_iter`]. Each call to `next()` deserializes one element.
254/// This avoids upfront allocation and naturally integrates with [`BudgetedReader`] for
255/// protection against malicious inputs.
256pub struct ReadManyIter<'reader, R: ByteReader, D: Deserializable> {
257 reader: &'reader mut R,
258 remaining: usize,
259 _item: core::marker::PhantomData<D>,
260}
261
262impl<'reader, R: ByteReader, D: Deserializable> Iterator for ReadManyIter<'reader, R, D> {
263 type Item = Result<D, DeserializationError>;
264
265 fn next(&mut self) -> Option<Self::Item> {
266 if self.remaining > 0 {
267 self.remaining -= 1;
268 Some(D::read_from(self.reader))
269 } else {
270 None
271 }
272 }
273
274 fn size_hint(&self) -> (usize, Option<usize>) {
275 (self.remaining, Some(self.remaining))
276 }
277}
278
279impl<'reader, R: ByteReader, D: Deserializable> ExactSizeIterator for ReadManyIter<'reader, R, D> {}
280
281// STANDARD LIBRARY ADAPTER
282// ================================================================================================
283
284/// An adapter of [ByteReader] to any type that implements [std::io::Read]
285///
286/// In particular, this covers things like [std::fs::File], standard input, etc.
287#[cfg(feature = "std")]
288pub struct ReadAdapter<'a> {
289 // NOTE: The [ByteReader] trait does not currently support reader implementations that require
290 // mutation during `peek_u8`, `has_more_bytes`, and `check_eor`. These (or equivalent)
291 // operations on the standard library [std::io::BufRead] trait require a mutable reference, as
292 // it may be necessary to read from the underlying input to implement them.
293 //
294 // To handle this, we wrap the underlying reader in an [RefCell], this allows us to mutate the
295 // reader if necessary during a call to one of the above-mentioned trait methods, without
296 // sacrificing safety - at the cost of enforcing Rust's borrowing semantics dynamically.
297 //
298 // This should not be a problem in practice, except in the case where `read_slice` is called,
299 // and the reference returned is from `reader` directly, rather than `buf`. If a call to one
300 // of the above-mentioned methods is made while that reference is live, and we attempt to read
301 // from `reader`, a panic will occur.
302 //
303 // Ultimately, this should be addressed by making the [ByteReader] trait align with the
304 // standard library I/O traits, so this is a temporary solution.
305 reader: RefCell<std::io::BufReader<&'a mut dyn std::io::Read>>,
306 // A temporary buffer to store chunks read from `reader` that are larger than what is required
307 // for the higher-level [ByteReader] APIs.
308 //
309 // By default we attempt to satisfy reads from `reader` directly, but that is not always
310 // possible.
311 buf: alloc::vec::Vec<u8>,
312 // The position in `buf` at which we should start reading the next byte, when `buf` is
313 // non-empty.
314 pos: usize,
315 // This is set when we attempt to read from `reader` and get an empty buffer. This indicates
316 // that once we exhaust `buf`, we have truly reached end-of-file.
317 //
318 // We will use this to more accurately handle functions like `has_more_bytes` when this is set.
319 guaranteed_eof: bool,
320}
321
322#[cfg(feature = "std")]
323impl<'a> ReadAdapter<'a> {
324 /// Create a new [ByteReader] adapter for the given implementation of [std::io::Read]
325 pub fn new(reader: &'a mut dyn std::io::Read) -> Self {
326 Self {
327 reader: RefCell::new(std::io::BufReader::with_capacity(256, reader)),
328 buf: Default::default(),
329 pos: 0,
330 guaranteed_eof: false,
331 }
332 }
333
334 /// Get the internal adapter buffer as a (possibly empty) slice of bytes
335 #[inline(always)]
336 fn buffer(&self) -> &[u8] {
337 self.buf.get(self.pos..).unwrap_or(&[])
338 }
339
340 /// Get the internal adapter buffer as a slice of bytes, or `None` if the buffer is empty
341 #[inline(always)]
342 fn non_empty_buffer(&self) -> Option<&[u8]> {
343 self.buf.get(self.pos..).filter(|b| !b.is_empty())
344 }
345
346 /// Return the current reader buffer as a (possibly empty) slice of bytes.
347 ///
348 /// This buffer being empty _does not_ mean we're at EOF, you must call
349 /// [non_empty_reader_buffer_mut] first.
350 #[inline(always)]
351 fn reader_buffer(&self) -> Ref<'_, [u8]> {
352 Ref::map(self.reader.borrow(), |r| r.buffer())
353 }
354
355 /// Return the current reader buffer, reading from the underlying reader
356 /// if the buffer is empty.
357 ///
358 /// Returns `Ok` only if the buffer is non-empty, and no errors occurred
359 /// while filling it (if filling was needed).
360 fn non_empty_reader_buffer_mut(&mut self) -> Result<&[u8], DeserializationError> {
361 use std::io::ErrorKind;
362 let buf = self.reader.get_mut().fill_buf().map_err(|e| match e.kind() {
363 ErrorKind::UnexpectedEof => DeserializationError::UnexpectedEOF,
364 e => DeserializationError::UnknownError(e.to_string()),
365 })?;
366 if buf.is_empty() {
367 self.guaranteed_eof = true;
368 Err(DeserializationError::UnexpectedEOF)
369 } else {
370 Ok(buf)
371 }
372 }
373
374 /// Same as [non_empty_reader_buffer_mut], but with dynamically-enforced
375 /// borrow check rules so that it can be called in functions like `peek_u8`.
376 ///
377 /// This comes with overhead for the dynamic checks, so you should prefer
378 /// to call [non_empty_reader_buffer_mut] if you already have a mutable
379 /// reference to `self`
380 fn non_empty_reader_buffer(&self) -> Result<Ref<'_, [u8]>, DeserializationError> {
381 use std::io::ErrorKind;
382 let mut reader = self.reader.borrow_mut();
383 let buf = reader.fill_buf().map_err(|e| match e.kind() {
384 ErrorKind::UnexpectedEof => DeserializationError::UnexpectedEOF,
385 e => DeserializationError::UnknownError(e.to_string()),
386 })?;
387 if buf.is_empty() {
388 Err(DeserializationError::UnexpectedEOF)
389 } else {
390 // Re-borrow immutably
391 drop(reader);
392 Ok(self.reader_buffer())
393 }
394 }
395
396 /// Returns true if there is sufficient capacity remaining in `buf` to hold `n` bytes
397 #[inline]
398 fn has_remaining_capacity(&self, n: usize) -> bool {
399 let remaining = self.buf.capacity() - self.buffer().len();
400 remaining >= n
401 }
402
403 /// Takes the next byte from the input, returning an error if the operation fails
404 fn pop(&mut self) -> Result<u8, DeserializationError> {
405 if let Some(byte) = self.non_empty_buffer().map(|b| b[0]) {
406 self.pos += 1;
407 return Ok(byte);
408 }
409 let result = self.non_empty_reader_buffer_mut().map(|b| b[0]);
410 if result.is_ok() {
411 self.reader.get_mut().consume(1);
412 } else {
413 self.guaranteed_eof = true;
414 }
415 result
416 }
417
418 /// Takes the next `N` bytes from the input as an array, returning an error if the operation
419 /// fails
420 fn read_exact<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
421 let buf = self.buffer();
422 let mut output = [0; N];
423 match buf.len() {
424 0 => {
425 let buf = self.non_empty_reader_buffer_mut()?;
426 if buf.len() < N {
427 return Err(DeserializationError::UnexpectedEOF);
428 }
429 // SAFETY: This copy is guaranteed to be safe, as we have validated above
430 // that `buf` has at least N bytes, and `output` is defined to be exactly
431 // N bytes.
432 unsafe {
433 core::ptr::copy_nonoverlapping(buf.as_ptr(), output.as_mut_ptr(), N);
434 }
435 self.reader.get_mut().consume(N);
436 },
437 n if n >= N => {
438 // SAFETY: This copy is guaranteed to be safe, as we have validated above
439 // that `buf` has at least N bytes, and `output` is defined to be exactly
440 // N bytes.
441 unsafe {
442 core::ptr::copy_nonoverlapping(buf.as_ptr(), output.as_mut_ptr(), N);
443 }
444 self.pos += N;
445 },
446 n => {
447 // We have to fill from both the local and reader buffers
448 self.non_empty_reader_buffer_mut()?;
449 let reader_buf = self.reader_buffer();
450 match reader_buf.len() {
451 #[cfg(debug_assertions)]
452 0 => unreachable!("expected reader buffer to be non-empty to reach here"),
453 #[cfg(not(debug_assertions))]
454 // SAFETY: The call to `non_empty_reader_buffer_mut` will return an error
455 // if `reader_buffer` is non-empty, as a result is is impossible to reach
456 // here with a length of 0.
457 0 => unsafe { core::hint::unreachable_unchecked() },
458 // We got enough in one request
459 m if m + n >= N => {
460 let needed = N - n;
461 let dst = output.as_mut_ptr();
462 // SAFETY: Both copies are guaranteed to be in-bounds:
463 //
464 // * `output` is defined to be exactly N bytes
465 // * `buf` is guaranteed to be < N bytes
466 // * `reader_buf` is guaranteed to have the remaining bytes needed,
467 // and we only copy exactly that many bytes
468 unsafe {
469 core::ptr::copy_nonoverlapping(self.buffer().as_ptr(), dst, n);
470 core::ptr::copy_nonoverlapping(reader_buf.as_ptr(), dst.add(n), needed);
471 drop(reader_buf);
472 }
473 self.pos += n;
474 self.reader.get_mut().consume(needed);
475 },
476 // We didn't get enough, but haven't necessarily reached eof yet, so fall back
477 // to filling `self.buf`
478 m => {
479 let needed = N - (m + n);
480 drop(reader_buf);
481 self.buffer_at_least(needed)?;
482 debug_assert!(
483 self.buffer().len() >= N,
484 "expected buffer to be at least {N} bytes after call to buffer_at_least"
485 );
486 // SAFETY: This is guaranteed to be an in-bounds copy
487 unsafe {
488 core::ptr::copy_nonoverlapping(
489 self.buffer().as_ptr(),
490 output.as_mut_ptr(),
491 N,
492 );
493 }
494 self.pos += N;
495 return Ok(output);
496 },
497 }
498 },
499 }
500
501 // Check if we should reset our internal buffer
502 if self.buffer().is_empty() && self.pos > 0 {
503 unsafe {
504 self.buf.set_len(0);
505 }
506 }
507
508 Ok(output)
509 }
510
511 /// Fill `self.buf` with `count` bytes
512 ///
513 /// This should only be called when we can't read from the reader directly
514 fn buffer_at_least(&mut self, mut count: usize) -> Result<(), DeserializationError> {
515 // Read until we have at least `count` bytes, or until we reach end-of-file,
516 // which ever comes first.
517 loop {
518 // If we have successfully read `count` bytes, we're done
519 if count == 0 || self.buffer().len() >= count {
520 break Ok(());
521 }
522
523 // This operation will return an error if the underlying reader hits EOF
524 self.non_empty_reader_buffer_mut()?;
525
526 // Extend `self.buf` with the bytes read from the underlying reader.
527 //
528 // NOTE: We have to re-borrow the reader buffer here, since we can't get a mutable
529 // reference to `self.buf` while holding an immutable reference to the reader buffer.
530 let reader = self.reader.get_mut();
531 let buf = reader.buffer();
532 let consumed = buf.len();
533 self.buf.extend_from_slice(buf);
534 reader.consume(consumed);
535 count = count.saturating_sub(consumed);
536 }
537 }
538}
539
540#[cfg(feature = "std")]
541impl ByteReader for ReadAdapter<'_> {
542 #[inline(always)]
543 fn read_u8(&mut self) -> Result<u8, DeserializationError> {
544 self.pop()
545 }
546
547 /// NOTE: If we happen to not have any bytes buffered yet when this is called, then we will be
548 /// forced to try and read from the underlying reader. This requires a mutable reference, which
549 /// is obtained dynamically via [RefCell].
550 ///
551 /// <div class="warning">
552 /// Callers must ensure that they do not hold any immutable references to the buffer of this
553 /// reader when calling this function so as to avoid a situation in which the dynamic borrow
554 /// check fails. Specifically, you must not be holding a reference to the result of
555 /// [Self::read_slice] when this function is called.
556 /// </div>
557 fn peek_u8(&self) -> Result<u8, DeserializationError> {
558 if let Some(byte) = self.buffer().first() {
559 return Ok(*byte);
560 }
561 self.non_empty_reader_buffer().map(|b| b[0])
562 }
563
564 fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError> {
565 // Edge case
566 if len == 0 {
567 return Ok(&[]);
568 }
569
570 // If we have unused buffer, and the consumed portion is
571 // large enough, we will move the unused portion of the buffer
572 // to the start, freeing up bytes at the end for more reads
573 // before forcing a reallocation
574 let should_optimize_storage = self.pos >= 16 && !self.has_remaining_capacity(len);
575 if should_optimize_storage {
576 // We're going to optimize storage first
577 let buf = self.buffer();
578 let src = buf.as_ptr();
579 let count = buf.len();
580 let dst = self.buf.as_mut_ptr();
581 unsafe {
582 core::ptr::copy(src, dst, count);
583 self.buf.set_len(count);
584 self.pos = 0;
585 }
586 }
587
588 // Fill the buffer so we have at least `len` bytes available,
589 // this will return an error if we hit EOF first
590 self.buffer_at_least(len)?;
591
592 let slice = &self.buf[self.pos..(self.pos + len)];
593 self.pos += len;
594 Ok(slice)
595 }
596
597 #[inline]
598 fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
599 if N == 0 {
600 return Ok([0; N]);
601 }
602 self.read_exact()
603 }
604
605 fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError> {
606 // Do we have sufficient data in the local buffer?
607 let buffer_len = self.buffer().len();
608 if buffer_len >= num_bytes {
609 return Ok(());
610 }
611
612 // What about if we include what is in the local buffer and the reader's buffer?
613 let reader_buffer_len = self.non_empty_reader_buffer().map(|b| b.len())?;
614 let buffer_len = buffer_len + reader_buffer_len;
615 if buffer_len >= num_bytes {
616 return Ok(());
617 }
618
619 // We have no more input, thus can't fulfill a request of `num_bytes`
620 if self.guaranteed_eof {
621 return Err(DeserializationError::UnexpectedEOF);
622 }
623
624 // Because this function is read-only, we must optimistically assume we can read `num_bytes`
625 // from the input, and fail later if that does not hold. We know we're not at EOF yet, but
626 // that's all we can say without buffering more from the reader. We could make use of
627 // `buffer_at_least`, which would guarantee a correct result, but it would also impose
628 // additional restrictions on the use of this function, e.g. not using it while holding a
629 // reference returned from `read_slice`. Since it is not a memory safety violation to return
630 // an optimistic result here, it makes for a better tradeoff.
631 Ok(())
632 }
633
634 #[inline]
635 fn has_more_bytes(&self) -> bool {
636 !self.buffer().is_empty() || self.non_empty_reader_buffer().is_ok()
637 }
638}
639
640// CURSOR
641// ================================================================================================
642
643#[cfg(feature = "std")]
644macro_rules! cursor_remaining_buf {
645 ($cursor:ident) => {{
646 let buf = $cursor.get_ref().as_ref();
647 let start = $cursor.position().min(buf.len() as u64) as usize;
648 &buf[start..]
649 }};
650}
651
652#[cfg(feature = "std")]
653impl<T: AsRef<[u8]>> ByteReader for std::io::Cursor<T> {
654 fn read_u8(&mut self) -> Result<u8, DeserializationError> {
655 let buf = cursor_remaining_buf!(self);
656 if buf.is_empty() {
657 Err(DeserializationError::UnexpectedEOF)
658 } else {
659 let byte = buf[0];
660 self.set_position(self.position() + 1);
661 Ok(byte)
662 }
663 }
664
665 fn peek_u8(&self) -> Result<u8, DeserializationError> {
666 cursor_remaining_buf!(self)
667 .first()
668 .copied()
669 .ok_or(DeserializationError::UnexpectedEOF)
670 }
671
672 fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError> {
673 let pos = self.position();
674 let size = self.get_ref().as_ref().len() as u64;
675 if size.saturating_sub(pos) < len as u64 {
676 Err(DeserializationError::UnexpectedEOF)
677 } else {
678 self.set_position(pos + len as u64);
679 let start = pos.min(size) as usize;
680 Ok(&self.get_ref().as_ref()[start..(start + len)])
681 }
682 }
683
684 fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
685 self.read_slice(N).map(|bytes| {
686 let mut result = [0u8; N];
687 result.copy_from_slice(bytes);
688 result
689 })
690 }
691
692 fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError> {
693 if cursor_remaining_buf!(self).len() >= num_bytes {
694 Ok(())
695 } else {
696 Err(DeserializationError::UnexpectedEOF)
697 }
698 }
699
700 #[inline]
701 fn has_more_bytes(&self) -> bool {
702 let pos = self.position();
703 let size = self.get_ref().as_ref().len() as u64;
704 pos < size
705 }
706}
707
708// SLICE READER
709// ================================================================================================
710
711/// Implements [ByteReader] trait for a slice of bytes.
712///
713/// NOTE: If you are building with the `std` feature, you should probably prefer [std::io::Cursor]
714/// instead. However, [SliceReader] is still useful in no-std environments until stabilization of
715/// the `core_io_borrowed_buf` feature.
716pub struct SliceReader<'a> {
717 source: &'a [u8],
718 pos: usize,
719}
720
721impl<'a> SliceReader<'a> {
722 /// Creates a new slice reader from the specified slice.
723 pub fn new(source: &'a [u8]) -> Self {
724 SliceReader { source, pos: 0 }
725 }
726}
727
728impl ByteReader for SliceReader<'_> {
729 fn read_u8(&mut self) -> Result<u8, DeserializationError> {
730 self.check_eor(1)?;
731 let result = self.source[self.pos];
732 self.pos += 1;
733 Ok(result)
734 }
735
736 fn peek_u8(&self) -> Result<u8, DeserializationError> {
737 self.check_eor(1)?;
738 Ok(self.source[self.pos])
739 }
740
741 fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError> {
742 self.check_eor(len)?;
743 let result = &self.source[self.pos..self.pos + len];
744 self.pos += len;
745 Ok(result)
746 }
747
748 fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
749 self.check_eor(N)?;
750 let mut result = [0_u8; N];
751 result.copy_from_slice(&self.source[self.pos..self.pos + N]);
752 self.pos += N;
753 Ok(result)
754 }
755
756 fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError> {
757 if self.pos + num_bytes > self.source.len() {
758 return Err(DeserializationError::UnexpectedEOF);
759 }
760 Ok(())
761 }
762
763 fn has_more_bytes(&self) -> bool {
764 self.pos < self.source.len()
765 }
766}
767
768// BUDGETED READER
769// ================================================================================================
770
771/// A reader wrapper that enforces a byte budget during deserialization.
772///
773/// # Threat Model
774///
775/// Malicious input can attack deserialization in two ways:
776///
777/// 1. **Fake length prefix**: Input claims `len = 2^60` elements, causing allocation of a huge
778/// `Vec` before any data is read.
779/// 2. **Oversized input**: Attacker sends gigabytes of valid-looking data to exhaust memory over
780/// time.
781///
782/// # Defense Strategy
783///
784/// Use `BudgetedReader` to limit total bytes consumed. Its [`max_alloc`](ByteReader::max_alloc)
785/// method derives a bound from the remaining budget, which
786/// [`read_many_iter`](ByteReader::read_many_iter) checks before iterating.
787///
788/// ## Problem: SliceReader alone doesn't bound allocations
789///
790/// ```
791/// use miden_serde_utils::{ByteReader, Deserializable, SliceReader};
792///
793/// // Malicious input: length prefix says 1 billion u64s, but only 16 bytes of data
794/// let mut data = Vec::new();
795/// data.push(0u8); // vint64 9-byte marker
796/// data.extend_from_slice(&1_000_000_000u64.to_le_bytes());
797/// data.extend_from_slice(&[0u8; 16]);
798///
799/// // SliceReader returns usize::MAX from max_alloc, so read_many_iter accepts
800/// // any length. This would try to iterate 1 billion times (slow, not OOM,
801/// // but still a DoS vector).
802/// let reader = SliceReader::new(&data);
803/// assert_eq!(reader.max_alloc(8), usize::MAX);
804/// ```
805///
806/// ## Solution: BudgetedReader bounds allocations via max_alloc
807///
808/// ```
809/// use miden_serde_utils::{BudgetedReader, ByteReader, Deserializable, SliceReader};
810///
811/// // Same malicious input
812/// let mut data = Vec::new();
813/// data.push(0u8);
814/// data.extend_from_slice(&1_000_000_000u64.to_le_bytes());
815/// data.extend_from_slice(&[0u8; 16]);
816///
817/// // BudgetedReader with 64-byte budget: max_alloc(8) = 64/8 = 8 elements
818/// let inner = SliceReader::new(&data);
819/// let reader = BudgetedReader::new(inner, 64);
820/// assert_eq!(reader.max_alloc(8), 8);
821///
822/// // read_many_iter rejects the 1B length since 1B > 8
823/// let result = Vec::<u64>::read_from_bytes_with_budget(&data, 64);
824/// assert!(result.is_err());
825/// ```
826///
827/// ## Best practice: Set budget to expected input size
828///
829/// ```
830/// use miden_serde_utils::{ByteWriter, Deserializable, Serializable};
831///
832/// // Legitimate input: 3 u64s, properly serialized
833/// let original = vec![1u64, 2, 3];
834/// let mut data = Vec::new();
835/// original.write_into(&mut data);
836///
837/// // Budget = data.len() bounds both fake lengths and total consumption
838/// let result = Vec::<u64>::read_from_bytes_with_budget(&data, data.len());
839/// assert_eq!(result.unwrap(), vec![1, 2, 3]);
840/// ```
841pub struct BudgetedReader<R> {
842 inner: R,
843 remaining: usize,
844}
845
846impl<R> BudgetedReader<R> {
847 /// Wraps a reader with the specified byte budget.
848 pub fn new(inner: R, budget: usize) -> Self {
849 Self { inner, remaining: budget }
850 }
851
852 /// Returns remaining budget in bytes.
853 pub fn remaining(&self) -> usize {
854 self.remaining
855 }
856
857 /// Consumes budget, returning an error if insufficient.
858 fn consume_budget(&mut self, n: usize) -> Result<(), DeserializationError> {
859 if n > self.remaining {
860 return Err(DeserializationError::InvalidValue(format!(
861 "budget exhausted: requested {n} bytes, {} remaining",
862 self.remaining
863 )));
864 }
865 self.remaining -= n;
866 Ok(())
867 }
868}
869
870impl<R: ByteReader> ByteReader for BudgetedReader<R> {
871 fn read_u8(&mut self) -> Result<u8, DeserializationError> {
872 self.consume_budget(1)?;
873 self.inner.read_u8()
874 }
875
876 fn peek_u8(&self) -> Result<u8, DeserializationError> {
877 // peek doesn't consume budget since it doesn't advance the reader
878 self.inner.peek_u8()
879 }
880
881 fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError> {
882 self.consume_budget(len)?;
883 self.inner.read_slice(len)
884 }
885
886 fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
887 self.consume_budget(N)?;
888 self.inner.read_array()
889 }
890
891 fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError> {
892 // check budget first, then delegate
893 if num_bytes > self.remaining {
894 return Err(DeserializationError::InvalidValue(format!(
895 "budget exhausted: requested {num_bytes} bytes, {} remaining",
896 self.remaining
897 )));
898 }
899 self.inner.check_eor(num_bytes)
900 }
901
902 fn has_more_bytes(&self) -> bool {
903 self.remaining > 0 && self.inner.has_more_bytes()
904 }
905
906 fn max_alloc(&self, element_size: usize) -> usize {
907 if element_size == 0 {
908 return usize::MAX; // ZSTs don't consume budget
909 }
910 self.remaining / element_size
911 }
912}
913
914#[cfg(all(test, feature = "std"))]
915mod tests {
916 use std::io::Cursor;
917
918 use super::*;
919 use crate::ByteWriter;
920
921 #[test]
922 fn read_adapter_empty() -> Result<(), DeserializationError> {
923 let mut reader = std::io::empty();
924 let mut adapter = ReadAdapter::new(&mut reader);
925 assert!(!adapter.has_more_bytes());
926 assert_eq!(adapter.check_eor(8), Err(DeserializationError::UnexpectedEOF));
927 assert_eq!(adapter.peek_u8(), Err(DeserializationError::UnexpectedEOF));
928 assert_eq!(adapter.read_u8(), Err(DeserializationError::UnexpectedEOF));
929 assert_eq!(adapter.read_slice(0), Ok([].as_slice()));
930 assert_eq!(adapter.read_slice(1), Err(DeserializationError::UnexpectedEOF));
931 assert_eq!(adapter.read_array(), Ok([]));
932 assert_eq!(adapter.read_array::<1>(), Err(DeserializationError::UnexpectedEOF));
933 Ok(())
934 }
935
936 #[test]
937 fn read_adapter_passthrough() -> Result<(), DeserializationError> {
938 let mut reader = std::io::repeat(0b101);
939 let mut adapter = ReadAdapter::new(&mut reader);
940 assert!(adapter.has_more_bytes());
941 assert_eq!(adapter.check_eor(8), Ok(()));
942 assert_eq!(adapter.peek_u8(), Ok(0b101));
943 assert_eq!(adapter.read_u8(), Ok(0b101));
944 assert_eq!(adapter.read_slice(0), Ok([].as_slice()));
945 assert_eq!(adapter.read_slice(4), Ok([0b101, 0b101, 0b101, 0b101].as_slice()));
946 assert_eq!(adapter.read_array(), Ok([]));
947 assert_eq!(adapter.read_array(), Ok([0b101, 0b101]));
948 Ok(())
949 }
950
951 #[test]
952 fn read_adapter_exact() {
953 const VALUE: usize = 2048;
954 let mut reader = Cursor::new(VALUE.to_le_bytes());
955 let mut adapter = ReadAdapter::new(&mut reader);
956 assert_eq!(usize::from_le_bytes(adapter.read_array().unwrap()), VALUE);
957 assert!(!adapter.has_more_bytes());
958 assert_eq!(adapter.peek_u8(), Err(DeserializationError::UnexpectedEOF));
959 assert_eq!(adapter.read_u8(), Err(DeserializationError::UnexpectedEOF));
960 }
961
962 #[test]
963 fn read_adapter_roundtrip() {
964 const VALUE: usize = 2048;
965
966 // Write VALUE to storage
967 let mut cursor = Cursor::new([0; core::mem::size_of::<usize>()]);
968 cursor.write_usize(VALUE);
969
970 // Read VALUE from storage
971 cursor.set_position(0);
972 let mut adapter = ReadAdapter::new(&mut cursor);
973
974 assert_eq!(adapter.read_usize(), Ok(VALUE));
975 }
976
977 #[test]
978 fn read_adapter_for_file() {
979 use std::fs::File;
980
981 use crate::ByteWriter;
982
983 let path = std::env::temp_dir().join("read_adapter_for_file.bin");
984
985 // Encode some data to a buffer, then write that buffer to a file
986 {
987 let mut buf = Vec::<u8>::with_capacity(256);
988 buf.write_bytes(b"MAGIC\0");
989 buf.write_bool(true);
990 buf.write_u32(0xbeef);
991 buf.write_usize(0xfeed);
992 buf.write_u16(0x5);
993
994 std::fs::write(&path, &buf).unwrap();
995 }
996
997 // Open the file, and try to decode the encoded items
998 let mut file = File::open(&path).unwrap();
999 let mut reader = ReadAdapter::new(&mut file);
1000 assert_eq!(reader.peek_u8().unwrap(), b'M');
1001 assert_eq!(reader.read_slice(6).unwrap(), b"MAGIC\0");
1002 assert!(reader.read_bool().unwrap());
1003 assert_eq!(reader.read_u32().unwrap(), 0xbeef);
1004 assert_eq!(reader.read_usize().unwrap(), 0xfeed);
1005 assert_eq!(reader.read_u16().unwrap(), 0x5);
1006 assert!(!reader.has_more_bytes(), "expected there to be no more data in the input");
1007 }
1008
1009 #[test]
1010 fn read_adapter_issue_383() {
1011 const STR_BYTES: &[u8] = b"just a string";
1012
1013 use std::fs::File;
1014
1015 use crate::ByteWriter;
1016
1017 let path = std::env::temp_dir().join("issue_383.bin");
1018
1019 // Encode some data to a buffer, then write that buffer to a file
1020 {
1021 let mut buf = vec![0u8; 1024];
1022 unsafe {
1023 buf.set_len(0);
1024 }
1025 buf.write_u128(2 * u64::MAX as u128);
1026 unsafe {
1027 buf.set_len(512);
1028 }
1029 buf.write_bytes(STR_BYTES);
1030 buf.write_u32(0xbeef);
1031
1032 std::fs::write(&path, &buf).unwrap();
1033 }
1034
1035 // Open the file, and try to decode the encoded items
1036 let mut file = File::open(&path).unwrap();
1037 let mut reader = ReadAdapter::new(&mut file);
1038 assert_eq!(reader.read_u128().unwrap(), 2 * u64::MAX as u128);
1039 assert_eq!(reader.buf.len(), 0);
1040 assert_eq!(reader.pos, 0);
1041 // Read to offset 512 (we're 16 bytes into the underlying file, i.e. offset of 496)
1042 reader.read_slice(496).unwrap();
1043 assert_eq!(reader.buf.len(), 496);
1044 assert_eq!(reader.pos, 496);
1045 // The byte string is 13 bytes, followed by 4 bytes containing the trailing u32 value.
1046 // We expect that the underlying reader will buffer the remaining bytes of the file when
1047 // reading STR_BYTES, so the total size of our adapter's buffer should be
1048 // 496 + STR_BYTES.len() + size_of::<u32>();
1049 assert_eq!(reader.read_slice(STR_BYTES.len()).unwrap(), STR_BYTES);
1050 assert_eq!(reader.buf.len(), 496 + STR_BYTES.len() + core::mem::size_of::<u32>());
1051 // We haven't read the u32 yet
1052 assert_eq!(reader.pos, 509);
1053 assert_eq!(reader.read_u32().unwrap(), 0xbeef);
1054 // Now we have
1055 assert_eq!(reader.pos, 513);
1056 assert!(!reader.has_more_bytes(), "expected there to be no more data in the input");
1057 }
1058
1059 #[test]
1060 fn budgeted_reader_basic() {
1061 let data = [1u8, 2, 3, 4, 5, 6, 7, 8];
1062 let inner = SliceReader::new(&data);
1063 let mut reader = BudgetedReader::new(inner, 4);
1064
1065 assert_eq!(reader.remaining(), 4);
1066 assert!(reader.has_more_bytes());
1067
1068 // read 4 bytes (within budget)
1069 assert_eq!(reader.read_u32().unwrap(), 0x04030201);
1070 assert_eq!(reader.remaining(), 0);
1071
1072 // budget exhausted
1073 assert!(!reader.has_more_bytes());
1074 assert!(reader.read_u8().is_err());
1075 }
1076
1077 #[test]
1078 fn budgeted_reader_peek_does_not_consume() {
1079 let data = [42u8];
1080 let inner = SliceReader::new(&data);
1081 let mut reader = BudgetedReader::new(inner, 1);
1082
1083 // peek multiple times, budget unchanged
1084 assert_eq!(reader.peek_u8().unwrap(), 42);
1085 assert_eq!(reader.peek_u8().unwrap(), 42);
1086 assert_eq!(reader.remaining(), 1);
1087
1088 // actual read consumes budget
1089 assert_eq!(reader.read_u8().unwrap(), 42);
1090 assert_eq!(reader.remaining(), 0);
1091 }
1092
1093 #[test]
1094 fn budgeted_reader_check_eor_respects_budget() {
1095 let data = [0u8; 100];
1096 let inner = SliceReader::new(&data);
1097 let reader = BudgetedReader::new(inner, 10);
1098
1099 // within budget
1100 assert!(reader.check_eor(10).is_ok());
1101
1102 // exceeds budget (even though inner has enough bytes)
1103 assert!(reader.check_eor(11).is_err());
1104 }
1105
1106 #[test]
1107 fn budgeted_reader_read_slice() {
1108 let data = [1u8, 2, 3, 4, 5];
1109 let inner = SliceReader::new(&data);
1110 let mut reader = BudgetedReader::new(inner, 3);
1111
1112 // read 3 bytes (exactly budget)
1113 assert_eq!(reader.read_slice(3).unwrap(), &[1, 2, 3]);
1114 assert_eq!(reader.remaining(), 0);
1115
1116 // can't read more
1117 assert!(reader.read_slice(1).is_err());
1118 }
1119
1120 #[test]
1121 fn budgeted_reader_read_array() {
1122 let data = [0xaau8, 0xbb, 0xcc, 0xdd];
1123 let inner = SliceReader::new(&data);
1124 let mut reader = BudgetedReader::new(inner, 2);
1125
1126 // read 2-byte array
1127 assert_eq!(reader.read_array::<2>().unwrap(), [0xaa, 0xbb]);
1128 assert_eq!(reader.remaining(), 0);
1129
1130 // budget exhausted
1131 assert!(reader.read_array::<2>().is_err());
1132 }
1133
1134 #[test]
1135 fn budgeted_reader_zero_budget() {
1136 let data = [1u8];
1137 let inner = SliceReader::new(&data);
1138 let mut reader = BudgetedReader::new(inner, 0);
1139
1140 assert!(!reader.has_more_bytes());
1141 assert!(reader.read_u8().is_err());
1142 // peek still works (doesn't consume budget)
1143 assert_eq!(reader.peek_u8().unwrap(), 1);
1144 }
1145
1146 #[test]
1147 fn budgeted_reader_max_alloc() {
1148 let data = [0u8; 100];
1149 let inner = SliceReader::new(&data);
1150 let reader = BudgetedReader::new(inner, 64);
1151
1152 // 64 bytes budget / 8 bytes per u64 = 8 elements max
1153 assert_eq!(reader.max_alloc(8), 8);
1154
1155 // 64 bytes budget / 1 byte per u8 = 64 elements max
1156 assert_eq!(reader.max_alloc(1), 64);
1157
1158 // 64 bytes budget / 16 bytes per u128 = 4 elements max
1159 assert_eq!(reader.max_alloc(16), 4);
1160
1161 // ZSTs (0 bytes) return usize::MAX
1162 assert_eq!(reader.max_alloc(0), usize::MAX);
1163 }
1164
1165 #[test]
1166 fn unbounded_reader_max_alloc_returns_max() {
1167 let data = [0u8; 100];
1168 let reader = SliceReader::new(&data);
1169
1170 // Unbounded readers return usize::MAX
1171 assert_eq!(reader.max_alloc(1), usize::MAX);
1172 assert_eq!(reader.max_alloc(8), usize::MAX);
1173 }
1174
1175 // ============================================================================================
1176 // The following tests document the threat model and defense layers.
1177 // ============================================================================================
1178
1179 /// SliceReader alone does NOT reject fake length prefixes.
1180 ///
1181 /// A malicious input claiming 1000 elements will be accepted by read_many_iter
1182 /// because SliceReader.max_alloc() returns usize::MAX. The deserialization will
1183 /// eventually fail with UnexpectedEOF, but only after attempting to iterate
1184 /// (which could be slow for huge counts, though not OOM since we don't pre-allocate).
1185 #[test]
1186 fn slice_reader_accepts_fake_length_prefix() {
1187 let mut data = Vec::new();
1188 // Write length = 1000 (vint64 encoding: 0x07D0 << 2 | 0b10 = 0x1F42)
1189 // For simplicity, use the 9-byte form
1190 data.push(0); // 9-byte marker
1191 data.extend_from_slice(&1000u64.to_le_bytes());
1192 // Only 8 bytes of actual u64 data (1 element, not 1000)
1193 data.extend_from_slice(&42u64.to_le_bytes());
1194
1195 // read_many_iter passes the max_alloc check (usize::MAX >= 1000)
1196 let mut reader = SliceReader::new(&data);
1197 let _len = reader.read_usize().unwrap();
1198 let iter_result = reader.read_many_iter::<u64>(1000);
1199
1200 // The iterator is created successfully
1201 assert!(iter_result.is_ok());
1202
1203 // But collecting fails on the 2nd element (EOF)
1204 let collect_result: Result<Vec<u64>, _> = iter_result.unwrap().collect();
1205 assert!(collect_result.is_err());
1206 assert!(matches!(collect_result.unwrap_err(), DeserializationError::UnexpectedEOF));
1207 }
1208
1209 /// BudgetedReader rejects fake length prefixes BEFORE iteration begins.
1210 ///
1211 /// With a 64-byte budget, max_alloc(8) = 8, so a claim of 1000 elements
1212 /// is rejected immediately by read_many_iter.
1213 #[test]
1214 fn budgeted_reader_rejects_fake_length_upfront() {
1215 let mut data = Vec::new();
1216 data.push(0); // 9-byte vint64 marker
1217 data.extend_from_slice(&1000u64.to_le_bytes());
1218 data.extend_from_slice(&42u64.to_le_bytes());
1219
1220 let inner = SliceReader::new(&data);
1221 let mut reader = BudgetedReader::new(inner, 64);
1222
1223 let _len = reader.read_usize().unwrap(); // consumes 9 bytes, 55 remaining
1224 // 55 / 8 = 6 elements max
1225 let iter_result = reader.read_many_iter::<u64>(1000);
1226
1227 // Rejected immediately: 1000 > 6
1228 match iter_result {
1229 Err(DeserializationError::InvalidValue(_)) => {}, // expected
1230 other => panic!("expected InvalidValue error, got {:?}", other.map(|_| "Ok")),
1231 }
1232 }
1233
1234 /// Best practice: budget = input length provides both protections.
1235 ///
1236 /// 1. Fake length prefixes are bounded by max_alloc (remaining_bytes / element_size)
1237 /// 2. Total consumption is bounded by the budget
1238 #[test]
1239 fn budget_equals_input_length_is_safe() {
1240 // Valid input: 2 u64s
1241 let original = vec![100u64, 200];
1242 let mut data = Vec::new();
1243 crate::Serializable::write_into(&original, &mut data);
1244
1245 // Budget = exact input size
1246 let result = Vec::<u64>::read_from_bytes_with_budget(&data, data.len());
1247 assert_eq!(result.unwrap(), vec![100, 200]);
1248
1249 // Malicious input claiming 1000 elements (same serialized prefix manipulation)
1250 let mut evil_data = Vec::new();
1251 evil_data.push(0); // 9-byte vint64
1252 evil_data.extend_from_slice(&1000u64.to_le_bytes());
1253 evil_data.extend_from_slice(&42u64.to_le_bytes()); // only 1 actual element
1254
1255 // Budget = input length (17 bytes). After reading length (9 bytes), 8 remain.
1256 // max_alloc(8) = 8/8 = 1, so 1000 > 1 fails.
1257 let result = Vec::<u64>::read_from_bytes_with_budget(&evil_data, evil_data.len());
1258 assert!(result.is_err());
1259 }
1260
1261 // ============================================================================================
1262 // Tests documenting min_serialized_size()-based allocation bounds (defaults to size_of)
1263 // ============================================================================================
1264
1265 /// The max_alloc check uses D::min_serialized_size() to bound memory allocation.
1266 /// By default, min_serialized_size() returns size_of::<D>().
1267 ///
1268 /// For flat collections like Vec<u64>, this works well: we check that
1269 /// budget / min_serialized_size() >= requested_count before allocating.
1270 #[test]
1271 fn min_serialized_size_bounds_flat_collections() {
1272 let mut data = Vec::new();
1273 data.push(0); // 9-byte vint64 marker
1274 data.extend_from_slice(&1000u64.to_le_bytes()); // claim 1000 u64s
1275 data.extend_from_slice(&[0u8; 16]); // only 2 u64s of actual data
1276
1277 let inner = SliceReader::new(&data);
1278 // Budget of 80 bytes: after reading 9-byte length, 71 remain.
1279 // max_alloc(u64::min_serialized_size()) = 71 / 8 = 8 elements max
1280 let mut reader = BudgetedReader::new(inner, 80);
1281
1282 let _len = reader.read_usize().unwrap();
1283 let result = reader.read_many_iter::<u64>(1000);
1284
1285 // Rejected: 1000 > 8
1286 assert!(result.is_err());
1287 }
1288
1289 /// For nested collections like Vec<Vec<u64>>, min_serialized_size() returns 1 (the minimum
1290 /// vint length prefix), not size_of. This is more permissive but accurate: a
1291 /// serialized Vec can be as small as 1 byte (empty vec).
1292 ///
1293 /// The early-abort check uses this minimum, and budget enforcement during actual
1294 /// reads provides the real protection against malicious input.
1295 #[test]
1296 fn min_serialized_size_override_for_nested_collections() {
1297 // Vec<u64>::min_serialized_size() returns 1 (minimum vint prefix), not size_of
1298 assert_eq!(<Vec<u64>>::min_serialized_size(), 1);
1299
1300 let mut data = Vec::new();
1301 data.push(0); // 9-byte vint64 marker
1302 data.extend_from_slice(&100u64.to_le_bytes()); // claim 100 inner Vecs
1303 // Only provide enough data for 1 empty inner Vec
1304 data.push(0b10); // vint64 for 0 (empty inner vec)
1305
1306 let inner = SliceReader::new(&data);
1307 // With min_serialized_size() = 1, we need budget >= 100 to pass the early check.
1308 // After reading 9-byte length, 101 - 9 = 92 remaining, 92 / 1 = 92 < 100.
1309 // So with budget = 110, we get 110 - 9 = 101 remaining, 101 >= 100.
1310 let mut reader = BudgetedReader::new(inner, 110);
1311
1312 let _len = reader.read_usize().unwrap();
1313 let result = reader.read_many_iter::<Vec<u64>>(100);
1314
1315 // The early check passes (100 <= 101)
1316 assert!(result.is_ok());
1317
1318 // But deserialization fails when we try to read 100 inner Vecs with only 1
1319 let collect_result: Result<Vec<Vec<u64>>, _> = result.unwrap().collect();
1320 assert!(collect_result.is_err());
1321 }
1322
1323 /// Demonstrates that min_serialized_size() approach still provides security for nested
1324 /// collections, just with later detection. The budget is enforced during reads.
1325 #[test]
1326 fn nested_collections_still_protected_by_budget() {
1327 // With Vec::min_serialized_size() = 1, the early check is permissive.
1328 // Security comes from budget enforcement during actual reads.
1329 let mut data = Vec::new();
1330 data.push(0); // 9-byte vint64 marker
1331 data.extend_from_slice(&10u64.to_le_bytes()); // claim 10 inner Vecs
1332 // Each inner vec claims 1000 u64s but provides none
1333 for _ in 0..10 {
1334 data.push(0); // 9-byte vint64 marker
1335 data.extend_from_slice(&1000u64.to_le_bytes());
1336 }
1337
1338 let inner = SliceReader::new(&data);
1339 // Small budget: will run out during inner deserialization
1340 let mut reader = BudgetedReader::new(inner, 100);
1341
1342 // Outer length read succeeds (consumes 9 bytes, 91 remaining)
1343 let _len = reader.read_usize().unwrap();
1344
1345 // With Vec::min_serialized_size() = 1, early check passes: 91 / 1 = 91 >= 10
1346 let result = reader.read_many_iter::<Vec<u64>>(10);
1347 assert!(result.is_ok());
1348
1349 // But collecting fails because the inner vecs claim 1000 u64s each,
1350 // exhausting the budget during inner deserialization
1351 let collect_result: Result<Vec<Vec<u64>>, _> = result.unwrap().collect();
1352 assert!(collect_result.is_err());
1353 }
1354}