Skip to main content

prototext_core/
schema.rs

1// SPDX-FileCopyrightText: 2025-2026 Frederic Ruget <fred@atlant.is> (GitHub: @douzebis)
2// SPDX-FileCopyrightText: 2025-2026 THALES CLOUD SECURISE SAS
3//
4// SPDX-License-Identifier: MIT
5
6use prost_reflect::{DescriptorPool, MessageDescriptor};
7
8// ── Public types ──────────────────────────────────────────────────────────────
9
10/// The parsed, indexed form of a `FileDescriptorSet`.
11///
12/// Owns a `DescriptorPool` and the fully-qualified name of the root message.
13pub struct ParsedSchema {
14    pool: DescriptorPool,
15    root_full_name: String,
16}
17
18impl ParsedSchema {
19    /// Construct an empty (no-schema) `ParsedSchema`.
20    pub fn empty() -> Self {
21        ParsedSchema {
22            pool: DescriptorPool::new(),
23            root_full_name: String::new(),
24        }
25    }
26
27    /// Return the `MessageDescriptor` for the root message, or `None` for an
28    /// empty schema (no-schema mode).
29    pub fn root_descriptor(&self) -> Option<MessageDescriptor> {
30        if self.root_full_name.is_empty() {
31            None
32        } else {
33            self.pool.get_message_by_name(&self.root_full_name)
34        }
35    }
36
37    /// Look up a message descriptor by fully-qualified name (no leading dot).
38    pub fn get_descriptor(&self, fqn: &str) -> Option<MessageDescriptor> {
39        self.pool.get_message_by_name(fqn)
40    }
41
42    /// Access the underlying descriptor pool.
43    pub fn pool(&self) -> &DescriptorPool {
44        &self.pool
45    }
46}
47
48// ── Error type ────────────────────────────────────────────────────────────────
49
50/// Errors that can occur while parsing a protobuf schema descriptor.
51#[non_exhaustive]
52#[derive(Debug)]
53pub enum SchemaError {
54    /// `schema_bytes` could not be decoded as a `FileDescriptorSet`.
55    InvalidDescriptor(String),
56    /// The requested root message was not found in the parsed descriptor.
57    MessageNotFound(String),
58}
59
60impl std::fmt::Display for SchemaError {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        match self {
63            SchemaError::InvalidDescriptor(msg) => write!(f, "invalid descriptor: {msg}"),
64            SchemaError::MessageNotFound(msg) => write!(f, "message not found: {msg}"),
65        }
66    }
67}
68
69impl std::error::Error for SchemaError {}
70
71// ── Public entry point ────────────────────────────────────────────────────────
72
73/// Decode `schema_bytes` into a `DescriptorPool`, for use with
74/// `schema_from_pool` when instantiating multiple types from the same bytes.
75pub fn decode_pool(schema_bytes: &[u8]) -> Result<DescriptorPool, SchemaError> {
76    DescriptorPool::decode(schema_bytes).map_err(|e| SchemaError::InvalidDescriptor(e.to_string()))
77}
78
79/// Create a `ParsedSchema` from an already-decoded pool and a root message name.
80pub fn schema_from_pool(
81    pool: DescriptorPool,
82    root_msg_name: &str,
83) -> Result<ParsedSchema, SchemaError> {
84    let root_full_name = root_msg_name.trim_start_matches('.').to_string();
85    if pool.get_message_by_name(&root_full_name).is_none() {
86        let available = pool
87            .all_messages()
88            .map(|m| m.full_name().to_string())
89            .collect::<Vec<_>>()
90            .join(", ");
91        return Err(SchemaError::MessageNotFound(format!(
92            "root message '{}' not found in schema (available: {})",
93            root_full_name, available
94        )));
95    }
96    Ok(ParsedSchema {
97        pool,
98        root_full_name,
99    })
100}
101
102/// Parse `schema_bytes` into a `ParsedSchema`.
103///
104/// An empty `schema_bytes` or empty `root_msg_name` returns a schema whose
105/// `root_descriptor()` is `None` (no-schema mode).
106pub fn parse_schema(schema_bytes: &[u8], root_msg_name: &str) -> Result<ParsedSchema, SchemaError> {
107    if schema_bytes.is_empty() || root_msg_name.is_empty() {
108        return Ok(ParsedSchema::empty());
109    }
110
111    let pool = DescriptorPool::decode(schema_bytes)
112        .map_err(|e| SchemaError::InvalidDescriptor(e.to_string()))?;
113
114    // prost-reflect uses no leading dot; strip one if the caller passed it.
115    let root_full_name = root_msg_name.trim_start_matches('.').to_string();
116
117    if pool.get_message_by_name(&root_full_name).is_none() {
118        let available = pool
119            .all_messages()
120            .map(|m| m.full_name().to_string())
121            .collect::<Vec<_>>()
122            .join(", ");
123        return Err(SchemaError::MessageNotFound(format!(
124            "root message '{}' not found in schema (available: {})",
125            root_full_name, available
126        )));
127    }
128
129    Ok(ParsedSchema {
130        pool,
131        root_full_name,
132    })
133}
134
135// ── Unit tests ────────────────────────────────────────────────────────────────
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use prost::Message as ProstMessage;
140    use prost_reflect::Kind;
141    use prost_types::{
142        DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto,
143        FileDescriptorProto, FileDescriptorSet,
144    };
145
146    // proto_type integers used only to construct test descriptors
147    const TYPE_ENUM: i32 = 14;
148    const TYPE_INT32: i32 = 5;
149    const LABEL_OPTIONAL: i32 = 1;
150
151    /// Build a minimal FileDescriptorSet bytes containing one file with the
152    /// given message descriptor and any top-level enums provided.
153    fn build_fds(enums: Vec<EnumDescriptorProto>, message: DescriptorProto) -> Vec<u8> {
154        let file = FileDescriptorProto {
155            name: Some("test.proto".into()),
156            syntax: Some("proto2".into()),
157            enum_type: enums,
158            message_type: vec![message],
159            ..Default::default()
160        };
161        let fds = FileDescriptorSet { file: vec![file] };
162        let mut buf = Vec::new();
163        fds.encode(&mut buf).unwrap();
164        buf
165    }
166
167    fn enum_value(name: &str, number: i32) -> EnumValueDescriptorProto {
168        EnumValueDescriptorProto {
169            name: Some(name.into()),
170            number: Some(number),
171            ..Default::default()
172        }
173    }
174
175    fn enum_field(name: &str, number: i32, type_name: &str) -> FieldDescriptorProto {
176        FieldDescriptorProto {
177            name: Some(name.into()),
178            number: Some(number),
179            r#type: Some(TYPE_ENUM),
180            label: Some(LABEL_OPTIONAL),
181            type_name: Some(type_name.into()),
182            ..Default::default()
183        }
184    }
185
186    fn int32_field(name: &str, number: i32) -> FieldDescriptorProto {
187        FieldDescriptorProto {
188            name: Some(name.into()),
189            number: Some(number),
190            r#type: Some(TYPE_INT32),
191            label: Some(LABEL_OPTIONAL),
192            ..Default::default()
193        }
194    }
195
196    // ── §8.1 Two-pass enum collection ─────────────────────────────────────────
197
198    #[test]
199    fn two_pass_enum_collection() {
200        // Schema: enum Color { RED=0; GREEN=1; BLUE=2; }
201        //         message Msg { optional Color color = 1; optional int32 id = 2; }
202        let color_enum = EnumDescriptorProto {
203            name: Some("Color".into()),
204            value: vec![
205                enum_value("RED", 0),
206                enum_value("GREEN", 1),
207                enum_value("BLUE", 2),
208            ],
209            ..Default::default()
210        };
211        let msg = DescriptorProto {
212            name: Some("Msg".into()),
213            field: vec![enum_field("color", 1, ".Color"), int32_field("id", 2)],
214            ..Default::default()
215        };
216        let fds_bytes = build_fds(vec![color_enum], msg);
217        let schema = parse_schema(&fds_bytes, "Msg").unwrap();
218        let root = schema.root_descriptor().unwrap();
219
220        let color_fd = root.get_field(1).expect("field 1 must exist");
221        let Kind::Enum(enum_desc) = color_fd.kind() else {
222            panic!("field 1 must be an enum");
223        };
224        let names: Vec<String> = enum_desc.values().map(|v| v.name().to_owned()).collect();
225        assert_eq!(
226            names,
227            &["RED", "GREEN", "BLUE"],
228            "enum values must be present"
229        );
230
231        let id_fd = root.get_field(2).expect("field 2 must exist");
232        assert_eq!(id_fd.kind(), Kind::Int32, "field 2 must be int32");
233    }
234
235    // ── §8.3 Extension field visibility ───────────────────────────────────────
236
237    /// Demonstrates that after spec-0011, `ParsedSchema` exposes extension
238    /// descriptors registered on a message, making them available for
239    /// rendering as `[pkg.ext_name]`.
240    ///
241    /// Schema:
242    ///   package acme;
243    ///   message Gadget { extensions 1000 to 1999; }
244    ///   extend Gadget { optional int32 blade_count = 1000; }
245    ///
246    /// Wire bytes: field 1000, wire-type varint (0), value 42.
247    ///   tag  = (1000 << 3) | 0 = 8000 → varint 0xC0 0x3E
248    ///   value = 42 → varint 0x2A
249    #[test]
250    fn extension_field_visible_via_get_extension() {
251        use prost_types::descriptor_proto::ExtensionRange;
252
253        let extension_field = FieldDescriptorProto {
254            name: Some("blade_count".into()),
255            number: Some(1000),
256            r#type: Some(TYPE_INT32),
257            label: Some(LABEL_OPTIONAL),
258            extendee: Some(".acme.Gadget".into()),
259            ..Default::default()
260        };
261
262        let gadget_msg = DescriptorProto {
263            name: Some("Gadget".into()),
264            extension_range: vec![ExtensionRange {
265                start: Some(1000),
266                end: Some(2000),
267                ..Default::default()
268            }],
269            ..Default::default()
270        };
271
272        let file = prost_types::FileDescriptorProto {
273            name: Some("gadget.proto".into()),
274            syntax: Some("proto2".into()),
275            package: Some("acme".into()),
276            message_type: vec![gadget_msg],
277            extension: vec![extension_field],
278            ..Default::default()
279        };
280        let fds = prost_types::FileDescriptorSet { file: vec![file] };
281        let mut buf = Vec::new();
282        fds.encode(&mut buf).unwrap();
283
284        let schema = parse_schema(&buf, "acme.Gadget").unwrap();
285        let root = schema.root_descriptor().unwrap();
286
287        // The extension must be visible via get_extension().
288        let ext = root
289            .get_extension(1000)
290            .expect("extension field 1000 must be visible");
291        assert_eq!(ext.full_name(), "acme.blade_count");
292        assert_eq!(ext.kind(), prost_reflect::Kind::Int32);
293    }
294
295    // ── §8.2 Enum named after primitive keyword ────────────────────────────────
296
297    #[test]
298    fn enum_named_float_not_mistaken_for_primitive() {
299        // Schema: enum float { FLOAT_ZERO=0; FLOAT_ONE=1; }
300        //         message Msg { optional float kind = 1; }
301        let float_enum = EnumDescriptorProto {
302            name: Some("float".into()),
303            value: vec![enum_value("FLOAT_ZERO", 0), enum_value("FLOAT_ONE", 1)],
304            ..Default::default()
305        };
306        let msg = DescriptorProto {
307            name: Some("Msg".into()),
308            field: vec![enum_field("kind", 1, ".float")],
309            ..Default::default()
310        };
311        let fds_bytes = build_fds(vec![float_enum], msg);
312        let schema = parse_schema(&fds_bytes, "Msg").unwrap();
313        let root = schema.root_descriptor().unwrap();
314
315        let kind_fd = root.get_field(1).expect("field 1 must exist");
316        assert!(
317            matches!(kind_fd.kind(), Kind::Enum(_)),
318            "field named 'float' backed by an enum must have Kind::Enum"
319        );
320        let Kind::Enum(enum_desc) = kind_fd.kind() else {
321            unreachable!()
322        };
323        assert!(
324            enum_desc.values().count() > 0,
325            "enum named 'float' must have non-empty values"
326        );
327    }
328}