1pub mod encoding;
8pub mod generative_adversarial;
9pub mod layers;
10pub mod nlp;
11pub mod reinforcement_learning;
12pub mod training;
13
14use crate::{
15 error::{QuantRS2Error, QuantRS2Result},
16 gate::GateOp,
17 qubit::QubitId,
18};
19use ndarray::{Array1, Array2};
20use num_complex::Complex64;
21
22pub use layers::Parameter;
24
25pub trait QMLLayer: Send + Sync {
27 fn num_qubits(&self) -> usize;
29
30 fn parameters(&self) -> &[Parameter];
32
33 fn parameters_mut(&mut self) -> &mut [Parameter];
35
36 fn set_parameters(&mut self, values: &[f64]) -> QuantRS2Result<()> {
38 if values.len() != self.parameters().len() {
39 return Err(QuantRS2Error::InvalidInput(format!(
40 "Expected {} parameters, got {}",
41 self.parameters().len(),
42 values.len()
43 )));
44 }
45
46 for (param, &value) in self.parameters_mut().iter_mut().zip(values.iter()) {
47 param.value = value;
48 }
49
50 Ok(())
51 }
52
53 fn gates(&self) -> Vec<Box<dyn GateOp>>;
55
56 fn compute_gradients(
58 &self,
59 state: &Array1<Complex64>,
60 loss_gradient: &Array1<Complex64>,
61 ) -> QuantRS2Result<Vec<f64>>;
62
63 fn name(&self) -> &str;
65}
66
67#[derive(Debug, Clone, Copy, PartialEq)]
69pub enum EncodingStrategy {
70 Amplitude,
72 Angle,
74 IQP,
76 Basis,
78}
79
80#[derive(Debug, Clone)]
82pub struct QMLConfig {
83 pub num_qubits: usize,
85 pub num_layers: usize,
87 pub encoding: EncodingStrategy,
89 pub entanglement: EntanglementPattern,
91 pub data_reuploading: bool,
93}
94
95impl Default for QMLConfig {
96 fn default() -> Self {
97 Self {
98 num_qubits: 4,
99 num_layers: 2,
100 encoding: EncodingStrategy::Angle,
101 entanglement: EntanglementPattern::Full,
102 data_reuploading: false,
103 }
104 }
105}
106
107#[derive(Debug, Clone, Copy, PartialEq)]
109pub enum EntanglementPattern {
110 None,
112 Linear,
114 Circular,
116 Full,
118 Alternating,
120}
121
122pub struct QMLCircuit {
124 config: QMLConfig,
126 layers: Vec<Box<dyn QMLLayer>>,
128 num_parameters: usize,
130}
131
132impl QMLCircuit {
133 pub fn new(config: QMLConfig) -> Self {
135 Self {
136 config,
137 layers: Vec::new(),
138 num_parameters: 0,
139 }
140 }
141
142 pub fn add_layer(&mut self, layer: Box<dyn QMLLayer>) -> QuantRS2Result<()> {
144 if layer.num_qubits() != self.config.num_qubits {
145 return Err(QuantRS2Error::InvalidInput(format!(
146 "Layer has {} qubits, circuit has {}",
147 layer.num_qubits(),
148 self.config.num_qubits
149 )));
150 }
151
152 self.num_parameters += layer.parameters().len();
153 self.layers.push(layer);
154 Ok(())
155 }
156
157 pub fn parameters(&self) -> Vec<&Parameter> {
159 self.layers
160 .iter()
161 .flat_map(|layer| layer.parameters().iter())
162 .collect()
163 }
164
165 pub fn set_parameters(&mut self, values: &[f64]) -> QuantRS2Result<()> {
167 if values.len() != self.num_parameters {
168 return Err(QuantRS2Error::InvalidInput(format!(
169 "Expected {} parameters, got {}",
170 self.num_parameters,
171 values.len()
172 )));
173 }
174
175 let mut offset = 0;
176 for layer in &mut self.layers {
177 let layer_params = layer.parameters().len();
178 layer.set_parameters(&values[offset..offset + layer_params])?;
179 offset += layer_params;
180 }
181
182 Ok(())
183 }
184
185 pub fn gates(&self) -> Vec<Box<dyn GateOp>> {
187 self.layers.iter().flat_map(|layer| layer.gates()).collect()
188 }
189
190 pub fn compute_gradients(
192 &self,
193 state: &Array1<Complex64>,
194 loss_gradient: &Array1<Complex64>,
195 ) -> QuantRS2Result<Vec<f64>> {
196 let mut all_gradients = Vec::new();
197
198 for layer in &self.layers {
199 let layer_grads = layer.compute_gradients(state, loss_gradient)?;
200 all_gradients.extend(layer_grads);
201 }
202
203 Ok(all_gradients)
204 }
205}
206
207pub fn create_entangling_gates(
209 num_qubits: usize,
210 pattern: EntanglementPattern,
211) -> Vec<(QubitId, QubitId)> {
212 match pattern {
213 EntanglementPattern::None => vec![],
214
215 EntanglementPattern::Linear => (0..num_qubits - 1)
216 .map(|i| (QubitId(i as u32), QubitId((i + 1) as u32)))
217 .collect(),
218
219 EntanglementPattern::Circular => {
220 let mut gates = vec![];
221 for i in 0..num_qubits {
222 gates.push((QubitId(i as u32), QubitId(((i + 1) % num_qubits) as u32)));
223 }
224 gates
225 }
226
227 EntanglementPattern::Full => {
228 let mut gates = vec![];
229 for i in 0..num_qubits {
230 for j in i + 1..num_qubits {
231 gates.push((QubitId(i as u32), QubitId(j as u32)));
232 }
233 }
234 gates
235 }
236
237 EntanglementPattern::Alternating => {
238 let mut gates = vec![];
239 for i in (0..num_qubits - 1).step_by(2) {
241 gates.push((QubitId(i as u32), QubitId((i + 1) as u32)));
242 }
243 for i in (1..num_qubits - 1).step_by(2) {
245 gates.push((QubitId(i as u32), QubitId((i + 1) as u32)));
246 }
247 gates
248 }
249 }
250}
251
252pub fn quantum_fisher_information(
254 circuit: &QMLCircuit,
255 _state: &Array1<Complex64>,
256) -> QuantRS2Result<Array2<f64>> {
257 let num_params = circuit.num_parameters;
258 let fisher = Array2::zeros((num_params, num_params));
259
260 Ok(fisher)
267}
268
269pub fn natural_gradient(
271 gradients: &[f64],
272 fisher: &Array2<f64>,
273 regularization: f64,
274) -> QuantRS2Result<Vec<f64>> {
275 let mut regularized_fisher = fisher.clone();
277 for i in 0..fisher.nrows() {
278 regularized_fisher[(i, i)] += regularization;
279 }
280
281 let _grad_array = Array1::from_vec(gradients.to_vec());
284
285 Ok(gradients.to_vec())
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293
294 #[test]
295 fn test_entanglement_patterns() {
296 let linear = create_entangling_gates(4, EntanglementPattern::Linear);
297 assert_eq!(linear.len(), 3);
298 assert_eq!(linear[0], (QubitId(0), QubitId(1)));
299
300 let circular = create_entangling_gates(4, EntanglementPattern::Circular);
301 assert_eq!(circular.len(), 4);
302 assert_eq!(circular[3], (QubitId(3), QubitId(0)));
303
304 let full = create_entangling_gates(3, EntanglementPattern::Full);
305 assert_eq!(full.len(), 3); let none = create_entangling_gates(4, EntanglementPattern::None);
308 assert_eq!(none.len(), 0);
309 }
310
311 #[test]
312 fn test_qml_circuit() {
313 let config = QMLConfig {
314 num_qubits: 2,
315 num_layers: 1,
316 ..Default::default()
317 };
318
319 let circuit = QMLCircuit::new(config);
320 assert_eq!(circuit.num_parameters, 0);
321 }
322}