Skip to main content

prototext_core/helpers/
varint.rs

1// SPDX-FileCopyrightText: 2025 - 2026 Frederic Ruget <fred@atlant.is> <fred@s3ns.io> (GitHub: @douzebis)
2// SPDX-FileCopyrightText: 2025 - 2026 Thales Cloud Sécurisé
3//
4// SPDX-License-Identifier: MIT
5
6// ── Varint parser result ──────────────────────────────────────────────────────
7
8/// Result of parsing one protobuf varint from a byte slice.
9///
10/// Mirrors the Python `Varint` class in `lib/varint.py`.
11///
12/// Exactly one of `varint` or `varint_gar` is `Some`:
13/// * `varint_gar` is `Some` when the varint is truncated (buffer ends before
14///   the terminator byte) or exceeds 64 bits.
15/// * `varint` is `Some` for a successfully decoded varint.
16/// * `varint_ohb` counts trailing non-canonical (overhung) bytes: set when
17///   the terminating byte is `0x00` preceded by one or more `0x80` bytes.
18#[derive(Debug)]
19pub struct VarintResult {
20    /// Byte position immediately after the parsed varint.
21    pub next_pos: usize,
22    /// `Some(raw_bytes)` when the varint is garbage (truncated / too large).
23    pub varint_gar: Option<Vec<u8>>,
24    /// The decoded varint value (valid only when `varint_gar` is `None`).
25    pub varint: Option<u64>,
26    /// Number of non-canonical overhang bytes (valid only when `varint_gar` is `None`).
27    pub varint_ohb: Option<u64>,
28}
29
30/// Parse one protobuf varint starting at `start` in `buf`.
31///
32/// Mirrors `Varint.__init__` in `lib/varint.py`.
33///
34/// OPT-3: #[inline] allows the compiler to merge this function into the
35/// parse_wiretag and decoder hot loops, enabling intra-procedural optimizations
36/// (constant-fold the shift sequence, avoid call overhead).  perf showed
37/// parse_varint at 4.33% and parse_wiretag at 10.49% of Path A samples.
38#[inline]
39pub fn parse_varint(buf: &[u8], start: usize) -> VarintResult {
40    let buflen = buf.len();
41    assert!(start <= buflen);
42
43    if start == buflen {
44        // Empty buffer at this position → garbage (empty)
45        return VarintResult {
46            next_pos: start,
47            varint_gar: Some(vec![]),
48            varint: None,
49            varint_ohb: None,
50        };
51    }
52
53    let mut v: u64 = 0;
54    let mut shift: u32 = 0;
55    let mut pos = start;
56    let mut too_big = false;
57
58    loop {
59        if pos >= buflen {
60            // Truncated varint — return rest of buffer as garbage (matches Python)
61            return VarintResult {
62                next_pos: buflen,
63                varint_gar: Some(buf[start..].to_vec()),
64                varint: None,
65                varint_ohb: None,
66            };
67        }
68        let b = buf[pos];
69        pos += 1;
70
71        let bits = (b & 0x7f) as u64;
72        if shift < 64 {
73            // shift == 63: the 10th byte.  Only bit 0 is valid for a u64;
74            // bits ≥ 2 would produce a value ≥ 2^64.
75            if shift == 63 && bits > 1 {
76                too_big = true;
77            } else {
78                v |= bits << shift;
79            }
80        } else {
81            // ≥ 11th byte: any set bit overflows u64.
82            if bits != 0 {
83                too_big = true;
84            }
85        }
86        shift += 7;
87
88        if b & 0x80 == 0 {
89            break; // terminator found
90        }
91
92        if shift > 70 {
93            // Absurdly long varint (> 10 bytes): consume continuation bytes and
94            // flag as too_big.
95            while pos < buflen {
96                let b2 = buf[pos];
97                pos += 1;
98                if (b2 & 0x7f) != 0 {
99                    too_big = true;
100                }
101                if b2 & 0x80 == 0 {
102                    break;
103                }
104            }
105            break;
106        }
107    }
108
109    if too_big {
110        // Python sets pos = buflen before its else-clause fires, so varint_gar
111        // always contains buf[start..] (rest of buffer) on overflow.  Matching
112        // that behaviour ensures identical INVALID_VARINT content.
113        return VarintResult {
114            next_pos: buflen,
115            varint_gar: Some(buf[start..].to_vec()),
116            varint: None,
117            varint_ohb: None,
118        };
119    }
120
121    // The byte at buf[pos-1] is the terminator (the byte that ended the varint).
122    // Use it directly instead of tracking `last_b` across loop iterations.
123    let last_b = buf[pos - 1];
124
125    // Check for overhung bytes: terminator is 0x00 preceded by ≥1 × 0x80
126    let ohb = if last_b == 0x00 && pos > start + 1 {
127        // Count trailing 0x80 bytes before the 0x00 terminator
128        let mut count: u64 = 1;
129        let mut p = pos - 2; // byte before the 0x00
130        while p > start && buf[p] == 0x80 {
131            count += 1;
132            p -= 1;
133        }
134        Some(count)
135    } else {
136        None
137    };
138
139    VarintResult {
140        next_pos: pos,
141        varint_gar: None,
142        varint: Some(v),
143        varint_ohb: ohb,
144    }
145}
146
147/// Encode a varint value (with optional overhang bytes) back to bytes.
148///
149/// Mirrors `Varint.__bytes__` in `lib/varint.py`.
150pub fn encode_varint_bytes(value: u64, ohb: Option<u64>) -> Vec<u8> {
151    let mut out = Vec::new();
152    write_varint_ohb(value, ohb, &mut out);
153    out
154}
155
156/// Append a varint encoding of `value` (with optional overhang bytes) directly
157/// into `out`, with no allocation.
158///
159/// OPT-2: This is the in-place replacement for `encode_varint_bytes`.  The old
160/// function allocated a fresh Vec<u8> per call (~18 ns each; 6× slower than
161/// appending to an existing Vec).  Callers that already have a target buffer
162/// should call this instead, eliminating the allocate-copy-free cycle that
163/// showed up as 21% memmove + 11% malloc/free in the perf profile of Path A.
164#[inline]
165pub fn write_varint_ohb(value: u64, ohb: Option<u64>, out: &mut Vec<u8>) {
166    let mut v = value;
167    loop {
168        let b = (v & 0x7f) as u8;
169        v >>= 7;
170        if v != 0 {
171            out.push(b | 0x80);
172        } else {
173            out.push(b);
174            break;
175        }
176    }
177    if let Some(count) = ohb {
178        if count > 0 {
179            *out.last_mut().unwrap() |= 0x80; // make last byte a continuation
180            for _ in 0..count - 1 {
181                out.push(0x80);
182            }
183            out.push(0x00); // final terminator
184        }
185    }
186}
187
188// ── Wiretag parser result ─────────────────────────────────────────────────────
189
190/// Result of parsing one protobuf wire tag (field number + wire type).
191///
192/// Mirrors the Python `Wiretag` class in `lib/wiretag.py`.
193///
194/// Exactly one of `wtag_gar` or `wtype` is valid:
195/// * `wtag_gar` is `Some` when the wire type is > 5 (invalid) or the
196///   field-number varint is truncated / too large.
197/// * Otherwise `wtype` holds the wire type (0–5) and `wfield` the field number.
198#[derive(Debug, Clone)]
199pub struct WiretagResult {
200    pub next_pos: usize,
201    /// Raw bytes when the tag is garbage.
202    pub wtag_gar: Option<Vec<u8>>,
203    /// Wire type (0–5); valid only when `wtag_gar` is `None`.
204    pub wtype: Option<u32>,
205    /// Field number; valid only when `wtag_gar` is `None`.
206    pub wfield: Option<u64>,
207    /// Overhang count in the field-number varint.
208    pub wfield_ohb: Option<u64>,
209    /// `true` when field number is 0 or ≥ 2²⁹.
210    pub wfield_oor: Option<bool>,
211}
212
213/// Parse one wire tag starting at `start` in `buf`.
214///
215/// Mirrors `Wiretag.__init__` in `lib/wiretag.py`.
216///
217/// OPT-3: #[inline] pairs with #[inline] on parse_varint so the compiler can
218/// fold both into the decoder.rs hot loop without separate call frames.
219#[inline]
220pub fn parse_wiretag(buf: &[u8], start: usize) -> WiretagResult {
221    let buflen = buf.len();
222    assert!(start < buflen, "parse_wiretag called at end of buffer");
223
224    let first_byte = buf[start];
225    let wtype = (first_byte & 0x07) as u32;
226
227    if wtype > 5 {
228        // Invalid wire type: consume rest of buffer as garbage
229        return WiretagResult {
230            next_pos: buflen,
231            wtag_gar: Some(buf[start..].to_vec()),
232            wtype: None,
233            wfield: None,
234            wfield_ohb: None,
235            wfield_oor: None,
236        };
237    }
238
239    // The field number occupies bits 3.. of the varint.
240    // Parse the whole tag as a varint, then extract field number from bits 3+.
241    let vr = parse_varint(buf, start);
242
243    if let Some(gar) = vr.varint_gar {
244        // Truncated or too-large tag varint
245        return WiretagResult {
246            next_pos: vr.next_pos,
247            wtag_gar: Some(gar),
248            wtype: None,
249            wfield: None,
250            wfield_ohb: None,
251            wfield_oor: None,
252        };
253    }
254
255    let raw = vr.varint.unwrap();
256    let field_number = raw >> 3;
257    let ohb = vr.varint_ohb;
258    let oor = if field_number == 0 || field_number >= (1 << 29) {
259        Some(true)
260    } else {
261        None
262    };
263
264    WiretagResult {
265        next_pos: vr.next_pos,
266        wtag_gar: None,
267        wtype: Some(wtype),
268        wfield: Some(field_number),
269        wfield_ohb: ohb,
270        wfield_oor: oor,
271    }
272}