Skip to main content

protowire_pb/
codec.rs

1// SPDX-License-Identifier: MIT
2// Copyright (c) 2026 TrendVidia, LLC.
3//! Schema-driven binary marshal/unmarshal via the [`Message`] trait.
4//!
5//! Each Rust message type implements `Message` to encode/decode its own
6//! fields one at a time over the wire primitives in [`crate::wire`].
7//! Helpers for nested-message blobs live here so impls stay short.
8//!
9//! Wire-format choices match proto3 semantics:
10//!
11//! - `int32` / `int64`: plain varint, with negative values sign-extended
12//!   to a 10-byte uint64.
13//! - `sint32` / `sint64`: zigzag varint; more compact for negative values.
14//! - `uint32` / `uint64`: plain varint.
15//! - `bool`: varint (0 / 1).
16//! - `float`: fixed32; `double`: fixed64.
17//! - `string` and `bytes`: length-delimited.
18//! - nested messages: length-delimited.
19//! - repeated fields: one tag+value per element (non-packed).
20//! - maps: each entry is a length-delimited `MapEntry { key=1; value=2 }`.
21
22use crate::wire::{Error, Reader, Result, WireType, Writer, MAX_NESTING_DEPTH};
23
24/// A message with self-contained encode/decode. Mirrors the role of
25/// `prost::Message` for our trait-based codec.
26pub trait Message: Sized + Default {
27    /// Append the message's fields to `w`. Caller is responsible for
28    /// the surrounding length prefix when used as a nested message.
29    fn encode_to(&self, w: &mut Writer);
30
31    /// Merge a single field (already-decoded tag) into `self`.
32    /// Implementations should call `r.skip(wire_type)` for unknown numbers.
33    fn merge_field(
34        &mut self,
35        field_number: u32,
36        wire_type: WireType,
37        r: &mut Reader<'_>,
38    ) -> Result<()>;
39}
40
41/// Encode a message to a byte vector.
42pub fn marshal<M: Message>(value: &M) -> Vec<u8> {
43    let mut w = Writer::new();
44    value.encode_to(&mut w);
45    w.finish()
46}
47
48/// Decode a message from a byte slice.
49pub fn unmarshal<M: Message>(data: &[u8]) -> Result<M> {
50    let mut r = Reader::new(data);
51    let mut msg = M::default();
52    while !r.eof() {
53        let (num, wt) = r.tag()?;
54        msg.merge_field(num, wt, &mut r)?;
55    }
56    Ok(msg)
57}
58
59/// Write a nested message at `field_number` as a length-delimited blob.
60pub fn write_message<M: Message>(w: &mut Writer, field_number: u32, msg: &M) {
61    let mut inner = Writer::new();
62    msg.encode_to(&mut inner);
63    let bytes = inner.finish();
64    w.tag(field_number, WireType::LengthDelimited);
65    w.varint(bytes.len() as u64);
66    w.raw(&bytes);
67}
68
69/// Read a length-delimited nested message. The reader's tag is already consumed.
70///
71/// Increments `r.depth` for the duration of the inner decode and rejects with
72/// [`Error::DepthExceeded`] before recursing past [`MAX_NESTING_DEPTH`]. Per
73/// HARDENING.md §Recursion, the counter must persist across `merge_field` →
74/// `read_message` re-entry; that's why it lives on the `Reader`, not as a
75/// thread-local or function argument.
76///
77/// The length-prefix bounds check uses `checked_add` so that a maximum-value
78/// varint length (2^64 - 1) cannot wrap `pos + len` past a naive comparison
79/// and trip a slice-indexing panic — HARDENING.md §API contract item 3.
80pub fn read_message<M: Message>(r: &mut Reader<'_>) -> Result<M> {
81    let len = r.varint()?;
82    let len = usize::try_from(len).map_err(|_| Error::NestedExceedsBuffer)?;
83    let end = r.pos.checked_add(len).ok_or(Error::NestedExceedsBuffer)?;
84    if end > r.data().len() {
85        return Err(Error::NestedExceedsBuffer);
86    }
87    if r.depth >= MAX_NESTING_DEPTH {
88        return Err(Error::DepthExceeded(MAX_NESTING_DEPTH));
89    }
90    r.depth += 1;
91    let result = (|| -> Result<M> {
92        let mut msg = M::default();
93        while r.pos < end {
94            let (num, wt) = r.tag()?;
95            msg.merge_field(num, wt, r)?;
96        }
97        if r.pos != end {
98            return Err(Error::Overrun { pos: r.pos, end });
99        }
100        Ok(msg)
101    })();
102    r.depth -= 1;
103    result
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use std::collections::BTreeMap;
110
111    // --- Test message types ---
112    //
113    // Mirror `Inner` / `Outer` from the TS port's pb/codec.test.ts.
114
115    #[derive(Debug, Default, Clone, PartialEq)]
116    struct Inner {
117        name: String,
118        value: i32,
119    }
120
121    impl Message for Inner {
122        fn encode_to(&self, w: &mut Writer) {
123            if !self.name.is_empty() {
124                w.tag(1, WireType::LengthDelimited);
125                w.string(&self.name);
126            }
127            if self.value != 0 {
128                w.tag(2, WireType::Varint);
129                w.varint_i32(self.value);
130            }
131        }
132
133        fn merge_field(&mut self, num: u32, wt: WireType, r: &mut Reader<'_>) -> Result<()> {
134            match num {
135                1 => self.name = r.string()?,
136                2 => self.value = r.varint()? as i32,
137                _ => r.skip(wt)?,
138            }
139            Ok(())
140        }
141    }
142
143    #[derive(Debug, Default, Clone, PartialEq)]
144    struct Outer {
145        title: String,
146        count: u32,
147        score: f64,
148        active: bool,
149        data: Vec<u8>,
150        items: Vec<Inner>,
151        signed: i64,
152        small_f: f32,
153    }
154
155    impl Message for Outer {
156        fn encode_to(&self, w: &mut Writer) {
157            if !self.title.is_empty() {
158                w.tag(1, WireType::LengthDelimited);
159                w.string(&self.title);
160            }
161            if self.count != 0 {
162                w.tag(2, WireType::Varint);
163                w.varint(self.count as u64);
164            }
165            if self.score != 0.0 {
166                w.tag(3, WireType::Fixed64);
167                w.double(self.score);
168            }
169            if self.active {
170                w.tag(4, WireType::Varint);
171                w.varint(1);
172            }
173            if !self.data.is_empty() {
174                w.tag(5, WireType::LengthDelimited);
175                w.bytes(&self.data);
176            }
177            for item in &self.items {
178                write_message(w, 6, item);
179            }
180            if self.signed != 0 {
181                w.tag(8, WireType::Varint);
182                w.varint_i64(self.signed);
183            }
184            if self.small_f != 0.0 {
185                w.tag(9, WireType::Fixed32);
186                w.float(self.small_f);
187            }
188        }
189
190        fn merge_field(&mut self, num: u32, wt: WireType, r: &mut Reader<'_>) -> Result<()> {
191            match num {
192                1 => self.title = r.string()?,
193                2 => self.count = r.varint()? as u32,
194                3 => self.score = r.double()?,
195                4 => self.active = r.varint()? != 0,
196                5 => self.data = r.bytes()?,
197                6 => self.items.push(read_message(r)?),
198                8 => self.signed = r.varint()? as i64,
199                9 => self.small_f = r.float()?,
200                _ => r.skip(wt)?,
201            }
202            Ok(())
203        }
204    }
205
206    #[test]
207    fn populated_message_round_trip() {
208        let orig = Outer {
209            title: "hello".into(),
210            count: 42,
211            score: 3.125,
212            active: true,
213            data: vec![0xde, 0xad],
214            items: vec![
215                Inner {
216                    name: "a".into(),
217                    value: 1,
218                },
219                Inner {
220                    name: "b".into(),
221                    value: -7,
222                },
223            ],
224            signed: -12345,
225            small_f: 2.5,
226        };
227        let bytes = marshal(&orig);
228        let got: Outer = unmarshal(&bytes).unwrap();
229        assert_eq!(got, orig);
230    }
231
232    #[test]
233    fn all_zero_message_marshals_to_empty_bytes() {
234        let bytes = marshal(&Outer::default());
235        assert!(bytes.is_empty());
236    }
237
238    #[test]
239    fn empty_bytes_unmarshal_to_default() {
240        let got: Outer = unmarshal(&[]).unwrap();
241        assert_eq!(got, Outer::default());
242    }
243
244    #[test]
245    fn unknown_fields_are_skipped() {
246        #[derive(Debug, Default, PartialEq)]
247        struct Big {
248            a: String,
249            b: String,
250            c: String,
251        }
252        impl Message for Big {
253            fn encode_to(&self, w: &mut Writer) {
254                if !self.a.is_empty() {
255                    w.tag(1, WireType::LengthDelimited);
256                    w.string(&self.a);
257                }
258                if !self.b.is_empty() {
259                    w.tag(2, WireType::LengthDelimited);
260                    w.string(&self.b);
261                }
262                if !self.c.is_empty() {
263                    w.tag(3, WireType::LengthDelimited);
264                    w.string(&self.c);
265                }
266            }
267            fn merge_field(&mut self, num: u32, wt: WireType, r: &mut Reader<'_>) -> Result<()> {
268                match num {
269                    1 => self.a = r.string()?,
270                    2 => self.b = r.string()?,
271                    3 => self.c = r.string()?,
272                    _ => r.skip(wt)?,
273                }
274                Ok(())
275            }
276        }
277        #[derive(Debug, Default, PartialEq)]
278        struct Small {
279            a: String,
280        }
281        impl Message for Small {
282            fn encode_to(&self, w: &mut Writer) {
283                if !self.a.is_empty() {
284                    w.tag(1, WireType::LengthDelimited);
285                    w.string(&self.a);
286                }
287            }
288            fn merge_field(&mut self, num: u32, wt: WireType, r: &mut Reader<'_>) -> Result<()> {
289                match num {
290                    1 => self.a = r.string()?,
291                    _ => r.skip(wt)?,
292                }
293                Ok(())
294            }
295        }
296
297        let bytes = marshal(&Big {
298            a: "aa".into(),
299            b: "bb".into(),
300            c: "cc".into(),
301        });
302        let got: Small = unmarshal(&bytes).unwrap();
303        assert_eq!(got.a, "aa");
304    }
305
306    #[derive(Debug, Default, Clone, PartialEq)]
307    struct Wrap {
308        inner: Option<Inner>,
309    }
310
311    impl Message for Wrap {
312        fn encode_to(&self, w: &mut Writer) {
313            if let Some(ref i) = self.inner {
314                write_message(w, 1, i);
315            }
316        }
317        fn merge_field(&mut self, num: u32, wt: WireType, r: &mut Reader<'_>) -> Result<()> {
318            match num {
319                1 => self.inner = Some(read_message(r)?),
320                _ => r.skip(wt)?,
321            }
322            Ok(())
323        }
324    }
325
326    #[test]
327    fn singular_nested_none_omits_tag() {
328        let bytes = marshal(&Wrap { inner: None });
329        assert!(bytes.is_empty());
330        let got: Wrap = unmarshal(&bytes).unwrap();
331        assert!(got.inner.is_none());
332    }
333
334    #[test]
335    fn singular_nested_populated_round_trips() {
336        let bytes = marshal(&Wrap {
337            inner: Some(Inner {
338                name: "x".into(),
339                value: 9,
340            }),
341        });
342        let got: Wrap = unmarshal(&bytes).unwrap();
343        assert_eq!(
344            got.inner,
345            Some(Inner {
346                name: "x".into(),
347                value: 9
348            })
349        );
350    }
351
352    #[test]
353    fn singular_nested_empty_emits_zero_length_blob() {
354        // Some(Inner::default()) should still emit tag(1, LengthDelim) + len 0.
355        let bytes = marshal(&Wrap {
356            inner: Some(Inner::default()),
357        });
358        assert_eq!(bytes, vec![0x0a, 0x00]);
359        let got: Wrap = unmarshal(&bytes).unwrap();
360        assert_eq!(got.inner, Some(Inner::default()));
361    }
362
363    #[derive(Debug, Default, Clone, PartialEq)]
364    struct WithStringMap {
365        meta: BTreeMap<String, String>,
366    }
367
368    impl Message for WithStringMap {
369        fn encode_to(&self, w: &mut Writer) {
370            for (k, v) in &self.meta {
371                let mut inner = Writer::new();
372                if !k.is_empty() {
373                    inner.tag(1, WireType::LengthDelimited);
374                    inner.string(k);
375                }
376                if !v.is_empty() {
377                    inner.tag(2, WireType::LengthDelimited);
378                    inner.string(v);
379                }
380                let bytes = inner.finish();
381                w.tag(1, WireType::LengthDelimited);
382                w.varint(bytes.len() as u64);
383                w.raw(&bytes);
384            }
385        }
386        fn merge_field(&mut self, num: u32, wt: WireType, r: &mut Reader<'_>) -> Result<()> {
387            match num {
388                1 => {
389                    let len = r.varint()? as usize;
390                    let end = r.pos + len;
391                    let mut k = String::new();
392                    let mut v = String::new();
393                    while r.pos < end {
394                        let (n, w) = r.tag()?;
395                        match n {
396                            1 => k = r.string()?,
397                            2 => v = r.string()?,
398                            _ => r.skip(w)?,
399                        }
400                    }
401                    self.meta.insert(k, v);
402                }
403                _ => r.skip(wt)?,
404            }
405            Ok(())
406        }
407    }
408
409    #[test]
410    fn map_string_string_round_trips() {
411        let mut meta = BTreeMap::new();
412        meta.insert("a".into(), "1".into());
413        meta.insert("b".into(), "2".into());
414        meta.insert("key with space".into(), "v".into());
415        let bytes = marshal(&WithStringMap { meta: meta.clone() });
416        let got: WithStringMap = unmarshal(&bytes).unwrap();
417        assert_eq!(got.meta, meta);
418    }
419
420    #[test]
421    fn map_string_string_empty_produces_empty_bytes() {
422        let bytes = marshal(&WithStringMap::default());
423        assert!(bytes.is_empty());
424        let got: WithStringMap = unmarshal(&bytes).unwrap();
425        assert!(got.meta.is_empty());
426    }
427
428    #[derive(Debug, Default, Clone, PartialEq)]
429    struct WithIntMap {
430        codes: BTreeMap<i32, String>,
431    }
432
433    impl Message for WithIntMap {
434        fn encode_to(&self, w: &mut Writer) {
435            for (k, v) in &self.codes {
436                let mut inner = Writer::new();
437                if *k != 0 {
438                    inner.tag(1, WireType::Varint);
439                    inner.varint_i32(*k);
440                }
441                if !v.is_empty() {
442                    inner.tag(2, WireType::LengthDelimited);
443                    inner.string(v);
444                }
445                let bytes = inner.finish();
446                w.tag(1, WireType::LengthDelimited);
447                w.varint(bytes.len() as u64);
448                w.raw(&bytes);
449            }
450        }
451        fn merge_field(&mut self, num: u32, wt: WireType, r: &mut Reader<'_>) -> Result<()> {
452            match num {
453                1 => {
454                    let len = r.varint()? as usize;
455                    let end = r.pos + len;
456                    let mut k: i32 = 0;
457                    let mut v = String::new();
458                    while r.pos < end {
459                        let (n, w) = r.tag()?;
460                        match n {
461                            1 => k = r.varint()? as i32,
462                            2 => v = r.string()?,
463                            _ => r.skip(w)?,
464                        }
465                    }
466                    self.codes.insert(k, v);
467                }
468                _ => r.skip(wt)?,
469            }
470            Ok(())
471        }
472    }
473
474    #[test]
475    fn map_int32_string_round_trips() {
476        let mut codes = BTreeMap::new();
477        codes.insert(404, "Not Found".into());
478        codes.insert(500, "Internal".into());
479        let bytes = marshal(&WithIntMap {
480            codes: codes.clone(),
481        });
482        let got: WithIntMap = unmarshal(&bytes).unwrap();
483        assert_eq!(got.codes, codes);
484    }
485
486    // --- Cross-port wire-contract specifics ---
487    //
488    // The two TS-only schema-validation tests (duplicate field number,
489    // repeated+map) don't apply to a trait-based codec — those are
490    // compile-time invariants here. Replace them with two tests that
491    // pin down the wire-format invariants the cross-port script
492    // depends on.
493
494    #[derive(Debug, Default, Clone, PartialEq)]
495    struct SignedI32 {
496        v: i32,
497    }
498    impl Message for SignedI32 {
499        fn encode_to(&self, w: &mut Writer) {
500            if self.v != 0 {
501                w.tag(1, WireType::Varint);
502                w.varint_i32(self.v);
503            }
504        }
505        fn merge_field(&mut self, num: u32, wt: WireType, r: &mut Reader<'_>) -> Result<()> {
506            match num {
507                1 => self.v = r.varint()? as i32,
508                _ => r.skip(wt)?,
509            }
510            Ok(())
511        }
512    }
513
514    #[test]
515    fn proto3_int32_negative_sign_extends_to_10_byte_varint() {
516        // Cross-port contract: -1 as proto3 int32 emits FF FF FF FF FF FF FF FF FF 01
517        // (sign-extended uint64). Required for envelope parity with Go/C++/TS/Java.
518        let bytes = marshal(&SignedI32 { v: -1 });
519        assert_eq!(
520            bytes,
521            vec![0x08, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01]
522        );
523        let got: SignedI32 = unmarshal(&bytes).unwrap();
524        assert_eq!(got.v, -1);
525    }
526
527    #[derive(Debug, Default, Clone, PartialEq)]
528    struct ZigzagI32 {
529        v: i32,
530    }
531    impl Message for ZigzagI32 {
532        fn encode_to(&self, w: &mut Writer) {
533            if self.v != 0 {
534                w.tag(1, WireType::Varint);
535                w.zigzag32(self.v);
536            }
537        }
538        fn merge_field(&mut self, num: u32, wt: WireType, r: &mut Reader<'_>) -> Result<()> {
539            match num {
540                1 => self.v = r.zigzag32()?,
541                _ => r.skip(wt)?,
542            }
543            Ok(())
544        }
545    }
546
547    // --- HARDENING.md §Recursion -----------------------------------------
548    //
549    // `read_message` must reject before recursing past MAX_NESTING_DEPTH and
550    // must not crash on adversarial deep-nesting input. The cap matches the
551    // cross-port HARDENING.md default of 100.
552
553    /// Self-recursive PB message — mirrors `adversarial.v1.Tree` in the
554    /// shared corpus. A tower of `Tree`s is the canonical adversarial
555    /// fixture for depth-cap testing.
556    #[derive(Default, Debug)]
557    struct Tree {
558        child: Option<Box<Tree>>,
559    }
560
561    impl Message for Tree {
562        fn encode_to(&self, w: &mut Writer) {
563            if let Some(c) = &self.child {
564                write_message(w, 1, c.as_ref());
565            }
566        }
567        fn merge_field(&mut self, num: u32, wt: WireType, r: &mut Reader<'_>) -> Result<()> {
568            match num {
569                1 => self.child = Some(Box::new(read_message(r)?)),
570                _ => r.skip(wt)?,
571            }
572            Ok(())
573        }
574    }
575
576    /// Build wire bytes for a Tree of N nested `child` levels.
577    fn build_tree_bytes(depth: usize) -> Vec<u8> {
578        let mut payload: Vec<u8> = Vec::new(); // empty leaf
579        for _ in 0..depth {
580            let mut framed = Vec::new();
581            framed.push(0x0a); // tag(1, LengthDelimited)
582            let mut len = payload.len() as u64;
583            while len >= 0x80 {
584                framed.push(((len & 0x7f) as u8) | 0x80);
585                len >>= 7;
586            }
587            framed.push(len as u8);
588            framed.extend_from_slice(&payload);
589            payload = framed;
590        }
591        payload
592    }
593
594    #[test]
595    fn deep_submessage_at_limit_is_accepted() {
596        // 100 nested children → root + 100 read_message calls. Cap is the
597        // increment count, so 100 levels of read_message reach depth=100
598        // without exceeding it.
599        let bytes = build_tree_bytes(100);
600        let _: Tree = unmarshal(&bytes).unwrap();
601    }
602
603    #[test]
604    fn deep_submessage_past_limit_returns_depth_exceeded() {
605        // 200 levels must reject cleanly, not crash.
606        let bytes = build_tree_bytes(200);
607        let res: Result<Tree> = unmarshal(&bytes);
608        assert!(
609            matches!(res, Err(Error::DepthExceeded(100))),
610            "got {:?}",
611            res
612        );
613    }
614
615    #[test]
616    fn deep_submessage_at_extreme_depth_rejects_without_stack_overflow() {
617        // 100k-deep is the SIGABRT case from issue #1. The cap must trip
618        // before native stack exhaustion.
619        let bytes = build_tree_bytes(100_000);
620        let res: Result<Tree> = unmarshal(&bytes);
621        assert!(
622            matches!(res, Err(Error::DepthExceeded(100))),
623            "got {:?}",
624            res
625        );
626    }
627
628    #[test]
629    fn sint32_zigzag_is_compact_for_negative_values() {
630        // -1 as sint32 emits a single byte (0x01) instead of the 10-byte
631        // sign-extended int32 form. This is the wire-format choice the
632        // `zigzag` opt-in selects.
633        let bytes = marshal(&ZigzagI32 { v: -1 });
634        assert_eq!(bytes, vec![0x08, 0x01]);
635        let got: ZigzagI32 = unmarshal(&bytes).unwrap();
636        assert_eq!(got.v, -1);
637    }
638}