prototext_core/helpers/varint.rs
1// SPDX-FileCopyrightText: 2025 - 2026 Frederic Ruget <fred@atlant.is> <fred@s3ns.io> (GitHub: @douzebis)
2// SPDX-FileCopyrightText: 2025 - 2026 Thales Cloud Sécurisé
3//
4// SPDX-License-Identifier: MIT
5
6// ── Varint parser result ──────────────────────────────────────────────────────
7
8/// Result of parsing one protobuf varint from a byte slice.
9///
10/// Mirrors the Python `Varint` class in `lib/varint.py`.
11///
12/// Exactly one of `varint` or `varint_gar` is `Some`:
13/// * `varint_gar` is `Some` when the varint is truncated (buffer ends before
14/// the terminator byte) or exceeds 64 bits.
15/// * `varint` is `Some` for a successfully decoded varint.
16/// * `varint_ohb` counts trailing non-canonical (overhung) bytes: set when
17/// the terminating byte is `0x00` preceded by one or more `0x80` bytes.
18#[derive(Debug)]
19pub struct VarintResult {
20 /// Byte position immediately after the parsed varint.
21 pub next_pos: usize,
22 /// `Some(raw_bytes)` when the varint is garbage (truncated / too large).
23 pub varint_gar: Option<Vec<u8>>,
24 /// The decoded varint value (valid only when `varint_gar` is `None`).
25 pub varint: Option<u64>,
26 /// Number of non-canonical overhang bytes (valid only when `varint_gar` is `None`).
27 pub varint_ohb: Option<u64>,
28}
29
30/// Parse one protobuf varint starting at `start` in `buf`.
31///
32/// Mirrors `Varint.__init__` in `lib/varint.py`.
33///
34/// OPT-3: #[inline] allows the compiler to merge this function into the
35/// parse_wiretag and decoder hot loops, enabling intra-procedural optimizations
36/// (constant-fold the shift sequence, avoid call overhead). perf showed
37/// parse_varint at 4.33% and parse_wiretag at 10.49% of Path A samples.
38#[inline]
39pub fn parse_varint(buf: &[u8], start: usize) -> VarintResult {
40 let buflen = buf.len();
41 assert!(start <= buflen);
42
43 if start == buflen {
44 // Empty buffer at this position → garbage (empty)
45 return VarintResult {
46 next_pos: start,
47 varint_gar: Some(vec![]),
48 varint: None,
49 varint_ohb: None,
50 };
51 }
52
53 let mut v: u64 = 0;
54 let mut shift: u32 = 0;
55 let mut pos = start;
56 let mut too_big = false;
57
58 loop {
59 if pos >= buflen {
60 // Truncated varint — return rest of buffer as garbage (matches Python)
61 return VarintResult {
62 next_pos: buflen,
63 varint_gar: Some(buf[start..].to_vec()),
64 varint: None,
65 varint_ohb: None,
66 };
67 }
68 let b = buf[pos];
69 pos += 1;
70
71 let bits = (b & 0x7f) as u64;
72 if shift < 64 {
73 // shift == 63: the 10th byte. Only bit 0 is valid for a u64;
74 // bits ≥ 2 would produce a value ≥ 2^64.
75 if shift == 63 && bits > 1 {
76 too_big = true;
77 } else {
78 v |= bits << shift;
79 }
80 } else {
81 // ≥ 11th byte: any set bit overflows u64.
82 if bits != 0 {
83 too_big = true;
84 }
85 }
86 shift += 7;
87
88 if b & 0x80 == 0 {
89 break; // terminator found
90 }
91
92 if shift > 70 {
93 // Absurdly long varint (> 10 bytes): consume continuation bytes and
94 // flag as too_big.
95 while pos < buflen {
96 let b2 = buf[pos];
97 pos += 1;
98 if (b2 & 0x7f) != 0 {
99 too_big = true;
100 }
101 if b2 & 0x80 == 0 {
102 break;
103 }
104 }
105 break;
106 }
107 }
108
109 if too_big {
110 // Python sets pos = buflen before its else-clause fires, so varint_gar
111 // always contains buf[start..] (rest of buffer) on overflow. Matching
112 // that behaviour ensures identical INVALID_VARINT content.
113 return VarintResult {
114 next_pos: buflen,
115 varint_gar: Some(buf[start..].to_vec()),
116 varint: None,
117 varint_ohb: None,
118 };
119 }
120
121 // The byte at buf[pos-1] is the terminator (the byte that ended the varint).
122 // Use it directly instead of tracking `last_b` across loop iterations.
123 let last_b = buf[pos - 1];
124
125 // Check for overhung bytes: terminator is 0x00 preceded by ≥1 × 0x80
126 let ohb = if last_b == 0x00 && pos > start + 1 {
127 // Count trailing 0x80 bytes before the 0x00 terminator
128 let mut count: u64 = 1;
129 let mut p = pos - 2; // byte before the 0x00
130 while p > start && buf[p] == 0x80 {
131 count += 1;
132 p -= 1;
133 }
134 Some(count)
135 } else {
136 None
137 };
138
139 VarintResult {
140 next_pos: pos,
141 varint_gar: None,
142 varint: Some(v),
143 varint_ohb: ohb,
144 }
145}
146
147/// Encode a varint value (with optional overhang bytes) back to bytes.
148///
149/// Mirrors `Varint.__bytes__` in `lib/varint.py`.
150pub fn encode_varint_bytes(value: u64, ohb: Option<u64>) -> Vec<u8> {
151 let mut out = Vec::new();
152 write_varint_ohb(value, ohb, &mut out);
153 out
154}
155
156/// Append a varint encoding of `value` (with optional overhang bytes) directly
157/// into `out`, with no allocation.
158///
159/// OPT-2: This is the in-place replacement for `encode_varint_bytes`. The old
160/// function allocated a fresh Vec<u8> per call (~18 ns each; 6× slower than
161/// appending to an existing Vec). Callers that already have a target buffer
162/// should call this instead, eliminating the allocate-copy-free cycle that
163/// showed up as 21% memmove + 11% malloc/free in the perf profile of Path A.
164#[inline]
165pub fn write_varint_ohb(value: u64, ohb: Option<u64>, out: &mut Vec<u8>) {
166 let mut v = value;
167 loop {
168 let b = (v & 0x7f) as u8;
169 v >>= 7;
170 if v != 0 {
171 out.push(b | 0x80);
172 } else {
173 out.push(b);
174 break;
175 }
176 }
177 if let Some(count) = ohb {
178 if count > 0 {
179 *out.last_mut().unwrap() |= 0x80; // make last byte a continuation
180 for _ in 0..count - 1 {
181 out.push(0x80);
182 }
183 out.push(0x00); // final terminator
184 }
185 }
186}
187
188// ── Wiretag parser result ─────────────────────────────────────────────────────
189
190/// Result of parsing one protobuf wire tag (field number + wire type).
191///
192/// Mirrors the Python `Wiretag` class in `lib/wiretag.py`.
193///
194/// Exactly one of `wtag_gar` or `wtype` is valid:
195/// * `wtag_gar` is `Some` when the wire type is > 5 (invalid) or the
196/// field-number varint is truncated / too large.
197/// * Otherwise `wtype` holds the wire type (0–5) and `wfield` the field number.
198#[derive(Debug, Clone)]
199pub struct WiretagResult {
200 pub next_pos: usize,
201 /// Raw bytes when the tag is garbage.
202 pub wtag_gar: Option<Vec<u8>>,
203 /// Wire type (0–5); valid only when `wtag_gar` is `None`.
204 pub wtype: Option<u32>,
205 /// Field number; valid only when `wtag_gar` is `None`.
206 pub wfield: Option<u64>,
207 /// Overhang count in the field-number varint.
208 pub wfield_ohb: Option<u64>,
209 /// `true` when field number is 0 or ≥ 2²⁹.
210 pub wfield_oor: Option<bool>,
211}
212
213/// Parse one wire tag starting at `start` in `buf`.
214///
215/// Mirrors `Wiretag.__init__` in `lib/wiretag.py`.
216///
217/// OPT-3: #[inline] pairs with #[inline] on parse_varint so the compiler can
218/// fold both into the decoder.rs hot loop without separate call frames.
219#[inline]
220pub fn parse_wiretag(buf: &[u8], start: usize) -> WiretagResult {
221 let buflen = buf.len();
222 assert!(start < buflen, "parse_wiretag called at end of buffer");
223
224 let first_byte = buf[start];
225 let wtype = (first_byte & 0x07) as u32;
226
227 if wtype > 5 {
228 // Invalid wire type: consume rest of buffer as garbage
229 return WiretagResult {
230 next_pos: buflen,
231 wtag_gar: Some(buf[start..].to_vec()),
232 wtype: None,
233 wfield: None,
234 wfield_ohb: None,
235 wfield_oor: None,
236 };
237 }
238
239 // The field number occupies bits 3.. of the varint.
240 // Parse the whole tag as a varint, then extract field number from bits 3+.
241 let vr = parse_varint(buf, start);
242
243 if let Some(gar) = vr.varint_gar {
244 // Truncated or too-large tag varint
245 return WiretagResult {
246 next_pos: vr.next_pos,
247 wtag_gar: Some(gar),
248 wtype: None,
249 wfield: None,
250 wfield_ohb: None,
251 wfield_oor: None,
252 };
253 }
254
255 let raw = vr.varint.unwrap();
256 let field_number = raw >> 3;
257 let ohb = vr.varint_ohb;
258 let oor = if field_number == 0 || field_number >= (1 << 29) {
259 Some(true)
260 } else {
261 None
262 };
263
264 WiretagResult {
265 next_pos: vr.next_pos,
266 wtag_gar: None,
267 wtype: Some(wtype),
268 wfield: Some(field_number),
269 wfield_ohb: ohb,
270 wfield_oor: oor,
271 }
272}