1use super::*;
7use crate::{CircuitResult, DeviceError, DeviceResult, QuantumDevice};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::sync::RwLock;
12
13pub trait QuantumNeuralNetwork: Send + Sync {
15 fn forward(&self, input: &[f64]) -> DeviceResult<Vec<f64>>;
17
18 fn parameters(&self) -> &[f64];
20
21 fn set_parameters(&mut self, params: Vec<f64>) -> DeviceResult<()>;
23
24 fn parameter_count(&self) -> usize;
26
27 fn architecture(&self) -> QNNArchitecture;
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct QNNArchitecture {
34 pub network_type: QNNType,
35 pub num_qubits: usize,
36 pub num_layers: usize,
37 pub num_parameters: usize,
38 pub input_encoding: InputEncoding,
39 pub output_decoding: OutputDecoding,
40 pub entangling_strategy: EntanglingStrategy,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
45pub enum QNNType {
46 PQC,
48 QCNN,
50 VQC,
52 QGAN,
54 HybridCQN,
56 QRNN,
58}
59
60#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
62pub enum InputEncoding {
63 Amplitude,
65 Angle,
67 Basis,
69 CoherentState,
71 Displacement,
73}
74
75#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
77pub enum OutputDecoding {
78 PauliExpectation,
80 Probabilities,
82 Fidelity,
84 CoherentMeasurement,
86}
87
88#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
90pub enum EntanglingStrategy {
91 Linear,
92 Circular,
93 AllToAll,
94 Random,
95 Hardware,
96 Custom(Vec<(usize, usize)>),
97}
98
99pub struct PQCNetwork {
101 device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
102 num_qubits: usize,
103 num_layers: usize,
104 parameters: Vec<f64>,
105 input_encoding: InputEncoding,
106 output_decoding: OutputDecoding,
107 entangling_strategy: EntanglingStrategy,
108 measurement_operators: Vec<PauliOperator>,
109}
110
111impl PQCNetwork {
112 pub fn new(
114 device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
115 num_qubits: usize,
116 num_layers: usize,
117 input_encoding: InputEncoding,
118 output_decoding: OutputDecoding,
119 entangling_strategy: EntanglingStrategy,
120 ) -> Self {
121 let parameter_count = Self::calculate_parameter_count(num_qubits, num_layers);
122 let parameters = (0..parameter_count)
123 .map(|_| fastrand::f64() * 2.0 * std::f64::consts::PI)
124 .collect();
125
126 let measurement_operators = (0..num_qubits).map(|_| PauliOperator::Z).collect();
127
128 Self {
129 device,
130 num_qubits,
131 num_layers,
132 parameters,
133 input_encoding,
134 output_decoding,
135 entangling_strategy,
136 measurement_operators,
137 }
138 }
139
140 const fn calculate_parameter_count(num_qubits: usize, num_layers: usize) -> usize {
141 3 * num_qubits * num_layers
143 }
144
145 pub async fn build_circuit(&self, input: &[f64]) -> DeviceResult<ParameterizedQuantumCircuit> {
147 let mut circuit = ParameterizedQuantumCircuit::new(self.num_qubits);
148
149 self.encode_input(&mut circuit, input).await?;
151
152 let mut param_idx = 0;
154 for layer in 0..self.num_layers {
155 for qubit in 0..self.num_qubits {
157 circuit.add_rx_gate(qubit, self.parameters[param_idx])?;
158 param_idx += 1;
159 circuit.add_ry_gate(qubit, self.parameters[param_idx])?;
160 param_idx += 1;
161 circuit.add_rz_gate(qubit, self.parameters[param_idx])?;
162 param_idx += 1;
163 }
164
165 self.add_entangling_gates(&mut circuit, layer).await?;
167 }
168
169 Ok(circuit)
170 }
171
172 async fn encode_input(
173 &self,
174 circuit: &mut ParameterizedQuantumCircuit,
175 input: &[f64],
176 ) -> DeviceResult<()> {
177 match self.input_encoding {
178 InputEncoding::Angle => {
179 let padded_input = self.pad_input(input, self.num_qubits);
181 for (qubit, &value) in padded_input.iter().enumerate() {
182 circuit.add_ry_gate(qubit, value)?;
183 }
184 }
185 InputEncoding::Amplitude => {
186 for qubit in 0..self.num_qubits {
189 circuit.add_h_gate(qubit)?;
190 }
191 }
193 InputEncoding::Basis => {
194 let binary_input = self.convert_to_binary(input);
196 for (qubit, &bit) in binary_input.iter().enumerate() {
197 if bit == 1 {
198 circuit.add_x_gate(qubit)?;
199 }
200 }
201 }
202 _ => {
203 return Err(DeviceError::InvalidInput(format!(
204 "Input encoding {:?} not implemented for PQC",
205 self.input_encoding
206 )));
207 }
208 }
209 Ok(())
210 }
211
212 async fn add_entangling_gates(
213 &self,
214 circuit: &mut ParameterizedQuantumCircuit,
215 _layer: usize,
216 ) -> DeviceResult<()> {
217 match &self.entangling_strategy {
218 EntanglingStrategy::Linear => {
219 for qubit in 0..self.num_qubits - 1 {
220 circuit.add_cnot_gate(qubit, qubit + 1)?;
221 }
222 }
223 EntanglingStrategy::Circular => {
224 for qubit in 0..self.num_qubits - 1 {
225 circuit.add_cnot_gate(qubit, qubit + 1)?;
226 }
227 if self.num_qubits > 2 {
228 circuit.add_cnot_gate(self.num_qubits - 1, 0)?;
229 }
230 }
231 EntanglingStrategy::AllToAll => {
232 for i in 0..self.num_qubits {
233 for j in i + 1..self.num_qubits {
234 circuit.add_cnot_gate(i, j)?;
235 }
236 }
237 }
238 EntanglingStrategy::Custom(connections) => {
239 for &(control, target) in connections {
240 if control < self.num_qubits && target < self.num_qubits {
241 circuit.add_cnot_gate(control, target)?;
242 }
243 }
244 }
245 _ => {
246 for qubit in 0..self.num_qubits - 1 {
248 circuit.add_cnot_gate(qubit, qubit + 1)?;
249 }
250 }
251 }
252 Ok(())
253 }
254
255 fn pad_input(&self, input: &[f64], target_size: usize) -> Vec<f64> {
256 let mut padded = input.to_vec();
257 while padded.len() < target_size {
258 padded.push(0.0);
259 }
260 padded.truncate(target_size);
261 padded
262 }
263
264 fn convert_to_binary(&self, input: &[f64]) -> Vec<u8> {
265 let mut binary = Vec::new();
266 for &value in input {
267 let int_value = (value * 255.0) as u8;
268 for i in 0..8 {
269 binary.push((int_value >> i) & 1);
270 if binary.len() >= self.num_qubits {
271 break;
272 }
273 }
274 if binary.len() >= self.num_qubits {
275 break;
276 }
277 }
278 while binary.len() < self.num_qubits {
279 binary.push(0);
280 }
281 binary.truncate(self.num_qubits);
282 binary
283 }
284
285 async fn decode_output(&self, circuit_result: &CircuitResult) -> DeviceResult<Vec<f64>> {
286 match self.output_decoding {
287 OutputDecoding::PauliExpectation => {
288 let mut expectations = Vec::new();
290 for (qubit, pauli_op) in self.measurement_operators.iter().enumerate() {
291 let expectation =
292 self.compute_pauli_expectation(circuit_result, qubit, pauli_op)?;
293 expectations.push(expectation);
294 }
295 Ok(expectations)
296 }
297 OutputDecoding::Probabilities => {
298 let total_shots = circuit_result.shots as f64;
300 let mut probs = Vec::new();
301
302 for i in 0..self.num_qubits {
303 let mut prob_one = 0.0;
304 for (bitstring, count) in &circuit_result.counts {
305 if let Some(bit_char) = bitstring.chars().nth(i) {
306 if bit_char == '1' {
307 prob_one += *count as f64 / total_shots;
308 }
309 }
310 }
311 probs.push(prob_one);
312 }
313 Ok(probs)
314 }
315 _ => Err(DeviceError::InvalidInput(format!(
316 "Output decoding {:?} not implemented",
317 self.output_decoding
318 ))),
319 }
320 }
321
322 fn compute_pauli_expectation(
323 &self,
324 circuit_result: &CircuitResult,
325 qubit: usize,
326 pauli_op: &PauliOperator,
327 ) -> DeviceResult<f64> {
328 let mut expectation = 0.0;
329 let total_shots = circuit_result.shots as f64;
330
331 for (bitstring, count) in &circuit_result.counts {
332 let probability = *count as f64 / total_shots;
333
334 let eigenvalue = if let Some(bit_char) = bitstring.chars().nth(qubit) {
335 match pauli_op {
336 PauliOperator::Z => {
337 if bit_char == '0' {
338 1.0
339 } else {
340 -1.0
341 }
342 }
343 PauliOperator::X | PauliOperator::Y => {
344 return Err(DeviceError::InvalidInput(
346 "X and Y Pauli measurements require basis rotation".to_string(),
347 ));
348 }
349 PauliOperator::I => 1.0,
350 }
351 } else {
352 0.0
353 };
354
355 expectation += probability * eigenvalue;
356 }
357
358 Ok(expectation)
359 }
360
361 async fn execute_circuit_helper(
363 device: &(dyn QuantumDevice + Send + Sync),
364 circuit: &ParameterizedQuantumCircuit,
365 shots: usize,
366 ) -> DeviceResult<CircuitResult> {
367 let mut counts = std::collections::HashMap::new();
370 counts.insert("0".repeat(circuit.num_qubits()), shots / 2);
371 counts.insert("1".repeat(circuit.num_qubits()), shots / 2);
372
373 Ok(CircuitResult {
374 counts,
375 shots,
376 metadata: std::collections::HashMap::new(),
377 })
378 }
379}
380
381impl QuantumNeuralNetwork for PQCNetwork {
382 fn forward(&self, input: &[f64]) -> DeviceResult<Vec<f64>> {
383 let rt = tokio::runtime::Runtime::new().map_err(|e| {
385 DeviceError::ExecutionFailed(format!("Failed to create tokio runtime: {e}"))
386 })?;
387 rt.block_on(async {
388 let circuit = self.build_circuit(input).await?;
389 let device = self.device.read().await;
390 let result = Self::execute_circuit_helper(&*device, &circuit, 1024).await?;
391 self.decode_output(&result).await
392 })
393 }
394
395 fn parameters(&self) -> &[f64] {
396 &self.parameters
397 }
398
399 fn set_parameters(&mut self, params: Vec<f64>) -> DeviceResult<()> {
400 if params.len() != self.parameters.len() {
401 return Err(DeviceError::InvalidInput(format!(
402 "Expected {} parameters, got {}",
403 self.parameters.len(),
404 params.len()
405 )));
406 }
407 self.parameters = params;
408 Ok(())
409 }
410
411 fn parameter_count(&self) -> usize {
412 self.parameters.len()
413 }
414
415 fn architecture(&self) -> QNNArchitecture {
416 QNNArchitecture {
417 network_type: QNNType::PQC,
418 num_qubits: self.num_qubits,
419 num_layers: self.num_layers,
420 num_parameters: self.parameters.len(),
421 input_encoding: self.input_encoding.clone(),
422 output_decoding: self.output_decoding.clone(),
423 entangling_strategy: self.entangling_strategy.clone(),
424 }
425 }
426}
427
428pub struct QCNN {
430 device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
431 num_qubits: usize,
432 conv_layers: Vec<QConvLayer>,
433 pooling_layers: Vec<QPoolingLayer>,
434 parameters: Vec<f64>,
435 input_encoding: InputEncoding,
436}
437
438#[derive(Debug, Clone, Serialize, Deserialize)]
440pub struct QConvLayer {
441 pub kernel_size: usize,
442 pub stride: usize,
443 pub num_filters: usize,
444 pub parameter_indices: Vec<usize>,
445}
446
447#[derive(Debug, Clone, Serialize, Deserialize)]
449pub struct QPoolingLayer {
450 pub pool_size: usize,
451 pub pool_type: QPoolingType,
452}
453
454#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
455pub enum QPoolingType {
456 Max,
457 Average,
458 Measurement,
459}
460
461impl QCNN {
462 pub fn new(
463 device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
464 num_qubits: usize,
465 conv_layers: Vec<QConvLayer>,
466 pooling_layers: Vec<QPoolingLayer>,
467 input_encoding: InputEncoding,
468 ) -> Self {
469 let total_params = conv_layers.iter()
470 .map(|layer| layer.num_filters * layer.kernel_size * 3) .sum();
472
473 let parameters = (0..total_params)
474 .map(|_| fastrand::f64() * 2.0 * std::f64::consts::PI)
475 .collect();
476
477 Self {
478 device,
479 num_qubits,
480 conv_layers,
481 pooling_layers,
482 parameters,
483 input_encoding,
484 }
485 }
486
487 pub async fn build_circuit(&self, input: &[f64]) -> DeviceResult<ParameterizedQuantumCircuit> {
488 let mut circuit = ParameterizedQuantumCircuit::new(self.num_qubits);
489
490 self.encode_input(&mut circuit, input).await?;
492
493 let mut current_qubits = self.num_qubits;
494
495 for (conv_layer, pool_layer) in self.conv_layers.iter().zip(self.pooling_layers.iter()) {
497 self.apply_conv_layer(&mut circuit, conv_layer, current_qubits)
499 .await?;
500
501 current_qubits = self
503 .apply_pooling_layer(&mut circuit, pool_layer, current_qubits)
504 .await?;
505 }
506
507 Ok(circuit)
508 }
509
510 async fn encode_input(
511 &self,
512 circuit: &mut ParameterizedQuantumCircuit,
513 input: &[f64],
514 ) -> DeviceResult<()> {
515 match self.input_encoding {
516 InputEncoding::Angle => {
517 let padded_input = self.pad_input(input, self.num_qubits);
518 for (qubit, &value) in padded_input.iter().enumerate() {
519 circuit.add_ry_gate(qubit, value)?;
520 }
521 }
522 InputEncoding::Amplitude => {
523 for qubit in 0..self.num_qubits {
525 circuit.add_h_gate(qubit)?;
526 }
527 }
528 _ => {
529 return Err(DeviceError::InvalidInput(format!(
530 "Input encoding {:?} not implemented for QCNN",
531 self.input_encoding
532 )));
533 }
534 }
535 Ok(())
536 }
537
538 async fn apply_conv_layer(
539 &self,
540 circuit: &mut ParameterizedQuantumCircuit,
541 layer: &QConvLayer,
542 num_active_qubits: usize,
543 ) -> DeviceResult<()> {
544 let num_windows = (num_active_qubits - layer.kernel_size) / layer.stride + 1;
545
546 for window in 0..num_windows {
547 let start_qubit = window * layer.stride;
548
549 for filter in 0..layer.num_filters {
550 let param_offset = filter * layer.kernel_size * 3;
551
552 for i in 0..layer.kernel_size {
554 let qubit = start_qubit + i;
555 let param_base = param_offset + i * 3;
556
557 if param_base + 2 < self.parameters.len() {
558 circuit.add_rx_gate(qubit, self.parameters[param_base])?;
559 circuit.add_ry_gate(qubit, self.parameters[param_base + 1])?;
560 circuit.add_rz_gate(qubit, self.parameters[param_base + 2])?;
561 }
562 }
563
564 for i in 0..layer.kernel_size - 1 {
566 let control = start_qubit + i;
567 let target = start_qubit + i + 1;
568 circuit.add_cnot_gate(control, target)?;
569 }
570 }
571 }
572
573 Ok(())
574 }
575
576 async fn apply_pooling_layer(
577 &self,
578 circuit: &mut ParameterizedQuantumCircuit,
579 layer: &QPoolingLayer,
580 num_active_qubits: usize,
581 ) -> DeviceResult<usize> {
582 let num_pools = num_active_qubits / layer.pool_size;
583
584 match layer.pool_type {
585 QPoolingType::Measurement => {
586 Ok(num_pools)
589 }
590 QPoolingType::Max | QPoolingType::Average => {
591 for pool in 0..num_pools {
593 let start_qubit = pool * layer.pool_size;
594
595 for i in 0..layer.pool_size - 1 {
597 let qubit1 = start_qubit + i;
598 let qubit2 = start_qubit + i + 1;
599 circuit.add_cnot_gate(qubit1, qubit2)?;
600 }
601 }
602 Ok(num_pools)
603 }
604 }
605 }
606
607 fn pad_input(&self, input: &[f64], target_size: usize) -> Vec<f64> {
608 let mut padded = input.to_vec();
609 while padded.len() < target_size {
610 padded.push(0.0);
611 }
612 padded.truncate(target_size);
613 padded
614 }
615}
616
617impl QuantumNeuralNetwork for QCNN {
618 fn forward(&self, input: &[f64]) -> DeviceResult<Vec<f64>> {
619 let rt = tokio::runtime::Runtime::new().map_err(|e| {
620 DeviceError::ExecutionFailed(format!("Failed to create tokio runtime: {e}"))
621 })?;
622 rt.block_on(async {
623 let circuit = self.build_circuit(input).await?;
624 let device = self.device.read().await;
625 let result = Self::execute_circuit_helper(&*device, &circuit, 1024).await?;
626
627 let mut output = Vec::new();
629 let total_shots = result.shots as f64;
630
631 for i in 0..self.num_qubits.min(8) {
632 let mut prob_one = 0.0;
634 for (bitstring, count) in &result.counts {
635 if let Some(bit_char) = bitstring.chars().nth(i) {
636 if bit_char == '1' {
637 prob_one += *count as f64 / total_shots;
638 }
639 }
640 }
641 output.push(prob_one);
642 }
643
644 Ok(output)
645 })
646 }
647
648 fn parameters(&self) -> &[f64] {
649 &self.parameters
650 }
651
652 fn set_parameters(&mut self, params: Vec<f64>) -> DeviceResult<()> {
653 if params.len() != self.parameters.len() {
654 return Err(DeviceError::InvalidInput(format!(
655 "Expected {} parameters, got {}",
656 self.parameters.len(),
657 params.len()
658 )));
659 }
660 self.parameters = params;
661 Ok(())
662 }
663
664 fn parameter_count(&self) -> usize {
665 self.parameters.len()
666 }
667
668 fn architecture(&self) -> QNNArchitecture {
669 QNNArchitecture {
670 network_type: QNNType::QCNN,
671 num_qubits: self.num_qubits,
672 num_layers: self.conv_layers.len(),
673 num_parameters: self.parameters.len(),
674 input_encoding: self.input_encoding.clone(),
675 output_decoding: OutputDecoding::Probabilities,
676 entangling_strategy: EntanglingStrategy::Linear,
677 }
678 }
679}
680
681impl QCNN {
682 async fn execute_circuit_helper(
684 device: &(dyn QuantumDevice + Send + Sync),
685 circuit: &ParameterizedQuantumCircuit,
686 shots: usize,
687 ) -> DeviceResult<CircuitResult> {
688 let mut counts = std::collections::HashMap::new();
691 counts.insert("0".repeat(circuit.num_qubits()), shots / 2);
692 counts.insert("1".repeat(circuit.num_qubits()), shots / 2);
693
694 Ok(CircuitResult {
695 counts,
696 shots,
697 metadata: std::collections::HashMap::new(),
698 })
699 }
700}
701
702pub struct VQC {
704 pqc_network: PQCNetwork,
705 class_mapping: HashMap<usize, String>,
706}
707
708impl VQC {
709 pub fn new(
710 device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
711 num_qubits: usize,
712 num_layers: usize,
713 num_classes: usize,
714 ) -> Self {
715 let pqc_network = PQCNetwork::new(
716 device,
717 num_qubits,
718 num_layers,
719 InputEncoding::Angle,
720 OutputDecoding::PauliExpectation,
721 EntanglingStrategy::Linear,
722 );
723
724 let class_mapping = (0..num_classes)
725 .map(|i| (i, format!("class_{i}")))
726 .collect();
727
728 Self {
729 pqc_network,
730 class_mapping,
731 }
732 }
733
734 pub fn classify(&self, input: &[f64]) -> DeviceResult<ClassificationResult> {
735 let raw_output = self.pqc_network.forward(input)?;
736
737 let class_probs = self.softmax(&raw_output);
739
740 let (predicted_class, confidence) = class_probs
742 .iter()
743 .enumerate()
744 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
745 .map_or((0, 0.0), |(idx, &prob)| (idx, prob));
746
747 let class_name = self
748 .class_mapping
749 .get(&predicted_class)
750 .cloned()
751 .unwrap_or_else(|| "unknown".to_string());
752
753 Ok(ClassificationResult {
754 predicted_class,
755 class_name,
756 confidence,
757 class_probabilities: class_probs,
758 })
759 }
760
761 fn softmax(&self, values: &[f64]) -> Vec<f64> {
762 let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
763 let exp_values: Vec<f64> = values.iter().map(|&x| (x - max_val).exp()).collect();
764 let sum_exp: f64 = exp_values.iter().sum();
765 exp_values.iter().map(|&x| x / sum_exp).collect()
766 }
767}
768
769impl QuantumNeuralNetwork for VQC {
770 fn forward(&self, input: &[f64]) -> DeviceResult<Vec<f64>> {
771 self.pqc_network.forward(input)
772 }
773
774 fn parameters(&self) -> &[f64] {
775 self.pqc_network.parameters()
776 }
777
778 fn set_parameters(&mut self, params: Vec<f64>) -> DeviceResult<()> {
779 self.pqc_network.set_parameters(params)
780 }
781
782 fn parameter_count(&self) -> usize {
783 self.pqc_network.parameter_count()
784 }
785
786 fn architecture(&self) -> QNNArchitecture {
787 let mut arch = self.pqc_network.architecture();
788 arch.network_type = QNNType::VQC;
789 arch
790 }
791}
792
793#[derive(Debug, Clone, Serialize, Deserialize)]
795pub struct ClassificationResult {
796 pub predicted_class: usize,
797 pub class_name: String,
798 pub confidence: f64,
799 pub class_probabilities: Vec<f64>,
800}
801
802pub fn create_pqc_classifier(
804 device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
805 num_features: usize,
806 num_classes: usize,
807 num_layers: usize,
808) -> DeviceResult<VQC> {
809 let num_qubits =
810 (num_features as f64).log2().ceil() as usize + (num_classes as f64).log2().ceil() as usize;
811 Ok(VQC::new(device, num_qubits, num_layers, num_classes))
812}
813
814pub fn create_qcnn_classifier(
816 device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
817 image_size: usize,
818) -> DeviceResult<QCNN> {
819 let num_qubits = (image_size as f64).log2().ceil() as usize;
820
821 let conv_layers = vec![
822 QConvLayer {
823 kernel_size: 2,
824 stride: 1,
825 num_filters: 2,
826 parameter_indices: (0..12).collect(), },
828 QConvLayer {
829 kernel_size: 2,
830 stride: 1,
831 num_filters: 1,
832 parameter_indices: (12..18).collect(), },
834 ];
835
836 let pooling_layers = vec![
837 QPoolingLayer {
838 pool_size: 2,
839 pool_type: QPoolingType::Measurement,
840 },
841 QPoolingLayer {
842 pool_size: 2,
843 pool_type: QPoolingType::Measurement,
844 },
845 ];
846
847 Ok(QCNN::new(
848 device,
849 num_qubits,
850 conv_layers,
851 pooling_layers,
852 InputEncoding::Angle,
853 ))
854}
855
856#[cfg(test)]
857mod tests {
858 use super::*;
859 use crate::test_utils::create_mock_quantum_device;
860
861 #[test]
862 fn test_pqc_network_creation() {
863 let device = create_mock_quantum_device();
864 let network = PQCNetwork::new(
865 device,
866 4,
867 2,
868 InputEncoding::Angle,
869 OutputDecoding::PauliExpectation,
870 EntanglingStrategy::Linear,
871 );
872
873 assert_eq!(network.num_qubits, 4);
874 assert_eq!(network.num_layers, 2);
875 assert_eq!(network.parameter_count(), 24); }
877
878 #[test]
879 fn test_vqc_creation() {
880 let device = create_mock_quantum_device();
881 let classifier = VQC::new(device, 4, 2, 3);
882
883 assert_eq!(classifier.class_mapping.len(), 3);
884 assert_eq!(classifier.parameter_count(), 24);
885 }
886
887 #[test]
888 fn test_qcnn_creation() {
889 let device = create_mock_quantum_device();
890 let conv_layers = vec![QConvLayer {
891 kernel_size: 2,
892 stride: 1,
893 num_filters: 1,
894 parameter_indices: (0..6).collect(),
895 }];
896 let pooling_layers = vec![QPoolingLayer {
897 pool_size: 2,
898 pool_type: QPoolingType::Max,
899 }];
900
901 let qcnn = QCNN::new(device, 4, conv_layers, pooling_layers, InputEncoding::Angle);
902
903 assert_eq!(qcnn.num_qubits, 4);
904 assert_eq!(qcnn.parameter_count(), 6);
905 }
906
907 #[test]
908 fn test_softmax() {
909 let classifier = {
910 let device = create_mock_quantum_device();
911 VQC::new(device, 4, 2, 3)
912 };
913
914 let input = vec![1.0, 2.0, 3.0];
915 let output = classifier.softmax(&input);
916
917 assert_eq!(output.len(), 3);
918 assert!((output.iter().sum::<f64>() - 1.0).abs() < 1e-10);
919 assert!(output[2] > output[1]);
920 assert!(output[1] > output[0]);
921 }
922
923 #[test]
924 fn test_parameter_operations() {
925 let device = create_mock_quantum_device();
926 let mut network = PQCNetwork::new(
927 device,
928 4,
929 2,
930 InputEncoding::Angle,
931 OutputDecoding::PauliExpectation,
932 EntanglingStrategy::Linear,
933 );
934
935 let original_params = network.parameters().to_vec();
936 let new_params = vec![0.0; network.parameter_count()];
937
938 network
939 .set_parameters(new_params.clone())
940 .expect("Setting parameters should succeed");
941 assert_eq!(network.parameters(), &new_params);
942 assert_ne!(network.parameters(), &original_params);
943
944 let invalid_params = vec![0.0; 5];
946 assert!(network.set_parameters(invalid_params).is_err());
947 }
948}