quantrs2_ml/
onnx_export.rs

1//! ONNX model export support for QuantRS2-ML
2//!
3//! This module provides functionality to export quantum ML models to the ONNX format,
4//! enabling interoperability with other ML frameworks and deployment platforms.
5
6use crate::error::{MLError, Result};
7use crate::keras_api::{
8    Activation, ActivationFunction, Dense, KerasLayer, QuantumDense, Sequential,
9};
10use crate::pytorch_api::{QuantumLinear, QuantumModule, QuantumSequential};
11use crate::simulator_backends::DynamicCircuit;
12use quantrs2_circuit::prelude::*;
13use scirs2_core::ndarray::{Array1, Array2, ArrayD};
14use std::collections::HashMap;
15use std::io::Write;
16
17/// ONNX graph representation
18#[derive(Debug, Clone)]
19pub struct ONNXGraph {
20    /// Graph nodes
21    nodes: Vec<ONNXNode>,
22    /// Graph inputs
23    inputs: Vec<ONNXValueInfo>,
24    /// Graph outputs
25    outputs: Vec<ONNXValueInfo>,
26    /// Graph initializers (weights)
27    initializers: Vec<ONNXTensor>,
28    /// Graph name
29    name: String,
30}
31
32impl ONNXGraph {
33    /// Create new ONNX graph
34    pub fn new(name: impl Into<String>) -> Self {
35        Self {
36            nodes: Vec::new(),
37            inputs: Vec::new(),
38            outputs: Vec::new(),
39            initializers: Vec::new(),
40            name: name.into(),
41        }
42    }
43
44    /// Add node to graph
45    pub fn add_node(&mut self, node: ONNXNode) {
46        self.nodes.push(node);
47    }
48
49    /// Add input to graph
50    pub fn add_input(&mut self, input: ONNXValueInfo) {
51        self.inputs.push(input);
52    }
53
54    /// Add output to graph
55    pub fn add_output(&mut self, output: ONNXValueInfo) {
56        self.outputs.push(output);
57    }
58
59    /// Add initializer to graph
60    pub fn add_initializer(&mut self, initializer: ONNXTensor) {
61        self.initializers.push(initializer);
62    }
63
64    /// Export graph to ONNX format
65    pub fn export(&self, path: &str) -> Result<()> {
66        let onnx_proto = self.to_onnx_proto()?;
67
68        std::fs::write(path, onnx_proto)?;
69        Ok(())
70    }
71
72    /// Convert to ONNX protobuf format (simplified)
73    fn to_onnx_proto(&self) -> Result<Vec<u8>> {
74        // This is a simplified representation of ONNX protobuf
75        // In a real implementation, you would use the official ONNX protobuf schema
76
77        let mut buffer = Vec::new();
78
79        // Write ONNX header
80        writeln!(buffer, "ONNX Model Export")?;
81        writeln!(buffer, "Graph Name: {}", self.name)?;
82        writeln!(buffer, "")?;
83
84        // Write inputs
85        writeln!(buffer, "Inputs:")?;
86        for input in &self.inputs {
87            writeln!(buffer, "  {}: {:?}", input.name, input.shape)?;
88        }
89        writeln!(buffer, "")?;
90
91        // Write outputs
92        writeln!(buffer, "Outputs:")?;
93        for output in &self.outputs {
94            writeln!(buffer, "  {}: {:?}", output.name, output.shape)?;
95        }
96        writeln!(buffer, "")?;
97
98        // Write nodes
99        writeln!(buffer, "Nodes:")?;
100        for node in &self.nodes {
101            writeln!(
102                buffer,
103                "  {} ({}): {} -> {}",
104                node.name,
105                node.op_type,
106                node.inputs.join(", "),
107                node.outputs.join(", ")
108            )?;
109        }
110        writeln!(buffer, "")?;
111
112        // Write initializers
113        writeln!(buffer, "Initializers:")?;
114        for init in &self.initializers {
115            writeln!(buffer, "  {}: {:?}", init.name, init.shape)?;
116        }
117
118        Ok(buffer)
119    }
120}
121
122/// ONNX node representation
123#[derive(Debug, Clone)]
124pub struct ONNXNode {
125    /// Node name
126    name: String,
127    /// Operator type
128    op_type: String,
129    /// Input names
130    inputs: Vec<String>,
131    /// Output names
132    outputs: Vec<String>,
133    /// Node attributes
134    attributes: HashMap<String, ONNXAttribute>,
135}
136
137impl ONNXNode {
138    /// Create new ONNX node
139    pub fn new(
140        name: impl Into<String>,
141        op_type: impl Into<String>,
142        inputs: Vec<String>,
143        outputs: Vec<String>,
144    ) -> Self {
145        Self {
146            name: name.into(),
147            op_type: op_type.into(),
148            inputs,
149            outputs,
150            attributes: HashMap::new(),
151        }
152    }
153
154    /// Add attribute to node
155    pub fn add_attribute(&mut self, name: impl Into<String>, value: ONNXAttribute) {
156        self.attributes.insert(name.into(), value);
157    }
158}
159
160/// ONNX attribute types
161#[derive(Debug, Clone)]
162pub enum ONNXAttribute {
163    /// Integer attribute
164    Int(i64),
165    /// Float attribute
166    Float(f32),
167    /// String attribute
168    String(String),
169    /// Tensor attribute
170    Tensor(ONNXTensor),
171    /// Integer array
172    Ints(Vec<i64>),
173    /// Float array
174    Floats(Vec<f32>),
175    /// String array
176    Strings(Vec<String>),
177}
178
179/// ONNX value info (for inputs/outputs)
180#[derive(Debug, Clone)]
181pub struct ONNXValueInfo {
182    /// Value name
183    name: String,
184    /// Data type
185    data_type: ONNXDataType,
186    /// Shape
187    shape: Vec<i64>,
188}
189
190impl ONNXValueInfo {
191    /// Create new value info
192    pub fn new(name: impl Into<String>, data_type: ONNXDataType, shape: Vec<i64>) -> Self {
193        Self {
194            name: name.into(),
195            data_type,
196            shape,
197        }
198    }
199}
200
201/// ONNX data types
202#[derive(Debug, Clone)]
203pub enum ONNXDataType {
204    /// Float32
205    Float32,
206    /// Float64
207    Float64,
208    /// Int32
209    Int32,
210    /// Int64
211    Int64,
212    /// Bool
213    Bool,
214}
215
216/// ONNX tensor representation
217#[derive(Debug, Clone)]
218pub struct ONNXTensor {
219    /// Tensor name
220    name: String,
221    /// Data type
222    data_type: ONNXDataType,
223    /// Shape
224    shape: Vec<i64>,
225    /// Raw data
226    data: Vec<u8>,
227}
228
229impl ONNXTensor {
230    /// Create tensor from ndarray
231    pub fn from_array_f32(name: impl Into<String>, array: &ArrayD<f32>) -> Self {
232        let shape: Vec<i64> = array.shape().iter().map(|&s| s as i64).collect();
233        let data = array
234            .as_slice()
235            .expect("ArrayD is contiguous in standard layout")
236            .iter()
237            .flat_map(|&f| f.to_le_bytes())
238            .collect();
239
240        Self {
241            name: name.into(),
242            data_type: ONNXDataType::Float32,
243            shape,
244            data,
245        }
246    }
247
248    /// Create tensor from ndarray (f64)
249    pub fn from_array_f64(name: impl Into<String>, array: &ArrayD<f64>) -> Self {
250        let shape: Vec<i64> = array.shape().iter().map(|&s| s as i64).collect();
251        let data = array
252            .as_slice()
253            .expect("ArrayD is contiguous in standard layout")
254            .iter()
255            .flat_map(|&f| (f as f32).to_le_bytes()) // Convert to f32 for ONNX
256            .collect();
257
258        Self {
259            name: name.into(),
260            data_type: ONNXDataType::Float32,
261            shape,
262            data,
263        }
264    }
265}
266
267/// ONNX exporter for quantum ML models
268pub struct ONNXExporter {
269    /// Quantum operator mappings
270    quantum_mappings: HashMap<String, String>,
271    /// Export options
272    options: ExportOptions,
273}
274
275/// Export options
276#[derive(Debug, Clone)]
277pub struct ExportOptions {
278    /// ONNX opset version
279    opset_version: i64,
280    /// Include quantum layers as custom operators
281    include_quantum_ops: bool,
282    /// Optimize classical layers only
283    optimize_classical_only: bool,
284    /// Target backend for quantum operations
285    quantum_backend: QuantumBackendTarget,
286}
287
288impl Default for ExportOptions {
289    fn default() -> Self {
290        Self {
291            opset_version: 11,
292            include_quantum_ops: true,
293            optimize_classical_only: false,
294            quantum_backend: QuantumBackendTarget::Generic,
295        }
296    }
297}
298
299/// Quantum backend targets for export
300#[derive(Debug, Clone)]
301pub enum QuantumBackendTarget {
302    /// Generic quantum backend
303    Generic,
304    /// Qiskit-compatible
305    Qiskit,
306    /// Cirq-compatible
307    Cirq,
308    /// PennyLane-compatible
309    PennyLane,
310    /// Custom backend
311    Custom(String),
312}
313
314impl ONNXExporter {
315    /// Create new ONNX exporter
316    pub fn new() -> Self {
317        let mut quantum_mappings = HashMap::new();
318
319        // Map quantum operations to ONNX custom operators
320        quantum_mappings.insert("QuantumDense".to_string(), "QuantumDense".to_string());
321        quantum_mappings.insert("QuantumLinear".to_string(), "QuantumLinear".to_string());
322        quantum_mappings.insert("QuantumConv2d".to_string(), "QuantumConv2d".to_string());
323        quantum_mappings.insert("QuantumRNN".to_string(), "QuantumRNN".to_string());
324
325        Self {
326            quantum_mappings,
327            options: ExportOptions::default(),
328        }
329    }
330
331    /// Set export options
332    pub fn with_options(mut self, options: ExportOptions) -> Self {
333        self.options = options;
334        self
335    }
336
337    /// Export Sequential model to ONNX
338    pub fn export_sequential(
339        &self,
340        model: &Sequential,
341        input_shape: &[usize],
342        output_path: &str,
343    ) -> Result<()> {
344        let mut graph = ONNXGraph::new("sequential_model");
345
346        // Add input
347        let input_shape_i64: Vec<i64> = input_shape.iter().map(|&s| s as i64).collect();
348        graph.add_input(ONNXValueInfo::new(
349            "input",
350            ONNXDataType::Float32,
351            input_shape_i64,
352        ));
353
354        let mut current_output = "input".to_string();
355        let mut node_counter = 0;
356
357        // Convert each layer
358        for layer in model.layers() {
359            let layer_name = format!("layer_{}", node_counter);
360            let output_name = format!("output_{}", node_counter);
361
362            // Convert layer based on type
363            let (nodes, initializers) =
364                self.convert_layer(layer.as_ref(), &layer_name, &current_output, &output_name)?;
365
366            // Add nodes and initializers to graph
367            for node in nodes {
368                graph.add_node(node);
369            }
370            for init in initializers {
371                graph.add_initializer(init);
372            }
373
374            current_output = output_name;
375            node_counter += 1;
376        }
377
378        // Add output
379        let output_shape = model.compute_output_shape(input_shape);
380        let output_shape_i64: Vec<i64> = output_shape.iter().map(|&s| s as i64).collect();
381        graph.add_output(ONNXValueInfo::new(
382            &current_output,
383            ONNXDataType::Float32,
384            output_shape_i64,
385        ));
386
387        // Export graph
388        graph.export(output_path)?;
389        Ok(())
390    }
391
392    /// Export PyTorch-style model to ONNX
393    pub fn export_pytorch_model<T: QuantumModule>(
394        &self,
395        model: &T,
396        input_shape: &[usize],
397        output_path: &str,
398    ) -> Result<()> {
399        let mut graph = ONNXGraph::new("pytorch_model");
400
401        // Add input
402        let input_shape_i64: Vec<i64> = input_shape.iter().map(|&s| s as i64).collect();
403        graph.add_input(ONNXValueInfo::new(
404            "input",
405            ONNXDataType::Float32,
406            input_shape_i64,
407        ));
408
409        // Convert model (simplified - would need more complex analysis)
410        let node = ONNXNode::new(
411            "pytorch_model",
412            "QuantumModel",
413            vec!["input".to_string()],
414            vec!["output".to_string()],
415        );
416        graph.add_node(node);
417
418        // Add output (would need to compute actual output shape)
419        graph.add_output(ONNXValueInfo::new(
420            "output",
421            ONNXDataType::Float32,
422            vec![1, 1], // Placeholder
423        ));
424
425        // Export graph
426        graph.export(output_path)?;
427        Ok(())
428    }
429
430    /// Convert layer to ONNX nodes and initializers
431    fn convert_layer(
432        &self,
433        layer: &dyn KerasLayer,
434        layer_name: &str,
435        input_name: &str,
436        output_name: &str,
437    ) -> Result<(Vec<ONNXNode>, Vec<ONNXTensor>)> {
438        // This would need to be implemented for each layer type
439        // For now, we'll provide a simplified conversion
440
441        let layer_type = self.get_layer_type(layer);
442
443        match layer_type.as_str() {
444            "Dense" => self.convert_dense_layer(layer, layer_name, input_name, output_name),
445            "QuantumDense" => {
446                self.convert_quantum_dense_layer(layer, layer_name, input_name, output_name)
447            }
448            "Activation" => {
449                self.convert_activation_layer(layer, layer_name, input_name, output_name)
450            }
451            _ => {
452                // Generic layer conversion
453                let node = ONNXNode::new(
454                    layer_name,
455                    &layer_type,
456                    vec![input_name.to_string()],
457                    vec![output_name.to_string()],
458                );
459                Ok((vec![node], vec![]))
460            }
461        }
462    }
463
464    /// Convert Dense layer
465    fn convert_dense_layer(
466        &self,
467        layer: &dyn KerasLayer,
468        layer_name: &str,
469        input_name: &str,
470        output_name: &str,
471    ) -> Result<(Vec<ONNXNode>, Vec<ONNXTensor>)> {
472        let weights = layer.get_weights();
473        let mut nodes = Vec::new();
474        let mut initializers = Vec::new();
475
476        if weights.len() >= 1 {
477            // Add weight initializer
478            let weight_name = format!("{}_weight", layer_name);
479            let weight_tensor = ONNXTensor::from_array_f64(&weight_name, &weights[0]);
480            initializers.push(weight_tensor);
481
482            // Create MatMul node
483            let mut matmul_inputs = vec![input_name.to_string(), weight_name];
484            let matmul_output = if weights.len() > 1 {
485                format!("{}_matmul", layer_name)
486            } else {
487                output_name.to_string()
488            };
489
490            let matmul_node = ONNXNode::new(
491                format!("{}_matmul", layer_name),
492                "MatMul",
493                matmul_inputs,
494                vec![matmul_output.clone()],
495            );
496            nodes.push(matmul_node);
497
498            // Add bias if present
499            if weights.len() > 1 {
500                let bias_name = format!("{}_bias", layer_name);
501                let bias_tensor = ONNXTensor::from_array_f64(&bias_name, &weights[1]);
502                initializers.push(bias_tensor);
503
504                let add_node = ONNXNode::new(
505                    format!("{}_add", layer_name),
506                    "Add",
507                    vec![matmul_output, bias_name],
508                    vec![output_name.to_string()],
509                );
510                nodes.push(add_node);
511            }
512        }
513
514        Ok((nodes, initializers))
515    }
516
517    /// Convert QuantumDense layer
518    fn convert_quantum_dense_layer(
519        &self,
520        layer: &dyn KerasLayer,
521        layer_name: &str,
522        input_name: &str,
523        output_name: &str,
524    ) -> Result<(Vec<ONNXNode>, Vec<ONNXTensor>)> {
525        if !self.options.include_quantum_ops {
526            return Err(MLError::InvalidConfiguration(
527                "Quantum operations not supported in export options".to_string(),
528            ));
529        }
530
531        let weights = layer.get_weights();
532        let mut nodes = Vec::new();
533        let mut initializers = Vec::new();
534
535        // Add quantum parameters as initializers
536        for (i, weight) in weights.iter().enumerate() {
537            let param_name = format!("{}_param_{}", layer_name, i);
538            let param_tensor = ONNXTensor::from_array_f64(&param_name, weight);
539            initializers.push(param_tensor);
540        }
541
542        // Create custom quantum node
543        let mut quantum_node = ONNXNode::new(
544            layer_name,
545            "QuantumDense",
546            vec![input_name.to_string()],
547            vec![output_name.to_string()],
548        );
549
550        // Add quantum-specific attributes
551        quantum_node.add_attribute(
552            "backend",
553            ONNXAttribute::String(format!("{:?}", self.options.quantum_backend)),
554        );
555        quantum_node.add_attribute("domain", ONNXAttribute::String("quantrs2.ml".to_string()));
556
557        nodes.push(quantum_node);
558
559        Ok((nodes, initializers))
560    }
561
562    /// Convert Activation layer
563    fn convert_activation_layer(
564        &self,
565        _layer: &dyn KerasLayer,
566        layer_name: &str,
567        input_name: &str,
568        output_name: &str,
569    ) -> Result<(Vec<ONNXNode>, Vec<ONNXTensor>)> {
570        // For now, assume ReLU activation
571        let node = ONNXNode::new(
572            layer_name,
573            "Relu",
574            vec![input_name.to_string()],
575            vec![output_name.to_string()],
576        );
577
578        Ok((vec![node], vec![]))
579    }
580
581    /// Get layer type string
582    fn get_layer_type(&self, _layer: &dyn KerasLayer) -> String {
583        // This would need to be implemented with proper type checking
584        // For now, return a placeholder
585        "Dense".to_string()
586    }
587}
588
589/// ONNX importer for loading models back into QuantRS2
590pub struct ONNXImporter {
591    /// Import options
592    options: ImportOptions,
593}
594
595/// Import options
596#[derive(Debug, Clone)]
597pub struct ImportOptions {
598    /// Target framework
599    target_framework: TargetFramework,
600    /// Handle unsupported operators
601    handle_unsupported: UnsupportedOpHandling,
602    /// Quantum backend for imported quantum ops
603    quantum_backend: QuantumBackendTarget,
604}
605
606/// Target frameworks for import
607#[derive(Debug, Clone)]
608pub enum TargetFramework {
609    /// Keras-style Sequential model
610    Keras,
611    /// PyTorch-style model
612    PyTorch,
613    /// Raw QuantRS2 model
614    QuantRS2,
615}
616
617/// How to handle unsupported operators
618#[derive(Debug, Clone)]
619pub enum UnsupportedOpHandling {
620    /// Raise error
621    Error,
622    /// Skip unsupported operators
623    Skip,
624    /// Replace with identity
625    Identity,
626    /// Custom handler
627    Custom(String),
628}
629
630impl Default for ImportOptions {
631    fn default() -> Self {
632        Self {
633            target_framework: TargetFramework::Keras,
634            handle_unsupported: UnsupportedOpHandling::Error,
635            quantum_backend: QuantumBackendTarget::Generic,
636        }
637    }
638}
639
640impl ONNXImporter {
641    /// Create new ONNX importer
642    pub fn new() -> Self {
643        Self {
644            options: ImportOptions::default(),
645        }
646    }
647
648    /// Set import options
649    pub fn with_options(mut self, options: ImportOptions) -> Self {
650        self.options = options;
651        self
652    }
653
654    /// Import ONNX model to Sequential model
655    pub fn import_to_sequential(&self, path: &str) -> Result<Sequential> {
656        let graph = self.load_onnx_graph(path)?;
657        self.convert_to_sequential(&graph)
658    }
659
660    /// Load ONNX graph from file
661    fn load_onnx_graph(&self, path: &str) -> Result<ONNXGraph> {
662        // This would parse the actual ONNX protobuf file
663        // For now, return a placeholder
664        Ok(ONNXGraph::new("imported_model"))
665    }
666
667    /// Convert ONNX graph to Sequential model
668    fn convert_to_sequential(&self, _graph: &ONNXGraph) -> Result<Sequential> {
669        // This would analyze the ONNX graph and recreate the Sequential model
670        // For now, return a simple model
671        Ok(Sequential::new())
672    }
673}
674
675/// Utility functions for ONNX export/import
676pub mod utils {
677    use super::*;
678
679    /// Validate ONNX model
680    pub fn validate_onnx_model(path: &str) -> Result<ValidationReport> {
681        // This would validate the ONNX model structure and operators
682        Ok(ValidationReport {
683            valid: true,
684            errors: Vec::new(),
685            warnings: Vec::new(),
686            quantum_ops_found: false,
687        })
688    }
689
690    /// Get ONNX model info
691    pub fn get_model_info(path: &str) -> Result<ModelInfo> {
692        // This would extract basic information about the ONNX model
693        Ok(ModelInfo {
694            opset_version: 11,
695            producer_name: "QuantRS2-ML".to_string(),
696            producer_version: "0.1.0".to_string(),
697            graph_name: "model".to_string(),
698            num_nodes: 0,
699            num_initializers: 0,
700            input_shapes: Vec::new(),
701            output_shapes: Vec::new(),
702        })
703    }
704
705    /// Convert quantum circuit to ONNX custom operator
706    pub fn circuit_to_onnx_op(circuit: &DynamicCircuit, name: &str) -> Result<ONNXNode> {
707        let mut node = ONNXNode::new(
708            name,
709            "QuantumCircuit",
710            vec!["input".to_string()],
711            vec!["output".to_string()],
712        );
713
714        // Add circuit-specific attributes
715        node.add_attribute(
716            "num_qubits",
717            ONNXAttribute::Int(circuit.num_qubits() as i64),
718        );
719        node.add_attribute("num_gates", ONNXAttribute::Int(circuit.num_gates() as i64));
720        node.add_attribute("depth", ONNXAttribute::Int(circuit.depth() as i64));
721
722        // Serialize circuit structure
723        let circuit_data = serialize_circuit(circuit)?;
724        node.add_attribute("circuit_data", ONNXAttribute::String(circuit_data));
725
726        Ok(node)
727    }
728
729    /// Serialize quantum circuit to string
730    fn serialize_circuit(circuit: &DynamicCircuit) -> Result<String> {
731        // This would serialize the circuit to a string format
732        // For now, return a placeholder
733        Ok("quantum_circuit_placeholder".to_string())
734    }
735
736    /// Create ONNX metadata for quantum ML model
737    pub fn create_quantum_metadata() -> HashMap<String, String> {
738        let mut metadata = HashMap::new();
739        metadata.insert("framework".to_string(), "QuantRS2-ML".to_string());
740        metadata.insert("domain".to_string(), "quantrs2.ml".to_string());
741        metadata.insert("version".to_string(), "0.1.0".to_string());
742        metadata.insert("quantum_support".to_string(), "true".to_string());
743        metadata
744    }
745}
746
747/// Validation report for ONNX models
748#[derive(Debug)]
749pub struct ValidationReport {
750    /// Model is valid
751    pub valid: bool,
752    /// Validation errors
753    pub errors: Vec<String>,
754    /// Validation warnings
755    pub warnings: Vec<String>,
756    /// Quantum operators found
757    pub quantum_ops_found: bool,
758}
759
760/// Model information
761#[derive(Debug)]
762pub struct ModelInfo {
763    /// ONNX opset version
764    pub opset_version: i64,
765    /// Producer name
766    pub producer_name: String,
767    /// Producer version
768    pub producer_version: String,
769    /// Graph name
770    pub graph_name: String,
771    /// Number of nodes
772    pub num_nodes: usize,
773    /// Number of initializers
774    pub num_initializers: usize,
775    /// Input shapes
776    pub input_shapes: Vec<Vec<i64>>,
777    /// Output shapes
778    pub output_shapes: Vec<Vec<i64>>,
779}
780
781// Extensions for Sequential model
782impl Sequential {
783    /// Export to ONNX format
784    pub fn export_onnx(
785        &self,
786        path: &str,
787        input_shape: &[usize],
788        options: Option<ExportOptions>,
789    ) -> Result<()> {
790        let exporter = ONNXExporter::new();
791        let exporter = if let Some(opts) = options {
792            exporter.with_options(opts)
793        } else {
794            exporter
795        };
796
797        exporter.export_sequential(self, input_shape, path)
798    }
799
800    /// Get layers (placeholder - would need actual implementation)
801    fn layers(&self) -> &[Box<dyn KerasLayer>] {
802        // This would return the actual layers from the Sequential model
803        &[]
804    }
805
806    /// Compute output shape (placeholder)
807    fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
808        // This would compute the actual output shape
809        input_shape.to_vec()
810    }
811}
812
813#[cfg(test)]
814mod tests {
815    use super::*;
816    use crate::keras_api::{ActivationFunction, Dense};
817
818    #[test]
819    fn test_onnx_graph_creation() {
820        let mut graph = ONNXGraph::new("test_graph");
821
822        graph.add_input(ONNXValueInfo::new(
823            "input",
824            ONNXDataType::Float32,
825            vec![1, 10],
826        ));
827
828        graph.add_output(ONNXValueInfo::new(
829            "output",
830            ONNXDataType::Float32,
831            vec![1, 5],
832        ));
833
834        let node = ONNXNode::new(
835            "dense_layer",
836            "MatMul",
837            vec!["input".to_string(), "weight".to_string()],
838            vec!["output".to_string()],
839        );
840        graph.add_node(node);
841
842        assert_eq!(graph.nodes.len(), 1);
843        assert_eq!(graph.inputs.len(), 1);
844        assert_eq!(graph.outputs.len(), 1);
845    }
846
847    #[test]
848    fn test_onnx_tensor_creation() {
849        let array = scirs2_core::ndarray::Array2::from_shape_vec(
850            (2, 3),
851            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
852        )
853        .expect("Shape and vec size are compatible")
854        .into_dyn();
855
856        let tensor = ONNXTensor::from_array_f64("test_tensor", &array);
857        assert_eq!(tensor.name, "test_tensor");
858        assert_eq!(tensor.shape, vec![2, 3]);
859    }
860
861    #[test]
862    fn test_onnx_exporter_creation() {
863        let exporter = ONNXExporter::new();
864        let options = ExportOptions {
865            opset_version: 13,
866            include_quantum_ops: false,
867            optimize_classical_only: true,
868            quantum_backend: QuantumBackendTarget::Qiskit,
869        };
870
871        let exporter = exporter.with_options(options);
872        assert_eq!(exporter.options.opset_version, 13);
873        assert!(!exporter.options.include_quantum_ops);
874    }
875
876    #[test]
877    fn test_onnx_node_attributes() {
878        let mut node = ONNXNode::new(
879            "test_node",
880            "Conv",
881            vec!["input".to_string()],
882            vec!["output".to_string()],
883        );
884
885        node.add_attribute("kernel_shape", ONNXAttribute::Ints(vec![3, 3]));
886        node.add_attribute("strides", ONNXAttribute::Ints(vec![1, 1]));
887
888        assert_eq!(node.attributes.len(), 2);
889    }
890
891    #[test]
892    fn test_validation_utils() {
893        let report = utils::validate_onnx_model("dummy_path");
894        assert!(report.is_ok());
895
896        let info = utils::get_model_info("dummy_path");
897        assert!(info.is_ok());
898    }
899}