Skip to main content

prototext_core/
helpers.rs

1// SPDX-FileCopyrightText: 2025 - 2026 Frederic Ruget <fred@atlant.is> <fred@s3ns.io> (GitHub: @douzebis)
2// SPDX-FileCopyrightText: 2025 - 2026 Thales Cloud Sécurisé
3//
4// SPDX-License-Identifier: MIT
5
6// ── Wire-type constants ───────────────────────────────────────────────────────
7
8pub const WT_VARINT: u32 = 0;
9pub const WT_I64: u32 = 1;
10pub const WT_LEN: u32 = 2;
11pub const WT_START_GROUP: u32 = 3;
12pub const WT_END_GROUP: u32 = 4;
13pub const WT_I32: u32 = 5;
14
15// ── Varint parser result ──────────────────────────────────────────────────────
16
17/// Result of parsing one protobuf varint from a byte slice.
18///
19/// Mirrors the Python `Varint` class in `lib/varint.py`.
20///
21/// Exactly one of `varint` or `varint_gar` is `Some`:
22/// * `varint_gar` is `Some` when the varint is truncated (buffer ends before
23///   the terminator byte) or exceeds 64 bits.
24/// * `varint` is `Some` for a successfully decoded varint.
25/// * `varint_ohb` counts trailing non-canonical (overhung) bytes: set when
26///   the terminating byte is `0x00` preceded by one or more `0x80` bytes.
27#[derive(Debug)]
28pub struct VarintResult {
29    /// Byte position immediately after the parsed varint.
30    pub next_pos: usize,
31    /// `Some(raw_bytes)` when the varint is garbage (truncated / too large).
32    pub varint_gar: Option<Vec<u8>>,
33    /// The decoded varint value (valid only when `varint_gar` is `None`).
34    pub varint: Option<u64>,
35    /// Number of non-canonical overhang bytes (valid only when `varint_gar` is `None`).
36    pub varint_ohb: Option<u64>,
37}
38
39/// Parse one protobuf varint starting at `start` in `buf`.
40///
41/// Mirrors `Varint.__init__` in `lib/varint.py`.
42///
43/// OPT-3: #[inline] allows the compiler to merge this function into the
44/// parse_wiretag and decoder hot loops, enabling intra-procedural optimizations
45/// (constant-fold the shift sequence, avoid call overhead).  perf showed
46/// parse_varint at 4.33% and parse_wiretag at 10.49% of Path A samples.
47#[inline]
48pub fn parse_varint(buf: &[u8], start: usize) -> VarintResult {
49    let buflen = buf.len();
50    assert!(start <= buflen);
51
52    if start == buflen {
53        // Empty buffer at this position → garbage (empty)
54        return VarintResult {
55            next_pos: start,
56            varint_gar: Some(vec![]),
57            varint: None,
58            varint_ohb: None,
59        };
60    }
61
62    let mut v: u64 = 0;
63    let mut shift: u32 = 0;
64    let mut pos = start;
65    let mut too_big = false;
66
67    loop {
68        if pos >= buflen {
69            // Truncated varint — return rest of buffer as garbage (matches Python)
70            return VarintResult {
71                next_pos: buflen,
72                varint_gar: Some(buf[start..].to_vec()),
73                varint: None,
74                varint_ohb: None,
75            };
76        }
77        let b = buf[pos];
78        pos += 1;
79
80        let bits = (b & 0x7f) as u64;
81        if shift < 64 {
82            // shift == 63: the 10th byte.  Only bit 0 is valid for a u64;
83            // bits ≥ 2 would produce a value ≥ 2^64.
84            if shift == 63 && bits > 1 {
85                too_big = true;
86            } else {
87                v |= bits << shift;
88            }
89        } else {
90            // ≥ 11th byte: any set bit overflows u64.
91            if bits != 0 {
92                too_big = true;
93            }
94        }
95        shift += 7;
96
97        if b & 0x80 == 0 {
98            break; // terminator found
99        }
100
101        if shift > 70 {
102            // Absurdly long varint (> 10 bytes): consume continuation bytes and
103            // flag as too_big.
104            while pos < buflen {
105                let b2 = buf[pos];
106                pos += 1;
107                if (b2 & 0x7f) != 0 {
108                    too_big = true;
109                }
110                if b2 & 0x80 == 0 {
111                    break;
112                }
113            }
114            break;
115        }
116    }
117
118    if too_big {
119        // Python sets pos = buflen before its else-clause fires, so varint_gar
120        // always contains buf[start..] (rest of buffer) on overflow.  Matching
121        // that behaviour ensures identical INVALID_VARINT content.
122        return VarintResult {
123            next_pos: buflen,
124            varint_gar: Some(buf[start..].to_vec()),
125            varint: None,
126            varint_ohb: None,
127        };
128    }
129
130    // The byte at buf[pos-1] is the terminator (the byte that ended the varint).
131    // Use it directly instead of tracking `last_b` across loop iterations.
132    let last_b = buf[pos - 1];
133
134    // Check for overhung bytes: terminator is 0x00 preceded by ≥1 × 0x80
135    let ohb = if last_b == 0x00 && pos > start + 1 {
136        // Count trailing 0x80 bytes before the 0x00 terminator
137        let mut count: u64 = 1;
138        let mut p = pos - 2; // byte before the 0x00
139        while p > start && buf[p] == 0x80 {
140            count += 1;
141            p -= 1;
142        }
143        Some(count)
144    } else {
145        None
146    };
147
148    VarintResult {
149        next_pos: pos,
150        varint_gar: None,
151        varint: Some(v),
152        varint_ohb: ohb,
153    }
154}
155
156/// Encode a varint value (with optional overhang bytes) back to bytes.
157///
158/// Mirrors `Varint.__bytes__` in `lib/varint.py`.
159pub fn encode_varint_bytes(value: u64, ohb: Option<u64>) -> Vec<u8> {
160    let mut out = Vec::new();
161    write_varint_ohb(value, ohb, &mut out);
162    out
163}
164
165/// Append a varint encoding of `value` (with optional overhang bytes) directly
166/// into `out`, with no allocation.
167///
168/// OPT-2: This is the in-place replacement for `encode_varint_bytes`.  The old
169/// function allocated a fresh Vec<u8> per call (~18 ns each; 6× slower than
170/// appending to an existing Vec).  Callers that already have a target buffer
171/// should call this instead, eliminating the allocate-copy-free cycle that
172/// showed up as 21% memmove + 11% malloc/free in the perf profile of Path A.
173#[inline]
174pub fn write_varint_ohb(value: u64, ohb: Option<u64>, out: &mut Vec<u8>) {
175    let mut v = value;
176    loop {
177        let b = (v & 0x7f) as u8;
178        v >>= 7;
179        if v != 0 {
180            out.push(b | 0x80);
181        } else {
182            out.push(b);
183            break;
184        }
185    }
186    if let Some(count) = ohb {
187        if count > 0 {
188            *out.last_mut().unwrap() |= 0x80; // make last byte a continuation
189            for _ in 0..count - 1 {
190                out.push(0x80);
191            }
192            out.push(0x00); // final terminator
193        }
194    }
195}
196
197// ── Wiretag parser result ─────────────────────────────────────────────────────
198
199/// Result of parsing one protobuf wire tag (field number + wire type).
200///
201/// Mirrors the Python `Wiretag` class in `lib/wiretag.py`.
202///
203/// Exactly one of `wtag_gar` or `wtype` is valid:
204/// * `wtag_gar` is `Some` when the wire type is > 5 (invalid) or the
205///   field-number varint is truncated / too large.
206/// * Otherwise `wtype` holds the wire type (0–5) and `wfield` the field number.
207#[derive(Debug, Clone)]
208pub struct WiretagResult {
209    pub next_pos: usize,
210    /// Raw bytes when the tag is garbage.
211    pub wtag_gar: Option<Vec<u8>>,
212    /// Wire type (0–5); valid only when `wtag_gar` is `None`.
213    pub wtype: Option<u32>,
214    /// Field number; valid only when `wtag_gar` is `None`.
215    pub wfield: Option<u64>,
216    /// Overhang count in the field-number varint.
217    pub wfield_ohb: Option<u64>,
218    /// `true` when field number is 0 or ≥ 2²⁹.
219    pub wfield_oor: Option<bool>,
220}
221
222/// Parse one wire tag starting at `start` in `buf`.
223///
224/// Mirrors `Wiretag.__init__` in `lib/wiretag.py`.
225///
226/// OPT-3: #[inline] pairs with #[inline] on parse_varint so the compiler can
227/// fold both into the decoder.rs hot loop without separate call frames.
228#[inline]
229pub fn parse_wiretag(buf: &[u8], start: usize) -> WiretagResult {
230    let buflen = buf.len();
231    assert!(start < buflen, "parse_wiretag called at end of buffer");
232
233    let first_byte = buf[start];
234    let wtype = (first_byte & 0x07) as u32;
235
236    if wtype > 5 {
237        // Invalid wire type: consume rest of buffer as garbage
238        return WiretagResult {
239            next_pos: buflen,
240            wtag_gar: Some(buf[start..].to_vec()),
241            wtype: None,
242            wfield: None,
243            wfield_ohb: None,
244            wfield_oor: None,
245        };
246    }
247
248    // The field number occupies bits 3.. of the varint.
249    // Parse the whole tag as a varint, then extract field number from bits 3+.
250    let vr = parse_varint(buf, start);
251
252    if let Some(gar) = vr.varint_gar {
253        // Truncated or too-large tag varint
254        return WiretagResult {
255            next_pos: vr.next_pos,
256            wtag_gar: Some(gar),
257            wtype: None,
258            wfield: None,
259            wfield_ohb: None,
260            wfield_oor: None,
261        };
262    }
263
264    let raw = vr.varint.unwrap();
265    let field_number = raw >> 3;
266    let ohb = vr.varint_ohb;
267    let oor = if field_number == 0 || field_number >= (1 << 29) {
268        Some(true)
269    } else {
270        None
271    };
272
273    WiretagResult {
274        next_pos: vr.next_pos,
275        wtag_gar: None,
276        wtype: Some(wtype),
277        wfield: Some(field_number),
278        wfield_ohb: ohb,
279        wfield_oor: oor,
280    }
281}
282
283// ── Numeric type codecs ───────────────────────────────────────────────────────
284// Mirror prototext/helpers.py exactly.  All functions are trivial but must
285// be correct — errors here silently break round-trip fidelity.
286
287/// Decode varint as int64 (two's complement).
288#[inline]
289pub fn decode_int64(v: u64) -> i64 {
290    v as i64
291}
292
293/// Decode varint as int32 (two's complement, low 32 bits).
294#[inline]
295pub fn decode_int32(v: u64) -> i32 {
296    (v as u32) as i32
297}
298
299/// Decode varint as uint32.
300#[inline]
301pub fn decode_uint32(v: u64) -> u32 {
302    v as u32
303}
304
305/// Decode varint as uint64.
306#[inline]
307pub fn decode_uint64(v: u64) -> u64 {
308    v
309}
310
311/// Decode varint as bool (0 → false, 1 → true).
312#[inline]
313pub fn decode_bool(v: u64) -> bool {
314    v != 0
315}
316
317/// Decode varint as sint32 (zig-zag).
318#[inline]
319pub fn decode_sint32(v: u64) -> i32 {
320    let n = v as u32;
321    ((n >> 1) as i32) ^ -((n & 1) as i32)
322}
323
324/// Decode varint as sint64 (zig-zag).
325#[inline]
326pub fn decode_sint64(v: u64) -> i64 {
327    ((v >> 1) as i64) ^ -((v & 1) as i64)
328}
329
330/// Decode 4 little-endian bytes as fixed32 (uint32).
331#[inline]
332pub fn decode_fixed32(data: &[u8]) -> u32 {
333    u32::from_le_bytes(data[..4].try_into().unwrap())
334}
335
336/// Decode 4 little-endian bytes as sfixed32 (int32).
337#[inline]
338pub fn decode_sfixed32(data: &[u8]) -> i32 {
339    i32::from_le_bytes(data[..4].try_into().unwrap())
340}
341
342/// Decode 4 little-endian bytes as f32.
343#[inline]
344pub fn decode_float(data: &[u8]) -> f32 {
345    f32::from_le_bytes(data[..4].try_into().unwrap())
346}
347
348/// Decode 8 little-endian bytes as fixed64 (uint64).
349#[inline]
350pub fn decode_fixed64(data: &[u8]) -> u64 {
351    u64::from_le_bytes(data[..8].try_into().unwrap())
352}
353
354/// Decode 8 little-endian bytes as sfixed64 (int64).
355#[inline]
356pub fn decode_sfixed64(data: &[u8]) -> i64 {
357    i64::from_le_bytes(data[..8].try_into().unwrap())
358}
359
360/// Decode 8 little-endian bytes as f64.
361#[inline]
362pub fn decode_double(data: &[u8]) -> f64 {
363    f64::from_le_bytes(data[..8].try_into().unwrap())
364}
365
366// ── Wire-encoding helpers (used by pt_codec.rs) ───────────────────────────────
367
368/// Append a raw varint encoding of `value` to `buf`.
369#[inline]
370pub fn write_varint(value: u64, buf: &mut Vec<u8>) {
371    let mut v = value;
372    loop {
373        let b = (v & 0x7f) as u8;
374        v >>= 7;
375        if v != 0 {
376            buf.push(b | 0x80);
377        } else {
378            buf.push(b);
379            break;
380        }
381    }
382}
383
384/// Append a field tag (wire_type 0–5) to `buf`.
385#[inline]
386pub fn write_tag(field_number: u32, wire_type: u32, buf: &mut Vec<u8>) {
387    write_varint(((field_number as u64) << 3) | (wire_type as u64), buf);
388}
389
390/// Append a VARINT field to `buf`.
391#[inline]
392pub fn write_varint_field(field_number: u32, value: u64, buf: &mut Vec<u8>) {
393    write_tag(field_number, WT_VARINT, buf);
394    write_varint(value, buf);
395}
396
397/// Append a bool field (VARINT) to `buf` — only written when `true`.
398#[inline]
399pub fn write_bool_field(field_number: u32, value: bool, buf: &mut Vec<u8>) {
400    if value {
401        write_tag(field_number, WT_VARINT, buf);
402        buf.push(1u8);
403    }
404}
405
406/// Append an optional VARINT field to `buf` — only written when `Some`.
407#[inline]
408pub fn write_opt_varint_field(field_number: u32, value: Option<u64>, buf: &mut Vec<u8>) {
409    if let Some(v) = value {
410        write_varint_field(field_number, v, buf);
411    }
412}
413
414/// Append a LEN-delimited field to `buf`.
415#[inline]
416pub fn write_len_field(field_number: u32, data: &[u8], buf: &mut Vec<u8>) {
417    write_tag(field_number, WT_LEN, buf);
418    write_varint(data.len() as u64, buf);
419    buf.extend_from_slice(data);
420}
421
422/// Append a fixed-32 field to `buf`.
423pub fn write_fixed32_field(field_number: u32, value: u32, buf: &mut Vec<u8>) {
424    write_tag(field_number, WT_I32, buf);
425    buf.extend_from_slice(&value.to_le_bytes());
426}
427
428/// Append a fixed-64 field to `buf`.
429pub fn write_fixed64_field(field_number: u32, value: u64, buf: &mut Vec<u8>) {
430    write_tag(field_number, WT_I64, buf);
431    buf.extend_from_slice(&value.to_le_bytes());
432}
433
434// ── Unit tests ────────────────────────────────────────────────────────────────
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439
440    // ── varint round-trips ────────────────────────────────────────────────────
441
442    #[test]
443    fn varint_zero() {
444        let buf = [0x00u8];
445        let r = parse_varint(&buf, 0);
446        assert_eq!(r.varint, Some(0));
447        assert_eq!(r.varint_ohb, None);
448        assert_eq!(r.next_pos, 1);
449    }
450
451    #[test]
452    fn varint_one_byte() {
453        let buf = [0x01u8];
454        let r = parse_varint(&buf, 0);
455        assert_eq!(r.varint, Some(1));
456        assert_eq!(r.next_pos, 1);
457    }
458
459    #[test]
460    fn varint_150() {
461        // 150 = 0x96 0x01
462        let buf = [0x96u8, 0x01];
463        let r = parse_varint(&buf, 0);
464        assert_eq!(r.varint, Some(150));
465        assert_eq!(r.next_pos, 2);
466        assert_eq!(r.varint_ohb, None);
467    }
468
469    #[test]
470    fn varint_max_u64() {
471        // max u64: 10 bytes of 0xFF followed by 0x01
472        let buf = [0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01];
473        let r = parse_varint(&buf, 0);
474        assert_eq!(r.varint, Some(u64::MAX));
475        assert_eq!(r.next_pos, 10);
476    }
477
478    #[test]
479    fn varint_truncated() {
480        // Continuation byte with no terminator
481        let buf = [0x80u8, 0x80];
482        let r = parse_varint(&buf, 0);
483        assert!(r.varint_gar.is_some());
484        assert!(r.varint.is_none());
485    }
486
487    #[test]
488    fn varint_empty_at_end() {
489        let buf = [0x01u8];
490        let r = parse_varint(&buf, 1); // start == buflen
491        assert!(r.varint_gar.is_some());
492        assert_eq!(r.varint_gar.unwrap(), Vec::<u8>::new());
493    }
494
495    #[test]
496    fn varint_overhang_one() {
497        // 0x00 encoded non-canonically as 0x80 0x00
498        let buf = [0x80u8, 0x00];
499        let r = parse_varint(&buf, 0);
500        assert_eq!(r.varint, Some(0));
501        assert_eq!(r.varint_ohb, Some(1));
502    }
503
504    #[test]
505    fn varint_overhang_two() {
506        // 0x00 encoded as 0x80 0x80 0x00  (2 overhang bytes)
507        let buf = [0x80u8, 0x80, 0x00];
508        let r = parse_varint(&buf, 0);
509        assert_eq!(r.varint, Some(0));
510        assert_eq!(r.varint_ohb, Some(2));
511    }
512
513    #[test]
514    fn varint_encode_with_overhang() {
515        let bytes = encode_varint_bytes(0, Some(1));
516        assert_eq!(bytes, vec![0x80, 0x00]);
517
518        let bytes2 = encode_varint_bytes(0, Some(2));
519        assert_eq!(bytes2, vec![0x80, 0x80, 0x00]);
520
521        let bytes3 = encode_varint_bytes(150, None);
522        assert_eq!(bytes3, vec![0x96, 0x01]);
523    }
524
525    #[test]
526    fn varint_encode_roundtrip() {
527        for val in [0u64, 1, 127, 128, 300, 16383, 16384, u64::MAX] {
528            let encoded = encode_varint_bytes(val, None);
529            let r = parse_varint(&encoded, 0);
530            assert_eq!(r.varint, Some(val), "roundtrip failed for {val}");
531            assert_eq!(r.next_pos, encoded.len());
532        }
533    }
534
535    // ── wiretag ──────────────────────────────────────────────────────────────
536
537    #[test]
538    fn wiretag_field1_varint() {
539        // tag for field 1, wire type 0: (1 << 3) | 0 = 0x08
540        let buf = [0x08u8];
541        let r = parse_wiretag(&buf, 0);
542        assert_eq!(r.wtype, Some(0));
543        assert_eq!(r.wfield, Some(1));
544        assert_eq!(r.wfield_ohb, None);
545        assert_eq!(r.wfield_oor, None);
546    }
547
548    #[test]
549    fn wiretag_invalid_wire_type() {
550        // wire type 6 is invalid
551        let buf = [0x06u8, 0x00, 0x01];
552        let r = parse_wiretag(&buf, 0);
553        assert!(r.wtag_gar.is_some());
554        assert!(r.wtype.is_none());
555    }
556
557    #[test]
558    fn wiretag_field_number_zero_is_oor() {
559        // field number 0: wire type 0, field 0 → (0 << 3) | 0 = 0x00
560        // but parse_wiretag asserts start < buflen, so use a buffer with content
561        let buf = [0x00u8]; // tag byte = 0 → wire_type=0, field=0
562        let r = parse_wiretag(&buf, 0);
563        assert_eq!(r.wfield, Some(0));
564        assert_eq!(r.wfield_oor, Some(true));
565    }
566
567    #[test]
568    fn wiretag_overhung() {
569        // Field 1, wire type 0 encoded non-canonically: (0x08) as 0x88 0x00
570        let buf = [0x88u8, 0x00];
571        let r = parse_wiretag(&buf, 0);
572        assert_eq!(r.wtype, Some(0));
573        assert_eq!(r.wfield, Some(1));
574        assert_eq!(r.wfield_ohb, Some(1));
575    }
576
577    // ── numeric codecs ────────────────────────────────────────────────────────
578
579    #[test]
580    fn int32_negative() {
581        // -1 as int32 is stored as 0xFFFFFFFF in a varint
582        assert_eq!(decode_int32(0xFFFFFFFF), -1i32);
583    }
584
585    #[test]
586    fn int64_negative() {
587        assert_eq!(decode_int64(u64::MAX), -1i64);
588    }
589
590    #[test]
591    fn sint32_roundtrip() {
592        for v in [-1i32, 0, 1, -2, 2, i32::MIN, i32::MAX] {
593            let encoded = if v >= 0 {
594                ((v as u32) << 1) as u64
595            } else {
596                ((!v as u32) * 2 + 1) as u64
597            };
598            assert_eq!(decode_sint32(encoded), v, "sint32 roundtrip for {v}");
599        }
600    }
601
602    #[test]
603    fn sint64_roundtrip() {
604        for v in [-1i64, 0, 1, -2, 2, i64::MIN, i64::MAX] {
605            let encoded = if v >= 0 {
606                (v as u64) << 1
607            } else {
608                ((!v as u64) << 1) | 1
609            };
610            assert_eq!(decode_sint64(encoded), v, "sint64 roundtrip for {v}");
611        }
612    }
613
614    #[test]
615    fn fixed32_little_endian() {
616        let data = [0x01u8, 0x00, 0x00, 0x00];
617        assert_eq!(decode_fixed32(&data), 1u32);
618        let data2 = [0xFFu8, 0xFF, 0xFF, 0xFF];
619        assert_eq!(decode_fixed32(&data2), u32::MAX);
620    }
621
622    #[test]
623    fn fixed64_little_endian() {
624        let data = [0x01u8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
625        assert_eq!(decode_fixed64(&data), 1u64);
626    }
627
628    #[test]
629    fn double_roundtrip() {
630        let val = std::f64::consts::PI;
631        let data = val.to_le_bytes();
632        assert_eq!(decode_double(&data), val);
633    }
634
635    #[test]
636    fn float_roundtrip() {
637        let val = 1.5f32;
638        let data = val.to_le_bytes();
639        assert_eq!(decode_float(&data), val);
640    }
641
642    #[test]
643    fn write_varint_field_roundtrip() {
644        let mut buf = Vec::new();
645        write_varint_field(1, 300, &mut buf);
646        // tag: (1<<3)|0 = 0x08; value 300 = 0xAC 0x02
647        assert_eq!(buf, vec![0x08, 0xAC, 0x02]);
648    }
649
650    #[test]
651    fn write_len_field_roundtrip() {
652        let mut buf = Vec::new();
653        write_len_field(2, b"hi", &mut buf);
654        // tag: (2<<3)|2 = 0x12; length: 0x02; data: 0x68 0x69
655        assert_eq!(buf, vec![0x12, 0x02, 0x68, 0x69]);
656    }
657}