quantrs2_device/quantum_ml/
quantum_neural_networks.rs

1//! Quantum Neural Networks
2//!
3//! This module implements various quantum neural network architectures including
4//! parameterized quantum circuits, quantum convolutional networks, and hybrid models.
5
6use super::*;
7use crate::{CircuitResult, DeviceError, DeviceResult, QuantumDevice};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::sync::RwLock;
12
13/// Quantum Neural Network trait
14pub trait QuantumNeuralNetwork: Send + Sync {
15    /// Forward pass through the network
16    fn forward(&self, input: &[f64]) -> DeviceResult<Vec<f64>>;
17
18    /// Get trainable parameters
19    fn parameters(&self) -> &[f64];
20
21    /// Set trainable parameters
22    fn set_parameters(&mut self, params: Vec<f64>) -> DeviceResult<()>;
23
24    /// Get parameter count
25    fn parameter_count(&self) -> usize;
26
27    /// Get network architecture description
28    fn architecture(&self) -> QNNArchitecture;
29}
30
31/// QNN Architecture description
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct QNNArchitecture {
34    pub network_type: QNNType,
35    pub num_qubits: usize,
36    pub num_layers: usize,
37    pub num_parameters: usize,
38    pub input_encoding: InputEncoding,
39    pub output_decoding: OutputDecoding,
40    pub entangling_strategy: EntanglingStrategy,
41}
42
43/// Types of quantum neural networks
44#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
45pub enum QNNType {
46    /// Parameterized Quantum Circuit
47    PQC,
48    /// Quantum Convolutional Neural Network
49    QCNN,
50    /// Variational Quantum Classifier
51    VQC,
52    /// Quantum Generative Adversarial Network
53    QGAN,
54    /// Hybrid Classical-Quantum Network
55    HybridCQN,
56    /// Quantum Recurrent Neural Network
57    QRNN,
58}
59
60/// Input encoding strategies
61#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
62pub enum InputEncoding {
63    /// Amplitude encoding
64    Amplitude,
65    /// Angle encoding
66    Angle,
67    /// Basis encoding
68    Basis,
69    /// Coherent state encoding (for CV systems)
70    CoherentState,
71    /// Displacement encoding
72    Displacement,
73}
74
75/// Output decoding strategies
76#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
77pub enum OutputDecoding {
78    /// Expectation value of Pauli operators
79    PauliExpectation,
80    /// Measurement probabilities
81    Probabilities,
82    /// Fidelity measurement
83    Fidelity,
84    /// Coherent state measurement
85    CoherentMeasurement,
86}
87
88/// Entangling strategies
89#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
90pub enum EntanglingStrategy {
91    Linear,
92    Circular,
93    AllToAll,
94    Random,
95    Hardware,
96    Custom(Vec<(usize, usize)>),
97}
98
99/// Parameterized Quantum Circuit Network
100pub struct PQCNetwork {
101    device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
102    num_qubits: usize,
103    num_layers: usize,
104    parameters: Vec<f64>,
105    input_encoding: InputEncoding,
106    output_decoding: OutputDecoding,
107    entangling_strategy: EntanglingStrategy,
108    measurement_operators: Vec<PauliOperator>,
109}
110
111impl PQCNetwork {
112    /// Create a new PQC network
113    pub fn new(
114        device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
115        num_qubits: usize,
116        num_layers: usize,
117        input_encoding: InputEncoding,
118        output_decoding: OutputDecoding,
119        entangling_strategy: EntanglingStrategy,
120    ) -> Self {
121        let parameter_count = Self::calculate_parameter_count(num_qubits, num_layers);
122        let parameters = (0..parameter_count)
123            .map(|_| fastrand::f64() * 2.0 * std::f64::consts::PI)
124            .collect();
125
126        let measurement_operators = (0..num_qubits).map(|_| PauliOperator::Z).collect();
127
128        Self {
129            device,
130            num_qubits,
131            num_layers,
132            parameters,
133            input_encoding,
134            output_decoding,
135            entangling_strategy,
136            measurement_operators,
137        }
138    }
139
140    const fn calculate_parameter_count(num_qubits: usize, num_layers: usize) -> usize {
141        // Each layer has 3 rotation gates per qubit
142        3 * num_qubits * num_layers
143    }
144
145    /// Build the quantum circuit for given input
146    pub async fn build_circuit(&self, input: &[f64]) -> DeviceResult<ParameterizedQuantumCircuit> {
147        let mut circuit = ParameterizedQuantumCircuit::new(self.num_qubits);
148
149        // Input encoding
150        self.encode_input(&mut circuit, input).await?;
151
152        // Parameterized layers
153        let mut param_idx = 0;
154        for layer in 0..self.num_layers {
155            // Rotation gates
156            for qubit in 0..self.num_qubits {
157                circuit.add_rx_gate(qubit, self.parameters[param_idx])?;
158                param_idx += 1;
159                circuit.add_ry_gate(qubit, self.parameters[param_idx])?;
160                param_idx += 1;
161                circuit.add_rz_gate(qubit, self.parameters[param_idx])?;
162                param_idx += 1;
163            }
164
165            // Entangling gates
166            self.add_entangling_gates(&mut circuit, layer).await?;
167        }
168
169        Ok(circuit)
170    }
171
172    async fn encode_input(
173        &self,
174        circuit: &mut ParameterizedQuantumCircuit,
175        input: &[f64],
176    ) -> DeviceResult<()> {
177        match self.input_encoding {
178            InputEncoding::Angle => {
179                // Encode input as rotation angles
180                let padded_input = self.pad_input(input, self.num_qubits);
181                for (qubit, &value) in padded_input.iter().enumerate() {
182                    circuit.add_ry_gate(qubit, value)?;
183                }
184            }
185            InputEncoding::Amplitude => {
186                // Amplitude encoding requires state preparation
187                // This is a simplified implementation
188                for qubit in 0..self.num_qubits {
189                    circuit.add_h_gate(qubit)?;
190                }
191                // Would need more sophisticated amplitude encoding
192            }
193            InputEncoding::Basis => {
194                // Basis encoding: encode classical bits as computational basis states
195                let binary_input = self.convert_to_binary(input);
196                for (qubit, &bit) in binary_input.iter().enumerate() {
197                    if bit == 1 {
198                        circuit.add_x_gate(qubit)?;
199                    }
200                }
201            }
202            _ => {
203                return Err(DeviceError::InvalidInput(format!(
204                    "Input encoding {:?} not implemented for PQC",
205                    self.input_encoding
206                )));
207            }
208        }
209        Ok(())
210    }
211
212    async fn add_entangling_gates(
213        &self,
214        circuit: &mut ParameterizedQuantumCircuit,
215        _layer: usize,
216    ) -> DeviceResult<()> {
217        match &self.entangling_strategy {
218            EntanglingStrategy::Linear => {
219                for qubit in 0..self.num_qubits - 1 {
220                    circuit.add_cnot_gate(qubit, qubit + 1)?;
221                }
222            }
223            EntanglingStrategy::Circular => {
224                for qubit in 0..self.num_qubits - 1 {
225                    circuit.add_cnot_gate(qubit, qubit + 1)?;
226                }
227                if self.num_qubits > 2 {
228                    circuit.add_cnot_gate(self.num_qubits - 1, 0)?;
229                }
230            }
231            EntanglingStrategy::AllToAll => {
232                for i in 0..self.num_qubits {
233                    for j in i + 1..self.num_qubits {
234                        circuit.add_cnot_gate(i, j)?;
235                    }
236                }
237            }
238            EntanglingStrategy::Custom(connections) => {
239                for &(control, target) in connections {
240                    if control < self.num_qubits && target < self.num_qubits {
241                        circuit.add_cnot_gate(control, target)?;
242                    }
243                }
244            }
245            _ => {
246                // Default to linear
247                for qubit in 0..self.num_qubits - 1 {
248                    circuit.add_cnot_gate(qubit, qubit + 1)?;
249                }
250            }
251        }
252        Ok(())
253    }
254
255    fn pad_input(&self, input: &[f64], target_size: usize) -> Vec<f64> {
256        let mut padded = input.to_vec();
257        while padded.len() < target_size {
258            padded.push(0.0);
259        }
260        padded.truncate(target_size);
261        padded
262    }
263
264    fn convert_to_binary(&self, input: &[f64]) -> Vec<u8> {
265        let mut binary = Vec::new();
266        for &value in input {
267            let int_value = (value * 255.0) as u8;
268            for i in 0..8 {
269                binary.push((int_value >> i) & 1);
270                if binary.len() >= self.num_qubits {
271                    break;
272                }
273            }
274            if binary.len() >= self.num_qubits {
275                break;
276            }
277        }
278        while binary.len() < self.num_qubits {
279            binary.push(0);
280        }
281        binary.truncate(self.num_qubits);
282        binary
283    }
284
285    async fn decode_output(&self, circuit_result: &CircuitResult) -> DeviceResult<Vec<f64>> {
286        match self.output_decoding {
287            OutputDecoding::PauliExpectation => {
288                // Compute expectation values of Pauli operators
289                let mut expectations = Vec::new();
290                for (qubit, pauli_op) in self.measurement_operators.iter().enumerate() {
291                    let expectation =
292                        self.compute_pauli_expectation(circuit_result, qubit, pauli_op)?;
293                    expectations.push(expectation);
294                }
295                Ok(expectations)
296            }
297            OutputDecoding::Probabilities => {
298                // Convert measurement counts to probabilities
299                let total_shots = circuit_result.shots as f64;
300                let mut probs = Vec::new();
301
302                for i in 0..self.num_qubits {
303                    let mut prob_one = 0.0;
304                    for (bitstring, count) in &circuit_result.counts {
305                        if let Some(bit_char) = bitstring.chars().nth(i) {
306                            if bit_char == '1' {
307                                prob_one += *count as f64 / total_shots;
308                            }
309                        }
310                    }
311                    probs.push(prob_one);
312                }
313                Ok(probs)
314            }
315            _ => Err(DeviceError::InvalidInput(format!(
316                "Output decoding {:?} not implemented",
317                self.output_decoding
318            ))),
319        }
320    }
321
322    fn compute_pauli_expectation(
323        &self,
324        circuit_result: &CircuitResult,
325        qubit: usize,
326        pauli_op: &PauliOperator,
327    ) -> DeviceResult<f64> {
328        let mut expectation = 0.0;
329        let total_shots = circuit_result.shots as f64;
330
331        for (bitstring, count) in &circuit_result.counts {
332            let probability = *count as f64 / total_shots;
333
334            let eigenvalue = if let Some(bit_char) = bitstring.chars().nth(qubit) {
335                match pauli_op {
336                    PauliOperator::Z => {
337                        if bit_char == '0' {
338                            1.0
339                        } else {
340                            -1.0
341                        }
342                    }
343                    PauliOperator::X | PauliOperator::Y => {
344                        // Would need different measurement basis
345                        return Err(DeviceError::InvalidInput(
346                            "X and Y Pauli measurements require basis rotation".to_string(),
347                        ));
348                    }
349                    PauliOperator::I => 1.0,
350                }
351            } else {
352                0.0
353            };
354
355            expectation += probability * eigenvalue;
356        }
357
358        Ok(expectation)
359    }
360
361    /// Execute a circuit on the quantum device (helper function to work around trait object limitations)
362    async fn execute_circuit_helper(
363        device: &(dyn QuantumDevice + Send + Sync),
364        circuit: &ParameterizedQuantumCircuit,
365        shots: usize,
366    ) -> DeviceResult<CircuitResult> {
367        // For now, return a mock result since we can't execute circuits directly
368        // In a real implementation, this would need proper circuit execution
369        let mut counts = std::collections::HashMap::new();
370        counts.insert("0".repeat(circuit.num_qubits()), shots / 2);
371        counts.insert("1".repeat(circuit.num_qubits()), shots / 2);
372
373        Ok(CircuitResult {
374            counts,
375            shots,
376            metadata: std::collections::HashMap::new(),
377        })
378    }
379}
380
381impl QuantumNeuralNetwork for PQCNetwork {
382    fn forward(&self, input: &[f64]) -> DeviceResult<Vec<f64>> {
383        // This would need to be async in practice
384        let rt = tokio::runtime::Runtime::new().map_err(|e| {
385            DeviceError::ExecutionFailed(format!("Failed to create tokio runtime: {e}"))
386        })?;
387        rt.block_on(async {
388            let circuit = self.build_circuit(input).await?;
389            let device = self.device.read().await;
390            let result = Self::execute_circuit_helper(&*device, &circuit, 1024).await?;
391            self.decode_output(&result).await
392        })
393    }
394
395    fn parameters(&self) -> &[f64] {
396        &self.parameters
397    }
398
399    fn set_parameters(&mut self, params: Vec<f64>) -> DeviceResult<()> {
400        if params.len() != self.parameters.len() {
401            return Err(DeviceError::InvalidInput(format!(
402                "Expected {} parameters, got {}",
403                self.parameters.len(),
404                params.len()
405            )));
406        }
407        self.parameters = params;
408        Ok(())
409    }
410
411    fn parameter_count(&self) -> usize {
412        self.parameters.len()
413    }
414
415    fn architecture(&self) -> QNNArchitecture {
416        QNNArchitecture {
417            network_type: QNNType::PQC,
418            num_qubits: self.num_qubits,
419            num_layers: self.num_layers,
420            num_parameters: self.parameters.len(),
421            input_encoding: self.input_encoding.clone(),
422            output_decoding: self.output_decoding.clone(),
423            entangling_strategy: self.entangling_strategy.clone(),
424        }
425    }
426}
427
428/// Quantum Convolutional Neural Network
429pub struct QCNN {
430    device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
431    num_qubits: usize,
432    conv_layers: Vec<QConvLayer>,
433    pooling_layers: Vec<QPoolingLayer>,
434    parameters: Vec<f64>,
435    input_encoding: InputEncoding,
436}
437
438/// Quantum Convolutional Layer
439#[derive(Debug, Clone, Serialize, Deserialize)]
440pub struct QConvLayer {
441    pub kernel_size: usize,
442    pub stride: usize,
443    pub num_filters: usize,
444    pub parameter_indices: Vec<usize>,
445}
446
447/// Quantum Pooling Layer
448#[derive(Debug, Clone, Serialize, Deserialize)]
449pub struct QPoolingLayer {
450    pub pool_size: usize,
451    pub pool_type: QPoolingType,
452}
453
454#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
455pub enum QPoolingType {
456    Max,
457    Average,
458    Measurement,
459}
460
461impl QCNN {
462    pub fn new(
463        device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
464        num_qubits: usize,
465        conv_layers: Vec<QConvLayer>,
466        pooling_layers: Vec<QPoolingLayer>,
467        input_encoding: InputEncoding,
468    ) -> Self {
469        let total_params = conv_layers.iter()
470            .map(|layer| layer.num_filters * layer.kernel_size * 3) // 3 rotation gates per kernel
471            .sum();
472
473        let parameters = (0..total_params)
474            .map(|_| fastrand::f64() * 2.0 * std::f64::consts::PI)
475            .collect();
476
477        Self {
478            device,
479            num_qubits,
480            conv_layers,
481            pooling_layers,
482            parameters,
483            input_encoding,
484        }
485    }
486
487    pub async fn build_circuit(&self, input: &[f64]) -> DeviceResult<ParameterizedQuantumCircuit> {
488        let mut circuit = ParameterizedQuantumCircuit::new(self.num_qubits);
489
490        // Input encoding
491        self.encode_input(&mut circuit, input).await?;
492
493        let mut current_qubits = self.num_qubits;
494
495        // Apply convolutional and pooling layers alternately
496        for (conv_layer, pool_layer) in self.conv_layers.iter().zip(self.pooling_layers.iter()) {
497            // Apply convolutional layer
498            self.apply_conv_layer(&mut circuit, conv_layer, current_qubits)
499                .await?;
500
501            // Apply pooling layer
502            current_qubits = self
503                .apply_pooling_layer(&mut circuit, pool_layer, current_qubits)
504                .await?;
505        }
506
507        Ok(circuit)
508    }
509
510    async fn encode_input(
511        &self,
512        circuit: &mut ParameterizedQuantumCircuit,
513        input: &[f64],
514    ) -> DeviceResult<()> {
515        match self.input_encoding {
516            InputEncoding::Angle => {
517                let padded_input = self.pad_input(input, self.num_qubits);
518                for (qubit, &value) in padded_input.iter().enumerate() {
519                    circuit.add_ry_gate(qubit, value)?;
520                }
521            }
522            InputEncoding::Amplitude => {
523                // Initialize in superposition
524                for qubit in 0..self.num_qubits {
525                    circuit.add_h_gate(qubit)?;
526                }
527            }
528            _ => {
529                return Err(DeviceError::InvalidInput(format!(
530                    "Input encoding {:?} not implemented for QCNN",
531                    self.input_encoding
532                )));
533            }
534        }
535        Ok(())
536    }
537
538    async fn apply_conv_layer(
539        &self,
540        circuit: &mut ParameterizedQuantumCircuit,
541        layer: &QConvLayer,
542        num_active_qubits: usize,
543    ) -> DeviceResult<()> {
544        let num_windows = (num_active_qubits - layer.kernel_size) / layer.stride + 1;
545
546        for window in 0..num_windows {
547            let start_qubit = window * layer.stride;
548
549            for filter in 0..layer.num_filters {
550                let param_offset = filter * layer.kernel_size * 3;
551
552                // Apply parameterized gates to qubits in the window
553                for i in 0..layer.kernel_size {
554                    let qubit = start_qubit + i;
555                    let param_base = param_offset + i * 3;
556
557                    if param_base + 2 < self.parameters.len() {
558                        circuit.add_rx_gate(qubit, self.parameters[param_base])?;
559                        circuit.add_ry_gate(qubit, self.parameters[param_base + 1])?;
560                        circuit.add_rz_gate(qubit, self.parameters[param_base + 2])?;
561                    }
562                }
563
564                // Apply entangling gates within the window
565                for i in 0..layer.kernel_size - 1 {
566                    let control = start_qubit + i;
567                    let target = start_qubit + i + 1;
568                    circuit.add_cnot_gate(control, target)?;
569                }
570            }
571        }
572
573        Ok(())
574    }
575
576    async fn apply_pooling_layer(
577        &self,
578        circuit: &mut ParameterizedQuantumCircuit,
579        layer: &QPoolingLayer,
580        num_active_qubits: usize,
581    ) -> DeviceResult<usize> {
582        let num_pools = num_active_qubits / layer.pool_size;
583
584        match layer.pool_type {
585            QPoolingType::Measurement => {
586                // Measure and discard some qubits (simplified)
587                // In practice, this would involve partial measurements
588                Ok(num_pools)
589            }
590            QPoolingType::Max | QPoolingType::Average => {
591                // Apply pooling unitaries (simplified)
592                for pool in 0..num_pools {
593                    let start_qubit = pool * layer.pool_size;
594
595                    // Apply pooling gates
596                    for i in 0..layer.pool_size - 1 {
597                        let qubit1 = start_qubit + i;
598                        let qubit2 = start_qubit + i + 1;
599                        circuit.add_cnot_gate(qubit1, qubit2)?;
600                    }
601                }
602                Ok(num_pools)
603            }
604        }
605    }
606
607    fn pad_input(&self, input: &[f64], target_size: usize) -> Vec<f64> {
608        let mut padded = input.to_vec();
609        while padded.len() < target_size {
610            padded.push(0.0);
611        }
612        padded.truncate(target_size);
613        padded
614    }
615}
616
617impl QuantumNeuralNetwork for QCNN {
618    fn forward(&self, input: &[f64]) -> DeviceResult<Vec<f64>> {
619        let rt = tokio::runtime::Runtime::new().map_err(|e| {
620            DeviceError::ExecutionFailed(format!("Failed to create tokio runtime: {e}"))
621        })?;
622        rt.block_on(async {
623            let circuit = self.build_circuit(input).await?;
624            let device = self.device.read().await;
625            let result = Self::execute_circuit_helper(&*device, &circuit, 1024).await?;
626
627            // Simple output decoding for QCNN
628            let mut output = Vec::new();
629            let total_shots = result.shots as f64;
630
631            for i in 0..self.num_qubits.min(8) {
632                // Limit output size
633                let mut prob_one = 0.0;
634                for (bitstring, count) in &result.counts {
635                    if let Some(bit_char) = bitstring.chars().nth(i) {
636                        if bit_char == '1' {
637                            prob_one += *count as f64 / total_shots;
638                        }
639                    }
640                }
641                output.push(prob_one);
642            }
643
644            Ok(output)
645        })
646    }
647
648    fn parameters(&self) -> &[f64] {
649        &self.parameters
650    }
651
652    fn set_parameters(&mut self, params: Vec<f64>) -> DeviceResult<()> {
653        if params.len() != self.parameters.len() {
654            return Err(DeviceError::InvalidInput(format!(
655                "Expected {} parameters, got {}",
656                self.parameters.len(),
657                params.len()
658            )));
659        }
660        self.parameters = params;
661        Ok(())
662    }
663
664    fn parameter_count(&self) -> usize {
665        self.parameters.len()
666    }
667
668    fn architecture(&self) -> QNNArchitecture {
669        QNNArchitecture {
670            network_type: QNNType::QCNN,
671            num_qubits: self.num_qubits,
672            num_layers: self.conv_layers.len(),
673            num_parameters: self.parameters.len(),
674            input_encoding: self.input_encoding.clone(),
675            output_decoding: OutputDecoding::Probabilities,
676            entangling_strategy: EntanglingStrategy::Linear,
677        }
678    }
679}
680
681impl QCNN {
682    /// Execute a circuit on the quantum device (helper function to work around trait object limitations)
683    async fn execute_circuit_helper(
684        device: &(dyn QuantumDevice + Send + Sync),
685        circuit: &ParameterizedQuantumCircuit,
686        shots: usize,
687    ) -> DeviceResult<CircuitResult> {
688        // For now, return a mock result since we can't execute circuits directly
689        // In a real implementation, this would need proper circuit execution
690        let mut counts = std::collections::HashMap::new();
691        counts.insert("0".repeat(circuit.num_qubits()), shots / 2);
692        counts.insert("1".repeat(circuit.num_qubits()), shots / 2);
693
694        Ok(CircuitResult {
695            counts,
696            shots,
697            metadata: std::collections::HashMap::new(),
698        })
699    }
700}
701
702/// Variational Quantum Classifier
703pub struct VQC {
704    pqc_network: PQCNetwork,
705    class_mapping: HashMap<usize, String>,
706}
707
708impl VQC {
709    pub fn new(
710        device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
711        num_qubits: usize,
712        num_layers: usize,
713        num_classes: usize,
714    ) -> Self {
715        let pqc_network = PQCNetwork::new(
716            device,
717            num_qubits,
718            num_layers,
719            InputEncoding::Angle,
720            OutputDecoding::PauliExpectation,
721            EntanglingStrategy::Linear,
722        );
723
724        let class_mapping = (0..num_classes)
725            .map(|i| (i, format!("class_{i}")))
726            .collect();
727
728        Self {
729            pqc_network,
730            class_mapping,
731        }
732    }
733
734    pub fn classify(&self, input: &[f64]) -> DeviceResult<ClassificationResult> {
735        let raw_output = self.pqc_network.forward(input)?;
736
737        // Convert quantum outputs to class probabilities
738        let class_probs = self.softmax(&raw_output);
739
740        // Find predicted class
741        let (predicted_class, confidence) = class_probs
742            .iter()
743            .enumerate()
744            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
745            .map_or((0, 0.0), |(idx, &prob)| (idx, prob));
746
747        let class_name = self
748            .class_mapping
749            .get(&predicted_class)
750            .cloned()
751            .unwrap_or_else(|| "unknown".to_string());
752
753        Ok(ClassificationResult {
754            predicted_class,
755            class_name,
756            confidence,
757            class_probabilities: class_probs,
758        })
759    }
760
761    fn softmax(&self, values: &[f64]) -> Vec<f64> {
762        let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
763        let exp_values: Vec<f64> = values.iter().map(|&x| (x - max_val).exp()).collect();
764        let sum_exp: f64 = exp_values.iter().sum();
765        exp_values.iter().map(|&x| x / sum_exp).collect()
766    }
767}
768
769impl QuantumNeuralNetwork for VQC {
770    fn forward(&self, input: &[f64]) -> DeviceResult<Vec<f64>> {
771        self.pqc_network.forward(input)
772    }
773
774    fn parameters(&self) -> &[f64] {
775        self.pqc_network.parameters()
776    }
777
778    fn set_parameters(&mut self, params: Vec<f64>) -> DeviceResult<()> {
779        self.pqc_network.set_parameters(params)
780    }
781
782    fn parameter_count(&self) -> usize {
783        self.pqc_network.parameter_count()
784    }
785
786    fn architecture(&self) -> QNNArchitecture {
787        let mut arch = self.pqc_network.architecture();
788        arch.network_type = QNNType::VQC;
789        arch
790    }
791}
792
793/// Classification result
794#[derive(Debug, Clone, Serialize, Deserialize)]
795pub struct ClassificationResult {
796    pub predicted_class: usize,
797    pub class_name: String,
798    pub confidence: f64,
799    pub class_probabilities: Vec<f64>,
800}
801
802/// Create a PQC network for classification
803pub fn create_pqc_classifier(
804    device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
805    num_features: usize,
806    num_classes: usize,
807    num_layers: usize,
808) -> DeviceResult<VQC> {
809    let num_qubits =
810        (num_features as f64).log2().ceil() as usize + (num_classes as f64).log2().ceil() as usize;
811    Ok(VQC::new(device, num_qubits, num_layers, num_classes))
812}
813
814/// Create a QCNN for image classification
815pub fn create_qcnn_classifier(
816    device: Arc<RwLock<dyn QuantumDevice + Send + Sync>>,
817    image_size: usize,
818) -> DeviceResult<QCNN> {
819    let num_qubits = (image_size as f64).log2().ceil() as usize;
820
821    let conv_layers = vec![
822        QConvLayer {
823            kernel_size: 2,
824            stride: 1,
825            num_filters: 2,
826            parameter_indices: (0..12).collect(), // 2 filters * 2 kernel_size * 3 params
827        },
828        QConvLayer {
829            kernel_size: 2,
830            stride: 1,
831            num_filters: 1,
832            parameter_indices: (12..18).collect(), // 1 filter * 2 kernel_size * 3 params
833        },
834    ];
835
836    let pooling_layers = vec![
837        QPoolingLayer {
838            pool_size: 2,
839            pool_type: QPoolingType::Measurement,
840        },
841        QPoolingLayer {
842            pool_size: 2,
843            pool_type: QPoolingType::Measurement,
844        },
845    ];
846
847    Ok(QCNN::new(
848        device,
849        num_qubits,
850        conv_layers,
851        pooling_layers,
852        InputEncoding::Angle,
853    ))
854}
855
856#[cfg(test)]
857mod tests {
858    use super::*;
859    use crate::test_utils::create_mock_quantum_device;
860
861    #[test]
862    fn test_pqc_network_creation() {
863        let device = create_mock_quantum_device();
864        let network = PQCNetwork::new(
865            device,
866            4,
867            2,
868            InputEncoding::Angle,
869            OutputDecoding::PauliExpectation,
870            EntanglingStrategy::Linear,
871        );
872
873        assert_eq!(network.num_qubits, 4);
874        assert_eq!(network.num_layers, 2);
875        assert_eq!(network.parameter_count(), 24); // 3 * 4 * 2
876    }
877
878    #[test]
879    fn test_vqc_creation() {
880        let device = create_mock_quantum_device();
881        let classifier = VQC::new(device, 4, 2, 3);
882
883        assert_eq!(classifier.class_mapping.len(), 3);
884        assert_eq!(classifier.parameter_count(), 24);
885    }
886
887    #[test]
888    fn test_qcnn_creation() {
889        let device = create_mock_quantum_device();
890        let conv_layers = vec![QConvLayer {
891            kernel_size: 2,
892            stride: 1,
893            num_filters: 1,
894            parameter_indices: (0..6).collect(),
895        }];
896        let pooling_layers = vec![QPoolingLayer {
897            pool_size: 2,
898            pool_type: QPoolingType::Max,
899        }];
900
901        let qcnn = QCNN::new(device, 4, conv_layers, pooling_layers, InputEncoding::Angle);
902
903        assert_eq!(qcnn.num_qubits, 4);
904        assert_eq!(qcnn.parameter_count(), 6);
905    }
906
907    #[test]
908    fn test_softmax() {
909        let classifier = {
910            let device = create_mock_quantum_device();
911            VQC::new(device, 4, 2, 3)
912        };
913
914        let input = vec![1.0, 2.0, 3.0];
915        let output = classifier.softmax(&input);
916
917        assert_eq!(output.len(), 3);
918        assert!((output.iter().sum::<f64>() - 1.0).abs() < 1e-10);
919        assert!(output[2] > output[1]);
920        assert!(output[1] > output[0]);
921    }
922
923    #[test]
924    fn test_parameter_operations() {
925        let device = create_mock_quantum_device();
926        let mut network = PQCNetwork::new(
927            device,
928            4,
929            2,
930            InputEncoding::Angle,
931            OutputDecoding::PauliExpectation,
932            EntanglingStrategy::Linear,
933        );
934
935        let original_params = network.parameters().to_vec();
936        let new_params = vec![0.0; network.parameter_count()];
937
938        network
939            .set_parameters(new_params.clone())
940            .expect("Setting parameters should succeed");
941        assert_eq!(network.parameters(), &new_params);
942        assert_ne!(network.parameters(), &original_params);
943
944        // Test invalid parameter count
945        let invalid_params = vec![0.0; 5];
946        assert!(network.set_parameters(invalid_params).is_err());
947    }
948}