1use crate::autodiff::optimizers::Optimizer;
7use crate::error::{MLError, Result};
8use crate::optimization::OptimizationMethod;
9use crate::qnn::{QNNLayerType, QuantumNeuralNetwork};
10use quantrs2_circuit::builder::{Circuit, Simulator};
11use quantrs2_core::gate::{
12 single::{RotationX, RotationY, RotationZ},
13 GateOp,
14};
15use quantrs2_sim::statevector::StateVectorSimulator;
16use scirs2_core::ndarray::{s, Array1, Array2, Array3, Axis};
17use scirs2_core::random::prelude::*;
18use std::collections::HashMap;
19
20#[derive(Debug, Clone, Copy)]
22pub enum MetaLearningAlgorithm {
23 MAML {
25 inner_steps: usize,
26 inner_lr: f64,
27 first_order: bool,
28 },
29
30 Reptile { inner_steps: usize, inner_lr: f64 },
32
33 ProtoMAML {
35 inner_steps: usize,
36 inner_lr: f64,
37 proto_weight: f64,
38 },
39
40 MetaSGD { inner_steps: usize },
42
43 ANIL { inner_steps: usize, inner_lr: f64 },
45}
46
47#[derive(Debug, Clone)]
49pub struct MetaTask {
50 pub id: String,
52
53 pub train_data: Vec<(Array1<f64>, usize)>,
55
56 pub test_data: Vec<(Array1<f64>, usize)>,
58
59 pub num_classes: usize,
61
62 pub metadata: HashMap<String, f64>,
64}
65
66pub struct QuantumMetaLearner {
68 algorithm: MetaLearningAlgorithm,
70
71 model: QuantumNeuralNetwork,
73
74 meta_params: Array1<f64>,
76
77 per_param_lr: Option<Array1<f64>>,
79
80 task_embeddings: HashMap<String, Array1<f64>>,
82
83 history: MetaLearningHistory,
85}
86
87#[derive(Debug, Clone)]
89pub struct MetaLearningHistory {
90 pub meta_train_losses: Vec<f64>,
92
93 pub meta_val_accuracies: Vec<f64>,
95
96 pub task_performance: HashMap<String, Vec<f64>>,
98}
99
100impl QuantumMetaLearner {
101 pub fn new(algorithm: MetaLearningAlgorithm, model: QuantumNeuralNetwork) -> Self {
103 let num_params = model.parameters.len();
104 let meta_params = model.parameters.clone();
105
106 let per_param_lr = match algorithm {
107 MetaLearningAlgorithm::MetaSGD { .. } => Some(Array1::from_elem(num_params, 0.01)),
108 _ => None,
109 };
110
111 Self {
112 algorithm,
113 model,
114 meta_params,
115 per_param_lr,
116 task_embeddings: HashMap::new(),
117 history: MetaLearningHistory {
118 meta_train_losses: Vec::new(),
119 meta_val_accuracies: Vec::new(),
120 task_performance: HashMap::new(),
121 },
122 }
123 }
124
125 pub fn meta_train(
127 &mut self,
128 tasks: &[MetaTask],
129 meta_optimizer: &mut dyn Optimizer,
130 meta_epochs: usize,
131 tasks_per_batch: usize,
132 ) -> Result<()> {
133 println!("Starting meta-training with {} tasks...", tasks.len());
134
135 for epoch in 0..meta_epochs {
136 let mut epoch_loss = 0.0;
137 let mut epoch_acc = 0.0;
138
139 let task_batch = self.sample_task_batch(tasks, tasks_per_batch);
141
142 match self.algorithm {
144 MetaLearningAlgorithm::MAML { .. } => {
145 let (loss, acc) = self.maml_update(&task_batch, meta_optimizer)?;
146 epoch_loss += loss;
147 epoch_acc += acc;
148 }
149 MetaLearningAlgorithm::Reptile { .. } => {
150 let (loss, acc) = self.reptile_update(&task_batch, meta_optimizer)?;
151 epoch_loss += loss;
152 epoch_acc += acc;
153 }
154 MetaLearningAlgorithm::ProtoMAML { .. } => {
155 let (loss, acc) = self.protomaml_update(&task_batch, meta_optimizer)?;
156 epoch_loss += loss;
157 epoch_acc += acc;
158 }
159 MetaLearningAlgorithm::MetaSGD { .. } => {
160 let (loss, acc) = self.metasgd_update(&task_batch, meta_optimizer)?;
161 epoch_loss += loss;
162 epoch_acc += acc;
163 }
164 MetaLearningAlgorithm::ANIL { .. } => {
165 let (loss, acc) = self.anil_update(&task_batch, meta_optimizer)?;
166 epoch_loss += loss;
167 epoch_acc += acc;
168 }
169 }
170
171 self.history.meta_train_losses.push(epoch_loss);
173 self.history.meta_val_accuracies.push(epoch_acc);
174
175 if epoch % 10 == 0 {
176 println!(
177 "Epoch {}: Loss = {:.4}, Accuracy = {:.2}%",
178 epoch,
179 epoch_loss,
180 epoch_acc * 100.0
181 );
182 }
183 }
184
185 Ok(())
186 }
187
188 fn maml_update(
190 &mut self,
191 tasks: &[MetaTask],
192 optimizer: &mut dyn Optimizer,
193 ) -> Result<(f64, f64)> {
194 let (inner_steps, inner_lr, first_order) = match self.algorithm {
195 MetaLearningAlgorithm::MAML {
196 inner_steps,
197 inner_lr,
198 first_order,
199 } => (inner_steps, inner_lr, first_order),
200 _ => unreachable!(),
201 };
202
203 let mut total_loss = 0.0;
204 let mut total_acc = 0.0;
205 let mut meta_gradients = Array1::zeros(self.meta_params.len());
206
207 for task in tasks {
208 let mut task_params = self.meta_params.clone();
210
211 for _ in 0..inner_steps {
213 let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
214 task_params = task_params - inner_lr * &grad;
215 }
216
217 let (query_loss, query_acc) = self.evaluate_task(&task.test_data, &task_params)?;
219 total_loss += query_loss;
220 total_acc += query_acc;
221
222 if !first_order {
224 let meta_grad = self.compute_maml_gradient(task, &task_params, inner_lr)?;
226 meta_gradients = meta_gradients + meta_grad;
227 } else {
228 let grad = self.compute_task_gradient(&task.test_data, &task_params)?;
230 meta_gradients = meta_gradients + grad;
231 }
232 }
233
234 meta_gradients = meta_gradients / tasks.len() as f64;
236 self.meta_params = self.meta_params.clone() - 0.001 * &meta_gradients; Ok((
239 total_loss / tasks.len() as f64,
240 total_acc / tasks.len() as f64,
241 ))
242 }
243
244 fn reptile_update(
246 &mut self,
247 tasks: &[MetaTask],
248 optimizer: &mut dyn Optimizer,
249 ) -> Result<(f64, f64)> {
250 let (inner_steps, inner_lr) = match self.algorithm {
251 MetaLearningAlgorithm::Reptile {
252 inner_steps,
253 inner_lr,
254 } => (inner_steps, inner_lr),
255 _ => unreachable!(),
256 };
257
258 let mut total_loss = 0.0;
259 let mut total_acc = 0.0;
260 let epsilon = 0.1; for task in tasks {
263 let mut task_params = self.meta_params.clone();
265
266 for _ in 0..inner_steps {
268 let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
269 task_params = task_params - inner_lr * &grad;
270 }
271
272 let (loss, acc) = self.evaluate_task(&task.test_data, &task_params)?;
274 total_loss += loss;
275 total_acc += acc;
276
277 let direction = &task_params - &self.meta_params;
279 self.meta_params = &self.meta_params + epsilon * &direction;
280 }
281
282 Ok((
283 total_loss / tasks.len() as f64,
284 total_acc / tasks.len() as f64,
285 ))
286 }
287
288 fn protomaml_update(
290 &mut self,
291 tasks: &[MetaTask],
292 optimizer: &mut dyn Optimizer,
293 ) -> Result<(f64, f64)> {
294 let (inner_steps, inner_lr, proto_weight) = match self.algorithm {
295 MetaLearningAlgorithm::ProtoMAML {
296 inner_steps,
297 inner_lr,
298 proto_weight,
299 } => (inner_steps, inner_lr, proto_weight),
300 _ => unreachable!(),
301 };
302
303 let mut total_loss = 0.0;
304 let mut total_acc = 0.0;
305
306 for task in tasks {
307 let prototypes = self.compute_prototypes(&task.train_data, task.num_classes)?;
309
310 let mut task_params = self.meta_params.clone();
312
313 for _ in 0..inner_steps {
315 let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
316 let proto_reg =
317 self.prototype_regularization(&task.train_data, &prototypes, &task_params)?;
318 task_params = task_params - inner_lr * (&grad + proto_weight * &proto_reg);
319 }
320
321 let (loss, acc) =
323 self.evaluate_with_prototypes(&task.test_data, &prototypes, &task_params)?;
324 total_loss += loss;
325 total_acc += acc;
326 }
327
328 Ok((
329 total_loss / tasks.len() as f64,
330 total_acc / tasks.len() as f64,
331 ))
332 }
333
334 fn metasgd_update(
336 &mut self,
337 tasks: &[MetaTask],
338 optimizer: &mut dyn Optimizer,
339 ) -> Result<(f64, f64)> {
340 let inner_steps = match self.algorithm {
341 MetaLearningAlgorithm::MetaSGD { inner_steps } => inner_steps,
342 _ => unreachable!(),
343 };
344
345 let mut total_loss = 0.0;
346 let mut total_acc = 0.0;
347 let mut meta_lr_gradients = Array1::zeros(self.meta_params.len());
348
349 for task in tasks {
350 let mut task_params = self.meta_params.clone();
351
352 for _ in 0..inner_steps {
354 let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
355 let lr = self.per_param_lr.as_ref().unwrap();
356 task_params = task_params - lr * &grad;
357 }
358
359 let (loss, acc) = self.evaluate_task(&task.test_data, &task_params)?;
361 total_loss += loss;
362 total_acc += acc;
363
364 let lr_grad = self.compute_lr_gradient(task, &task_params)?;
366 meta_lr_gradients = meta_lr_gradients + lr_grad;
367 }
368
369 if let Some(ref mut lr) = self.per_param_lr {
371 *lr = lr.clone() - &(0.001 * &meta_lr_gradients / tasks.len() as f64);
372 }
373
374 Ok((
375 total_loss / tasks.len() as f64,
376 total_acc / tasks.len() as f64,
377 ))
378 }
379
380 fn anil_update(
382 &mut self,
383 tasks: &[MetaTask],
384 optimizer: &mut dyn Optimizer,
385 ) -> Result<(f64, f64)> {
386 let (inner_steps, inner_lr) = match self.algorithm {
387 MetaLearningAlgorithm::ANIL {
388 inner_steps,
389 inner_lr,
390 } => (inner_steps, inner_lr),
391 _ => unreachable!(),
392 };
393
394 let num_params = self.meta_params.len();
396 let final_layer_start = (num_params * 3) / 4; let mut total_loss = 0.0;
399 let mut total_acc = 0.0;
400
401 for task in tasks {
402 let mut task_params = self.meta_params.clone();
403
404 for _ in 0..inner_steps {
406 let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
407
408 for i in final_layer_start..num_params {
410 task_params[i] -= inner_lr * grad[i];
411 }
412 }
413
414 let (loss, acc) = self.evaluate_task(&task.test_data, &task_params)?;
415 total_loss += loss;
416 total_acc += acc;
417 }
418
419 Ok((
420 total_loss / tasks.len() as f64,
421 total_acc / tasks.len() as f64,
422 ))
423 }
424
425 fn compute_task_gradient(
427 &self,
428 data: &[(Array1<f64>, usize)],
429 params: &Array1<f64>,
430 ) -> Result<Array1<f64>> {
431 Ok(Array1::zeros(params.len()))
433 }
434
435 fn evaluate_task(
437 &self,
438 data: &[(Array1<f64>, usize)],
439 params: &Array1<f64>,
440 ) -> Result<(f64, f64)> {
441 let loss = 0.5 + 0.5 * thread_rng().gen::<f64>();
443 let acc = 0.5 + 0.3 * thread_rng().gen::<f64>();
444 Ok((loss, acc))
445 }
446
447 fn compute_maml_gradient(
449 &self,
450 task: &MetaTask,
451 adapted_params: &Array1<f64>,
452 inner_lr: f64,
453 ) -> Result<Array1<f64>> {
454 Ok(Array1::zeros(self.meta_params.len()))
456 }
457
458 fn compute_prototypes(
460 &self,
461 data: &[(Array1<f64>, usize)],
462 num_classes: usize,
463 ) -> Result<Vec<Array1<f64>>> {
464 let feature_dim = 16; let mut prototypes = vec![Array1::zeros(feature_dim); num_classes];
466 let mut counts = vec![0; num_classes];
467
468 for (x, label) in data {
470 counts[*label] += 1;
471 }
472
473 Ok(prototypes)
474 }
475
476 fn prototype_regularization(
478 &self,
479 data: &[(Array1<f64>, usize)],
480 prototypes: &[Array1<f64>],
481 params: &Array1<f64>,
482 ) -> Result<Array1<f64>> {
483 Ok(Array1::zeros(params.len()))
485 }
486
487 fn evaluate_with_prototypes(
489 &self,
490 data: &[(Array1<f64>, usize)],
491 prototypes: &[Array1<f64>],
492 params: &Array1<f64>,
493 ) -> Result<(f64, f64)> {
494 Ok((0.3, 0.7))
496 }
497
498 fn compute_lr_gradient(
500 &self,
501 task: &MetaTask,
502 adapted_params: &Array1<f64>,
503 ) -> Result<Array1<f64>> {
504 Ok(Array1::zeros(self.meta_params.len()))
506 }
507
508 fn sample_task_batch(&self, tasks: &[MetaTask], batch_size: usize) -> Vec<MetaTask> {
510 let mut batch = Vec::new();
511 let mut rng = thread_rng();
512
513 for _ in 0..batch_size.min(tasks.len()) {
514 let idx = rng.gen_range(0..tasks.len());
515 batch.push(tasks[idx].clone());
516 }
517
518 batch
519 }
520
521 pub fn adapt_to_task(&mut self, task: &MetaTask) -> Result<Array1<f64>> {
523 let adapted_params = match self.algorithm {
524 MetaLearningAlgorithm::MAML {
525 inner_steps,
526 inner_lr,
527 ..
528 }
529 | MetaLearningAlgorithm::Reptile {
530 inner_steps,
531 inner_lr,
532 }
533 | MetaLearningAlgorithm::ProtoMAML {
534 inner_steps,
535 inner_lr,
536 ..
537 }
538 | MetaLearningAlgorithm::ANIL {
539 inner_steps,
540 inner_lr,
541 } => {
542 let mut params = self.meta_params.clone();
543 for _ in 0..inner_steps {
544 let grad = self.compute_task_gradient(&task.train_data, ¶ms)?;
545 params = params - inner_lr * &grad;
546 }
547 params
548 }
549 MetaLearningAlgorithm::MetaSGD { inner_steps } => {
550 let mut params = self.meta_params.clone();
551 let lr = self.per_param_lr.as_ref().unwrap();
552 for _ in 0..inner_steps {
553 let grad = self.compute_task_gradient(&task.train_data, ¶ms)?;
554 params = params - lr * &grad;
555 }
556 params
557 }
558 };
559
560 Ok(adapted_params)
561 }
562
563 pub fn get_task_embedding(&self, task_id: &str) -> Option<&Array1<f64>> {
565 self.task_embeddings.get(task_id)
566 }
567
568 pub fn meta_params(&self) -> &Array1<f64> {
570 &self.meta_params
571 }
572
573 pub fn per_param_lr(&self) -> Option<&Array1<f64>> {
575 self.per_param_lr.as_ref()
576 }
577}
578
579pub struct ContinualMetaLearner {
581 meta_learner: QuantumMetaLearner,
583
584 memory_buffer: Vec<MetaTask>,
586
587 memory_capacity: usize,
589
590 replay_ratio: f64,
592}
593
594impl ContinualMetaLearner {
595 pub fn new(
597 meta_learner: QuantumMetaLearner,
598 memory_capacity: usize,
599 replay_ratio: f64,
600 ) -> Self {
601 Self {
602 meta_learner,
603 memory_buffer: Vec::new(),
604 memory_capacity,
605 replay_ratio,
606 }
607 }
608
609 pub fn learn_task(&mut self, new_task: MetaTask) -> Result<()> {
611 if self.memory_buffer.len() < self.memory_capacity {
613 self.memory_buffer.push(new_task.clone());
614 } else {
615 let idx = fastrand::usize(0..self.memory_buffer.len());
616 self.memory_buffer[idx] = new_task.clone();
617 }
618
619 let num_replay = (self.memory_buffer.len() as f64 * self.replay_ratio) as usize;
621 let mut task_batch = vec![new_task];
622
623 for _ in 0..num_replay {
624 let idx = fastrand::usize(0..self.memory_buffer.len());
625 task_batch.push(self.memory_buffer[idx].clone());
626 }
627
628 let mut dummy_optimizer = crate::autodiff::optimizers::Adam::new(0.001);
630 self.meta_learner
631 .meta_train(&task_batch, &mut dummy_optimizer, 10, task_batch.len())?;
632
633 Ok(())
634 }
635
636 pub fn memory_buffer_len(&self) -> usize {
638 self.memory_buffer.len()
639 }
640}
641
642pub struct TaskGenerator {
644 feature_dim: usize,
646
647 num_classes: usize,
649
650 task_params: HashMap<String, f64>,
652}
653
654impl TaskGenerator {
655 pub fn new(feature_dim: usize, num_classes: usize) -> Self {
657 Self {
658 feature_dim,
659 num_classes,
660 task_params: HashMap::new(),
661 }
662 }
663
664 pub fn generate_sinusoid_task(&self, num_samples: usize) -> MetaTask {
666 let amplitude = 0.1 + 4.9 * thread_rng().gen::<f64>();
667 let phase = 2.0 * std::f64::consts::PI * thread_rng().gen::<f64>();
668
669 let mut train_data = Vec::new();
670 let mut test_data = Vec::new();
671
672 for i in 0..num_samples {
674 let x = -5.0 + 10.0 * thread_rng().gen::<f64>();
675 let y = amplitude * (x + phase).sin();
676
677 let input = Array1::from_vec(vec![x]);
678 let label = if y > 0.0 { 1 } else { 0 }; if i < num_samples / 2 {
681 train_data.push((input, label));
682 } else {
683 test_data.push((input, label));
684 }
685 }
686
687 MetaTask {
688 id: format!("sin_a{:.2}_p{:.2}", amplitude, phase),
689 train_data,
690 test_data,
691 num_classes: 2,
692 metadata: vec![
693 ("amplitude".to_string(), amplitude),
694 ("phase".to_string(), phase),
695 ]
696 .into_iter()
697 .collect(),
698 }
699 }
700
701 pub fn generate_rotation_task(&self, num_samples: usize) -> MetaTask {
703 let angle = 2.0 * std::f64::consts::PI * thread_rng().gen::<f64>();
704 let cos_a = angle.cos();
705 let sin_a = angle.sin();
706
707 let mut train_data = Vec::new();
708 let mut test_data = Vec::new();
709
710 for i in 0..num_samples {
711 let mut features = Array1::zeros(self.feature_dim);
713 let label = i % self.num_classes;
714
715 for j in 0..self.feature_dim {
717 features[j] = if j % self.num_classes == label {
718 1.0
719 } else {
720 0.0
721 };
722 features[j] += 0.1 * thread_rng().gen::<f64>();
723 }
724
725 if self.feature_dim >= 2 {
727 let x = features[0];
728 let y = features[1];
729 features[0] = cos_a * x - sin_a * y;
730 features[1] = sin_a * x + cos_a * y;
731 }
732
733 if i < num_samples / 2 {
734 train_data.push((features, label));
735 } else {
736 test_data.push((features, label));
737 }
738 }
739
740 MetaTask {
741 id: format!("rot_{:.2}", angle),
742 train_data,
743 test_data,
744 num_classes: self.num_classes,
745 metadata: vec![("rotation_angle".to_string(), angle)]
746 .into_iter()
747 .collect(),
748 }
749 }
750}
751
752#[cfg(test)]
753mod tests {
754 use super::*;
755 use crate::autodiff::optimizers::Adam;
756 use crate::qnn::QNNLayerType;
757
758 #[test]
759 fn test_task_generator() {
760 let generator = TaskGenerator::new(4, 2);
761
762 let sin_task = generator.generate_sinusoid_task(20);
763 assert_eq!(sin_task.train_data.len(), 10);
764 assert_eq!(sin_task.test_data.len(), 10);
765
766 let rot_task = generator.generate_rotation_task(30);
767 assert_eq!(rot_task.train_data.len(), 15);
768 assert_eq!(rot_task.test_data.len(), 15);
769 }
770
771 #[test]
772 fn test_meta_learner_creation() {
773 let layers = vec![
774 QNNLayerType::EncodingLayer { num_features: 4 },
775 QNNLayerType::VariationalLayer { num_params: 8 },
776 QNNLayerType::MeasurementLayer {
777 measurement_basis: "computational".to_string(),
778 },
779 ];
780
781 let qnn = QuantumNeuralNetwork::new(layers, 4, 4, 2).unwrap();
782
783 let maml_algo = MetaLearningAlgorithm::MAML {
784 inner_steps: 5,
785 inner_lr: 0.01,
786 first_order: true,
787 };
788
789 let meta_learner = QuantumMetaLearner::new(maml_algo, qnn);
790 assert!(meta_learner.per_param_lr.is_none());
791
792 let layers2 = vec![
794 QNNLayerType::EncodingLayer { num_features: 4 },
795 QNNLayerType::VariationalLayer { num_params: 8 },
796 ];
797 let qnn2 = QuantumNeuralNetwork::new(layers2, 4, 4, 2).unwrap();
798
799 let metasgd_algo = MetaLearningAlgorithm::MetaSGD { inner_steps: 3 };
800 let meta_sgd = QuantumMetaLearner::new(metasgd_algo, qnn2);
801 assert!(meta_sgd.per_param_lr.is_some());
802 }
803
804 #[test]
805 fn test_task_adaptation() {
806 let layers = vec![
807 QNNLayerType::EncodingLayer { num_features: 2 },
808 QNNLayerType::VariationalLayer { num_params: 6 },
809 ];
810
811 let qnn = QuantumNeuralNetwork::new(layers, 4, 2, 2).unwrap();
812 let algo = MetaLearningAlgorithm::Reptile {
813 inner_steps: 5,
814 inner_lr: 0.01,
815 };
816
817 let mut meta_learner = QuantumMetaLearner::new(algo, qnn);
818
819 let generator = TaskGenerator::new(2, 2);
821 let task = generator.generate_rotation_task(20);
822
823 let adapted_params = meta_learner.adapt_to_task(&task).unwrap();
825 assert_eq!(adapted_params.len(), meta_learner.meta_params.len());
826 }
827}