quantrs2_ml/keras_api/
quantum_layers.rs1use super::KerasLayer;
4use crate::error::{MLError, Result};
5use crate::simulator_backends::{DynamicCircuit, Observable, SimulatorBackend, StatevectorBackend};
6use quantrs2_circuit::prelude::*;
7use scirs2_core::ndarray::{s, ArrayD, IxDyn};
8use std::sync::Arc;
9
10pub struct QuantumDense {
12 num_qubits: usize,
14 units: usize,
16 ansatz_type: QuantumAnsatzType,
18 num_layers: usize,
20 observable: Observable,
22 backend: Arc<dyn SimulatorBackend>,
24 name: String,
26 built: bool,
28 input_shape: Option<Vec<usize>>,
30 quantum_weights: Vec<ArrayD<f64>>,
32}
33
34#[derive(Debug, Clone)]
36pub enum QuantumAnsatzType {
37 HardwareEfficient,
39 RealAmplitudes,
41 StronglyEntangling,
43 Custom(DynamicCircuit),
45}
46
47impl QuantumDense {
48 pub fn new(num_qubits: usize, units: usize) -> Self {
50 Self {
51 num_qubits,
52 units,
53 ansatz_type: QuantumAnsatzType::HardwareEfficient,
54 num_layers: 1,
55 observable: Observable::PauliZ(vec![0]),
56 backend: Arc::new(StatevectorBackend::new(10)),
57 name: format!("quantum_dense_{}", fastrand::u32(..)),
58 built: false,
59 input_shape: None,
60 quantum_weights: Vec::new(),
61 }
62 }
63
64 pub fn ansatz_type(mut self, ansatz_type: QuantumAnsatzType) -> Self {
66 self.ansatz_type = ansatz_type;
67 self
68 }
69
70 pub fn num_layers(mut self, num_layers: usize) -> Self {
72 self.num_layers = num_layers;
73 self
74 }
75
76 pub fn observable(mut self, observable: Observable) -> Self {
78 self.observable = observable;
79 self
80 }
81
82 pub fn backend(mut self, backend: Arc<dyn SimulatorBackend>) -> Self {
84 self.backend = backend;
85 self
86 }
87
88 pub fn name(mut self, name: impl Into<String>) -> Self {
90 self.name = name.into();
91 self
92 }
93
94 fn build_quantum_circuit(&self) -> Result<DynamicCircuit> {
96 let mut builder: Circuit<8> = Circuit::new();
97
98 match &self.ansatz_type {
99 QuantumAnsatzType::HardwareEfficient => {
100 for layer in 0..self.num_layers {
101 if layer == 0 {
102 for qubit in 0..self.num_qubits {
103 builder.ry(qubit, 0.0)?;
104 }
105 }
106
107 for qubit in 0..self.num_qubits {
108 builder.ry(qubit, 0.0)?;
109 builder.rz(qubit, 0.0)?;
110 }
111
112 for qubit in 0..self.num_qubits - 1 {
113 builder.cnot(qubit, qubit + 1)?;
114 }
115 if self.num_qubits > 2 {
116 builder.cnot(self.num_qubits - 1, 0)?;
117 }
118 }
119 }
120 QuantumAnsatzType::RealAmplitudes => {
121 for layer in 0..self.num_layers {
122 if layer == 0 {
123 for qubit in 0..self.num_qubits {
124 builder.ry(qubit, 0.0)?;
125 }
126 }
127
128 for qubit in 0..self.num_qubits {
129 builder.ry(qubit, 0.0)?;
130 }
131
132 for qubit in 0..self.num_qubits - 1 {
133 builder.cnot(qubit, qubit + 1)?;
134 }
135 }
136 }
137 QuantumAnsatzType::StronglyEntangling => {
138 for layer in 0..self.num_layers {
139 if layer == 0 {
140 for qubit in 0..self.num_qubits {
141 builder.ry(qubit, 0.0)?;
142 }
143 }
144
145 for qubit in 0..self.num_qubits {
146 builder.rx(qubit, 0.0)?;
147 builder.ry(qubit, 0.0)?;
148 builder.rz(qubit, 0.0)?;
149 }
150
151 for qubit in 0..self.num_qubits - 1 {
152 builder.cnot(qubit, qubit + 1)?;
153 }
154 if self.num_qubits > 2 {
155 builder.cnot(self.num_qubits - 1, 0)?;
156 }
157 }
158 }
159 QuantumAnsatzType::Custom(circuit) => {
160 return Ok(circuit.clone());
161 }
162 }
163
164 let circuit = builder.build();
165 DynamicCircuit::from_circuit(circuit)
166 }
167}
168
169impl KerasLayer for QuantumDense {
170 fn build(&mut self, input_shape: &[usize]) -> Result<()> {
171 self.input_shape = Some(input_shape.to_vec());
172
173 let num_params = match &self.ansatz_type {
174 QuantumAnsatzType::HardwareEfficient => self.num_qubits * 2 * self.num_layers,
175 QuantumAnsatzType::RealAmplitudes => self.num_qubits * self.num_layers,
176 QuantumAnsatzType::StronglyEntangling => self.num_qubits * 3 * self.num_layers,
177 QuantumAnsatzType::Custom(_) => 10,
178 };
179
180 let params = ArrayD::from_shape_fn(IxDyn(&[self.units, num_params]), |_| {
181 fastrand::f64() * 2.0 * std::f64::consts::PI
182 });
183 self.quantum_weights.push(params);
184
185 self.built = true;
186 Ok(())
187 }
188
189 fn call(&self, inputs: &ArrayD<f64>) -> Result<ArrayD<f64>> {
190 if !self.built {
191 return Err(MLError::InvalidConfiguration(
192 "Layer must be built before calling".to_string(),
193 ));
194 }
195
196 let batch_size = inputs.shape()[0];
197 let mut outputs = ArrayD::zeros(IxDyn(&[batch_size, self.units]));
198
199 for batch_idx in 0..batch_size {
200 for unit_idx in 0..self.units {
201 let circuit = self.build_quantum_circuit()?;
202
203 let input_slice = inputs.slice(s![batch_idx, ..]);
204 let param_slice = self.quantum_weights[0].slice(s![unit_idx, ..]);
205
206 let combined_params: Vec<f64> = input_slice
207 .iter()
208 .chain(param_slice.iter())
209 .copied()
210 .collect();
211
212 let expectation =
213 self.backend
214 .expectation_value(&circuit, &combined_params, &self.observable)?;
215
216 outputs[[batch_idx, unit_idx]] = expectation;
217 }
218 }
219
220 Ok(outputs)
221 }
222
223 fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
224 let mut output_shape = input_shape.to_vec();
225 let last_idx = output_shape.len() - 1;
226 output_shape[last_idx] = self.units;
227 output_shape
228 }
229
230 fn name(&self) -> &str {
231 &self.name
232 }
233
234 fn get_weights(&self) -> Vec<ArrayD<f64>> {
235 self.quantum_weights.clone()
236 }
237
238 fn set_weights(&mut self, weights: Vec<ArrayD<f64>>) -> Result<()> {
239 if weights.len() != self.quantum_weights.len() {
240 return Err(MLError::InvalidConfiguration(
241 "Number of weight arrays doesn't match layer structure".to_string(),
242 ));
243 }
244 self.quantum_weights = weights;
245 Ok(())
246 }
247
248 fn built(&self) -> bool {
249 self.built
250 }
251}