numrs/ops/
export.rs

1//! ONNX Export Module
2//!
3//! Provides functionality to automatically export Autograd graphs to ONNX.
4
5use crate::autograd::{OpKind, Tensor};
6use crate::llo::{array_to_onnx_tensor, OnnxAttribute, OnnxModel, OnnxNode, OnnxTensor};
7use anyhow::{anyhow, Result};
8use std::collections::{HashMap, HashSet};
9
10/// Exporter state context
11struct GraphExporter {
12    model: OnnxModel,
13    _visited: HashSet<usize>,             // Visited NodeIds
14    node_outputs: HashMap<usize, String>, // NodeId -> Output Name
15    tensor_names: HashMap<usize, String>, // Tensor ID (pointer) -> Name
16    name_counter: usize,
17}
18
19impl GraphExporter {
20    fn new(name: &str) -> Self {
21        GraphExporter {
22            model: OnnxModel::new(name),
23            _visited: HashSet::new(),
24            node_outputs: HashMap::new(),
25            tensor_names: HashMap::new(),
26            name_counter: 0,
27        }
28    }
29
30    fn get_tensor_name(&mut self, tensor: &Tensor) -> String {
31        // Use tensor's data pointer as unique ID roughly
32        let id = tensor.data.data.as_ptr() as usize;
33
34        if let Some(name) = self.tensor_names.get(&id) {
35            return name.clone();
36        }
37
38        let name = format!("tensor_{}", self.name_counter);
39        self.name_counter += 1;
40        self.tensor_names.insert(id, name.clone());
41        name
42    }
43
44    fn traverse(&mut self, tensor: &Tensor) -> Result<String> {
45        let output_name = self.get_tensor_name(tensor);
46
47        // If leaf (no compute node), it's either Input or Weight
48        if tensor.compute_node.is_none() {
49            if self
50                .model
51                .graph
52                .inputs
53                .iter()
54                .any(|i| i.name == output_name)
55                || self
56                    .model
57                    .graph
58                    .initializers
59                    .iter()
60                    .any(|i| i.name == output_name)
61            {
62                return Ok(output_name);
63            }
64
65            if tensor.requires_grad {
66                // Weight (Parameter)
67                let onnx_tensor = array_to_onnx_tensor(&output_name, &tensor.data)?;
68                self.model.add_initializer(onnx_tensor);
69            } else {
70                // Input
71                let input = OnnxTensor {
72                    name: output_name.clone(),
73                    dtype: 1, // FLOAT
74                    shape: tensor.shape().to_vec(),
75                    data: vec![], // Inputs don't have data
76                };
77                self.model.add_input(input);
78            }
79            return Ok(output_name);
80        }
81
82        // If compute node, traverse recurrently
83        let node = tensor.compute_node.as_ref().unwrap();
84        // (We assume NodeId is unique per operation instance)
85        // Accessing private field 'id' might be tricky if not pub, checking mod.rs...
86        // Assuming we can check visitation by output name existence for now
87        // A better way is needed if multiple outputs per node, but here 1 tensor = 1 node mostly.
88
89        // Check if we already visited this operation by checking if its output is defined?
90        // Actually, Autograd structure in NumRs is: Tensor HAS A Node.
91        // So visiting the Tensor visits the Node.
92        // We need to avoid re-emitting the same node if multiple tensors point to it (not common here)
93        // OR if the graph splits and joins.
94
95        // To handle diamond dependencies, we should check availability of inputs.
96        // But here we are traversing BACKWARDS from output.
97        // Better approach:
98        // 1. Traverse recursively to inputs.
99        // 2. Once inputs are ready (names returned), emit THIS node.
100
101        // Check cycle/visited
102        // Using output_name as unique key for "tensor computed status"
103        if self.node_outputs.values().any(|n| n == &output_name) {
104            return Ok(output_name);
105        }
106
107        let mut input_names = Vec::new();
108        for input in &node.inputs {
109            input_names.push(self.traverse(input)?);
110        }
111
112        // Emit ONNX Node
113        let node_name = format!("node_{}_{}", self.name_counter, output_name);
114
115        let onnx_node = match &node.op {
116            OpKind::Add => OnnxNode {
117                op_type: "Add".to_string(),
118                name: node_name,
119                inputs: input_names,
120                outputs: vec![output_name.clone()],
121                attributes: HashMap::new(),
122            },
123            OpKind::Mul => OnnxNode {
124                op_type: "Mul".to_string(),
125                name: node_name,
126                inputs: input_names,
127                outputs: vec![output_name.clone()],
128                attributes: HashMap::new(),
129            },
130            OpKind::MatMul => OnnxNode {
131                op_type: "MatMul".to_string(),
132                name: node_name,
133                inputs: input_names,
134                outputs: vec![output_name.clone()],
135                attributes: HashMap::new(),
136            },
137            OpKind::ReLU => OnnxNode {
138                op_type: "Relu".to_string(),
139                name: node_name,
140                inputs: input_names,
141                outputs: vec![output_name.clone()],
142                attributes: HashMap::new(),
143            },
144            OpKind::Sigmoid => OnnxNode {
145                op_type: "Sigmoid".to_string(),
146                name: node_name,
147                inputs: input_names,
148                outputs: vec![output_name.clone()],
149                attributes: HashMap::new(),
150            },
151            OpKind::Softmax => {
152                let mut attrs = HashMap::new();
153                attrs.insert("axis".to_string(), OnnxAttribute::Int(1));
154                OnnxNode {
155                    op_type: "Softmax".to_string(),
156                    name: node_name,
157                    inputs: input_names,
158                    outputs: vec![output_name.clone()],
159                    attributes: attrs,
160                }
161            }
162            OpKind::Log => OnnxNode {
163                op_type: "Log".to_string(),
164                name: node_name,
165                inputs: input_names,
166                outputs: vec![output_name.clone()],
167                attributes: HashMap::new(),
168            },
169            OpKind::Exp => OnnxNode {
170                op_type: "Exp".to_string(),
171                name: node_name,
172                inputs: input_names,
173                outputs: vec![output_name.clone()],
174                attributes: HashMap::new(),
175            },
176            OpKind::Sub => OnnxNode {
177                op_type: "Sub".to_string(),
178                name: node_name,
179                inputs: input_names,
180                outputs: vec![output_name.clone()],
181                attributes: HashMap::new(),
182            },
183            OpKind::Div => OnnxNode {
184                op_type: "Div".to_string(),
185                name: node_name,
186                inputs: input_names,
187                outputs: vec![output_name.clone()],
188                attributes: HashMap::new(),
189            },
190            OpKind::Transpose => {
191                let mut attrs = HashMap::new();
192                // Default transpose in autograd is likely 2D swap (perm=[1,0])
193                attrs.insert("perm".to_string(), OnnxAttribute::Ints(vec![1, 0]));
194                OnnxNode {
195                    op_type: "Transpose".to_string(),
196                    name: node_name,
197                    inputs: input_names,
198                    outputs: vec![output_name.clone()],
199                    attributes: attrs,
200                }
201            }
202            OpKind::Conv1D { stride, padding } => {
203                let mut attrs = HashMap::new();
204                attrs.insert(
205                    "strides".to_string(),
206                    OnnxAttribute::Ints(vec![*stride as i64]),
207                );
208                attrs.insert(
209                    "pads".to_string(),
210                    OnnxAttribute::Ints(vec![*padding as i64, *padding as i64]),
211                ); // Start, End
212
213                // Get kernel_shape from weight (input[1])
214                if node.inputs.len() >= 2 {
215                    let weight = &node.inputs[1];
216                    let k = weight.shape()[2];
217                    attrs.insert(
218                        "kernel_shape".to_string(),
219                        OnnxAttribute::Ints(vec![k as i64]),
220                    );
221                }
222
223                OnnxNode {
224                    op_type: "Conv".to_string(),
225                    name: node_name,
226                    inputs: input_names,
227                    outputs: vec![output_name.clone()],
228                    attributes: attrs,
229                }
230            }
231            OpKind::Flatten { start_dim, end_dim } => {
232                // If standard Flatten(1, -1):
233                if *start_dim == 1 && (*end_dim == usize::MAX || *end_dim == 2) {
234                    let mut attrs = HashMap::new();
235                    attrs.insert("axis".to_string(), OnnxAttribute::Int(1));
236                    OnnxNode {
237                        op_type: "Flatten".to_string(),
238                        name: node_name,
239                        inputs: input_names,
240                        outputs: vec![output_name.clone()],
241                        attributes: attrs,
242                    }
243                } else {
244                    let mut attrs = HashMap::new();
245                    attrs.insert("axis".to_string(), OnnxAttribute::Int(*start_dim as i64));
246                    OnnxNode {
247                        op_type: "Flatten".to_string(),
248                        name: node_name,
249                        inputs: input_names,
250                        outputs: vec![output_name.clone()],
251                        attributes: attrs,
252                    }
253                }
254            }
255            OpKind::Reshape { shape } => {
256                // Reshape consumes: Data, Shape (as a Tensor!)
257
258                // 1. Create Shape Initializer
259                let shape_name = format!("{}_shape_const", output_name);
260
261                // Manually construct OnnxTensor
262                let shape_tensor = OnnxTensor {
263                    name: shape_name.clone(),
264                    dtype: 7, // INT64
265                    shape: vec![shape.len()],
266                    data: unsafe {
267                        let i64s: Vec<i64> = shape.iter().map(|&x| x as i64).collect();
268                        let ptr = i64s.as_ptr() as *const u8;
269                        let len = i64s.len() * 8;
270                        std::slice::from_raw_parts(ptr, len).to_vec()
271                    },
272                };
273                self.model.add_initializer(shape_tensor);
274
275                // 2. Add input dependency
276                let mut node_inputs = input_names.clone();
277                node_inputs.push(shape_name);
278
279                OnnxNode {
280                    op_type: "Reshape".to_string(),
281                    name: node_name,
282                    inputs: node_inputs,
283                    outputs: vec![output_name.clone()],
284                    attributes: HashMap::new(),
285                }
286            }
287            OpKind::BatchNorm {
288                training: _,
289                momentum,
290                eps,
291            } => {
292                let mut attrs = HashMap::new();
293                attrs.insert("epsilon".to_string(), OnnxAttribute::Float(*eps));
294                attrs.insert("momentum".to_string(), OnnxAttribute::Float(*momentum));
295
296                OnnxNode {
297                    op_type: "BatchNormalization".to_string(),
298                    name: node_name,
299                    inputs: input_names,
300                    outputs: vec![output_name.clone()],
301                    attributes: attrs,
302                }
303            }
304            OpKind::Dropout { p, training: _ } => {
305                let mut attrs = HashMap::new();
306                attrs.insert("ratio".to_string(), OnnxAttribute::Float(*p));
307                OnnxNode {
308                    op_type: "Dropout".to_string(),
309                    name: node_name,
310                    inputs: input_names,
311                    outputs: vec![output_name.clone()],
312                    attributes: attrs,
313                }
314            }
315            // Handle specific complex ops or fallbacks
316            _ => return Err(anyhow!("Unsupported op for export: {:?}", node.op)),
317        };
318
319        self.model.add_node(onnx_node);
320        self.node_outputs
321            .insert(self.name_counter, output_name.clone()); // Dummy ID usage
322
323        Ok(output_name)
324    }
325}
326
327/// Export a tensor's computational graph to ONNX JSON string
328///
329/// Use this when file system access is not available (e.g. WASM).
330pub fn export_to_json(output: &Tensor) -> Result<String> {
331    let mut exporter = GraphExporter::new("exported_model");
332
333    // Traverse graph to populate model
334    let output_name = exporter.traverse(output)?;
335
336    // Set output
337    exporter.model.set_outputs(vec![output_name]);
338
339    // Serialize
340    crate::ops::model::serialize_onnx(&exporter.model)
341}
342
343/// Export a tensor's computational graph to ONNX
344///
345/// Automatically traverses the graph backwards from `output`, identifying
346/// parameters (requires_grad=true) and inputs (requires_grad=false).
347///
348/// # Arguments
349/// * `output` - The output tensor of the model (must be a computed tensor)
350/// * `path` - Path to save the .onnx (json) file
351pub fn export_to_onnx(output: &Tensor, path: &str) -> Result<()> {
352    let json = export_to_json(output)?;
353    std::fs::write(path, json)?;
354    Ok(())
355}