Skip to main content

prototext_core/decoder/
mod.rs

1// SPDX-FileCopyrightText: 2025 - 2026 Frederic Ruget <fred@atlant.is> <fred@s3ns.io> (GitHub: @douzebis)
2// SPDX-FileCopyrightText: 2025 - 2026 Thales Cloud Sécurisé
3//
4// SPDX-License-Identifier: MIT
5
6mod codec;
7mod packed;
8mod types;
9
10pub use types::*;
11
12use prost_reflect::{FieldDescriptor, Kind, MessageDescriptor};
13
14use crate::helpers::{
15    decode_double, decode_fixed32, decode_fixed64, decode_float, decode_sfixed32, decode_sfixed64,
16    parse_varint, parse_wiretag, WiretagResult, WT_END_GROUP, WT_I32, WT_I64, WT_LEN,
17    WT_START_GROUP, WT_VARINT,
18};
19use crate::schema::ParsedSchema;
20
21use codec::{decode_varint_by_kind, format_annotation, TypeMismatch};
22use packed::decode_len_field;
23
24// ── Public entry point ────────────────────────────────────────────────────────
25
26/// Decode a binary protobuf payload into a lossless `ProtoTextMessage`.
27///
28/// Mirrors `decode_pb()` in `decode.py`.
29pub fn ingest_pb(
30    pb_bytes: &[u8],
31    full_schema: &ParsedSchema,
32    annotations: bool,
33) -> ProtoTextMessage {
34    let root = full_schema.root_descriptor();
35    let (msg, _, _) = parse_message(pb_bytes, 0, None, root.as_ref(), full_schema, annotations);
36    msg
37}
38
39// ── Core recursive parser ─────────────────────────────────────────────────────
40
41/// Parse one protobuf message starting at `start`.
42///
43/// Returns `(message, next_pos, group_end_tag)`.
44/// `group_end_tag` is `Some(tag)` when the parse terminated on an END_GROUP
45/// wire type — the caller uses this to detect mismatched group numbers.
46///
47/// Mirrors `parse_message()` in `decode.py`.
48pub fn parse_message(
49    buf: &[u8],
50    start: usize,
51    my_group: Option<u64>, // Some(field_number) when inside a group
52    schema: Option<&MessageDescriptor>,
53    full_schema: &ParsedSchema, // full registry for nested type lookups
54    annotations: bool,
55) -> (ProtoTextMessage, usize, Option<WiretagResult>) {
56    let buflen = buf.len();
57    let mut pos = start;
58    let mut message = ProtoTextMessage::default();
59
60    loop {
61        if pos == buflen {
62            return (message, pos, None);
63        }
64
65        let mut field = ProtoTextField::default();
66
67        // ── Parse wire tag ────────────────────────────────────────────────────
68
69        let tag = parse_wiretag(buf, pos);
70
71        if let Some(wtag_gar) = tag.wtag_gar {
72            // Invalid wire type: consume rest of buffer
73            if annotations {
74                field.annotations.push("invalid field".to_string());
75            }
76            field.field_number = Some(0);
77            field.content = ProtoTextContent::InvalidTagType(wtag_gar);
78            pos = buflen;
79            message.fields.push(field);
80            continue;
81        }
82
83        let field_number = tag.wfield.unwrap();
84        let wire_type = tag.wtype.unwrap();
85        field.field_number = Some(field_number);
86        if let Some(ohb) = tag.wfield_ohb {
87            field.tag_overhang_count = Some(ohb);
88        }
89        if tag.wfield_oor.is_some() {
90            field.tag_is_out_of_range = true;
91        }
92        pos = tag.next_pos;
93
94        // ── Schema lookup ─────────────────────────────────────────────────────
95
96        let field_schema: Option<FieldDescriptor> =
97            schema.and_then(|s| s.get_field(field_number as u32));
98
99        if annotations {
100            if schema.is_none() {
101                field.annotations.push("no schema".to_string());
102            } else if let Some(ref fs) = field_schema {
103                field.annotations.push(format_annotation(fs));
104            } else {
105                field.annotations.push("unknown field".to_string());
106            }
107        }
108
109        // ── Wire-type dispatch ────────────────────────────────────────────────
110
111        match wire_type {
112            // ── VARINT ───────────────────────────────────────────────────────
113            WT_VARINT => {
114                let vr = parse_varint(buf, pos);
115                if let Some(varint_gar) = vr.varint_gar {
116                    field.content = ProtoTextContent::InvalidVarint(varint_gar);
117                    pos = buflen;
118                    message.fields.push(field);
119                    continue;
120                }
121                pos = vr.next_pos;
122                if let Some(ohb) = vr.varint_ohb {
123                    field.value_overhang_count = Some(ohb);
124                }
125                let val = vr.varint.unwrap();
126
127                if let Some(ref fs) = field_schema {
128                    match decode_varint_by_kind(val, fs.kind()) {
129                        Ok(content) => field.content = content,
130                        Err(TypeMismatch) => {
131                            field.proto2_has_type_mismatch = true;
132                            field.content = ProtoTextContent::WireVarint(val);
133                        }
134                    }
135                } else {
136                    field.content = ProtoTextContent::WireVarint(val);
137                }
138            }
139
140            // ── FIXED64 ──────────────────────────────────────────────────────
141            WT_I64 => {
142                if pos + 8 > buflen {
143                    field.content = ProtoTextContent::InvalidFixed64(buf[pos..].to_vec());
144                    pos = buflen;
145                    message.fields.push(field);
146                    continue;
147                }
148                let data = &buf[pos..pos + 8];
149                pos += 8;
150
151                if let Some(ref fs) = field_schema {
152                    field.content = match fs.kind() {
153                        Kind::Double => ProtoTextContent::Double(decode_double(data)),
154                        Kind::Fixed64 => ProtoTextContent::PFixed64(decode_fixed64(data)),
155                        Kind::Sfixed64 => ProtoTextContent::Sfixed64(decode_sfixed64(data)),
156                        _ => ProtoTextContent::WireFixed64(decode_fixed64(data)),
157                    };
158                } else {
159                    field.content = ProtoTextContent::WireFixed64(decode_fixed64(data));
160                }
161            }
162
163            // ── LENGTH-DELIMITED ─────────────────────────────────────────────
164            WT_LEN => {
165                let lr = parse_varint(buf, pos);
166                if lr.varint_gar.is_some() {
167                    field.content = ProtoTextContent::InvalidBytesLength(buf[pos..].to_vec());
168                    pos = buflen;
169                    message.fields.push(field);
170                    continue;
171                }
172                pos = lr.next_pos;
173                if let Some(ohb) = lr.varint_ohb {
174                    field.length_overhang_count = Some(ohb);
175                }
176                let length = lr.varint.unwrap() as usize;
177
178                if pos + length > buflen {
179                    field.missing_bytes_count = Some((length - (buflen - pos)) as u64);
180                    field.content = ProtoTextContent::TruncatedBytes(buf[pos..].to_vec());
181                    pos = buflen;
182                    message.fields.push(field);
183                    continue;
184                }
185                let data = &buf[pos..pos + length];
186                pos += length;
187
188                decode_len_field(
189                    data,
190                    field_schema.as_ref(),
191                    full_schema,
192                    annotations,
193                    &mut field,
194                );
195            }
196
197            // ── START GROUP ──────────────────────────────────────────────────
198            WT_START_GROUP => {
199                // Resolve nested schema via the full registry.
200                let nested_desc: Option<MessageDescriptor> = field_schema
201                    .as_ref()
202                    .filter(|fs| fs.is_group())
203                    .and_then(|fs| {
204                        if let Kind::Message(msg_desc) = fs.kind() {
205                            Some(msg_desc)
206                        } else {
207                            None
208                        }
209                    });
210
211                let (nested_msg, new_pos, end_tag) = parse_message(
212                    buf,
213                    pos,
214                    Some(field_number),
215                    nested_desc.as_ref(),
216                    full_schema,
217                    annotations,
218                );
219                pos = new_pos;
220
221                if end_tag.is_none() {
222                    field.open_ended_group = true;
223                } else if let Some(ref et) = end_tag {
224                    if let Some(ohb) = et.wfield_ohb {
225                        field.end_tag_overhang_count = Some(ohb);
226                    }
227                    if et.wfield_oor.is_some() {
228                        field.end_tag_is_out_of_range = true;
229                    }
230                    let end_field = et.wfield.unwrap_or(0);
231                    if end_field != field_number {
232                        field.mismatched_group_end = Some(end_field);
233                    }
234                }
235                // Always store as proto2-level group (field 40), matching Python.
236                field.content = ProtoTextContent::Group(Box::new(nested_msg));
237            }
238
239            // ── END GROUP ────────────────────────────────────────────────────
240            WT_END_GROUP => {
241                if my_group.is_none() {
242                    // Unexpected END_GROUP outside a group
243                    field.content = ProtoTextContent::InvalidGroupEnd(buf[pos..].to_vec());
244                    pos = buflen;
245                    message.fields.push(field);
246                    continue;
247                }
248                // Valid END_GROUP: return WITHOUT pushing the tag as a field.
249                return (message, pos, Some(tag));
250            }
251
252            // ── FIXED32 ──────────────────────────────────────────────────────
253            WT_I32 => {
254                if pos + 4 > buflen {
255                    field.content = ProtoTextContent::InvalidFixed32(buf[pos..].to_vec());
256                    pos = buflen;
257                    message.fields.push(field);
258                    continue;
259                }
260                let data = &buf[pos..pos + 4];
261                pos += 4;
262
263                if let Some(ref fs) = field_schema {
264                    field.content = match fs.kind() {
265                        Kind::Float => ProtoTextContent::Float(decode_float(data)),
266                        Kind::Fixed32 => ProtoTextContent::PFixed32(decode_fixed32(data)),
267                        Kind::Sfixed32 => ProtoTextContent::Sfixed32(decode_sfixed32(data)),
268                        // Fallback: Python uses field.fixed32 (proto2-level, field 37)
269                        _ => ProtoTextContent::PFixed32(decode_fixed32(data)),
270                    };
271                } else {
272                    // No schema fallback: Python uses field.fixed32 (proto2-level, field 37)
273                    field.content = ProtoTextContent::PFixed32(decode_fixed32(data));
274                }
275            }
276
277            _ => {
278                // Wire types 0–5 are the only valid ones; wiretag parser rejects >5.
279                unreachable!("wire type > 5 should have been caught by parse_wiretag");
280            }
281        }
282
283        message.fields.push(field);
284    }
285}