1use std::collections::HashMap;
7use std::sync::Arc;
8
9use prost::Message as ProstMessage;
10use prost_types::{field_descriptor_proto::Type, FileDescriptorProto, FileDescriptorSet};
11
12type EnumValueMap = HashMap<String, Vec<(i32, Box<str>)>>;
16
17#[derive(Debug, Clone)]
21pub struct FieldInfo {
22 pub name: String,
24 pub proto_type: i32,
26 pub label: i32,
28 pub is_packed: bool,
31 pub nested_type_name: Option<String>,
33 pub enum_type_name: Option<String>,
35 pub type_display_name: Option<String>,
37 pub enum_values: Box<[(i32, Box<str>)]>,
41}
42
43#[derive(Debug)]
45pub struct MessageSchema {
46 pub name: String,
48 pub fields: HashMap<u32, FieldInfo>,
50}
51
52pub struct ParsedSchema {
58 pub root_type_name: String,
59 pub messages: HashMap<String, Arc<MessageSchema>>,
60 pub all_schemas: Arc<HashMap<String, Arc<MessageSchema>>>,
64}
65
66impl ParsedSchema {
67 pub fn empty() -> Self {
71 ParsedSchema {
72 root_type_name: String::new(),
73 messages: HashMap::new(),
74 all_schemas: Arc::new(HashMap::new()),
75 }
76 }
77
78 pub fn root_schema(&self) -> Option<Arc<MessageSchema>> {
81 if self.root_type_name.is_empty() {
82 None
83 } else {
84 self.messages.get(&self.root_type_name).cloned()
85 }
86 }
87}
88
89#[non_exhaustive]
93#[derive(Debug)]
94pub enum SchemaError {
95 InvalidDescriptor(String),
98 MessageNotFound(String),
100}
101
102impl std::fmt::Display for SchemaError {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 match self {
105 SchemaError::InvalidDescriptor(msg) => write!(f, "invalid descriptor: {msg}"),
106 SchemaError::MessageNotFound(msg) => write!(f, "message not found: {msg}"),
107 }
108 }
109}
110
111impl std::error::Error for SchemaError {}
112
113pub fn parse_schema(schema_bytes: &[u8], root_msg_name: &str) -> Result<ParsedSchema, SchemaError> {
123 if schema_bytes.is_empty() || root_msg_name.is_empty() {
124 return Ok(ParsedSchema {
125 root_type_name: String::new(),
126 messages: HashMap::new(),
127 all_schemas: Arc::new(HashMap::new()), });
129 }
130
131 let files: Vec<FileDescriptorProto> = if let Ok(fds) = FileDescriptorSet::decode(schema_bytes) {
133 fds.file
134 } else if let Ok(fdp) = FileDescriptorProto::decode(schema_bytes) {
135 vec![fdp]
136 } else {
137 return Err(SchemaError::InvalidDescriptor(
138 "schema_bytes is neither a valid FileDescriptorSet nor FileDescriptorProto".into(),
139 ));
140 };
141
142 let mut raw: HashMap<String, prost_types::DescriptorProto> = HashMap::new();
145 for file in &files {
146 let pkg = file.package.as_deref().unwrap_or("");
147 collect_message_types(pkg, &file.message_type, &mut raw);
148 }
149
150 let mut enum_map: EnumValueMap = HashMap::new();
153 for file in &files {
154 let pkg = file.package.as_deref().unwrap_or("");
155 collect_enum_types(pkg, &file.enum_type, &mut enum_map);
156 collect_nested_enum_types(pkg, &file.message_type, &mut enum_map);
158 }
159
160 let mut messages: HashMap<String, Arc<MessageSchema>> = HashMap::new();
162 for (fqn, dp) in &raw {
163 let schema = build_message_schema(dp, &raw, &enum_map);
164 messages.insert(fqn.clone(), Arc::new(schema));
165 }
166
167 let root_type_name = if root_msg_name.starts_with('.') {
169 root_msg_name.to_string()
170 } else {
171 format!(".{}", root_msg_name)
172 };
173
174 if !messages.contains_key(&root_type_name) {
175 return Err(SchemaError::MessageNotFound(format!(
176 "root message '{}' not found in schema (available: {})",
177 root_type_name,
178 messages.keys().cloned().collect::<Vec<_>>().join(", ")
179 )));
180 }
181
182 let all_schemas = Arc::new(
185 messages
186 .iter()
187 .map(|(k, v)| (k.clone(), Arc::clone(v)))
188 .collect::<HashMap<_, _>>(),
189 );
190
191 Ok(ParsedSchema {
192 root_type_name,
193 messages,
194 all_schemas,
195 })
196}
197
198fn collect_message_types(
203 parent_prefix: &str,
204 descriptors: &[prost_types::DescriptorProto],
205 out: &mut HashMap<String, prost_types::DescriptorProto>,
206) {
207 for dp in descriptors {
208 let name = dp.name.as_deref().unwrap_or("");
209 let fqn = if parent_prefix.is_empty() {
210 format!(".{}", name)
211 } else {
212 format!(".{}.{}", parent_prefix, name)
213 };
214 let nested_prefix = if parent_prefix.is_empty() {
216 name.to_string()
217 } else {
218 format!("{}.{}", parent_prefix, name)
219 };
220 collect_message_types(&nested_prefix, &dp.nested_type, out);
221 out.insert(fqn, dp.clone());
222 }
223}
224
225fn build_message_schema(
227 dp: &prost_types::DescriptorProto,
228 _all: &HashMap<String, prost_types::DescriptorProto>,
229 enum_map: &EnumValueMap,
230) -> MessageSchema {
231 let mut fields = HashMap::new();
232 for fdp in &dp.field {
233 let number = match fdp.number {
234 Some(n) => n as u32,
235 None => continue,
236 };
237 let proto_type = fdp.r#type.unwrap_or(0);
238 let label = fdp.label.unwrap_or(0);
239
240 let is_packed = fdp.options.as_ref().and_then(|o| o.packed).unwrap_or(false);
244
245 let nested_type_name = fdp.type_name.clone();
253
254 let type_display_name = nested_type_name
256 .as_ref()
257 .map(|tn| tn.rsplit('.').next().unwrap_or(tn).to_string());
258
259 let enum_type_name = if proto_type == Type::Enum as i32 {
261 nested_type_name.clone()
262 } else {
263 None
264 };
265
266 let enum_values: Box<[(i32, Box<str>)]> = if proto_type == Type::Enum as i32 {
268 if let Some(etn) = &enum_type_name {
269 if let Some(vals) = enum_map.get(etn.as_str()) {
270 vals.iter()
271 .map(|(n, s)| (*n, s.clone()))
272 .collect::<Vec<_>>()
273 .into_boxed_slice()
274 } else {
275 Box::default()
276 }
277 } else {
278 Box::default()
279 }
280 } else {
281 Box::default()
282 };
283
284 let fi = FieldInfo {
285 name: fdp.name.clone().unwrap_or_default(),
286 proto_type,
287 label,
288 is_packed,
289 nested_type_name: if proto_type == Type::Message as i32
290 || proto_type == Type::Group as i32
291 {
292 nested_type_name
293 } else {
294 None
295 },
296 enum_type_name,
297 type_display_name,
298 enum_values,
299 };
300 fields.insert(number, fi);
301 }
302
303 MessageSchema {
304 name: dp.name.clone().unwrap_or_default(),
305 fields,
306 }
307}
308
309fn collect_enum_types(
311 parent_prefix: &str,
312 enums: &[prost_types::EnumDescriptorProto],
313 out: &mut EnumValueMap,
314) {
315 for edp in enums {
316 let name = edp.name.as_deref().unwrap_or("");
317 let fqn = if parent_prefix.is_empty() {
318 format!(".{}", name)
319 } else {
320 format!(".{}.{}", parent_prefix, name)
321 };
322 let mut vals: Vec<(i32, Box<str>)> = edp
323 .value
324 .iter()
325 .filter_map(|vdp| {
326 let n = vdp.number?;
327 let s: Box<str> = vdp.name.as_deref().unwrap_or("").into();
328 Some((n, s))
329 })
330 .collect();
331 vals.sort_by_key(|(n, _)| *n);
332 out.insert(fqn, vals);
333 }
334}
335
336fn collect_nested_enum_types(
338 parent_prefix: &str,
339 descriptors: &[prost_types::DescriptorProto],
340 out: &mut EnumValueMap,
341) {
342 for dp in descriptors {
343 let name = dp.name.as_deref().unwrap_or("");
344 let prefix = if parent_prefix.is_empty() {
345 name.to_string()
346 } else {
347 format!("{}.{}", parent_prefix, name)
348 };
349 collect_enum_types(&prefix, &dp.enum_type, out);
350 collect_nested_enum_types(&prefix, &dp.nested_type, out);
351 }
352}
353
354pub mod proto_type {
357 pub const DOUBLE: i32 = 1;
358 pub const FLOAT: i32 = 2;
359 pub const INT64: i32 = 3;
360 pub const UINT64: i32 = 4;
361 pub const INT32: i32 = 5;
362 pub const FIXED64: i32 = 6;
363 pub const FIXED32: i32 = 7;
364 pub const BOOL: i32 = 8;
365 pub const STRING: i32 = 9;
366 pub const GROUP: i32 = 10;
367 pub const MESSAGE: i32 = 11;
368 pub const BYTES: i32 = 12;
369 pub const UINT32: i32 = 13;
370 pub const ENUM: i32 = 14;
371 pub const SFIXED32: i32 = 15;
372 pub const SFIXED64: i32 = 16;
373 pub const SINT32: i32 = 17;
374 pub const SINT64: i32 = 18;
375}
376
377pub mod proto_label {
378 pub const OPTIONAL: i32 = 1;
379 pub const REQUIRED: i32 = 2;
380 pub const REPEATED: i32 = 3;
381}
382
383#[cfg(test)]
385mod tests {
386 use super::*;
387 use prost::Message as ProstMessage;
388 use prost_types::{
389 DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto,
390 FileDescriptorProto, FileDescriptorSet,
391 };
392
393 fn build_fds(enums: Vec<EnumDescriptorProto>, message: DescriptorProto) -> Vec<u8> {
396 let file = FileDescriptorProto {
397 name: Some("test.proto".into()),
398 syntax: Some("proto2".into()),
399 enum_type: enums,
400 message_type: vec![message],
401 ..Default::default()
402 };
403 let fds = FileDescriptorSet { file: vec![file] };
404 let mut buf = Vec::new();
405 fds.encode(&mut buf).unwrap();
406 buf
407 }
408
409 fn enum_value(name: &str, number: i32) -> EnumValueDescriptorProto {
410 EnumValueDescriptorProto {
411 name: Some(name.into()),
412 number: Some(number),
413 ..Default::default()
414 }
415 }
416
417 fn enum_field(name: &str, number: i32, type_name: &str) -> FieldDescriptorProto {
418 FieldDescriptorProto {
419 name: Some(name.into()),
420 number: Some(number),
421 r#type: Some(proto_type::ENUM),
422 label: Some(proto_label::OPTIONAL),
423 type_name: Some(type_name.into()),
424 ..Default::default()
425 }
426 }
427
428 fn int32_field(name: &str, number: i32) -> FieldDescriptorProto {
429 FieldDescriptorProto {
430 name: Some(name.into()),
431 number: Some(number),
432 r#type: Some(proto_type::INT32),
433 label: Some(proto_label::OPTIONAL),
434 ..Default::default()
435 }
436 }
437
438 #[test]
441 fn two_pass_enum_collection() {
442 let color_enum = EnumDescriptorProto {
445 name: Some("Color".into()),
446 value: vec![
447 enum_value("RED", 0),
448 enum_value("GREEN", 1),
449 enum_value("BLUE", 2),
450 ],
451 ..Default::default()
452 };
453 let msg = DescriptorProto {
454 name: Some("Msg".into()),
455 field: vec![enum_field("color", 1, ".Color"), int32_field("id", 2)],
456 ..Default::default()
457 };
458 let fds_bytes = build_fds(vec![color_enum], msg);
459 let schema = parse_schema(&fds_bytes, "Msg").unwrap();
460 let root = schema.root_schema().unwrap();
461
462 let color_fi = root.fields.get(&1).expect("field 1 must exist");
463 assert_eq!(
464 color_fi.enum_values.as_ref(),
465 &[(0, "RED".into()), (1, "GREEN".into()), (2, "BLUE".into())],
466 "enum_values must be sorted by numeric value"
467 );
468
469 let id_fi = root.fields.get(&2).expect("field 2 must exist");
470 assert!(
471 id_fi.enum_values.is_empty(),
472 "non-enum field must have empty enum_values"
473 );
474 }
475
476 #[test]
479 fn enum_named_float_not_mistaken_for_primitive() {
480 let float_enum = EnumDescriptorProto {
483 name: Some("float".into()),
484 value: vec![enum_value("FLOAT_ZERO", 0), enum_value("FLOAT_ONE", 1)],
485 ..Default::default()
486 };
487 let msg = DescriptorProto {
488 name: Some("Msg".into()),
489 field: vec![enum_field("kind", 1, ".float")],
490 ..Default::default()
491 };
492 let fds_bytes = build_fds(vec![float_enum], msg);
493 let schema = parse_schema(&fds_bytes, "Msg").unwrap();
494 let root = schema.root_schema().unwrap();
495
496 let kind_fi = root.fields.get(&1).expect("field 1 must exist");
497 assert_eq!(
498 kind_fi.proto_type,
499 proto_type::ENUM,
500 "field named 'float' backed by an enum must have proto_type=ENUM"
501 );
502 assert!(
503 !kind_fi.enum_values.is_empty(),
504 "enum named 'float' must have non-empty enum_values"
505 );
506 }
507}