1use crate::autograd::{OpKind, Tensor};
6use crate::llo::{array_to_onnx_tensor, OnnxAttribute, OnnxModel, OnnxNode, OnnxTensor};
7use anyhow::{anyhow, Result};
8use std::collections::{HashMap, HashSet};
9
10struct GraphExporter {
12 model: OnnxModel,
13 _visited: HashSet<usize>, node_outputs: HashMap<usize, String>, tensor_names: HashMap<usize, String>, name_counter: usize,
17}
18
19impl GraphExporter {
20 fn new(name: &str) -> Self {
21 GraphExporter {
22 model: OnnxModel::new(name),
23 _visited: HashSet::new(),
24 node_outputs: HashMap::new(),
25 tensor_names: HashMap::new(),
26 name_counter: 0,
27 }
28 }
29
30 fn get_tensor_name(&mut self, tensor: &Tensor) -> String {
31 let id = tensor.data.data.as_ptr() as usize;
33
34 if let Some(name) = self.tensor_names.get(&id) {
35 return name.clone();
36 }
37
38 let name = format!("tensor_{}", self.name_counter);
39 self.name_counter += 1;
40 self.tensor_names.insert(id, name.clone());
41 name
42 }
43
44 fn traverse(&mut self, tensor: &Tensor) -> Result<String> {
45 let output_name = self.get_tensor_name(tensor);
46
47 if tensor.compute_node.is_none() {
49 if self
50 .model
51 .graph
52 .inputs
53 .iter()
54 .any(|i| i.name == output_name)
55 || self
56 .model
57 .graph
58 .initializers
59 .iter()
60 .any(|i| i.name == output_name)
61 {
62 return Ok(output_name);
63 }
64
65 if tensor.requires_grad {
66 let onnx_tensor = array_to_onnx_tensor(&output_name, &tensor.data)?;
68 self.model.add_initializer(onnx_tensor);
69 } else {
70 let input = OnnxTensor {
72 name: output_name.clone(),
73 dtype: 1, shape: tensor.shape().to_vec(),
75 data: vec![], };
77 self.model.add_input(input);
78 }
79 return Ok(output_name);
80 }
81
82 let node = tensor.compute_node.as_ref().unwrap();
84 if self.node_outputs.values().any(|n| n == &output_name) {
104 return Ok(output_name);
105 }
106
107 let mut input_names = Vec::new();
108 for input in &node.inputs {
109 input_names.push(self.traverse(input)?);
110 }
111
112 let node_name = format!("node_{}_{}", self.name_counter, output_name);
114
115 let onnx_node = match &node.op {
116 OpKind::Add => OnnxNode {
117 op_type: "Add".to_string(),
118 name: node_name,
119 inputs: input_names,
120 outputs: vec![output_name.clone()],
121 attributes: HashMap::new(),
122 },
123 OpKind::Mul => OnnxNode {
124 op_type: "Mul".to_string(),
125 name: node_name,
126 inputs: input_names,
127 outputs: vec![output_name.clone()],
128 attributes: HashMap::new(),
129 },
130 OpKind::MatMul => OnnxNode {
131 op_type: "MatMul".to_string(),
132 name: node_name,
133 inputs: input_names,
134 outputs: vec![output_name.clone()],
135 attributes: HashMap::new(),
136 },
137 OpKind::ReLU => OnnxNode {
138 op_type: "Relu".to_string(),
139 name: node_name,
140 inputs: input_names,
141 outputs: vec![output_name.clone()],
142 attributes: HashMap::new(),
143 },
144 OpKind::Sigmoid => OnnxNode {
145 op_type: "Sigmoid".to_string(),
146 name: node_name,
147 inputs: input_names,
148 outputs: vec![output_name.clone()],
149 attributes: HashMap::new(),
150 },
151 OpKind::Softmax => {
152 let mut attrs = HashMap::new();
153 attrs.insert("axis".to_string(), OnnxAttribute::Int(1));
154 OnnxNode {
155 op_type: "Softmax".to_string(),
156 name: node_name,
157 inputs: input_names,
158 outputs: vec![output_name.clone()],
159 attributes: attrs,
160 }
161 }
162 OpKind::Log => OnnxNode {
163 op_type: "Log".to_string(),
164 name: node_name,
165 inputs: input_names,
166 outputs: vec![output_name.clone()],
167 attributes: HashMap::new(),
168 },
169 OpKind::Exp => OnnxNode {
170 op_type: "Exp".to_string(),
171 name: node_name,
172 inputs: input_names,
173 outputs: vec![output_name.clone()],
174 attributes: HashMap::new(),
175 },
176 OpKind::Sub => OnnxNode {
177 op_type: "Sub".to_string(),
178 name: node_name,
179 inputs: input_names,
180 outputs: vec![output_name.clone()],
181 attributes: HashMap::new(),
182 },
183 OpKind::Div => OnnxNode {
184 op_type: "Div".to_string(),
185 name: node_name,
186 inputs: input_names,
187 outputs: vec![output_name.clone()],
188 attributes: HashMap::new(),
189 },
190 OpKind::Transpose => {
191 let mut attrs = HashMap::new();
192 attrs.insert("perm".to_string(), OnnxAttribute::Ints(vec![1, 0]));
194 OnnxNode {
195 op_type: "Transpose".to_string(),
196 name: node_name,
197 inputs: input_names,
198 outputs: vec![output_name.clone()],
199 attributes: attrs,
200 }
201 }
202 OpKind::Conv1D { stride, padding } => {
203 let mut attrs = HashMap::new();
204 attrs.insert(
205 "strides".to_string(),
206 OnnxAttribute::Ints(vec![*stride as i64]),
207 );
208 attrs.insert(
209 "pads".to_string(),
210 OnnxAttribute::Ints(vec![*padding as i64, *padding as i64]),
211 ); if node.inputs.len() >= 2 {
215 let weight = &node.inputs[1];
216 let k = weight.shape()[2];
217 attrs.insert(
218 "kernel_shape".to_string(),
219 OnnxAttribute::Ints(vec![k as i64]),
220 );
221 }
222
223 OnnxNode {
224 op_type: "Conv".to_string(),
225 name: node_name,
226 inputs: input_names,
227 outputs: vec![output_name.clone()],
228 attributes: attrs,
229 }
230 }
231 OpKind::Flatten { start_dim, end_dim } => {
232 if *start_dim == 1 && (*end_dim == usize::MAX || *end_dim == 2) {
234 let mut attrs = HashMap::new();
235 attrs.insert("axis".to_string(), OnnxAttribute::Int(1));
236 OnnxNode {
237 op_type: "Flatten".to_string(),
238 name: node_name,
239 inputs: input_names,
240 outputs: vec![output_name.clone()],
241 attributes: attrs,
242 }
243 } else {
244 let mut attrs = HashMap::new();
245 attrs.insert("axis".to_string(), OnnxAttribute::Int(*start_dim as i64));
246 OnnxNode {
247 op_type: "Flatten".to_string(),
248 name: node_name,
249 inputs: input_names,
250 outputs: vec![output_name.clone()],
251 attributes: attrs,
252 }
253 }
254 }
255 OpKind::Reshape { shape } => {
256 let shape_name = format!("{}_shape_const", output_name);
260
261 let shape_tensor = OnnxTensor {
263 name: shape_name.clone(),
264 dtype: 7, shape: vec![shape.len()],
266 data: unsafe {
267 let i64s: Vec<i64> = shape.iter().map(|&x| x as i64).collect();
268 let ptr = i64s.as_ptr() as *const u8;
269 let len = i64s.len() * 8;
270 std::slice::from_raw_parts(ptr, len).to_vec()
271 },
272 };
273 self.model.add_initializer(shape_tensor);
274
275 let mut node_inputs = input_names.clone();
277 node_inputs.push(shape_name);
278
279 OnnxNode {
280 op_type: "Reshape".to_string(),
281 name: node_name,
282 inputs: node_inputs,
283 outputs: vec![output_name.clone()],
284 attributes: HashMap::new(),
285 }
286 }
287 OpKind::BatchNorm {
288 training: _,
289 momentum,
290 eps,
291 } => {
292 let mut attrs = HashMap::new();
293 attrs.insert("epsilon".to_string(), OnnxAttribute::Float(*eps));
294 attrs.insert("momentum".to_string(), OnnxAttribute::Float(*momentum));
295
296 OnnxNode {
297 op_type: "BatchNormalization".to_string(),
298 name: node_name,
299 inputs: input_names,
300 outputs: vec![output_name.clone()],
301 attributes: attrs,
302 }
303 }
304 OpKind::Dropout { p, training: _ } => {
305 let mut attrs = HashMap::new();
306 attrs.insert("ratio".to_string(), OnnxAttribute::Float(*p));
307 OnnxNode {
308 op_type: "Dropout".to_string(),
309 name: node_name,
310 inputs: input_names,
311 outputs: vec![output_name.clone()],
312 attributes: attrs,
313 }
314 }
315 _ => return Err(anyhow!("Unsupported op for export: {:?}", node.op)),
317 };
318
319 self.model.add_node(onnx_node);
320 self.node_outputs
321 .insert(self.name_counter, output_name.clone()); Ok(output_name)
324 }
325}
326
327pub fn export_to_json(output: &Tensor) -> Result<String> {
331 let mut exporter = GraphExporter::new("exported_model");
332
333 let output_name = exporter.traverse(output)?;
335
336 exporter.model.set_outputs(vec![output_name]);
338
339 crate::ops::model::serialize_onnx(&exporter.model)
341}
342
343pub fn export_to_onnx(output: &Tensor, path: &str) -> Result<()> {
352 let json = export_to_json(output)?;
353 std::fs::write(path, json)?;
354 Ok(())
355}