1use crate::error::{MLError, Result};
6use crate::simulator_backends::{DynamicCircuit, Observable, SimulationResult, SimulatorBackend};
7use quantrs2_circuit::prelude::*;
8use quantrs2_core::prelude::*;
9use scirs2_core::ndarray::{s, Array1, Array2, Array3, Array4, ArrayD, Axis};
10use std::sync::Arc;
11
12use super::types::{
13 DataEncodingType, DifferentiationMethod, PQCLayer, ParameterInitStrategy, QuantumCircuitLayer,
14 QuantumDataset, TFQLossFunction, TFQModel, TFQOptimizer,
15};
16
17#[cfg(test)]
18mod tests {
19 use super::*;
20 use crate::simulator_backends::{BackendCapabilities, StatevectorBackend};
21 use crate::tensorflow_compatibility::tfq_utils;
22 #[test]
23 #[ignore]
24 fn test_quantum_circuit_layer() {
25 let mut builder = CircuitBuilder::new();
26 builder.ry(0, 0.0).expect("RY gate should succeed");
27 builder.ry(1, 0.0).expect("RY gate should succeed");
28 builder.cnot(0, 1).expect("CNOT gate should succeed");
29 let circuit = builder.build();
30 let symbols = vec!["theta1".to_string(), "theta2".to_string()];
31 let observable = Observable::PauliZ(vec![0, 1]);
32 let backend = Arc::new(StatevectorBackend::new(8));
33 let layer = QuantumCircuitLayer::new(circuit, symbols, observable, backend);
34 let inputs = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4])
35 .expect("Valid shape for inputs");
36 let parameters = Array2::from_shape_vec((2, 2), vec![0.5, 0.6, 0.7, 0.8])
37 .expect("Valid shape for parameters");
38 let result = layer.forward(&inputs, ¶meters);
39 assert!(result.is_ok());
40 }
41 #[test]
42 fn test_pqc_layer_initialization() -> Result<()> {
43 let mut builder = CircuitBuilder::new();
44 builder.h(0)?;
45 let circuit = builder.build();
46 let symbols = vec!["param1".to_string()];
47 let observable = Observable::PauliZ(vec![0]);
48 let backend = Arc::new(StatevectorBackend::new(8));
49 let pqc = PQCLayer::new(circuit, symbols, observable, backend).with_initialization(
50 ParameterInitStrategy::RandomNormal {
51 mean: 0.0,
52 std: 0.1,
53 },
54 );
55 let params = pqc.initialize_parameters(5, 3);
56 assert_eq!(params.shape(), &[5, 3]);
57 Ok(())
58 }
59 #[test]
60 fn test_glorot_uniform_initialization() -> Result<()> {
61 let mut builder = CircuitBuilder::new();
62 builder.h(0)?;
63 let circuit = builder.build();
64 let symbols = vec!["param1".to_string()];
65 let observable = Observable::PauliZ(vec![0]);
66 let backend = Arc::new(StatevectorBackend::new(8));
67 let pqc = PQCLayer::new(circuit, symbols, observable, backend)
68 .with_initialization(ParameterInitStrategy::GlorotUniform);
69 let params = pqc.initialize_parameters(10, 6);
70 assert_eq!(params.shape(), &[10, 6]);
71 let limit = (6.0_f64 / (2.0_f64 * 6.0_f64)).sqrt();
72 for &val in params.iter() {
73 assert!(
74 val >= -limit && val <= limit,
75 "Parameter {} outside range [-{}, {}]",
76 val,
77 limit,
78 limit
79 );
80 }
81 let mean = params.mean().expect("Mean calculation should succeed");
82 let variance = params
83 .mapv(|x| (x - mean).powi(2))
84 .mean()
85 .expect("Variance calculation should succeed");
86 assert!(variance > 0.0, "Parameters should have non-zero variance");
87 Ok(())
88 }
89 #[test]
90 fn test_glorot_normal_initialization() -> Result<()> {
91 let mut builder = CircuitBuilder::new();
92 builder.h(0)?;
93 let circuit = builder.build();
94 let symbols = vec!["param1".to_string()];
95 let observable = Observable::PauliZ(vec![0]);
96 let backend = Arc::new(StatevectorBackend::new(8));
97 let pqc = PQCLayer::new(circuit, symbols, observable, backend)
98 .with_initialization(ParameterInitStrategy::GlorotNormal);
99 let params = pqc.initialize_parameters(100, 10);
100 assert_eq!(params.shape(), &[100, 10]);
101 let mean = params.mean().expect("Mean calculation should succeed");
102 assert!(mean.abs() < 0.1, "Mean {} should be close to 0", mean);
103 let expected_std = (2.0_f64 / (2.0_f64 * 10.0_f64)).sqrt();
104 let variance = params
105 .mapv(|x| (x - mean).powi(2))
106 .mean()
107 .expect("Variance calculation should succeed");
108 let actual_std = variance.sqrt();
109 assert!(
110 (actual_std - expected_std).abs() / expected_std < 0.2,
111 "Std {} should be close to expected {}",
112 actual_std,
113 expected_std
114 );
115 Ok(())
116 }
117 #[test]
118 fn test_elasticnet_regularization() -> Result<()> {
119 let l1_ratio = 0.5_f64;
120 let alpha = 0.01_f64;
121 let parameters = Array2::from_shape_vec((2, 2), vec![1.0, -1.0, 2.0, -2.0])
122 .expect("Valid shape for parameters");
123 let expected_l1 = parameters.mapv(|x: f64| alpha * l1_ratio * x.signum());
124 let expected_l2 = ¶meters * (2.0 * alpha * (1.0 - l1_ratio));
125 let expected_total = &expected_l1 + &expected_l2;
126 assert!((expected_total[[0, 0]] - 0.015_f64).abs() < 1e-10);
127 assert!((expected_total[[0, 1]] - (-0.015_f64)).abs() < 1e-10);
128 assert!((expected_total[[1, 0]] - 0.025_f64).abs() < 1e-10);
129 assert!((expected_total[[1, 1]] - (-0.025_f64)).abs() < 1e-10);
130 Ok(())
131 }
132 #[test]
133 fn test_elasticnet_extreme_cases() -> Result<()> {
134 let alpha = 0.01_f64;
135 let parameters = Array2::from_shape_vec((1, 2), vec![1.0, 2.0]).expect("Valid shape");
136 let l1_ratio = 0.0_f64;
137 let l1_part = parameters.mapv(|x: f64| alpha * l1_ratio * x.signum());
138 let l2_part = ¶meters * (2.0 * alpha * (1.0 - l1_ratio));
139 let total = &l1_part + &l2_part;
140 assert!((total[[0, 0]] - 0.02_f64).abs() < 1e-10);
141 assert!((total[[0, 1]] - 0.04_f64).abs() < 1e-10);
142 let l1_ratio = 1.0_f64;
143 let l1_part = parameters.mapv(|x: f64| alpha * l1_ratio * x.signum());
144 let l2_part = ¶meters * (2.0 * alpha * (1.0 - l1_ratio));
145 let total = &l1_part + &l2_part;
146 assert!((total[[0, 0]] - 0.01_f64).abs() < 1e-10);
147 assert!((total[[0, 1]] - 0.01_f64).abs() < 1e-10);
148 Ok(())
149 }
150 #[test]
151 #[ignore]
152 fn test_tfq_utils() {
153 let circuit = tfq_utils::create_data_encoding_circuit(3, DataEncodingType::Angle)
154 .expect("Data encoding circuit creation should succeed");
155 assert_eq!(circuit.num_qubits(), 3);
156 let ansatz = tfq_utils::create_hardware_efficient_ansatz(4, 2)
157 .expect("Hardware efficient ansatz creation should succeed");
158 assert_eq!(ansatz.num_qubits(), 4);
159 }
160 #[test]
161 fn test_quantum_dataset() -> Result<()> {
162 let circuits = vec![CircuitBuilder::new().build(), CircuitBuilder::new().build()];
163 let parameters = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
164 .expect("Valid shape for parameters");
165 let labels = Array1::from_vec(vec![0.0, 1.0]);
166 let dataset = QuantumDataset::new(circuits, parameters, labels, 1);
167 let dataset = dataset?;
168 let batches: Vec<_> = dataset.batches().collect();
169 assert_eq!(batches.len(), 2);
170 assert_eq!(batches[0].0.len(), 1);
171 Ok(())
172 }
173 #[test]
174 #[ignore]
175 fn test_tfq_model() {
176 let mut model = TFQModel::new(vec![2, 2])
177 .set_loss(TFQLossFunction::MeanSquaredError)
178 .set_optimizer(TFQOptimizer::Adam {
179 learning_rate: 0.01,
180 beta1: 0.9,
181 beta2: 0.999,
182 epsilon: 1e-8,
183 });
184 assert!(model.compile().is_ok());
185 }
186 #[test]
187 fn test_cirq_circuit_creation() {
188 use crate::tensorflow_compatibility::cirq_converter::*;
189 let circuit = CirqCircuit::new(4);
190 assert_eq!(circuit.num_qubits, 4);
191 assert_eq!(circuit.gates.len(), 0);
192 assert_eq!(circuit.param_symbols.len(), 0);
193 }
194 #[test]
195 fn test_cirq_bell_circuit() -> Result<()> {
196 use crate::tensorflow_compatibility::cirq_converter::*;
197 let cirq_circuit = create_bell_circuit();
198 assert_eq!(cirq_circuit.num_qubits, 2);
199 assert_eq!(cirq_circuit.gates.len(), 2);
200 let quantrs_circuit = cirq_circuit.to_quantrs2_circuit::<2>()?;
201 assert_eq!(quantrs_circuit.num_qubits(), 2);
202 assert_eq!(quantrs_circuit.gates().len(), 2);
203 Ok(())
204 }
205 #[test]
206 fn test_cirq_gate_conversion() -> Result<()> {
207 use crate::tensorflow_compatibility::cirq_converter::*;
208 let mut cirq_circuit = CirqCircuit::new(3);
209 cirq_circuit.add_gate(CirqGate::H { qubit: 0 });
210 cirq_circuit.add_gate(CirqGate::X { qubit: 1 });
211 cirq_circuit.add_gate(CirqGate::Y { qubit: 2 });
212 cirq_circuit.add_gate(CirqGate::CNOT {
213 control: 0,
214 target: 1,
215 });
216 cirq_circuit.add_gate(CirqGate::CZ {
217 control: 1,
218 target: 2,
219 });
220 assert_eq!(cirq_circuit.gates.len(), 5);
221 let quantrs_circuit = cirq_circuit.to_quantrs2_circuit::<3>()?;
222 assert_eq!(quantrs_circuit.gates().len(), 5);
223 Ok(())
224 }
225 #[test]
226 fn test_cirq_rotation_gates() -> Result<()> {
227 use crate::tensorflow_compatibility::cirq_converter::*;
228 let mut cirq_circuit = CirqCircuit::new(2);
229 cirq_circuit.add_gate(CirqGate::Rx {
230 qubit: 0,
231 angle: std::f64::consts::PI / 2.0,
232 });
233 cirq_circuit.add_gate(CirqGate::Ry {
234 qubit: 1,
235 angle: std::f64::consts::PI / 4.0,
236 });
237 cirq_circuit.add_gate(CirqGate::Rz {
238 qubit: 0,
239 angle: std::f64::consts::PI,
240 });
241 let quantrs_circuit = cirq_circuit.to_quantrs2_circuit::<2>()?;
242 assert_eq!(quantrs_circuit.gates().len(), 3);
243 Ok(())
244 }
245 #[test]
246 fn test_cirq_pow_gates() -> Result<()> {
247 use crate::tensorflow_compatibility::cirq_converter::*;
248 let mut cirq_circuit = CirqCircuit::new(3);
249 cirq_circuit.add_gate(CirqGate::XPowGate {
250 qubit: 0,
251 exponent: 0.5,
252 global_shift: 0.0,
253 });
254 cirq_circuit.add_gate(CirqGate::YPowGate {
255 qubit: 1,
256 exponent: 1.0,
257 global_shift: 0.0,
258 });
259 cirq_circuit.add_gate(CirqGate::ZPowGate {
260 qubit: 2,
261 exponent: 0.25,
262 global_shift: 0.0,
263 });
264 let quantrs_circuit = cirq_circuit.to_quantrs2_circuit::<3>()?;
265 assert_eq!(quantrs_circuit.gates().len(), 3);
266 Ok(())
267 }
268 #[test]
269 fn test_cirq_parametric_circuit() -> Result<()> {
270 use crate::tensorflow_compatibility::cirq_converter::*;
271 let cirq_circuit = create_parametric_circuit(4, 2);
272 assert_eq!(cirq_circuit.num_qubits, 4);
273 assert_eq!(cirq_circuit.gates.len(), 14);
274 assert_eq!(cirq_circuit.param_symbols.len(), 8);
275 let quantrs_circuit = cirq_circuit.to_quantrs2_circuit::<4>()?;
276 assert_eq!(quantrs_circuit.gates().len(), 14);
277 Ok(())
278 }
279 #[test]
280 fn test_cirq_to_dynamic_circuit() -> Result<()> {
281 use crate::tensorflow_compatibility::cirq_converter::*;
282 let cirq_circuit = create_bell_circuit();
283 let dynamic_circuit = cirq_circuit.to_dynamic_circuit()?;
284 assert_eq!(dynamic_circuit.num_qubits(), 2);
285 Ok(())
286 }
287 #[test]
288 fn test_cirq_u3_gate() -> Result<()> {
289 use crate::tensorflow_compatibility::cirq_converter::*;
290 let mut cirq_circuit = CirqCircuit::new(1);
291 cirq_circuit.add_gate(CirqGate::U3 {
292 qubit: 0,
293 theta: std::f64::consts::PI / 2.0,
294 phi: std::f64::consts::PI / 4.0,
295 lambda: std::f64::consts::PI / 3.0,
296 });
297 let quantrs_circuit = cirq_circuit.to_quantrs2_circuit::<1>()?;
298 assert_eq!(quantrs_circuit.gates().len(), 1);
299 Ok(())
300 }
301 #[test]
302 fn test_cirq_pow_gate_to_angle() {
303 use crate::tensorflow_compatibility::cirq_converter::*;
304 let angle = pow_gate_to_angle(1.0);
305 assert!((angle - std::f64::consts::PI).abs() < 1e-10);
306 let angle = pow_gate_to_angle(0.5);
307 assert!((angle - std::f64::consts::PI / 2.0).abs() < 1e-10);
308 let angle = pow_gate_to_angle(0.25);
309 assert!((angle - std::f64::consts::PI / 4.0).abs() < 1e-10);
310 }
311 #[test]
312 fn test_differentiation_method_enum() {
313 use DifferentiationMethod::*;
314 assert_eq!(ParameterShift, ParameterShift);
315 assert_eq!(Adjoint, Adjoint);
316 assert_ne!(ParameterShift, Adjoint);
317 }
318 #[test]
319 fn test_pqc_with_adjoint_method() -> Result<()> {
320 let mut builder = CircuitBuilder::new();
321 builder.h(0)?;
322 let circuit = builder.build();
323 let symbols = vec!["param1".to_string()];
324 let observable = Observable::PauliZ(vec![0]);
325 let backend = Arc::new(StatevectorBackend::new(8));
326 let pqc = PQCLayer::new(circuit, symbols, observable, backend)
327 .with_differentiation(DifferentiationMethod::Adjoint);
328 assert_eq!(pqc.differentiation_method, DifferentiationMethod::Adjoint);
329 Ok(())
330 }
331 #[test]
332 fn test_pqc_with_parameter_shift_default() -> Result<()> {
333 let mut builder = CircuitBuilder::new();
334 builder.h(0)?;
335 let circuit = builder.build();
336 let symbols = vec!["param1".to_string()];
337 let observable = Observable::PauliZ(vec![0]);
338 let backend = Arc::new(StatevectorBackend::new(8));
339 let pqc = PQCLayer::new(circuit, symbols, observable, backend);
340 assert_eq!(
341 pqc.differentiation_method,
342 DifferentiationMethod::ParameterShift
343 );
344 Ok(())
345 }
346 #[test]
347 fn test_differentiation_method_switching() -> Result<()> {
348 let mut builder = CircuitBuilder::new();
349 builder.h(0)?;
350 let circuit = builder.build();
351 let symbols = vec!["param1".to_string()];
352 let observable = Observable::PauliZ(vec![0]);
353 let backend = Arc::new(StatevectorBackend::new(8));
354 let pqc = PQCLayer::new(
355 circuit.clone(),
356 symbols.clone(),
357 observable.clone(),
358 backend.clone(),
359 )
360 .with_differentiation(DifferentiationMethod::ParameterShift);
361 assert_eq!(
362 pqc.differentiation_method,
363 DifferentiationMethod::ParameterShift
364 );
365 let pqc_adjoint = PQCLayer::new(circuit, symbols, observable, backend)
366 .with_differentiation(DifferentiationMethod::Adjoint);
367 assert_eq!(
368 pqc_adjoint.differentiation_method,
369 DifferentiationMethod::Adjoint
370 );
371 Ok(())
372 }
373}