onnx_helpers/builder/
graph.rs1use onnx_pb::{GraphProto, NodeProto, TensorProto, TypeProto, ValueInfoProto};
4
5use crate::{
6 builder::{self, Bag, Marker},
7 nodes::*,
8};
9
10#[derive(Default, Clone)]
12pub struct Graph {
13 name: String,
14 nodes: Vec<NodeProto>,
15 inputs: Vec<ValueInfoProto>,
16 outputs: Vec<ValueInfoProto>,
17 initializers: Vec<TensorProto>,
18 doc_string: Option<String>,
19 constants: i64,
20 bag: Bag,
21}
22
23impl Graph {
24 #[inline]
26 pub fn new<S: Into<String>>(name: S) -> Self {
27 Graph {
28 name: name.into(),
29 ..Graph::default()
30 }
31 }
32
33 #[inline]
35 pub fn name<S: Into<String>>(mut self, name: S) -> Self {
36 self.name = name.into();
37 self
38 }
39
40 #[inline]
42 pub fn doc_string<S: Into<String>>(mut self, doc_string: S) -> Self {
43 self.doc_string = Some(doc_string.into());
44 self
45 }
46
47 #[inline]
49 pub fn constant<S: Into<String>, T: Into<TensorProto>>(&mut self, name: S, tensor: T) -> Node {
50 let mut node: Node = ops::Constant::new(name, tensor).into();
51 node.bag = Some(self.bag.clone());
52 self.bag.node(node.inner.clone());
53 self.constants += 1;
54 node
55 }
56
57 #[inline(always)]
59 pub fn concat<I>(&mut self, axis: i64, inputs: I) -> Node
60 where
61 I: IntoIterator,
62 I::Item: Into<String>,
63 {
64 let mut node: Node = ops::Concat::new(axis, inputs).into();
65 node.bag = Some(self.bag.clone());
66 self.bag.node(node.inner.clone());
67 node
68 }
69
70 #[inline]
72 pub fn nodes<T: Into<NodeProto>>(mut self, node: T) -> Self {
73 self.nodes.push(node.into());
74 self
75 }
76
77 #[inline]
79 pub fn inputs<T: Into<ValueInfoProto>>(mut self, input: T) -> Self {
80 self.inputs.push(input.into());
81 self
82 }
83
84 #[inline]
86 pub fn outputs<T: Into<ValueInfoProto>>(mut self, output: T) -> Self {
87 self.outputs.push(output.into());
88 self
89 }
90
91 #[inline]
93 pub fn outputs_typed<T: Into<ValueInfoProto>, D: Into<TypeProto>>(
94 mut self,
95 output: T,
96 typ: D,
97 ) -> Self {
98 let mut info = output.into();
99 if info.r#type.is_none() {
100 info.r#type = Some(typ.into());
101 }
102 self.outputs.push(info);
103 self
104 }
105
106 #[inline]
108 pub fn initializer<T: Into<TensorProto>>(mut self, initializer: T) -> Self {
109 self.initializers.push(initializer.into());
110 self
111 }
112
113 #[inline]
115 pub fn node<T: Into<String>>(&mut self, name: T) -> builder::Node {
116 let mut node = builder::Node::default().name(name);
117 node.bag = Some(self.bag.clone());
118 node
119 }
120
121 #[inline]
123 pub fn input<T: Into<String>>(&mut self, name: T) -> builder::Value {
124 let mut value = builder::Value::new(name);
125 value.bag = Some(self.bag.clone());
126 value.marker = Some(Marker::Input);
127 value
128 }
129
130 #[inline]
132 pub fn output<T: Into<String>>(&mut self, name: T) -> builder::Value {
133 let mut value = builder::Value::new(name);
134 value.bag = Some(self.bag.clone());
135 value.marker = Some(Marker::Output);
136 value
137 }
138
139 #[inline]
141 pub fn model(self) -> builder::Model {
142 builder::Model::new(self.build())
143 }
144
145 #[inline]
147 pub fn build(self) -> GraphProto {
148 let mut nodes = self.nodes;
149 nodes.extend(self.bag.nodes().into_iter().map(Into::into));
150 nodes.dedup_by(|a, b| a.name == b.name);
151 sort_nodes(&mut nodes);
152 let mut inputs = self.inputs;
153 inputs.extend(self.bag.inputs().into_iter().map(Into::into));
154 inputs.dedup_by(|a, b| a.name == b.name);
155 let mut outputs = self.outputs;
156 outputs.extend(self.bag.outputs().into_iter().map(Into::into));
157 outputs.dedup_by(|a, b| a.name == b.name);
158 GraphProto {
159 name: self.name,
160 node: nodes,
161 input: inputs,
162 output: outputs,
163 doc_string: self.doc_string.unwrap_or_default(),
164 initializer: self.initializers,
165 ..GraphProto::default()
166 }
167 }
168}
169
170impl Into<GraphProto> for Graph {
171 fn into(self) -> GraphProto {
172 self.build()
173 }
174}
175
176fn sort_nodes(nodes: &mut Vec<NodeProto>) {
177 use std::collections::HashMap;
178 use std::iter::FromIterator;
179
180 use petgraph::{algo::toposort, graphmap::DiGraphMap};
181
182 let mut g = DiGraphMap::new();
183 for node in nodes.iter() {
184 for input in node.input.iter() {
185 g.add_edge(input, &node.name, 1);
186 }
187 for output in node.output.iter() {
188 g.add_edge(&node.name, output, 1);
189 }
190 }
191 let sorted: HashMap<String, usize> = HashMap::from_iter(
192 toposort(&g, None)
193 .unwrap()
194 .into_iter()
195 .enumerate()
196 .map(|(i, s)| (s.to_owned(), i)),
197 );
198 nodes.sort_by(|a, b| {
199 let a = sorted.get(&a.name).unwrap();
200 let b = sorted.get(&b.name).unwrap();
201 a.partial_cmp(b).unwrap()
202 });
203}