1use crate::error::{MLError, Result};
7use crate::keras_api::{
8 Activation, ActivationFunction, Dense, KerasLayer, QuantumDense, Sequential,
9};
10use crate::pytorch_api::{QuantumLinear, QuantumModule, QuantumSequential};
11use crate::simulator_backends::DynamicCircuit;
12use quantrs2_circuit::prelude::*;
13use scirs2_core::ndarray::{Array1, Array2, ArrayD};
14use std::collections::HashMap;
15use std::io::Write;
16
17#[derive(Debug, Clone)]
19pub struct ONNXGraph {
20 nodes: Vec<ONNXNode>,
22 inputs: Vec<ONNXValueInfo>,
24 outputs: Vec<ONNXValueInfo>,
26 initializers: Vec<ONNXTensor>,
28 name: String,
30}
31
32impl ONNXGraph {
33 pub fn new(name: impl Into<String>) -> Self {
35 Self {
36 nodes: Vec::new(),
37 inputs: Vec::new(),
38 outputs: Vec::new(),
39 initializers: Vec::new(),
40 name: name.into(),
41 }
42 }
43
44 pub fn add_node(&mut self, node: ONNXNode) {
46 self.nodes.push(node);
47 }
48
49 pub fn add_input(&mut self, input: ONNXValueInfo) {
51 self.inputs.push(input);
52 }
53
54 pub fn add_output(&mut self, output: ONNXValueInfo) {
56 self.outputs.push(output);
57 }
58
59 pub fn add_initializer(&mut self, initializer: ONNXTensor) {
61 self.initializers.push(initializer);
62 }
63
64 pub fn export(&self, path: &str) -> Result<()> {
66 let onnx_proto = self.to_onnx_proto()?;
67
68 std::fs::write(path, onnx_proto)?;
69 Ok(())
70 }
71
72 fn to_onnx_proto(&self) -> Result<Vec<u8>> {
74 let mut buffer = Vec::new();
78
79 writeln!(buffer, "ONNX Model Export")?;
81 writeln!(buffer, "Graph Name: {}", self.name)?;
82 writeln!(buffer, "")?;
83
84 writeln!(buffer, "Inputs:")?;
86 for input in &self.inputs {
87 writeln!(buffer, " {}: {:?}", input.name, input.shape)?;
88 }
89 writeln!(buffer, "")?;
90
91 writeln!(buffer, "Outputs:")?;
93 for output in &self.outputs {
94 writeln!(buffer, " {}: {:?}", output.name, output.shape)?;
95 }
96 writeln!(buffer, "")?;
97
98 writeln!(buffer, "Nodes:")?;
100 for node in &self.nodes {
101 writeln!(
102 buffer,
103 " {} ({}): {} -> {}",
104 node.name,
105 node.op_type,
106 node.inputs.join(", "),
107 node.outputs.join(", ")
108 )?;
109 }
110 writeln!(buffer, "")?;
111
112 writeln!(buffer, "Initializers:")?;
114 for init in &self.initializers {
115 writeln!(buffer, " {}: {:?}", init.name, init.shape)?;
116 }
117
118 Ok(buffer)
119 }
120}
121
122#[derive(Debug, Clone)]
124pub struct ONNXNode {
125 name: String,
127 op_type: String,
129 inputs: Vec<String>,
131 outputs: Vec<String>,
133 attributes: HashMap<String, ONNXAttribute>,
135}
136
137impl ONNXNode {
138 pub fn new(
140 name: impl Into<String>,
141 op_type: impl Into<String>,
142 inputs: Vec<String>,
143 outputs: Vec<String>,
144 ) -> Self {
145 Self {
146 name: name.into(),
147 op_type: op_type.into(),
148 inputs,
149 outputs,
150 attributes: HashMap::new(),
151 }
152 }
153
154 pub fn add_attribute(&mut self, name: impl Into<String>, value: ONNXAttribute) {
156 self.attributes.insert(name.into(), value);
157 }
158}
159
160#[derive(Debug, Clone)]
162pub enum ONNXAttribute {
163 Int(i64),
165 Float(f32),
167 String(String),
169 Tensor(ONNXTensor),
171 Ints(Vec<i64>),
173 Floats(Vec<f32>),
175 Strings(Vec<String>),
177}
178
179#[derive(Debug, Clone)]
181pub struct ONNXValueInfo {
182 name: String,
184 data_type: ONNXDataType,
186 shape: Vec<i64>,
188}
189
190impl ONNXValueInfo {
191 pub fn new(name: impl Into<String>, data_type: ONNXDataType, shape: Vec<i64>) -> Self {
193 Self {
194 name: name.into(),
195 data_type,
196 shape,
197 }
198 }
199}
200
201#[derive(Debug, Clone)]
203pub enum ONNXDataType {
204 Float32,
206 Float64,
208 Int32,
210 Int64,
212 Bool,
214}
215
216#[derive(Debug, Clone)]
218pub struct ONNXTensor {
219 name: String,
221 data_type: ONNXDataType,
223 shape: Vec<i64>,
225 data: Vec<u8>,
227}
228
229impl ONNXTensor {
230 pub fn from_array_f32(name: impl Into<String>, array: &ArrayD<f32>) -> Self {
232 let shape: Vec<i64> = array.shape().iter().map(|&s| s as i64).collect();
233 let data = array
234 .as_slice()
235 .expect("ArrayD is contiguous in standard layout")
236 .iter()
237 .flat_map(|&f| f.to_le_bytes())
238 .collect();
239
240 Self {
241 name: name.into(),
242 data_type: ONNXDataType::Float32,
243 shape,
244 data,
245 }
246 }
247
248 pub fn from_array_f64(name: impl Into<String>, array: &ArrayD<f64>) -> Self {
250 let shape: Vec<i64> = array.shape().iter().map(|&s| s as i64).collect();
251 let data = array
252 .as_slice()
253 .expect("ArrayD is contiguous in standard layout")
254 .iter()
255 .flat_map(|&f| (f as f32).to_le_bytes()) .collect();
257
258 Self {
259 name: name.into(),
260 data_type: ONNXDataType::Float32,
261 shape,
262 data,
263 }
264 }
265}
266
267pub struct ONNXExporter {
269 quantum_mappings: HashMap<String, String>,
271 options: ExportOptions,
273}
274
275#[derive(Debug, Clone)]
277pub struct ExportOptions {
278 opset_version: i64,
280 include_quantum_ops: bool,
282 optimize_classical_only: bool,
284 quantum_backend: QuantumBackendTarget,
286}
287
288impl Default for ExportOptions {
289 fn default() -> Self {
290 Self {
291 opset_version: 11,
292 include_quantum_ops: true,
293 optimize_classical_only: false,
294 quantum_backend: QuantumBackendTarget::Generic,
295 }
296 }
297}
298
299#[derive(Debug, Clone)]
301pub enum QuantumBackendTarget {
302 Generic,
304 Qiskit,
306 Cirq,
308 PennyLane,
310 Custom(String),
312}
313
314impl ONNXExporter {
315 pub fn new() -> Self {
317 let mut quantum_mappings = HashMap::new();
318
319 quantum_mappings.insert("QuantumDense".to_string(), "QuantumDense".to_string());
321 quantum_mappings.insert("QuantumLinear".to_string(), "QuantumLinear".to_string());
322 quantum_mappings.insert("QuantumConv2d".to_string(), "QuantumConv2d".to_string());
323 quantum_mappings.insert("QuantumRNN".to_string(), "QuantumRNN".to_string());
324
325 Self {
326 quantum_mappings,
327 options: ExportOptions::default(),
328 }
329 }
330
331 pub fn with_options(mut self, options: ExportOptions) -> Self {
333 self.options = options;
334 self
335 }
336
337 pub fn export_sequential(
339 &self,
340 model: &Sequential,
341 input_shape: &[usize],
342 output_path: &str,
343 ) -> Result<()> {
344 let mut graph = ONNXGraph::new("sequential_model");
345
346 let input_shape_i64: Vec<i64> = input_shape.iter().map(|&s| s as i64).collect();
348 graph.add_input(ONNXValueInfo::new(
349 "input",
350 ONNXDataType::Float32,
351 input_shape_i64,
352 ));
353
354 let mut current_output = "input".to_string();
355 let mut node_counter = 0;
356
357 for layer in model.layers() {
359 let layer_name = format!("layer_{}", node_counter);
360 let output_name = format!("output_{}", node_counter);
361
362 let (nodes, initializers) =
364 self.convert_layer(layer.as_ref(), &layer_name, ¤t_output, &output_name)?;
365
366 for node in nodes {
368 graph.add_node(node);
369 }
370 for init in initializers {
371 graph.add_initializer(init);
372 }
373
374 current_output = output_name;
375 node_counter += 1;
376 }
377
378 let output_shape = model.compute_output_shape(input_shape);
380 let output_shape_i64: Vec<i64> = output_shape.iter().map(|&s| s as i64).collect();
381 graph.add_output(ONNXValueInfo::new(
382 ¤t_output,
383 ONNXDataType::Float32,
384 output_shape_i64,
385 ));
386
387 graph.export(output_path)?;
389 Ok(())
390 }
391
392 pub fn export_pytorch_model<T: QuantumModule>(
394 &self,
395 model: &T,
396 input_shape: &[usize],
397 output_path: &str,
398 ) -> Result<()> {
399 let mut graph = ONNXGraph::new("pytorch_model");
400
401 let input_shape_i64: Vec<i64> = input_shape.iter().map(|&s| s as i64).collect();
403 graph.add_input(ONNXValueInfo::new(
404 "input",
405 ONNXDataType::Float32,
406 input_shape_i64,
407 ));
408
409 let node = ONNXNode::new(
411 "pytorch_model",
412 "QuantumModel",
413 vec!["input".to_string()],
414 vec!["output".to_string()],
415 );
416 graph.add_node(node);
417
418 graph.add_output(ONNXValueInfo::new(
420 "output",
421 ONNXDataType::Float32,
422 vec![1, 1], ));
424
425 graph.export(output_path)?;
427 Ok(())
428 }
429
430 fn convert_layer(
432 &self,
433 layer: &dyn KerasLayer,
434 layer_name: &str,
435 input_name: &str,
436 output_name: &str,
437 ) -> Result<(Vec<ONNXNode>, Vec<ONNXTensor>)> {
438 let layer_type = self.get_layer_type(layer);
442
443 match layer_type.as_str() {
444 "Dense" => self.convert_dense_layer(layer, layer_name, input_name, output_name),
445 "QuantumDense" => {
446 self.convert_quantum_dense_layer(layer, layer_name, input_name, output_name)
447 }
448 "Activation" => {
449 self.convert_activation_layer(layer, layer_name, input_name, output_name)
450 }
451 _ => {
452 let node = ONNXNode::new(
454 layer_name,
455 &layer_type,
456 vec![input_name.to_string()],
457 vec![output_name.to_string()],
458 );
459 Ok((vec![node], vec![]))
460 }
461 }
462 }
463
464 fn convert_dense_layer(
466 &self,
467 layer: &dyn KerasLayer,
468 layer_name: &str,
469 input_name: &str,
470 output_name: &str,
471 ) -> Result<(Vec<ONNXNode>, Vec<ONNXTensor>)> {
472 let weights = layer.get_weights();
473 let mut nodes = Vec::new();
474 let mut initializers = Vec::new();
475
476 if weights.len() >= 1 {
477 let weight_name = format!("{}_weight", layer_name);
479 let weight_tensor = ONNXTensor::from_array_f64(&weight_name, &weights[0]);
480 initializers.push(weight_tensor);
481
482 let mut matmul_inputs = vec![input_name.to_string(), weight_name];
484 let matmul_output = if weights.len() > 1 {
485 format!("{}_matmul", layer_name)
486 } else {
487 output_name.to_string()
488 };
489
490 let matmul_node = ONNXNode::new(
491 format!("{}_matmul", layer_name),
492 "MatMul",
493 matmul_inputs,
494 vec![matmul_output.clone()],
495 );
496 nodes.push(matmul_node);
497
498 if weights.len() > 1 {
500 let bias_name = format!("{}_bias", layer_name);
501 let bias_tensor = ONNXTensor::from_array_f64(&bias_name, &weights[1]);
502 initializers.push(bias_tensor);
503
504 let add_node = ONNXNode::new(
505 format!("{}_add", layer_name),
506 "Add",
507 vec![matmul_output, bias_name],
508 vec![output_name.to_string()],
509 );
510 nodes.push(add_node);
511 }
512 }
513
514 Ok((nodes, initializers))
515 }
516
517 fn convert_quantum_dense_layer(
519 &self,
520 layer: &dyn KerasLayer,
521 layer_name: &str,
522 input_name: &str,
523 output_name: &str,
524 ) -> Result<(Vec<ONNXNode>, Vec<ONNXTensor>)> {
525 if !self.options.include_quantum_ops {
526 return Err(MLError::InvalidConfiguration(
527 "Quantum operations not supported in export options".to_string(),
528 ));
529 }
530
531 let weights = layer.get_weights();
532 let mut nodes = Vec::new();
533 let mut initializers = Vec::new();
534
535 for (i, weight) in weights.iter().enumerate() {
537 let param_name = format!("{}_param_{}", layer_name, i);
538 let param_tensor = ONNXTensor::from_array_f64(¶m_name, weight);
539 initializers.push(param_tensor);
540 }
541
542 let mut quantum_node = ONNXNode::new(
544 layer_name,
545 "QuantumDense",
546 vec![input_name.to_string()],
547 vec![output_name.to_string()],
548 );
549
550 quantum_node.add_attribute(
552 "backend",
553 ONNXAttribute::String(format!("{:?}", self.options.quantum_backend)),
554 );
555 quantum_node.add_attribute("domain", ONNXAttribute::String("quantrs2.ml".to_string()));
556
557 nodes.push(quantum_node);
558
559 Ok((nodes, initializers))
560 }
561
562 fn convert_activation_layer(
564 &self,
565 _layer: &dyn KerasLayer,
566 layer_name: &str,
567 input_name: &str,
568 output_name: &str,
569 ) -> Result<(Vec<ONNXNode>, Vec<ONNXTensor>)> {
570 let node = ONNXNode::new(
572 layer_name,
573 "Relu",
574 vec![input_name.to_string()],
575 vec![output_name.to_string()],
576 );
577
578 Ok((vec![node], vec![]))
579 }
580
581 fn get_layer_type(&self, _layer: &dyn KerasLayer) -> String {
583 "Dense".to_string()
586 }
587}
588
589pub struct ONNXImporter {
591 options: ImportOptions,
593}
594
595#[derive(Debug, Clone)]
597pub struct ImportOptions {
598 target_framework: TargetFramework,
600 handle_unsupported: UnsupportedOpHandling,
602 quantum_backend: QuantumBackendTarget,
604}
605
606#[derive(Debug, Clone)]
608pub enum TargetFramework {
609 Keras,
611 PyTorch,
613 QuantRS2,
615}
616
617#[derive(Debug, Clone)]
619pub enum UnsupportedOpHandling {
620 Error,
622 Skip,
624 Identity,
626 Custom(String),
628}
629
630impl Default for ImportOptions {
631 fn default() -> Self {
632 Self {
633 target_framework: TargetFramework::Keras,
634 handle_unsupported: UnsupportedOpHandling::Error,
635 quantum_backend: QuantumBackendTarget::Generic,
636 }
637 }
638}
639
640impl ONNXImporter {
641 pub fn new() -> Self {
643 Self {
644 options: ImportOptions::default(),
645 }
646 }
647
648 pub fn with_options(mut self, options: ImportOptions) -> Self {
650 self.options = options;
651 self
652 }
653
654 pub fn import_to_sequential(&self, path: &str) -> Result<Sequential> {
656 let graph = self.load_onnx_graph(path)?;
657 self.convert_to_sequential(&graph)
658 }
659
660 fn load_onnx_graph(&self, path: &str) -> Result<ONNXGraph> {
662 Ok(ONNXGraph::new("imported_model"))
665 }
666
667 fn convert_to_sequential(&self, _graph: &ONNXGraph) -> Result<Sequential> {
669 Ok(Sequential::new())
672 }
673}
674
675pub mod utils {
677 use super::*;
678
679 pub fn validate_onnx_model(path: &str) -> Result<ValidationReport> {
681 Ok(ValidationReport {
683 valid: true,
684 errors: Vec::new(),
685 warnings: Vec::new(),
686 quantum_ops_found: false,
687 })
688 }
689
690 pub fn get_model_info(path: &str) -> Result<ModelInfo> {
692 Ok(ModelInfo {
694 opset_version: 11,
695 producer_name: "QuantRS2-ML".to_string(),
696 producer_version: "0.1.0".to_string(),
697 graph_name: "model".to_string(),
698 num_nodes: 0,
699 num_initializers: 0,
700 input_shapes: Vec::new(),
701 output_shapes: Vec::new(),
702 })
703 }
704
705 pub fn circuit_to_onnx_op(circuit: &DynamicCircuit, name: &str) -> Result<ONNXNode> {
707 let mut node = ONNXNode::new(
708 name,
709 "QuantumCircuit",
710 vec!["input".to_string()],
711 vec!["output".to_string()],
712 );
713
714 node.add_attribute(
716 "num_qubits",
717 ONNXAttribute::Int(circuit.num_qubits() as i64),
718 );
719 node.add_attribute("num_gates", ONNXAttribute::Int(circuit.num_gates() as i64));
720 node.add_attribute("depth", ONNXAttribute::Int(circuit.depth() as i64));
721
722 let circuit_data = serialize_circuit(circuit)?;
724 node.add_attribute("circuit_data", ONNXAttribute::String(circuit_data));
725
726 Ok(node)
727 }
728
729 fn serialize_circuit(circuit: &DynamicCircuit) -> Result<String> {
731 Ok("quantum_circuit_placeholder".to_string())
734 }
735
736 pub fn create_quantum_metadata() -> HashMap<String, String> {
738 let mut metadata = HashMap::new();
739 metadata.insert("framework".to_string(), "QuantRS2-ML".to_string());
740 metadata.insert("domain".to_string(), "quantrs2.ml".to_string());
741 metadata.insert("version".to_string(), "0.1.0".to_string());
742 metadata.insert("quantum_support".to_string(), "true".to_string());
743 metadata
744 }
745}
746
747#[derive(Debug)]
749pub struct ValidationReport {
750 pub valid: bool,
752 pub errors: Vec<String>,
754 pub warnings: Vec<String>,
756 pub quantum_ops_found: bool,
758}
759
760#[derive(Debug)]
762pub struct ModelInfo {
763 pub opset_version: i64,
765 pub producer_name: String,
767 pub producer_version: String,
769 pub graph_name: String,
771 pub num_nodes: usize,
773 pub num_initializers: usize,
775 pub input_shapes: Vec<Vec<i64>>,
777 pub output_shapes: Vec<Vec<i64>>,
779}
780
781impl Sequential {
783 pub fn export_onnx(
785 &self,
786 path: &str,
787 input_shape: &[usize],
788 options: Option<ExportOptions>,
789 ) -> Result<()> {
790 let exporter = ONNXExporter::new();
791 let exporter = if let Some(opts) = options {
792 exporter.with_options(opts)
793 } else {
794 exporter
795 };
796
797 exporter.export_sequential(self, input_shape, path)
798 }
799
800 fn layers(&self) -> &[Box<dyn KerasLayer>] {
802 &[]
804 }
805
806 fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
808 input_shape.to_vec()
810 }
811}
812
813#[cfg(test)]
814mod tests {
815 use super::*;
816 use crate::keras_api::{ActivationFunction, Dense};
817
818 #[test]
819 fn test_onnx_graph_creation() {
820 let mut graph = ONNXGraph::new("test_graph");
821
822 graph.add_input(ONNXValueInfo::new(
823 "input",
824 ONNXDataType::Float32,
825 vec![1, 10],
826 ));
827
828 graph.add_output(ONNXValueInfo::new(
829 "output",
830 ONNXDataType::Float32,
831 vec![1, 5],
832 ));
833
834 let node = ONNXNode::new(
835 "dense_layer",
836 "MatMul",
837 vec!["input".to_string(), "weight".to_string()],
838 vec!["output".to_string()],
839 );
840 graph.add_node(node);
841
842 assert_eq!(graph.nodes.len(), 1);
843 assert_eq!(graph.inputs.len(), 1);
844 assert_eq!(graph.outputs.len(), 1);
845 }
846
847 #[test]
848 fn test_onnx_tensor_creation() {
849 let array = scirs2_core::ndarray::Array2::from_shape_vec(
850 (2, 3),
851 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
852 )
853 .expect("Shape and vec size are compatible")
854 .into_dyn();
855
856 let tensor = ONNXTensor::from_array_f64("test_tensor", &array);
857 assert_eq!(tensor.name, "test_tensor");
858 assert_eq!(tensor.shape, vec![2, 3]);
859 }
860
861 #[test]
862 fn test_onnx_exporter_creation() {
863 let exporter = ONNXExporter::new();
864 let options = ExportOptions {
865 opset_version: 13,
866 include_quantum_ops: false,
867 optimize_classical_only: true,
868 quantum_backend: QuantumBackendTarget::Qiskit,
869 };
870
871 let exporter = exporter.with_options(options);
872 assert_eq!(exporter.options.opset_version, 13);
873 assert!(!exporter.options.include_quantum_ops);
874 }
875
876 #[test]
877 fn test_onnx_node_attributes() {
878 let mut node = ONNXNode::new(
879 "test_node",
880 "Conv",
881 vec!["input".to_string()],
882 vec!["output".to_string()],
883 );
884
885 node.add_attribute("kernel_shape", ONNXAttribute::Ints(vec![3, 3]));
886 node.add_attribute("strides", ONNXAttribute::Ints(vec![1, 1]));
887
888 assert_eq!(node.attributes.len(), 2);
889 }
890
891 #[test]
892 fn test_validation_utils() {
893 let report = utils::validate_onnx_model("dummy_path");
894 assert!(report.is_ok());
895
896 let info = utils::get_model_info("dummy_path");
897 assert!(info.is_ok());
898 }
899}