1use prost_reflect::{DescriptorPool, MessageDescriptor};
7
8pub struct ParsedSchema {
14 pool: DescriptorPool,
15 root_full_name: String,
16}
17
18impl ParsedSchema {
19 pub fn empty() -> Self {
21 ParsedSchema {
22 pool: DescriptorPool::new(),
23 root_full_name: String::new(),
24 }
25 }
26
27 pub fn root_descriptor(&self) -> Option<MessageDescriptor> {
30 if self.root_full_name.is_empty() {
31 None
32 } else {
33 self.pool.get_message_by_name(&self.root_full_name)
34 }
35 }
36
37 pub fn get_descriptor(&self, fqn: &str) -> Option<MessageDescriptor> {
39 self.pool.get_message_by_name(fqn)
40 }
41
42 pub fn pool(&self) -> &DescriptorPool {
44 &self.pool
45 }
46}
47
48#[non_exhaustive]
52#[derive(Debug)]
53pub enum SchemaError {
54 InvalidDescriptor(String),
56 MessageNotFound(String),
58}
59
60impl std::fmt::Display for SchemaError {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 match self {
63 SchemaError::InvalidDescriptor(msg) => write!(f, "invalid descriptor: {msg}"),
64 SchemaError::MessageNotFound(msg) => write!(f, "message not found: {msg}"),
65 }
66 }
67}
68
69impl std::error::Error for SchemaError {}
70
71pub fn decode_pool(schema_bytes: &[u8]) -> Result<DescriptorPool, SchemaError> {
76 DescriptorPool::decode(schema_bytes).map_err(|e| SchemaError::InvalidDescriptor(e.to_string()))
77}
78
79pub fn schema_from_pool(
81 pool: DescriptorPool,
82 root_msg_name: &str,
83) -> Result<ParsedSchema, SchemaError> {
84 let root_full_name = root_msg_name.trim_start_matches('.').to_string();
85 if pool.get_message_by_name(&root_full_name).is_none() {
86 let available = pool
87 .all_messages()
88 .map(|m| m.full_name().to_string())
89 .collect::<Vec<_>>()
90 .join(", ");
91 return Err(SchemaError::MessageNotFound(format!(
92 "root message '{}' not found in schema (available: {})",
93 root_full_name, available
94 )));
95 }
96 Ok(ParsedSchema {
97 pool,
98 root_full_name,
99 })
100}
101
102pub fn parse_schema(schema_bytes: &[u8], root_msg_name: &str) -> Result<ParsedSchema, SchemaError> {
107 if schema_bytes.is_empty() || root_msg_name.is_empty() {
108 return Ok(ParsedSchema::empty());
109 }
110
111 let pool = DescriptorPool::decode(schema_bytes)
112 .map_err(|e| SchemaError::InvalidDescriptor(e.to_string()))?;
113
114 let root_full_name = root_msg_name.trim_start_matches('.').to_string();
116
117 if pool.get_message_by_name(&root_full_name).is_none() {
118 let available = pool
119 .all_messages()
120 .map(|m| m.full_name().to_string())
121 .collect::<Vec<_>>()
122 .join(", ");
123 return Err(SchemaError::MessageNotFound(format!(
124 "root message '{}' not found in schema (available: {})",
125 root_full_name, available
126 )));
127 }
128
129 Ok(ParsedSchema {
130 pool,
131 root_full_name,
132 })
133}
134
135#[cfg(test)]
137mod tests {
138 use super::*;
139 use prost::Message as ProstMessage;
140 use prost_reflect::Kind;
141 use prost_types::{
142 DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto,
143 FileDescriptorProto, FileDescriptorSet,
144 };
145
146 const TYPE_ENUM: i32 = 14;
148 const TYPE_INT32: i32 = 5;
149 const LABEL_OPTIONAL: i32 = 1;
150
151 fn build_fds(enums: Vec<EnumDescriptorProto>, message: DescriptorProto) -> Vec<u8> {
154 let file = FileDescriptorProto {
155 name: Some("test.proto".into()),
156 syntax: Some("proto2".into()),
157 enum_type: enums,
158 message_type: vec![message],
159 ..Default::default()
160 };
161 let fds = FileDescriptorSet { file: vec![file] };
162 let mut buf = Vec::new();
163 fds.encode(&mut buf).unwrap();
164 buf
165 }
166
167 fn enum_value(name: &str, number: i32) -> EnumValueDescriptorProto {
168 EnumValueDescriptorProto {
169 name: Some(name.into()),
170 number: Some(number),
171 ..Default::default()
172 }
173 }
174
175 fn enum_field(name: &str, number: i32, type_name: &str) -> FieldDescriptorProto {
176 FieldDescriptorProto {
177 name: Some(name.into()),
178 number: Some(number),
179 r#type: Some(TYPE_ENUM),
180 label: Some(LABEL_OPTIONAL),
181 type_name: Some(type_name.into()),
182 ..Default::default()
183 }
184 }
185
186 fn int32_field(name: &str, number: i32) -> FieldDescriptorProto {
187 FieldDescriptorProto {
188 name: Some(name.into()),
189 number: Some(number),
190 r#type: Some(TYPE_INT32),
191 label: Some(LABEL_OPTIONAL),
192 ..Default::default()
193 }
194 }
195
196 #[test]
199 fn two_pass_enum_collection() {
200 let color_enum = EnumDescriptorProto {
203 name: Some("Color".into()),
204 value: vec![
205 enum_value("RED", 0),
206 enum_value("GREEN", 1),
207 enum_value("BLUE", 2),
208 ],
209 ..Default::default()
210 };
211 let msg = DescriptorProto {
212 name: Some("Msg".into()),
213 field: vec![enum_field("color", 1, ".Color"), int32_field("id", 2)],
214 ..Default::default()
215 };
216 let fds_bytes = build_fds(vec![color_enum], msg);
217 let schema = parse_schema(&fds_bytes, "Msg").unwrap();
218 let root = schema.root_descriptor().unwrap();
219
220 let color_fd = root.get_field(1).expect("field 1 must exist");
221 let Kind::Enum(enum_desc) = color_fd.kind() else {
222 panic!("field 1 must be an enum");
223 };
224 let names: Vec<String> = enum_desc.values().map(|v| v.name().to_owned()).collect();
225 assert_eq!(
226 names,
227 &["RED", "GREEN", "BLUE"],
228 "enum values must be present"
229 );
230
231 let id_fd = root.get_field(2).expect("field 2 must exist");
232 assert_eq!(id_fd.kind(), Kind::Int32, "field 2 must be int32");
233 }
234
235 #[test]
250 fn extension_field_visible_via_get_extension() {
251 use prost_types::descriptor_proto::ExtensionRange;
252
253 let extension_field = FieldDescriptorProto {
254 name: Some("blade_count".into()),
255 number: Some(1000),
256 r#type: Some(TYPE_INT32),
257 label: Some(LABEL_OPTIONAL),
258 extendee: Some(".acme.Gadget".into()),
259 ..Default::default()
260 };
261
262 let gadget_msg = DescriptorProto {
263 name: Some("Gadget".into()),
264 extension_range: vec![ExtensionRange {
265 start: Some(1000),
266 end: Some(2000),
267 ..Default::default()
268 }],
269 ..Default::default()
270 };
271
272 let file = prost_types::FileDescriptorProto {
273 name: Some("gadget.proto".into()),
274 syntax: Some("proto2".into()),
275 package: Some("acme".into()),
276 message_type: vec![gadget_msg],
277 extension: vec![extension_field],
278 ..Default::default()
279 };
280 let fds = prost_types::FileDescriptorSet { file: vec![file] };
281 let mut buf = Vec::new();
282 fds.encode(&mut buf).unwrap();
283
284 let schema = parse_schema(&buf, "acme.Gadget").unwrap();
285 let root = schema.root_descriptor().unwrap();
286
287 let ext = root
289 .get_extension(1000)
290 .expect("extension field 1000 must be visible");
291 assert_eq!(ext.full_name(), "acme.blade_count");
292 assert_eq!(ext.kind(), prost_reflect::Kind::Int32);
293 }
294
295 #[test]
298 fn enum_named_float_not_mistaken_for_primitive() {
299 let float_enum = EnumDescriptorProto {
302 name: Some("float".into()),
303 value: vec![enum_value("FLOAT_ZERO", 0), enum_value("FLOAT_ONE", 1)],
304 ..Default::default()
305 };
306 let msg = DescriptorProto {
307 name: Some("Msg".into()),
308 field: vec![enum_field("kind", 1, ".float")],
309 ..Default::default()
310 };
311 let fds_bytes = build_fds(vec![float_enum], msg);
312 let schema = parse_schema(&fds_bytes, "Msg").unwrap();
313 let root = schema.root_descriptor().unwrap();
314
315 let kind_fd = root.get_field(1).expect("field 1 must exist");
316 assert!(
317 matches!(kind_fd.kind(), Kind::Enum(_)),
318 "field named 'float' backed by an enum must have Kind::Enum"
319 );
320 let Kind::Enum(enum_desc) = kind_fd.kind() else {
321 unreachable!()
322 };
323 assert!(
324 enum_desc.values().count() > 0,
325 "enum named 'float' must have non-empty values"
326 );
327 }
328}