1use crate::autodiff::optimizers::Optimizer;
7use crate::error::{MLError, Result};
8use crate::optimization::OptimizationMethod;
9use crate::qnn::{QNNLayerType, QuantumNeuralNetwork};
10use quantrs2_circuit::builder::{Circuit, Simulator};
11use quantrs2_core::gate::{
12 single::{RotationX, RotationY, RotationZ},
13 GateOp,
14};
15use quantrs2_sim::statevector::StateVectorSimulator;
16use scirs2_core::ndarray::{s, Array1, Array2, Array3, Axis};
17use scirs2_core::random::prelude::*;
18use std::collections::HashMap;
19
20#[derive(Debug, Clone, Copy)]
22pub enum MetaLearningAlgorithm {
23 MAML {
25 inner_steps: usize,
26 inner_lr: f64,
27 first_order: bool,
28 },
29
30 Reptile { inner_steps: usize, inner_lr: f64 },
32
33 ProtoMAML {
35 inner_steps: usize,
36 inner_lr: f64,
37 proto_weight: f64,
38 },
39
40 MetaSGD { inner_steps: usize },
42
43 ANIL { inner_steps: usize, inner_lr: f64 },
45}
46
47#[derive(Debug, Clone)]
49pub struct MetaTask {
50 pub id: String,
52
53 pub train_data: Vec<(Array1<f64>, usize)>,
55
56 pub test_data: Vec<(Array1<f64>, usize)>,
58
59 pub num_classes: usize,
61
62 pub metadata: HashMap<String, f64>,
64}
65
66pub struct QuantumMetaLearner {
68 algorithm: MetaLearningAlgorithm,
70
71 model: QuantumNeuralNetwork,
73
74 meta_params: Array1<f64>,
76
77 per_param_lr: Option<Array1<f64>>,
79
80 task_embeddings: HashMap<String, Array1<f64>>,
82
83 history: MetaLearningHistory,
85}
86
87#[derive(Debug, Clone)]
89pub struct MetaLearningHistory {
90 pub meta_train_losses: Vec<f64>,
92
93 pub meta_val_accuracies: Vec<f64>,
95
96 pub task_performance: HashMap<String, Vec<f64>>,
98}
99
100impl QuantumMetaLearner {
101 pub fn new(algorithm: MetaLearningAlgorithm, model: QuantumNeuralNetwork) -> Self {
103 let num_params = model.parameters.len();
104 let meta_params = model.parameters.clone();
105
106 let per_param_lr = match algorithm {
107 MetaLearningAlgorithm::MetaSGD { .. } => Some(Array1::from_elem(num_params, 0.01)),
108 _ => None,
109 };
110
111 Self {
112 algorithm,
113 model,
114 meta_params,
115 per_param_lr,
116 task_embeddings: HashMap::new(),
117 history: MetaLearningHistory {
118 meta_train_losses: Vec::new(),
119 meta_val_accuracies: Vec::new(),
120 task_performance: HashMap::new(),
121 },
122 }
123 }
124
125 pub fn meta_train(
127 &mut self,
128 tasks: &[MetaTask],
129 meta_optimizer: &mut dyn Optimizer,
130 meta_epochs: usize,
131 tasks_per_batch: usize,
132 ) -> Result<()> {
133 println!("Starting meta-training with {} tasks...", tasks.len());
134
135 for epoch in 0..meta_epochs {
136 let mut epoch_loss = 0.0;
137 let mut epoch_acc = 0.0;
138
139 let task_batch = self.sample_task_batch(tasks, tasks_per_batch);
141
142 match self.algorithm {
144 MetaLearningAlgorithm::MAML { .. } => {
145 let (loss, acc) = self.maml_update(&task_batch, meta_optimizer)?;
146 epoch_loss += loss;
147 epoch_acc += acc;
148 }
149 MetaLearningAlgorithm::Reptile { .. } => {
150 let (loss, acc) = self.reptile_update(&task_batch, meta_optimizer)?;
151 epoch_loss += loss;
152 epoch_acc += acc;
153 }
154 MetaLearningAlgorithm::ProtoMAML { .. } => {
155 let (loss, acc) = self.protomaml_update(&task_batch, meta_optimizer)?;
156 epoch_loss += loss;
157 epoch_acc += acc;
158 }
159 MetaLearningAlgorithm::MetaSGD { .. } => {
160 let (loss, acc) = self.metasgd_update(&task_batch, meta_optimizer)?;
161 epoch_loss += loss;
162 epoch_acc += acc;
163 }
164 MetaLearningAlgorithm::ANIL { .. } => {
165 let (loss, acc) = self.anil_update(&task_batch, meta_optimizer)?;
166 epoch_loss += loss;
167 epoch_acc += acc;
168 }
169 }
170
171 self.history.meta_train_losses.push(epoch_loss);
173 self.history.meta_val_accuracies.push(epoch_acc);
174
175 if epoch % 10 == 0 {
176 println!(
177 "Epoch {}: Loss = {:.4}, Accuracy = {:.2}%",
178 epoch,
179 epoch_loss,
180 epoch_acc * 100.0
181 );
182 }
183 }
184
185 Ok(())
186 }
187
188 fn maml_update(
190 &mut self,
191 tasks: &[MetaTask],
192 optimizer: &mut dyn Optimizer,
193 ) -> Result<(f64, f64)> {
194 let (inner_steps, inner_lr, first_order) = match self.algorithm {
195 MetaLearningAlgorithm::MAML {
196 inner_steps,
197 inner_lr,
198 first_order,
199 } => (inner_steps, inner_lr, first_order),
200 _ => unreachable!(),
201 };
202
203 let mut total_loss = 0.0;
204 let mut total_acc = 0.0;
205 let mut meta_gradients = Array1::zeros(self.meta_params.len());
206
207 for task in tasks {
208 let mut task_params = self.meta_params.clone();
210
211 for _ in 0..inner_steps {
213 let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
214 task_params = task_params - inner_lr * &grad;
215 }
216
217 let (query_loss, query_acc) = self.evaluate_task(&task.test_data, &task_params)?;
219 total_loss += query_loss;
220 total_acc += query_acc;
221
222 if !first_order {
224 let meta_grad = self.compute_maml_gradient(task, &task_params, inner_lr)?;
226 meta_gradients = meta_gradients + meta_grad;
227 } else {
228 let grad = self.compute_task_gradient(&task.test_data, &task_params)?;
230 meta_gradients = meta_gradients + grad;
231 }
232 }
233
234 meta_gradients = meta_gradients / tasks.len() as f64;
236 self.meta_params = self.meta_params.clone() - 0.001 * &meta_gradients; Ok((
239 total_loss / tasks.len() as f64,
240 total_acc / tasks.len() as f64,
241 ))
242 }
243
244 fn reptile_update(
246 &mut self,
247 tasks: &[MetaTask],
248 optimizer: &mut dyn Optimizer,
249 ) -> Result<(f64, f64)> {
250 let (inner_steps, inner_lr) = match self.algorithm {
251 MetaLearningAlgorithm::Reptile {
252 inner_steps,
253 inner_lr,
254 } => (inner_steps, inner_lr),
255 _ => unreachable!(),
256 };
257
258 let mut total_loss = 0.0;
259 let mut total_acc = 0.0;
260 let epsilon = 0.1; for task in tasks {
263 let mut task_params = self.meta_params.clone();
265
266 for _ in 0..inner_steps {
268 let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
269 task_params = task_params - inner_lr * &grad;
270 }
271
272 let (loss, acc) = self.evaluate_task(&task.test_data, &task_params)?;
274 total_loss += loss;
275 total_acc += acc;
276
277 let direction = &task_params - &self.meta_params;
279 self.meta_params = &self.meta_params + epsilon * &direction;
280 }
281
282 Ok((
283 total_loss / tasks.len() as f64,
284 total_acc / tasks.len() as f64,
285 ))
286 }
287
288 fn protomaml_update(
290 &mut self,
291 tasks: &[MetaTask],
292 optimizer: &mut dyn Optimizer,
293 ) -> Result<(f64, f64)> {
294 let (inner_steps, inner_lr, proto_weight) = match self.algorithm {
295 MetaLearningAlgorithm::ProtoMAML {
296 inner_steps,
297 inner_lr,
298 proto_weight,
299 } => (inner_steps, inner_lr, proto_weight),
300 _ => unreachable!(),
301 };
302
303 let mut total_loss = 0.0;
304 let mut total_acc = 0.0;
305
306 for task in tasks {
307 let prototypes = self.compute_prototypes(&task.train_data, task.num_classes)?;
309
310 let mut task_params = self.meta_params.clone();
312
313 for _ in 0..inner_steps {
315 let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
316 let proto_reg =
317 self.prototype_regularization(&task.train_data, &prototypes, &task_params)?;
318 task_params = task_params - inner_lr * (&grad + proto_weight * &proto_reg);
319 }
320
321 let (loss, acc) =
323 self.evaluate_with_prototypes(&task.test_data, &prototypes, &task_params)?;
324 total_loss += loss;
325 total_acc += acc;
326 }
327
328 Ok((
329 total_loss / tasks.len() as f64,
330 total_acc / tasks.len() as f64,
331 ))
332 }
333
334 fn metasgd_update(
336 &mut self,
337 tasks: &[MetaTask],
338 optimizer: &mut dyn Optimizer,
339 ) -> Result<(f64, f64)> {
340 let inner_steps = match self.algorithm {
341 MetaLearningAlgorithm::MetaSGD { inner_steps } => inner_steps,
342 _ => unreachable!(),
343 };
344
345 let mut total_loss = 0.0;
346 let mut total_acc = 0.0;
347 let mut meta_lr_gradients = Array1::zeros(self.meta_params.len());
348
349 for task in tasks {
350 let mut task_params = self.meta_params.clone();
351
352 for _ in 0..inner_steps {
354 let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
355 let lr = self
356 .per_param_lr
357 .as_ref()
358 .expect("per_param_lr must be initialized for MetaSGD");
359 task_params = task_params - lr * &grad;
360 }
361
362 let (loss, acc) = self.evaluate_task(&task.test_data, &task_params)?;
364 total_loss += loss;
365 total_acc += acc;
366
367 let lr_grad = self.compute_lr_gradient(task, &task_params)?;
369 meta_lr_gradients = meta_lr_gradients + lr_grad;
370 }
371
372 if let Some(ref mut lr) = self.per_param_lr {
374 *lr = lr.clone() - &(0.001 * &meta_lr_gradients / tasks.len() as f64);
375 }
376
377 Ok((
378 total_loss / tasks.len() as f64,
379 total_acc / tasks.len() as f64,
380 ))
381 }
382
383 fn anil_update(
385 &mut self,
386 tasks: &[MetaTask],
387 optimizer: &mut dyn Optimizer,
388 ) -> Result<(f64, f64)> {
389 let (inner_steps, inner_lr) = match self.algorithm {
390 MetaLearningAlgorithm::ANIL {
391 inner_steps,
392 inner_lr,
393 } => (inner_steps, inner_lr),
394 _ => unreachable!(),
395 };
396
397 let num_params = self.meta_params.len();
399 let final_layer_start = (num_params * 3) / 4; let mut total_loss = 0.0;
402 let mut total_acc = 0.0;
403
404 for task in tasks {
405 let mut task_params = self.meta_params.clone();
406
407 for _ in 0..inner_steps {
409 let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
410
411 for i in final_layer_start..num_params {
413 task_params[i] -= inner_lr * grad[i];
414 }
415 }
416
417 let (loss, acc) = self.evaluate_task(&task.test_data, &task_params)?;
418 total_loss += loss;
419 total_acc += acc;
420 }
421
422 Ok((
423 total_loss / tasks.len() as f64,
424 total_acc / tasks.len() as f64,
425 ))
426 }
427
428 fn compute_task_gradient(
430 &self,
431 data: &[(Array1<f64>, usize)],
432 params: &Array1<f64>,
433 ) -> Result<Array1<f64>> {
434 Ok(Array1::zeros(params.len()))
436 }
437
438 fn evaluate_task(
440 &self,
441 data: &[(Array1<f64>, usize)],
442 params: &Array1<f64>,
443 ) -> Result<(f64, f64)> {
444 let loss = 0.5 + 0.5 * thread_rng().gen::<f64>();
446 let acc = 0.5 + 0.3 * thread_rng().gen::<f64>();
447 Ok((loss, acc))
448 }
449
450 fn compute_maml_gradient(
452 &self,
453 task: &MetaTask,
454 adapted_params: &Array1<f64>,
455 inner_lr: f64,
456 ) -> Result<Array1<f64>> {
457 Ok(Array1::zeros(self.meta_params.len()))
459 }
460
461 fn compute_prototypes(
463 &self,
464 data: &[(Array1<f64>, usize)],
465 num_classes: usize,
466 ) -> Result<Vec<Array1<f64>>> {
467 let feature_dim = 16; let mut prototypes = vec![Array1::zeros(feature_dim); num_classes];
469 let mut counts = vec![0; num_classes];
470
471 for (x, label) in data {
473 counts[*label] += 1;
474 }
475
476 Ok(prototypes)
477 }
478
479 fn prototype_regularization(
481 &self,
482 data: &[(Array1<f64>, usize)],
483 prototypes: &[Array1<f64>],
484 params: &Array1<f64>,
485 ) -> Result<Array1<f64>> {
486 Ok(Array1::zeros(params.len()))
488 }
489
490 fn evaluate_with_prototypes(
492 &self,
493 data: &[(Array1<f64>, usize)],
494 prototypes: &[Array1<f64>],
495 params: &Array1<f64>,
496 ) -> Result<(f64, f64)> {
497 Ok((0.3, 0.7))
499 }
500
501 fn compute_lr_gradient(
503 &self,
504 task: &MetaTask,
505 adapted_params: &Array1<f64>,
506 ) -> Result<Array1<f64>> {
507 Ok(Array1::zeros(self.meta_params.len()))
509 }
510
511 fn sample_task_batch(&self, tasks: &[MetaTask], batch_size: usize) -> Vec<MetaTask> {
513 let mut batch = Vec::new();
514 let mut rng = thread_rng();
515
516 for _ in 0..batch_size.min(tasks.len()) {
517 let idx = rng.gen_range(0..tasks.len());
518 batch.push(tasks[idx].clone());
519 }
520
521 batch
522 }
523
524 pub fn adapt_to_task(&mut self, task: &MetaTask) -> Result<Array1<f64>> {
526 let adapted_params = match self.algorithm {
527 MetaLearningAlgorithm::MAML {
528 inner_steps,
529 inner_lr,
530 ..
531 }
532 | MetaLearningAlgorithm::Reptile {
533 inner_steps,
534 inner_lr,
535 }
536 | MetaLearningAlgorithm::ProtoMAML {
537 inner_steps,
538 inner_lr,
539 ..
540 }
541 | MetaLearningAlgorithm::ANIL {
542 inner_steps,
543 inner_lr,
544 } => {
545 let mut params = self.meta_params.clone();
546 for _ in 0..inner_steps {
547 let grad = self.compute_task_gradient(&task.train_data, ¶ms)?;
548 params = params - inner_lr * &grad;
549 }
550 params
551 }
552 MetaLearningAlgorithm::MetaSGD { inner_steps } => {
553 let mut params = self.meta_params.clone();
554 let lr = self
555 .per_param_lr
556 .as_ref()
557 .expect("per_param_lr must be initialized for MetaSGD");
558 for _ in 0..inner_steps {
559 let grad = self.compute_task_gradient(&task.train_data, ¶ms)?;
560 params = params - lr * &grad;
561 }
562 params
563 }
564 };
565
566 Ok(adapted_params)
567 }
568
569 pub fn get_task_embedding(&self, task_id: &str) -> Option<&Array1<f64>> {
571 self.task_embeddings.get(task_id)
572 }
573
574 pub fn meta_params(&self) -> &Array1<f64> {
576 &self.meta_params
577 }
578
579 pub fn per_param_lr(&self) -> Option<&Array1<f64>> {
581 self.per_param_lr.as_ref()
582 }
583}
584
585pub struct ContinualMetaLearner {
587 meta_learner: QuantumMetaLearner,
589
590 memory_buffer: Vec<MetaTask>,
592
593 memory_capacity: usize,
595
596 replay_ratio: f64,
598}
599
600impl ContinualMetaLearner {
601 pub fn new(
603 meta_learner: QuantumMetaLearner,
604 memory_capacity: usize,
605 replay_ratio: f64,
606 ) -> Self {
607 Self {
608 meta_learner,
609 memory_buffer: Vec::new(),
610 memory_capacity,
611 replay_ratio,
612 }
613 }
614
615 pub fn learn_task(&mut self, new_task: MetaTask) -> Result<()> {
617 if self.memory_buffer.len() < self.memory_capacity {
619 self.memory_buffer.push(new_task.clone());
620 } else {
621 let idx = fastrand::usize(0..self.memory_buffer.len());
622 self.memory_buffer[idx] = new_task.clone();
623 }
624
625 let num_replay = (self.memory_buffer.len() as f64 * self.replay_ratio) as usize;
627 let mut task_batch = vec![new_task];
628
629 for _ in 0..num_replay {
630 let idx = fastrand::usize(0..self.memory_buffer.len());
631 task_batch.push(self.memory_buffer[idx].clone());
632 }
633
634 let mut dummy_optimizer = crate::autodiff::optimizers::Adam::new(0.001);
636 self.meta_learner
637 .meta_train(&task_batch, &mut dummy_optimizer, 10, task_batch.len())?;
638
639 Ok(())
640 }
641
642 pub fn memory_buffer_len(&self) -> usize {
644 self.memory_buffer.len()
645 }
646}
647
648pub struct TaskGenerator {
650 feature_dim: usize,
652
653 num_classes: usize,
655
656 task_params: HashMap<String, f64>,
658}
659
660impl TaskGenerator {
661 pub fn new(feature_dim: usize, num_classes: usize) -> Self {
663 Self {
664 feature_dim,
665 num_classes,
666 task_params: HashMap::new(),
667 }
668 }
669
670 pub fn generate_sinusoid_task(&self, num_samples: usize) -> MetaTask {
672 let amplitude = 0.1 + 4.9 * thread_rng().gen::<f64>();
673 let phase = 2.0 * std::f64::consts::PI * thread_rng().gen::<f64>();
674
675 let mut train_data = Vec::new();
676 let mut test_data = Vec::new();
677
678 for i in 0..num_samples {
680 let x = -5.0 + 10.0 * thread_rng().gen::<f64>();
681 let y = amplitude * (x + phase).sin();
682
683 let input = Array1::from_vec(vec![x]);
684 let label = if y > 0.0 { 1 } else { 0 }; if i < num_samples / 2 {
687 train_data.push((input, label));
688 } else {
689 test_data.push((input, label));
690 }
691 }
692
693 MetaTask {
694 id: format!("sin_a{:.2}_p{:.2}", amplitude, phase),
695 train_data,
696 test_data,
697 num_classes: 2,
698 metadata: vec![
699 ("amplitude".to_string(), amplitude),
700 ("phase".to_string(), phase),
701 ]
702 .into_iter()
703 .collect(),
704 }
705 }
706
707 pub fn generate_rotation_task(&self, num_samples: usize) -> MetaTask {
709 let angle = 2.0 * std::f64::consts::PI * thread_rng().gen::<f64>();
710 let cos_a = angle.cos();
711 let sin_a = angle.sin();
712
713 let mut train_data = Vec::new();
714 let mut test_data = Vec::new();
715
716 for i in 0..num_samples {
717 let mut features = Array1::zeros(self.feature_dim);
719 let label = i % self.num_classes;
720
721 for j in 0..self.feature_dim {
723 features[j] = if j % self.num_classes == label {
724 1.0
725 } else {
726 0.0
727 };
728 features[j] += 0.1 * thread_rng().gen::<f64>();
729 }
730
731 if self.feature_dim >= 2 {
733 let x = features[0];
734 let y = features[1];
735 features[0] = cos_a * x - sin_a * y;
736 features[1] = sin_a * x + cos_a * y;
737 }
738
739 if i < num_samples / 2 {
740 train_data.push((features, label));
741 } else {
742 test_data.push((features, label));
743 }
744 }
745
746 MetaTask {
747 id: format!("rot_{:.2}", angle),
748 train_data,
749 test_data,
750 num_classes: self.num_classes,
751 metadata: vec![("rotation_angle".to_string(), angle)]
752 .into_iter()
753 .collect(),
754 }
755 }
756}
757
758#[cfg(test)]
759mod tests {
760 use super::*;
761 use crate::autodiff::optimizers::Adam;
762 use crate::qnn::QNNLayerType;
763
764 #[test]
765 fn test_task_generator() {
766 let generator = TaskGenerator::new(4, 2);
767
768 let sin_task = generator.generate_sinusoid_task(20);
769 assert_eq!(sin_task.train_data.len(), 10);
770 assert_eq!(sin_task.test_data.len(), 10);
771
772 let rot_task = generator.generate_rotation_task(30);
773 assert_eq!(rot_task.train_data.len(), 15);
774 assert_eq!(rot_task.test_data.len(), 15);
775 }
776
777 #[test]
778 fn test_meta_learner_creation() {
779 let layers = vec![
780 QNNLayerType::EncodingLayer { num_features: 4 },
781 QNNLayerType::VariationalLayer { num_params: 8 },
782 QNNLayerType::MeasurementLayer {
783 measurement_basis: "computational".to_string(),
784 },
785 ];
786
787 let qnn = QuantumNeuralNetwork::new(layers, 4, 4, 2).expect("Failed to create QNN");
788
789 let maml_algo = MetaLearningAlgorithm::MAML {
790 inner_steps: 5,
791 inner_lr: 0.01,
792 first_order: true,
793 };
794
795 let meta_learner = QuantumMetaLearner::new(maml_algo, qnn);
796 assert!(meta_learner.per_param_lr.is_none());
797
798 let layers2 = vec![
800 QNNLayerType::EncodingLayer { num_features: 4 },
801 QNNLayerType::VariationalLayer { num_params: 8 },
802 ];
803 let qnn2 =
804 QuantumNeuralNetwork::new(layers2, 4, 4, 2).expect("Failed to create QNN for Meta-SGD");
805
806 let metasgd_algo = MetaLearningAlgorithm::MetaSGD { inner_steps: 3 };
807 let meta_sgd = QuantumMetaLearner::new(metasgd_algo, qnn2);
808 assert!(meta_sgd.per_param_lr.is_some());
809 }
810
811 #[test]
812 fn test_task_adaptation() {
813 let layers = vec![
814 QNNLayerType::EncodingLayer { num_features: 2 },
815 QNNLayerType::VariationalLayer { num_params: 6 },
816 ];
817
818 let qnn = QuantumNeuralNetwork::new(layers, 4, 2, 2).expect("Failed to create QNN");
819 let algo = MetaLearningAlgorithm::Reptile {
820 inner_steps: 5,
821 inner_lr: 0.01,
822 };
823
824 let mut meta_learner = QuantumMetaLearner::new(algo, qnn);
825
826 let generator = TaskGenerator::new(2, 2);
828 let task = generator.generate_rotation_task(20);
829
830 let adapted_params = meta_learner
832 .adapt_to_task(&task)
833 .expect("Task adaptation should succeed");
834 assert_eq!(adapted_params.len(), meta_learner.meta_params.len());
835 }
836}