Skip to main content

prototext_core/
schema.rs

1// SPDX-FileCopyrightText: 2025 - 2026 Frederic Ruget <fred@atlant.is> <fred@s3ns.io> (GitHub: @douzebis)
2// SPDX-FileCopyrightText: 2025 - 2026 Thales Cloud Sécurisé
3//
4// SPDX-License-Identifier: MIT
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use prost::Message as ProstMessage;
10use prost_types::{field_descriptor_proto::Type, FileDescriptorProto, FileDescriptorSet};
11
12// ── Enum value collection ──────────────────────────────────────────────────────
13
14/// Temporary map from fully-qualified enum type name → sorted `(i32, name)` table.
15type EnumValueMap = HashMap<String, Vec<(i32, Box<str>)>>;
16
17// ── Public types ──────────────────────────────────────────────────────────────
18
19/// Per-field information extracted from a FieldDescriptorProto.
20#[derive(Debug, Clone)]
21pub struct FieldInfo {
22    /// Field name (for annotations).
23    pub name: String,
24    /// Proto type constant (matches Python's FieldDescriptor.TYPE_* values).
25    pub proto_type: i32,
26    /// Label constant (LABEL_OPTIONAL = 1, LABEL_REQUIRED = 2, LABEL_REPEATED = 3).
27    pub label: i32,
28    /// `true` for repeated scalar fields encoded as packed (proto2 [packed=true]
29    /// or proto3 implicit packing).
30    pub is_packed: bool,
31    /// Fully-qualified type name for MESSAGE / GROUP fields (e.g. ".pkg.Inner").
32    pub nested_type_name: Option<String>,
33    /// Enum type name for ENUM fields (for annotations).
34    pub enum_type_name: Option<String>,
35    /// Enum/message name (short) for annotations.
36    pub type_display_name: Option<String>,
37    /// Numeric value → symbolic name table for ENUM fields.
38    /// Sorted by numeric value for O(log n) lookup via binary_search_by_key.
39    /// Empty for non-ENUM fields.
40    pub enum_values: Box<[(i32, Box<str>)]>,
41}
42
43/// Schema for one protobuf message type: maps field number → FieldInfo.
44#[derive(Debug)]
45pub struct MessageSchema {
46    /// Short message name (for annotations).
47    pub name: String,
48    /// Field-number to FieldInfo map.
49    pub fields: HashMap<u32, FieldInfo>,
50}
51
52/// The parsed, indexed form of a FileDescriptorProto/Set.
53///
54/// `root_type_name` is the fully-qualified name of the root message (the one
55/// that corresponds to a `foo.pb` payload).  `messages` maps every reachable
56/// type to its `MessageSchema`.
57pub struct ParsedSchema {
58    pub root_type_name: String,
59    pub messages: HashMap<String, Arc<MessageSchema>>,
60    /// OPT-6: Pre-built Arc<HashMap> for the protoc renderer so `get_schemas()`
61    /// in lib.rs does not rebuild it on every `encode_pb("protoc")` call.
62    /// Built once at `parse_schema()` time and shared across all calls.
63    pub all_schemas: Arc<HashMap<String, Arc<MessageSchema>>>,
64}
65
66impl ParsedSchema {
67    /// Construct an empty (no-schema) `ParsedSchema`.
68    ///
69    /// Equivalent to `parse_schema(b"", "")` but infallible and allocation-free.
70    pub fn empty() -> Self {
71        ParsedSchema {
72            root_type_name: String::new(),
73            messages: HashMap::new(),
74            all_schemas: Arc::new(HashMap::new()),
75        }
76    }
77
78    /// Return the `MessageSchema` for the root message, or `None` for an empty
79    /// schema (no-schema mode, equivalent to `ctx.schema = None`).
80    pub fn root_schema(&self) -> Option<Arc<MessageSchema>> {
81        if self.root_type_name.is_empty() {
82            None
83        } else {
84            self.messages.get(&self.root_type_name).cloned()
85        }
86    }
87}
88
89// ── Error type ────────────────────────────────────────────────────────────────
90
91/// Errors that can occur while parsing a protobuf schema descriptor.
92#[non_exhaustive]
93#[derive(Debug)]
94pub enum SchemaError {
95    /// `schema_bytes` could not be decoded as a `FileDescriptorSet` or
96    /// `FileDescriptorProto`.
97    InvalidDescriptor(String),
98    /// The requested root message was not found in the parsed descriptor.
99    MessageNotFound(String),
100}
101
102impl std::fmt::Display for SchemaError {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        match self {
105            SchemaError::InvalidDescriptor(msg) => write!(f, "invalid descriptor: {msg}"),
106            SchemaError::MessageNotFound(msg) => write!(f, "message not found: {msg}"),
107        }
108    }
109}
110
111impl std::error::Error for SchemaError {}
112
113// ── Public entry point ────────────────────────────────────────────────────────
114
115/// Parse `schema_bytes` into a `ParsedSchema`.
116///
117/// Tries `FileDescriptorSet` first, falls back to `FileDescriptorProto`,
118/// mirroring the Python `load_schema_descriptor` behaviour.
119///
120/// An empty `schema_bytes` or empty `root_msg_name` returns a schema whose
121/// `root_schema()` is `None` (no-schema mode).
122pub fn parse_schema(schema_bytes: &[u8], root_msg_name: &str) -> Result<ParsedSchema, SchemaError> {
123    if schema_bytes.is_empty() || root_msg_name.is_empty() {
124        return Ok(ParsedSchema {
125            root_type_name: String::new(),
126            messages: HashMap::new(),
127            all_schemas: Arc::new(HashMap::new()), // OPT-6: empty cache
128        });
129    }
130
131    // Collect all FileDescriptorProtos: try FDS first, then bare FDP.
132    let files: Vec<FileDescriptorProto> = if let Ok(fds) = FileDescriptorSet::decode(schema_bytes) {
133        fds.file
134    } else if let Ok(fdp) = FileDescriptorProto::decode(schema_bytes) {
135        vec![fdp]
136    } else {
137        return Err(SchemaError::InvalidDescriptor(
138            "schema_bytes is neither a valid FileDescriptorSet nor FileDescriptorProto".into(),
139        ));
140    };
141
142    // Build a global registry of type_name → DescriptorProto (flat, all files).
143    // Keys are fully-qualified names like ".package.MessageName".
144    let mut raw: HashMap<String, prost_types::DescriptorProto> = HashMap::new();
145    for file in &files {
146        let pkg = file.package.as_deref().unwrap_or("");
147        collect_message_types(pkg, &file.message_type, &mut raw);
148    }
149
150    // Pass 1: collect all enum value tables from all files.
151    // Keys are fully-qualified enum type names like ".google.protobuf.FieldDescriptorProto.Type".
152    let mut enum_map: EnumValueMap = HashMap::new();
153    for file in &files {
154        let pkg = file.package.as_deref().unwrap_or("");
155        collect_enum_types(pkg, &file.enum_type, &mut enum_map);
156        // Also collect enums nested inside message types.
157        collect_nested_enum_types(pkg, &file.message_type, &mut enum_map);
158    }
159
160    // Now build MessageSchema for every collected type.
161    let mut messages: HashMap<String, Arc<MessageSchema>> = HashMap::new();
162    for (fqn, dp) in &raw {
163        let schema = build_message_schema(dp, &raw, &enum_map);
164        messages.insert(fqn.clone(), Arc::new(schema));
165    }
166
167    // Normalise root_msg_name: add leading dot if missing.
168    let root_type_name = if root_msg_name.starts_with('.') {
169        root_msg_name.to_string()
170    } else {
171        format!(".{}", root_msg_name)
172    };
173
174    if !messages.contains_key(&root_type_name) {
175        return Err(SchemaError::MessageNotFound(format!(
176            "root message '{}' not found in schema (available: {})",
177            root_type_name,
178            messages.keys().cloned().collect::<Vec<_>>().join(", ")
179        )));
180    }
181
182    // OPT-6: build all_schemas once here so get_schemas() in lib.rs can just
183    // clone the Arc instead of rebuilding the HashMap on every encode_pb("protoc").
184    let all_schemas = Arc::new(
185        messages
186            .iter()
187            .map(|(k, v)| (k.clone(), Arc::clone(v)))
188            .collect::<HashMap<_, _>>(),
189    );
190
191    Ok(ParsedSchema {
192        root_type_name,
193        messages,
194        all_schemas,
195    })
196}
197
198// ── Internal helpers ──────────────────────────────────────────────────────────
199
200/// Recursively collect all DescriptorProtos from a list of top-level or nested
201/// message types, building their fully-qualified names.
202fn collect_message_types(
203    parent_prefix: &str,
204    descriptors: &[prost_types::DescriptorProto],
205    out: &mut HashMap<String, prost_types::DescriptorProto>,
206) {
207    for dp in descriptors {
208        let name = dp.name.as_deref().unwrap_or("");
209        let fqn = if parent_prefix.is_empty() {
210            format!(".{}", name)
211        } else {
212            format!(".{}.{}", parent_prefix, name)
213        };
214        // Recurse into nested types before inserting (order doesn't matter for HashMap).
215        let nested_prefix = if parent_prefix.is_empty() {
216            name.to_string()
217        } else {
218            format!("{}.{}", parent_prefix, name)
219        };
220        collect_message_types(&nested_prefix, &dp.nested_type, out);
221        out.insert(fqn, dp.clone());
222    }
223}
224
225/// Build a `MessageSchema` from a `DescriptorProto`.
226fn build_message_schema(
227    dp: &prost_types::DescriptorProto,
228    _all: &HashMap<String, prost_types::DescriptorProto>,
229    enum_map: &EnumValueMap,
230) -> MessageSchema {
231    let mut fields = HashMap::new();
232    for fdp in &dp.field {
233        let number = match fdp.number {
234            Some(n) => n as u32,
235            None => continue,
236        };
237        let proto_type = fdp.r#type.unwrap_or(0);
238        let label = fdp.label.unwrap_or(0);
239
240        // Determine if the field is packed.
241        // The options.packed flag is set by protoc for both proto2 [packed=true]
242        // and proto3 implicit packed fields.
243        let is_packed = fdp.options.as_ref().and_then(|o| o.packed).unwrap_or(false);
244
245        // For proto3, repeated scalar fields are packed by default even if
246        // options.packed is not set in the descriptor.  Check syntax.
247        // (We use a simple heuristic: if label==REPEATED and type is scalar,
248        //  treat it as packed for proto3 files.  We don't have the file syntax
249        //  here, so we rely on protoc having already set options.packed.)
250        // In practice, protoc sets options.packed=true for all packed fields.
251
252        let nested_type_name = fdp.type_name.clone();
253
254        // Short display name for annotations (last component of type_name).
255        let type_display_name = nested_type_name
256            .as_ref()
257            .map(|tn| tn.rsplit('.').next().unwrap_or(tn).to_string());
258
259        // Enum type name (only for ENUM fields).
260        let enum_type_name = if proto_type == Type::Enum as i32 {
261            nested_type_name.clone()
262        } else {
263            None
264        };
265
266        // Resolve enum value table for ENUM fields (pass 2 — enum_map already built).
267        let enum_values: Box<[(i32, Box<str>)]> = if proto_type == Type::Enum as i32 {
268            if let Some(etn) = &enum_type_name {
269                if let Some(vals) = enum_map.get(etn.as_str()) {
270                    vals.iter()
271                        .map(|(n, s)| (*n, s.clone()))
272                        .collect::<Vec<_>>()
273                        .into_boxed_slice()
274                } else {
275                    Box::default()
276                }
277            } else {
278                Box::default()
279            }
280        } else {
281            Box::default()
282        };
283
284        let fi = FieldInfo {
285            name: fdp.name.clone().unwrap_or_default(),
286            proto_type,
287            label,
288            is_packed,
289            nested_type_name: if proto_type == Type::Message as i32
290                || proto_type == Type::Group as i32
291            {
292                nested_type_name
293            } else {
294                None
295            },
296            enum_type_name,
297            type_display_name,
298            enum_values,
299        };
300        fields.insert(number, fi);
301    }
302
303    MessageSchema {
304        name: dp.name.clone().unwrap_or_default(),
305        fields,
306    }
307}
308
309/// Collect top-level enum types from a file into the enum_map.
310fn collect_enum_types(
311    parent_prefix: &str,
312    enums: &[prost_types::EnumDescriptorProto],
313    out: &mut EnumValueMap,
314) {
315    for edp in enums {
316        let name = edp.name.as_deref().unwrap_or("");
317        let fqn = if parent_prefix.is_empty() {
318            format!(".{}", name)
319        } else {
320            format!(".{}.{}", parent_prefix, name)
321        };
322        let mut vals: Vec<(i32, Box<str>)> = edp
323            .value
324            .iter()
325            .filter_map(|vdp| {
326                let n = vdp.number?;
327                let s: Box<str> = vdp.name.as_deref().unwrap_or("").into();
328                Some((n, s))
329            })
330            .collect();
331        vals.sort_by_key(|(n, _)| *n);
332        out.insert(fqn, vals);
333    }
334}
335
336/// Recursively collect enum types nested inside message types.
337fn collect_nested_enum_types(
338    parent_prefix: &str,
339    descriptors: &[prost_types::DescriptorProto],
340    out: &mut EnumValueMap,
341) {
342    for dp in descriptors {
343        let name = dp.name.as_deref().unwrap_or("");
344        let prefix = if parent_prefix.is_empty() {
345            name.to_string()
346        } else {
347            format!("{}.{}", parent_prefix, name)
348        };
349        collect_enum_types(&prefix, &dp.enum_type, out);
350        collect_nested_enum_types(&prefix, &dp.nested_type, out);
351    }
352}
353
354// ── Proto type constants (matching Python's FieldDescriptor.TYPE_*) ────────────
355
356pub mod proto_type {
357    pub const DOUBLE: i32 = 1;
358    pub const FLOAT: i32 = 2;
359    pub const INT64: i32 = 3;
360    pub const UINT64: i32 = 4;
361    pub const INT32: i32 = 5;
362    pub const FIXED64: i32 = 6;
363    pub const FIXED32: i32 = 7;
364    pub const BOOL: i32 = 8;
365    pub const STRING: i32 = 9;
366    pub const GROUP: i32 = 10;
367    pub const MESSAGE: i32 = 11;
368    pub const BYTES: i32 = 12;
369    pub const UINT32: i32 = 13;
370    pub const ENUM: i32 = 14;
371    pub const SFIXED32: i32 = 15;
372    pub const SFIXED64: i32 = 16;
373    pub const SINT32: i32 = 17;
374    pub const SINT64: i32 = 18;
375}
376
377pub mod proto_label {
378    pub const OPTIONAL: i32 = 1;
379    pub const REQUIRED: i32 = 2;
380    pub const REPEATED: i32 = 3;
381}
382
383// ── Unit tests ────────────────────────────────────────────────────────────────
384#[cfg(test)]
385mod tests {
386    use super::*;
387    use prost::Message as ProstMessage;
388    use prost_types::{
389        DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto,
390        FileDescriptorProto, FileDescriptorSet,
391    };
392
393    /// Build a minimal FileDescriptorSet bytes containing one file with the
394    /// given message descriptor and any top-level enums provided.
395    fn build_fds(enums: Vec<EnumDescriptorProto>, message: DescriptorProto) -> Vec<u8> {
396        let file = FileDescriptorProto {
397            name: Some("test.proto".into()),
398            syntax: Some("proto2".into()),
399            enum_type: enums,
400            message_type: vec![message],
401            ..Default::default()
402        };
403        let fds = FileDescriptorSet { file: vec![file] };
404        let mut buf = Vec::new();
405        fds.encode(&mut buf).unwrap();
406        buf
407    }
408
409    fn enum_value(name: &str, number: i32) -> EnumValueDescriptorProto {
410        EnumValueDescriptorProto {
411            name: Some(name.into()),
412            number: Some(number),
413            ..Default::default()
414        }
415    }
416
417    fn enum_field(name: &str, number: i32, type_name: &str) -> FieldDescriptorProto {
418        FieldDescriptorProto {
419            name: Some(name.into()),
420            number: Some(number),
421            r#type: Some(proto_type::ENUM),
422            label: Some(proto_label::OPTIONAL),
423            type_name: Some(type_name.into()),
424            ..Default::default()
425        }
426    }
427
428    fn int32_field(name: &str, number: i32) -> FieldDescriptorProto {
429        FieldDescriptorProto {
430            name: Some(name.into()),
431            number: Some(number),
432            r#type: Some(proto_type::INT32),
433            label: Some(proto_label::OPTIONAL),
434            ..Default::default()
435        }
436    }
437
438    // ── §8.1 Two-pass enum collection ─────────────────────────────────────────
439
440    #[test]
441    fn two_pass_enum_collection() {
442        // Schema: enum Color { RED=0; GREEN=1; BLUE=2; }
443        //         message Msg { optional Color color = 1; optional int32 id = 2; }
444        let color_enum = EnumDescriptorProto {
445            name: Some("Color".into()),
446            value: vec![
447                enum_value("RED", 0),
448                enum_value("GREEN", 1),
449                enum_value("BLUE", 2),
450            ],
451            ..Default::default()
452        };
453        let msg = DescriptorProto {
454            name: Some("Msg".into()),
455            field: vec![enum_field("color", 1, ".Color"), int32_field("id", 2)],
456            ..Default::default()
457        };
458        let fds_bytes = build_fds(vec![color_enum], msg);
459        let schema = parse_schema(&fds_bytes, "Msg").unwrap();
460        let root = schema.root_schema().unwrap();
461
462        let color_fi = root.fields.get(&1).expect("field 1 must exist");
463        assert_eq!(
464            color_fi.enum_values.as_ref(),
465            &[(0, "RED".into()), (1, "GREEN".into()), (2, "BLUE".into())],
466            "enum_values must be sorted by numeric value"
467        );
468
469        let id_fi = root.fields.get(&2).expect("field 2 must exist");
470        assert!(
471            id_fi.enum_values.is_empty(),
472            "non-enum field must have empty enum_values"
473        );
474    }
475
476    // ── §8.2 Enum named after primitive keyword ────────────────────────────────
477
478    #[test]
479    fn enum_named_float_not_mistaken_for_primitive() {
480        // Schema: enum float { FLOAT_ZERO=0; FLOAT_ONE=1; }
481        //         message Msg { optional float kind = 1; }
482        let float_enum = EnumDescriptorProto {
483            name: Some("float".into()),
484            value: vec![enum_value("FLOAT_ZERO", 0), enum_value("FLOAT_ONE", 1)],
485            ..Default::default()
486        };
487        let msg = DescriptorProto {
488            name: Some("Msg".into()),
489            field: vec![enum_field("kind", 1, ".float")],
490            ..Default::default()
491        };
492        let fds_bytes = build_fds(vec![float_enum], msg);
493        let schema = parse_schema(&fds_bytes, "Msg").unwrap();
494        let root = schema.root_schema().unwrap();
495
496        let kind_fi = root.fields.get(&1).expect("field 1 must exist");
497        assert_eq!(
498            kind_fi.proto_type,
499            proto_type::ENUM,
500            "field named 'float' backed by an enum must have proto_type=ENUM"
501        );
502        assert!(
503            !kind_fi.enum_values.is_empty(),
504            "enum named 'float' must have non-empty enum_values"
505        );
506    }
507}