1use crate::autodiff::optimizers::Optimizer;
7use crate::error::{MLError, Result};
8use crate::optimization::OptimizationMethod;
9use crate::qnn::{QNNLayerType, QuantumNeuralNetwork};
10use ndarray::{s, Array1, Array2, Array3, Axis};
11use quantrs2_circuit::builder::{Circuit, Simulator};
12use quantrs2_core::gate::{
13 single::{RotationX, RotationY, RotationZ},
14 GateOp,
15};
16use quantrs2_sim::statevector::StateVectorSimulator;
17use std::collections::HashMap;
18
19#[derive(Debug, Clone, Copy)]
21pub enum MetaLearningAlgorithm {
22 MAML {
24 inner_steps: usize,
25 inner_lr: f64,
26 first_order: bool,
27 },
28
29 Reptile { inner_steps: usize, inner_lr: f64 },
31
32 ProtoMAML {
34 inner_steps: usize,
35 inner_lr: f64,
36 proto_weight: f64,
37 },
38
39 MetaSGD { inner_steps: usize },
41
42 ANIL { inner_steps: usize, inner_lr: f64 },
44}
45
46#[derive(Debug, Clone)]
48pub struct MetaTask {
49 pub id: String,
51
52 pub train_data: Vec<(Array1<f64>, usize)>,
54
55 pub test_data: Vec<(Array1<f64>, usize)>,
57
58 pub num_classes: usize,
60
61 pub metadata: HashMap<String, f64>,
63}
64
65pub struct QuantumMetaLearner {
67 algorithm: MetaLearningAlgorithm,
69
70 model: QuantumNeuralNetwork,
72
73 meta_params: Array1<f64>,
75
76 per_param_lr: Option<Array1<f64>>,
78
79 task_embeddings: HashMap<String, Array1<f64>>,
81
82 history: MetaLearningHistory,
84}
85
86#[derive(Debug, Clone)]
88pub struct MetaLearningHistory {
89 pub meta_train_losses: Vec<f64>,
91
92 pub meta_val_accuracies: Vec<f64>,
94
95 pub task_performance: HashMap<String, Vec<f64>>,
97}
98
99impl QuantumMetaLearner {
100 pub fn new(algorithm: MetaLearningAlgorithm, model: QuantumNeuralNetwork) -> Self {
102 let num_params = model.parameters.len();
103 let meta_params = model.parameters.clone();
104
105 let per_param_lr = match algorithm {
106 MetaLearningAlgorithm::MetaSGD { .. } => Some(Array1::from_elem(num_params, 0.01)),
107 _ => None,
108 };
109
110 Self {
111 algorithm,
112 model,
113 meta_params,
114 per_param_lr,
115 task_embeddings: HashMap::new(),
116 history: MetaLearningHistory {
117 meta_train_losses: Vec::new(),
118 meta_val_accuracies: Vec::new(),
119 task_performance: HashMap::new(),
120 },
121 }
122 }
123
124 pub fn meta_train(
126 &mut self,
127 tasks: &[MetaTask],
128 meta_optimizer: &mut dyn Optimizer,
129 meta_epochs: usize,
130 tasks_per_batch: usize,
131 ) -> Result<()> {
132 println!("Starting meta-training with {} tasks...", tasks.len());
133
134 for epoch in 0..meta_epochs {
135 let mut epoch_loss = 0.0;
136 let mut epoch_acc = 0.0;
137
138 let task_batch = self.sample_task_batch(tasks, tasks_per_batch);
140
141 match self.algorithm {
143 MetaLearningAlgorithm::MAML { .. } => {
144 let (loss, acc) = self.maml_update(&task_batch, meta_optimizer)?;
145 epoch_loss += loss;
146 epoch_acc += acc;
147 }
148 MetaLearningAlgorithm::Reptile { .. } => {
149 let (loss, acc) = self.reptile_update(&task_batch, meta_optimizer)?;
150 epoch_loss += loss;
151 epoch_acc += acc;
152 }
153 MetaLearningAlgorithm::ProtoMAML { .. } => {
154 let (loss, acc) = self.protomaml_update(&task_batch, meta_optimizer)?;
155 epoch_loss += loss;
156 epoch_acc += acc;
157 }
158 MetaLearningAlgorithm::MetaSGD { .. } => {
159 let (loss, acc) = self.metasgd_update(&task_batch, meta_optimizer)?;
160 epoch_loss += loss;
161 epoch_acc += acc;
162 }
163 MetaLearningAlgorithm::ANIL { .. } => {
164 let (loss, acc) = self.anil_update(&task_batch, meta_optimizer)?;
165 epoch_loss += loss;
166 epoch_acc += acc;
167 }
168 }
169
170 self.history.meta_train_losses.push(epoch_loss);
172 self.history.meta_val_accuracies.push(epoch_acc);
173
174 if epoch % 10 == 0 {
175 println!(
176 "Epoch {}: Loss = {:.4}, Accuracy = {:.2}%",
177 epoch,
178 epoch_loss,
179 epoch_acc * 100.0
180 );
181 }
182 }
183
184 Ok(())
185 }
186
187 fn maml_update(
189 &mut self,
190 tasks: &[MetaTask],
191 optimizer: &mut dyn Optimizer,
192 ) -> Result<(f64, f64)> {
193 let (inner_steps, inner_lr, first_order) = match self.algorithm {
194 MetaLearningAlgorithm::MAML {
195 inner_steps,
196 inner_lr,
197 first_order,
198 } => (inner_steps, inner_lr, first_order),
199 _ => unreachable!(),
200 };
201
202 let mut total_loss = 0.0;
203 let mut total_acc = 0.0;
204 let mut meta_gradients = Array1::zeros(self.meta_params.len());
205
206 for task in tasks {
207 let mut task_params = self.meta_params.clone();
209
210 for _ in 0..inner_steps {
212 let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
213 task_params = task_params - inner_lr * &grad;
214 }
215
216 let (query_loss, query_acc) = self.evaluate_task(&task.test_data, &task_params)?;
218 total_loss += query_loss;
219 total_acc += query_acc;
220
221 if !first_order {
223 let meta_grad = self.compute_maml_gradient(task, &task_params, inner_lr)?;
225 meta_gradients = meta_gradients + meta_grad;
226 } else {
227 let grad = self.compute_task_gradient(&task.test_data, &task_params)?;
229 meta_gradients = meta_gradients + grad;
230 }
231 }
232
233 meta_gradients = meta_gradients / tasks.len() as f64;
235 self.meta_params = self.meta_params.clone() - 0.001 * &meta_gradients; Ok((
238 total_loss / tasks.len() as f64,
239 total_acc / tasks.len() as f64,
240 ))
241 }
242
243 fn reptile_update(
245 &mut self,
246 tasks: &[MetaTask],
247 optimizer: &mut dyn Optimizer,
248 ) -> Result<(f64, f64)> {
249 let (inner_steps, inner_lr) = match self.algorithm {
250 MetaLearningAlgorithm::Reptile {
251 inner_steps,
252 inner_lr,
253 } => (inner_steps, inner_lr),
254 _ => unreachable!(),
255 };
256
257 let mut total_loss = 0.0;
258 let mut total_acc = 0.0;
259 let epsilon = 0.1; for task in tasks {
262 let mut task_params = self.meta_params.clone();
264
265 for _ in 0..inner_steps {
267 let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
268 task_params = task_params - inner_lr * &grad;
269 }
270
271 let (loss, acc) = self.evaluate_task(&task.test_data, &task_params)?;
273 total_loss += loss;
274 total_acc += acc;
275
276 let direction = &task_params - &self.meta_params;
278 self.meta_params = &self.meta_params + epsilon * &direction;
279 }
280
281 Ok((
282 total_loss / tasks.len() as f64,
283 total_acc / tasks.len() as f64,
284 ))
285 }
286
287 fn protomaml_update(
289 &mut self,
290 tasks: &[MetaTask],
291 optimizer: &mut dyn Optimizer,
292 ) -> Result<(f64, f64)> {
293 let (inner_steps, inner_lr, proto_weight) = match self.algorithm {
294 MetaLearningAlgorithm::ProtoMAML {
295 inner_steps,
296 inner_lr,
297 proto_weight,
298 } => (inner_steps, inner_lr, proto_weight),
299 _ => unreachable!(),
300 };
301
302 let mut total_loss = 0.0;
303 let mut total_acc = 0.0;
304
305 for task in tasks {
306 let prototypes = self.compute_prototypes(&task.train_data, task.num_classes)?;
308
309 let mut task_params = self.meta_params.clone();
311
312 for _ in 0..inner_steps {
314 let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
315 let proto_reg =
316 self.prototype_regularization(&task.train_data, &prototypes, &task_params)?;
317 task_params = task_params - inner_lr * (&grad + proto_weight * &proto_reg);
318 }
319
320 let (loss, acc) =
322 self.evaluate_with_prototypes(&task.test_data, &prototypes, &task_params)?;
323 total_loss += loss;
324 total_acc += acc;
325 }
326
327 Ok((
328 total_loss / tasks.len() as f64,
329 total_acc / tasks.len() as f64,
330 ))
331 }
332
333 fn metasgd_update(
335 &mut self,
336 tasks: &[MetaTask],
337 optimizer: &mut dyn Optimizer,
338 ) -> Result<(f64, f64)> {
339 let inner_steps = match self.algorithm {
340 MetaLearningAlgorithm::MetaSGD { inner_steps } => inner_steps,
341 _ => unreachable!(),
342 };
343
344 let mut total_loss = 0.0;
345 let mut total_acc = 0.0;
346 let mut meta_lr_gradients = Array1::zeros(self.meta_params.len());
347
348 for task in tasks {
349 let mut task_params = self.meta_params.clone();
350
351 for _ in 0..inner_steps {
353 let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
354 let lr = self.per_param_lr.as_ref().unwrap();
355 task_params = task_params - lr * &grad;
356 }
357
358 let (loss, acc) = self.evaluate_task(&task.test_data, &task_params)?;
360 total_loss += loss;
361 total_acc += acc;
362
363 let lr_grad = self.compute_lr_gradient(task, &task_params)?;
365 meta_lr_gradients = meta_lr_gradients + lr_grad;
366 }
367
368 if let Some(ref mut lr) = self.per_param_lr {
370 *lr = lr.clone() - &(0.001 * &meta_lr_gradients / tasks.len() as f64);
371 }
372
373 Ok((
374 total_loss / tasks.len() as f64,
375 total_acc / tasks.len() as f64,
376 ))
377 }
378
379 fn anil_update(
381 &mut self,
382 tasks: &[MetaTask],
383 optimizer: &mut dyn Optimizer,
384 ) -> Result<(f64, f64)> {
385 let (inner_steps, inner_lr) = match self.algorithm {
386 MetaLearningAlgorithm::ANIL {
387 inner_steps,
388 inner_lr,
389 } => (inner_steps, inner_lr),
390 _ => unreachable!(),
391 };
392
393 let num_params = self.meta_params.len();
395 let final_layer_start = (num_params * 3) / 4; let mut total_loss = 0.0;
398 let mut total_acc = 0.0;
399
400 for task in tasks {
401 let mut task_params = self.meta_params.clone();
402
403 for _ in 0..inner_steps {
405 let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
406
407 for i in final_layer_start..num_params {
409 task_params[i] -= inner_lr * grad[i];
410 }
411 }
412
413 let (loss, acc) = self.evaluate_task(&task.test_data, &task_params)?;
414 total_loss += loss;
415 total_acc += acc;
416 }
417
418 Ok((
419 total_loss / tasks.len() as f64,
420 total_acc / tasks.len() as f64,
421 ))
422 }
423
424 fn compute_task_gradient(
426 &self,
427 data: &[(Array1<f64>, usize)],
428 params: &Array1<f64>,
429 ) -> Result<Array1<f64>> {
430 Ok(Array1::zeros(params.len()))
432 }
433
434 fn evaluate_task(
436 &self,
437 data: &[(Array1<f64>, usize)],
438 params: &Array1<f64>,
439 ) -> Result<(f64, f64)> {
440 let loss = 0.5 + 0.5 * rand::random::<f64>();
442 let acc = 0.5 + 0.3 * rand::random::<f64>();
443 Ok((loss, acc))
444 }
445
446 fn compute_maml_gradient(
448 &self,
449 task: &MetaTask,
450 adapted_params: &Array1<f64>,
451 inner_lr: f64,
452 ) -> Result<Array1<f64>> {
453 Ok(Array1::zeros(self.meta_params.len()))
455 }
456
457 fn compute_prototypes(
459 &self,
460 data: &[(Array1<f64>, usize)],
461 num_classes: usize,
462 ) -> Result<Vec<Array1<f64>>> {
463 let feature_dim = 16; let mut prototypes = vec![Array1::zeros(feature_dim); num_classes];
465 let mut counts = vec![0; num_classes];
466
467 for (x, label) in data {
469 counts[*label] += 1;
470 }
471
472 Ok(prototypes)
473 }
474
475 fn prototype_regularization(
477 &self,
478 data: &[(Array1<f64>, usize)],
479 prototypes: &[Array1<f64>],
480 params: &Array1<f64>,
481 ) -> Result<Array1<f64>> {
482 Ok(Array1::zeros(params.len()))
484 }
485
486 fn evaluate_with_prototypes(
488 &self,
489 data: &[(Array1<f64>, usize)],
490 prototypes: &[Array1<f64>],
491 params: &Array1<f64>,
492 ) -> Result<(f64, f64)> {
493 Ok((0.3, 0.7))
495 }
496
497 fn compute_lr_gradient(
499 &self,
500 task: &MetaTask,
501 adapted_params: &Array1<f64>,
502 ) -> Result<Array1<f64>> {
503 Ok(Array1::zeros(self.meta_params.len()))
505 }
506
507 fn sample_task_batch(&self, tasks: &[MetaTask], batch_size: usize) -> Vec<MetaTask> {
509 let mut batch = Vec::new();
510 let mut rng = fastrand::Rng::new();
511
512 for _ in 0..batch_size.min(tasks.len()) {
513 let idx = rng.usize(0..tasks.len());
514 batch.push(tasks[idx].clone());
515 }
516
517 batch
518 }
519
520 pub fn adapt_to_task(&mut self, task: &MetaTask) -> Result<Array1<f64>> {
522 let adapted_params = match self.algorithm {
523 MetaLearningAlgorithm::MAML {
524 inner_steps,
525 inner_lr,
526 ..
527 }
528 | MetaLearningAlgorithm::Reptile {
529 inner_steps,
530 inner_lr,
531 }
532 | MetaLearningAlgorithm::ProtoMAML {
533 inner_steps,
534 inner_lr,
535 ..
536 }
537 | MetaLearningAlgorithm::ANIL {
538 inner_steps,
539 inner_lr,
540 } => {
541 let mut params = self.meta_params.clone();
542 for _ in 0..inner_steps {
543 let grad = self.compute_task_gradient(&task.train_data, ¶ms)?;
544 params = params - inner_lr * &grad;
545 }
546 params
547 }
548 MetaLearningAlgorithm::MetaSGD { inner_steps } => {
549 let mut params = self.meta_params.clone();
550 let lr = self.per_param_lr.as_ref().unwrap();
551 for _ in 0..inner_steps {
552 let grad = self.compute_task_gradient(&task.train_data, ¶ms)?;
553 params = params - lr * &grad;
554 }
555 params
556 }
557 };
558
559 Ok(adapted_params)
560 }
561
562 pub fn get_task_embedding(&self, task_id: &str) -> Option<&Array1<f64>> {
564 self.task_embeddings.get(task_id)
565 }
566
567 pub fn meta_params(&self) -> &Array1<f64> {
569 &self.meta_params
570 }
571
572 pub fn per_param_lr(&self) -> Option<&Array1<f64>> {
574 self.per_param_lr.as_ref()
575 }
576}
577
578pub struct ContinualMetaLearner {
580 meta_learner: QuantumMetaLearner,
582
583 memory_buffer: Vec<MetaTask>,
585
586 memory_capacity: usize,
588
589 replay_ratio: f64,
591}
592
593impl ContinualMetaLearner {
594 pub fn new(
596 meta_learner: QuantumMetaLearner,
597 memory_capacity: usize,
598 replay_ratio: f64,
599 ) -> Self {
600 Self {
601 meta_learner,
602 memory_buffer: Vec::new(),
603 memory_capacity,
604 replay_ratio,
605 }
606 }
607
608 pub fn learn_task(&mut self, new_task: MetaTask) -> Result<()> {
610 if self.memory_buffer.len() < self.memory_capacity {
612 self.memory_buffer.push(new_task.clone());
613 } else {
614 let idx = fastrand::usize(0..self.memory_buffer.len());
615 self.memory_buffer[idx] = new_task.clone();
616 }
617
618 let num_replay = (self.memory_buffer.len() as f64 * self.replay_ratio) as usize;
620 let mut task_batch = vec![new_task];
621
622 for _ in 0..num_replay {
623 let idx = fastrand::usize(0..self.memory_buffer.len());
624 task_batch.push(self.memory_buffer[idx].clone());
625 }
626
627 let mut dummy_optimizer = crate::autodiff::optimizers::Adam::new(0.001);
629 self.meta_learner
630 .meta_train(&task_batch, &mut dummy_optimizer, 10, task_batch.len())?;
631
632 Ok(())
633 }
634
635 pub fn memory_buffer_len(&self) -> usize {
637 self.memory_buffer.len()
638 }
639}
640
641pub struct TaskGenerator {
643 feature_dim: usize,
645
646 num_classes: usize,
648
649 task_params: HashMap<String, f64>,
651}
652
653impl TaskGenerator {
654 pub fn new(feature_dim: usize, num_classes: usize) -> Self {
656 Self {
657 feature_dim,
658 num_classes,
659 task_params: HashMap::new(),
660 }
661 }
662
663 pub fn generate_sinusoid_task(&self, num_samples: usize) -> MetaTask {
665 let amplitude = 0.1 + 4.9 * rand::random::<f64>();
666 let phase = 2.0 * std::f64::consts::PI * rand::random::<f64>();
667
668 let mut train_data = Vec::new();
669 let mut test_data = Vec::new();
670
671 for i in 0..num_samples {
673 let x = -5.0 + 10.0 * rand::random::<f64>();
674 let y = amplitude * (x + phase).sin();
675
676 let input = Array1::from_vec(vec![x]);
677 let label = if y > 0.0 { 1 } else { 0 }; if i < num_samples / 2 {
680 train_data.push((input, label));
681 } else {
682 test_data.push((input, label));
683 }
684 }
685
686 MetaTask {
687 id: format!("sin_a{:.2}_p{:.2}", amplitude, phase),
688 train_data,
689 test_data,
690 num_classes: 2,
691 metadata: vec![
692 ("amplitude".to_string(), amplitude),
693 ("phase".to_string(), phase),
694 ]
695 .into_iter()
696 .collect(),
697 }
698 }
699
700 pub fn generate_rotation_task(&self, num_samples: usize) -> MetaTask {
702 let angle = 2.0 * std::f64::consts::PI * rand::random::<f64>();
703 let cos_a = angle.cos();
704 let sin_a = angle.sin();
705
706 let mut train_data = Vec::new();
707 let mut test_data = Vec::new();
708
709 for i in 0..num_samples {
710 let mut features = Array1::zeros(self.feature_dim);
712 let label = i % self.num_classes;
713
714 for j in 0..self.feature_dim {
716 features[j] = if j % self.num_classes == label {
717 1.0
718 } else {
719 0.0
720 };
721 features[j] += 0.1 * rand::random::<f64>();
722 }
723
724 if self.feature_dim >= 2 {
726 let x = features[0];
727 let y = features[1];
728 features[0] = cos_a * x - sin_a * y;
729 features[1] = sin_a * x + cos_a * y;
730 }
731
732 if i < num_samples / 2 {
733 train_data.push((features, label));
734 } else {
735 test_data.push((features, label));
736 }
737 }
738
739 MetaTask {
740 id: format!("rot_{:.2}", angle),
741 train_data,
742 test_data,
743 num_classes: self.num_classes,
744 metadata: vec![("rotation_angle".to_string(), angle)]
745 .into_iter()
746 .collect(),
747 }
748 }
749}
750
751#[cfg(test)]
752mod tests {
753 use super::*;
754 use crate::autodiff::optimizers::Adam;
755 use crate::qnn::QNNLayerType;
756
757 #[test]
758 fn test_task_generator() {
759 let generator = TaskGenerator::new(4, 2);
760
761 let sin_task = generator.generate_sinusoid_task(20);
762 assert_eq!(sin_task.train_data.len(), 10);
763 assert_eq!(sin_task.test_data.len(), 10);
764
765 let rot_task = generator.generate_rotation_task(30);
766 assert_eq!(rot_task.train_data.len(), 15);
767 assert_eq!(rot_task.test_data.len(), 15);
768 }
769
770 #[test]
771 fn test_meta_learner_creation() {
772 let layers = vec![
773 QNNLayerType::EncodingLayer { num_features: 4 },
774 QNNLayerType::VariationalLayer { num_params: 8 },
775 QNNLayerType::MeasurementLayer {
776 measurement_basis: "computational".to_string(),
777 },
778 ];
779
780 let qnn = QuantumNeuralNetwork::new(layers, 4, 4, 2).unwrap();
781
782 let maml_algo = MetaLearningAlgorithm::MAML {
783 inner_steps: 5,
784 inner_lr: 0.01,
785 first_order: true,
786 };
787
788 let meta_learner = QuantumMetaLearner::new(maml_algo, qnn);
789 assert!(meta_learner.per_param_lr.is_none());
790
791 let layers2 = vec![
793 QNNLayerType::EncodingLayer { num_features: 4 },
794 QNNLayerType::VariationalLayer { num_params: 8 },
795 ];
796 let qnn2 = QuantumNeuralNetwork::new(layers2, 4, 4, 2).unwrap();
797
798 let metasgd_algo = MetaLearningAlgorithm::MetaSGD { inner_steps: 3 };
799 let meta_sgd = QuantumMetaLearner::new(metasgd_algo, qnn2);
800 assert!(meta_sgd.per_param_lr.is_some());
801 }
802
803 #[test]
804 fn test_task_adaptation() {
805 let layers = vec![
806 QNNLayerType::EncodingLayer { num_features: 2 },
807 QNNLayerType::VariationalLayer { num_params: 6 },
808 ];
809
810 let qnn = QuantumNeuralNetwork::new(layers, 4, 2, 2).unwrap();
811 let algo = MetaLearningAlgorithm::Reptile {
812 inner_steps: 5,
813 inner_lr: 0.01,
814 };
815
816 let mut meta_learner = QuantumMetaLearner::new(algo, qnn);
817
818 let generator = TaskGenerator::new(2, 2);
820 let task = generator.generate_rotation_task(20);
821
822 let adapted_params = meta_learner.adapt_to_task(&task).unwrap();
824 assert_eq!(adapted_params.len(), meta_learner.meta_params.len());
825 }
826}