Skip to main content

buffa/
message.rs

1//! The core [`Message`] trait and [`DecodeOptions`] builder.
2//!
3//! Every generated message type implements [`Message`], which provides
4//! encode/decode/merge methods and a two-pass serialization model
5//! (`compute_size` → `write_to`) that avoids the exponential-time
6//! problem affecting naïve length-delimited encoders.
7
8use bytes::{Buf, BufMut};
9
10use crate::error::DecodeError;
11use crate::message_field::DefaultInstance;
12
13/// Default recursion depth limit for decoding nested messages.
14///
15/// Protobuf implementations are required to enforce a recursion limit to
16/// prevent stack overflow from deeply nested messages in untrusted input.
17/// This value (100) matches the limit used by the official protobuf
18/// implementations and the protobuf conformance suite.
19///
20/// Pass this constant as the `depth` argument when calling [`Message::merge`]
21/// at a top-level decode site.  The provided convenience methods ([`Message::decode`],
22/// [`Message::decode_from_slice`], [`Message::merge_from_slice`]) use this
23/// limit automatically.
24pub const RECURSION_LIMIT: u32 = 100;
25
26/// The core trait implemented by all protobuf message types.
27///
28/// This trait is implemented by **generated code** — you write a `.proto` file,
29/// codegen emits the Rust struct and its `Message` impl. You should almost
30/// never implement this trait by hand.
31///
32/// # Manual implementation is discouraged
33///
34/// The only reason to implement `Message` yourself is when you need a
35/// custom in-memory representation that codegen cannot produce — for
36/// example, wrapping a `std::ops::Range<i64>` as a leaf message so the
37/// rest of your code uses the natural Rust type. If you just want a message
38/// type, **write a `.proto` file instead.**
39///
40/// Manual implementation is intentionally high-friction:
41/// - You must correctly implement the two-pass serialization contract
42///   (`compute_size` caches sizes before `write_to` uses them).
43/// - You must implement wire-format decoding in `merge_field`.
44/// - [`DefaultInstance`] is an `unsafe` supertrait; you must uphold its
45///   safety contract for the default-instance pointer.
46///
47/// If you still need to do this, see the [custom types section of the
48/// user guide](https://github.com/anthropics/buffa/blob/main/docs/guide.md#custom-type-implementations)
49/// for a complete worked example, and [`DefaultInstance`] for the safety
50/// contract.
51///
52/// # Serialization model
53///
54/// Serialization is a two-pass process to avoid the exponential-time problem
55/// that affects prost with deeply nested messages:
56///
57/// 1. **`compute_size()`** — walks the message tree and caches the encoded
58///    size of every sub-message in its `CachedSize` field.
59/// 2. **`write_to()`** — walks the tree again, writing bytes and using the
60///    cached sizes for length-prefixed sub-messages.
61///
62/// The convenience method `encode()` performs both passes. If you need to
63/// serialize the same message multiple times without mutation in between,
64/// you can call `compute_size()` once and then `write_to()` repeatedly.
65///
66/// # Thread safety
67///
68/// `Message` requires `Send + Sync`. The `CachedSize` field uses `AtomicU32`
69/// with `Relaxed` ordering, so messages can be placed in an `Arc` and shared
70/// across threads. Serialization (`compute_size` → `write_to`) must still be
71/// sequenced on a single thread per message — `merge` requires `&mut self`,
72/// and `compute_size`/`write_to` must be called in order without interleaving
73/// from another thread to produce a valid encoding.
74pub trait Message: DefaultInstance + Clone + PartialEq + Send + Sync {
75    /// Compute and cache the encoded byte size of this message.
76    ///
77    /// This recursively computes sizes for all sub-messages and stores them
78    /// in each message's `CachedSize` field. Must be called before `write_to()`.
79    ///
80    /// # Size limit
81    ///
82    /// The protobuf specification limits messages to 2 GiB. The return type
83    /// is `u32`, so messages whose encoded size exceeds `u32::MAX` (4 GiB)
84    /// will produce a wrapped (undefined) size and a truncated encoding.
85    /// Stay well within the 2 GiB spec limit.
86    fn compute_size(&self) -> u32;
87
88    /// Write this message's encoded bytes to a buffer.
89    ///
90    /// Assumes `compute_size()` has already been called. Uses cached sizes
91    /// for length-delimited sub-message headers.
92    fn write_to(&self, buf: &mut impl BufMut);
93
94    /// Convenience: compute size, then write. This is the primary encoding API.
95    fn encode(&self, buf: &mut impl BufMut) {
96        self.compute_size();
97        self.write_to(buf);
98    }
99
100    /// Encode this message as a length-delimited byte sequence.
101    fn encode_length_delimited(&self, buf: &mut impl BufMut) {
102        let len = self.compute_size();
103        crate::encoding::encode_varint(len as u64, buf);
104        self.write_to(buf);
105    }
106
107    /// Encode this message to a new `Vec<u8>`.
108    fn encode_to_vec(&self) -> alloc::vec::Vec<u8> {
109        let size = self.compute_size() as usize;
110        let mut buf = alloc::vec::Vec::with_capacity(size);
111        self.write_to(&mut buf);
112        buf
113    }
114
115    /// Encode this message to a new [`bytes::Bytes`].
116    ///
117    /// Useful when handing off to networking code (hyper, tonic, axum)
118    /// that expects `Bytes` frame or body payloads. Works in `no_std`.
119    ///
120    /// This is equivalent to `Bytes::from(self.encode_to_vec())` — both
121    /// are zero-copy with respect to the encoded bytes — but saves readers
122    /// from having to know that `From<Vec<u8>> for Bytes` is zero-copy.
123    fn encode_to_bytes(&self) -> bytes::Bytes {
124        let size = self.compute_size() as usize;
125        let mut buf = bytes::BytesMut::with_capacity(size);
126        self.write_to(&mut buf);
127        buf.freeze()
128    }
129
130    /// Decode a message from a buffer.
131    fn decode(buf: &mut impl Buf) -> Result<Self, DecodeError>
132    where
133        Self: Sized,
134    {
135        let mut msg = Self::default();
136        msg.merge(buf, RECURSION_LIMIT)?;
137        Ok(msg)
138    }
139
140    /// Decode a message from a byte slice.
141    ///
142    /// Convenience wrapper around [`decode`](Self::decode) that avoids the
143    /// `&mut bytes.as_slice()` incantation.
144    fn decode_from_slice(mut data: &[u8]) -> Result<Self, DecodeError>
145    where
146        Self: Sized,
147    {
148        // `mut data` creates a local mutable copy of the fat pointer so that
149        // `Buf::advance` can move the read cursor without affecting the caller.
150        Self::decode(&mut data)
151    }
152
153    /// Decode a length-delimited message from a buffer.
154    ///
155    /// This is a **top-level** entry point.  It reads a varint length prefix,
156    /// then decodes using arithmetic bounds checking, calling
157    /// [`merge_to_limit`](Self::merge_to_limit) with a fresh
158    /// [`RECURSION_LIMIT`] budget.  Any sub-messages inside are decoded via
159    /// [`merge_length_delimited`](Self::merge_length_delimited), which tracks
160    /// and decrements the budget.
161    ///
162    /// Do **not** call this method from within a
163    /// [`merge_to_limit`](Self::merge_to_limit) implementation to decode a
164    /// nested sub-message field; use
165    /// [`merge_length_delimited`](Self::merge_length_delimited) instead so
166    /// that the caller's depth budget is propagated correctly.
167    fn decode_length_delimited(buf: &mut impl Buf) -> Result<Self, DecodeError>
168    where
169        Self: Sized,
170    {
171        // Refuse messages larger than 2 GiB to prevent allocating attacker-
172        // controlled amounts of memory from a crafted length prefix.
173        const MAX_MESSAGE_BYTES: u64 = 0x7FFF_FFFF;
174        let len_u64 = crate::encoding::decode_varint(buf)?;
175        if len_u64 > MAX_MESSAGE_BYTES {
176            return Err(DecodeError::MessageTooLarge);
177        }
178        // Safe on 32-bit: len_u64 <= 2 GiB - 1 < u32::MAX, so the cast never truncates.
179        let len = usize::try_from(len_u64).map_err(|_| DecodeError::MessageTooLarge)?;
180        if buf.remaining() < len {
181            return Err(DecodeError::UnexpectedEof);
182        }
183        // Arithmetic limit: decode `len` bytes from the buffer without
184        // wrapping it in `Take`.  This keeps the buffer type `B` unchanged
185        // through every recursion level, avoiding E0275 for recursive
186        // message types like `google.protobuf.Struct ↔ Value`.
187        let limit = buf.remaining() - len;
188        let mut msg = Self::default();
189        msg.merge_to_limit(buf, RECURSION_LIMIT, limit)?;
190        if buf.remaining() != limit {
191            let remaining = buf.remaining();
192            if remaining > limit {
193                buf.advance(remaining - limit);
194            } else {
195                return Err(DecodeError::UnexpectedEof);
196            }
197        }
198        Ok(msg)
199    }
200
201    /// Processes a single already-decoded tag and its associated field data
202    /// from `buf`.
203    ///
204    /// This is the per-field dispatch method generated for each message type.
205    /// Both [`merge_to_limit`](Self::merge_to_limit) and
206    /// [`merge_group`](Self::merge_group) call this in their respective loops.
207    ///
208    /// `depth` is the remaining nesting budget.
209    ///
210    /// # Errors
211    ///
212    /// Returns a [`DecodeError`] if:
213    /// - the buffer is truncated or malformed,
214    /// - a wire-type mismatch is detected for a known field, or
215    /// - the recursion limit is exceeded.
216    fn merge_field(
217        &mut self,
218        tag: crate::encoding::Tag,
219        buf: &mut impl Buf,
220        depth: u32,
221    ) -> Result<(), DecodeError>;
222
223    /// Merge fields from a buffer until `buf.remaining()` reaches `limit`.
224    ///
225    /// This is the core decode loop.  [`merge`](Self::merge) delegates to this
226    /// with `limit = 0` (read until exhausted).
227    /// [`merge_length_delimited`](Self::merge_length_delimited) computes
228    /// `limit` from the declared sub-message length and calls this directly.
229    ///
230    /// The caller must ensure `limit <= buf.remaining()`.  The default
231    /// implementations of [`merge`](Self::merge) and
232    /// [`merge_length_delimited`](Self::merge_length_delimited) uphold this
233    /// invariant.
234    ///
235    /// `depth` is the remaining nesting budget.  Each call to
236    /// [`merge_length_delimited`](Self::merge_length_delimited) decrements it
237    /// by one before recursing; when it reaches zero the call returns
238    /// [`DecodeError::RecursionLimitExceeded`].
239    fn merge_to_limit(
240        &mut self,
241        buf: &mut impl Buf,
242        depth: u32,
243        limit: usize,
244    ) -> Result<(), DecodeError> {
245        while buf.remaining() > limit {
246            let tag = crate::encoding::Tag::decode(buf)?;
247            self.merge_field(tag, buf, depth)?;
248        }
249        Ok(())
250    }
251
252    /// Merges a group-encoded message from `buf`, reading fields until an
253    /// EndGroup tag with the given `field_number` is encountered.
254    ///
255    /// Proto2 groups use StartGroup/EndGroup wire types instead of
256    /// length-delimited encoding. The opening StartGroup tag has already been
257    /// consumed by the caller; this method reads the group body and the
258    /// closing EndGroup tag.
259    ///
260    /// # Errors
261    ///
262    /// Returns a [`DecodeError`] if:
263    /// - the buffer is truncated before the EndGroup tag,
264    /// - an EndGroup tag is encountered with a mismatched field number,
265    /// - a wire-type mismatch is detected for a known field, or
266    /// - the recursion limit is exceeded.
267    fn merge_group(
268        &mut self,
269        buf: &mut impl Buf,
270        depth: u32,
271        field_number: u32,
272    ) -> Result<(), DecodeError> {
273        let depth = depth
274            .checked_sub(1)
275            .ok_or(DecodeError::RecursionLimitExceeded)?;
276        loop {
277            if !buf.has_remaining() {
278                return Err(DecodeError::UnexpectedEof);
279            }
280            let tag = crate::encoding::Tag::decode(buf)?;
281            if tag.wire_type() == crate::encoding::WireType::EndGroup {
282                return if tag.field_number() == field_number {
283                    Ok(())
284                } else {
285                    Err(DecodeError::InvalidEndGroup(tag.field_number()))
286                };
287            }
288            self.merge_field(tag, buf, depth)?;
289        }
290    }
291
292    /// Merge fields from a buffer into this message.
293    ///
294    /// Fields that are already set will be overwritten for singular fields,
295    /// or appended for repeated fields, following standard protobuf merge
296    /// semantics.
297    ///
298    /// `depth` is the remaining nesting budget.  Each call to
299    /// [`merge_length_delimited`](Self::merge_length_delimited) decrements it
300    /// by one before recursing; when it reaches zero the call returns
301    /// [`DecodeError::RecursionLimitExceeded`].  Pass [`RECURSION_LIMIT`] at
302    /// the outermost call site, or use the convenience methods
303    /// ([`decode`](Self::decode), [`merge_from_slice`](Self::merge_from_slice))
304    /// which do this automatically.
305    fn merge(&mut self, buf: &mut impl Buf, depth: u32) -> Result<(), DecodeError> {
306        self.merge_to_limit(buf, depth, 0)
307    }
308
309    /// Merge fields from a byte slice into this message.
310    ///
311    /// Convenience wrapper around [`merge`](Self::merge) that avoids the
312    /// `&mut bytes.as_slice()` incantation.
313    fn merge_from_slice(&mut self, mut data: &[u8]) -> Result<(), DecodeError> {
314        self.merge(&mut data, RECURSION_LIMIT)
315    }
316
317    /// Merge fields from a length-delimited sub-message payload into this message.
318    ///
319    /// Reads a varint length prefix, then calls [`merge_to_limit`](Self::merge_to_limit)
320    /// with an arithmetic bound derived from the declared sub-message length.
321    /// The buffer type `B` passes through unchanged at every recursion level,
322    /// avoiding the `E0275` trait-solver recursion limit that occurs with
323    /// `Take<&mut Take<&mut T>>` type growth.
324    ///
325    /// Used by generated code when decoding singular `MessageField<T>` fields
326    /// — the sub-message is merged into the existing value rather than
327    /// replaced, per protobuf merge semantics.
328    ///
329    /// `depth` is the remaining nesting budget passed down from the enclosing
330    /// [`merge_to_limit`](Self::merge_to_limit) call.  This method decrements
331    /// it by one before calling the inner `merge_to_limit`; when it reaches
332    /// zero it returns [`DecodeError::RecursionLimitExceeded`].
333    ///
334    /// Enforces the same 2 GiB safety limit as [`decode_length_delimited`](Self::decode_length_delimited).
335    ///
336    /// # Errors
337    ///
338    /// Returns an error if the buffer is too short, if the declared length
339    /// exceeds 2 GiB, if the recursion limit is reached, or if the inner
340    /// `merge_to_limit` call fails.
341    fn merge_length_delimited(
342        &mut self,
343        buf: &mut impl Buf,
344        depth: u32,
345    ) -> Result<(), DecodeError> {
346        let depth = depth
347            .checked_sub(1)
348            .ok_or(DecodeError::RecursionLimitExceeded)?;
349        const MAX_SUB_MESSAGE_BYTES: u64 = 0x7FFF_FFFF;
350        let len_u64 = crate::encoding::decode_varint(buf)?;
351        if len_u64 > MAX_SUB_MESSAGE_BYTES {
352            return Err(DecodeError::MessageTooLarge);
353        }
354        let len = usize::try_from(len_u64).map_err(|_| DecodeError::MessageTooLarge)?;
355        if buf.remaining() < len {
356            return Err(DecodeError::UnexpectedEof);
357        }
358        // Arithmetic limit: the sub-message occupies `len` bytes, so the
359        // decode loop should stop when `buf.remaining()` drops to
360        // `remaining - len`.  This avoids wrapping the buffer in `Take`,
361        // which would grow the type at each recursion level and trigger
362        // E0275 for recursive message types like `Struct ↔ Value`.
363        let limit = buf.remaining() - len;
364        self.merge_to_limit(buf, depth, limit)?;
365        if buf.remaining() != limit {
366            let remaining = buf.remaining();
367            if remaining > limit {
368                // Sub-message consumed fewer bytes than declared; skip the rest.
369                buf.advance(remaining - limit);
370            } else {
371                return Err(DecodeError::UnexpectedEof);
372            }
373        }
374        Ok(())
375    }
376
377    /// The cached encoded size from the last `compute_size()` call.
378    ///
379    /// Returns 0 if `compute_size()` has never been called.
380    fn cached_size(&self) -> u32;
381
382    /// Clear all fields to their default values.
383    fn clear(&mut self);
384}
385
386/// Options for configuring message decoding behavior.
387///
388/// Use this to set custom recursion depth limits or maximum message sizes
389/// when decoding from untrusted input.
390///
391/// # Examples
392///
393/// ```no_run
394/// # use buffa::__doctest_fixtures::Person;
395/// use buffa::DecodeOptions;
396///
397/// # fn example(bytes: &[u8]) -> Result<(), buffa::DecodeError> {
398/// // Restrict recursion depth to 50 and message size to 1 MiB:
399/// let msg: Person = DecodeOptions::new()
400///     .with_recursion_limit(50)
401///     .with_max_message_size(1024 * 1024)
402///     .decode_from_slice(bytes)?;
403/// # Ok(())
404/// # }
405/// ```
406#[derive(Debug, Clone)]
407pub struct DecodeOptions {
408    recursion_limit: u32,
409    max_message_size: usize,
410}
411
412/// Default maximum message size: 2 GiB - 1 (matches the internal sub-message
413/// limit in `merge_length_delimited`).
414const DEFAULT_MAX_MESSAGE_SIZE: usize = 0x7FFF_FFFF;
415
416impl Default for DecodeOptions {
417    fn default() -> Self {
418        Self::new()
419    }
420}
421
422impl DecodeOptions {
423    /// Create new decode options with defaults.
424    ///
425    /// Defaults:
426    /// - `recursion_limit`: 100 (same as [`RECURSION_LIMIT`])
427    /// - `max_message_size`: 2 GiB - 1
428    pub fn new() -> Self {
429        Self {
430            recursion_limit: RECURSION_LIMIT,
431            max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
432        }
433    }
434
435    /// Set the maximum recursion depth for nested messages.
436    ///
437    /// Each nested sub-message consumes one level of depth budget. When
438    /// the budget reaches zero, decoding returns
439    /// [`DecodeError::RecursionLimitExceeded`].
440    ///
441    /// Default: 100.
442    #[must_use]
443    pub fn with_recursion_limit(mut self, limit: u32) -> Self {
444        self.recursion_limit = limit;
445        self
446    }
447
448    /// Set the maximum total message size in bytes.
449    ///
450    /// If the input buffer or length-delimited payload exceeds this size,
451    /// decoding returns [`DecodeError::MessageTooLarge`].
452    ///
453    /// This is checked at the top-level decode entry point. Individual
454    /// sub-messages are still bounded by the internal 2 GiB limit
455    /// regardless of this setting.
456    ///
457    /// Default: 2 GiB - 1 (0x7FFF_FFFF).
458    #[must_use]
459    pub fn with_max_message_size(mut self, max_bytes: usize) -> Self {
460        self.max_message_size = max_bytes;
461        self
462    }
463
464    /// Returns the configured recursion depth limit.
465    pub fn recursion_limit(&self) -> u32 {
466        self.recursion_limit
467    }
468
469    /// Returns the configured maximum message size in bytes.
470    pub fn max_message_size(&self) -> usize {
471        self.max_message_size
472    }
473
474    /// Decode a message from a buffer.
475    pub fn decode<M: Message>(&self, buf: &mut impl Buf) -> Result<M, DecodeError> {
476        if buf.remaining() > self.max_message_size {
477            return Err(DecodeError::MessageTooLarge);
478        }
479        let mut msg = M::default();
480        msg.merge(buf, self.recursion_limit)?;
481        Ok(msg)
482    }
483
484    /// Decode a message from a byte slice.
485    pub fn decode_from_slice<M: Message>(&self, data: &[u8]) -> Result<M, DecodeError> {
486        if data.len() > self.max_message_size {
487            return Err(DecodeError::MessageTooLarge);
488        }
489        let mut msg = M::default();
490        msg.merge(&mut &*data, self.recursion_limit)?;
491        Ok(msg)
492    }
493
494    /// Decode a length-delimited message from a buffer.
495    pub fn decode_length_delimited<M: Message>(
496        &self,
497        buf: &mut impl Buf,
498    ) -> Result<M, DecodeError> {
499        // Enforce the 2 GiB internal safety cap even if the user sets a
500        // larger max_message_size, to prevent allocating attacker-controlled
501        // amounts of memory from a crafted length prefix.
502        let max = core::cmp::min(
503            self.max_message_size as u64,
504            DEFAULT_MAX_MESSAGE_SIZE as u64,
505        );
506        let len_u64 = crate::encoding::decode_varint(buf)?;
507        if len_u64 > max {
508            return Err(DecodeError::MessageTooLarge);
509        }
510        let len = usize::try_from(len_u64).map_err(|_| DecodeError::MessageTooLarge)?;
511        if buf.remaining() < len {
512            return Err(DecodeError::UnexpectedEof);
513        }
514        let limit = buf.remaining() - len;
515        let mut msg = M::default();
516        msg.merge_to_limit(buf, self.recursion_limit, limit)?;
517        if buf.remaining() != limit {
518            let remaining = buf.remaining();
519            if remaining > limit {
520                buf.advance(remaining - limit);
521            } else {
522                return Err(DecodeError::UnexpectedEof);
523            }
524        }
525        Ok(msg)
526    }
527
528    /// Merge fields from a buffer into an existing message.
529    pub fn merge<M: Message>(&self, msg: &mut M, buf: &mut impl Buf) -> Result<(), DecodeError> {
530        if buf.remaining() > self.max_message_size {
531            return Err(DecodeError::MessageTooLarge);
532        }
533        msg.merge(buf, self.recursion_limit)
534    }
535
536    /// Merge fields from a byte slice into an existing message.
537    pub fn merge_from_slice<M: Message>(
538        &self,
539        msg: &mut M,
540        data: &[u8],
541    ) -> Result<(), DecodeError> {
542        if data.len() > self.max_message_size {
543            return Err(DecodeError::MessageTooLarge);
544        }
545        msg.merge(&mut &*data, self.recursion_limit)
546    }
547
548    /// Decode a zero-copy view from a byte slice.
549    pub fn decode_view<'a, V: crate::view::MessageView<'a>>(
550        &self,
551        buf: &'a [u8],
552    ) -> Result<V, DecodeError> {
553        if buf.len() > self.max_message_size {
554            return Err(DecodeError::MessageTooLarge);
555        }
556        V::decode_view_with_limit(buf, self.recursion_limit)
557    }
558
559    /// Decode a message by reading all bytes from a [`std::io::Read`] source.
560    ///
561    /// Reads until EOF, enforces `max_message_size`, then decodes the
562    /// buffered bytes. Returns `std::io::Error` to be compatible with
563    /// `Read`-based error handling.
564    #[cfg(feature = "std")]
565    pub fn decode_reader<M: Message>(
566        &self,
567        reader: &mut impl std::io::Read,
568    ) -> Result<M, std::io::Error> {
569        let bytes = self.read_limited(reader)?;
570        self.decode_from_slice::<M>(&bytes)
571            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
572    }
573
574    /// Decode a length-delimited message from a [`std::io::Read`] source.
575    ///
576    /// Reads a varint length prefix, enforces `max_message_size`, reads
577    /// exactly that many bytes, then decodes. Useful for reading sequential
578    /// length-delimited messages from a file or stream.
579    #[cfg(feature = "std")]
580    pub fn decode_length_delimited_reader<M: Message>(
581        &self,
582        reader: &mut impl std::io::Read,
583    ) -> Result<M, std::io::Error> {
584        let len = read_varint(reader)?;
585        let max = core::cmp::min(
586            self.max_message_size as u64,
587            DEFAULT_MAX_MESSAGE_SIZE as u64,
588        );
589        if len > max {
590            return Err(std::io::Error::new(
591                std::io::ErrorKind::InvalidData,
592                DecodeError::MessageTooLarge,
593            ));
594        }
595        let len = len as usize;
596        let mut buf = alloc::vec![0u8; len];
597        reader.read_exact(&mut buf)?;
598        self.decode_from_slice::<M>(&buf)
599            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
600    }
601
602    /// Read all bytes from a reader up to `max_message_size`.
603    #[cfg(feature = "std")]
604    fn read_limited(
605        &self,
606        reader: &mut impl std::io::Read,
607    ) -> Result<alloc::vec::Vec<u8>, std::io::Error> {
608        use std::io::Read as _;
609        let mut buf = alloc::vec::Vec::new();
610        reader
611            .take(self.max_message_size as u64 + 1)
612            .read_to_end(&mut buf)?;
613        if buf.len() > self.max_message_size {
614            return Err(std::io::Error::new(
615                std::io::ErrorKind::InvalidData,
616                DecodeError::MessageTooLarge,
617            ));
618        }
619        Ok(buf)
620    }
621}
622
623/// Read a varint from a `std::io::Read` source, one byte at a time.
624///
625/// Mirrors the validation in [`decode_varint_slow`](crate::encoding): a
626/// 10th byte > `0x01` (overflow bits set, or continuation bit implying an
627/// 11th byte) is rejected.
628#[cfg(feature = "std")]
629fn read_varint(reader: &mut impl std::io::Read) -> Result<u64, std::io::Error> {
630    let mut value: u64 = 0;
631    let mut shift: u32 = 0;
632    loop {
633        let mut byte = [0u8; 1];
634        reader.read_exact(&mut byte)?;
635        let b = byte[0];
636        if shift < 63 {
637            value |= ((b & 0x7F) as u64) << shift;
638            if b < 0x80 {
639                return Ok(value);
640            }
641            shift += 7;
642        } else {
643            // 10th byte: only bit 0 maps to bit 63 of the result. A byte
644            // > 0x01 means either data overflow (bits 1-6 set) or an 11th
645            // byte (continuation bit 0x80 set).
646            if b > 0x01 {
647                return Err(std::io::Error::new(
648                    std::io::ErrorKind::InvalidData,
649                    DecodeError::VarintTooLong,
650                ));
651            }
652            value |= (b as u64) << 63;
653            return Ok(value);
654        }
655    }
656}
657
658#[cfg(test)]
659mod tests {
660    use super::*;
661    use crate::cached_size::CachedSize;
662    use crate::encoding::encode_varint;
663    use crate::error::DecodeError;
664    use crate::message_field::DefaultInstance;
665
666    // Minimal hand-written Message for testing merge_length_delimited.
667    // Includes `CachedSize` to mirror the shape of generated message structs.
668    #[derive(Clone, Debug, Default, PartialEq)]
669    struct FlatMsg {
670        value: i32,
671        __buffa_cached_size: CachedSize,
672    }
673
674    unsafe impl DefaultInstance for FlatMsg {
675        fn default_instance() -> &'static Self {
676            static INST: crate::__private::OnceBox<FlatMsg> = crate::__private::OnceBox::new();
677            INST.get_or_init(|| alloc::boxed::Box::new(FlatMsg::default()))
678        }
679    }
680
681    impl Message for FlatMsg {
682        fn compute_size(&self) -> u32 {
683            let size = if self.value != 0 {
684                1 + crate::types::int32_encoded_len(self.value) as u32
685            } else {
686                0
687            };
688            self.__buffa_cached_size.set(size);
689            size
690        }
691
692        fn write_to(&self, buf: &mut impl BufMut) {
693            if self.value != 0 {
694                crate::encoding::Tag::new(1, crate::encoding::WireType::Varint).encode(buf);
695                crate::types::encode_int32(self.value, buf);
696            }
697        }
698
699        fn merge_field(
700            &mut self,
701            tag: crate::encoding::Tag,
702            buf: &mut impl Buf,
703            _depth: u32,
704        ) -> Result<(), DecodeError> {
705            match tag.field_number() {
706                1 => {
707                    self.value = crate::types::decode_int32(buf)?;
708                }
709                _ => {
710                    crate::encoding::skip_field(tag, buf)?;
711                }
712            }
713            Ok(())
714        }
715
716        fn cached_size(&self) -> u32 {
717            self.__buffa_cached_size.get()
718        }
719
720        fn clear(&mut self) {
721            *self = Self::default();
722        }
723    }
724
725    fn wire_bytes(msg: &FlatMsg) -> alloc::vec::Vec<u8> {
726        let mut buf = alloc::vec::Vec::new();
727        msg.encode_length_delimited(&mut buf);
728        buf
729    }
730
731    #[test]
732    fn test_merge_length_delimited_basic() {
733        let src = FlatMsg {
734            value: 42,
735            __buffa_cached_size: CachedSize::default(),
736        };
737        let mut dst = FlatMsg::default();
738        dst.merge_length_delimited(&mut wire_bytes(&src).as_slice(), RECURSION_LIMIT)
739            .unwrap();
740        assert_eq!(dst.value, 42);
741    }
742
743    #[test]
744    fn test_merge_length_delimited_merges_into_existing() {
745        // Second merge overwrites (proto3 last-wins for scalar fields).
746        let mut dst = FlatMsg::default();
747        dst.merge_length_delimited(
748            &mut wire_bytes(&FlatMsg {
749                value: 1,
750                __buffa_cached_size: CachedSize::default(),
751            })
752            .as_slice(),
753            RECURSION_LIMIT,
754        )
755        .unwrap();
756        assert_eq!(dst.value, 1);
757        dst.merge_length_delimited(
758            &mut wire_bytes(&FlatMsg {
759                value: 2,
760                __buffa_cached_size: CachedSize::default(),
761            })
762            .as_slice(),
763            RECURSION_LIMIT,
764        )
765        .unwrap();
766        assert_eq!(dst.value, 2);
767    }
768
769    #[test]
770    fn test_merge_length_delimited_truncated() {
771        // Length prefix says 10 bytes but buffer contains only 2.
772        let mut buf = alloc::vec::Vec::new();
773        encode_varint(10, &mut buf);
774        buf.extend_from_slice(&[0x01, 0x01]);
775        let mut dst = FlatMsg::default();
776        assert_eq!(
777            dst.merge_length_delimited(&mut buf.as_slice(), RECURSION_LIMIT),
778            Err(DecodeError::UnexpectedEof)
779        );
780    }
781
782    #[test]
783    fn test_merge_length_delimited_oversized() {
784        // Length prefix exceeds the 2 GiB safety limit.
785        let mut buf = alloc::vec::Vec::new();
786        encode_varint(0x8000_0000u64, &mut buf); // 2 GiB + 1
787        let mut dst = FlatMsg::default();
788        assert_eq!(
789            dst.merge_length_delimited(&mut buf.as_slice(), RECURSION_LIMIT),
790            Err(DecodeError::MessageTooLarge)
791        );
792    }
793
794    #[test]
795    fn test_merge_length_delimited_recursion_limit() {
796        // depth=1 means merge_length_delimited will decrement to 0, then any
797        // nested call would return RecursionLimitExceeded.  Passing depth=1
798        // to merge_length_delimited itself is the boundary: it decrements to
799        // 0 and calls merge with depth=0, which for FlatMsg (a leaf) succeeds.
800        // Passing depth=0 directly must return RecursionLimitExceeded.
801        let src = FlatMsg {
802            value: 7,
803            __buffa_cached_size: CachedSize::default(),
804        };
805        let mut dst = FlatMsg::default();
806        assert_eq!(
807            dst.merge_length_delimited(&mut wire_bytes(&src).as_slice(), 0),
808            Err(DecodeError::RecursionLimitExceeded)
809        );
810        // depth=1 succeeds: exactly one level is consumed.
811        dst.merge_length_delimited(&mut wire_bytes(&src).as_slice(), 1)
812            .unwrap();
813        assert_eq!(dst.value, 7);
814    }
815
816    #[test]
817    fn test_decode_from_slice_basic() {
818        let src = FlatMsg {
819            value: 42,
820            __buffa_cached_size: CachedSize::default(),
821        };
822        let bytes = src.encode_to_vec();
823        let dst = FlatMsg::decode_from_slice(&bytes).unwrap();
824        assert_eq!(dst.value, 42);
825    }
826
827    #[test]
828    fn test_encode_to_bytes_matches_encode_to_vec() {
829        let src = FlatMsg {
830            value: 42,
831            __buffa_cached_size: CachedSize::default(),
832        };
833        let vec = src.encode_to_vec();
834        let bytes = src.encode_to_bytes();
835        assert_eq!(vec.as_slice(), bytes.as_ref());
836        // Round-trip through the Bytes variant.
837        let dst = FlatMsg::decode_from_slice(&bytes).unwrap();
838        assert_eq!(dst.value, 42);
839        // Empty-message case: zero-length buffer is well-defined.
840        assert!(FlatMsg::default().encode_to_bytes().is_empty());
841    }
842
843    #[test]
844    fn test_decode_from_slice_empty() {
845        let dst = FlatMsg::decode_from_slice(&[]).unwrap();
846        assert_eq!(dst.value, 0);
847    }
848
849    #[test]
850    fn test_decode_from_slice_invalid_returns_error() {
851        // A lone 0xFF byte is not a valid varint tag.
852        let result = FlatMsg::decode_from_slice(&[0xFF]);
853        assert!(result.is_err());
854    }
855
856    #[test]
857    fn test_merge_from_slice_basic() {
858        let src = FlatMsg {
859            value: 7,
860            __buffa_cached_size: CachedSize::default(),
861        };
862        let bytes = src.encode_to_vec();
863        let mut dst = FlatMsg::default();
864        dst.merge_from_slice(&bytes).unwrap();
865        assert_eq!(dst.value, 7);
866    }
867
868    #[test]
869    fn test_merge_from_slice_last_wins() {
870        let src1 = FlatMsg {
871            value: 1,
872            __buffa_cached_size: CachedSize::default(),
873        };
874        let src2 = FlatMsg {
875            value: 2,
876            __buffa_cached_size: CachedSize::default(),
877        };
878        let mut dst = FlatMsg::default();
879        dst.merge_from_slice(&src1.encode_to_vec()).unwrap();
880        dst.merge_from_slice(&src2.encode_to_vec()).unwrap();
881        // Proto3 last-wins semantics for scalar fields.
882        assert_eq!(dst.value, 2);
883    }
884
885    // ── DecodeOptions tests ──────────────────────────────────────────
886
887    #[test]
888    fn test_decode_options_default_works() {
889        let src = FlatMsg {
890            value: 99,
891            __buffa_cached_size: CachedSize::default(),
892        };
893        let bytes = src.encode_to_vec();
894        let msg: FlatMsg = DecodeOptions::new().decode_from_slice(&bytes).unwrap();
895        assert_eq!(msg.value, 99);
896    }
897
898    #[test]
899    fn test_decode_options_max_message_size_rejects() {
900        let src = FlatMsg {
901            value: 42,
902            __buffa_cached_size: CachedSize::default(),
903        };
904        let bytes = src.encode_to_vec();
905        // Set max size to 1 byte — smaller than the encoded message.
906        let result: Result<FlatMsg, _> = DecodeOptions::new()
907            .with_max_message_size(1)
908            .decode_from_slice(&bytes);
909        assert_eq!(result, Err(DecodeError::MessageTooLarge));
910    }
911
912    #[test]
913    fn test_decode_options_max_message_size_exact_boundary() {
914        let src = FlatMsg {
915            value: 42,
916            __buffa_cached_size: CachedSize::default(),
917        };
918        let bytes = src.encode_to_vec();
919        // Exact size should succeed.
920        let msg: FlatMsg = DecodeOptions::new()
921            .with_max_message_size(bytes.len())
922            .decode_from_slice(&bytes)
923            .unwrap();
924        assert_eq!(msg.value, 42);
925        // One byte less should fail.
926        let result: Result<FlatMsg, _> = DecodeOptions::new()
927            .with_max_message_size(bytes.len() - 1)
928            .decode_from_slice(&bytes);
929        assert_eq!(result, Err(DecodeError::MessageTooLarge));
930    }
931
932    #[test]
933    fn test_decode_options_custom_recursion_limit() {
934        // FlatMsg has no nested messages, so any recursion limit >= 0 works.
935        // Just verify the API compiles and runs.
936        let src = FlatMsg {
937            value: 7,
938            __buffa_cached_size: CachedSize::default(),
939        };
940        let bytes = src.encode_to_vec();
941        let msg: FlatMsg = DecodeOptions::new()
942            .with_recursion_limit(1)
943            .decode_from_slice(&bytes)
944            .unwrap();
945        assert_eq!(msg.value, 7);
946    }
947
948    #[test]
949    fn test_decode_options_merge() {
950        let src = FlatMsg {
951            value: 55,
952            __buffa_cached_size: CachedSize::default(),
953        };
954        let bytes = src.encode_to_vec();
955        let mut msg = FlatMsg::default();
956        DecodeOptions::new()
957            .merge_from_slice(&mut msg, &bytes)
958            .unwrap();
959        assert_eq!(msg.value, 55);
960    }
961
962    #[test]
963    fn test_decode_options_merge_rejects_oversize() {
964        let src = FlatMsg {
965            value: 55,
966            __buffa_cached_size: CachedSize::default(),
967        };
968        let bytes = src.encode_to_vec();
969        let mut msg = FlatMsg::default();
970        let result = DecodeOptions::new()
971            .with_max_message_size(1)
972            .merge_from_slice(&mut msg, &bytes);
973        assert_eq!(result, Err(DecodeError::MessageTooLarge));
974    }
975
976    #[test]
977    fn test_decode_options_length_delimited() {
978        let src = FlatMsg {
979            value: 42,
980            __buffa_cached_size: CachedSize::default(),
981        };
982        let mut ld_bytes = alloc::vec::Vec::new();
983        src.encode_length_delimited(&mut ld_bytes);
984        let msg: FlatMsg = DecodeOptions::new()
985            .decode_length_delimited(&mut ld_bytes.as_slice())
986            .unwrap();
987        assert_eq!(msg.value, 42);
988    }
989
990    #[test]
991    fn test_decode_options_length_delimited_rejects_oversize() {
992        let src = FlatMsg {
993            value: 42,
994            __buffa_cached_size: CachedSize::default(),
995        };
996        let mut ld_bytes = alloc::vec::Vec::new();
997        src.encode_length_delimited(&mut ld_bytes);
998        let result: Result<FlatMsg, _> = DecodeOptions::new()
999            .with_max_message_size(1)
1000            .decode_length_delimited(&mut ld_bytes.as_slice());
1001        assert_eq!(result, Err(DecodeError::MessageTooLarge));
1002    }
1003
1004    #[test]
1005    fn decode_options_getters_return_defaults() {
1006        let opts = DecodeOptions::new();
1007        assert_eq!(opts.recursion_limit(), RECURSION_LIMIT);
1008        assert_eq!(opts.max_message_size(), 0x7FFF_FFFF);
1009    }
1010
1011    #[test]
1012    fn decode_options_getters_return_custom_values() {
1013        let opts = DecodeOptions::new()
1014            .with_recursion_limit(42)
1015            .with_max_message_size(1024);
1016        assert_eq!(opts.recursion_limit(), 42);
1017        assert_eq!(opts.max_message_size(), 1024);
1018    }
1019
1020    #[test]
1021    fn test_decode_options_default_impl() {
1022        // DecodeOptions::default() ≡ DecodeOptions::new().
1023        let opts = DecodeOptions::default();
1024        assert_eq!(opts.recursion_limit(), RECURSION_LIMIT);
1025        assert_eq!(opts.max_message_size(), 0x7FFF_FFFF);
1026    }
1027
1028    #[test]
1029    fn test_decode_options_decode_buf() {
1030        // The Buf-taking decode() variant (vs decode_from_slice).
1031        let src = FlatMsg {
1032            value: 123,
1033            ..Default::default()
1034        };
1035        let bytes = src.encode_to_vec();
1036        let msg: FlatMsg = DecodeOptions::new().decode(&mut bytes.as_slice()).unwrap();
1037        assert_eq!(msg.value, 123);
1038        // Oversize check on the Buf variant.
1039        let result: Result<FlatMsg, _> = DecodeOptions::new()
1040            .with_max_message_size(1)
1041            .decode(&mut bytes.as_slice());
1042        assert_eq!(result, Err(DecodeError::MessageTooLarge));
1043    }
1044
1045    #[test]
1046    fn test_decode_options_merge_buf() {
1047        // The Buf-taking merge() variant (vs merge_from_slice).
1048        let src = FlatMsg {
1049            value: 77,
1050            ..Default::default()
1051        };
1052        let bytes = src.encode_to_vec();
1053        let mut msg = FlatMsg::default();
1054        DecodeOptions::new()
1055            .merge(&mut msg, &mut bytes.as_slice())
1056            .unwrap();
1057        assert_eq!(msg.value, 77);
1058        // Oversize check.
1059        let mut msg = FlatMsg::default();
1060        let result = DecodeOptions::new()
1061            .with_max_message_size(1)
1062            .merge(&mut msg, &mut bytes.as_slice());
1063        assert_eq!(result, Err(DecodeError::MessageTooLarge));
1064    }
1065
1066    // ── Message trait default methods ─────────────────────────────────
1067
1068    #[test]
1069    fn test_message_encode_trait_default() {
1070        // Message::encode(buf) ≡ compute_size() then write_to(buf).
1071        let src = FlatMsg {
1072            value: 42,
1073            ..Default::default()
1074        };
1075        let mut buf = alloc::vec::Vec::new();
1076        src.encode(&mut buf);
1077        assert_eq!(buf, src.encode_to_vec());
1078    }
1079
1080    #[test]
1081    fn test_message_decode_length_delimited_trait_default() {
1082        // The trait-level decode_length_delimited (distinct from
1083        // DecodeOptions::decode_length_delimited).
1084        let src = FlatMsg {
1085            value: 42,
1086            ..Default::default()
1087        };
1088        let mut ld = alloc::vec::Vec::new();
1089        src.encode_length_delimited(&mut ld);
1090        let got = FlatMsg::decode_length_delimited(&mut ld.as_slice()).unwrap();
1091        assert_eq!(got.value, 42);
1092    }
1093
1094    #[test]
1095    fn test_message_decode_length_delimited_oversize() {
1096        // Length prefix > 2 GiB → MessageTooLarge.
1097        let mut buf = alloc::vec::Vec::new();
1098        encode_varint(0x8000_0000u64, &mut buf);
1099        let result = FlatMsg::decode_length_delimited(&mut buf.as_slice());
1100        assert_eq!(result, Err(DecodeError::MessageTooLarge));
1101    }
1102
1103    #[test]
1104    fn test_message_decode_length_delimited_truncated() {
1105        // Length prefix says 10 bytes, buffer has 2.
1106        let mut buf = alloc::vec::Vec::new();
1107        encode_varint(10, &mut buf);
1108        buf.push(0x08);
1109        buf.push(0x01);
1110        let result = FlatMsg::decode_length_delimited(&mut buf.as_slice());
1111        assert_eq!(result, Err(DecodeError::UnexpectedEof));
1112    }
1113
1114    #[test]
1115    fn test_message_decode_length_delimited_with_trailing() {
1116        // Buffer has two back-to-back length-delimited messages.
1117        // decode_length_delimited should consume exactly the first one
1118        // and leave the buffer positioned at the second.
1119        let a = FlatMsg {
1120            value: 1,
1121            ..Default::default()
1122        };
1123        let b = FlatMsg {
1124            value: 2,
1125            ..Default::default()
1126        };
1127        let mut buf = alloc::vec::Vec::new();
1128        a.encode_length_delimited(&mut buf);
1129        b.encode_length_delimited(&mut buf);
1130
1131        let mut cur = buf.as_slice();
1132        let first = FlatMsg::decode_length_delimited(&mut cur).unwrap();
1133        assert_eq!(first.value, 1);
1134        let second = FlatMsg::decode_length_delimited(&mut cur).unwrap();
1135        assert_eq!(second.value, 2);
1136        assert!(cur.is_empty());
1137    }
1138
1139    // ── merge_group tests ─────────────────────────────────────────────
1140
1141    /// Build a group body for FlatMsg (field 1 = value) terminated by
1142    /// EndGroup with the given field number.
1143    fn group_bytes(value: i32, group_field_number: u32) -> alloc::vec::Vec<u8> {
1144        use crate::encoding::{Tag, WireType};
1145        let mut buf = alloc::vec::Vec::new();
1146        if value != 0 {
1147            Tag::new(1, WireType::Varint).encode(&mut buf);
1148            crate::types::encode_int32(value, &mut buf);
1149        }
1150        Tag::new(group_field_number, WireType::EndGroup).encode(&mut buf);
1151        buf
1152    }
1153
1154    #[test]
1155    fn test_merge_group_basic() {
1156        let data = group_bytes(42, 5);
1157        let mut dst = FlatMsg::default();
1158        dst.merge_group(&mut data.as_slice(), RECURSION_LIMIT, 5)
1159            .unwrap();
1160        assert_eq!(dst.value, 42);
1161    }
1162
1163    #[test]
1164    fn test_merge_group_empty() {
1165        // Group with no fields — just EndGroup.
1166        let data = group_bytes(0, 3);
1167        let mut dst = FlatMsg::default();
1168        dst.merge_group(&mut data.as_slice(), RECURSION_LIMIT, 3)
1169            .unwrap();
1170        assert_eq!(dst.value, 0);
1171    }
1172
1173    #[test]
1174    fn test_merge_group_merges_into_existing() {
1175        let data1 = group_bytes(1, 5);
1176        let data2 = group_bytes(2, 5);
1177        let mut dst = FlatMsg::default();
1178        dst.merge_group(&mut data1.as_slice(), RECURSION_LIMIT, 5)
1179            .unwrap();
1180        assert_eq!(dst.value, 1);
1181        dst.merge_group(&mut data2.as_slice(), RECURSION_LIMIT, 5)
1182            .unwrap();
1183        assert_eq!(dst.value, 2);
1184    }
1185
1186    #[test]
1187    fn test_merge_group_recursion_limit_zero() {
1188        // depth=0 should immediately fail with RecursionLimitExceeded
1189        // because merge_group decrements before entering the loop.
1190        let data = group_bytes(42, 5);
1191        let mut dst = FlatMsg::default();
1192        assert_eq!(
1193            dst.merge_group(&mut data.as_slice(), 0, 5),
1194            Err(DecodeError::RecursionLimitExceeded)
1195        );
1196    }
1197
1198    #[test]
1199    fn test_merge_group_recursion_limit_one_succeeds() {
1200        // depth=1 succeeds: merge_group decrements to 0, but FlatMsg's
1201        // merge_field doesn't recurse further.
1202        let data = group_bytes(7, 5);
1203        let mut dst = FlatMsg::default();
1204        dst.merge_group(&mut data.as_slice(), 1, 5).unwrap();
1205        assert_eq!(dst.value, 7);
1206    }
1207
1208    #[test]
1209    fn test_merge_group_mismatched_end() {
1210        // EndGroup with wrong field number.
1211        use crate::encoding::{Tag, WireType};
1212        let mut data = alloc::vec::Vec::new();
1213        Tag::new(99, WireType::EndGroup).encode(&mut data);
1214
1215        let mut dst = FlatMsg::default();
1216        assert_eq!(
1217            dst.merge_group(&mut data.as_slice(), RECURSION_LIMIT, 5),
1218            Err(DecodeError::InvalidEndGroup(99))
1219        );
1220    }
1221
1222    #[test]
1223    fn test_merge_group_truncated() {
1224        // Buffer ends without EndGroup tag.
1225        use crate::encoding::{Tag, WireType};
1226        let mut data = alloc::vec::Vec::new();
1227        Tag::new(1, WireType::Varint).encode(&mut data);
1228        crate::types::encode_int32(42, &mut data);
1229        // No EndGroup.
1230
1231        let mut dst = FlatMsg::default();
1232        assert_eq!(
1233            dst.merge_group(&mut data.as_slice(), RECURSION_LIMIT, 5),
1234            Err(DecodeError::UnexpectedEof)
1235        );
1236    }
1237
1238    #[test]
1239    fn test_merge_group_empty_buffer() {
1240        let mut dst = FlatMsg::default();
1241        assert_eq!(
1242            dst.merge_group(&mut [].as_slice(), RECURSION_LIMIT, 5),
1243            Err(DecodeError::UnexpectedEof)
1244        );
1245    }
1246
1247    #[test]
1248    fn test_merge_group_unknown_fields_skipped() {
1249        // Group body contains an unknown field (field 99) which FlatMsg
1250        // routes to skip_field; the known field (field 1) should still
1251        // be decoded.
1252        use crate::encoding::{Tag, WireType};
1253        let mut data = alloc::vec::Vec::new();
1254        // Unknown varint field 99 = 0
1255        Tag::new(99, WireType::Varint).encode(&mut data);
1256        crate::encoding::encode_varint(0, &mut data);
1257        // Known field 1 = 99
1258        Tag::new(1, WireType::Varint).encode(&mut data);
1259        crate::types::encode_int32(99, &mut data);
1260        // EndGroup(5)
1261        Tag::new(5, WireType::EndGroup).encode(&mut data);
1262
1263        let mut dst = FlatMsg::default();
1264        dst.merge_group(&mut data.as_slice(), RECURSION_LIMIT, 5)
1265            .unwrap();
1266        assert_eq!(dst.value, 99);
1267    }
1268
1269    #[test]
1270    fn test_merge_group_trailing_data_preserved() {
1271        // After the EndGroup tag, trailing data should remain in the buffer.
1272        let mut data = group_bytes(42, 5);
1273        data.extend_from_slice(&[0xDE, 0xAD]);
1274
1275        let mut cur = data.as_slice();
1276        let mut dst = FlatMsg::default();
1277        dst.merge_group(&mut cur, RECURSION_LIMIT, 5).unwrap();
1278        assert_eq!(dst.value, 42);
1279        assert_eq!(cur, &[0xDE, 0xAD]);
1280    }
1281
1282    // ── read_varint (std::io::Read) tests ──────────────────────────────
1283
1284    #[cfg(feature = "std")]
1285    mod read_varint_tests {
1286        use super::super::read_varint;
1287        use crate::encoding::encode_varint;
1288
1289        #[test]
1290        fn roundtrip_values() {
1291            let cases: &[u64] = &[0, 1, 127, 128, 300, 1 << 14, 1 << 35, 1 << 63, u64::MAX];
1292            for &v in cases {
1293                let mut buf = Vec::new();
1294                encode_varint(v, &mut buf);
1295                let got = read_varint(&mut buf.as_slice()).unwrap();
1296                assert_eq!(got, v, "roundtrip failed for {v}");
1297            }
1298        }
1299
1300        #[test]
1301        fn rejects_10th_byte_overflow() {
1302            // 9 continuation bytes + 10th byte with overflow bit (0x02).
1303            // Mirrors encoding::tests::test_varint_10th_byte_overflow_rejected.
1304            let mut bad: Vec<u8> = vec![0xFF; 9];
1305            bad.push(0x02);
1306            let err = read_varint(&mut bad.as_slice()).unwrap_err();
1307            assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
1308        }
1309
1310        #[test]
1311        fn rejects_11th_byte() {
1312            // 10 continuation bytes — implies an 11th byte is needed.
1313            let bad: &[u8] = &[0xFF; 10];
1314            let err = read_varint(&mut &bad[..]).unwrap_err();
1315            assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
1316        }
1317
1318        #[test]
1319        fn u64_max_roundtrips() {
1320            // u64::MAX requires 10 bytes with the 10th byte == 0x01.
1321            let mut buf = Vec::new();
1322            encode_varint(u64::MAX, &mut buf);
1323            assert_eq!(buf.len(), 10);
1324            assert_eq!(buf[9], 0x01);
1325            let got = read_varint(&mut buf.as_slice()).unwrap();
1326            assert_eq!(got, u64::MAX);
1327        }
1328
1329        #[test]
1330        fn eof_before_terminator_is_error() {
1331            // Single continuation byte with no follow-up.
1332            let bad: &[u8] = &[0x80];
1333            let err = read_varint(&mut &bad[..]).unwrap_err();
1334            assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
1335        }
1336
1337        #[test]
1338        fn empty_input_is_error() {
1339            let err = read_varint(&mut &[][..]).unwrap_err();
1340            assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
1341        }
1342    }
1343
1344    // ── DecodeOptions std::io::Read tests ─────────────────────────────
1345
1346    #[cfg(feature = "std")]
1347    mod reader_tests {
1348        use super::*;
1349
1350        #[test]
1351        fn decode_reader_basic() {
1352            let src = FlatMsg {
1353                value: 42,
1354                ..Default::default()
1355            };
1356            let bytes = src.encode_to_vec();
1357            let msg: FlatMsg = DecodeOptions::new()
1358                .decode_reader(&mut bytes.as_slice())
1359                .unwrap();
1360            assert_eq!(msg.value, 42);
1361        }
1362
1363        #[test]
1364        fn decode_reader_rejects_oversize() {
1365            let src = FlatMsg {
1366                value: 42,
1367                ..Default::default()
1368            };
1369            let bytes = src.encode_to_vec();
1370            let err = DecodeOptions::new()
1371                .with_max_message_size(1)
1372                .decode_reader::<FlatMsg>(&mut bytes.as_slice())
1373                .unwrap_err();
1374            assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
1375        }
1376
1377        #[test]
1378        fn decode_reader_exact_boundary() {
1379            // max_message_size == encoded length → success.
1380            let src = FlatMsg {
1381                value: 42,
1382                ..Default::default()
1383            };
1384            let bytes = src.encode_to_vec();
1385            let msg: FlatMsg = DecodeOptions::new()
1386                .with_max_message_size(bytes.len())
1387                .decode_reader(&mut bytes.as_slice())
1388                .unwrap();
1389            assert_eq!(msg.value, 42);
1390        }
1391
1392        #[test]
1393        fn decode_reader_propagates_read_error() {
1394            // A reader that errors immediately.
1395            struct ErrReader;
1396            impl std::io::Read for ErrReader {
1397                fn read(&mut self, _: &mut [u8]) -> std::io::Result<usize> {
1398                    Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "gone"))
1399                }
1400            }
1401            let err = DecodeOptions::new()
1402                .decode_reader::<FlatMsg>(&mut ErrReader)
1403                .unwrap_err();
1404            assert_eq!(err.kind(), std::io::ErrorKind::BrokenPipe);
1405        }
1406
1407        #[test]
1408        fn decode_length_delimited_reader_basic() {
1409            let src = FlatMsg {
1410                value: 99,
1411                ..Default::default()
1412            };
1413            let mut ld = Vec::new();
1414            src.encode_length_delimited(&mut ld);
1415            let msg: FlatMsg = DecodeOptions::new()
1416                .decode_length_delimited_reader(&mut ld.as_slice())
1417                .unwrap();
1418            assert_eq!(msg.value, 99);
1419        }
1420
1421        #[test]
1422        fn decode_length_delimited_reader_rejects_oversize_prefix() {
1423            // Length prefix claims more bytes than max_message_size allows.
1424            let src = FlatMsg {
1425                value: 99,
1426                ..Default::default()
1427            };
1428            let mut ld = Vec::new();
1429            src.encode_length_delimited(&mut ld);
1430            let err = DecodeOptions::new()
1431                .with_max_message_size(1)
1432                .decode_length_delimited_reader::<FlatMsg>(&mut ld.as_slice())
1433                .unwrap_err();
1434            assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
1435        }
1436
1437        #[test]
1438        fn decode_length_delimited_reader_sequential() {
1439            // Two messages in a stream — typical log-file use case.
1440            let a = FlatMsg {
1441                value: 10,
1442                ..Default::default()
1443            };
1444            let b = FlatMsg {
1445                value: 20,
1446                ..Default::default()
1447            };
1448            let mut stream = Vec::new();
1449            a.encode_length_delimited(&mut stream);
1450            b.encode_length_delimited(&mut stream);
1451
1452            let mut cursor = std::io::Cursor::new(stream);
1453            let first: FlatMsg = DecodeOptions::new()
1454                .decode_length_delimited_reader(&mut cursor)
1455                .unwrap();
1456            assert_eq!(first.value, 10);
1457            let second: FlatMsg = DecodeOptions::new()
1458                .decode_length_delimited_reader(&mut cursor)
1459                .unwrap();
1460            assert_eq!(second.value, 20);
1461        }
1462
1463        #[test]
1464        fn decode_length_delimited_reader_truncated_body() {
1465            // Length prefix says N bytes but reader EOFs early.
1466            let mut buf = Vec::new();
1467            crate::encoding::encode_varint(100, &mut buf);
1468            buf.push(0x08);
1469            let err = DecodeOptions::new()
1470                .decode_length_delimited_reader::<FlatMsg>(&mut buf.as_slice())
1471                .unwrap_err();
1472            assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
1473        }
1474    }
1475}