prost_build/
message_graph.rs1use std::collections::HashMap;
2
3use petgraph::algo::has_path_connecting;
4use petgraph::graph::NodeIndex;
5use petgraph::Graph;
6
7use prost_types::{
8 field_descriptor_proto::{Label, Type},
9 DescriptorProto, FileDescriptorProto,
10};
11
12pub struct MessageGraph {
16 index: HashMap<String, NodeIndex>,
17 graph: Graph<String, ()>,
18 messages: HashMap<String, DescriptorProto>,
19}
20
21impl MessageGraph {
22 pub(crate) fn new<'a>(files: impl Iterator<Item = &'a FileDescriptorProto>) -> MessageGraph {
23 let mut msg_graph = MessageGraph {
24 index: HashMap::new(),
25 graph: Graph::new(),
26 messages: HashMap::new(),
27 };
28
29 for file in files {
30 let package = format!(
31 "{}{}",
32 if file.package.is_some() { "." } else { "" },
33 file.package.as_deref().unwrap_or("")
34 );
35 for msg in &file.message_type {
36 msg_graph.add_message(&package, msg);
37 }
38 }
39
40 msg_graph
41 }
42
43 fn get_or_insert_index(&mut self, msg_name: String) -> NodeIndex {
44 assert_eq!(b'.', msg_name.as_bytes()[0]);
45 *self
46 .index
47 .entry(msg_name.clone())
48 .or_insert_with(|| self.graph.add_node(msg_name))
49 }
50
51 fn add_message(&mut self, package: &str, msg: &DescriptorProto) {
57 let msg_name = format!("{}.{}", package, msg.name.as_ref().unwrap());
58 let msg_index = self.get_or_insert_index(msg_name.clone());
59
60 for field in &msg.field {
61 if field.r#type() == Type::Message && field.label() != Label::Repeated {
62 let field_index = self.get_or_insert_index(field.type_name.clone().unwrap());
63 self.graph.add_edge(msg_index, field_index, ());
64 }
65 }
66 self.messages.insert(msg_name.clone(), msg.clone());
67
68 for msg in &msg.nested_type {
69 self.add_message(&msg_name, msg);
70 }
71 }
72
73 pub fn get_message(&self, message: &str) -> Option<&DescriptorProto> {
75 self.messages.get(message)
76 }
77
78 pub fn is_nested(&self, outer: &str, inner: &str) -> bool {
80 let outer = match self.index.get(outer) {
81 Some(outer) => *outer,
82 None => return false,
83 };
84 let inner = match self.index.get(inner) {
85 Some(inner) => *inner,
86 None => return false,
87 };
88
89 has_path_connecting(&self.graph, outer, inner, None)
90 }
91}