quantrs2_ml/
meta_learning.rs

1//! Quantum Meta-Learning Algorithms
2//!
3//! This module implements various meta-learning algorithms adapted for quantum circuits,
4//! enabling quantum models to learn how to learn from limited data across multiple tasks.
5
6use crate::autodiff::optimizers::Optimizer;
7use crate::error::{MLError, Result};
8use crate::optimization::OptimizationMethod;
9use crate::qnn::{QNNLayerType, QuantumNeuralNetwork};
10use ndarray::{s, Array1, Array2, Array3, Axis};
11use quantrs2_circuit::builder::{Circuit, Simulator};
12use quantrs2_core::gate::{
13    single::{RotationX, RotationY, RotationZ},
14    GateOp,
15};
16use quantrs2_sim::statevector::StateVectorSimulator;
17use std::collections::HashMap;
18
19/// Meta-learning algorithm types
20#[derive(Debug, Clone, Copy)]
21pub enum MetaLearningAlgorithm {
22    /// Model-Agnostic Meta-Learning
23    MAML {
24        inner_steps: usize,
25        inner_lr: f64,
26        first_order: bool,
27    },
28
29    /// Reptile algorithm
30    Reptile { inner_steps: usize, inner_lr: f64 },
31
32    /// Prototypical MAML
33    ProtoMAML {
34        inner_steps: usize,
35        inner_lr: f64,
36        proto_weight: f64,
37    },
38
39    /// Meta-SGD with learnable inner learning rates
40    MetaSGD { inner_steps: usize },
41
42    /// Almost No Inner Loop (ANIL)
43    ANIL { inner_steps: usize, inner_lr: f64 },
44}
45
46/// Task definition for meta-learning
47#[derive(Debug, Clone)]
48pub struct MetaTask {
49    /// Task identifier
50    pub id: String,
51
52    /// Training data (support set)
53    pub train_data: Vec<(Array1<f64>, usize)>,
54
55    /// Test data (query set)
56    pub test_data: Vec<(Array1<f64>, usize)>,
57
58    /// Number of classes
59    pub num_classes: usize,
60
61    /// Task-specific metadata
62    pub metadata: HashMap<String, f64>,
63}
64
65/// Base quantum meta-learner
66pub struct QuantumMetaLearner {
67    /// Meta-learning algorithm
68    algorithm: MetaLearningAlgorithm,
69
70    /// Base quantum model
71    model: QuantumNeuralNetwork,
72
73    /// Meta-parameters
74    meta_params: Array1<f64>,
75
76    /// Per-parameter learning rates (for Meta-SGD)
77    per_param_lr: Option<Array1<f64>>,
78
79    /// Task embeddings
80    task_embeddings: HashMap<String, Array1<f64>>,
81
82    /// Training history
83    history: MetaLearningHistory,
84}
85
86/// Training history for meta-learning
87#[derive(Debug, Clone)]
88pub struct MetaLearningHistory {
89    /// Meta-train losses
90    pub meta_train_losses: Vec<f64>,
91
92    /// Meta-validation accuracies
93    pub meta_val_accuracies: Vec<f64>,
94
95    /// Per-task performance
96    pub task_performance: HashMap<String, Vec<f64>>,
97}
98
99impl QuantumMetaLearner {
100    /// Create a new quantum meta-learner
101    pub fn new(algorithm: MetaLearningAlgorithm, model: QuantumNeuralNetwork) -> Self {
102        let num_params = model.parameters.len();
103        let meta_params = model.parameters.clone();
104
105        let per_param_lr = match algorithm {
106            MetaLearningAlgorithm::MetaSGD { .. } => Some(Array1::from_elem(num_params, 0.01)),
107            _ => None,
108        };
109
110        Self {
111            algorithm,
112            model,
113            meta_params,
114            per_param_lr,
115            task_embeddings: HashMap::new(),
116            history: MetaLearningHistory {
117                meta_train_losses: Vec::new(),
118                meta_val_accuracies: Vec::new(),
119                task_performance: HashMap::new(),
120            },
121        }
122    }
123
124    /// Meta-train on multiple tasks
125    pub fn meta_train(
126        &mut self,
127        tasks: &[MetaTask],
128        meta_optimizer: &mut dyn Optimizer,
129        meta_epochs: usize,
130        tasks_per_batch: usize,
131    ) -> Result<()> {
132        println!("Starting meta-training with {} tasks...", tasks.len());
133
134        for epoch in 0..meta_epochs {
135            let mut epoch_loss = 0.0;
136            let mut epoch_acc = 0.0;
137
138            // Sample batch of tasks
139            let task_batch = self.sample_task_batch(tasks, tasks_per_batch);
140
141            // Perform meta-update based on algorithm
142            match self.algorithm {
143                MetaLearningAlgorithm::MAML { .. } => {
144                    let (loss, acc) = self.maml_update(&task_batch, meta_optimizer)?;
145                    epoch_loss += loss;
146                    epoch_acc += acc;
147                }
148                MetaLearningAlgorithm::Reptile { .. } => {
149                    let (loss, acc) = self.reptile_update(&task_batch, meta_optimizer)?;
150                    epoch_loss += loss;
151                    epoch_acc += acc;
152                }
153                MetaLearningAlgorithm::ProtoMAML { .. } => {
154                    let (loss, acc) = self.protomaml_update(&task_batch, meta_optimizer)?;
155                    epoch_loss += loss;
156                    epoch_acc += acc;
157                }
158                MetaLearningAlgorithm::MetaSGD { .. } => {
159                    let (loss, acc) = self.metasgd_update(&task_batch, meta_optimizer)?;
160                    epoch_loss += loss;
161                    epoch_acc += acc;
162                }
163                MetaLearningAlgorithm::ANIL { .. } => {
164                    let (loss, acc) = self.anil_update(&task_batch, meta_optimizer)?;
165                    epoch_loss += loss;
166                    epoch_acc += acc;
167                }
168            }
169
170            // Update history
171            self.history.meta_train_losses.push(epoch_loss);
172            self.history.meta_val_accuracies.push(epoch_acc);
173
174            if epoch % 10 == 0 {
175                println!(
176                    "Epoch {}: Loss = {:.4}, Accuracy = {:.2}%",
177                    epoch,
178                    epoch_loss,
179                    epoch_acc * 100.0
180                );
181            }
182        }
183
184        Ok(())
185    }
186
187    /// MAML update step
188    fn maml_update(
189        &mut self,
190        tasks: &[MetaTask],
191        optimizer: &mut dyn Optimizer,
192    ) -> Result<(f64, f64)> {
193        let (inner_steps, inner_lr, first_order) = match self.algorithm {
194            MetaLearningAlgorithm::MAML {
195                inner_steps,
196                inner_lr,
197                first_order,
198            } => (inner_steps, inner_lr, first_order),
199            _ => unreachable!(),
200        };
201
202        let mut total_loss = 0.0;
203        let mut total_acc = 0.0;
204        let mut meta_gradients = Array1::zeros(self.meta_params.len());
205
206        for task in tasks {
207            // Clone meta-parameters for inner loop
208            let mut task_params = self.meta_params.clone();
209
210            // Inner loop: adapt to task
211            for _ in 0..inner_steps {
212                let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
213                task_params = task_params - inner_lr * &grad;
214            }
215
216            // Compute loss on query set with adapted parameters
217            let (query_loss, query_acc) = self.evaluate_task(&task.test_data, &task_params)?;
218            total_loss += query_loss;
219            total_acc += query_acc;
220
221            // Compute meta-gradient
222            if !first_order {
223                // Full second-order MAML gradient
224                let meta_grad = self.compute_maml_gradient(task, &task_params, inner_lr)?;
225                meta_gradients = meta_gradients + meta_grad;
226            } else {
227                // First-order approximation (FO-MAML)
228                let grad = self.compute_task_gradient(&task.test_data, &task_params)?;
229                meta_gradients = meta_gradients + grad;
230            }
231        }
232
233        // Average gradients and update meta-parameters
234        meta_gradients = meta_gradients / tasks.len() as f64;
235        self.meta_params = self.meta_params.clone() - 0.001 * &meta_gradients; // Meta learning rate
236
237        Ok((
238            total_loss / tasks.len() as f64,
239            total_acc / tasks.len() as f64,
240        ))
241    }
242
243    /// Reptile update step
244    fn reptile_update(
245        &mut self,
246        tasks: &[MetaTask],
247        optimizer: &mut dyn Optimizer,
248    ) -> Result<(f64, f64)> {
249        let (inner_steps, inner_lr) = match self.algorithm {
250            MetaLearningAlgorithm::Reptile {
251                inner_steps,
252                inner_lr,
253            } => (inner_steps, inner_lr),
254            _ => unreachable!(),
255        };
256
257        let mut total_loss = 0.0;
258        let mut total_acc = 0.0;
259        let epsilon = 0.1; // Reptile step size
260
261        for task in tasks {
262            // Clone meta-parameters
263            let mut task_params = self.meta_params.clone();
264
265            // Perform multiple SGD steps on task
266            for _ in 0..inner_steps {
267                let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
268                task_params = task_params - inner_lr * &grad;
269            }
270
271            // Evaluate adapted model
272            let (loss, acc) = self.evaluate_task(&task.test_data, &task_params)?;
273            total_loss += loss;
274            total_acc += acc;
275
276            // Reptile update: move meta-params toward task-adapted params
277            let direction = &task_params - &self.meta_params;
278            self.meta_params = &self.meta_params + epsilon * &direction;
279        }
280
281        Ok((
282            total_loss / tasks.len() as f64,
283            total_acc / tasks.len() as f64,
284        ))
285    }
286
287    /// ProtoMAML update step
288    fn protomaml_update(
289        &mut self,
290        tasks: &[MetaTask],
291        optimizer: &mut dyn Optimizer,
292    ) -> Result<(f64, f64)> {
293        let (inner_steps, inner_lr, proto_weight) = match self.algorithm {
294            MetaLearningAlgorithm::ProtoMAML {
295                inner_steps,
296                inner_lr,
297                proto_weight,
298            } => (inner_steps, inner_lr, proto_weight),
299            _ => unreachable!(),
300        };
301
302        let mut total_loss = 0.0;
303        let mut total_acc = 0.0;
304
305        for task in tasks {
306            // Compute prototypes for each class
307            let prototypes = self.compute_prototypes(&task.train_data, task.num_classes)?;
308
309            // Clone parameters for adaptation
310            let mut task_params = self.meta_params.clone();
311
312            // Inner loop with prototype regularization
313            for _ in 0..inner_steps {
314                let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
315                let proto_reg =
316                    self.prototype_regularization(&task.train_data, &prototypes, &task_params)?;
317                task_params = task_params - inner_lr * (&grad + proto_weight * &proto_reg);
318            }
319
320            // Evaluate with prototypical classification
321            let (loss, acc) =
322                self.evaluate_with_prototypes(&task.test_data, &prototypes, &task_params)?;
323            total_loss += loss;
324            total_acc += acc;
325        }
326
327        Ok((
328            total_loss / tasks.len() as f64,
329            total_acc / tasks.len() as f64,
330        ))
331    }
332
333    /// Meta-SGD update step
334    fn metasgd_update(
335        &mut self,
336        tasks: &[MetaTask],
337        optimizer: &mut dyn Optimizer,
338    ) -> Result<(f64, f64)> {
339        let inner_steps = match self.algorithm {
340            MetaLearningAlgorithm::MetaSGD { inner_steps } => inner_steps,
341            _ => unreachable!(),
342        };
343
344        let mut total_loss = 0.0;
345        let mut total_acc = 0.0;
346        let mut meta_lr_gradients = Array1::zeros(self.meta_params.len());
347
348        for task in tasks {
349            let mut task_params = self.meta_params.clone();
350
351            // Inner loop with learnable per-parameter learning rates
352            for _ in 0..inner_steps {
353                let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
354                let lr = self.per_param_lr.as_ref().unwrap();
355                task_params = task_params - lr * &grad;
356            }
357
358            // Evaluate and compute gradients w.r.t. both parameters and learning rates
359            let (loss, acc) = self.evaluate_task(&task.test_data, &task_params)?;
360            total_loss += loss;
361            total_acc += acc;
362
363            // Compute gradient w.r.t. learning rates
364            let lr_grad = self.compute_lr_gradient(task, &task_params)?;
365            meta_lr_gradients = meta_lr_gradients + lr_grad;
366        }
367
368        // Update both parameters and learning rates
369        if let Some(ref mut lr) = self.per_param_lr {
370            *lr = lr.clone() - &(0.001 * &meta_lr_gradients / tasks.len() as f64);
371        }
372
373        Ok((
374            total_loss / tasks.len() as f64,
375            total_acc / tasks.len() as f64,
376        ))
377    }
378
379    /// ANIL update step
380    fn anil_update(
381        &mut self,
382        tasks: &[MetaTask],
383        optimizer: &mut dyn Optimizer,
384    ) -> Result<(f64, f64)> {
385        let (inner_steps, inner_lr) = match self.algorithm {
386            MetaLearningAlgorithm::ANIL {
387                inner_steps,
388                inner_lr,
389            } => (inner_steps, inner_lr),
390            _ => unreachable!(),
391        };
392
393        // ANIL: Only adapt the final layer(s) in inner loop
394        let num_params = self.meta_params.len();
395        let final_layer_start = (num_params * 3) / 4; // Last 25% of parameters
396
397        let mut total_loss = 0.0;
398        let mut total_acc = 0.0;
399
400        for task in tasks {
401            let mut task_params = self.meta_params.clone();
402
403            // Inner loop: only update final layer parameters
404            for _ in 0..inner_steps {
405                let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
406
407                // Only update final layer
408                for i in final_layer_start..num_params {
409                    task_params[i] -= inner_lr * grad[i];
410                }
411            }
412
413            let (loss, acc) = self.evaluate_task(&task.test_data, &task_params)?;
414            total_loss += loss;
415            total_acc += acc;
416        }
417
418        Ok((
419            total_loss / tasks.len() as f64,
420            total_acc / tasks.len() as f64,
421        ))
422    }
423
424    /// Compute gradient for a task
425    fn compute_task_gradient(
426        &self,
427        data: &[(Array1<f64>, usize)],
428        params: &Array1<f64>,
429    ) -> Result<Array1<f64>> {
430        // Placeholder - would compute actual quantum gradients
431        Ok(Array1::zeros(params.len()))
432    }
433
434    /// Evaluate model on task data
435    fn evaluate_task(
436        &self,
437        data: &[(Array1<f64>, usize)],
438        params: &Array1<f64>,
439    ) -> Result<(f64, f64)> {
440        // Placeholder - would evaluate quantum model
441        let loss = 0.5 + 0.5 * rand::random::<f64>();
442        let acc = 0.5 + 0.3 * rand::random::<f64>();
443        Ok((loss, acc))
444    }
445
446    /// Compute MAML gradient (second-order)
447    fn compute_maml_gradient(
448        &self,
449        task: &MetaTask,
450        adapted_params: &Array1<f64>,
451        inner_lr: f64,
452    ) -> Result<Array1<f64>> {
453        // Placeholder - would compute Hessian-vector products
454        Ok(Array1::zeros(self.meta_params.len()))
455    }
456
457    /// Compute prototypes for ProtoMAML
458    fn compute_prototypes(
459        &self,
460        data: &[(Array1<f64>, usize)],
461        num_classes: usize,
462    ) -> Result<Vec<Array1<f64>>> {
463        let feature_dim = 16; // Placeholder dimension
464        let mut prototypes = vec![Array1::zeros(feature_dim); num_classes];
465        let mut counts = vec![0; num_classes];
466
467        // Placeholder - would encode data and compute class means
468        for (x, label) in data {
469            counts[*label] += 1;
470        }
471
472        Ok(prototypes)
473    }
474
475    /// Prototype regularization
476    fn prototype_regularization(
477        &self,
478        data: &[(Array1<f64>, usize)],
479        prototypes: &[Array1<f64>],
480        params: &Array1<f64>,
481    ) -> Result<Array1<f64>> {
482        // Placeholder - would compute prototype-based regularization
483        Ok(Array1::zeros(params.len()))
484    }
485
486    /// Evaluate with prototypical classification
487    fn evaluate_with_prototypes(
488        &self,
489        data: &[(Array1<f64>, usize)],
490        prototypes: &[Array1<f64>],
491        params: &Array1<f64>,
492    ) -> Result<(f64, f64)> {
493        // Placeholder
494        Ok((0.3, 0.7))
495    }
496
497    /// Compute gradient w.r.t. learning rates for Meta-SGD
498    fn compute_lr_gradient(
499        &self,
500        task: &MetaTask,
501        adapted_params: &Array1<f64>,
502    ) -> Result<Array1<f64>> {
503        // Placeholder
504        Ok(Array1::zeros(self.meta_params.len()))
505    }
506
507    /// Sample batch of tasks
508    fn sample_task_batch(&self, tasks: &[MetaTask], batch_size: usize) -> Vec<MetaTask> {
509        let mut batch = Vec::new();
510        let mut rng = fastrand::Rng::new();
511
512        for _ in 0..batch_size.min(tasks.len()) {
513            let idx = rng.usize(0..tasks.len());
514            batch.push(tasks[idx].clone());
515        }
516
517        batch
518    }
519
520    /// Adapt to new task
521    pub fn adapt_to_task(&mut self, task: &MetaTask) -> Result<Array1<f64>> {
522        let adapted_params = match self.algorithm {
523            MetaLearningAlgorithm::MAML {
524                inner_steps,
525                inner_lr,
526                ..
527            }
528            | MetaLearningAlgorithm::Reptile {
529                inner_steps,
530                inner_lr,
531            }
532            | MetaLearningAlgorithm::ProtoMAML {
533                inner_steps,
534                inner_lr,
535                ..
536            }
537            | MetaLearningAlgorithm::ANIL {
538                inner_steps,
539                inner_lr,
540            } => {
541                let mut params = self.meta_params.clone();
542                for _ in 0..inner_steps {
543                    let grad = self.compute_task_gradient(&task.train_data, &params)?;
544                    params = params - inner_lr * &grad;
545                }
546                params
547            }
548            MetaLearningAlgorithm::MetaSGD { inner_steps } => {
549                let mut params = self.meta_params.clone();
550                let lr = self.per_param_lr.as_ref().unwrap();
551                for _ in 0..inner_steps {
552                    let grad = self.compute_task_gradient(&task.train_data, &params)?;
553                    params = params - lr * &grad;
554                }
555                params
556            }
557        };
558
559        Ok(adapted_params)
560    }
561
562    /// Get task embedding
563    pub fn get_task_embedding(&self, task_id: &str) -> Option<&Array1<f64>> {
564        self.task_embeddings.get(task_id)
565    }
566
567    /// Get meta parameters
568    pub fn meta_params(&self) -> &Array1<f64> {
569        &self.meta_params
570    }
571
572    /// Get per-parameter learning rates
573    pub fn per_param_lr(&self) -> Option<&Array1<f64>> {
574        self.per_param_lr.as_ref()
575    }
576}
577
578/// Continual meta-learning with memory
579pub struct ContinualMetaLearner {
580    /// Base meta-learner
581    meta_learner: QuantumMetaLearner,
582
583    /// Memory buffer for past tasks
584    memory_buffer: Vec<MetaTask>,
585
586    /// Maximum memory size
587    memory_capacity: usize,
588
589    /// Replay ratio
590    replay_ratio: f64,
591}
592
593impl ContinualMetaLearner {
594    /// Create new continual meta-learner
595    pub fn new(
596        meta_learner: QuantumMetaLearner,
597        memory_capacity: usize,
598        replay_ratio: f64,
599    ) -> Self {
600        Self {
601            meta_learner,
602            memory_buffer: Vec::new(),
603            memory_capacity,
604            replay_ratio,
605        }
606    }
607
608    /// Learn new task while preserving old knowledge
609    pub fn learn_task(&mut self, new_task: MetaTask) -> Result<()> {
610        // Add to memory with reservoir sampling
611        if self.memory_buffer.len() < self.memory_capacity {
612            self.memory_buffer.push(new_task.clone());
613        } else {
614            let idx = fastrand::usize(0..self.memory_buffer.len());
615            self.memory_buffer[idx] = new_task.clone();
616        }
617
618        // Create mixed batch with replay
619        let num_replay = (self.memory_buffer.len() as f64 * self.replay_ratio) as usize;
620        let mut task_batch = vec![new_task];
621
622        for _ in 0..num_replay {
623            let idx = fastrand::usize(0..self.memory_buffer.len());
624            task_batch.push(self.memory_buffer[idx].clone());
625        }
626
627        // Update meta-learner
628        let mut dummy_optimizer = crate::autodiff::optimizers::Adam::new(0.001);
629        self.meta_learner
630            .meta_train(&task_batch, &mut dummy_optimizer, 10, task_batch.len())?;
631
632        Ok(())
633    }
634
635    /// Get memory buffer length
636    pub fn memory_buffer_len(&self) -> usize {
637        self.memory_buffer.len()
638    }
639}
640
641/// Task generator for meta-learning experiments
642pub struct TaskGenerator {
643    /// Feature dimension
644    feature_dim: usize,
645
646    /// Number of classes per task
647    num_classes: usize,
648
649    /// Task distribution parameters
650    task_params: HashMap<String, f64>,
651}
652
653impl TaskGenerator {
654    /// Create new task generator
655    pub fn new(feature_dim: usize, num_classes: usize) -> Self {
656        Self {
657            feature_dim,
658            num_classes,
659            task_params: HashMap::new(),
660        }
661    }
662
663    /// Generate sinusoid regression task
664    pub fn generate_sinusoid_task(&self, num_samples: usize) -> MetaTask {
665        let amplitude = 0.1 + 4.9 * rand::random::<f64>();
666        let phase = 2.0 * std::f64::consts::PI * rand::random::<f64>();
667
668        let mut train_data = Vec::new();
669        let mut test_data = Vec::new();
670
671        // Generate samples
672        for i in 0..num_samples {
673            let x = -5.0 + 10.0 * rand::random::<f64>();
674            let y = amplitude * (x + phase).sin();
675
676            let input = Array1::from_vec(vec![x]);
677            let label = if y > 0.0 { 1 } else { 0 }; // Binarize for classification
678
679            if i < num_samples / 2 {
680                train_data.push((input, label));
681            } else {
682                test_data.push((input, label));
683            }
684        }
685
686        MetaTask {
687            id: format!("sin_a{:.2}_p{:.2}", amplitude, phase),
688            train_data,
689            test_data,
690            num_classes: 2,
691            metadata: vec![
692                ("amplitude".to_string(), amplitude),
693                ("phase".to_string(), phase),
694            ]
695            .into_iter()
696            .collect(),
697        }
698    }
699
700    /// Generate classification task with rotated features
701    pub fn generate_rotation_task(&self, num_samples: usize) -> MetaTask {
702        let angle = 2.0 * std::f64::consts::PI * rand::random::<f64>();
703        let cos_a = angle.cos();
704        let sin_a = angle.sin();
705
706        let mut train_data = Vec::new();
707        let mut test_data = Vec::new();
708
709        for i in 0..num_samples {
710            // Generate base features
711            let mut features = Array1::zeros(self.feature_dim);
712            let label = i % self.num_classes;
713
714            // Class-specific pattern
715            for j in 0..self.feature_dim {
716                features[j] = if j % self.num_classes == label {
717                    1.0
718                } else {
719                    0.0
720                };
721                features[j] += 0.1 * rand::random::<f64>();
722            }
723
724            // Apply rotation (simplified for first 2 dims)
725            if self.feature_dim >= 2 {
726                let x = features[0];
727                let y = features[1];
728                features[0] = cos_a * x - sin_a * y;
729                features[1] = sin_a * x + cos_a * y;
730            }
731
732            if i < num_samples / 2 {
733                train_data.push((features, label));
734            } else {
735                test_data.push((features, label));
736            }
737        }
738
739        MetaTask {
740            id: format!("rot_{:.2}", angle),
741            train_data,
742            test_data,
743            num_classes: self.num_classes,
744            metadata: vec![("rotation_angle".to_string(), angle)]
745                .into_iter()
746                .collect(),
747        }
748    }
749}
750
751#[cfg(test)]
752mod tests {
753    use super::*;
754    use crate::autodiff::optimizers::Adam;
755    use crate::qnn::QNNLayerType;
756
757    #[test]
758    fn test_task_generator() {
759        let generator = TaskGenerator::new(4, 2);
760
761        let sin_task = generator.generate_sinusoid_task(20);
762        assert_eq!(sin_task.train_data.len(), 10);
763        assert_eq!(sin_task.test_data.len(), 10);
764
765        let rot_task = generator.generate_rotation_task(30);
766        assert_eq!(rot_task.train_data.len(), 15);
767        assert_eq!(rot_task.test_data.len(), 15);
768    }
769
770    #[test]
771    fn test_meta_learner_creation() {
772        let layers = vec![
773            QNNLayerType::EncodingLayer { num_features: 4 },
774            QNNLayerType::VariationalLayer { num_params: 8 },
775            QNNLayerType::MeasurementLayer {
776                measurement_basis: "computational".to_string(),
777            },
778        ];
779
780        let qnn = QuantumNeuralNetwork::new(layers, 4, 4, 2).unwrap();
781
782        let maml_algo = MetaLearningAlgorithm::MAML {
783            inner_steps: 5,
784            inner_lr: 0.01,
785            first_order: true,
786        };
787
788        let meta_learner = QuantumMetaLearner::new(maml_algo, qnn);
789        assert!(meta_learner.per_param_lr.is_none());
790
791        // Test Meta-SGD
792        let layers2 = vec![
793            QNNLayerType::EncodingLayer { num_features: 4 },
794            QNNLayerType::VariationalLayer { num_params: 8 },
795        ];
796        let qnn2 = QuantumNeuralNetwork::new(layers2, 4, 4, 2).unwrap();
797
798        let metasgd_algo = MetaLearningAlgorithm::MetaSGD { inner_steps: 3 };
799        let meta_sgd = QuantumMetaLearner::new(metasgd_algo, qnn2);
800        assert!(meta_sgd.per_param_lr.is_some());
801    }
802
803    #[test]
804    fn test_task_adaptation() {
805        let layers = vec![
806            QNNLayerType::EncodingLayer { num_features: 2 },
807            QNNLayerType::VariationalLayer { num_params: 6 },
808        ];
809
810        let qnn = QuantumNeuralNetwork::new(layers, 4, 2, 2).unwrap();
811        let algo = MetaLearningAlgorithm::Reptile {
812            inner_steps: 5,
813            inner_lr: 0.01,
814        };
815
816        let mut meta_learner = QuantumMetaLearner::new(algo, qnn);
817
818        // Generate task
819        let generator = TaskGenerator::new(2, 2);
820        let task = generator.generate_rotation_task(20);
821
822        // Adapt to task
823        let adapted_params = meta_learner.adapt_to_task(&task).unwrap();
824        assert_eq!(adapted_params.len(), meta_learner.meta_params.len());
825    }
826}