Skip to main content

protowire_pb/
wire.rs

1// SPDX-License-Identifier: MIT
2// Copyright (c) 2026 TrendVidia, LLC.
3//! Low-level protobuf wire-format primitives: varint, zigzag, fixed32/64,
4//! length-delimited bytes, and tag (field-number + wire-type) encoding.
5//!
6//! Mirrors `google.golang.org/protobuf/encoding/protowire` at the call sites
7//! used by the schema-free `pb` codec, and the TS port's `wire.ts`.
8
9use thiserror::Error;
10
11/// HARDENING.md `MaxNestingDepth` — applies to PB submessage / group / map-entry
12/// nesting. Rejection happens before recursing into the inner message, so a
13/// 100k-deep adversarial input becomes a clean `Err(DepthExceeded)` instead
14/// of a stack-overflow abort.
15pub const MAX_NESTING_DEPTH: usize = 100;
16
17#[derive(Debug, Error)]
18pub enum Error {
19    #[error("truncated varint")]
20    TruncatedVarint,
21    #[error("varint exceeds 10 bytes")]
22    VarintTooLong,
23    #[error("truncated fixed32")]
24    TruncatedFixed32,
25    #[error("truncated fixed64")]
26    TruncatedFixed64,
27    #[error("truncated length-delimited")]
28    TruncatedLengthDelim,
29    #[error("invalid tag: field number 0 at offset {0}")]
30    InvalidTag(usize),
31    #[error("unknown wire type {0}")]
32    UnknownWireType(u8),
33    #[error("group wire types are not supported")]
34    GroupNotSupported,
35    #[error("invalid utf-8 in string field: {0}")]
36    InvalidUtf8(#[from] std::string::FromUtf8Error),
37    #[error("nested message exceeds buffer")]
38    NestedExceedsBuffer,
39    #[error("message overran (pos={pos}, end={end})")]
40    Overrun { pos: usize, end: usize },
41    #[error("nesting depth exceeds MaxNestingDepth ({0})")]
42    DepthExceeded(usize),
43}
44
45pub type Result<T> = std::result::Result<T, Error>;
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48#[repr(u8)]
49pub enum WireType {
50    Varint = 0,
51    Fixed64 = 1,
52    LengthDelimited = 2,
53    StartGroup = 3,
54    EndGroup = 4,
55    Fixed32 = 5,
56}
57
58impl WireType {
59    pub fn from_u8(v: u8) -> Result<Self> {
60        match v {
61            0 => Ok(Self::Varint),
62            1 => Ok(Self::Fixed64),
63            2 => Ok(Self::LengthDelimited),
64            3 => Ok(Self::StartGroup),
65            4 => Ok(Self::EndGroup),
66            5 => Ok(Self::Fixed32),
67            _ => Err(Error::UnknownWireType(v)),
68        }
69    }
70}
71
72#[derive(Debug, Default)]
73pub struct Writer {
74    buf: Vec<u8>,
75}
76
77impl Writer {
78    pub fn new() -> Self {
79        Self::default()
80    }
81
82    pub fn with_capacity(cap: usize) -> Self {
83        Self {
84            buf: Vec::with_capacity(cap),
85        }
86    }
87
88    pub fn finish(self) -> Vec<u8> {
89        self.buf
90    }
91
92    pub fn len(&self) -> usize {
93        self.buf.len()
94    }
95
96    pub fn is_empty(&self) -> bool {
97        self.buf.is_empty()
98    }
99
100    /// Append raw bytes (no length prefix).
101    pub fn raw(&mut self, b: &[u8]) {
102        self.buf.extend_from_slice(b);
103    }
104
105    /// Write an unsigned varint.
106    pub fn varint(&mut self, mut v: u64) {
107        while v >= 0x80 {
108            self.buf.push(((v & 0x7f) as u8) | 0x80);
109            v >>= 7;
110        }
111        self.buf.push(v as u8);
112    }
113
114    /// Proto3 `int32`: plain varint, with negative values sign-extended to a
115    /// 10-byte uint64.
116    pub fn varint_i32(&mut self, v: i32) {
117        self.varint(v as i64 as u64);
118    }
119
120    /// Proto3 `int64`: plain varint, two's-complement uint64 form.
121    pub fn varint_i64(&mut self, v: i64) {
122        self.varint(v as u64);
123    }
124
125    /// Zigzag-encoded signed varint (proto3 `sint32`).
126    pub fn zigzag32(&mut self, v: i32) {
127        let u = ((v << 1) ^ (v >> 31)) as u32;
128        self.varint(u as u64);
129    }
130
131    /// Zigzag-encoded signed varint (proto3 `sint64`).
132    pub fn zigzag64(&mut self, v: i64) {
133        let u = ((v << 1) ^ (v >> 63)) as u64;
134        self.varint(u);
135    }
136
137    /// Little-endian fixed 32-bit unsigned integer.
138    pub fn fixed32(&mut self, v: u32) {
139        self.buf.extend_from_slice(&v.to_le_bytes());
140    }
141
142    /// Little-endian fixed 64-bit unsigned integer.
143    pub fn fixed64(&mut self, v: u64) {
144        self.buf.extend_from_slice(&v.to_le_bytes());
145    }
146
147    /// IEEE 754 32-bit float, little-endian.
148    pub fn float(&mut self, v: f32) {
149        self.buf.extend_from_slice(&v.to_le_bytes());
150    }
151
152    /// IEEE 754 64-bit double, little-endian.
153    pub fn double(&mut self, v: f64) {
154        self.buf.extend_from_slice(&v.to_le_bytes());
155    }
156
157    /// UTF-8 length-prefixed string.
158    pub fn string(&mut self, v: &str) {
159        self.bytes(v.as_bytes());
160    }
161
162    /// Length-prefixed byte sequence.
163    pub fn bytes(&mut self, v: &[u8]) {
164        self.varint(v.len() as u64);
165        self.raw(v);
166    }
167
168    /// Tag = (field_number << 3) | wire_type, encoded as a varint.
169    ///
170    /// Panics on out-of-range field numbers — that's a programmer error,
171    /// not a wire-format failure.
172    pub fn tag(&mut self, field_number: u32, wire_type: WireType) {
173        assert!(
174            (1..=0x1fff_ffff).contains(&field_number),
175            "field number out of range: {field_number}"
176        );
177        self.varint(((field_number as u64) << 3) | (wire_type as u64));
178    }
179}
180
181pub struct Reader<'a> {
182    pub(crate) data: &'a [u8],
183    pub pos: usize,
184    /// Live recursion depth, incremented by `read_message` when entering a
185    /// nested submessage and decremented on exit. The depth survives across
186    /// `merge_field` calls so a `Message` impl that hands the same `Reader`
187    /// to a fresh `read_message` cannot reset it to zero.
188    pub(crate) depth: usize,
189}
190
191impl<'a> Reader<'a> {
192    pub fn new(data: &'a [u8]) -> Self {
193        Self {
194            data,
195            pos: 0,
196            depth: 0,
197        }
198    }
199
200    pub fn data(&self) -> &'a [u8] {
201        self.data
202    }
203
204    pub fn eof(&self) -> bool {
205        self.pos >= self.data.len()
206    }
207
208    pub fn remaining(&self) -> usize {
209        self.data.len().saturating_sub(self.pos)
210    }
211
212    /// Read an unsigned varint, up to 10 bytes (uint64 range).
213    pub fn varint(&mut self) -> Result<u64> {
214        let mut result: u64 = 0;
215        let mut shift = 0u32;
216        for i in 0..10 {
217            if self.pos >= self.data.len() {
218                return Err(Error::TruncatedVarint);
219            }
220            let byte = self.data[self.pos];
221            self.pos += 1;
222            result |= ((byte & 0x7f) as u64) << shift;
223            if byte & 0x80 == 0 {
224                return Ok(result);
225            }
226            shift += 7;
227            if i == 9 {
228                return Err(Error::VarintTooLong);
229            }
230        }
231        Err(Error::VarintTooLong)
232    }
233
234    /// Decode a zigzag varint as a 32-bit signed integer.
235    pub fn zigzag32(&mut self) -> Result<i32> {
236        let u = self.varint()? as u32;
237        Ok(((u >> 1) as i32) ^ -((u & 1) as i32))
238    }
239
240    /// Decode a zigzag varint as a 64-bit signed integer.
241    pub fn zigzag64(&mut self) -> Result<i64> {
242        let u = self.varint()?;
243        Ok(((u >> 1) as i64) ^ -((u & 1) as i64))
244    }
245
246    pub fn fixed32(&mut self) -> Result<u32> {
247        if self.pos + 4 > self.data.len() {
248            return Err(Error::TruncatedFixed32);
249        }
250        let v = u32::from_le_bytes(self.data[self.pos..self.pos + 4].try_into().unwrap());
251        self.pos += 4;
252        Ok(v)
253    }
254
255    pub fn fixed64(&mut self) -> Result<u64> {
256        if self.pos + 8 > self.data.len() {
257            return Err(Error::TruncatedFixed64);
258        }
259        let v = u64::from_le_bytes(self.data[self.pos..self.pos + 8].try_into().unwrap());
260        self.pos += 8;
261        Ok(v)
262    }
263
264    pub fn float(&mut self) -> Result<f32> {
265        Ok(f32::from_bits(self.fixed32()?))
266    }
267
268    pub fn double(&mut self) -> Result<f64> {
269        Ok(f64::from_bits(self.fixed64()?))
270    }
271
272    /// Length-prefixed bytes; returns a copy.
273    pub fn bytes(&mut self) -> Result<Vec<u8>> {
274        Ok(self.bytes_view()?.to_vec())
275    }
276
277    /// Length-prefixed bytes; returns a borrow into the underlying buffer.
278    ///
279    /// Guards against attacker-supplied length-prefix overflow per
280    /// HARDENING.md §API contract item 3: a 10-byte varint of `2^64-1`
281    /// would wrap `pos + len` to a small value and slip past a naive
282    /// bounds check, then trip a slice-indexing panic. Compute the end
283    /// offset with `checked_add` and reject before slicing.
284    pub fn bytes_view(&mut self) -> Result<&'a [u8]> {
285        let len = self.read_length()?;
286        let end = self.pos + len;
287        let view = &self.data[self.pos..end];
288        self.pos = end;
289        Ok(view)
290    }
291
292    /// Read a varint length and validate it fits in the remaining buffer.
293    /// Returns the length as `usize` ready for slicing. Used by every
294    /// length-delimited consumer (`bytes_view`, `skip`, `read_message`)
295    /// so the overflow guard exists in exactly one place.
296    fn read_length(&mut self) -> Result<usize> {
297        let len = self.varint()?;
298        let len = usize::try_from(len).map_err(|_| Error::TruncatedLengthDelim)?;
299        let end = self
300            .pos
301            .checked_add(len)
302            .ok_or(Error::TruncatedLengthDelim)?;
303        if end > self.data.len() {
304            return Err(Error::TruncatedLengthDelim);
305        }
306        Ok(len)
307    }
308
309    /// UTF-8 length-prefixed string.
310    pub fn string(&mut self) -> Result<String> {
311        let bytes = self.bytes_view()?.to_vec();
312        Ok(String::from_utf8(bytes)?)
313    }
314
315    /// Decode a tag varint into (field_number, wire_type).
316    pub fn tag(&mut self) -> Result<(u32, WireType)> {
317        let t = self.varint()?;
318        let wire_type = WireType::from_u8((t & 0x7) as u8)?;
319        let field_number = (t >> 3) as u32;
320        if field_number == 0 {
321            return Err(Error::InvalidTag(self.pos));
322        }
323        Ok((field_number, wire_type))
324    }
325
326    /// Skip the value of a field with the given wire type.
327    pub fn skip(&mut self, wire_type: WireType) -> Result<()> {
328        match wire_type {
329            WireType::Varint => {
330                self.varint()?;
331            }
332            WireType::Fixed64 => {
333                if self.pos + 8 > self.data.len() {
334                    return Err(Error::TruncatedFixed64);
335                }
336                self.pos += 8;
337            }
338            WireType::LengthDelimited => {
339                let len = self.read_length()?;
340                self.pos += len;
341            }
342            WireType::Fixed32 => {
343                if self.pos + 4 > self.data.len() {
344                    return Err(Error::TruncatedFixed32);
345                }
346                self.pos += 4;
347            }
348            WireType::StartGroup | WireType::EndGroup => {
349                return Err(Error::GroupNotSupported);
350            }
351        }
352        Ok(())
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    fn round_trip_varint(v: u64) -> u64 {
361        let mut w = Writer::new();
362        w.varint(v);
363        let bytes = w.finish();
364        let mut r = Reader::new(&bytes);
365        let out = r.varint().unwrap();
366        assert!(r.eof());
367        out
368    }
369
370    #[test]
371    fn varint_encodes_zero_as_single_byte() {
372        let mut w = Writer::new();
373        w.varint(0);
374        assert_eq!(w.finish(), vec![0]);
375    }
376
377    #[test]
378    fn varint_round_trips_small_numbers() {
379        for v in [0u64, 1, 127, 128, 255, 256, 16383, 16384] {
380            assert_eq!(round_trip_varint(v), v);
381        }
382    }
383
384    #[test]
385    fn varint_round_trips_up_to_i64_max() {
386        let v = i64::MAX as u64;
387        assert_eq!(round_trip_varint(v), v);
388    }
389
390    #[test]
391    fn varint_round_trips_full_uint64_range() {
392        for v in [0u64, 1, 0x80, 0xff, 0xffff, 0xffff_ffff, u64::MAX] {
393            assert_eq!(round_trip_varint(v), v);
394        }
395    }
396
397    #[test]
398    fn varint_encodes_150_as_canonical_proto_example() {
399        let mut w = Writer::new();
400        w.varint(150);
401        assert_eq!(w.finish(), vec![0x96, 0x01]);
402    }
403
404    #[test]
405    fn zigzag32_matches_proto3_spec() {
406        let cases: &[(i32, u32)] = &[
407            (0, 0),
408            (-1, 1),
409            (1, 2),
410            (-2, 3),
411            (2147483647, 4294967294),
412            (-2147483648, 4294967295),
413        ];
414        for &(signed, encoded) in cases {
415            let mut w = Writer::new();
416            w.zigzag32(signed);
417            let bytes = w.finish();
418            let mut r = Reader::new(&bytes);
419            assert_eq!(r.varint().unwrap() as u32, encoded);
420
421            let mut r2 = Reader::new(&bytes);
422            assert_eq!(r2.zigzag32().unwrap(), signed);
423        }
424    }
425
426    #[test]
427    fn zigzag64_round_trips_boundary_values() {
428        for v in [0i64, -1, 1, -2, i64::MAX, i64::MIN] {
429            let mut w = Writer::new();
430            w.zigzag64(v);
431            let bytes = w.finish();
432            let mut r = Reader::new(&bytes);
433            assert_eq!(r.zigzag64().unwrap(), v);
434        }
435    }
436
437    #[test]
438    fn fixed32_round_trips() {
439        for v in [0u32, 1, 0x7fff_ffff, 0xffff_ffff] {
440            let mut w = Writer::new();
441            w.fixed32(v);
442            let bytes = w.finish();
443            let mut r = Reader::new(&bytes);
444            assert_eq!(r.fixed32().unwrap(), v);
445        }
446    }
447
448    #[test]
449    fn fixed64_round_trips_uint64() {
450        for v in [0u64, 1, 0xffff_ffff, u64::MAX] {
451            let mut w = Writer::new();
452            w.fixed64(v);
453            let bytes = w.finish();
454            let mut r = Reader::new(&bytes);
455            assert_eq!(r.fixed64().unwrap(), v);
456        }
457    }
458
459    #[test]
460    fn float_and_double_round_trip() {
461        let mut w = Writer::new();
462        w.float(2.5);
463        w.double(std::f64::consts::PI);
464        let bytes = w.finish();
465        let mut r = Reader::new(&bytes);
466        assert!((r.float().unwrap() - 2.5).abs() < 1e-5);
467        assert_eq!(r.double().unwrap(), std::f64::consts::PI);
468    }
469
470    #[test]
471    fn utf8_strings_round_trip() {
472        let mut w = Writer::new();
473        w.string("héllo, 世界");
474        let bytes = w.finish();
475        let mut r = Reader::new(&bytes);
476        assert_eq!(r.string().unwrap(), "héllo, 世界");
477    }
478
479    #[test]
480    fn bytes_round_trip() {
481        let mut w = Writer::new();
482        w.bytes(&[0xde, 0xad, 0xbe, 0xef]);
483        let bytes = w.finish();
484        let mut r = Reader::new(&bytes);
485        assert_eq!(r.bytes().unwrap(), vec![0xde, 0xad, 0xbe, 0xef]);
486    }
487
488    #[test]
489    fn tag_for_field_1_varint_is_0x08() {
490        let mut w = Writer::new();
491        w.tag(1, WireType::Varint);
492        assert_eq!(w.finish(), vec![0x08]);
493    }
494
495    #[test]
496    fn tag_decodes_back_to_field_number_and_wire_type() {
497        let mut w = Writer::new();
498        w.tag(15, WireType::LengthDelimited);
499        let bytes = w.finish();
500        let mut r = Reader::new(&bytes);
501        assert_eq!(r.tag().unwrap(), (15, WireType::LengthDelimited));
502    }
503
504    #[test]
505    fn skip_handles_each_wire_type() {
506        let mut w = Writer::new();
507        w.tag(1, WireType::Varint);
508        w.varint(150);
509        w.tag(2, WireType::Fixed32);
510        w.fixed32(0xdead_beef);
511        w.tag(3, WireType::Fixed64);
512        w.fixed64(0xdead_beef_cafe_babe);
513        w.tag(4, WireType::LengthDelimited);
514        w.string("skip me");
515        w.tag(5, WireType::Varint);
516        w.varint(7);
517
518        let bytes = w.finish();
519        let mut r = Reader::new(&bytes);
520        let mut keep5: Option<u64> = None;
521        while !r.eof() {
522            let (num, wt) = r.tag().unwrap();
523            if num == 5 {
524                keep5 = Some(r.varint().unwrap());
525            } else {
526                r.skip(wt).unwrap();
527            }
528        }
529        assert_eq!(keep5, Some(7));
530    }
531
532    #[test]
533    fn truncated_varint_is_rejected() {
534        let mut r = Reader::new(&[0x80]);
535        assert!(matches!(r.varint(), Err(Error::TruncatedVarint)));
536    }
537
538    #[test]
539    fn length_prefix_max_varint_does_not_overflow() {
540        // tag(1, LengthDelimited) + 10-byte u64::MAX varint length.
541        // Naive `pos + len > data.len()` would wrap and slip past the
542        // bounds check, then panic on the slice. HARDENING.md §API
543        // contract requires a clean reject.
544        let mut bytes = vec![0x0a];
545        // u64::MAX as a 10-byte varint
546        bytes.extend_from_slice(&[0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01]);
547        let mut r = Reader::new(&bytes);
548        let (_, _) = r.tag().unwrap();
549        assert!(matches!(r.bytes_view(), Err(Error::TruncatedLengthDelim)));
550    }
551
552    #[test]
553    fn length_prefix_overflow_during_skip_does_not_panic() {
554        let mut bytes = vec![0x0a];
555        bytes.extend_from_slice(&[0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01]);
556        let mut r = Reader::new(&bytes);
557        let (_, wt) = r.tag().unwrap();
558        assert!(matches!(r.skip(wt), Err(Error::TruncatedLengthDelim)));
559    }
560}