1use crate::{
33 error::{QuantRS2Error, QuantRS2Result},
34 gate::GateOp,
35 qubit::QubitId,
36};
37use scirs2_core::ndarray::{Array1, Array2, Axis};
38use scirs2_core::random::prelude::*;
39use scirs2_core::Complex64;
40use std::f64::consts::PI;
41
42#[derive(Debug, Clone)]
44pub struct QuantumMetaLearningConfig {
45 pub num_qubits: usize,
47 pub circuit_depth: usize,
49 pub inner_lr: f64,
51 pub outer_lr: f64,
53 pub inner_steps: usize,
55 pub n_support: usize,
57 pub n_query: usize,
59 pub n_way: usize,
61 pub meta_batch_size: usize,
63}
64
65impl Default for QuantumMetaLearningConfig {
66 fn default() -> Self {
67 Self {
68 num_qubits: 4,
69 circuit_depth: 4,
70 inner_lr: 0.01,
71 outer_lr: 0.001,
72 inner_steps: 5,
73 n_support: 5,
74 n_query: 15,
75 n_way: 2,
76 meta_batch_size: 4,
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct QuantumTask {
84 pub support_states: Vec<Array1<Complex64>>,
86 pub support_labels: Vec<usize>,
88 pub query_states: Vec<Array1<Complex64>>,
90 pub query_labels: Vec<usize>,
92}
93
94impl QuantumTask {
95 pub fn new(
97 support_states: Vec<Array1<Complex64>>,
98 support_labels: Vec<usize>,
99 query_states: Vec<Array1<Complex64>>,
100 query_labels: Vec<usize>,
101 ) -> Self {
102 Self {
103 support_states,
104 support_labels,
105 query_states,
106 query_labels,
107 }
108 }
109
110 pub fn random(num_qubits: usize, n_way: usize, n_support: usize, n_query: usize) -> Self {
112 let mut rng = thread_rng();
113 let dim = 1 << num_qubits;
114
115 let mut support_states = Vec::new();
116 let mut support_labels = Vec::new();
117 let mut query_states = Vec::new();
118 let mut query_labels = Vec::new();
119
120 for class in 0..n_way {
121 let mut prototype = Array1::from_shape_fn(dim, |_| {
123 Complex64::new(rng.gen_range(-1.0..1.0), rng.gen_range(-1.0..1.0))
124 });
125 let norm: f64 = prototype.iter().map(|x| x.norm_sqr()).sum::<f64>().sqrt();
126 for i in 0..dim {
127 prototype[i] = prototype[i] / norm;
128 }
129
130 for _ in 0..n_support {
132 let mut state = prototype.clone();
133 for i in 0..dim {
135 state[i] = state[i]
136 + Complex64::new(rng.gen_range(-0.1..0.1), rng.gen_range(-0.1..0.1));
137 }
138 let norm: f64 = state.iter().map(|x| x.norm_sqr()).sum::<f64>().sqrt();
139 for i in 0..dim {
140 state[i] = state[i] / norm;
141 }
142 support_states.push(state);
143 support_labels.push(class);
144 }
145
146 for _ in 0..n_query {
148 let mut state = prototype.clone();
149 for i in 0..dim {
151 state[i] = state[i]
152 + Complex64::new(rng.gen_range(-0.1..0.1), rng.gen_range(-0.1..0.1));
153 }
154 let norm: f64 = state.iter().map(|x| x.norm_sqr()).sum::<f64>().sqrt();
155 for i in 0..dim {
156 state[i] = state[i] / norm;
157 }
158 query_states.push(state);
159 query_labels.push(class);
160 }
161 }
162
163 Self {
164 support_states,
165 support_labels,
166 query_states,
167 query_labels,
168 }
169 }
170}
171
172#[derive(Debug, Clone)]
174pub struct QuantumMetaCircuit {
175 num_qubits: usize,
177 depth: usize,
179 num_classes: usize,
181 params: Array2<f64>,
183 readout_weights: Array2<f64>,
185}
186
187impl QuantumMetaCircuit {
188 pub fn new(num_qubits: usize, depth: usize, num_classes: usize) -> Self {
190 let mut rng = thread_rng();
191
192 let params = Array2::from_shape_fn((depth, num_qubits * 3), |_| rng.gen_range(-PI..PI));
193
194 let scale = (2.0 / num_qubits as f64).sqrt();
195 let readout_weights =
196 Array2::from_shape_fn((num_classes, num_qubits), |_| rng.gen_range(-scale..scale));
197
198 Self {
199 num_qubits,
200 depth,
201 num_classes,
202 params,
203 readout_weights,
204 }
205 }
206
207 pub fn forward(&self, state: &Array1<Complex64>) -> QuantRS2Result<Array1<f64>> {
209 let mut encoded = state.clone();
211
212 for layer in 0..self.depth {
213 for q in 0..self.num_qubits {
215 let rx = self.params[[layer, q * 3]];
216 let ry = self.params[[layer, q * 3 + 1]];
217 let rz = self.params[[layer, q * 3 + 2]];
218
219 encoded = self.apply_rotation(&encoded, q, rx, ry, rz)?;
220 }
221
222 for q in 0..self.num_qubits - 1 {
224 encoded = self.apply_cnot(&encoded, q, q + 1)?;
225 }
226 }
227
228 let mut expectations = Array1::zeros(self.num_qubits);
230 for q in 0..self.num_qubits {
231 expectations[q] = self.pauli_z_expectation(&encoded, q)?;
232 }
233
234 let mut logits = Array1::zeros(self.num_classes);
236 for i in 0..self.num_classes {
237 for j in 0..self.num_qubits {
238 logits[i] += self.readout_weights[[i, j]] * expectations[j];
239 }
240 }
241
242 let max_logit = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
244 let mut probs = Array1::zeros(self.num_classes);
245 let mut sum_exp = 0.0;
246
247 for i in 0..self.num_classes {
248 probs[i] = (logits[i] - max_logit).exp();
249 sum_exp += probs[i];
250 }
251
252 for i in 0..self.num_classes {
253 probs[i] /= sum_exp;
254 }
255
256 Ok(probs)
257 }
258
259 pub fn compute_loss(
261 &self,
262 states: &[Array1<Complex64>],
263 labels: &[usize],
264 ) -> QuantRS2Result<f64> {
265 let mut total_loss = 0.0;
266
267 for (state, &label) in states.iter().zip(labels.iter()) {
268 let probs = self.forward(state)?;
269 total_loss -= probs[label].ln();
271 }
272
273 Ok(total_loss / states.len() as f64)
274 }
275
276 pub fn compute_gradients(
278 &self,
279 states: &[Array1<Complex64>],
280 labels: &[usize],
281 ) -> QuantRS2Result<(Array2<f64>, Array2<f64>)> {
282 let epsilon = 1e-4;
283
284 let mut param_grads = Array2::zeros(self.params.dim());
286
287 for i in 0..self.params.shape()[0] {
288 for j in 0..self.params.shape()[1] {
289 let mut circuit_plus = self.clone();
290 circuit_plus.params[[i, j]] += epsilon;
291 let loss_plus = circuit_plus.compute_loss(states, labels)?;
292
293 let mut circuit_minus = self.clone();
294 circuit_minus.params[[i, j]] -= epsilon;
295 let loss_minus = circuit_minus.compute_loss(states, labels)?;
296
297 param_grads[[i, j]] = (loss_plus - loss_minus) / (2.0 * epsilon);
298 }
299 }
300
301 let mut readout_grads = Array2::zeros(self.readout_weights.dim());
303
304 for i in 0..self.readout_weights.shape()[0] {
305 for j in 0..self.readout_weights.shape()[1] {
306 let mut circuit_plus = self.clone();
307 circuit_plus.readout_weights[[i, j]] += epsilon;
308 let loss_plus = circuit_plus.compute_loss(states, labels)?;
309
310 let mut circuit_minus = self.clone();
311 circuit_minus.readout_weights[[i, j]] -= epsilon;
312 let loss_minus = circuit_minus.compute_loss(states, labels)?;
313
314 readout_grads[[i, j]] = (loss_plus - loss_minus) / (2.0 * epsilon);
315 }
316 }
317
318 Ok((param_grads, readout_grads))
319 }
320
321 pub fn update_params(
323 &mut self,
324 param_grads: &Array2<f64>,
325 readout_grads: &Array2<f64>,
326 lr: f64,
327 ) {
328 self.params = &self.params - &(param_grads * lr);
329 self.readout_weights = &self.readout_weights - &(readout_grads * lr);
330 }
331
332 fn apply_rotation(
334 &self,
335 state: &Array1<Complex64>,
336 qubit: usize,
337 rx: f64,
338 ry: f64,
339 rz: f64,
340 ) -> QuantRS2Result<Array1<Complex64>> {
341 let mut result = state.clone();
342 result = self.apply_rz_gate(&result, qubit, rz)?;
343 result = self.apply_ry_gate(&result, qubit, ry)?;
344 result = self.apply_rx_gate(&result, qubit, rx)?;
345 Ok(result)
346 }
347
348 fn apply_rx_gate(
349 &self,
350 state: &Array1<Complex64>,
351 qubit: usize,
352 angle: f64,
353 ) -> QuantRS2Result<Array1<Complex64>> {
354 let dim = state.len();
355 let mut new_state = Array1::zeros(dim);
356 let cos_half = Complex64::new((angle / 2.0).cos(), 0.0);
357 let sin_half = Complex64::new(0.0, -(angle / 2.0).sin());
358
359 for i in 0..dim {
360 let j = i ^ (1 << qubit);
361 new_state[i] = state[i] * cos_half + state[j] * sin_half;
362 }
363
364 Ok(new_state)
365 }
366
367 fn apply_ry_gate(
368 &self,
369 state: &Array1<Complex64>,
370 qubit: usize,
371 angle: f64,
372 ) -> QuantRS2Result<Array1<Complex64>> {
373 let dim = state.len();
374 let mut new_state = Array1::zeros(dim);
375 let cos_half = (angle / 2.0).cos();
376 let sin_half = (angle / 2.0).sin();
377
378 for i in 0..dim {
379 let bit = (i >> qubit) & 1;
380 let j = i ^ (1 << qubit);
381 if bit == 0 {
382 new_state[i] = state[i] * cos_half - state[j] * sin_half;
383 } else {
384 new_state[i] = state[i] * cos_half + state[j] * sin_half;
385 }
386 }
387
388 Ok(new_state)
389 }
390
391 fn apply_rz_gate(
392 &self,
393 state: &Array1<Complex64>,
394 qubit: usize,
395 angle: f64,
396 ) -> QuantRS2Result<Array1<Complex64>> {
397 let dim = state.len();
398 let mut new_state = state.clone();
399 let phase = Complex64::new((angle / 2.0).cos(), -(angle / 2.0).sin());
400
401 for i in 0..dim {
402 let bit = (i >> qubit) & 1;
403 new_state[i] = if bit == 1 {
404 new_state[i] * phase
405 } else {
406 new_state[i] * phase.conj()
407 };
408 }
409
410 Ok(new_state)
411 }
412
413 fn apply_cnot(
414 &self,
415 state: &Array1<Complex64>,
416 control: usize,
417 target: usize,
418 ) -> QuantRS2Result<Array1<Complex64>> {
419 let dim = state.len();
420 let mut new_state = state.clone();
421
422 for i in 0..dim {
423 let control_bit = (i >> control) & 1;
424 if control_bit == 1 {
425 let j = i ^ (1 << target);
426 if i < j {
427 let temp = new_state[i];
428 new_state[i] = new_state[j];
429 new_state[j] = temp;
430 }
431 }
432 }
433
434 Ok(new_state)
435 }
436
437 fn pauli_z_expectation(&self, state: &Array1<Complex64>, qubit: usize) -> QuantRS2Result<f64> {
438 let dim = state.len();
439 let mut expectation = 0.0;
440
441 for i in 0..dim {
442 let bit = (i >> qubit) & 1;
443 let sign = if bit == 0 { 1.0 } else { -1.0 };
444 expectation += sign * state[i].norm_sqr();
445 }
446
447 Ok(expectation)
448 }
449}
450
451#[derive(Debug, Clone)]
453pub struct QuantumMAML {
454 config: QuantumMetaLearningConfig,
456 meta_model: QuantumMetaCircuit,
458}
459
460impl QuantumMAML {
461 pub fn new(config: QuantumMetaLearningConfig) -> Self {
463 let meta_model =
464 QuantumMetaCircuit::new(config.num_qubits, config.circuit_depth, config.n_way);
465
466 Self { config, meta_model }
467 }
468
469 pub fn meta_train_step(&mut self, tasks: &[QuantumTask]) -> QuantRS2Result<f64> {
471 let mut meta_param_grads = Array2::zeros(self.meta_model.params.dim());
472 let mut meta_readout_grads = Array2::zeros(self.meta_model.readout_weights.dim());
473 let mut total_loss = 0.0;
474
475 for task in tasks {
476 let mut adapted_model = self.meta_model.clone();
478
479 for _ in 0..self.config.inner_steps {
480 let (param_grads, readout_grads) =
481 adapted_model.compute_gradients(&task.support_states, &task.support_labels)?;
482
483 adapted_model.update_params(¶m_grads, &readout_grads, self.config.inner_lr);
484 }
485
486 let query_loss = adapted_model.compute_loss(&task.query_states, &task.query_labels)?;
488 total_loss += query_loss;
489
490 let (param_grads, readout_grads) =
492 adapted_model.compute_gradients(&task.query_states, &task.query_labels)?;
493
494 meta_param_grads = meta_param_grads + param_grads;
495 meta_readout_grads = meta_readout_grads + readout_grads;
496 }
497
498 meta_param_grads = meta_param_grads / (tasks.len() as f64);
500 meta_readout_grads = meta_readout_grads / (tasks.len() as f64);
501
502 self.meta_model
504 .update_params(&meta_param_grads, &meta_readout_grads, self.config.outer_lr);
505
506 Ok(total_loss / tasks.len() as f64)
507 }
508
509 pub fn adapt(&self, task: &QuantumTask) -> QuantRS2Result<QuantumMetaCircuit> {
511 let mut adapted_model = self.meta_model.clone();
512
513 for _ in 0..self.config.inner_steps {
514 let (param_grads, readout_grads) =
515 adapted_model.compute_gradients(&task.support_states, &task.support_labels)?;
516
517 adapted_model.update_params(¶m_grads, &readout_grads, self.config.inner_lr);
518 }
519
520 Ok(adapted_model)
521 }
522
523 pub fn evaluate(&self, task: &QuantumTask) -> QuantRS2Result<f64> {
525 let adapted_model = self.adapt(task)?;
526
527 let mut correct = 0;
528 for (state, &label) in task.query_states.iter().zip(task.query_labels.iter()) {
529 let probs = adapted_model.forward(state)?;
530 let mut max_prob = f64::NEG_INFINITY;
531 let mut predicted = 0;
532
533 for (i, &prob) in probs.iter().enumerate() {
534 if prob > max_prob {
535 max_prob = prob;
536 predicted = i;
537 }
538 }
539
540 if predicted == label {
541 correct += 1;
542 }
543 }
544
545 Ok(correct as f64 / task.query_states.len() as f64)
546 }
547
548 pub fn meta_model(&self) -> &QuantumMetaCircuit {
550 &self.meta_model
551 }
552}
553
554#[derive(Debug, Clone)]
556pub struct QuantumReptile {
557 config: QuantumMetaLearningConfig,
559 meta_model: QuantumMetaCircuit,
561}
562
563impl QuantumReptile {
564 pub fn new(config: QuantumMetaLearningConfig) -> Self {
566 let meta_model =
567 QuantumMetaCircuit::new(config.num_qubits, config.circuit_depth, config.n_way);
568
569 Self { config, meta_model }
570 }
571
572 pub fn meta_train_step(&mut self, task: &QuantumTask) -> QuantRS2Result<f64> {
574 let mut adapted_model = self.meta_model.clone();
576
577 for _ in 0..self.config.inner_steps {
578 let (param_grads, readout_grads) =
579 adapted_model.compute_gradients(&task.support_states, &task.support_labels)?;
580
581 adapted_model.update_params(¶m_grads, &readout_grads, self.config.inner_lr);
582 }
583
584 let loss = adapted_model.compute_loss(&task.query_states, &task.query_labels)?;
586
587 let param_diff = &adapted_model.params - &self.meta_model.params;
589 let readout_diff = &adapted_model.readout_weights - &self.meta_model.readout_weights;
590
591 self.meta_model.params = &self.meta_model.params + &(param_diff * self.config.outer_lr);
592 self.meta_model.readout_weights =
593 &self.meta_model.readout_weights + &(readout_diff * self.config.outer_lr);
594
595 Ok(loss)
596 }
597
598 pub fn adapt(&self, task: &QuantumTask) -> QuantRS2Result<QuantumMetaCircuit> {
600 let mut adapted_model = self.meta_model.clone();
601
602 for _ in 0..self.config.inner_steps {
603 let (param_grads, readout_grads) =
604 adapted_model.compute_gradients(&task.support_states, &task.support_labels)?;
605
606 adapted_model.update_params(¶m_grads, &readout_grads, self.config.inner_lr);
607 }
608
609 Ok(adapted_model)
610 }
611}
612
613#[cfg(test)]
614mod tests {
615 use super::*;
616
617 #[test]
618 fn test_quantum_meta_circuit() {
619 let circuit = QuantumMetaCircuit::new(3, 2, 2);
620
621 let state = Array1::from_vec(vec![
622 Complex64::new(1.0, 0.0),
623 Complex64::new(0.0, 0.0),
624 Complex64::new(0.0, 0.0),
625 Complex64::new(0.0, 0.0),
626 Complex64::new(0.0, 0.0),
627 Complex64::new(0.0, 0.0),
628 Complex64::new(0.0, 0.0),
629 Complex64::new(0.0, 0.0),
630 ]);
631
632 let probs = circuit.forward(&state).unwrap();
633 assert_eq!(probs.len(), 2);
634
635 let sum: f64 = probs.iter().sum();
636 assert!((sum - 1.0).abs() < 1e-6);
637 }
638
639 #[test]
640 fn test_quantum_maml() {
641 let config = QuantumMetaLearningConfig {
642 num_qubits: 2,
643 circuit_depth: 2,
644 inner_lr: 0.01,
645 outer_lr: 0.001,
646 inner_steps: 3,
647 n_support: 2,
648 n_query: 5,
649 n_way: 2,
650 meta_batch_size: 2,
651 };
652
653 let maml = QuantumMAML::new(config.clone());
654
655 let task = QuantumTask::random(
656 config.num_qubits,
657 config.n_way,
658 config.n_support,
659 config.n_query,
660 );
661
662 let adapted_model = maml.adapt(&task).unwrap();
663 let probs = adapted_model.forward(&task.query_states[0]).unwrap();
664
665 assert_eq!(probs.len(), config.n_way);
666 }
667}