1pub mod encoding;
8pub mod layers;
9pub mod training;
10
11use crate::{
12 error::{QuantRS2Error, QuantRS2Result},
13 gate::GateOp,
14 qubit::QubitId,
15};
16use ndarray::{Array1, Array2};
17use num_complex::Complex64;
18use std::collections::HashMap;
19
20pub use layers::Parameter;
22
23pub trait QMLLayer: Send + Sync {
25 fn num_qubits(&self) -> usize;
27
28 fn parameters(&self) -> &[Parameter];
30
31 fn parameters_mut(&mut self) -> &mut [Parameter];
33
34 fn set_parameters(&mut self, values: &[f64]) -> QuantRS2Result<()> {
36 if values.len() != self.parameters().len() {
37 return Err(QuantRS2Error::InvalidInput(format!(
38 "Expected {} parameters, got {}",
39 self.parameters().len(),
40 values.len()
41 )));
42 }
43
44 for (param, &value) in self.parameters_mut().iter_mut().zip(values.iter()) {
45 param.value = value;
46 }
47
48 Ok(())
49 }
50
51 fn gates(&self) -> Vec<Box<dyn GateOp>>;
53
54 fn compute_gradients(
56 &self,
57 state: &Array1<Complex64>,
58 loss_gradient: &Array1<Complex64>,
59 ) -> QuantRS2Result<Vec<f64>>;
60
61 fn name(&self) -> &str;
63}
64
65#[derive(Debug, Clone, Copy, PartialEq)]
67pub enum EncodingStrategy {
68 Amplitude,
70 Angle,
72 IQP,
74 Basis,
76}
77
78#[derive(Debug, Clone)]
80pub struct QMLConfig {
81 pub num_qubits: usize,
83 pub num_layers: usize,
85 pub encoding: EncodingStrategy,
87 pub entanglement: EntanglementPattern,
89 pub data_reuploading: bool,
91}
92
93impl Default for QMLConfig {
94 fn default() -> Self {
95 Self {
96 num_qubits: 4,
97 num_layers: 2,
98 encoding: EncodingStrategy::Angle,
99 entanglement: EntanglementPattern::Full,
100 data_reuploading: false,
101 }
102 }
103}
104
105#[derive(Debug, Clone, Copy, PartialEq)]
107pub enum EntanglementPattern {
108 None,
110 Linear,
112 Circular,
114 Full,
116 Alternating,
118}
119
120pub struct QMLCircuit {
122 config: QMLConfig,
124 layers: Vec<Box<dyn QMLLayer>>,
126 num_parameters: usize,
128}
129
130impl QMLCircuit {
131 pub fn new(config: QMLConfig) -> Self {
133 Self {
134 config,
135 layers: Vec::new(),
136 num_parameters: 0,
137 }
138 }
139
140 pub fn add_layer(&mut self, layer: Box<dyn QMLLayer>) -> QuantRS2Result<()> {
142 if layer.num_qubits() != self.config.num_qubits {
143 return Err(QuantRS2Error::InvalidInput(format!(
144 "Layer has {} qubits, circuit has {}",
145 layer.num_qubits(),
146 self.config.num_qubits
147 )));
148 }
149
150 self.num_parameters += layer.parameters().len();
151 self.layers.push(layer);
152 Ok(())
153 }
154
155 pub fn parameters(&self) -> Vec<&Parameter> {
157 self.layers
158 .iter()
159 .flat_map(|layer| layer.parameters().iter())
160 .collect()
161 }
162
163 pub fn set_parameters(&mut self, values: &[f64]) -> QuantRS2Result<()> {
165 if values.len() != self.num_parameters {
166 return Err(QuantRS2Error::InvalidInput(format!(
167 "Expected {} parameters, got {}",
168 self.num_parameters,
169 values.len()
170 )));
171 }
172
173 let mut offset = 0;
174 for layer in &mut self.layers {
175 let layer_params = layer.parameters().len();
176 layer.set_parameters(&values[offset..offset + layer_params])?;
177 offset += layer_params;
178 }
179
180 Ok(())
181 }
182
183 pub fn gates(&self) -> Vec<Box<dyn GateOp>> {
185 self.layers.iter().flat_map(|layer| layer.gates()).collect()
186 }
187
188 pub fn compute_gradients(
190 &self,
191 state: &Array1<Complex64>,
192 loss_gradient: &Array1<Complex64>,
193 ) -> QuantRS2Result<Vec<f64>> {
194 let mut all_gradients = Vec::new();
195
196 for layer in &self.layers {
197 let layer_grads = layer.compute_gradients(state, loss_gradient)?;
198 all_gradients.extend(layer_grads);
199 }
200
201 Ok(all_gradients)
202 }
203}
204
205pub fn create_entangling_gates(
207 num_qubits: usize,
208 pattern: EntanglementPattern,
209) -> Vec<(QubitId, QubitId)> {
210 match pattern {
211 EntanglementPattern::None => vec![],
212
213 EntanglementPattern::Linear => (0..num_qubits - 1)
214 .map(|i| (QubitId(i as u32), QubitId((i + 1) as u32)))
215 .collect(),
216
217 EntanglementPattern::Circular => {
218 let mut gates = vec![];
219 for i in 0..num_qubits {
220 gates.push((QubitId(i as u32), QubitId(((i + 1) % num_qubits) as u32)));
221 }
222 gates
223 }
224
225 EntanglementPattern::Full => {
226 let mut gates = vec![];
227 for i in 0..num_qubits {
228 for j in i + 1..num_qubits {
229 gates.push((QubitId(i as u32), QubitId(j as u32)));
230 }
231 }
232 gates
233 }
234
235 EntanglementPattern::Alternating => {
236 let mut gates = vec![];
237 for i in (0..num_qubits - 1).step_by(2) {
239 gates.push((QubitId(i as u32), QubitId((i + 1) as u32)));
240 }
241 for i in (1..num_qubits - 1).step_by(2) {
243 gates.push((QubitId(i as u32), QubitId((i + 1) as u32)));
244 }
245 gates
246 }
247 }
248}
249
250pub fn quantum_fisher_information(
252 circuit: &QMLCircuit,
253 state: &Array1<Complex64>,
254) -> QuantRS2Result<Array2<f64>> {
255 let num_params = circuit.num_parameters;
256 let mut fisher = Array2::zeros((num_params, num_params));
257
258 Ok(fisher)
265}
266
267pub fn natural_gradient(
269 gradients: &[f64],
270 fisher: &Array2<f64>,
271 regularization: f64,
272) -> QuantRS2Result<Vec<f64>> {
273 let mut regularized_fisher = fisher.clone();
275 for i in 0..fisher.nrows() {
276 regularized_fisher[(i, i)] += regularization;
277 }
278
279 let grad_array = Array1::from_vec(gradients.to_vec());
282
283 Ok(gradients.to_vec())
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291
292 #[test]
293 fn test_entanglement_patterns() {
294 let linear = create_entangling_gates(4, EntanglementPattern::Linear);
295 assert_eq!(linear.len(), 3);
296 assert_eq!(linear[0], (QubitId(0), QubitId(1)));
297
298 let circular = create_entangling_gates(4, EntanglementPattern::Circular);
299 assert_eq!(circular.len(), 4);
300 assert_eq!(circular[3], (QubitId(3), QubitId(0)));
301
302 let full = create_entangling_gates(3, EntanglementPattern::Full);
303 assert_eq!(full.len(), 3); let none = create_entangling_gates(4, EntanglementPattern::None);
306 assert_eq!(none.len(), 0);
307 }
308
309 #[test]
310 fn test_qml_circuit() {
311 let config = QMLConfig {
312 num_qubits: 2,
313 num_layers: 1,
314 ..Default::default()
315 };
316
317 let circuit = QMLCircuit::new(config);
318 assert_eq!(circuit.num_parameters, 0);
319 }
320}