prost_build/
message_graph.rs

1use std::collections::HashMap;
2
3use petgraph::algo::has_path_connecting;
4use petgraph::graph::NodeIndex;
5use petgraph::Graph;
6
7use prost_types::{
8    field_descriptor_proto::{Label, Type},
9    DescriptorProto, FileDescriptorProto,
10};
11
12/// `MessageGraph` builds a graph of messages whose edges correspond to nesting.
13/// The goal is to recognize when message types are recursively nested, so
14/// that fields can be boxed when necessary.
15pub struct MessageGraph {
16    index: HashMap<String, NodeIndex>,
17    graph: Graph<String, ()>,
18    messages: HashMap<String, DescriptorProto>,
19}
20
21impl MessageGraph {
22    pub(crate) fn new<'a>(files: impl Iterator<Item = &'a FileDescriptorProto>) -> MessageGraph {
23        let mut msg_graph = MessageGraph {
24            index: HashMap::new(),
25            graph: Graph::new(),
26            messages: HashMap::new(),
27        };
28
29        for file in files {
30            let package = format!(
31                "{}{}",
32                if file.package.is_some() { "." } else { "" },
33                file.package.as_deref().unwrap_or("")
34            );
35            for msg in &file.message_type {
36                msg_graph.add_message(&package, msg);
37            }
38        }
39
40        msg_graph
41    }
42
43    fn get_or_insert_index(&mut self, msg_name: String) -> NodeIndex {
44        assert_eq!(b'.', msg_name.as_bytes()[0]);
45        *self
46            .index
47            .entry(msg_name.clone())
48            .or_insert_with(|| self.graph.add_node(msg_name))
49    }
50
51    /// Adds message to graph IFF it contains a non-repeated field containing another message.
52    /// The purpose of the message graph is detecting recursively nested messages and co-recursively nested messages.
53    /// Because prost does not box message fields, recursively nested messages would not compile in Rust.
54    /// To allow recursive messages, the message graph is used to detect recursion and automatically box the recursive field.
55    /// Since repeated messages are already put in a Vec, boxing them isn’t necessary even if the reference is recursive.
56    fn add_message(&mut self, package: &str, msg: &DescriptorProto) {
57        let msg_name = format!("{}.{}", package, msg.name.as_ref().unwrap());
58        let msg_index = self.get_or_insert_index(msg_name.clone());
59
60        for field in &msg.field {
61            if field.r#type() == Type::Message && field.label() != Label::Repeated {
62                let field_index = self.get_or_insert_index(field.type_name.clone().unwrap());
63                self.graph.add_edge(msg_index, field_index, ());
64            }
65        }
66        self.messages.insert(msg_name.clone(), msg.clone());
67
68        for msg in &msg.nested_type {
69            self.add_message(&msg_name, msg);
70        }
71    }
72
73    /// Try get a message descriptor from current message graph
74    pub fn get_message(&self, message: &str) -> Option<&DescriptorProto> {
75        self.messages.get(message)
76    }
77
78    /// Returns true if message type `inner` is nested in message type `outer`.
79    pub fn is_nested(&self, outer: &str, inner: &str) -> bool {
80        let outer = match self.index.get(outer) {
81            Some(outer) => *outer,
82            None => return false,
83        };
84        let inner = match self.index.get(inner) {
85            Some(inner) => *inner,
86            None => return false,
87        };
88
89        has_path_connecting(&self.graph, outer, inner, None)
90    }
91}