quantrs2_ml/keras_api/
quantum_layers.rs

1//! Quantum layers for Keras-like API
2
3use super::KerasLayer;
4use crate::error::{MLError, Result};
5use crate::simulator_backends::{DynamicCircuit, Observable, SimulatorBackend, StatevectorBackend};
6use quantrs2_circuit::prelude::*;
7use scirs2_core::ndarray::{s, ArrayD, IxDyn};
8use std::sync::Arc;
9
10/// Quantum Dense layer
11pub struct QuantumDense {
12    /// Number of qubits
13    num_qubits: usize,
14    /// Number of output features
15    units: usize,
16    /// Quantum circuit ansatz
17    ansatz_type: QuantumAnsatzType,
18    /// Number of layers in ansatz
19    num_layers: usize,
20    /// Observable for measurement
21    observable: Observable,
22    /// Backend
23    backend: Arc<dyn SimulatorBackend>,
24    /// Layer name
25    name: String,
26    /// Built flag
27    built: bool,
28    /// Input shape
29    input_shape: Option<Vec<usize>>,
30    /// Quantum parameters
31    quantum_weights: Vec<ArrayD<f64>>,
32}
33
34/// Quantum ansatz types
35#[derive(Debug, Clone)]
36pub enum QuantumAnsatzType {
37    /// Hardware efficient ansatz
38    HardwareEfficient,
39    /// Real amplitudes ansatz
40    RealAmplitudes,
41    /// Strongly entangling layers
42    StronglyEntangling,
43    /// Custom ansatz
44    Custom(DynamicCircuit),
45}
46
47impl QuantumDense {
48    /// Create new quantum dense layer
49    pub fn new(num_qubits: usize, units: usize) -> Self {
50        Self {
51            num_qubits,
52            units,
53            ansatz_type: QuantumAnsatzType::HardwareEfficient,
54            num_layers: 1,
55            observable: Observable::PauliZ(vec![0]),
56            backend: Arc::new(StatevectorBackend::new(10)),
57            name: format!("quantum_dense_{}", fastrand::u32(..)),
58            built: false,
59            input_shape: None,
60            quantum_weights: Vec::new(),
61        }
62    }
63
64    /// Set ansatz type
65    pub fn ansatz_type(mut self, ansatz_type: QuantumAnsatzType) -> Self {
66        self.ansatz_type = ansatz_type;
67        self
68    }
69
70    /// Set number of layers
71    pub fn num_layers(mut self, num_layers: usize) -> Self {
72        self.num_layers = num_layers;
73        self
74    }
75
76    /// Set observable
77    pub fn observable(mut self, observable: Observable) -> Self {
78        self.observable = observable;
79        self
80    }
81
82    /// Set backend
83    pub fn backend(mut self, backend: Arc<dyn SimulatorBackend>) -> Self {
84        self.backend = backend;
85        self
86    }
87
88    /// Set layer name
89    pub fn name(mut self, name: impl Into<String>) -> Self {
90        self.name = name.into();
91        self
92    }
93
94    /// Build quantum circuit based on ansatz type
95    fn build_quantum_circuit(&self) -> Result<DynamicCircuit> {
96        let mut builder: Circuit<8> = Circuit::new();
97
98        match &self.ansatz_type {
99            QuantumAnsatzType::HardwareEfficient => {
100                for layer in 0..self.num_layers {
101                    if layer == 0 {
102                        for qubit in 0..self.num_qubits {
103                            builder.ry(qubit, 0.0)?;
104                        }
105                    }
106
107                    for qubit in 0..self.num_qubits {
108                        builder.ry(qubit, 0.0)?;
109                        builder.rz(qubit, 0.0)?;
110                    }
111
112                    for qubit in 0..self.num_qubits - 1 {
113                        builder.cnot(qubit, qubit + 1)?;
114                    }
115                    if self.num_qubits > 2 {
116                        builder.cnot(self.num_qubits - 1, 0)?;
117                    }
118                }
119            }
120            QuantumAnsatzType::RealAmplitudes => {
121                for layer in 0..self.num_layers {
122                    if layer == 0 {
123                        for qubit in 0..self.num_qubits {
124                            builder.ry(qubit, 0.0)?;
125                        }
126                    }
127
128                    for qubit in 0..self.num_qubits {
129                        builder.ry(qubit, 0.0)?;
130                    }
131
132                    for qubit in 0..self.num_qubits - 1 {
133                        builder.cnot(qubit, qubit + 1)?;
134                    }
135                }
136            }
137            QuantumAnsatzType::StronglyEntangling => {
138                for layer in 0..self.num_layers {
139                    if layer == 0 {
140                        for qubit in 0..self.num_qubits {
141                            builder.ry(qubit, 0.0)?;
142                        }
143                    }
144
145                    for qubit in 0..self.num_qubits {
146                        builder.rx(qubit, 0.0)?;
147                        builder.ry(qubit, 0.0)?;
148                        builder.rz(qubit, 0.0)?;
149                    }
150
151                    for qubit in 0..self.num_qubits - 1 {
152                        builder.cnot(qubit, qubit + 1)?;
153                    }
154                    if self.num_qubits > 2 {
155                        builder.cnot(self.num_qubits - 1, 0)?;
156                    }
157                }
158            }
159            QuantumAnsatzType::Custom(circuit) => {
160                return Ok(circuit.clone());
161            }
162        }
163
164        let circuit = builder.build();
165        DynamicCircuit::from_circuit(circuit)
166    }
167}
168
169impl KerasLayer for QuantumDense {
170    fn build(&mut self, input_shape: &[usize]) -> Result<()> {
171        self.input_shape = Some(input_shape.to_vec());
172
173        let num_params = match &self.ansatz_type {
174            QuantumAnsatzType::HardwareEfficient => self.num_qubits * 2 * self.num_layers,
175            QuantumAnsatzType::RealAmplitudes => self.num_qubits * self.num_layers,
176            QuantumAnsatzType::StronglyEntangling => self.num_qubits * 3 * self.num_layers,
177            QuantumAnsatzType::Custom(_) => 10,
178        };
179
180        let params = ArrayD::from_shape_fn(IxDyn(&[self.units, num_params]), |_| {
181            fastrand::f64() * 2.0 * std::f64::consts::PI
182        });
183        self.quantum_weights.push(params);
184
185        self.built = true;
186        Ok(())
187    }
188
189    fn call(&self, inputs: &ArrayD<f64>) -> Result<ArrayD<f64>> {
190        if !self.built {
191            return Err(MLError::InvalidConfiguration(
192                "Layer must be built before calling".to_string(),
193            ));
194        }
195
196        let batch_size = inputs.shape()[0];
197        let mut outputs = ArrayD::zeros(IxDyn(&[batch_size, self.units]));
198
199        for batch_idx in 0..batch_size {
200            for unit_idx in 0..self.units {
201                let circuit = self.build_quantum_circuit()?;
202
203                let input_slice = inputs.slice(s![batch_idx, ..]);
204                let param_slice = self.quantum_weights[0].slice(s![unit_idx, ..]);
205
206                let combined_params: Vec<f64> = input_slice
207                    .iter()
208                    .chain(param_slice.iter())
209                    .copied()
210                    .collect();
211
212                let expectation =
213                    self.backend
214                        .expectation_value(&circuit, &combined_params, &self.observable)?;
215
216                outputs[[batch_idx, unit_idx]] = expectation;
217            }
218        }
219
220        Ok(outputs)
221    }
222
223    fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
224        let mut output_shape = input_shape.to_vec();
225        let last_idx = output_shape.len() - 1;
226        output_shape[last_idx] = self.units;
227        output_shape
228    }
229
230    fn name(&self) -> &str {
231        &self.name
232    }
233
234    fn get_weights(&self) -> Vec<ArrayD<f64>> {
235        self.quantum_weights.clone()
236    }
237
238    fn set_weights(&mut self, weights: Vec<ArrayD<f64>>) -> Result<()> {
239        if weights.len() != self.quantum_weights.len() {
240            return Err(MLError::InvalidConfiguration(
241                "Number of weight arrays doesn't match layer structure".to_string(),
242            ));
243        }
244        self.quantum_weights = weights;
245        Ok(())
246    }
247
248    fn built(&self) -> bool {
249        self.built
250    }
251}