quantrs2_ml/quantum_continuous_flows/
functions.rs1use crate::error::{MLError, Result};
6use crate::quantum_continuous_flows::types::*;
7use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView1, Axis};
8use scirs2_core::random::prelude::*;
9use scirs2_core::random::ChaCha20Rng;
10use scirs2_core::random::{Rng, SeedableRng};
11use scirs2_core::Complex64;
12use std::f64::consts::PI;
13#[cfg(test)]
14mod tests {
15 use super::*;
16 #[test]
17 fn test_quantum_continuous_flow_creation() {
18 let config = QuantumContinuousFlowConfig::default();
19 let flow = QuantumContinuousFlow::new(config);
20 assert!(flow.is_ok());
21 }
22 #[test]
23 fn test_flow_forward_pass() {
24 let config = QuantumContinuousFlowConfig {
25 input_dim: 4,
26 latent_dim: 4,
27 num_qubits: 4,
28 num_flow_layers: 2,
29 ..Default::default()
30 };
31 let flow = QuantumContinuousFlow::new(config).expect("Flow creation should succeed");
32 let x = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
33 let result = flow.forward(&x);
34 assert!(result.is_ok());
35 let output = result.expect("Forward pass should succeed");
36 assert_eq!(output.latent_sample.len(), 4);
37 assert!(output.quantum_enhancement.quantum_advantage_ratio >= 1.0);
38 }
39 #[test]
40 fn test_flow_inverse_pass() {
41 let config = QuantumContinuousFlowConfig {
42 input_dim: 4,
43 latent_dim: 4,
44 num_qubits: 4,
45 ..Default::default()
46 };
47 let flow = QuantumContinuousFlow::new(config).expect("Flow creation should succeed");
48 let z = Array1::from_vec(vec![0.5, -0.3, 0.8, -0.1]);
49 let result = flow.inverse(&z);
50 assert!(result.is_ok());
51 let output = result.expect("Inverse pass should succeed");
52 assert_eq!(output.data_sample.len(), 4);
53 }
54 #[test]
55 fn test_quantum_sampling() {
56 let config = QuantumContinuousFlowConfig {
57 input_dim: 2,
58 latent_dim: 2,
59 num_qubits: 3,
60 ..Default::default()
61 };
62 let flow = QuantumContinuousFlow::new(config).expect("Flow creation should succeed");
63 let result = flow.sample(5);
64 assert!(result.is_ok());
65 let output = result.expect("Sampling should succeed");
66 assert_eq!(output.samples.shape(), &[5, 2]);
67 assert_eq!(output.quantum_metrics.len(), 5);
68 }
69 #[test]
70 fn test_quantum_coupling_types() {
71 let config = QuantumContinuousFlowConfig {
72 flow_architecture: FlowArchitecture::QuantumRealNVP {
73 hidden_dims: vec![32, 32],
74 num_coupling_layers: 2,
75 quantum_coupling_type: QuantumCouplingType::QuantumEntangledCoupling,
76 },
77 ..Default::default()
78 };
79 let flow = QuantumContinuousFlow::new(config);
80 assert!(flow.is_ok());
81 }
82 #[test]
83 fn test_quantum_neural_ode_flow() {
84 let config = QuantumContinuousFlowConfig {
85 flow_architecture: FlowArchitecture::QuantumContinuousNormalizing {
86 ode_net_dims: vec![16, 16],
87 quantum_ode_solver: QuantumODESolver::QuantumRungeKutta4,
88 trace_estimation_method: TraceEstimationMethod::EntanglementBasedTrace,
89 },
90 ..Default::default()
91 };
92 let flow = QuantumContinuousFlow::new(config);
93 assert!(flow.is_ok());
94 }
95 #[test]
96 fn test_quantum_base_distributions() {
97 let config = QuantumContinuousFlowConfig {
98 latent_dim: 3,
99 ..Default::default()
100 };
101 let flow = QuantumContinuousFlow::new(config).expect("Flow creation should succeed");
102 let sample = flow.sample_base_distribution();
103 assert!(sample.is_ok());
104 assert_eq!(
105 sample
106 .expect("Base distribution sample should succeed")
107 .len(),
108 3
109 );
110 }
111 #[test]
112 #[ignore]
113 fn test_invertibility_guarantees() {
114 let config = QuantumContinuousFlowConfig {
115 input_dim: 4,
116 latent_dim: 4,
117 num_qubits: 4,
118 num_flow_layers: 2,
119 invertibility_tolerance: 1e-8,
120 ..Default::default()
121 };
122 let flow = QuantumContinuousFlow::new(config).expect("Flow creation should succeed");
123 let x = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
124 let forward_output = flow.forward(&x).expect("Forward pass should succeed");
125 let inverse_output = flow
126 .inverse(&forward_output.latent_sample)
127 .expect("Inverse pass should succeed");
128 let error = (&x - &inverse_output.data_sample)
129 .mapv(|x: f64| x.abs())
130 .sum();
131 assert!(error < 1.0);
132 }
133}