1use crate::{
33 error::{QuantRS2Error, QuantRS2Result},
34 gate::GateOp,
35 qubit::QubitId,
36};
37use scirs2_core::ndarray::{Array1, Array2, Axis};
38use scirs2_core::random::prelude::*;
39use scirs2_core::Complex64;
40use std::f64::consts::PI;
41
42#[derive(Debug, Clone)]
44pub struct QuantumMetaLearningConfig {
45 pub num_qubits: usize,
47 pub circuit_depth: usize,
49 pub inner_lr: f64,
51 pub outer_lr: f64,
53 pub inner_steps: usize,
55 pub n_support: usize,
57 pub n_query: usize,
59 pub n_way: usize,
61 pub meta_batch_size: usize,
63}
64
65impl Default for QuantumMetaLearningConfig {
66 fn default() -> Self {
67 Self {
68 num_qubits: 4,
69 circuit_depth: 4,
70 inner_lr: 0.01,
71 outer_lr: 0.001,
72 inner_steps: 5,
73 n_support: 5,
74 n_query: 15,
75 n_way: 2,
76 meta_batch_size: 4,
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct QuantumTask {
84 pub support_states: Vec<Array1<Complex64>>,
86 pub support_labels: Vec<usize>,
88 pub query_states: Vec<Array1<Complex64>>,
90 pub query_labels: Vec<usize>,
92}
93
94impl QuantumTask {
95 pub const fn new(
97 support_states: Vec<Array1<Complex64>>,
98 support_labels: Vec<usize>,
99 query_states: Vec<Array1<Complex64>>,
100 query_labels: Vec<usize>,
101 ) -> Self {
102 Self {
103 support_states,
104 support_labels,
105 query_states,
106 query_labels,
107 }
108 }
109
110 pub fn random(num_qubits: usize, n_way: usize, n_support: usize, n_query: usize) -> Self {
112 let mut rng = thread_rng();
113 let dim = 1 << num_qubits;
114
115 let mut support_states = Vec::new();
116 let mut support_labels = Vec::new();
117 let mut query_states = Vec::new();
118 let mut query_labels = Vec::new();
119
120 for class in 0..n_way {
121 let mut prototype = Array1::from_shape_fn(dim, |_| {
123 Complex64::new(rng.random_range(-1.0..1.0), rng.random_range(-1.0..1.0))
124 });
125 let norm: f64 = prototype.iter().map(|x| x.norm_sqr()).sum::<f64>().sqrt();
126 for i in 0..dim {
127 prototype[i] = prototype[i] / norm;
128 }
129
130 for _ in 0..n_support {
132 let mut state = prototype.clone();
133 for i in 0..dim {
135 state[i] = state[i]
136 + Complex64::new(rng.random_range(-0.1..0.1), rng.random_range(-0.1..0.1));
137 }
138 let norm: f64 = state.iter().map(|x| x.norm_sqr()).sum::<f64>().sqrt();
139 for i in 0..dim {
140 state[i] = state[i] / norm;
141 }
142 support_states.push(state);
143 support_labels.push(class);
144 }
145
146 for _ in 0..n_query {
148 let mut state = prototype.clone();
149 for i in 0..dim {
151 state[i] = state[i]
152 + Complex64::new(rng.random_range(-0.1..0.1), rng.random_range(-0.1..0.1));
153 }
154 let norm: f64 = state.iter().map(|x| x.norm_sqr()).sum::<f64>().sqrt();
155 for i in 0..dim {
156 state[i] = state[i] / norm;
157 }
158 query_states.push(state);
159 query_labels.push(class);
160 }
161 }
162
163 Self {
164 support_states,
165 support_labels,
166 query_states,
167 query_labels,
168 }
169 }
170}
171
172#[derive(Debug, Clone)]
174pub struct QuantumMetaCircuit {
175 num_qubits: usize,
177 depth: usize,
179 num_classes: usize,
181 params: Array2<f64>,
183 readout_weights: Array2<f64>,
185}
186
187impl QuantumMetaCircuit {
188 pub fn new(num_qubits: usize, depth: usize, num_classes: usize) -> Self {
190 let mut rng = thread_rng();
191
192 let params = Array2::from_shape_fn((depth, num_qubits * 3), |_| rng.random_range(-PI..PI));
193
194 let scale = (2.0 / num_qubits as f64).sqrt();
195 let readout_weights = Array2::from_shape_fn((num_classes, num_qubits), |_| {
196 rng.random_range(-scale..scale)
197 });
198
199 Self {
200 num_qubits,
201 depth,
202 num_classes,
203 params,
204 readout_weights,
205 }
206 }
207
208 pub fn forward(&self, state: &Array1<Complex64>) -> QuantRS2Result<Array1<f64>> {
210 let mut encoded = state.clone();
212
213 for layer in 0..self.depth {
214 for q in 0..self.num_qubits {
216 let rx = self.params[[layer, q * 3]];
217 let ry = self.params[[layer, q * 3 + 1]];
218 let rz = self.params[[layer, q * 3 + 2]];
219
220 encoded = self.apply_rotation(&encoded, q, rx, ry, rz)?;
221 }
222
223 for q in 0..self.num_qubits - 1 {
225 encoded = self.apply_cnot(&encoded, q, q + 1)?;
226 }
227 }
228
229 let mut expectations = Array1::zeros(self.num_qubits);
231 for q in 0..self.num_qubits {
232 expectations[q] = self.pauli_z_expectation(&encoded, q)?;
233 }
234
235 let mut logits = Array1::zeros(self.num_classes);
237 for i in 0..self.num_classes {
238 for j in 0..self.num_qubits {
239 logits[i] += self.readout_weights[[i, j]] * expectations[j];
240 }
241 }
242
243 let max_logit = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max);
245 let mut probs = Array1::zeros(self.num_classes);
246 let mut sum_exp = 0.0;
247
248 for i in 0..self.num_classes {
249 probs[i] = (logits[i] - max_logit).exp();
250 sum_exp += probs[i];
251 }
252
253 for i in 0..self.num_classes {
254 probs[i] /= sum_exp;
255 }
256
257 Ok(probs)
258 }
259
260 pub fn compute_loss(
262 &self,
263 states: &[Array1<Complex64>],
264 labels: &[usize],
265 ) -> QuantRS2Result<f64> {
266 let mut total_loss = 0.0;
267
268 for (state, &label) in states.iter().zip(labels.iter()) {
269 let probs = self.forward(state)?;
270 total_loss -= probs[label].ln();
272 }
273
274 Ok(total_loss / states.len() as f64)
275 }
276
277 pub fn compute_gradients(
279 &self,
280 states: &[Array1<Complex64>],
281 labels: &[usize],
282 ) -> QuantRS2Result<(Array2<f64>, Array2<f64>)> {
283 let epsilon = 1e-4;
284
285 let mut param_grads = Array2::zeros(self.params.dim());
287
288 for i in 0..self.params.shape()[0] {
289 for j in 0..self.params.shape()[1] {
290 let mut circuit_plus = self.clone();
291 circuit_plus.params[[i, j]] += epsilon;
292 let loss_plus = circuit_plus.compute_loss(states, labels)?;
293
294 let mut circuit_minus = self.clone();
295 circuit_minus.params[[i, j]] -= epsilon;
296 let loss_minus = circuit_minus.compute_loss(states, labels)?;
297
298 param_grads[[i, j]] = (loss_plus - loss_minus) / (2.0 * epsilon);
299 }
300 }
301
302 let mut readout_grads = Array2::zeros(self.readout_weights.dim());
304
305 for i in 0..self.readout_weights.shape()[0] {
306 for j in 0..self.readout_weights.shape()[1] {
307 let mut circuit_plus = self.clone();
308 circuit_plus.readout_weights[[i, j]] += epsilon;
309 let loss_plus = circuit_plus.compute_loss(states, labels)?;
310
311 let mut circuit_minus = self.clone();
312 circuit_minus.readout_weights[[i, j]] -= epsilon;
313 let loss_minus = circuit_minus.compute_loss(states, labels)?;
314
315 readout_grads[[i, j]] = (loss_plus - loss_minus) / (2.0 * epsilon);
316 }
317 }
318
319 Ok((param_grads, readout_grads))
320 }
321
322 pub fn update_params(
324 &mut self,
325 param_grads: &Array2<f64>,
326 readout_grads: &Array2<f64>,
327 lr: f64,
328 ) {
329 self.params = &self.params - &(param_grads * lr);
330 self.readout_weights = &self.readout_weights - &(readout_grads * lr);
331 }
332
333 fn apply_rotation(
335 &self,
336 state: &Array1<Complex64>,
337 qubit: usize,
338 rx: f64,
339 ry: f64,
340 rz: f64,
341 ) -> QuantRS2Result<Array1<Complex64>> {
342 let mut result = state.clone();
343 result = self.apply_rz_gate(&result, qubit, rz)?;
344 result = self.apply_ry_gate(&result, qubit, ry)?;
345 result = self.apply_rx_gate(&result, qubit, rx)?;
346 Ok(result)
347 }
348
349 fn apply_rx_gate(
350 &self,
351 state: &Array1<Complex64>,
352 qubit: usize,
353 angle: f64,
354 ) -> QuantRS2Result<Array1<Complex64>> {
355 let dim = state.len();
356 let mut new_state = Array1::zeros(dim);
357 let cos_half = Complex64::new((angle / 2.0).cos(), 0.0);
358 let sin_half = Complex64::new(0.0, -(angle / 2.0).sin());
359
360 for i in 0..dim {
361 let j = i ^ (1 << qubit);
362 new_state[i] = state[i] * cos_half + state[j] * sin_half;
363 }
364
365 Ok(new_state)
366 }
367
368 fn apply_ry_gate(
369 &self,
370 state: &Array1<Complex64>,
371 qubit: usize,
372 angle: f64,
373 ) -> QuantRS2Result<Array1<Complex64>> {
374 let dim = state.len();
375 let mut new_state = Array1::zeros(dim);
376 let cos_half = (angle / 2.0).cos();
377 let sin_half = (angle / 2.0).sin();
378
379 for i in 0..dim {
380 let bit = (i >> qubit) & 1;
381 let j = i ^ (1 << qubit);
382 if bit == 0 {
383 new_state[i] = state[i] * cos_half - state[j] * sin_half;
384 } else {
385 new_state[i] = state[i] * cos_half + state[j] * sin_half;
386 }
387 }
388
389 Ok(new_state)
390 }
391
392 fn apply_rz_gate(
393 &self,
394 state: &Array1<Complex64>,
395 qubit: usize,
396 angle: f64,
397 ) -> QuantRS2Result<Array1<Complex64>> {
398 let dim = state.len();
399 let mut new_state = state.clone();
400 let phase = Complex64::new((angle / 2.0).cos(), -(angle / 2.0).sin());
401
402 for i in 0..dim {
403 let bit = (i >> qubit) & 1;
404 new_state[i] = if bit == 1 {
405 new_state[i] * phase
406 } else {
407 new_state[i] * phase.conj()
408 };
409 }
410
411 Ok(new_state)
412 }
413
414 fn apply_cnot(
415 &self,
416 state: &Array1<Complex64>,
417 control: usize,
418 target: usize,
419 ) -> QuantRS2Result<Array1<Complex64>> {
420 let dim = state.len();
421 let mut new_state = state.clone();
422
423 for i in 0..dim {
424 let control_bit = (i >> control) & 1;
425 if control_bit == 1 {
426 let j = i ^ (1 << target);
427 if i < j {
428 let temp = new_state[i];
429 new_state[i] = new_state[j];
430 new_state[j] = temp;
431 }
432 }
433 }
434
435 Ok(new_state)
436 }
437
438 fn pauli_z_expectation(&self, state: &Array1<Complex64>, qubit: usize) -> QuantRS2Result<f64> {
439 let dim = state.len();
440 let mut expectation = 0.0;
441
442 for i in 0..dim {
443 let bit = (i >> qubit) & 1;
444 let sign = if bit == 0 { 1.0 } else { -1.0 };
445 expectation += sign * state[i].norm_sqr();
446 }
447
448 Ok(expectation)
449 }
450}
451
452#[derive(Debug, Clone)]
454pub struct QuantumMAML {
455 config: QuantumMetaLearningConfig,
457 meta_model: QuantumMetaCircuit,
459}
460
461impl QuantumMAML {
462 pub fn new(config: QuantumMetaLearningConfig) -> Self {
464 let meta_model =
465 QuantumMetaCircuit::new(config.num_qubits, config.circuit_depth, config.n_way);
466
467 Self { config, meta_model }
468 }
469
470 pub fn meta_train_step(&mut self, tasks: &[QuantumTask]) -> QuantRS2Result<f64> {
472 let mut meta_param_grads = Array2::zeros(self.meta_model.params.dim());
473 let mut meta_readout_grads = Array2::zeros(self.meta_model.readout_weights.dim());
474 let mut total_loss = 0.0;
475
476 for task in tasks {
477 let mut adapted_model = self.meta_model.clone();
479
480 for _ in 0..self.config.inner_steps {
481 let (param_grads, readout_grads) =
482 adapted_model.compute_gradients(&task.support_states, &task.support_labels)?;
483
484 adapted_model.update_params(¶m_grads, &readout_grads, self.config.inner_lr);
485 }
486
487 let query_loss = adapted_model.compute_loss(&task.query_states, &task.query_labels)?;
489 total_loss += query_loss;
490
491 let (param_grads, readout_grads) =
493 adapted_model.compute_gradients(&task.query_states, &task.query_labels)?;
494
495 meta_param_grads = meta_param_grads + param_grads;
496 meta_readout_grads = meta_readout_grads + readout_grads;
497 }
498
499 meta_param_grads = meta_param_grads / (tasks.len() as f64);
501 meta_readout_grads = meta_readout_grads / (tasks.len() as f64);
502
503 self.meta_model
505 .update_params(&meta_param_grads, &meta_readout_grads, self.config.outer_lr);
506
507 Ok(total_loss / tasks.len() as f64)
508 }
509
510 pub fn adapt(&self, task: &QuantumTask) -> QuantRS2Result<QuantumMetaCircuit> {
512 let mut adapted_model = self.meta_model.clone();
513
514 for _ in 0..self.config.inner_steps {
515 let (param_grads, readout_grads) =
516 adapted_model.compute_gradients(&task.support_states, &task.support_labels)?;
517
518 adapted_model.update_params(¶m_grads, &readout_grads, self.config.inner_lr);
519 }
520
521 Ok(adapted_model)
522 }
523
524 pub fn evaluate(&self, task: &QuantumTask) -> QuantRS2Result<f64> {
526 let adapted_model = self.adapt(task)?;
527
528 let mut correct = 0;
529 for (state, &label) in task.query_states.iter().zip(task.query_labels.iter()) {
530 let probs = adapted_model.forward(state)?;
531 let mut max_prob = f64::NEG_INFINITY;
532 let mut predicted = 0;
533
534 for (i, &prob) in probs.iter().enumerate() {
535 if prob > max_prob {
536 max_prob = prob;
537 predicted = i;
538 }
539 }
540
541 if predicted == label {
542 correct += 1;
543 }
544 }
545
546 Ok(correct as f64 / task.query_states.len() as f64)
547 }
548
549 pub const fn meta_model(&self) -> &QuantumMetaCircuit {
551 &self.meta_model
552 }
553}
554
555#[derive(Debug, Clone)]
557pub struct QuantumReptile {
558 config: QuantumMetaLearningConfig,
560 meta_model: QuantumMetaCircuit,
562}
563
564impl QuantumReptile {
565 pub fn new(config: QuantumMetaLearningConfig) -> Self {
567 let meta_model =
568 QuantumMetaCircuit::new(config.num_qubits, config.circuit_depth, config.n_way);
569
570 Self { config, meta_model }
571 }
572
573 pub fn meta_train_step(&mut self, task: &QuantumTask) -> QuantRS2Result<f64> {
575 let mut adapted_model = self.meta_model.clone();
577
578 for _ in 0..self.config.inner_steps {
579 let (param_grads, readout_grads) =
580 adapted_model.compute_gradients(&task.support_states, &task.support_labels)?;
581
582 adapted_model.update_params(¶m_grads, &readout_grads, self.config.inner_lr);
583 }
584
585 let loss = adapted_model.compute_loss(&task.query_states, &task.query_labels)?;
587
588 let param_diff = &adapted_model.params - &self.meta_model.params;
590 let readout_diff = &adapted_model.readout_weights - &self.meta_model.readout_weights;
591
592 self.meta_model.params = &self.meta_model.params + &(param_diff * self.config.outer_lr);
593 self.meta_model.readout_weights =
594 &self.meta_model.readout_weights + &(readout_diff * self.config.outer_lr);
595
596 Ok(loss)
597 }
598
599 pub fn adapt(&self, task: &QuantumTask) -> QuantRS2Result<QuantumMetaCircuit> {
601 let mut adapted_model = self.meta_model.clone();
602
603 for _ in 0..self.config.inner_steps {
604 let (param_grads, readout_grads) =
605 adapted_model.compute_gradients(&task.support_states, &task.support_labels)?;
606
607 adapted_model.update_params(¶m_grads, &readout_grads, self.config.inner_lr);
608 }
609
610 Ok(adapted_model)
611 }
612}
613
614#[cfg(test)]
615mod tests {
616 use super::*;
617
618 #[test]
619 fn test_quantum_meta_circuit() {
620 let circuit = QuantumMetaCircuit::new(3, 2, 2);
621
622 let state = Array1::from_vec(vec![
623 Complex64::new(1.0, 0.0),
624 Complex64::new(0.0, 0.0),
625 Complex64::new(0.0, 0.0),
626 Complex64::new(0.0, 0.0),
627 Complex64::new(0.0, 0.0),
628 Complex64::new(0.0, 0.0),
629 Complex64::new(0.0, 0.0),
630 Complex64::new(0.0, 0.0),
631 ]);
632
633 let probs = circuit
634 .forward(&state)
635 .expect("forward pass should succeed");
636 assert_eq!(probs.len(), 2);
637
638 let sum: f64 = probs.iter().sum();
639 assert!((sum - 1.0).abs() < 1e-6);
640 }
641
642 #[test]
643 fn test_quantum_maml() {
644 let config = QuantumMetaLearningConfig {
645 num_qubits: 2,
646 circuit_depth: 2,
647 inner_lr: 0.01,
648 outer_lr: 0.001,
649 inner_steps: 3,
650 n_support: 2,
651 n_query: 5,
652 n_way: 2,
653 meta_batch_size: 2,
654 };
655
656 let maml = QuantumMAML::new(config.clone());
657
658 let task = QuantumTask::random(
659 config.num_qubits,
660 config.n_way,
661 config.n_support,
662 config.n_query,
663 );
664
665 let adapted_model = maml.adapt(&task).expect("MAML adaptation should succeed");
666 let probs = adapted_model
667 .forward(&task.query_states[0])
668 .expect("adapted model forward pass should succeed");
669
670 assert_eq!(probs.len(), config.n_way);
671 }
672}