Skip to main content

prototext_core/serialize/render_text/
mod.rs

1// SPDX-FileCopyrightText: 2025 - 2026 Frederic Ruget <fred@atlant.is> <fred@s3ns.io> (GitHub: @douzebis)
2// SPDX-FileCopyrightText: 2025 - 2026 Thales Cloud Sécurisé
3//
4// SPDX-License-Identifier: MIT
5
6mod helpers;
7mod packed;
8mod varint;
9
10use std::cell::Cell;
11use std::collections::HashMap;
12use std::sync::Arc;
13
14use prost_reflect::{Cardinality, ExtensionDescriptor, FieldDescriptor, Kind, MessageDescriptor};
15
16use crate::helpers::{
17    decode_double, decode_fixed32, decode_fixed64, decode_float, decode_sfixed32, decode_sfixed64,
18};
19use crate::helpers::{
20    parse_varint, parse_wiretag, WiretagResult, WT_END_GROUP, WT_I32, WT_I64, WT_LEN,
21    WT_START_GROUP, WT_VARINT,
22};
23use crate::schema::ParsedSchema;
24use crate::serialize::common::{
25    format_double_protoc, format_fixed32_protoc, format_fixed64_protoc, format_float_protoc,
26    format_sfixed32_protoc, format_sfixed64_protoc, format_wire_fixed32_protoc,
27    format_wire_fixed64_protoc,
28};
29
30use helpers::{
31    render_group_field, render_invalid, render_invalid_tag_type, render_len_field, render_scalar,
32    render_truncated_bytes, ScalarCtx,
33};
34use varint::{decode_varint_typed, render_varint_field, VarintKind};
35
36// Magic prefix that identifies a textual prototext payload.
37const PROTOTEXT_MAGIC: &[u8] = b"#@ prototext:";
38
39// ── FieldOrExt adapter ────────────────────────────────────────────────────────
40
41/// Unifies `FieldDescriptor` (regular field) and `ExtensionDescriptor`
42/// (extension field) for the subset of accessors used by the renderer.
43pub(super) enum FieldOrExt {
44    Field(FieldDescriptor),
45    Ext(ExtensionDescriptor),
46}
47
48impl FieldOrExt {
49    pub(super) fn kind(&self) -> Kind {
50        match self {
51            FieldOrExt::Field(f) => f.kind(),
52            FieldOrExt::Ext(e) => e.kind(),
53        }
54    }
55
56    pub(super) fn cardinality(&self) -> Cardinality {
57        match self {
58            FieldOrExt::Field(f) => f.cardinality(),
59            FieldOrExt::Ext(e) => e.cardinality(),
60        }
61    }
62
63    /// Returns `true` only for regular group fields; extensions cannot be groups.
64    pub(super) fn is_group(&self) -> bool {
65        match self {
66            FieldOrExt::Field(f) => f.is_group(),
67            FieldOrExt::Ext(_) => false,
68        }
69    }
70
71    pub(super) fn is_packed(&self) -> bool {
72        match self {
73            FieldOrExt::Field(f) => f.is_packed(),
74            FieldOrExt::Ext(_) => false,
75        }
76    }
77
78    /// The name to use in field-line output.
79    ///
80    /// Regular field: `"name"` (bare field name).
81    /// Extension field: `"[full.qualified.name]"`.
82    pub(super) fn display_name(&self) -> String {
83        match self {
84            FieldOrExt::Field(f) => f.name().to_owned(),
85            FieldOrExt::Ext(e) => format!("[{}]", e.full_name()),
86        }
87    }
88
89    /// Returns the underlying `FieldDescriptor` if this is a regular field,
90    /// or `None` for extension fields.
91    ///
92    /// Used to pass to functions that still take `Option<&FieldDescriptor>`.
93    #[allow(dead_code)]
94    pub(super) fn as_field(&self) -> Option<&FieldDescriptor> {
95        match self {
96            FieldOrExt::Field(f) => Some(f),
97            FieldOrExt::Ext(_) => None,
98        }
99    }
100}
101
102// ── Render-mode state ─────────────────────────────────────────────────────────
103//
104// `CBL_START` is set to `out.len()` by `write_close_brace` before writing a
105// `}\n` line, and reset to `out.len()` (past-end) by every other write.  It
106// is currently unused beyond being maintained; the close-brace folding feature
107// it was intended to support has been removed.
108//
109thread_local! {
110    pub(super) static CBL_START:    Cell<usize> = const { Cell::new(0) };
111    // Set once per `decode_and_render` call; read by every internal render fn.
112    pub(super) static ANNOTATIONS:  Cell<bool>  = const { Cell::new(false) };
113    pub(super) static INDENT_SIZE:  Cell<usize> = const { Cell::new(2) };
114    // Tracks recursion depth; managed via `enter_level()` / `LevelGuard`.
115    pub(super) static LEVEL:        Cell<usize> = const { Cell::new(0) };
116}
117
118/// RAII guard for `LEVEL`: increments on construction, decrements on drop.
119/// Guarantees the level is restored even if the callee panics.
120pub(super) struct LevelGuard;
121
122impl Drop for LevelGuard {
123    fn drop(&mut self) {
124        LEVEL.with(|l| l.set(l.get() - 1));
125    }
126}
127
128pub(super) fn enter_level() -> LevelGuard {
129    LEVEL.with(|l| l.set(l.get() + 1));
130    LevelGuard
131}
132
133/// Return `true` when `data` is already rendered prototext text (fast-path).
134pub fn is_prototext_text(data: &[u8]) -> bool {
135    data.starts_with(PROTOTEXT_MAGIC)
136}
137
138// ── Public entry point ────────────────────────────────────────────────────────
139
140/// Decode raw protobuf binary and render as protoc-style text in one pass.
141///
142/// Writes the `#@ prototext: protoc\n` header followed by field lines directly
143/// into a pre-allocated `Vec<u8>`.
144///
145/// Parameters mirror `format_as_text` in `lib.rs`.
146pub fn decode_and_render(
147    buf: &[u8],
148    schema: Option<&ParsedSchema>,
149    annotations: bool,
150    indent_size: usize,
151) -> Vec<u8> {
152    let capacity = buf.len() * 8;
153    let mut out = Vec::with_capacity(capacity);
154
155    // Header
156    out.extend_from_slice(b"#@ prototext: protoc\n");
157    // Initialise render-mode state.
158    // CBL_START past the end so the first write_close_brace always takes
159    // the fresh-write path.
160    CBL_START.with(|c| c.set(out.len()));
161    ANNOTATIONS.with(|c| c.set(annotations));
162    INDENT_SIZE.with(|c| c.set(indent_size));
163    LEVEL.with(|c| c.set(0));
164
165    // Build a flat name→MessageDescriptor map for nested-type lookups.
166    // Keyed by bare FQN (no leading dot), matching prost-reflect's convention.
167    let all_descriptors: Option<HashMap<String, Arc<MessageDescriptor>>> =
168        schema.map(|s| build_descriptor_map(s));
169    let all_schemas = all_descriptors.as_ref();
170
171    let root_desc: Option<MessageDescriptor> = schema.and_then(|s| s.root_descriptor());
172
173    render_message(buf, 0, None, root_desc.as_ref(), all_schemas, &mut out);
174
175    // Development instrumentation — truncate event
176    #[cfg(debug_assertions)]
177    {
178        let actual = out.len();
179        if actual < capacity {
180            eprintln!(
181                "[render_text] truncate: input_len={} capacity={} actual={} ratio={:.2}x",
182                buf.len(),
183                capacity,
184                actual,
185                actual as f64 / buf.len().max(1) as f64
186            );
187        }
188    }
189
190    out
191}
192
193/// Build a `HashMap<bare_fqn, Arc<MessageDescriptor>>` from a `ParsedSchema`.
194fn build_descriptor_map(schema: &ParsedSchema) -> HashMap<String, Arc<MessageDescriptor>> {
195    schema
196        .pool()
197        .all_messages()
198        .map(|msg| (msg.full_name().to_string(), Arc::new(msg)))
199        .collect()
200}
201
202// ── Core recursive render-while-decode ───────────────────────────────────────
203
204/// Parse and render one protobuf message into `out`.
205///
206/// Returns `(next_pos, group_end_tag)`:
207/// - `next_pos`: byte position after this message (for the caller to
208///   continue its own parse loop, or for GROUP end detection).
209/// - `group_end_tag`: `Some(tag)` when parsing terminated on a `WT_END_GROUP`.
210pub(super) fn render_message(
211    buf: &[u8],
212    start: usize,
213    my_group: Option<u64>,
214    schema: Option<&MessageDescriptor>,
215    all_schemas: Option<&HashMap<String, Arc<MessageDescriptor>>>,
216    out: &mut Vec<u8>,
217) -> (usize, Option<WiretagResult>) {
218    let buflen = buf.len();
219    let mut pos = start;
220
221    loop {
222        if pos == buflen {
223            return (pos, None);
224        }
225
226        // ── Parse wire tag ────────────────────────────────────────────────────
227
228        let tag = parse_wiretag(buf, pos);
229
230        if let Some(ref wtag_gar) = tag.wtag_gar {
231            // Invalid wire tag: consume rest of buffer as INVALID_TAG_TYPE
232            render_invalid_tag_type(wtag_gar, out);
233            return (buflen, None);
234        }
235
236        let field_number = tag.wfield.unwrap();
237        let wire_type = tag.wtype.unwrap();
238        let tag_ohb = tag.wfield_ohb;
239        let tag_oor = tag.wfield_oor.is_some();
240        pos = tag.next_pos;
241
242        // ── Schema lookup ─────────────────────────────────────────────────────
243
244        let field_schema: Option<FieldOrExt> = schema.and_then(|s| {
245            if let Some(f) = s.get_field(field_number as u32) {
246                Some(FieldOrExt::Field(f))
247            } else {
248                s.get_extension(field_number as u32).map(FieldOrExt::Ext)
249            }
250        });
251
252        // ── Wire-type dispatch ────────────────────────────────────────────────
253
254        match wire_type {
255            // ── VARINT ───────────────────────────────────────────────────────
256            WT_VARINT => {
257                let vr = parse_varint(buf, pos);
258                if let Some(ref varint_gar) = vr.varint_gar {
259                    render_invalid(
260                        field_number,
261                        field_schema.as_ref(),
262                        tag_ohb,
263                        tag_oor,
264                        "INVALID_VARINT",
265                        varint_gar,
266                        out,
267                    );
268                    return (buflen, None);
269                }
270                pos = vr.next_pos;
271                let val_ohb = vr.varint_ohb;
272                let val = vr.varint.unwrap();
273
274                let (content_kind, typed_val) = if let Some(ref fs) = field_schema {
275                    decode_varint_typed(val, fs)
276                } else {
277                    (VarintKind::Wire, val)
278                };
279
280                render_varint_field(
281                    field_number,
282                    field_schema.as_ref(),
283                    tag_ohb,
284                    tag_oor,
285                    val_ohb,
286                    content_kind,
287                    typed_val,
288                    out,
289                );
290            }
291
292            // ── FIXED64 ──────────────────────────────────────────────────────
293            WT_I64 => {
294                if pos + 8 > buflen {
295                    let raw = &buf[pos..];
296                    render_invalid(
297                        field_number,
298                        field_schema.as_ref(),
299                        tag_ohb,
300                        tag_oor,
301                        "INVALID_FIXED64",
302                        raw,
303                        out,
304                    );
305                    return (buflen, None);
306                }
307                let data = &buf[pos..pos + 8];
308                pos += 8;
309
310                let is_mismatch;
311                let mut nan_bits: Option<u64> = None;
312                let value_str = if let Some(ref fs) = field_schema {
313                    match fs.kind() {
314                        Kind::Double => {
315                            is_mismatch = false;
316                            let v = decode_double(data);
317                            if v.is_nan() {
318                                let bits = v.to_bits();
319                                if bits != f64::NAN.to_bits() {
320                                    nan_bits = Some(bits);
321                                }
322                            }
323                            format_double_protoc(v)
324                        }
325                        Kind::Fixed64 => {
326                            is_mismatch = false;
327                            format_fixed64_protoc(decode_fixed64(data))
328                        }
329                        Kind::Sfixed64 => {
330                            is_mismatch = false;
331                            format_sfixed64_protoc(decode_sfixed64(data))
332                        }
333                        _ => {
334                            is_mismatch = true;
335                            format_wire_fixed64_protoc(decode_fixed64(data))
336                        } // mismatch → hex
337                    }
338                } else {
339                    is_mismatch = false;
340                    format_wire_fixed64_protoc(decode_fixed64(data)) // unknown → hex
341                };
342
343                render_scalar(
344                    &ScalarCtx {
345                        field_number,
346                        field_schema: field_schema.as_ref(),
347                        tag_ohb,
348                        tag_oor,
349                        len_ohb: None,
350                        wire_type_name: "fixed64",
351                        nan_bits,
352                    },
353                    &value_str,
354                    is_mismatch,
355                    out,
356                );
357            }
358
359            // ── LENGTH-DELIMITED ─────────────────────────────────────────────
360            WT_LEN => {
361                let lr = parse_varint(buf, pos);
362                if let Some(ref varint_gar) = lr.varint_gar {
363                    render_invalid(
364                        field_number,
365                        field_schema.as_ref(),
366                        tag_ohb,
367                        tag_oor,
368                        "INVALID_LEN",
369                        varint_gar,
370                        out,
371                    );
372                    return (buflen, None);
373                }
374                let len_ohb = lr.varint_ohb;
375                pos = lr.next_pos;
376                let length = lr.varint.unwrap() as usize;
377
378                if pos + length > buflen {
379                    let missing = (length - (buflen - pos)) as u64;
380                    let raw = &buf[pos..];
381                    render_truncated_bytes(
382                        field_number,
383                        tag_ohb,
384                        tag_oor,
385                        len_ohb,
386                        missing,
387                        raw,
388                        out,
389                    );
390                    return (buflen, None);
391                }
392                let data = &buf[pos..pos + length];
393                pos += length;
394
395                render_len_field(
396                    field_number,
397                    field_schema.as_ref(),
398                    all_schemas,
399                    tag_ohb,
400                    tag_oor,
401                    len_ohb,
402                    data,
403                    out,
404                );
405            }
406
407            // ── START GROUP ──────────────────────────────────────────────────
408            WT_START_GROUP => {
409                render_group_field(
410                    buf,
411                    &mut pos,
412                    field_number,
413                    field_schema.as_ref(),
414                    all_schemas,
415                    tag_ohb,
416                    tag_oor,
417                    out,
418                );
419            }
420
421            // ── END GROUP ────────────────────────────────────────────────────
422            WT_END_GROUP => {
423                if my_group.is_none() {
424                    // Unexpected END_GROUP outside a group
425                    let raw = &buf[pos..];
426                    render_invalid(
427                        field_number,
428                        field_schema.as_ref(),
429                        tag_ohb,
430                        tag_oor,
431                        "INVALID_GROUP_END",
432                        raw,
433                        out,
434                    );
435                    return (buflen, None);
436                }
437                // Valid END_GROUP: return to parent without rendering a field.
438                return (pos, Some(tag));
439            }
440
441            // ── FIXED32 ──────────────────────────────────────────────────────
442            WT_I32 => {
443                if pos + 4 > buflen {
444                    let raw = &buf[pos..];
445                    render_invalid(
446                        field_number,
447                        field_schema.as_ref(),
448                        tag_ohb,
449                        tag_oor,
450                        "INVALID_FIXED32",
451                        raw,
452                        out,
453                    );
454                    return (buflen, None);
455                }
456                let data = &buf[pos..pos + 4];
457                pos += 4;
458
459                let is_mismatch;
460                let mut nan_bits: Option<u64> = None;
461                let value_str = if let Some(ref fs) = field_schema {
462                    match fs.kind() {
463                        Kind::Float => {
464                            is_mismatch = false;
465                            let v = decode_float(data);
466                            if v.is_nan() {
467                                let bits = v.to_bits();
468                                if bits != f32::NAN.to_bits() {
469                                    nan_bits = Some(bits as u64);
470                                }
471                            }
472                            format_float_protoc(v)
473                        }
474                        Kind::Fixed32 => {
475                            is_mismatch = false;
476                            format_fixed32_protoc(decode_fixed32(data))
477                        }
478                        Kind::Sfixed32 => {
479                            is_mismatch = false;
480                            format_sfixed32_protoc(decode_sfixed32(data))
481                        }
482                        _ => {
483                            is_mismatch = true;
484                            format_wire_fixed32_protoc(decode_fixed32(data))
485                        } // mismatch → hex (D2)
486                    }
487                } else {
488                    is_mismatch = false;
489                    format_wire_fixed32_protoc(decode_fixed32(data)) // unknown → hex
490                };
491
492                render_scalar(
493                    &ScalarCtx {
494                        field_number,
495                        field_schema: field_schema.as_ref(),
496                        tag_ohb,
497                        tag_oor,
498                        len_ohb: None,
499                        wire_type_name: "fixed32",
500                        nan_bits,
501                    },
502                    &value_str,
503                    is_mismatch,
504                    out,
505                );
506            }
507
508            _ => unreachable!("wire type > 5 caught by parse_wiretag"),
509        }
510    }
511}