Skip to main content

prototext_core/
schema.rs

1// SPDX-FileCopyrightText: 2025 - 2026 Frederic Ruget <fred@atlant.is> <fred@s3ns.io> (GitHub: @douzebis)
2// SPDX-FileCopyrightText: 2025 - 2026 Thales Cloud Sécurisé
3//
4// SPDX-License-Identifier: MIT
5
6use prost_reflect::{DescriptorPool, MessageDescriptor};
7
8// ── Public types ──────────────────────────────────────────────────────────────
9
10/// The parsed, indexed form of a `FileDescriptorSet`.
11///
12/// Owns a `DescriptorPool` and the fully-qualified name of the root message.
13pub struct ParsedSchema {
14    pool: DescriptorPool,
15    root_full_name: String,
16}
17
18impl ParsedSchema {
19    /// Construct an empty (no-schema) `ParsedSchema`.
20    pub fn empty() -> Self {
21        ParsedSchema {
22            pool: DescriptorPool::new(),
23            root_full_name: String::new(),
24        }
25    }
26
27    /// Return the `MessageDescriptor` for the root message, or `None` for an
28    /// empty schema (no-schema mode).
29    pub fn root_descriptor(&self) -> Option<MessageDescriptor> {
30        if self.root_full_name.is_empty() {
31            None
32        } else {
33            self.pool.get_message_by_name(&self.root_full_name)
34        }
35    }
36
37    /// Look up a message descriptor by fully-qualified name (no leading dot).
38    pub fn get_descriptor(&self, fqn: &str) -> Option<MessageDescriptor> {
39        self.pool.get_message_by_name(fqn)
40    }
41
42    /// Access the underlying descriptor pool.
43    pub fn pool(&self) -> &DescriptorPool {
44        &self.pool
45    }
46}
47
48// ── Error type ────────────────────────────────────────────────────────────────
49
50/// Errors that can occur while parsing a protobuf schema descriptor.
51#[non_exhaustive]
52#[derive(Debug)]
53pub enum SchemaError {
54    /// `schema_bytes` could not be decoded as a `FileDescriptorSet`.
55    InvalidDescriptor(String),
56    /// The requested root message was not found in the parsed descriptor.
57    MessageNotFound(String),
58}
59
60impl std::fmt::Display for SchemaError {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        match self {
63            SchemaError::InvalidDescriptor(msg) => write!(f, "invalid descriptor: {msg}"),
64            SchemaError::MessageNotFound(msg) => write!(f, "message not found: {msg}"),
65        }
66    }
67}
68
69impl std::error::Error for SchemaError {}
70
71// ── Public entry point ────────────────────────────────────────────────────────
72
73/// Parse `schema_bytes` into a `ParsedSchema`.
74///
75/// An empty `schema_bytes` or empty `root_msg_name` returns a schema whose
76/// `root_descriptor()` is `None` (no-schema mode).
77pub fn parse_schema(schema_bytes: &[u8], root_msg_name: &str) -> Result<ParsedSchema, SchemaError> {
78    if schema_bytes.is_empty() || root_msg_name.is_empty() {
79        return Ok(ParsedSchema::empty());
80    }
81
82    let pool = DescriptorPool::decode(schema_bytes)
83        .map_err(|e| SchemaError::InvalidDescriptor(e.to_string()))?;
84
85    // prost-reflect uses no leading dot; strip one if the caller passed it.
86    let root_full_name = root_msg_name.trim_start_matches('.').to_string();
87
88    if pool.get_message_by_name(&root_full_name).is_none() {
89        let available = pool
90            .all_messages()
91            .map(|m| m.full_name().to_string())
92            .collect::<Vec<_>>()
93            .join(", ");
94        return Err(SchemaError::MessageNotFound(format!(
95            "root message '{}' not found in schema (available: {})",
96            root_full_name, available
97        )));
98    }
99
100    Ok(ParsedSchema {
101        pool,
102        root_full_name,
103    })
104}
105
106// ── Unit tests ────────────────────────────────────────────────────────────────
107#[cfg(test)]
108mod tests {
109    use super::*;
110    use prost::Message as ProstMessage;
111    use prost_reflect::Kind;
112    use prost_types::{
113        DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto,
114        FileDescriptorProto, FileDescriptorSet,
115    };
116
117    // proto_type integers used only to construct test descriptors
118    const TYPE_ENUM: i32 = 14;
119    const TYPE_INT32: i32 = 5;
120    const LABEL_OPTIONAL: i32 = 1;
121
122    /// Build a minimal FileDescriptorSet bytes containing one file with the
123    /// given message descriptor and any top-level enums provided.
124    fn build_fds(enums: Vec<EnumDescriptorProto>, message: DescriptorProto) -> Vec<u8> {
125        let file = FileDescriptorProto {
126            name: Some("test.proto".into()),
127            syntax: Some("proto2".into()),
128            enum_type: enums,
129            message_type: vec![message],
130            ..Default::default()
131        };
132        let fds = FileDescriptorSet { file: vec![file] };
133        let mut buf = Vec::new();
134        fds.encode(&mut buf).unwrap();
135        buf
136    }
137
138    fn enum_value(name: &str, number: i32) -> EnumValueDescriptorProto {
139        EnumValueDescriptorProto {
140            name: Some(name.into()),
141            number: Some(number),
142            ..Default::default()
143        }
144    }
145
146    fn enum_field(name: &str, number: i32, type_name: &str) -> FieldDescriptorProto {
147        FieldDescriptorProto {
148            name: Some(name.into()),
149            number: Some(number),
150            r#type: Some(TYPE_ENUM),
151            label: Some(LABEL_OPTIONAL),
152            type_name: Some(type_name.into()),
153            ..Default::default()
154        }
155    }
156
157    fn int32_field(name: &str, number: i32) -> FieldDescriptorProto {
158        FieldDescriptorProto {
159            name: Some(name.into()),
160            number: Some(number),
161            r#type: Some(TYPE_INT32),
162            label: Some(LABEL_OPTIONAL),
163            ..Default::default()
164        }
165    }
166
167    // ── §8.1 Two-pass enum collection ─────────────────────────────────────────
168
169    #[test]
170    fn two_pass_enum_collection() {
171        // Schema: enum Color { RED=0; GREEN=1; BLUE=2; }
172        //         message Msg { optional Color color = 1; optional int32 id = 2; }
173        let color_enum = EnumDescriptorProto {
174            name: Some("Color".into()),
175            value: vec![
176                enum_value("RED", 0),
177                enum_value("GREEN", 1),
178                enum_value("BLUE", 2),
179            ],
180            ..Default::default()
181        };
182        let msg = DescriptorProto {
183            name: Some("Msg".into()),
184            field: vec![enum_field("color", 1, ".Color"), int32_field("id", 2)],
185            ..Default::default()
186        };
187        let fds_bytes = build_fds(vec![color_enum], msg);
188        let schema = parse_schema(&fds_bytes, "Msg").unwrap();
189        let root = schema.root_descriptor().unwrap();
190
191        let color_fd = root.get_field(1).expect("field 1 must exist");
192        let Kind::Enum(enum_desc) = color_fd.kind() else {
193            panic!("field 1 must be an enum");
194        };
195        let names: Vec<String> = enum_desc.values().map(|v| v.name().to_owned()).collect();
196        assert_eq!(
197            names,
198            &["RED", "GREEN", "BLUE"],
199            "enum values must be present"
200        );
201
202        let id_fd = root.get_field(2).expect("field 2 must exist");
203        assert_eq!(id_fd.kind(), Kind::Int32, "field 2 must be int32");
204    }
205
206    // ── §8.3 Extension field visibility ───────────────────────────────────────
207
208    /// Demonstrates that after spec-0011, `ParsedSchema` exposes extension
209    /// descriptors registered on a message, making them available for
210    /// rendering as `[pkg.ext_name]`.
211    ///
212    /// Schema:
213    ///   package acme;
214    ///   message Gadget { extensions 1000 to 1999; }
215    ///   extend Gadget { optional int32 blade_count = 1000; }
216    ///
217    /// Wire bytes: field 1000, wire-type varint (0), value 42.
218    ///   tag  = (1000 << 3) | 0 = 8000 → varint 0xC0 0x3E
219    ///   value = 42 → varint 0x2A
220    #[test]
221    fn extension_field_visible_via_get_extension() {
222        use prost_types::descriptor_proto::ExtensionRange;
223
224        let extension_field = FieldDescriptorProto {
225            name: Some("blade_count".into()),
226            number: Some(1000),
227            r#type: Some(TYPE_INT32),
228            label: Some(LABEL_OPTIONAL),
229            extendee: Some(".acme.Gadget".into()),
230            ..Default::default()
231        };
232
233        let gadget_msg = DescriptorProto {
234            name: Some("Gadget".into()),
235            extension_range: vec![ExtensionRange {
236                start: Some(1000),
237                end: Some(2000),
238                ..Default::default()
239            }],
240            ..Default::default()
241        };
242
243        let file = prost_types::FileDescriptorProto {
244            name: Some("gadget.proto".into()),
245            syntax: Some("proto2".into()),
246            package: Some("acme".into()),
247            message_type: vec![gadget_msg],
248            extension: vec![extension_field],
249            ..Default::default()
250        };
251        let fds = prost_types::FileDescriptorSet { file: vec![file] };
252        let mut buf = Vec::new();
253        fds.encode(&mut buf).unwrap();
254
255        let schema = parse_schema(&buf, "acme.Gadget").unwrap();
256        let root = schema.root_descriptor().unwrap();
257
258        // The extension must be visible via get_extension().
259        let ext = root
260            .get_extension(1000)
261            .expect("extension field 1000 must be visible");
262        assert_eq!(ext.full_name(), "acme.blade_count");
263        assert_eq!(ext.kind(), prost_reflect::Kind::Int32);
264    }
265
266    // ── §8.2 Enum named after primitive keyword ────────────────────────────────
267
268    #[test]
269    fn enum_named_float_not_mistaken_for_primitive() {
270        // Schema: enum float { FLOAT_ZERO=0; FLOAT_ONE=1; }
271        //         message Msg { optional float kind = 1; }
272        let float_enum = EnumDescriptorProto {
273            name: Some("float".into()),
274            value: vec![enum_value("FLOAT_ZERO", 0), enum_value("FLOAT_ONE", 1)],
275            ..Default::default()
276        };
277        let msg = DescriptorProto {
278            name: Some("Msg".into()),
279            field: vec![enum_field("kind", 1, ".float")],
280            ..Default::default()
281        };
282        let fds_bytes = build_fds(vec![float_enum], msg);
283        let schema = parse_schema(&fds_bytes, "Msg").unwrap();
284        let root = schema.root_descriptor().unwrap();
285
286        let kind_fd = root.get_field(1).expect("field 1 must exist");
287        assert!(
288            matches!(kind_fd.kind(), Kind::Enum(_)),
289            "field named 'float' backed by an enum must have Kind::Enum"
290        );
291        let Kind::Enum(enum_desc) = kind_fd.kind() else {
292            unreachable!()
293        };
294        assert!(
295            enum_desc.values().count() > 0,
296            "enum named 'float' must have non-empty values"
297        );
298    }
299}