onnx_helpers/builder/
graph.rs

1//! Graph builder.
2
3use onnx_pb::{GraphProto, NodeProto, TensorProto, TypeProto, ValueInfoProto};
4
5use crate::{
6    builder::{self, Bag, Marker},
7    nodes::*,
8};
9
10/// Graph builder.
11#[derive(Default, Clone)]
12pub struct Graph {
13    name: String,
14    nodes: Vec<NodeProto>,
15    inputs: Vec<ValueInfoProto>,
16    outputs: Vec<ValueInfoProto>,
17    initializers: Vec<TensorProto>,
18    doc_string: Option<String>,
19    constants: i64,
20    bag: Bag,
21}
22
23impl Graph {
24    /// Creates a new builder.
25    #[inline]
26    pub fn new<S: Into<String>>(name: S) -> Self {
27        Graph {
28            name: name.into(),
29            ..Graph::default()
30        }
31    }
32
33    /// Sets graph name.
34    #[inline]
35    pub fn name<S: Into<String>>(mut self, name: S) -> Self {
36        self.name = name.into();
37        self
38    }
39
40    /// Sets graph doc_string.
41    #[inline]
42    pub fn doc_string<S: Into<String>>(mut self, doc_string: S) -> Self {
43        self.doc_string = Some(doc_string.into());
44        self
45    }
46
47    /// Creates constant node in a graph.
48    #[inline]
49    pub fn constant<S: Into<String>, T: Into<TensorProto>>(&mut self, name: S, tensor: T) -> Node {
50        let mut node: Node = ops::Constant::new(name, tensor).into();
51        node.bag = Some(self.bag.clone());
52        self.bag.node(node.inner.clone());
53        self.constants += 1;
54        node
55    }
56
57    /// Creates a concat node in a graph.
58    #[inline(always)]
59    pub fn concat<I>(&mut self, axis: i64, inputs: I) -> Node
60    where
61        I: IntoIterator,
62        I::Item: Into<String>,
63    {
64        let mut node: Node = ops::Concat::new(axis, inputs).into();
65        node.bag = Some(self.bag.clone());
66        self.bag.node(node.inner.clone());
67        node
68    }
69
70    /// Inserts graph nodes.
71    #[inline]
72    pub fn nodes<T: Into<NodeProto>>(mut self, node: T) -> Self {
73        self.nodes.push(node.into());
74        self
75    }
76
77    /// Inserts graph inputs.
78    #[inline]
79    pub fn inputs<T: Into<ValueInfoProto>>(mut self, input: T) -> Self {
80        self.inputs.push(input.into());
81        self
82    }
83
84    /// Inserts graph outputs.
85    #[inline]
86    pub fn outputs<T: Into<ValueInfoProto>>(mut self, output: T) -> Self {
87        self.outputs.push(output.into());
88        self
89    }
90
91    /// Inserts typed graph outputs.
92    #[inline]
93    pub fn outputs_typed<T: Into<ValueInfoProto>, D: Into<TypeProto>>(
94        mut self,
95        output: T,
96        typ: D,
97    ) -> Self {
98        let mut info = output.into();
99        if info.r#type.is_none() {
100            info.r#type = Some(typ.into());
101        }
102        self.outputs.push(info);
103        self
104    }
105
106    /// Inserts graph initializers.
107    #[inline]
108    pub fn initializer<T: Into<TensorProto>>(mut self, initializer: T) -> Self {
109        self.initializers.push(initializer.into());
110        self
111    }
112
113    /// Creates graph node builder.
114    #[inline]
115    pub fn node<T: Into<String>>(&mut self, name: T) -> builder::Node {
116        let mut node = builder::Node::default().name(name);
117        node.bag = Some(self.bag.clone());
118        node
119    }
120
121    /// Creates graph input builder.
122    #[inline]
123    pub fn input<T: Into<String>>(&mut self, name: T) -> builder::Value {
124        let mut value = builder::Value::new(name);
125        value.bag = Some(self.bag.clone());
126        value.marker = Some(Marker::Input);
127        value
128    }
129
130    /// Creates graph output builder.
131    #[inline]
132    pub fn output<T: Into<String>>(&mut self, name: T) -> builder::Value {
133        let mut value = builder::Value::new(name);
134        value.bag = Some(self.bag.clone());
135        value.marker = Some(Marker::Output);
136        value
137    }
138
139    /// Builds a model builder from graph.
140    #[inline]
141    pub fn model(self) -> builder::Model {
142        builder::Model::new(self.build())
143    }
144
145    /// Builds the graph.
146    #[inline]
147    pub fn build(self) -> GraphProto {
148        let mut nodes = self.nodes;
149        nodes.extend(self.bag.nodes().into_iter().map(Into::into));
150        nodes.dedup_by(|a, b| a.name == b.name);
151        sort_nodes(&mut nodes);
152        let mut inputs = self.inputs;
153        inputs.extend(self.bag.inputs().into_iter().map(Into::into));
154        inputs.dedup_by(|a, b| a.name == b.name);
155        let mut outputs = self.outputs;
156        outputs.extend(self.bag.outputs().into_iter().map(Into::into));
157        outputs.dedup_by(|a, b| a.name == b.name);
158        GraphProto {
159            name: self.name,
160            node: nodes,
161            input: inputs,
162            output: outputs,
163            doc_string: self.doc_string.unwrap_or_default(),
164            initializer: self.initializers,
165            ..GraphProto::default()
166        }
167    }
168}
169
170impl Into<GraphProto> for Graph {
171    fn into(self) -> GraphProto {
172        self.build()
173    }
174}
175
176fn sort_nodes(nodes: &mut Vec<NodeProto>) {
177    use std::collections::HashMap;
178    use std::iter::FromIterator;
179
180    use petgraph::{algo::toposort, graphmap::DiGraphMap};
181
182    let mut g = DiGraphMap::new();
183    for node in nodes.iter() {
184        for input in node.input.iter() {
185            g.add_edge(input, &node.name, 1);
186        }
187        for output in node.output.iter() {
188            g.add_edge(&node.name, output, 1);
189        }
190    }
191    let sorted: HashMap<String, usize> = HashMap::from_iter(
192        toposort(&g, None)
193            .unwrap()
194            .into_iter()
195            .enumerate()
196            .map(|(i, s)| (s.to_owned(), i)),
197    );
198    nodes.sort_by(|a, b| {
199        let a = sorted.get(&a.name).unwrap();
200        let b = sorted.get(&b.name).unwrap();
201        a.partial_cmp(b).unwrap()
202    });
203}