miden_serde_utils/byte_reader.rs
1// Copyright (c) Facebook, Inc. and its affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6#[cfg(feature = "std")]
7use alloc::string::ToString;
8use alloc::{format, string::String, vec::Vec};
9#[cfg(feature = "std")]
10use core::cell::{Ref, RefCell};
11#[cfg(feature = "std")]
12use std::io::BufRead;
13
14use crate::{Deserializable, DeserializationError};
15
16// BYTE READER TRAIT
17// ================================================================================================
18
19/// Defines how primitive values are to be read from `Self`.
20///
21/// Whenever data is read from the reader using any of the `read_*` functions, the reader advances
22/// to the next unread byte. If the error occurs, the reader is not rolled back to the state prior
23/// to calling any of the function.
24pub trait ByteReader {
25 // REQUIRED METHODS
26 // --------------------------------------------------------------------------------------------
27
28 /// Returns a single byte read from `self`.
29 ///
30 /// # Errors
31 /// Returns a [DeserializationError] error the reader is at EOF.
32 fn read_u8(&mut self) -> Result<u8, DeserializationError>;
33
34 /// Returns the next byte to be read from `self` without advancing the reader to the next byte.
35 ///
36 /// # Errors
37 /// Returns a [DeserializationError] error the reader is at EOF.
38 fn peek_u8(&self) -> Result<u8, DeserializationError>;
39
40 /// Returns a slice of bytes of the specified length read from `self`.
41 ///
42 /// # Errors
43 /// Returns a [DeserializationError] if a slice of the specified length could not be read
44 /// from `self`.
45 fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError>;
46
47 /// Returns a byte array of length `N` read from `self`.
48 ///
49 /// # Errors
50 /// Returns a [DeserializationError] if an array of the specified length could not be read
51 /// from `self`.
52 fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError>;
53
54 /// Checks if it is possible to read at least `num_bytes` bytes from this ByteReader
55 ///
56 /// # Errors
57 /// Returns an error if, when reading the requested number of bytes, we go beyond the
58 /// the data available in the reader.
59 fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError>;
60
61 /// Returns true if there are more bytes left to be read from `self`.
62 fn has_more_bytes(&self) -> bool;
63
64 /// Returns the maximum number of elements that can be safely allocated, given each
65 /// element occupies `element_size` bytes when serialized.
66 ///
67 /// This can be used by callers to pre-validate collection lengths before iterating,
68 /// preventing denial-of-service attacks from malicious length prefixes that claim
69 /// billions of elements.
70 ///
71 /// The default implementation returns `usize::MAX`, meaning no limit is enforced.
72 /// [`BudgetedReader`] overrides this to return `remaining_budget / element_size`,
73 /// providing tight, adaptive limits based on the caller's budget.
74 ///
75 /// # Arguments
76 /// * `element_size` - The serialized size of one element, from
77 /// [`Deserializable::min_serialized_size`]. Defaults to `size_of::<D>()` but can be
78 /// overridden for types where serialized size differs from in-memory size.
79 fn max_alloc(&self, _element_size: usize) -> usize {
80 usize::MAX
81 }
82
83 // PROVIDED METHODS
84 // --------------------------------------------------------------------------------------------
85
86 /// Returns a boolean value read from `self` consuming 1 byte from the reader.
87 ///
88 /// # Errors
89 /// Returns a [DeserializationError] if a u16 value could not be read from `self`.
90 fn read_bool(&mut self) -> Result<bool, DeserializationError> {
91 let byte = self.read_u8()?;
92 match byte {
93 0 => Ok(false),
94 1 => Ok(true),
95 _ => Err(DeserializationError::InvalidValue(format!("{byte} is not a boolean value"))),
96 }
97 }
98
99 /// Returns a u16 value read from `self` in little-endian byte order.
100 ///
101 /// # Errors
102 /// Returns a [DeserializationError] if a u16 value could not be read from `self`.
103 fn read_u16(&mut self) -> Result<u16, DeserializationError> {
104 let bytes = self.read_array::<2>()?;
105 Ok(u16::from_le_bytes(bytes))
106 }
107
108 /// Returns a u32 value read from `self` in little-endian byte order.
109 ///
110 /// # Errors
111 /// Returns a [DeserializationError] if a u32 value could not be read from `self`.
112 fn read_u32(&mut self) -> Result<u32, DeserializationError> {
113 let bytes = self.read_array::<4>()?;
114 Ok(u32::from_le_bytes(bytes))
115 }
116
117 /// Returns a u64 value read from `self` in little-endian byte order.
118 ///
119 /// # Errors
120 /// Returns a [DeserializationError] if a u64 value could not be read from `self`.
121 fn read_u64(&mut self) -> Result<u64, DeserializationError> {
122 let bytes = self.read_array::<8>()?;
123 Ok(u64::from_le_bytes(bytes))
124 }
125
126 /// Returns a u128 value read from `self` in little-endian byte order.
127 ///
128 /// # Errors
129 /// Returns a [DeserializationError] if a u128 value could not be read from `self`.
130 fn read_u128(&mut self) -> Result<u128, DeserializationError> {
131 let bytes = self.read_array::<16>()?;
132 Ok(u128::from_le_bytes(bytes))
133 }
134
135 /// Returns a usize value read from `self` in [vint64](https://docs.rs/vint64/latest/vint64/)
136 /// format.
137 ///
138 /// # Errors
139 /// Returns a [DeserializationError] if:
140 /// * usize value could not be read from `self`.
141 /// * encoded value is greater than `usize` maximum value on a given platform.
142 fn read_usize(&mut self) -> Result<usize, DeserializationError> {
143 let first_byte = self.peek_u8()?;
144 let length = first_byte.trailing_zeros() as usize + 1;
145
146 let result = if length == 9 {
147 // 9-byte special case
148 self.read_u8()?;
149 let value = self.read_array::<8>()?;
150 u64::from_le_bytes(value)
151 } else {
152 let mut encoded = [0u8; 8];
153 let value = self.read_slice(length)?;
154 encoded[..length].copy_from_slice(value);
155 u64::from_le_bytes(encoded) >> length
156 };
157
158 // check if the result value is within acceptable bounds for `usize` on a given platform
159 if result > usize::MAX as u64 {
160 return Err(DeserializationError::InvalidValue(format!(
161 "Encoded value must be less than {}, but {} was provided",
162 usize::MAX,
163 result
164 )));
165 }
166
167 Ok(result as usize)
168 }
169
170 /// Returns a byte vector of the specified length read from `self`.
171 ///
172 /// # Errors
173 /// Returns a [DeserializationError] if a vector of the specified length could not be read
174 /// from `self`.
175 fn read_vec(&mut self, len: usize) -> Result<Vec<u8>, DeserializationError> {
176 let data = self.read_slice(len)?;
177 Ok(data.to_vec())
178 }
179
180 /// Returns a String of the specified length read from `self`.
181 ///
182 /// # Errors
183 /// Returns a [DeserializationError] if a String of the specified length could not be read
184 /// from `self`.
185 fn read_string(&mut self, num_bytes: usize) -> Result<String, DeserializationError> {
186 let data = self.read_vec(num_bytes)?;
187 String::from_utf8(data).map_err(|err| DeserializationError::InvalidValue(format!("{err}")))
188 }
189
190 /// Reads a deserializable value from `self`.
191 ///
192 /// # Errors
193 /// Returns a [DeserializationError] if the specified value could not be read from `self`.
194 fn read<D>(&mut self) -> Result<D, DeserializationError>
195 where
196 Self: Sized,
197 D: Deserializable,
198 {
199 D::read_from(self)
200 }
201
202 /// Returns an iterator that deserializes `num_elements` instances of `D` from this reader.
203 ///
204 /// This method validates the requested count against the reader's capacity before returning
205 /// the iterator, rejecting implausible lengths early. Each element is then deserialized
206 /// lazily as the iterator is consumed.
207 ///
208 /// # Errors
209 ///
210 /// Returns an error if `num_elements` exceeds `self.max_alloc(D::min_serialized_size())`,
211 /// indicating the reader cannot allocate that many elements.
212 ///
213 /// # Example
214 ///
215 /// ```ignore
216 /// // Collect into a Vec
217 /// let items: Vec<u64> = reader
218 /// .read_many_iter::<u64>(count)?
219 /// .collect::<Result<_, _>>()?;
220 ///
221 /// // Collect directly into a BTreeMap (no intermediate Vec)
222 /// let map: BTreeMap<K, V> = reader
223 /// .read_many_iter::<(K, V)>(count)?
224 /// .collect::<Result<_, _>>()?;
225 /// ```
226 fn read_many_iter<D>(
227 &mut self,
228 num_elements: usize,
229 ) -> Result<ReadManyIter<'_, Self, D>, DeserializationError>
230 where
231 Self: Sized,
232 D: Deserializable,
233 {
234 let max_elements = self.max_alloc(D::min_serialized_size());
235 if num_elements > max_elements {
236 return Err(DeserializationError::InvalidValue(format!(
237 "requested {num_elements} elements but reader can provide at most {max_elements}"
238 )));
239 }
240 Ok(ReadManyIter {
241 reader: self,
242 remaining: num_elements,
243 _item: core::marker::PhantomData,
244 })
245 }
246}
247
248// READ MANY ITERATOR
249// ================================================================================================
250
251/// Iterator that lazily deserializes elements from a [`ByteReader`].
252///
253/// Created by [`ByteReader::read_many_iter`]. Each call to `next()` deserializes one element.
254/// This avoids upfront allocation and naturally integrates with [`BudgetedReader`] for
255/// protection against malicious inputs.
256pub struct ReadManyIter<'reader, R: ByteReader, D: Deserializable> {
257 reader: &'reader mut R,
258 remaining: usize,
259 _item: core::marker::PhantomData<D>,
260}
261
262impl<'reader, R: ByteReader, D: Deserializable> Iterator for ReadManyIter<'reader, R, D> {
263 type Item = Result<D, DeserializationError>;
264
265 fn next(&mut self) -> Option<Self::Item> {
266 if self.remaining > 0 {
267 self.remaining -= 1;
268 Some(D::read_from(self.reader))
269 } else {
270 None
271 }
272 }
273
274 fn size_hint(&self) -> (usize, Option<usize>) {
275 (self.remaining, Some(self.remaining))
276 }
277}
278
279impl<'reader, R: ByteReader, D: Deserializable> ExactSizeIterator for ReadManyIter<'reader, R, D> {}
280
281// STANDARD LIBRARY ADAPTER
282// ================================================================================================
283
284/// An adapter of [ByteReader] to any type that implements [std::io::Read]
285///
286/// In particular, this covers things like [std::fs::File], standard input, etc.
287#[cfg(feature = "std")]
288pub struct ReadAdapter<'a> {
289 // NOTE: The [ByteReader] trait does not currently support reader implementations that require
290 // mutation during `peek_u8`, `has_more_bytes`, and `check_eor`. These (or equivalent)
291 // operations on the standard library [std::io::BufRead] trait require a mutable reference, as
292 // it may be necessary to read from the underlying input to implement them.
293 //
294 // To handle this, we wrap the underlying reader in an [RefCell], this allows us to mutate the
295 // reader if necessary during a call to one of the above-mentioned trait methods, without
296 // sacrificing safety - at the cost of enforcing Rust's borrowing semantics dynamically.
297 //
298 // This should not be a problem in practice, except in the case where `read_slice` is called,
299 // and the reference returned is from `reader` directly, rather than `buf`. If a call to one
300 // of the above-mentioned methods is made while that reference is live, and we attempt to read
301 // from `reader`, a panic will occur.
302 //
303 // Ultimately, this should be addressed by making the [ByteReader] trait align with the
304 // standard library I/O traits, so this is a temporary solution.
305 reader: RefCell<std::io::BufReader<&'a mut dyn std::io::Read>>,
306 // A temporary buffer to store chunks read from `reader` that are larger than what is required
307 // for the higher-level [ByteReader] APIs.
308 //
309 // By default we attempt to satisfy reads from `reader` directly, but that is not always
310 // possible.
311 buf: Vec<u8>,
312 // The position in `buf` at which we should start reading the next byte, when `buf` is
313 // non-empty.
314 pos: usize,
315 // This is set when we attempt to read from `reader` and get an empty buffer. This indicates
316 // that once we exhaust `buf`, we have truly reached end-of-file.
317 //
318 // We will use this to more accurately handle functions like `has_more_bytes` when this is set.
319 guaranteed_eof: bool,
320}
321
322#[cfg(feature = "std")]
323impl<'a> ReadAdapter<'a> {
324 /// Create a new [ByteReader] adapter for the given implementation of [std::io::Read]
325 pub fn new(reader: &'a mut dyn std::io::Read) -> Self {
326 Self {
327 reader: RefCell::new(std::io::BufReader::with_capacity(256, reader)),
328 buf: Default::default(),
329 pos: 0,
330 guaranteed_eof: false,
331 }
332 }
333
334 /// Get the internal adapter buffer as a (possibly empty) slice of bytes
335 #[inline(always)]
336 fn buffer(&self) -> &[u8] {
337 self.buf.get(self.pos..).unwrap_or(&[])
338 }
339
340 /// Get the internal adapter buffer as a slice of bytes, or `None` if the buffer is empty
341 #[inline(always)]
342 fn non_empty_buffer(&self) -> Option<&[u8]> {
343 self.buf.get(self.pos..).filter(|b| !b.is_empty())
344 }
345
346 /// Return the current reader buffer as a (possibly empty) slice of bytes.
347 ///
348 /// This buffer being empty _does not_ mean we're at EOF, you must call
349 /// [non_empty_reader_buffer_mut] first.
350 #[inline(always)]
351 fn reader_buffer(&self) -> Ref<'_, [u8]> {
352 Ref::map(self.reader.borrow(), |r| r.buffer())
353 }
354
355 /// Return the current reader buffer, reading from the underlying reader
356 /// if the buffer is empty.
357 ///
358 /// Returns `Ok` only if the buffer is non-empty, and no errors occurred
359 /// while filling it (if filling was needed).
360 fn non_empty_reader_buffer_mut(&mut self) -> Result<&[u8], DeserializationError> {
361 use std::io::ErrorKind;
362 let buf = self.reader.get_mut().fill_buf().map_err(|e| match e.kind() {
363 ErrorKind::UnexpectedEof => DeserializationError::UnexpectedEOF,
364 e => DeserializationError::UnknownError(e.to_string()),
365 })?;
366 if buf.is_empty() {
367 self.guaranteed_eof = true;
368 Err(DeserializationError::UnexpectedEOF)
369 } else {
370 Ok(buf)
371 }
372 }
373
374 /// Same as [non_empty_reader_buffer_mut], but with dynamically-enforced
375 /// borrow check rules so that it can be called in functions like `peek_u8`.
376 ///
377 /// This comes with overhead for the dynamic checks, so you should prefer
378 /// to call [non_empty_reader_buffer_mut] if you already have a mutable
379 /// reference to `self`
380 fn non_empty_reader_buffer(&self) -> Result<Ref<'_, [u8]>, DeserializationError> {
381 use std::io::ErrorKind;
382 let mut reader = self.reader.borrow_mut();
383 let buf = reader.fill_buf().map_err(|e| match e.kind() {
384 ErrorKind::UnexpectedEof => DeserializationError::UnexpectedEOF,
385 e => DeserializationError::UnknownError(e.to_string()),
386 })?;
387 if buf.is_empty() {
388 Err(DeserializationError::UnexpectedEOF)
389 } else {
390 // Re-borrow immutably
391 drop(reader);
392 Ok(self.reader_buffer())
393 }
394 }
395
396 /// Returns true if there is sufficient capacity remaining in `buf` to hold `n` bytes
397 #[inline]
398 fn has_remaining_capacity(&self, n: usize) -> bool {
399 let remaining = self.buf.capacity() - self.buffer().len();
400 remaining >= n
401 }
402
403 /// Takes the next byte from the input, returning an error if the operation fails
404 fn pop(&mut self) -> Result<u8, DeserializationError> {
405 if let Some(byte) = self.non_empty_buffer().map(|b| b[0]) {
406 self.pos += 1;
407 return Ok(byte);
408 }
409 let result = self.non_empty_reader_buffer_mut().map(|b| b[0]);
410 if result.is_ok() {
411 self.reader.get_mut().consume(1);
412 } else {
413 self.guaranteed_eof = true;
414 }
415 result
416 }
417
418 /// Takes the next `N` bytes from the input as an array, returning an error if the operation
419 /// fails
420 fn read_exact<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
421 let buf = self.buffer();
422 let mut output = [0; N];
423 match buf.len() {
424 0 => {
425 let buf = self.non_empty_reader_buffer_mut()?;
426 if buf.len() < N {
427 return Err(DeserializationError::UnexpectedEOF);
428 }
429 // SAFETY: This copy is guaranteed to be safe, as we have validated above
430 // that `buf` has at least N bytes, and `output` is defined to be exactly
431 // N bytes.
432 unsafe {
433 core::ptr::copy_nonoverlapping(buf.as_ptr(), output.as_mut_ptr(), N);
434 }
435 self.reader.get_mut().consume(N);
436 },
437 n if n >= N => {
438 // SAFETY: This copy is guaranteed to be safe, as we have validated above
439 // that `buf` has at least N bytes, and `output` is defined to be exactly
440 // N bytes.
441 unsafe {
442 core::ptr::copy_nonoverlapping(buf.as_ptr(), output.as_mut_ptr(), N);
443 }
444 self.pos += N;
445 },
446 n => {
447 // We have to fill from both the local and reader buffers
448 self.non_empty_reader_buffer_mut()?;
449 let reader_buf = self.reader_buffer();
450 match reader_buf.len() {
451 #[cfg(debug_assertions)]
452 0 => unreachable!("expected reader buffer to be non-empty to reach here"),
453 #[cfg(not(debug_assertions))]
454 // SAFETY: The call to `non_empty_reader_buffer_mut` will return an error
455 // if `reader_buffer` is non-empty, as a result is is impossible to reach
456 // here with a length of 0.
457 0 => unsafe { core::hint::unreachable_unchecked() },
458 // We got enough in one request
459 m if m + n >= N => {
460 let needed = N - n;
461 let dst = output.as_mut_ptr();
462 // SAFETY: Both copies are guaranteed to be in-bounds:
463 //
464 // * `output` is defined to be exactly N bytes
465 // * `buf` is guaranteed to be < N bytes
466 // * `reader_buf` is guaranteed to have the remaining bytes needed,
467 // and we only copy exactly that many bytes
468 unsafe {
469 core::ptr::copy_nonoverlapping(self.buffer().as_ptr(), dst, n);
470 core::ptr::copy_nonoverlapping(reader_buf.as_ptr(), dst.add(n), needed);
471 drop(reader_buf);
472 }
473 self.pos += n;
474 self.reader.get_mut().consume(needed);
475 },
476 // We didn't get enough, but haven't necessarily reached eof yet, so fall back
477 // to filling `self.buf`
478 m => {
479 let needed = N - (m + n);
480 drop(reader_buf);
481 self.buffer_at_least(needed)?;
482 debug_assert!(
483 self.buffer().len() >= N,
484 "expected buffer to be at least {N} bytes after call to buffer_at_least"
485 );
486 // SAFETY: This is guaranteed to be an in-bounds copy
487 unsafe {
488 core::ptr::copy_nonoverlapping(
489 self.buffer().as_ptr(),
490 output.as_mut_ptr(),
491 N,
492 );
493 }
494 self.pos += N;
495 return Ok(output);
496 },
497 }
498 },
499 }
500
501 // Check if we should reset our internal buffer
502 if self.buffer().is_empty() && self.pos > 0 {
503 unsafe {
504 self.buf.set_len(0);
505 }
506 }
507
508 Ok(output)
509 }
510
511 /// Fill `self.buf` with `count` bytes
512 ///
513 /// This should only be called when we can't read from the reader directly
514 fn buffer_at_least(&mut self, mut count: usize) -> Result<(), DeserializationError> {
515 // Read until we have at least `count` bytes, or until we reach end-of-file,
516 // which ever comes first.
517 loop {
518 // If we have successfully read `count` bytes, we're done
519 if count == 0 || self.buffer().len() >= count {
520 break Ok(());
521 }
522
523 // This operation will return an error if the underlying reader hits EOF
524 self.non_empty_reader_buffer_mut()?;
525
526 // Extend `self.buf` with the bytes read from the underlying reader.
527 //
528 // NOTE: We have to re-borrow the reader buffer here, since we can't get a mutable
529 // reference to `self.buf` while holding an immutable reference to the reader buffer.
530 let reader = self.reader.get_mut();
531 let buf = reader.buffer();
532 let consumed = buf.len();
533 self.buf.extend_from_slice(buf);
534 reader.consume(consumed);
535 count = count.saturating_sub(consumed);
536 }
537 }
538}
539
540#[cfg(feature = "std")]
541impl ByteReader for ReadAdapter<'_> {
542 #[inline(always)]
543 fn read_u8(&mut self) -> Result<u8, DeserializationError> {
544 self.pop()
545 }
546
547 /// NOTE: If we happen to not have any bytes buffered yet when this is called, then we will be
548 /// forced to try and read from the underlying reader. This requires a mutable reference, which
549 /// is obtained dynamically via [RefCell].
550 ///
551 /// <div class="warning">
552 /// Callers must ensure that they do not hold any immutable references to the buffer of this
553 /// reader when calling this function so as to avoid a situation in which the dynamic borrow
554 /// check fails. Specifically, you must not be holding a reference to the result of
555 /// [Self::read_slice] when this function is called.
556 /// </div>
557 fn peek_u8(&self) -> Result<u8, DeserializationError> {
558 if let Some(byte) = self.buffer().first() {
559 return Ok(*byte);
560 }
561 self.non_empty_reader_buffer().map(|b| b[0])
562 }
563
564 fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError> {
565 // Edge case
566 if len == 0 {
567 return Ok(&[]);
568 }
569
570 // If we have unused buffer, and the consumed portion is
571 // large enough, we will move the unused portion of the buffer
572 // to the start, freeing up bytes at the end for more reads
573 // before forcing a reallocation
574 let should_optimize_storage = self.pos >= 16 && !self.has_remaining_capacity(len);
575 if should_optimize_storage {
576 // We're going to optimize storage first
577 let buf = self.buffer();
578 let src = buf.as_ptr();
579 let count = buf.len();
580 let dst = self.buf.as_mut_ptr();
581 unsafe {
582 core::ptr::copy(src, dst, count);
583 self.buf.set_len(count);
584 self.pos = 0;
585 }
586 }
587
588 // Fill the buffer so we have at least `len` bytes available,
589 // this will return an error if we hit EOF first
590 self.buffer_at_least(len)?;
591
592 let slice = &self.buf[self.pos..(self.pos + len)];
593 self.pos += len;
594 Ok(slice)
595 }
596
597 #[inline]
598 fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
599 if N == 0 {
600 return Ok([0; N]);
601 }
602 self.read_exact()
603 }
604
605 fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError> {
606 // Do we have sufficient data in the local buffer?
607 let buffer_len = self.buffer().len();
608 if buffer_len >= num_bytes {
609 return Ok(());
610 }
611
612 // What about if we include what is in the local buffer and the reader's buffer?
613 let reader_buffer_len = self.non_empty_reader_buffer().map(|b| b.len())?;
614 let buffer_len = buffer_len + reader_buffer_len;
615 if buffer_len >= num_bytes {
616 return Ok(());
617 }
618
619 // We have no more input, thus can't fulfill a request of `num_bytes`
620 if self.guaranteed_eof {
621 return Err(DeserializationError::UnexpectedEOF);
622 }
623
624 // Because this function is read-only, we must optimistically assume we can read `num_bytes`
625 // from the input, and fail later if that does not hold. We know we're not at EOF yet, but
626 // that's all we can say without buffering more from the reader. We could make use of
627 // `buffer_at_least`, which would guarantee a correct result, but it would also impose
628 // additional restrictions on the use of this function, e.g. not using it while holding a
629 // reference returned from `read_slice`. Since it is not a memory safety violation to return
630 // an optimistic result here, it makes for a better tradeoff.
631 Ok(())
632 }
633
634 #[inline]
635 fn has_more_bytes(&self) -> bool {
636 !self.buffer().is_empty() || self.non_empty_reader_buffer().is_ok()
637 }
638}
639
640// CURSOR
641// ================================================================================================
642
643#[cfg(feature = "std")]
644macro_rules! cursor_remaining_buf {
645 ($cursor:ident) => {{
646 let buf = $cursor.get_ref().as_ref();
647 let start = $cursor.position().min(buf.len() as u64) as usize;
648 &buf[start..]
649 }};
650}
651
652#[cfg(feature = "std")]
653impl<T: AsRef<[u8]>> ByteReader for std::io::Cursor<T> {
654 fn read_u8(&mut self) -> Result<u8, DeserializationError> {
655 let buf = cursor_remaining_buf!(self);
656 if buf.is_empty() {
657 Err(DeserializationError::UnexpectedEOF)
658 } else {
659 let byte = buf[0];
660 self.set_position(self.position() + 1);
661 Ok(byte)
662 }
663 }
664
665 fn peek_u8(&self) -> Result<u8, DeserializationError> {
666 cursor_remaining_buf!(self)
667 .first()
668 .copied()
669 .ok_or(DeserializationError::UnexpectedEOF)
670 }
671
672 fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError> {
673 let pos = self.position();
674 let size = self.get_ref().as_ref().len() as u64;
675 if size.saturating_sub(pos) < len as u64 {
676 Err(DeserializationError::UnexpectedEOF)
677 } else {
678 self.set_position(pos + len as u64);
679 let start = pos.min(size) as usize;
680 Ok(&self.get_ref().as_ref()[start..(start + len)])
681 }
682 }
683
684 fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
685 self.read_slice(N).map(|bytes| {
686 let mut result = [0u8; N];
687 result.copy_from_slice(bytes);
688 result
689 })
690 }
691
692 fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError> {
693 if cursor_remaining_buf!(self).len() >= num_bytes {
694 Ok(())
695 } else {
696 Err(DeserializationError::UnexpectedEOF)
697 }
698 }
699
700 #[inline]
701 fn has_more_bytes(&self) -> bool {
702 let pos = self.position();
703 let size = self.get_ref().as_ref().len() as u64;
704 pos < size
705 }
706}
707
708// SLICE READER
709// ================================================================================================
710
711/// Implements [ByteReader] trait for a slice of bytes.
712///
713/// NOTE: If you are building with the `std` feature, you should probably prefer [std::io::Cursor]
714/// instead. However, [SliceReader] is still useful in no-std environments until stabilization of
715/// the `core_io_borrowed_buf` feature.
716pub struct SliceReader<'a> {
717 source: &'a [u8],
718 pos: usize,
719}
720
721impl<'a> SliceReader<'a> {
722 /// Creates a new slice reader from the specified slice.
723 pub fn new(source: &'a [u8]) -> Self {
724 SliceReader { source, pos: 0 }
725 }
726}
727
728impl ByteReader for SliceReader<'_> {
729 fn read_u8(&mut self) -> Result<u8, DeserializationError> {
730 self.check_eor(1)?;
731 let result = self.source[self.pos];
732 self.pos += 1;
733 Ok(result)
734 }
735
736 fn peek_u8(&self) -> Result<u8, DeserializationError> {
737 self.check_eor(1)?;
738 Ok(self.source[self.pos])
739 }
740
741 fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError> {
742 self.check_eor(len)?;
743 let result = &self.source[self.pos..self.pos + len];
744 self.pos += len;
745 Ok(result)
746 }
747
748 fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
749 self.check_eor(N)?;
750 let mut result = [0_u8; N];
751 result.copy_from_slice(&self.source[self.pos..self.pos + N]);
752 self.pos += N;
753 Ok(result)
754 }
755
756 fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError> {
757 self.pos
758 .checked_add(num_bytes)
759 .filter(|end| *end <= self.source.len())
760 .map(|_| ())
761 .ok_or(DeserializationError::UnexpectedEOF)
762 }
763
764 fn has_more_bytes(&self) -> bool {
765 self.pos < self.source.len()
766 }
767}
768
769// BUDGETED READER
770// ================================================================================================
771
772/// A reader wrapper that enforces a byte budget during deserialization.
773///
774/// # Threat Model
775///
776/// Malicious input can attack deserialization in two ways:
777///
778/// 1. **Fake length prefix**: Input claims `len = 2^60` elements, causing allocation of a huge
779/// `Vec` before any data is read.
780/// 2. **Oversized input**: Attacker sends gigabytes of valid-looking data to exhaust memory over
781/// time.
782///
783/// # Defense Strategy
784///
785/// Use `BudgetedReader` to limit total bytes consumed. Its [`max_alloc`](ByteReader::max_alloc)
786/// method derives a bound from the remaining budget, which
787/// [`read_many_iter`](ByteReader::read_many_iter) checks before iterating.
788///
789/// ## Problem: SliceReader alone doesn't bound allocations
790///
791/// ```
792/// use miden_serde_utils::{ByteReader, Deserializable, SliceReader};
793///
794/// // Malicious input: length prefix says 1 billion u64s, but only 16 bytes of data
795/// let mut data = Vec::new();
796/// data.push(0u8); // vint64 9-byte marker
797/// data.extend_from_slice(&1_000_000_000u64.to_le_bytes());
798/// data.extend_from_slice(&[0u8; 16]);
799///
800/// // SliceReader returns usize::MAX from max_alloc, so read_many_iter accepts
801/// // any length. This would try to iterate 1 billion times (slow, not OOM,
802/// // but still a DoS vector).
803/// let reader = SliceReader::new(&data);
804/// assert_eq!(reader.max_alloc(8), usize::MAX);
805/// ```
806///
807/// ## Solution: BudgetedReader bounds allocations via max_alloc
808///
809/// ```
810/// use miden_serde_utils::{BudgetedReader, ByteReader, Deserializable, SliceReader};
811///
812/// // Same malicious input
813/// let mut data = Vec::new();
814/// data.push(0u8);
815/// data.extend_from_slice(&1_000_000_000u64.to_le_bytes());
816/// data.extend_from_slice(&[0u8; 16]);
817///
818/// // BudgetedReader with 64-byte budget: max_alloc(8) = 64/8 = 8 elements
819/// let inner = SliceReader::new(&data);
820/// let reader = BudgetedReader::new(inner, 64);
821/// assert_eq!(reader.max_alloc(8), 8);
822///
823/// // read_many_iter rejects the 1B length since 1B > 8
824/// let result = Vec::<u64>::read_from_bytes_with_budget(&data, 64);
825/// assert!(result.is_err());
826/// ```
827///
828/// ## Best practice: Set budget to expected input size
829///
830/// ```
831/// use miden_serde_utils::{ByteWriter, Deserializable, Serializable};
832///
833/// // Legitimate input: 3 u64s, properly serialized
834/// let original = vec![1u64, 2, 3];
835/// let mut data = Vec::new();
836/// original.write_into(&mut data);
837///
838/// // Budget = data.len() bounds both fake lengths and total consumption
839/// let result = Vec::<u64>::read_from_bytes_with_budget(&data, data.len());
840/// assert_eq!(result.unwrap(), vec![1, 2, 3]);
841/// ```
842pub struct BudgetedReader<R> {
843 inner: R,
844 remaining: usize,
845}
846
847impl<R> BudgetedReader<R> {
848 /// Wraps a reader with the specified byte budget.
849 pub fn new(inner: R, budget: usize) -> Self {
850 Self { inner, remaining: budget }
851 }
852
853 /// Returns remaining budget in bytes.
854 pub fn remaining(&self) -> usize {
855 self.remaining
856 }
857
858 /// Consumes budget, returning an error if insufficient.
859 fn consume_budget(&mut self, n: usize) -> Result<(), DeserializationError> {
860 if n > self.remaining {
861 return Err(DeserializationError::InvalidValue(format!(
862 "budget exhausted: requested {n} bytes, {} remaining",
863 self.remaining
864 )));
865 }
866 self.remaining -= n;
867 Ok(())
868 }
869}
870
871impl<R: ByteReader> ByteReader for BudgetedReader<R> {
872 fn read_u8(&mut self) -> Result<u8, DeserializationError> {
873 self.consume_budget(1)?;
874 self.inner.read_u8()
875 }
876
877 fn peek_u8(&self) -> Result<u8, DeserializationError> {
878 // peek doesn't consume budget since it doesn't advance the reader
879 self.inner.peek_u8()
880 }
881
882 fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError> {
883 self.consume_budget(len)?;
884 self.inner.read_slice(len)
885 }
886
887 fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
888 self.consume_budget(N)?;
889 self.inner.read_array()
890 }
891
892 fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError> {
893 // check budget first, then delegate
894 if num_bytes > self.remaining {
895 return Err(DeserializationError::InvalidValue(format!(
896 "budget exhausted: requested {num_bytes} bytes, {} remaining",
897 self.remaining
898 )));
899 }
900 self.inner.check_eor(num_bytes)
901 }
902
903 fn has_more_bytes(&self) -> bool {
904 self.remaining > 0 && self.inner.has_more_bytes()
905 }
906
907 fn max_alloc(&self, element_size: usize) -> usize {
908 if element_size == 0 {
909 return usize::MAX; // ZSTs don't consume budget
910 }
911 self.remaining / element_size
912 }
913}
914
915#[cfg(all(test, feature = "std"))]
916mod tests {
917 use core::mem::size_of;
918 use std::io::Cursor;
919
920 use super::*;
921 use crate::ByteWriter;
922
923 #[test]
924 fn read_adapter_empty() -> Result<(), DeserializationError> {
925 let mut reader = std::io::empty();
926 let mut adapter = ReadAdapter::new(&mut reader);
927 assert!(!adapter.has_more_bytes());
928 assert_eq!(adapter.check_eor(8), Err(DeserializationError::UnexpectedEOF));
929 assert_eq!(adapter.peek_u8(), Err(DeserializationError::UnexpectedEOF));
930 assert_eq!(adapter.read_u8(), Err(DeserializationError::UnexpectedEOF));
931 assert_eq!(adapter.read_slice(0), Ok([].as_slice()));
932 assert_eq!(adapter.read_slice(1), Err(DeserializationError::UnexpectedEOF));
933 assert_eq!(adapter.read_array(), Ok([]));
934 assert_eq!(adapter.read_array::<1>(), Err(DeserializationError::UnexpectedEOF));
935 Ok(())
936 }
937
938 #[test]
939 fn read_adapter_passthrough() -> Result<(), DeserializationError> {
940 let mut reader = std::io::repeat(0b101);
941 let mut adapter = ReadAdapter::new(&mut reader);
942 assert!(adapter.has_more_bytes());
943 assert_eq!(adapter.check_eor(8), Ok(()));
944 assert_eq!(adapter.peek_u8(), Ok(0b101));
945 assert_eq!(adapter.read_u8(), Ok(0b101));
946 assert_eq!(adapter.read_slice(0), Ok([].as_slice()));
947 assert_eq!(adapter.read_slice(4), Ok([0b101, 0b101, 0b101, 0b101].as_slice()));
948 assert_eq!(adapter.read_array(), Ok([]));
949 assert_eq!(adapter.read_array(), Ok([0b101, 0b101]));
950 Ok(())
951 }
952
953 #[test]
954 fn read_adapter_exact() {
955 const VALUE: usize = 2048;
956 let mut reader = Cursor::new(VALUE.to_le_bytes());
957 let mut adapter = ReadAdapter::new(&mut reader);
958 assert_eq!(usize::from_le_bytes(adapter.read_array().unwrap()), VALUE);
959 assert!(!adapter.has_more_bytes());
960 assert_eq!(adapter.peek_u8(), Err(DeserializationError::UnexpectedEOF));
961 assert_eq!(adapter.read_u8(), Err(DeserializationError::UnexpectedEOF));
962 }
963
964 #[test]
965 fn read_adapter_roundtrip() {
966 const VALUE: usize = 2048;
967
968 // Write VALUE to storage
969 let mut cursor = Cursor::new([0; size_of::<usize>()]);
970 cursor.write_usize(VALUE);
971
972 // Read VALUE from storage
973 cursor.set_position(0);
974 let mut adapter = ReadAdapter::new(&mut cursor);
975
976 assert_eq!(adapter.read_usize(), Ok(VALUE));
977 }
978
979 #[test]
980 fn read_adapter_for_file() {
981 use std::fs::File;
982
983 use crate::ByteWriter;
984
985 let path = std::env::temp_dir().join("read_adapter_for_file.bin");
986
987 // Encode some data to a buffer, then write that buffer to a file
988 {
989 let mut buf = Vec::<u8>::with_capacity(256);
990 buf.write_bytes(b"MAGIC\0");
991 buf.write_bool(true);
992 buf.write_u32(0xbeef);
993 buf.write_usize(0xfeed);
994 buf.write_u16(0x5);
995
996 std::fs::write(&path, &buf).unwrap();
997 }
998
999 // Open the file, and try to decode the encoded items
1000 let mut file = File::open(&path).unwrap();
1001 let mut reader = ReadAdapter::new(&mut file);
1002 assert_eq!(reader.peek_u8().unwrap(), b'M');
1003 assert_eq!(reader.read_slice(6).unwrap(), b"MAGIC\0");
1004 assert!(reader.read_bool().unwrap());
1005 assert_eq!(reader.read_u32().unwrap(), 0xbeef);
1006 assert_eq!(reader.read_usize().unwrap(), 0xfeed);
1007 assert_eq!(reader.read_u16().unwrap(), 0x5);
1008 assert!(!reader.has_more_bytes(), "expected there to be no more data in the input");
1009 }
1010
1011 #[test]
1012 fn read_adapter_issue_383() {
1013 const STR_BYTES: &[u8] = b"just a string";
1014
1015 use std::fs::File;
1016
1017 use crate::ByteWriter;
1018
1019 let path = std::env::temp_dir().join("issue_383.bin");
1020
1021 // Encode some data to a buffer, then write that buffer to a file
1022 {
1023 let mut buf = vec![0u8; 1024];
1024 unsafe {
1025 buf.set_len(0);
1026 }
1027 buf.write_u128(2 * u64::MAX as u128);
1028 unsafe {
1029 buf.set_len(512);
1030 }
1031 buf.write_bytes(STR_BYTES);
1032 buf.write_u32(0xbeef);
1033
1034 std::fs::write(&path, &buf).unwrap();
1035 }
1036
1037 // Open the file, and try to decode the encoded items
1038 let mut file = File::open(&path).unwrap();
1039 let mut reader = ReadAdapter::new(&mut file);
1040 assert_eq!(reader.read_u128().unwrap(), 2 * u64::MAX as u128);
1041 assert_eq!(reader.buf.len(), 0);
1042 assert_eq!(reader.pos, 0);
1043 // Read to offset 512 (we're 16 bytes into the underlying file, i.e. offset of 496)
1044 reader.read_slice(496).unwrap();
1045 assert_eq!(reader.buf.len(), 496);
1046 assert_eq!(reader.pos, 496);
1047 // The byte string is 13 bytes, followed by 4 bytes containing the trailing u32 value.
1048 // We expect that the underlying reader will buffer the remaining bytes of the file when
1049 // reading STR_BYTES, so the total size of our adapter's buffer should be
1050 // 496 + STR_BYTES.len() + size_of::<u32>();
1051 assert_eq!(reader.read_slice(STR_BYTES.len()).unwrap(), STR_BYTES);
1052 assert_eq!(reader.buf.len(), 496 + STR_BYTES.len() + size_of::<u32>());
1053 // We haven't read the u32 yet
1054 assert_eq!(reader.pos, 509);
1055 assert_eq!(reader.read_u32().unwrap(), 0xbeef);
1056 // Now we have
1057 assert_eq!(reader.pos, 513);
1058 assert!(!reader.has_more_bytes(), "expected there to be no more data in the input");
1059 }
1060
1061 #[test]
1062 fn budgeted_reader_basic() {
1063 let data = [1u8, 2, 3, 4, 5, 6, 7, 8];
1064 let inner = SliceReader::new(&data);
1065 let mut reader = BudgetedReader::new(inner, 4);
1066
1067 assert_eq!(reader.remaining(), 4);
1068 assert!(reader.has_more_bytes());
1069
1070 // read 4 bytes (within budget)
1071 assert_eq!(reader.read_u32().unwrap(), 0x04030201);
1072 assert_eq!(reader.remaining(), 0);
1073
1074 // budget exhausted
1075 assert!(!reader.has_more_bytes());
1076 assert!(reader.read_u8().is_err());
1077 }
1078
1079 #[test]
1080 fn budgeted_reader_peek_does_not_consume() {
1081 let data = [42u8];
1082 let inner = SliceReader::new(&data);
1083 let mut reader = BudgetedReader::new(inner, 1);
1084
1085 // peek multiple times, budget unchanged
1086 assert_eq!(reader.peek_u8().unwrap(), 42);
1087 assert_eq!(reader.peek_u8().unwrap(), 42);
1088 assert_eq!(reader.remaining(), 1);
1089
1090 // actual read consumes budget
1091 assert_eq!(reader.read_u8().unwrap(), 42);
1092 assert_eq!(reader.remaining(), 0);
1093 }
1094
1095 #[test]
1096 fn budgeted_reader_check_eor_respects_budget() {
1097 let data = [0u8; 100];
1098 let inner = SliceReader::new(&data);
1099 let reader = BudgetedReader::new(inner, 10);
1100
1101 // within budget
1102 assert!(reader.check_eor(10).is_ok());
1103
1104 // exceeds budget (even though inner has enough bytes)
1105 assert!(reader.check_eor(11).is_err());
1106 }
1107
1108 #[test]
1109 fn budgeted_reader_read_slice() {
1110 let data = [1u8, 2, 3, 4, 5];
1111 let inner = SliceReader::new(&data);
1112 let mut reader = BudgetedReader::new(inner, 3);
1113
1114 // read 3 bytes (exactly budget)
1115 assert_eq!(reader.read_slice(3).unwrap(), &[1, 2, 3]);
1116 assert_eq!(reader.remaining(), 0);
1117
1118 // can't read more
1119 assert!(reader.read_slice(1).is_err());
1120 }
1121
1122 #[test]
1123 fn budgeted_reader_read_array() {
1124 let data = [0xaau8, 0xbb, 0xcc, 0xdd];
1125 let inner = SliceReader::new(&data);
1126 let mut reader = BudgetedReader::new(inner, 2);
1127
1128 // read 2-byte array
1129 assert_eq!(reader.read_array::<2>().unwrap(), [0xaa, 0xbb]);
1130 assert_eq!(reader.remaining(), 0);
1131
1132 // budget exhausted
1133 assert!(reader.read_array::<2>().is_err());
1134 }
1135
1136 #[test]
1137 fn budgeted_reader_zero_budget() {
1138 let data = [1u8];
1139 let inner = SliceReader::new(&data);
1140 let mut reader = BudgetedReader::new(inner, 0);
1141
1142 assert!(!reader.has_more_bytes());
1143 assert!(reader.read_u8().is_err());
1144 // peek still works (doesn't consume budget)
1145 assert_eq!(reader.peek_u8().unwrap(), 1);
1146 }
1147
1148 #[test]
1149 fn budgeted_reader_max_alloc() {
1150 let data = [0u8; 100];
1151 let inner = SliceReader::new(&data);
1152 let reader = BudgetedReader::new(inner, 64);
1153
1154 // 64 bytes budget / 8 bytes per u64 = 8 elements max
1155 assert_eq!(reader.max_alloc(8), 8);
1156
1157 // 64 bytes budget / 1 byte per u8 = 64 elements max
1158 assert_eq!(reader.max_alloc(1), 64);
1159
1160 // 64 bytes budget / 16 bytes per u128 = 4 elements max
1161 assert_eq!(reader.max_alloc(16), 4);
1162
1163 // ZSTs (0 bytes) return usize::MAX
1164 assert_eq!(reader.max_alloc(0), usize::MAX);
1165 }
1166
1167 #[test]
1168 fn unbounded_reader_max_alloc_returns_max() {
1169 let data = [0u8; 100];
1170 let reader = SliceReader::new(&data);
1171
1172 // Unbounded readers return usize::MAX
1173 assert_eq!(reader.max_alloc(1), usize::MAX);
1174 assert_eq!(reader.max_alloc(8), usize::MAX);
1175 }
1176
1177 #[test]
1178 fn slice_reader_rejects_overflowing_read_lengths() {
1179 let data = [1u8];
1180 let mut reader = SliceReader::new(&data);
1181
1182 assert_eq!(reader.read_u8().unwrap(), 1);
1183 assert_eq!(reader.read_slice(usize::MAX), Err(DeserializationError::UnexpectedEOF));
1184 assert_eq!(reader.check_eor(usize::MAX), Err(DeserializationError::UnexpectedEOF));
1185 }
1186
1187 // ============================================================================================
1188 // The following tests document the threat model and defense layers.
1189 // ============================================================================================
1190
1191 /// SliceReader alone does NOT reject fake length prefixes.
1192 ///
1193 /// A malicious input claiming 1000 elements will be accepted by read_many_iter
1194 /// because SliceReader.max_alloc() returns usize::MAX. The deserialization will
1195 /// eventually fail with UnexpectedEOF, but only after attempting to iterate
1196 /// (which could be slow for huge counts, though not OOM since we don't pre-allocate).
1197 #[test]
1198 fn slice_reader_accepts_fake_length_prefix() {
1199 let mut data = Vec::new();
1200 // Write length = 1000 (vint64 encoding: 0x07D0 << 2 | 0b10 = 0x1F42)
1201 // For simplicity, use the 9-byte form
1202 data.push(0); // 9-byte marker
1203 data.extend_from_slice(&1000u64.to_le_bytes());
1204 // Only 8 bytes of actual u64 data (1 element, not 1000)
1205 data.extend_from_slice(&42u64.to_le_bytes());
1206
1207 // read_many_iter passes the max_alloc check (usize::MAX >= 1000)
1208 let mut reader = SliceReader::new(&data);
1209 let _len = reader.read_usize().unwrap();
1210 let iter_result = reader.read_many_iter::<u64>(1000);
1211
1212 // The iterator is created successfully
1213 assert!(iter_result.is_ok());
1214
1215 // But collecting fails on the 2nd element (EOF)
1216 let collect_result: Result<Vec<u64>, _> = iter_result.unwrap().collect();
1217 assert!(collect_result.is_err());
1218 assert!(matches!(collect_result.unwrap_err(), DeserializationError::UnexpectedEOF));
1219 }
1220
1221 /// BudgetedReader rejects fake length prefixes BEFORE iteration begins.
1222 ///
1223 /// With a 64-byte budget, max_alloc(8) = 8, so a claim of 1000 elements
1224 /// is rejected immediately by read_many_iter.
1225 #[test]
1226 fn budgeted_reader_rejects_fake_length_upfront() {
1227 let mut data = Vec::new();
1228 data.push(0); // 9-byte vint64 marker
1229 data.extend_from_slice(&1000u64.to_le_bytes());
1230 data.extend_from_slice(&42u64.to_le_bytes());
1231
1232 let inner = SliceReader::new(&data);
1233 let mut reader = BudgetedReader::new(inner, 64);
1234
1235 let _len = reader.read_usize().unwrap(); // consumes 9 bytes, 55 remaining
1236 // 55 / 8 = 6 elements max
1237 let iter_result = reader.read_many_iter::<u64>(1000);
1238
1239 // Rejected immediately: 1000 > 6
1240 match iter_result {
1241 Err(DeserializationError::InvalidValue(_)) => {}, // expected
1242 other => panic!("expected InvalidValue error, got {:?}", other.map(|_| "Ok")),
1243 }
1244 }
1245
1246 /// Best practice: budget = input length provides both protections.
1247 ///
1248 /// 1. Fake length prefixes are bounded by max_alloc (remaining_bytes / element_size)
1249 /// 2. Total consumption is bounded by the budget
1250 #[test]
1251 fn budget_equals_input_length_is_safe() {
1252 // Valid input: 2 u64s
1253 let original = vec![100u64, 200];
1254 let mut data = Vec::new();
1255 crate::Serializable::write_into(&original, &mut data);
1256
1257 // Budget = exact input size
1258 let result = Vec::<u64>::read_from_bytes_with_budget(&data, data.len());
1259 assert_eq!(result.unwrap(), vec![100, 200]);
1260
1261 // Malicious input claiming 1000 elements (same serialized prefix manipulation)
1262 let mut evil_data = Vec::new();
1263 evil_data.push(0); // 9-byte vint64
1264 evil_data.extend_from_slice(&1000u64.to_le_bytes());
1265 evil_data.extend_from_slice(&42u64.to_le_bytes()); // only 1 actual element
1266
1267 // Budget = input length (17 bytes). After reading length (9 bytes), 8 remain.
1268 // max_alloc(8) = 8/8 = 1, so 1000 > 1 fails.
1269 let result = Vec::<u64>::read_from_bytes_with_budget(&evil_data, evil_data.len());
1270 assert!(result.is_err());
1271 }
1272
1273 // ============================================================================================
1274 // Tests documenting min_serialized_size()-based allocation bounds (defaults to size_of)
1275 // ============================================================================================
1276
1277 /// The max_alloc check uses D::min_serialized_size() to bound memory allocation.
1278 /// By default, min_serialized_size() returns size_of::<D>().
1279 ///
1280 /// For flat collections like Vec<u64>, this works well: we check that
1281 /// budget / min_serialized_size() >= requested_count before allocating.
1282 #[test]
1283 fn min_serialized_size_bounds_flat_collections() {
1284 let mut data = Vec::new();
1285 data.push(0); // 9-byte vint64 marker
1286 data.extend_from_slice(&1000u64.to_le_bytes()); // claim 1000 u64s
1287 data.extend_from_slice(&[0u8; 16]); // only 2 u64s of actual data
1288
1289 let inner = SliceReader::new(&data);
1290 // Budget of 80 bytes: after reading 9-byte length, 71 remain.
1291 // max_alloc(u64::min_serialized_size()) = 71 / 8 = 8 elements max
1292 let mut reader = BudgetedReader::new(inner, 80);
1293
1294 let _len = reader.read_usize().unwrap();
1295 let result = reader.read_many_iter::<u64>(1000);
1296
1297 // Rejected: 1000 > 8
1298 assert!(result.is_err());
1299 }
1300
1301 /// For nested collections like Vec<Vec<u64>>, min_serialized_size() returns 1 (the minimum
1302 /// vint length prefix), not size_of. This is more permissive but accurate: a
1303 /// serialized Vec can be as small as 1 byte (empty vec).
1304 ///
1305 /// The early-abort check uses this minimum, and budget enforcement during actual
1306 /// reads provides the real protection against malicious input.
1307 #[test]
1308 fn min_serialized_size_override_for_nested_collections() {
1309 // Vec<u64>::min_serialized_size() returns 1 (minimum vint prefix), not size_of
1310 assert_eq!(<Vec<u64>>::min_serialized_size(), 1);
1311
1312 let mut data = Vec::new();
1313 data.push(0); // 9-byte vint64 marker
1314 data.extend_from_slice(&100u64.to_le_bytes()); // claim 100 inner Vecs
1315 // Only provide enough data for 1 empty inner Vec
1316 data.push(0b10); // vint64 for 0 (empty inner vec)
1317
1318 let inner = SliceReader::new(&data);
1319 // With min_serialized_size() = 1, we need budget >= 100 to pass the early check.
1320 // After reading 9-byte length, 101 - 9 = 92 remaining, 92 / 1 = 92 < 100.
1321 // So with budget = 110, we get 110 - 9 = 101 remaining, 101 >= 100.
1322 let mut reader = BudgetedReader::new(inner, 110);
1323
1324 let _len = reader.read_usize().unwrap();
1325 let result = reader.read_many_iter::<Vec<u64>>(100);
1326
1327 // The early check passes (100 <= 101)
1328 assert!(result.is_ok());
1329
1330 // But deserialization fails when we try to read 100 inner Vecs with only 1
1331 let collect_result: Result<Vec<Vec<u64>>, _> = result.unwrap().collect();
1332 assert!(collect_result.is_err());
1333 }
1334
1335 /// Demonstrates that min_serialized_size() approach still provides security for nested
1336 /// collections, just with later detection. The budget is enforced during reads.
1337 #[test]
1338 fn nested_collections_still_protected_by_budget() {
1339 // With Vec::min_serialized_size() = 1, the early check is permissive.
1340 // Security comes from budget enforcement during actual reads.
1341 let mut data = Vec::new();
1342 data.push(0); // 9-byte vint64 marker
1343 data.extend_from_slice(&10u64.to_le_bytes()); // claim 10 inner Vecs
1344 // Each inner vec claims 1000 u64s but provides none
1345 for _ in 0..10 {
1346 data.push(0); // 9-byte vint64 marker
1347 data.extend_from_slice(&1000u64.to_le_bytes());
1348 }
1349
1350 let inner = SliceReader::new(&data);
1351 // Small budget: will run out during inner deserialization
1352 let mut reader = BudgetedReader::new(inner, 100);
1353
1354 // Outer length read succeeds (consumes 9 bytes, 91 remaining)
1355 let _len = reader.read_usize().unwrap();
1356
1357 // With Vec::min_serialized_size() = 1, early check passes: 91 / 1 = 91 >= 10
1358 let result = reader.read_many_iter::<Vec<u64>>(10);
1359 assert!(result.is_ok());
1360
1361 // But collecting fails because the inner vecs claim 1000 u64s each,
1362 // exhausting the budget during inner deserialization
1363 let collect_result: Result<Vec<Vec<u64>>, _> = result.unwrap().collect();
1364 assert!(collect_result.is_err());
1365 }
1366
1367 /// Tuples should use sum of element min_serialized_size, not size_of (which includes padding).
1368 ///
1369 /// This test verifies that (u8, u64) has min_serialized_size = 9 (1 + 8) not 16 (in-memory size
1370 /// with 7 bytes of alignment padding).
1371 #[test]
1372 fn tuple_min_serialized_size_excludes_padding() {
1373 // Serialized: 1 byte for u8 + 8 bytes for u64 = 9 bytes
1374 // In-memory: 8 bytes for u8 (with 7 bytes padding) + 8 bytes for u64 = 16 bytes
1375 assert_eq!(<(u8, u64)>::min_serialized_size(), 9);
1376 assert_eq!(size_of::<(u8, u64)>(), 16);
1377
1378 // Verify budget calculation uses 9, not 16
1379 let mut data = Vec::new();
1380 data.push(0); // 9-byte vint64 marker
1381 data.extend_from_slice(&4u64.to_le_bytes()); // claim 4 tuples
1382 // Provide exactly 4 tuples worth of data: 4 * 9 = 36 bytes
1383 for i in 0u8..4 {
1384 data.push(i); // u8
1385 data.extend_from_slice(&(i as u64).to_le_bytes()); // u64
1386 }
1387
1388 let inner = SliceReader::new(&data);
1389 // Budget: 9 (length prefix) + 36 (data) = 45 bytes
1390 let mut reader = BudgetedReader::new(inner, 45);
1391
1392 let _len = reader.read_usize().unwrap();
1393 // With min_serialized_size = 9: remaining = 45 - 9 = 36, max_elements = 36 / 9 = 4
1394 // This should succeed (4 <= 4)
1395 let result = reader.read_many_iter::<(u8, u64)>(4);
1396 assert!(result.is_ok());
1397
1398 // With min_serialized_size = 16 (wrong): max_elements = 36 / 16 = 2
1399 // This would fail (4 > 2)
1400 let collect_result: Result<Vec<(u8, u64)>, _> = result.unwrap().collect();
1401 assert!(collect_result.is_ok());
1402 assert_eq!(collect_result.unwrap().len(), 4);
1403 }
1404}