1use crate::error::{MLError, Result};
7use crate::keras_api::{
8 Activation, ActivationFunction, Dense, KerasLayer, QuantumDense, Sequential,
9};
10use crate::pytorch_api::{QuantumLinear, QuantumModule, QuantumSequential};
11use crate::simulator_backends::DynamicCircuit;
12use ndarray::{Array1, Array2, ArrayD};
13use quantrs2_circuit::prelude::*;
14use std::collections::HashMap;
15use std::io::Write;
16
17#[derive(Debug, Clone)]
19pub struct ONNXGraph {
20 nodes: Vec<ONNXNode>,
22 inputs: Vec<ONNXValueInfo>,
24 outputs: Vec<ONNXValueInfo>,
26 initializers: Vec<ONNXTensor>,
28 name: String,
30}
31
32impl ONNXGraph {
33 pub fn new(name: impl Into<String>) -> Self {
35 Self {
36 nodes: Vec::new(),
37 inputs: Vec::new(),
38 outputs: Vec::new(),
39 initializers: Vec::new(),
40 name: name.into(),
41 }
42 }
43
44 pub fn add_node(&mut self, node: ONNXNode) {
46 self.nodes.push(node);
47 }
48
49 pub fn add_input(&mut self, input: ONNXValueInfo) {
51 self.inputs.push(input);
52 }
53
54 pub fn add_output(&mut self, output: ONNXValueInfo) {
56 self.outputs.push(output);
57 }
58
59 pub fn add_initializer(&mut self, initializer: ONNXTensor) {
61 self.initializers.push(initializer);
62 }
63
64 pub fn export(&self, path: &str) -> Result<()> {
66 let onnx_proto = self.to_onnx_proto()?;
67
68 std::fs::write(path, onnx_proto)?;
69 Ok(())
70 }
71
72 fn to_onnx_proto(&self) -> Result<Vec<u8>> {
74 let mut buffer = Vec::new();
78
79 writeln!(buffer, "ONNX Model Export")?;
81 writeln!(buffer, "Graph Name: {}", self.name)?;
82 writeln!(buffer, "")?;
83
84 writeln!(buffer, "Inputs:")?;
86 for input in &self.inputs {
87 writeln!(buffer, " {}: {:?}", input.name, input.shape)?;
88 }
89 writeln!(buffer, "")?;
90
91 writeln!(buffer, "Outputs:")?;
93 for output in &self.outputs {
94 writeln!(buffer, " {}: {:?}", output.name, output.shape)?;
95 }
96 writeln!(buffer, "")?;
97
98 writeln!(buffer, "Nodes:")?;
100 for node in &self.nodes {
101 writeln!(
102 buffer,
103 " {} ({}): {} -> {}",
104 node.name,
105 node.op_type,
106 node.inputs.join(", "),
107 node.outputs.join(", ")
108 )?;
109 }
110 writeln!(buffer, "")?;
111
112 writeln!(buffer, "Initializers:")?;
114 for init in &self.initializers {
115 writeln!(buffer, " {}: {:?}", init.name, init.shape)?;
116 }
117
118 Ok(buffer)
119 }
120}
121
122#[derive(Debug, Clone)]
124pub struct ONNXNode {
125 name: String,
127 op_type: String,
129 inputs: Vec<String>,
131 outputs: Vec<String>,
133 attributes: HashMap<String, ONNXAttribute>,
135}
136
137impl ONNXNode {
138 pub fn new(
140 name: impl Into<String>,
141 op_type: impl Into<String>,
142 inputs: Vec<String>,
143 outputs: Vec<String>,
144 ) -> Self {
145 Self {
146 name: name.into(),
147 op_type: op_type.into(),
148 inputs,
149 outputs,
150 attributes: HashMap::new(),
151 }
152 }
153
154 pub fn add_attribute(&mut self, name: impl Into<String>, value: ONNXAttribute) {
156 self.attributes.insert(name.into(), value);
157 }
158}
159
160#[derive(Debug, Clone)]
162pub enum ONNXAttribute {
163 Int(i64),
165 Float(f32),
167 String(String),
169 Tensor(ONNXTensor),
171 Ints(Vec<i64>),
173 Floats(Vec<f32>),
175 Strings(Vec<String>),
177}
178
179#[derive(Debug, Clone)]
181pub struct ONNXValueInfo {
182 name: String,
184 data_type: ONNXDataType,
186 shape: Vec<i64>,
188}
189
190impl ONNXValueInfo {
191 pub fn new(name: impl Into<String>, data_type: ONNXDataType, shape: Vec<i64>) -> Self {
193 Self {
194 name: name.into(),
195 data_type,
196 shape,
197 }
198 }
199}
200
201#[derive(Debug, Clone)]
203pub enum ONNXDataType {
204 Float32,
206 Float64,
208 Int32,
210 Int64,
212 Bool,
214}
215
216#[derive(Debug, Clone)]
218pub struct ONNXTensor {
219 name: String,
221 data_type: ONNXDataType,
223 shape: Vec<i64>,
225 data: Vec<u8>,
227}
228
229impl ONNXTensor {
230 pub fn from_array_f32(name: impl Into<String>, array: &ArrayD<f32>) -> Self {
232 let shape: Vec<i64> = array.shape().iter().map(|&s| s as i64).collect();
233 let data = array
234 .as_slice()
235 .unwrap()
236 .iter()
237 .flat_map(|&f| f.to_le_bytes())
238 .collect();
239
240 Self {
241 name: name.into(),
242 data_type: ONNXDataType::Float32,
243 shape,
244 data,
245 }
246 }
247
248 pub fn from_array_f64(name: impl Into<String>, array: &ArrayD<f64>) -> Self {
250 let shape: Vec<i64> = array.shape().iter().map(|&s| s as i64).collect();
251 let data = array.as_slice().unwrap().iter()
252 .flat_map(|&f| (f as f32).to_le_bytes()) .collect();
254
255 Self {
256 name: name.into(),
257 data_type: ONNXDataType::Float32,
258 shape,
259 data,
260 }
261 }
262}
263
264pub struct ONNXExporter {
266 quantum_mappings: HashMap<String, String>,
268 options: ExportOptions,
270}
271
272#[derive(Debug, Clone)]
274pub struct ExportOptions {
275 opset_version: i64,
277 include_quantum_ops: bool,
279 optimize_classical_only: bool,
281 quantum_backend: QuantumBackendTarget,
283}
284
285impl Default for ExportOptions {
286 fn default() -> Self {
287 Self {
288 opset_version: 11,
289 include_quantum_ops: true,
290 optimize_classical_only: false,
291 quantum_backend: QuantumBackendTarget::Generic,
292 }
293 }
294}
295
296#[derive(Debug, Clone)]
298pub enum QuantumBackendTarget {
299 Generic,
301 Qiskit,
303 Cirq,
305 PennyLane,
307 Custom(String),
309}
310
311impl ONNXExporter {
312 pub fn new() -> Self {
314 let mut quantum_mappings = HashMap::new();
315
316 quantum_mappings.insert("QuantumDense".to_string(), "QuantumDense".to_string());
318 quantum_mappings.insert("QuantumLinear".to_string(), "QuantumLinear".to_string());
319 quantum_mappings.insert("QuantumConv2d".to_string(), "QuantumConv2d".to_string());
320 quantum_mappings.insert("QuantumRNN".to_string(), "QuantumRNN".to_string());
321
322 Self {
323 quantum_mappings,
324 options: ExportOptions::default(),
325 }
326 }
327
328 pub fn with_options(mut self, options: ExportOptions) -> Self {
330 self.options = options;
331 self
332 }
333
334 pub fn export_sequential(
336 &self,
337 model: &Sequential,
338 input_shape: &[usize],
339 output_path: &str,
340 ) -> Result<()> {
341 let mut graph = ONNXGraph::new("sequential_model");
342
343 let input_shape_i64: Vec<i64> = input_shape.iter().map(|&s| s as i64).collect();
345 graph.add_input(ONNXValueInfo::new(
346 "input",
347 ONNXDataType::Float32,
348 input_shape_i64,
349 ));
350
351 let mut current_output = "input".to_string();
352 let mut node_counter = 0;
353
354 for layer in model.layers() {
356 let layer_name = format!("layer_{}", node_counter);
357 let output_name = format!("output_{}", node_counter);
358
359 let (nodes, initializers) =
361 self.convert_layer(layer.as_ref(), &layer_name, ¤t_output, &output_name)?;
362
363 for node in nodes {
365 graph.add_node(node);
366 }
367 for init in initializers {
368 graph.add_initializer(init);
369 }
370
371 current_output = output_name;
372 node_counter += 1;
373 }
374
375 let output_shape = model.compute_output_shape(input_shape);
377 let output_shape_i64: Vec<i64> = output_shape.iter().map(|&s| s as i64).collect();
378 graph.add_output(ONNXValueInfo::new(
379 ¤t_output,
380 ONNXDataType::Float32,
381 output_shape_i64,
382 ));
383
384 graph.export(output_path)?;
386 Ok(())
387 }
388
389 pub fn export_pytorch_model<T: QuantumModule>(
391 &self,
392 model: &T,
393 input_shape: &[usize],
394 output_path: &str,
395 ) -> Result<()> {
396 let mut graph = ONNXGraph::new("pytorch_model");
397
398 let input_shape_i64: Vec<i64> = input_shape.iter().map(|&s| s as i64).collect();
400 graph.add_input(ONNXValueInfo::new(
401 "input",
402 ONNXDataType::Float32,
403 input_shape_i64,
404 ));
405
406 let node = ONNXNode::new(
408 "pytorch_model",
409 "QuantumModel",
410 vec!["input".to_string()],
411 vec!["output".to_string()],
412 );
413 graph.add_node(node);
414
415 graph.add_output(ONNXValueInfo::new(
417 "output",
418 ONNXDataType::Float32,
419 vec![1, 1], ));
421
422 graph.export(output_path)?;
424 Ok(())
425 }
426
427 fn convert_layer(
429 &self,
430 layer: &dyn KerasLayer,
431 layer_name: &str,
432 input_name: &str,
433 output_name: &str,
434 ) -> Result<(Vec<ONNXNode>, Vec<ONNXTensor>)> {
435 let layer_type = self.get_layer_type(layer);
439
440 match layer_type.as_str() {
441 "Dense" => self.convert_dense_layer(layer, layer_name, input_name, output_name),
442 "QuantumDense" => {
443 self.convert_quantum_dense_layer(layer, layer_name, input_name, output_name)
444 }
445 "Activation" => {
446 self.convert_activation_layer(layer, layer_name, input_name, output_name)
447 }
448 _ => {
449 let node = ONNXNode::new(
451 layer_name,
452 &layer_type,
453 vec![input_name.to_string()],
454 vec![output_name.to_string()],
455 );
456 Ok((vec![node], vec![]))
457 }
458 }
459 }
460
461 fn convert_dense_layer(
463 &self,
464 layer: &dyn KerasLayer,
465 layer_name: &str,
466 input_name: &str,
467 output_name: &str,
468 ) -> Result<(Vec<ONNXNode>, Vec<ONNXTensor>)> {
469 let weights = layer.get_weights();
470 let mut nodes = Vec::new();
471 let mut initializers = Vec::new();
472
473 if weights.len() >= 1 {
474 let weight_name = format!("{}_weight", layer_name);
476 let weight_tensor = ONNXTensor::from_array_f64(&weight_name, &weights[0]);
477 initializers.push(weight_tensor);
478
479 let mut matmul_inputs = vec![input_name.to_string(), weight_name];
481 let matmul_output = if weights.len() > 1 {
482 format!("{}_matmul", layer_name)
483 } else {
484 output_name.to_string()
485 };
486
487 let matmul_node = ONNXNode::new(
488 format!("{}_matmul", layer_name),
489 "MatMul",
490 matmul_inputs,
491 vec![matmul_output.clone()],
492 );
493 nodes.push(matmul_node);
494
495 if weights.len() > 1 {
497 let bias_name = format!("{}_bias", layer_name);
498 let bias_tensor = ONNXTensor::from_array_f64(&bias_name, &weights[1]);
499 initializers.push(bias_tensor);
500
501 let add_node = ONNXNode::new(
502 format!("{}_add", layer_name),
503 "Add",
504 vec![matmul_output, bias_name],
505 vec![output_name.to_string()],
506 );
507 nodes.push(add_node);
508 }
509 }
510
511 Ok((nodes, initializers))
512 }
513
514 fn convert_quantum_dense_layer(
516 &self,
517 layer: &dyn KerasLayer,
518 layer_name: &str,
519 input_name: &str,
520 output_name: &str,
521 ) -> Result<(Vec<ONNXNode>, Vec<ONNXTensor>)> {
522 if !self.options.include_quantum_ops {
523 return Err(MLError::InvalidConfiguration(
524 "Quantum operations not supported in export options".to_string(),
525 ));
526 }
527
528 let weights = layer.get_weights();
529 let mut nodes = Vec::new();
530 let mut initializers = Vec::new();
531
532 for (i, weight) in weights.iter().enumerate() {
534 let param_name = format!("{}_param_{}", layer_name, i);
535 let param_tensor = ONNXTensor::from_array_f64(¶m_name, weight);
536 initializers.push(param_tensor);
537 }
538
539 let mut quantum_node = ONNXNode::new(
541 layer_name,
542 "QuantumDense",
543 vec![input_name.to_string()],
544 vec![output_name.to_string()],
545 );
546
547 quantum_node.add_attribute(
549 "backend",
550 ONNXAttribute::String(format!("{:?}", self.options.quantum_backend)),
551 );
552 quantum_node.add_attribute("domain", ONNXAttribute::String("quantrs2.ml".to_string()));
553
554 nodes.push(quantum_node);
555
556 Ok((nodes, initializers))
557 }
558
559 fn convert_activation_layer(
561 &self,
562 _layer: &dyn KerasLayer,
563 layer_name: &str,
564 input_name: &str,
565 output_name: &str,
566 ) -> Result<(Vec<ONNXNode>, Vec<ONNXTensor>)> {
567 let node = ONNXNode::new(
569 layer_name,
570 "Relu",
571 vec![input_name.to_string()],
572 vec![output_name.to_string()],
573 );
574
575 Ok((vec![node], vec![]))
576 }
577
578 fn get_layer_type(&self, _layer: &dyn KerasLayer) -> String {
580 "Dense".to_string()
583 }
584}
585
586pub struct ONNXImporter {
588 options: ImportOptions,
590}
591
592#[derive(Debug, Clone)]
594pub struct ImportOptions {
595 target_framework: TargetFramework,
597 handle_unsupported: UnsupportedOpHandling,
599 quantum_backend: QuantumBackendTarget,
601}
602
603#[derive(Debug, Clone)]
605pub enum TargetFramework {
606 Keras,
608 PyTorch,
610 QuantRS2,
612}
613
614#[derive(Debug, Clone)]
616pub enum UnsupportedOpHandling {
617 Error,
619 Skip,
621 Identity,
623 Custom(String),
625}
626
627impl Default for ImportOptions {
628 fn default() -> Self {
629 Self {
630 target_framework: TargetFramework::Keras,
631 handle_unsupported: UnsupportedOpHandling::Error,
632 quantum_backend: QuantumBackendTarget::Generic,
633 }
634 }
635}
636
637impl ONNXImporter {
638 pub fn new() -> Self {
640 Self {
641 options: ImportOptions::default(),
642 }
643 }
644
645 pub fn with_options(mut self, options: ImportOptions) -> Self {
647 self.options = options;
648 self
649 }
650
651 pub fn import_to_sequential(&self, path: &str) -> Result<Sequential> {
653 let graph = self.load_onnx_graph(path)?;
654 self.convert_to_sequential(&graph)
655 }
656
657 fn load_onnx_graph(&self, path: &str) -> Result<ONNXGraph> {
659 Ok(ONNXGraph::new("imported_model"))
662 }
663
664 fn convert_to_sequential(&self, _graph: &ONNXGraph) -> Result<Sequential> {
666 Ok(Sequential::new())
669 }
670}
671
672pub mod utils {
674 use super::*;
675
676 pub fn validate_onnx_model(path: &str) -> Result<ValidationReport> {
678 Ok(ValidationReport {
680 valid: true,
681 errors: Vec::new(),
682 warnings: Vec::new(),
683 quantum_ops_found: false,
684 })
685 }
686
687 pub fn get_model_info(path: &str) -> Result<ModelInfo> {
689 Ok(ModelInfo {
691 opset_version: 11,
692 producer_name: "QuantRS2-ML".to_string(),
693 producer_version: "0.1.0".to_string(),
694 graph_name: "model".to_string(),
695 num_nodes: 0,
696 num_initializers: 0,
697 input_shapes: Vec::new(),
698 output_shapes: Vec::new(),
699 })
700 }
701
702 pub fn circuit_to_onnx_op(circuit: &DynamicCircuit, name: &str) -> Result<ONNXNode> {
704 let mut node = ONNXNode::new(
705 name,
706 "QuantumCircuit",
707 vec!["input".to_string()],
708 vec!["output".to_string()],
709 );
710
711 node.add_attribute(
713 "num_qubits",
714 ONNXAttribute::Int(circuit.num_qubits() as i64),
715 );
716 node.add_attribute("num_gates", ONNXAttribute::Int(circuit.num_gates() as i64));
717 node.add_attribute("depth", ONNXAttribute::Int(circuit.depth() as i64));
718
719 let circuit_data = serialize_circuit(circuit)?;
721 node.add_attribute("circuit_data", ONNXAttribute::String(circuit_data));
722
723 Ok(node)
724 }
725
726 fn serialize_circuit(circuit: &DynamicCircuit) -> Result<String> {
728 Ok("quantum_circuit_placeholder".to_string())
731 }
732
733 pub fn create_quantum_metadata() -> HashMap<String, String> {
735 let mut metadata = HashMap::new();
736 metadata.insert("framework".to_string(), "QuantRS2-ML".to_string());
737 metadata.insert("domain".to_string(), "quantrs2.ml".to_string());
738 metadata.insert("version".to_string(), "0.1.0".to_string());
739 metadata.insert("quantum_support".to_string(), "true".to_string());
740 metadata
741 }
742}
743
744#[derive(Debug)]
746pub struct ValidationReport {
747 pub valid: bool,
749 pub errors: Vec<String>,
751 pub warnings: Vec<String>,
753 pub quantum_ops_found: bool,
755}
756
757#[derive(Debug)]
759pub struct ModelInfo {
760 pub opset_version: i64,
762 pub producer_name: String,
764 pub producer_version: String,
766 pub graph_name: String,
768 pub num_nodes: usize,
770 pub num_initializers: usize,
772 pub input_shapes: Vec<Vec<i64>>,
774 pub output_shapes: Vec<Vec<i64>>,
776}
777
778impl Sequential {
780 pub fn export_onnx(
782 &self,
783 path: &str,
784 input_shape: &[usize],
785 options: Option<ExportOptions>,
786 ) -> Result<()> {
787 let exporter = ONNXExporter::new();
788 let exporter = if let Some(opts) = options {
789 exporter.with_options(opts)
790 } else {
791 exporter
792 };
793
794 exporter.export_sequential(self, input_shape, path)
795 }
796
797 fn layers(&self) -> &[Box<dyn KerasLayer>] {
799 &[]
801 }
802
803 fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
805 input_shape.to_vec()
807 }
808}
809
810#[cfg(test)]
811mod tests {
812 use super::*;
813 use crate::keras_api::{ActivationFunction, Dense};
814
815 #[test]
816 fn test_onnx_graph_creation() {
817 let mut graph = ONNXGraph::new("test_graph");
818
819 graph.add_input(ONNXValueInfo::new(
820 "input",
821 ONNXDataType::Float32,
822 vec![1, 10],
823 ));
824
825 graph.add_output(ONNXValueInfo::new(
826 "output",
827 ONNXDataType::Float32,
828 vec![1, 5],
829 ));
830
831 let node = ONNXNode::new(
832 "dense_layer",
833 "MatMul",
834 vec!["input".to_string(), "weight".to_string()],
835 vec!["output".to_string()],
836 );
837 graph.add_node(node);
838
839 assert_eq!(graph.nodes.len(), 1);
840 assert_eq!(graph.inputs.len(), 1);
841 assert_eq!(graph.outputs.len(), 1);
842 }
843
844 #[test]
845 fn test_onnx_tensor_creation() {
846 let array = ndarray::Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
847 .unwrap()
848 .into_dyn();
849
850 let tensor = ONNXTensor::from_array_f64("test_tensor", &array);
851 assert_eq!(tensor.name, "test_tensor");
852 assert_eq!(tensor.shape, vec![2, 3]);
853 }
854
855 #[test]
856 fn test_onnx_exporter_creation() {
857 let exporter = ONNXExporter::new();
858 let options = ExportOptions {
859 opset_version: 13,
860 include_quantum_ops: false,
861 optimize_classical_only: true,
862 quantum_backend: QuantumBackendTarget::Qiskit,
863 };
864
865 let exporter = exporter.with_options(options);
866 assert_eq!(exporter.options.opset_version, 13);
867 assert!(!exporter.options.include_quantum_ops);
868 }
869
870 #[test]
871 fn test_onnx_node_attributes() {
872 let mut node = ONNXNode::new(
873 "test_node",
874 "Conv",
875 vec!["input".to_string()],
876 vec!["output".to_string()],
877 );
878
879 node.add_attribute("kernel_shape", ONNXAttribute::Ints(vec![3, 3]));
880 node.add_attribute("strides", ONNXAttribute::Ints(vec![1, 1]));
881
882 assert_eq!(node.attributes.len(), 2);
883 }
884
885 #[test]
886 fn test_validation_utils() {
887 let report = utils::validate_onnx_model("dummy_path");
888 assert!(report.is_ok());
889
890 let info = utils::get_model_info("dummy_path");
891 assert!(info.is_ok());
892 }
893}