quantrs2_ml/
onnx_export.rs

1//! ONNX model export support for QuantRS2-ML
2//!
3//! This module provides functionality to export quantum ML models to the ONNX format,
4//! enabling interoperability with other ML frameworks and deployment platforms.
5
6use crate::error::{MLError, Result};
7use crate::keras_api::{
8    Activation, ActivationFunction, Dense, KerasLayer, QuantumDense, Sequential,
9};
10use crate::pytorch_api::{QuantumLinear, QuantumModule, QuantumSequential};
11use crate::simulator_backends::DynamicCircuit;
12use ndarray::{Array1, Array2, ArrayD};
13use quantrs2_circuit::prelude::*;
14use std::collections::HashMap;
15use std::io::Write;
16
17/// ONNX graph representation
18#[derive(Debug, Clone)]
19pub struct ONNXGraph {
20    /// Graph nodes
21    nodes: Vec<ONNXNode>,
22    /// Graph inputs
23    inputs: Vec<ONNXValueInfo>,
24    /// Graph outputs
25    outputs: Vec<ONNXValueInfo>,
26    /// Graph initializers (weights)
27    initializers: Vec<ONNXTensor>,
28    /// Graph name
29    name: String,
30}
31
32impl ONNXGraph {
33    /// Create new ONNX graph
34    pub fn new(name: impl Into<String>) -> Self {
35        Self {
36            nodes: Vec::new(),
37            inputs: Vec::new(),
38            outputs: Vec::new(),
39            initializers: Vec::new(),
40            name: name.into(),
41        }
42    }
43
44    /// Add node to graph
45    pub fn add_node(&mut self, node: ONNXNode) {
46        self.nodes.push(node);
47    }
48
49    /// Add input to graph
50    pub fn add_input(&mut self, input: ONNXValueInfo) {
51        self.inputs.push(input);
52    }
53
54    /// Add output to graph
55    pub fn add_output(&mut self, output: ONNXValueInfo) {
56        self.outputs.push(output);
57    }
58
59    /// Add initializer to graph
60    pub fn add_initializer(&mut self, initializer: ONNXTensor) {
61        self.initializers.push(initializer);
62    }
63
64    /// Export graph to ONNX format
65    pub fn export(&self, path: &str) -> Result<()> {
66        let onnx_proto = self.to_onnx_proto()?;
67
68        std::fs::write(path, onnx_proto)?;
69        Ok(())
70    }
71
72    /// Convert to ONNX protobuf format (simplified)
73    fn to_onnx_proto(&self) -> Result<Vec<u8>> {
74        // This is a simplified representation of ONNX protobuf
75        // In a real implementation, you would use the official ONNX protobuf schema
76
77        let mut buffer = Vec::new();
78
79        // Write ONNX header
80        writeln!(buffer, "ONNX Model Export")?;
81        writeln!(buffer, "Graph Name: {}", self.name)?;
82        writeln!(buffer, "")?;
83
84        // Write inputs
85        writeln!(buffer, "Inputs:")?;
86        for input in &self.inputs {
87            writeln!(buffer, "  {}: {:?}", input.name, input.shape)?;
88        }
89        writeln!(buffer, "")?;
90
91        // Write outputs
92        writeln!(buffer, "Outputs:")?;
93        for output in &self.outputs {
94            writeln!(buffer, "  {}: {:?}", output.name, output.shape)?;
95        }
96        writeln!(buffer, "")?;
97
98        // Write nodes
99        writeln!(buffer, "Nodes:")?;
100        for node in &self.nodes {
101            writeln!(
102                buffer,
103                "  {} ({}): {} -> {}",
104                node.name,
105                node.op_type,
106                node.inputs.join(", "),
107                node.outputs.join(", ")
108            )?;
109        }
110        writeln!(buffer, "")?;
111
112        // Write initializers
113        writeln!(buffer, "Initializers:")?;
114        for init in &self.initializers {
115            writeln!(buffer, "  {}: {:?}", init.name, init.shape)?;
116        }
117
118        Ok(buffer)
119    }
120}
121
122/// ONNX node representation
123#[derive(Debug, Clone)]
124pub struct ONNXNode {
125    /// Node name
126    name: String,
127    /// Operator type
128    op_type: String,
129    /// Input names
130    inputs: Vec<String>,
131    /// Output names
132    outputs: Vec<String>,
133    /// Node attributes
134    attributes: HashMap<String, ONNXAttribute>,
135}
136
137impl ONNXNode {
138    /// Create new ONNX node
139    pub fn new(
140        name: impl Into<String>,
141        op_type: impl Into<String>,
142        inputs: Vec<String>,
143        outputs: Vec<String>,
144    ) -> Self {
145        Self {
146            name: name.into(),
147            op_type: op_type.into(),
148            inputs,
149            outputs,
150            attributes: HashMap::new(),
151        }
152    }
153
154    /// Add attribute to node
155    pub fn add_attribute(&mut self, name: impl Into<String>, value: ONNXAttribute) {
156        self.attributes.insert(name.into(), value);
157    }
158}
159
160/// ONNX attribute types
161#[derive(Debug, Clone)]
162pub enum ONNXAttribute {
163    /// Integer attribute
164    Int(i64),
165    /// Float attribute
166    Float(f32),
167    /// String attribute
168    String(String),
169    /// Tensor attribute
170    Tensor(ONNXTensor),
171    /// Integer array
172    Ints(Vec<i64>),
173    /// Float array
174    Floats(Vec<f32>),
175    /// String array
176    Strings(Vec<String>),
177}
178
179/// ONNX value info (for inputs/outputs)
180#[derive(Debug, Clone)]
181pub struct ONNXValueInfo {
182    /// Value name
183    name: String,
184    /// Data type
185    data_type: ONNXDataType,
186    /// Shape
187    shape: Vec<i64>,
188}
189
190impl ONNXValueInfo {
191    /// Create new value info
192    pub fn new(name: impl Into<String>, data_type: ONNXDataType, shape: Vec<i64>) -> Self {
193        Self {
194            name: name.into(),
195            data_type,
196            shape,
197        }
198    }
199}
200
201/// ONNX data types
202#[derive(Debug, Clone)]
203pub enum ONNXDataType {
204    /// Float32
205    Float32,
206    /// Float64
207    Float64,
208    /// Int32
209    Int32,
210    /// Int64
211    Int64,
212    /// Bool
213    Bool,
214}
215
216/// ONNX tensor representation
217#[derive(Debug, Clone)]
218pub struct ONNXTensor {
219    /// Tensor name
220    name: String,
221    /// Data type
222    data_type: ONNXDataType,
223    /// Shape
224    shape: Vec<i64>,
225    /// Raw data
226    data: Vec<u8>,
227}
228
229impl ONNXTensor {
230    /// Create tensor from ndarray
231    pub fn from_array_f32(name: impl Into<String>, array: &ArrayD<f32>) -> Self {
232        let shape: Vec<i64> = array.shape().iter().map(|&s| s as i64).collect();
233        let data = array
234            .as_slice()
235            .unwrap()
236            .iter()
237            .flat_map(|&f| f.to_le_bytes())
238            .collect();
239
240        Self {
241            name: name.into(),
242            data_type: ONNXDataType::Float32,
243            shape,
244            data,
245        }
246    }
247
248    /// Create tensor from ndarray (f64)
249    pub fn from_array_f64(name: impl Into<String>, array: &ArrayD<f64>) -> Self {
250        let shape: Vec<i64> = array.shape().iter().map(|&s| s as i64).collect();
251        let data = array.as_slice().unwrap().iter()
252            .flat_map(|&f| (f as f32).to_le_bytes()) // Convert to f32 for ONNX
253            .collect();
254
255        Self {
256            name: name.into(),
257            data_type: ONNXDataType::Float32,
258            shape,
259            data,
260        }
261    }
262}
263
264/// ONNX exporter for quantum ML models
265pub struct ONNXExporter {
266    /// Quantum operator mappings
267    quantum_mappings: HashMap<String, String>,
268    /// Export options
269    options: ExportOptions,
270}
271
272/// Export options
273#[derive(Debug, Clone)]
274pub struct ExportOptions {
275    /// ONNX opset version
276    opset_version: i64,
277    /// Include quantum layers as custom operators
278    include_quantum_ops: bool,
279    /// Optimize classical layers only
280    optimize_classical_only: bool,
281    /// Target backend for quantum operations
282    quantum_backend: QuantumBackendTarget,
283}
284
285impl Default for ExportOptions {
286    fn default() -> Self {
287        Self {
288            opset_version: 11,
289            include_quantum_ops: true,
290            optimize_classical_only: false,
291            quantum_backend: QuantumBackendTarget::Generic,
292        }
293    }
294}
295
296/// Quantum backend targets for export
297#[derive(Debug, Clone)]
298pub enum QuantumBackendTarget {
299    /// Generic quantum backend
300    Generic,
301    /// Qiskit-compatible
302    Qiskit,
303    /// Cirq-compatible
304    Cirq,
305    /// PennyLane-compatible
306    PennyLane,
307    /// Custom backend
308    Custom(String),
309}
310
311impl ONNXExporter {
312    /// Create new ONNX exporter
313    pub fn new() -> Self {
314        let mut quantum_mappings = HashMap::new();
315
316        // Map quantum operations to ONNX custom operators
317        quantum_mappings.insert("QuantumDense".to_string(), "QuantumDense".to_string());
318        quantum_mappings.insert("QuantumLinear".to_string(), "QuantumLinear".to_string());
319        quantum_mappings.insert("QuantumConv2d".to_string(), "QuantumConv2d".to_string());
320        quantum_mappings.insert("QuantumRNN".to_string(), "QuantumRNN".to_string());
321
322        Self {
323            quantum_mappings,
324            options: ExportOptions::default(),
325        }
326    }
327
328    /// Set export options
329    pub fn with_options(mut self, options: ExportOptions) -> Self {
330        self.options = options;
331        self
332    }
333
334    /// Export Sequential model to ONNX
335    pub fn export_sequential(
336        &self,
337        model: &Sequential,
338        input_shape: &[usize],
339        output_path: &str,
340    ) -> Result<()> {
341        let mut graph = ONNXGraph::new("sequential_model");
342
343        // Add input
344        let input_shape_i64: Vec<i64> = input_shape.iter().map(|&s| s as i64).collect();
345        graph.add_input(ONNXValueInfo::new(
346            "input",
347            ONNXDataType::Float32,
348            input_shape_i64,
349        ));
350
351        let mut current_output = "input".to_string();
352        let mut node_counter = 0;
353
354        // Convert each layer
355        for layer in model.layers() {
356            let layer_name = format!("layer_{}", node_counter);
357            let output_name = format!("output_{}", node_counter);
358
359            // Convert layer based on type
360            let (nodes, initializers) =
361                self.convert_layer(layer.as_ref(), &layer_name, &current_output, &output_name)?;
362
363            // Add nodes and initializers to graph
364            for node in nodes {
365                graph.add_node(node);
366            }
367            for init in initializers {
368                graph.add_initializer(init);
369            }
370
371            current_output = output_name;
372            node_counter += 1;
373        }
374
375        // Add output
376        let output_shape = model.compute_output_shape(input_shape);
377        let output_shape_i64: Vec<i64> = output_shape.iter().map(|&s| s as i64).collect();
378        graph.add_output(ONNXValueInfo::new(
379            &current_output,
380            ONNXDataType::Float32,
381            output_shape_i64,
382        ));
383
384        // Export graph
385        graph.export(output_path)?;
386        Ok(())
387    }
388
389    /// Export PyTorch-style model to ONNX
390    pub fn export_pytorch_model<T: QuantumModule>(
391        &self,
392        model: &T,
393        input_shape: &[usize],
394        output_path: &str,
395    ) -> Result<()> {
396        let mut graph = ONNXGraph::new("pytorch_model");
397
398        // Add input
399        let input_shape_i64: Vec<i64> = input_shape.iter().map(|&s| s as i64).collect();
400        graph.add_input(ONNXValueInfo::new(
401            "input",
402            ONNXDataType::Float32,
403            input_shape_i64,
404        ));
405
406        // Convert model (simplified - would need more complex analysis)
407        let node = ONNXNode::new(
408            "pytorch_model",
409            "QuantumModel",
410            vec!["input".to_string()],
411            vec!["output".to_string()],
412        );
413        graph.add_node(node);
414
415        // Add output (would need to compute actual output shape)
416        graph.add_output(ONNXValueInfo::new(
417            "output",
418            ONNXDataType::Float32,
419            vec![1, 1], // Placeholder
420        ));
421
422        // Export graph
423        graph.export(output_path)?;
424        Ok(())
425    }
426
427    /// Convert layer to ONNX nodes and initializers
428    fn convert_layer(
429        &self,
430        layer: &dyn KerasLayer,
431        layer_name: &str,
432        input_name: &str,
433        output_name: &str,
434    ) -> Result<(Vec<ONNXNode>, Vec<ONNXTensor>)> {
435        // This would need to be implemented for each layer type
436        // For now, we'll provide a simplified conversion
437
438        let layer_type = self.get_layer_type(layer);
439
440        match layer_type.as_str() {
441            "Dense" => self.convert_dense_layer(layer, layer_name, input_name, output_name),
442            "QuantumDense" => {
443                self.convert_quantum_dense_layer(layer, layer_name, input_name, output_name)
444            }
445            "Activation" => {
446                self.convert_activation_layer(layer, layer_name, input_name, output_name)
447            }
448            _ => {
449                // Generic layer conversion
450                let node = ONNXNode::new(
451                    layer_name,
452                    &layer_type,
453                    vec![input_name.to_string()],
454                    vec![output_name.to_string()],
455                );
456                Ok((vec![node], vec![]))
457            }
458        }
459    }
460
461    /// Convert Dense layer
462    fn convert_dense_layer(
463        &self,
464        layer: &dyn KerasLayer,
465        layer_name: &str,
466        input_name: &str,
467        output_name: &str,
468    ) -> Result<(Vec<ONNXNode>, Vec<ONNXTensor>)> {
469        let weights = layer.get_weights();
470        let mut nodes = Vec::new();
471        let mut initializers = Vec::new();
472
473        if weights.len() >= 1 {
474            // Add weight initializer
475            let weight_name = format!("{}_weight", layer_name);
476            let weight_tensor = ONNXTensor::from_array_f64(&weight_name, &weights[0]);
477            initializers.push(weight_tensor);
478
479            // Create MatMul node
480            let mut matmul_inputs = vec![input_name.to_string(), weight_name];
481            let matmul_output = if weights.len() > 1 {
482                format!("{}_matmul", layer_name)
483            } else {
484                output_name.to_string()
485            };
486
487            let matmul_node = ONNXNode::new(
488                format!("{}_matmul", layer_name),
489                "MatMul",
490                matmul_inputs,
491                vec![matmul_output.clone()],
492            );
493            nodes.push(matmul_node);
494
495            // Add bias if present
496            if weights.len() > 1 {
497                let bias_name = format!("{}_bias", layer_name);
498                let bias_tensor = ONNXTensor::from_array_f64(&bias_name, &weights[1]);
499                initializers.push(bias_tensor);
500
501                let add_node = ONNXNode::new(
502                    format!("{}_add", layer_name),
503                    "Add",
504                    vec![matmul_output, bias_name],
505                    vec![output_name.to_string()],
506                );
507                nodes.push(add_node);
508            }
509        }
510
511        Ok((nodes, initializers))
512    }
513
514    /// Convert QuantumDense layer
515    fn convert_quantum_dense_layer(
516        &self,
517        layer: &dyn KerasLayer,
518        layer_name: &str,
519        input_name: &str,
520        output_name: &str,
521    ) -> Result<(Vec<ONNXNode>, Vec<ONNXTensor>)> {
522        if !self.options.include_quantum_ops {
523            return Err(MLError::InvalidConfiguration(
524                "Quantum operations not supported in export options".to_string(),
525            ));
526        }
527
528        let weights = layer.get_weights();
529        let mut nodes = Vec::new();
530        let mut initializers = Vec::new();
531
532        // Add quantum parameters as initializers
533        for (i, weight) in weights.iter().enumerate() {
534            let param_name = format!("{}_param_{}", layer_name, i);
535            let param_tensor = ONNXTensor::from_array_f64(&param_name, weight);
536            initializers.push(param_tensor);
537        }
538
539        // Create custom quantum node
540        let mut quantum_node = ONNXNode::new(
541            layer_name,
542            "QuantumDense",
543            vec![input_name.to_string()],
544            vec![output_name.to_string()],
545        );
546
547        // Add quantum-specific attributes
548        quantum_node.add_attribute(
549            "backend",
550            ONNXAttribute::String(format!("{:?}", self.options.quantum_backend)),
551        );
552        quantum_node.add_attribute("domain", ONNXAttribute::String("quantrs2.ml".to_string()));
553
554        nodes.push(quantum_node);
555
556        Ok((nodes, initializers))
557    }
558
559    /// Convert Activation layer
560    fn convert_activation_layer(
561        &self,
562        _layer: &dyn KerasLayer,
563        layer_name: &str,
564        input_name: &str,
565        output_name: &str,
566    ) -> Result<(Vec<ONNXNode>, Vec<ONNXTensor>)> {
567        // For now, assume ReLU activation
568        let node = ONNXNode::new(
569            layer_name,
570            "Relu",
571            vec![input_name.to_string()],
572            vec![output_name.to_string()],
573        );
574
575        Ok((vec![node], vec![]))
576    }
577
578    /// Get layer type string
579    fn get_layer_type(&self, _layer: &dyn KerasLayer) -> String {
580        // This would need to be implemented with proper type checking
581        // For now, return a placeholder
582        "Dense".to_string()
583    }
584}
585
586/// ONNX importer for loading models back into QuantRS2
587pub struct ONNXImporter {
588    /// Import options
589    options: ImportOptions,
590}
591
592/// Import options
593#[derive(Debug, Clone)]
594pub struct ImportOptions {
595    /// Target framework
596    target_framework: TargetFramework,
597    /// Handle unsupported operators
598    handle_unsupported: UnsupportedOpHandling,
599    /// Quantum backend for imported quantum ops
600    quantum_backend: QuantumBackendTarget,
601}
602
603/// Target frameworks for import
604#[derive(Debug, Clone)]
605pub enum TargetFramework {
606    /// Keras-style Sequential model
607    Keras,
608    /// PyTorch-style model
609    PyTorch,
610    /// Raw QuantRS2 model
611    QuantRS2,
612}
613
614/// How to handle unsupported operators
615#[derive(Debug, Clone)]
616pub enum UnsupportedOpHandling {
617    /// Raise error
618    Error,
619    /// Skip unsupported operators
620    Skip,
621    /// Replace with identity
622    Identity,
623    /// Custom handler
624    Custom(String),
625}
626
627impl Default for ImportOptions {
628    fn default() -> Self {
629        Self {
630            target_framework: TargetFramework::Keras,
631            handle_unsupported: UnsupportedOpHandling::Error,
632            quantum_backend: QuantumBackendTarget::Generic,
633        }
634    }
635}
636
637impl ONNXImporter {
638    /// Create new ONNX importer
639    pub fn new() -> Self {
640        Self {
641            options: ImportOptions::default(),
642        }
643    }
644
645    /// Set import options
646    pub fn with_options(mut self, options: ImportOptions) -> Self {
647        self.options = options;
648        self
649    }
650
651    /// Import ONNX model to Sequential model
652    pub fn import_to_sequential(&self, path: &str) -> Result<Sequential> {
653        let graph = self.load_onnx_graph(path)?;
654        self.convert_to_sequential(&graph)
655    }
656
657    /// Load ONNX graph from file
658    fn load_onnx_graph(&self, path: &str) -> Result<ONNXGraph> {
659        // This would parse the actual ONNX protobuf file
660        // For now, return a placeholder
661        Ok(ONNXGraph::new("imported_model"))
662    }
663
664    /// Convert ONNX graph to Sequential model
665    fn convert_to_sequential(&self, _graph: &ONNXGraph) -> Result<Sequential> {
666        // This would analyze the ONNX graph and recreate the Sequential model
667        // For now, return a simple model
668        Ok(Sequential::new())
669    }
670}
671
672/// Utility functions for ONNX export/import
673pub mod utils {
674    use super::*;
675
676    /// Validate ONNX model
677    pub fn validate_onnx_model(path: &str) -> Result<ValidationReport> {
678        // This would validate the ONNX model structure and operators
679        Ok(ValidationReport {
680            valid: true,
681            errors: Vec::new(),
682            warnings: Vec::new(),
683            quantum_ops_found: false,
684        })
685    }
686
687    /// Get ONNX model info
688    pub fn get_model_info(path: &str) -> Result<ModelInfo> {
689        // This would extract basic information about the ONNX model
690        Ok(ModelInfo {
691            opset_version: 11,
692            producer_name: "QuantRS2-ML".to_string(),
693            producer_version: "0.1.0".to_string(),
694            graph_name: "model".to_string(),
695            num_nodes: 0,
696            num_initializers: 0,
697            input_shapes: Vec::new(),
698            output_shapes: Vec::new(),
699        })
700    }
701
702    /// Convert quantum circuit to ONNX custom operator
703    pub fn circuit_to_onnx_op(circuit: &DynamicCircuit, name: &str) -> Result<ONNXNode> {
704        let mut node = ONNXNode::new(
705            name,
706            "QuantumCircuit",
707            vec!["input".to_string()],
708            vec!["output".to_string()],
709        );
710
711        // Add circuit-specific attributes
712        node.add_attribute(
713            "num_qubits",
714            ONNXAttribute::Int(circuit.num_qubits() as i64),
715        );
716        node.add_attribute("num_gates", ONNXAttribute::Int(circuit.num_gates() as i64));
717        node.add_attribute("depth", ONNXAttribute::Int(circuit.depth() as i64));
718
719        // Serialize circuit structure
720        let circuit_data = serialize_circuit(circuit)?;
721        node.add_attribute("circuit_data", ONNXAttribute::String(circuit_data));
722
723        Ok(node)
724    }
725
726    /// Serialize quantum circuit to string
727    fn serialize_circuit(circuit: &DynamicCircuit) -> Result<String> {
728        // This would serialize the circuit to a string format
729        // For now, return a placeholder
730        Ok("quantum_circuit_placeholder".to_string())
731    }
732
733    /// Create ONNX metadata for quantum ML model
734    pub fn create_quantum_metadata() -> HashMap<String, String> {
735        let mut metadata = HashMap::new();
736        metadata.insert("framework".to_string(), "QuantRS2-ML".to_string());
737        metadata.insert("domain".to_string(), "quantrs2.ml".to_string());
738        metadata.insert("version".to_string(), "0.1.0".to_string());
739        metadata.insert("quantum_support".to_string(), "true".to_string());
740        metadata
741    }
742}
743
744/// Validation report for ONNX models
745#[derive(Debug)]
746pub struct ValidationReport {
747    /// Model is valid
748    pub valid: bool,
749    /// Validation errors
750    pub errors: Vec<String>,
751    /// Validation warnings
752    pub warnings: Vec<String>,
753    /// Quantum operators found
754    pub quantum_ops_found: bool,
755}
756
757/// Model information
758#[derive(Debug)]
759pub struct ModelInfo {
760    /// ONNX opset version
761    pub opset_version: i64,
762    /// Producer name
763    pub producer_name: String,
764    /// Producer version
765    pub producer_version: String,
766    /// Graph name
767    pub graph_name: String,
768    /// Number of nodes
769    pub num_nodes: usize,
770    /// Number of initializers
771    pub num_initializers: usize,
772    /// Input shapes
773    pub input_shapes: Vec<Vec<i64>>,
774    /// Output shapes
775    pub output_shapes: Vec<Vec<i64>>,
776}
777
778// Extensions for Sequential model
779impl Sequential {
780    /// Export to ONNX format
781    pub fn export_onnx(
782        &self,
783        path: &str,
784        input_shape: &[usize],
785        options: Option<ExportOptions>,
786    ) -> Result<()> {
787        let exporter = ONNXExporter::new();
788        let exporter = if let Some(opts) = options {
789            exporter.with_options(opts)
790        } else {
791            exporter
792        };
793
794        exporter.export_sequential(self, input_shape, path)
795    }
796
797    /// Get layers (placeholder - would need actual implementation)
798    fn layers(&self) -> &[Box<dyn KerasLayer>] {
799        // This would return the actual layers from the Sequential model
800        &[]
801    }
802
803    /// Compute output shape (placeholder)
804    fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
805        // This would compute the actual output shape
806        input_shape.to_vec()
807    }
808}
809
810#[cfg(test)]
811mod tests {
812    use super::*;
813    use crate::keras_api::{ActivationFunction, Dense};
814
815    #[test]
816    fn test_onnx_graph_creation() {
817        let mut graph = ONNXGraph::new("test_graph");
818
819        graph.add_input(ONNXValueInfo::new(
820            "input",
821            ONNXDataType::Float32,
822            vec![1, 10],
823        ));
824
825        graph.add_output(ONNXValueInfo::new(
826            "output",
827            ONNXDataType::Float32,
828            vec![1, 5],
829        ));
830
831        let node = ONNXNode::new(
832            "dense_layer",
833            "MatMul",
834            vec!["input".to_string(), "weight".to_string()],
835            vec!["output".to_string()],
836        );
837        graph.add_node(node);
838
839        assert_eq!(graph.nodes.len(), 1);
840        assert_eq!(graph.inputs.len(), 1);
841        assert_eq!(graph.outputs.len(), 1);
842    }
843
844    #[test]
845    fn test_onnx_tensor_creation() {
846        let array = ndarray::Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
847            .unwrap()
848            .into_dyn();
849
850        let tensor = ONNXTensor::from_array_f64("test_tensor", &array);
851        assert_eq!(tensor.name, "test_tensor");
852        assert_eq!(tensor.shape, vec![2, 3]);
853    }
854
855    #[test]
856    fn test_onnx_exporter_creation() {
857        let exporter = ONNXExporter::new();
858        let options = ExportOptions {
859            opset_version: 13,
860            include_quantum_ops: false,
861            optimize_classical_only: true,
862            quantum_backend: QuantumBackendTarget::Qiskit,
863        };
864
865        let exporter = exporter.with_options(options);
866        assert_eq!(exporter.options.opset_version, 13);
867        assert!(!exporter.options.include_quantum_ops);
868    }
869
870    #[test]
871    fn test_onnx_node_attributes() {
872        let mut node = ONNXNode::new(
873            "test_node",
874            "Conv",
875            vec!["input".to_string()],
876            vec!["output".to_string()],
877        );
878
879        node.add_attribute("kernel_shape", ONNXAttribute::Ints(vec![3, 3]));
880        node.add_attribute("strides", ONNXAttribute::Ints(vec![1, 1]));
881
882        assert_eq!(node.attributes.len(), 2);
883    }
884
885    #[test]
886    fn test_validation_utils() {
887        let report = utils::validate_onnx_model("dummy_path");
888        assert!(report.is_ok());
889
890        let info = utils::get_model_info("dummy_path");
891        assert!(info.is_ok());
892    }
893}