1use prost_reflect::{DescriptorPool, MessageDescriptor};
7
8pub struct ParsedSchema {
14 pool: DescriptorPool,
15 root_full_name: String,
16}
17
18impl ParsedSchema {
19 pub fn empty() -> Self {
21 ParsedSchema {
22 pool: DescriptorPool::new(),
23 root_full_name: String::new(),
24 }
25 }
26
27 pub fn root_descriptor(&self) -> Option<MessageDescriptor> {
30 if self.root_full_name.is_empty() {
31 None
32 } else {
33 self.pool.get_message_by_name(&self.root_full_name)
34 }
35 }
36
37 pub fn get_descriptor(&self, fqn: &str) -> Option<MessageDescriptor> {
39 self.pool.get_message_by_name(fqn)
40 }
41
42 pub fn pool(&self) -> &DescriptorPool {
44 &self.pool
45 }
46}
47
48#[non_exhaustive]
52#[derive(Debug)]
53pub enum SchemaError {
54 InvalidDescriptor(String),
56 MessageNotFound(String),
58}
59
60impl std::fmt::Display for SchemaError {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 match self {
63 SchemaError::InvalidDescriptor(msg) => write!(f, "invalid descriptor: {msg}"),
64 SchemaError::MessageNotFound(msg) => write!(f, "message not found: {msg}"),
65 }
66 }
67}
68
69impl std::error::Error for SchemaError {}
70
71pub fn parse_schema(schema_bytes: &[u8], root_msg_name: &str) -> Result<ParsedSchema, SchemaError> {
78 if schema_bytes.is_empty() || root_msg_name.is_empty() {
79 return Ok(ParsedSchema::empty());
80 }
81
82 let pool = DescriptorPool::decode(schema_bytes)
83 .map_err(|e| SchemaError::InvalidDescriptor(e.to_string()))?;
84
85 let root_full_name = root_msg_name.trim_start_matches('.').to_string();
87
88 if pool.get_message_by_name(&root_full_name).is_none() {
89 let available = pool
90 .all_messages()
91 .map(|m| m.full_name().to_string())
92 .collect::<Vec<_>>()
93 .join(", ");
94 return Err(SchemaError::MessageNotFound(format!(
95 "root message '{}' not found in schema (available: {})",
96 root_full_name, available
97 )));
98 }
99
100 Ok(ParsedSchema {
101 pool,
102 root_full_name,
103 })
104}
105
106#[cfg(test)]
108mod tests {
109 use super::*;
110 use prost::Message as ProstMessage;
111 use prost_reflect::Kind;
112 use prost_types::{
113 DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto,
114 FileDescriptorProto, FileDescriptorSet,
115 };
116
117 const TYPE_ENUM: i32 = 14;
119 const TYPE_INT32: i32 = 5;
120 const LABEL_OPTIONAL: i32 = 1;
121
122 fn build_fds(enums: Vec<EnumDescriptorProto>, message: DescriptorProto) -> Vec<u8> {
125 let file = FileDescriptorProto {
126 name: Some("test.proto".into()),
127 syntax: Some("proto2".into()),
128 enum_type: enums,
129 message_type: vec![message],
130 ..Default::default()
131 };
132 let fds = FileDescriptorSet { file: vec![file] };
133 let mut buf = Vec::new();
134 fds.encode(&mut buf).unwrap();
135 buf
136 }
137
138 fn enum_value(name: &str, number: i32) -> EnumValueDescriptorProto {
139 EnumValueDescriptorProto {
140 name: Some(name.into()),
141 number: Some(number),
142 ..Default::default()
143 }
144 }
145
146 fn enum_field(name: &str, number: i32, type_name: &str) -> FieldDescriptorProto {
147 FieldDescriptorProto {
148 name: Some(name.into()),
149 number: Some(number),
150 r#type: Some(TYPE_ENUM),
151 label: Some(LABEL_OPTIONAL),
152 type_name: Some(type_name.into()),
153 ..Default::default()
154 }
155 }
156
157 fn int32_field(name: &str, number: i32) -> FieldDescriptorProto {
158 FieldDescriptorProto {
159 name: Some(name.into()),
160 number: Some(number),
161 r#type: Some(TYPE_INT32),
162 label: Some(LABEL_OPTIONAL),
163 ..Default::default()
164 }
165 }
166
167 #[test]
170 fn two_pass_enum_collection() {
171 let color_enum = EnumDescriptorProto {
174 name: Some("Color".into()),
175 value: vec![
176 enum_value("RED", 0),
177 enum_value("GREEN", 1),
178 enum_value("BLUE", 2),
179 ],
180 ..Default::default()
181 };
182 let msg = DescriptorProto {
183 name: Some("Msg".into()),
184 field: vec![enum_field("color", 1, ".Color"), int32_field("id", 2)],
185 ..Default::default()
186 };
187 let fds_bytes = build_fds(vec![color_enum], msg);
188 let schema = parse_schema(&fds_bytes, "Msg").unwrap();
189 let root = schema.root_descriptor().unwrap();
190
191 let color_fd = root.get_field(1).expect("field 1 must exist");
192 let Kind::Enum(enum_desc) = color_fd.kind() else {
193 panic!("field 1 must be an enum");
194 };
195 let names: Vec<String> = enum_desc.values().map(|v| v.name().to_owned()).collect();
196 assert_eq!(
197 names,
198 &["RED", "GREEN", "BLUE"],
199 "enum values must be present"
200 );
201
202 let id_fd = root.get_field(2).expect("field 2 must exist");
203 assert_eq!(id_fd.kind(), Kind::Int32, "field 2 must be int32");
204 }
205
206 #[test]
221 fn extension_field_visible_via_get_extension() {
222 use prost_types::descriptor_proto::ExtensionRange;
223
224 let extension_field = FieldDescriptorProto {
225 name: Some("blade_count".into()),
226 number: Some(1000),
227 r#type: Some(TYPE_INT32),
228 label: Some(LABEL_OPTIONAL),
229 extendee: Some(".acme.Gadget".into()),
230 ..Default::default()
231 };
232
233 let gadget_msg = DescriptorProto {
234 name: Some("Gadget".into()),
235 extension_range: vec![ExtensionRange {
236 start: Some(1000),
237 end: Some(2000),
238 ..Default::default()
239 }],
240 ..Default::default()
241 };
242
243 let file = prost_types::FileDescriptorProto {
244 name: Some("gadget.proto".into()),
245 syntax: Some("proto2".into()),
246 package: Some("acme".into()),
247 message_type: vec![gadget_msg],
248 extension: vec![extension_field],
249 ..Default::default()
250 };
251 let fds = prost_types::FileDescriptorSet { file: vec![file] };
252 let mut buf = Vec::new();
253 fds.encode(&mut buf).unwrap();
254
255 let schema = parse_schema(&buf, "acme.Gadget").unwrap();
256 let root = schema.root_descriptor().unwrap();
257
258 let ext = root
260 .get_extension(1000)
261 .expect("extension field 1000 must be visible");
262 assert_eq!(ext.full_name(), "acme.blade_count");
263 assert_eq!(ext.kind(), prost_reflect::Kind::Int32);
264 }
265
266 #[test]
269 fn enum_named_float_not_mistaken_for_primitive() {
270 let float_enum = EnumDescriptorProto {
273 name: Some("float".into()),
274 value: vec![enum_value("FLOAT_ZERO", 0), enum_value("FLOAT_ONE", 1)],
275 ..Default::default()
276 };
277 let msg = DescriptorProto {
278 name: Some("Msg".into()),
279 field: vec![enum_field("kind", 1, ".float")],
280 ..Default::default()
281 };
282 let fds_bytes = build_fds(vec![float_enum], msg);
283 let schema = parse_schema(&fds_bytes, "Msg").unwrap();
284 let root = schema.root_descriptor().unwrap();
285
286 let kind_fd = root.get_field(1).expect("field 1 must exist");
287 assert!(
288 matches!(kind_fd.kind(), Kind::Enum(_)),
289 "field named 'float' backed by an enum must have Kind::Enum"
290 );
291 let Kind::Enum(enum_desc) = kind_fd.kind() else {
292 unreachable!()
293 };
294 assert!(
295 enum_desc.values().count() > 0,
296 "enum named 'float' must have non-empty values"
297 );
298 }
299}