Skip to main content

prototext_core/decoder/
mod.rs

1// SPDX-FileCopyrightText: 2025-2026 Frederic Ruget <fred@atlant.is> (GitHub: @douzebis)
2// SPDX-FileCopyrightText: 2025-2026 THALES CLOUD SECURISE SAS
3//
4// SPDX-License-Identifier: MIT
5
6mod codec;
7mod packed;
8mod types;
9
10pub use types::*;
11
12use prost_reflect::{FieldDescriptor, Kind, MessageDescriptor};
13
14use crate::helpers::{
15    decode_double, decode_fixed32, decode_fixed64, decode_float, decode_sfixed32, decode_sfixed64,
16    parse_varint, parse_wiretag, WiretagResult, WT_END_GROUP, WT_I32, WT_I64, WT_LEN,
17    WT_START_GROUP, WT_VARINT,
18};
19use crate::schema::ParsedSchema;
20
21use codec::{decode_varint_by_kind, format_annotation, TypeMismatch};
22use packed::decode_len_field;
23
24// ── Public entry point ────────────────────────────────────────────────────────
25
26/// Decode a binary protobuf payload into a lossless `ProtoTextMessage`.
27///
28/// Mirrors `decode_pb()` in `decode.py`.
29pub fn ingest_pb(
30    pb_bytes: &[u8],
31    full_schema: &ParsedSchema,
32    annotations: bool,
33) -> ProtoTextMessage {
34    let root = full_schema.root_descriptor();
35    let (msg, _, _, _) = parse_message(pb_bytes, 0, None, root.as_ref(), full_schema, annotations);
36    msg
37}
38
39// ── Core recursive parser ─────────────────────────────────────────────────────
40
41/// Parse one protobuf message starting at `start`.
42///
43/// Returns `(message, next_pos, group_end_tag, malformities)`.
44/// `group_end_tag` is `Some(tag)` when the parse terminated on an END_GROUP
45/// wire type — the caller uses this to detect mismatched group numbers.
46/// `malformities` counts structurally invalid fields (invalid tag, truncated
47/// data, etc.).  Non-canonical bytes do not count.  Used by `decode_len_field`
48/// to decide whether a LEN payload is a valid nested message (spec 0097).
49///
50/// Mirrors `parse_message()` in `decode.py`.
51pub fn parse_message(
52    buf: &[u8],
53    start: usize,
54    my_group: Option<u64>, // Some(field_number) when inside a group
55    schema: Option<&MessageDescriptor>,
56    full_schema: &ParsedSchema, // full registry for nested type lookups
57    annotations: bool,
58) -> (ProtoTextMessage, usize, Option<WiretagResult>, u32) {
59    let buflen = buf.len();
60    let mut pos = start;
61    let mut message = ProtoTextMessage::default();
62    let mut malformities: u32 = 0;
63
64    loop {
65        if pos == buflen {
66            return (message, pos, None, malformities);
67        }
68
69        let mut field = ProtoTextField::default();
70
71        // ── Parse wire tag ────────────────────────────────────────────────────
72
73        let tag = parse_wiretag(buf, pos);
74
75        if let Some(wtag_gar) = tag.wtag_gar {
76            // Invalid wire type: consume rest of buffer
77            if annotations {
78                field.annotations.push("invalid field".to_string());
79            }
80            field.field_number = Some(0);
81            field.content = ProtoTextContent::InvalidTagType(wtag_gar);
82            pos = buflen;
83            malformities += 1;
84            message.fields.push(field);
85            continue;
86        }
87
88        let field_number = tag.wfield.unwrap();
89        let wire_type = tag.wtype.unwrap();
90        field.field_number = Some(field_number);
91        if let Some(ohb) = tag.wfield_ohb {
92            field.tag_overhang_count = Some(ohb);
93        }
94        if tag.wfield_oor.is_some() {
95            field.tag_is_out_of_range = true;
96        }
97        pos = tag.next_pos;
98
99        // ── Schema lookup ─────────────────────────────────────────────────────
100
101        let field_schema: Option<FieldDescriptor> =
102            schema.and_then(|s| s.get_field(field_number as u32));
103
104        if annotations {
105            if schema.is_none() {
106                field.annotations.push("no schema".to_string());
107            } else if let Some(ref fs) = field_schema {
108                field.annotations.push(format_annotation(fs));
109            } else {
110                field.annotations.push("unknown field".to_string());
111            }
112        }
113
114        // ── Wire-type dispatch ────────────────────────────────────────────────
115
116        match wire_type {
117            // ── VARINT ───────────────────────────────────────────────────────
118            WT_VARINT => {
119                let vr = parse_varint(buf, pos);
120                if let Some(varint_gar) = vr.varint_gar {
121                    field.content = ProtoTextContent::InvalidVarint(varint_gar);
122                    pos = buflen;
123                    malformities += 1;
124                    message.fields.push(field);
125                    continue;
126                }
127                pos = vr.next_pos;
128                if let Some(ohb) = vr.varint_ohb {
129                    field.value_overhang_count = Some(ohb);
130                }
131                let val = vr.varint.unwrap();
132
133                if let Some(ref fs) = field_schema {
134                    match decode_varint_by_kind(val, fs.kind()) {
135                        Ok(content) => field.content = content,
136                        Err(TypeMismatch) => {
137                            field.proto2_has_type_mismatch = true;
138                            field.content = ProtoTextContent::WireVarint(val);
139                        }
140                    }
141                } else {
142                    field.content = ProtoTextContent::WireVarint(val);
143                }
144            }
145
146            // ── FIXED64 ──────────────────────────────────────────────────────
147            WT_I64 => {
148                if pos + 8 > buflen {
149                    field.content = ProtoTextContent::InvalidFixed64(buf[pos..].to_vec());
150                    pos = buflen;
151                    malformities += 1;
152                    message.fields.push(field);
153                    continue;
154                }
155                let data = &buf[pos..pos + 8];
156                pos += 8;
157
158                if let Some(ref fs) = field_schema {
159                    match fs.kind() {
160                        Kind::Double => {
161                            field.content = ProtoTextContent::Double(decode_double(data));
162                        }
163                        Kind::Fixed64 => {
164                            field.content = ProtoTextContent::PFixed64(decode_fixed64(data));
165                        }
166                        Kind::Sfixed64 => {
167                            field.content = ProtoTextContent::Sfixed64(decode_sfixed64(data));
168                        }
169                        _ => {
170                            field.proto2_has_type_mismatch = true;
171                            field.content = ProtoTextContent::WireFixed64(decode_fixed64(data));
172                        }
173                    }
174                } else {
175                    field.content = ProtoTextContent::WireFixed64(decode_fixed64(data));
176                }
177            }
178
179            // ── LENGTH-DELIMITED ─────────────────────────────────────────────
180            WT_LEN => {
181                let lr = parse_varint(buf, pos);
182                if lr.varint_gar.is_some() {
183                    field.content = ProtoTextContent::InvalidBytesLength(buf[pos..].to_vec());
184                    pos = buflen;
185                    malformities += 1;
186                    message.fields.push(field);
187                    continue;
188                }
189                pos = lr.next_pos;
190                if let Some(ohb) = lr.varint_ohb {
191                    field.length_overhang_count = Some(ohb);
192                }
193                let length = lr.varint.unwrap() as usize;
194
195                if pos + length > buflen {
196                    field.missing_bytes_count = Some((length - (buflen - pos)) as u64);
197                    field.content = ProtoTextContent::TruncatedBytes(buf[pos..].to_vec());
198                    pos = buflen;
199                    malformities += 1;
200                    message.fields.push(field);
201                    continue;
202                }
203                let data = &buf[pos..pos + length];
204                pos += length;
205
206                decode_len_field(
207                    data,
208                    field_schema.as_ref(),
209                    full_schema,
210                    annotations,
211                    &mut field,
212                );
213            }
214
215            // ── START GROUP ──────────────────────────────────────────────────
216            WT_START_GROUP => {
217                // Resolve nested schema via the full registry.
218                let nested_desc: Option<MessageDescriptor> = field_schema
219                    .as_ref()
220                    .filter(|fs| fs.is_group())
221                    .and_then(|fs| {
222                        if let Kind::Message(msg_desc) = fs.kind() {
223                            Some(msg_desc)
224                        } else {
225                            None
226                        }
227                    });
228
229                let (nested_msg, new_pos, end_tag, _) = parse_message(
230                    buf,
231                    pos,
232                    Some(field_number),
233                    nested_desc.as_ref(),
234                    full_schema,
235                    annotations,
236                );
237                pos = new_pos;
238
239                if end_tag.is_none() {
240                    field.open_ended_group = true;
241                    malformities += 1;
242                } else if let Some(ref et) = end_tag {
243                    if let Some(ohb) = et.wfield_ohb {
244                        field.end_tag_overhang_count = Some(ohb);
245                    }
246                    if et.wfield_oor.is_some() {
247                        field.end_tag_is_out_of_range = true;
248                    }
249                    let end_field = et.wfield.unwrap_or(0);
250                    if end_field != field_number {
251                        field.mismatched_group_end = Some(end_field);
252                    }
253                }
254                // Always store as proto2-level group (field 40), matching Python.
255                field.content = ProtoTextContent::Group(Box::new(nested_msg));
256            }
257
258            // ── END GROUP ────────────────────────────────────────────────────
259            WT_END_GROUP => {
260                if my_group.is_none() {
261                    // Unexpected END_GROUP outside a group
262                    field.content = ProtoTextContent::InvalidGroupEnd(buf[pos..].to_vec());
263                    pos = buflen;
264                    malformities += 1;
265                    message.fields.push(field);
266                    continue;
267                }
268                // Valid END_GROUP: return WITHOUT pushing the tag as a field.
269                return (message, pos, Some(tag), malformities);
270            }
271
272            // ── FIXED32 ──────────────────────────────────────────────────────
273            WT_I32 => {
274                if pos + 4 > buflen {
275                    field.content = ProtoTextContent::InvalidFixed32(buf[pos..].to_vec());
276                    pos = buflen;
277                    malformities += 1;
278                    message.fields.push(field);
279                    continue;
280                }
281                let data = &buf[pos..pos + 4];
282                pos += 4;
283
284                if let Some(ref fs) = field_schema {
285                    match fs.kind() {
286                        Kind::Float => {
287                            field.content = ProtoTextContent::Float(decode_float(data));
288                        }
289                        Kind::Fixed32 => {
290                            field.content = ProtoTextContent::PFixed32(decode_fixed32(data));
291                        }
292                        Kind::Sfixed32 => {
293                            field.content = ProtoTextContent::Sfixed32(decode_sfixed32(data));
294                        }
295                        _ => {
296                            field.proto2_has_type_mismatch = true;
297                            // Fallback: Python uses field.fixed32 (proto2-level, field 37)
298                            field.content = ProtoTextContent::PFixed32(decode_fixed32(data));
299                        }
300                    }
301                } else {
302                    // No schema fallback: Python uses field.fixed32 (proto2-level, field 37)
303                    field.content = ProtoTextContent::PFixed32(decode_fixed32(data));
304                }
305            }
306
307            _ => {
308                // Wire types 0–5 are the only valid ones; wiretag parser rejects >5.
309                unreachable!("wire type > 5 should have been caught by parse_wiretag");
310            }
311        }
312
313        message.fields.push(field);
314    }
315}