quantrs2_ml/
meta_learning.rs

1//! Quantum Meta-Learning Algorithms
2//!
3//! This module implements various meta-learning algorithms adapted for quantum circuits,
4//! enabling quantum models to learn how to learn from limited data across multiple tasks.
5
6use crate::autodiff::optimizers::Optimizer;
7use crate::error::{MLError, Result};
8use crate::optimization::OptimizationMethod;
9use crate::qnn::{QNNLayerType, QuantumNeuralNetwork};
10use quantrs2_circuit::builder::{Circuit, Simulator};
11use quantrs2_core::gate::{
12    single::{RotationX, RotationY, RotationZ},
13    GateOp,
14};
15use quantrs2_sim::statevector::StateVectorSimulator;
16use scirs2_core::ndarray::{s, Array1, Array2, Array3, Axis};
17use scirs2_core::random::prelude::*;
18use std::collections::HashMap;
19
20/// Meta-learning algorithm types
21#[derive(Debug, Clone, Copy)]
22pub enum MetaLearningAlgorithm {
23    /// Model-Agnostic Meta-Learning
24    MAML {
25        inner_steps: usize,
26        inner_lr: f64,
27        first_order: bool,
28    },
29
30    /// Reptile algorithm
31    Reptile { inner_steps: usize, inner_lr: f64 },
32
33    /// Prototypical MAML
34    ProtoMAML {
35        inner_steps: usize,
36        inner_lr: f64,
37        proto_weight: f64,
38    },
39
40    /// Meta-SGD with learnable inner learning rates
41    MetaSGD { inner_steps: usize },
42
43    /// Almost No Inner Loop (ANIL)
44    ANIL { inner_steps: usize, inner_lr: f64 },
45}
46
47/// Task definition for meta-learning
48#[derive(Debug, Clone)]
49pub struct MetaTask {
50    /// Task identifier
51    pub id: String,
52
53    /// Training data (support set)
54    pub train_data: Vec<(Array1<f64>, usize)>,
55
56    /// Test data (query set)
57    pub test_data: Vec<(Array1<f64>, usize)>,
58
59    /// Number of classes
60    pub num_classes: usize,
61
62    /// Task-specific metadata
63    pub metadata: HashMap<String, f64>,
64}
65
66/// Base quantum meta-learner
67pub struct QuantumMetaLearner {
68    /// Meta-learning algorithm
69    algorithm: MetaLearningAlgorithm,
70
71    /// Base quantum model
72    model: QuantumNeuralNetwork,
73
74    /// Meta-parameters
75    meta_params: Array1<f64>,
76
77    /// Per-parameter learning rates (for Meta-SGD)
78    per_param_lr: Option<Array1<f64>>,
79
80    /// Task embeddings
81    task_embeddings: HashMap<String, Array1<f64>>,
82
83    /// Training history
84    history: MetaLearningHistory,
85}
86
87/// Training history for meta-learning
88#[derive(Debug, Clone)]
89pub struct MetaLearningHistory {
90    /// Meta-train losses
91    pub meta_train_losses: Vec<f64>,
92
93    /// Meta-validation accuracies
94    pub meta_val_accuracies: Vec<f64>,
95
96    /// Per-task performance
97    pub task_performance: HashMap<String, Vec<f64>>,
98}
99
100impl QuantumMetaLearner {
101    /// Create a new quantum meta-learner
102    pub fn new(algorithm: MetaLearningAlgorithm, model: QuantumNeuralNetwork) -> Self {
103        let num_params = model.parameters.len();
104        let meta_params = model.parameters.clone();
105
106        let per_param_lr = match algorithm {
107            MetaLearningAlgorithm::MetaSGD { .. } => Some(Array1::from_elem(num_params, 0.01)),
108            _ => None,
109        };
110
111        Self {
112            algorithm,
113            model,
114            meta_params,
115            per_param_lr,
116            task_embeddings: HashMap::new(),
117            history: MetaLearningHistory {
118                meta_train_losses: Vec::new(),
119                meta_val_accuracies: Vec::new(),
120                task_performance: HashMap::new(),
121            },
122        }
123    }
124
125    /// Meta-train on multiple tasks
126    pub fn meta_train(
127        &mut self,
128        tasks: &[MetaTask],
129        meta_optimizer: &mut dyn Optimizer,
130        meta_epochs: usize,
131        tasks_per_batch: usize,
132    ) -> Result<()> {
133        println!("Starting meta-training with {} tasks...", tasks.len());
134
135        for epoch in 0..meta_epochs {
136            let mut epoch_loss = 0.0;
137            let mut epoch_acc = 0.0;
138
139            // Sample batch of tasks
140            let task_batch = self.sample_task_batch(tasks, tasks_per_batch);
141
142            // Perform meta-update based on algorithm
143            match self.algorithm {
144                MetaLearningAlgorithm::MAML { .. } => {
145                    let (loss, acc) = self.maml_update(&task_batch, meta_optimizer)?;
146                    epoch_loss += loss;
147                    epoch_acc += acc;
148                }
149                MetaLearningAlgorithm::Reptile { .. } => {
150                    let (loss, acc) = self.reptile_update(&task_batch, meta_optimizer)?;
151                    epoch_loss += loss;
152                    epoch_acc += acc;
153                }
154                MetaLearningAlgorithm::ProtoMAML { .. } => {
155                    let (loss, acc) = self.protomaml_update(&task_batch, meta_optimizer)?;
156                    epoch_loss += loss;
157                    epoch_acc += acc;
158                }
159                MetaLearningAlgorithm::MetaSGD { .. } => {
160                    let (loss, acc) = self.metasgd_update(&task_batch, meta_optimizer)?;
161                    epoch_loss += loss;
162                    epoch_acc += acc;
163                }
164                MetaLearningAlgorithm::ANIL { .. } => {
165                    let (loss, acc) = self.anil_update(&task_batch, meta_optimizer)?;
166                    epoch_loss += loss;
167                    epoch_acc += acc;
168                }
169            }
170
171            // Update history
172            self.history.meta_train_losses.push(epoch_loss);
173            self.history.meta_val_accuracies.push(epoch_acc);
174
175            if epoch % 10 == 0 {
176                println!(
177                    "Epoch {}: Loss = {:.4}, Accuracy = {:.2}%",
178                    epoch,
179                    epoch_loss,
180                    epoch_acc * 100.0
181                );
182            }
183        }
184
185        Ok(())
186    }
187
188    /// MAML update step
189    fn maml_update(
190        &mut self,
191        tasks: &[MetaTask],
192        optimizer: &mut dyn Optimizer,
193    ) -> Result<(f64, f64)> {
194        let (inner_steps, inner_lr, first_order) = match self.algorithm {
195            MetaLearningAlgorithm::MAML {
196                inner_steps,
197                inner_lr,
198                first_order,
199            } => (inner_steps, inner_lr, first_order),
200            _ => unreachable!(),
201        };
202
203        let mut total_loss = 0.0;
204        let mut total_acc = 0.0;
205        let mut meta_gradients = Array1::zeros(self.meta_params.len());
206
207        for task in tasks {
208            // Clone meta-parameters for inner loop
209            let mut task_params = self.meta_params.clone();
210
211            // Inner loop: adapt to task
212            for _ in 0..inner_steps {
213                let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
214                task_params = task_params - inner_lr * &grad;
215            }
216
217            // Compute loss on query set with adapted parameters
218            let (query_loss, query_acc) = self.evaluate_task(&task.test_data, &task_params)?;
219            total_loss += query_loss;
220            total_acc += query_acc;
221
222            // Compute meta-gradient
223            if !first_order {
224                // Full second-order MAML gradient
225                let meta_grad = self.compute_maml_gradient(task, &task_params, inner_lr)?;
226                meta_gradients = meta_gradients + meta_grad;
227            } else {
228                // First-order approximation (FO-MAML)
229                let grad = self.compute_task_gradient(&task.test_data, &task_params)?;
230                meta_gradients = meta_gradients + grad;
231            }
232        }
233
234        // Average gradients and update meta-parameters
235        meta_gradients = meta_gradients / tasks.len() as f64;
236        self.meta_params = self.meta_params.clone() - 0.001 * &meta_gradients; // Meta learning rate
237
238        Ok((
239            total_loss / tasks.len() as f64,
240            total_acc / tasks.len() as f64,
241        ))
242    }
243
244    /// Reptile update step
245    fn reptile_update(
246        &mut self,
247        tasks: &[MetaTask],
248        optimizer: &mut dyn Optimizer,
249    ) -> Result<(f64, f64)> {
250        let (inner_steps, inner_lr) = match self.algorithm {
251            MetaLearningAlgorithm::Reptile {
252                inner_steps,
253                inner_lr,
254            } => (inner_steps, inner_lr),
255            _ => unreachable!(),
256        };
257
258        let mut total_loss = 0.0;
259        let mut total_acc = 0.0;
260        let epsilon = 0.1; // Reptile step size
261
262        for task in tasks {
263            // Clone meta-parameters
264            let mut task_params = self.meta_params.clone();
265
266            // Perform multiple SGD steps on task
267            for _ in 0..inner_steps {
268                let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
269                task_params = task_params - inner_lr * &grad;
270            }
271
272            // Evaluate adapted model
273            let (loss, acc) = self.evaluate_task(&task.test_data, &task_params)?;
274            total_loss += loss;
275            total_acc += acc;
276
277            // Reptile update: move meta-params toward task-adapted params
278            let direction = &task_params - &self.meta_params;
279            self.meta_params = &self.meta_params + epsilon * &direction;
280        }
281
282        Ok((
283            total_loss / tasks.len() as f64,
284            total_acc / tasks.len() as f64,
285        ))
286    }
287
288    /// ProtoMAML update step
289    fn protomaml_update(
290        &mut self,
291        tasks: &[MetaTask],
292        optimizer: &mut dyn Optimizer,
293    ) -> Result<(f64, f64)> {
294        let (inner_steps, inner_lr, proto_weight) = match self.algorithm {
295            MetaLearningAlgorithm::ProtoMAML {
296                inner_steps,
297                inner_lr,
298                proto_weight,
299            } => (inner_steps, inner_lr, proto_weight),
300            _ => unreachable!(),
301        };
302
303        let mut total_loss = 0.0;
304        let mut total_acc = 0.0;
305
306        for task in tasks {
307            // Compute prototypes for each class
308            let prototypes = self.compute_prototypes(&task.train_data, task.num_classes)?;
309
310            // Clone parameters for adaptation
311            let mut task_params = self.meta_params.clone();
312
313            // Inner loop with prototype regularization
314            for _ in 0..inner_steps {
315                let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
316                let proto_reg =
317                    self.prototype_regularization(&task.train_data, &prototypes, &task_params)?;
318                task_params = task_params - inner_lr * (&grad + proto_weight * &proto_reg);
319            }
320
321            // Evaluate with prototypical classification
322            let (loss, acc) =
323                self.evaluate_with_prototypes(&task.test_data, &prototypes, &task_params)?;
324            total_loss += loss;
325            total_acc += acc;
326        }
327
328        Ok((
329            total_loss / tasks.len() as f64,
330            total_acc / tasks.len() as f64,
331        ))
332    }
333
334    /// Meta-SGD update step
335    fn metasgd_update(
336        &mut self,
337        tasks: &[MetaTask],
338        optimizer: &mut dyn Optimizer,
339    ) -> Result<(f64, f64)> {
340        let inner_steps = match self.algorithm {
341            MetaLearningAlgorithm::MetaSGD { inner_steps } => inner_steps,
342            _ => unreachable!(),
343        };
344
345        let mut total_loss = 0.0;
346        let mut total_acc = 0.0;
347        let mut meta_lr_gradients = Array1::zeros(self.meta_params.len());
348
349        for task in tasks {
350            let mut task_params = self.meta_params.clone();
351
352            // Inner loop with learnable per-parameter learning rates
353            for _ in 0..inner_steps {
354                let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
355                let lr = self.per_param_lr.as_ref().unwrap();
356                task_params = task_params - lr * &grad;
357            }
358
359            // Evaluate and compute gradients w.r.t. both parameters and learning rates
360            let (loss, acc) = self.evaluate_task(&task.test_data, &task_params)?;
361            total_loss += loss;
362            total_acc += acc;
363
364            // Compute gradient w.r.t. learning rates
365            let lr_grad = self.compute_lr_gradient(task, &task_params)?;
366            meta_lr_gradients = meta_lr_gradients + lr_grad;
367        }
368
369        // Update both parameters and learning rates
370        if let Some(ref mut lr) = self.per_param_lr {
371            *lr = lr.clone() - &(0.001 * &meta_lr_gradients / tasks.len() as f64);
372        }
373
374        Ok((
375            total_loss / tasks.len() as f64,
376            total_acc / tasks.len() as f64,
377        ))
378    }
379
380    /// ANIL update step
381    fn anil_update(
382        &mut self,
383        tasks: &[MetaTask],
384        optimizer: &mut dyn Optimizer,
385    ) -> Result<(f64, f64)> {
386        let (inner_steps, inner_lr) = match self.algorithm {
387            MetaLearningAlgorithm::ANIL {
388                inner_steps,
389                inner_lr,
390            } => (inner_steps, inner_lr),
391            _ => unreachable!(),
392        };
393
394        // ANIL: Only adapt the final layer(s) in inner loop
395        let num_params = self.meta_params.len();
396        let final_layer_start = (num_params * 3) / 4; // Last 25% of parameters
397
398        let mut total_loss = 0.0;
399        let mut total_acc = 0.0;
400
401        for task in tasks {
402            let mut task_params = self.meta_params.clone();
403
404            // Inner loop: only update final layer parameters
405            for _ in 0..inner_steps {
406                let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
407
408                // Only update final layer
409                for i in final_layer_start..num_params {
410                    task_params[i] -= inner_lr * grad[i];
411                }
412            }
413
414            let (loss, acc) = self.evaluate_task(&task.test_data, &task_params)?;
415            total_loss += loss;
416            total_acc += acc;
417        }
418
419        Ok((
420            total_loss / tasks.len() as f64,
421            total_acc / tasks.len() as f64,
422        ))
423    }
424
425    /// Compute gradient for a task
426    fn compute_task_gradient(
427        &self,
428        data: &[(Array1<f64>, usize)],
429        params: &Array1<f64>,
430    ) -> Result<Array1<f64>> {
431        // Placeholder - would compute actual quantum gradients
432        Ok(Array1::zeros(params.len()))
433    }
434
435    /// Evaluate model on task data
436    fn evaluate_task(
437        &self,
438        data: &[(Array1<f64>, usize)],
439        params: &Array1<f64>,
440    ) -> Result<(f64, f64)> {
441        // Placeholder - would evaluate quantum model
442        let loss = 0.5 + 0.5 * thread_rng().gen::<f64>();
443        let acc = 0.5 + 0.3 * thread_rng().gen::<f64>();
444        Ok((loss, acc))
445    }
446
447    /// Compute MAML gradient (second-order)
448    fn compute_maml_gradient(
449        &self,
450        task: &MetaTask,
451        adapted_params: &Array1<f64>,
452        inner_lr: f64,
453    ) -> Result<Array1<f64>> {
454        // Placeholder - would compute Hessian-vector products
455        Ok(Array1::zeros(self.meta_params.len()))
456    }
457
458    /// Compute prototypes for ProtoMAML
459    fn compute_prototypes(
460        &self,
461        data: &[(Array1<f64>, usize)],
462        num_classes: usize,
463    ) -> Result<Vec<Array1<f64>>> {
464        let feature_dim = 16; // Placeholder dimension
465        let mut prototypes = vec![Array1::zeros(feature_dim); num_classes];
466        let mut counts = vec![0; num_classes];
467
468        // Placeholder - would encode data and compute class means
469        for (x, label) in data {
470            counts[*label] += 1;
471        }
472
473        Ok(prototypes)
474    }
475
476    /// Prototype regularization
477    fn prototype_regularization(
478        &self,
479        data: &[(Array1<f64>, usize)],
480        prototypes: &[Array1<f64>],
481        params: &Array1<f64>,
482    ) -> Result<Array1<f64>> {
483        // Placeholder - would compute prototype-based regularization
484        Ok(Array1::zeros(params.len()))
485    }
486
487    /// Evaluate with prototypical classification
488    fn evaluate_with_prototypes(
489        &self,
490        data: &[(Array1<f64>, usize)],
491        prototypes: &[Array1<f64>],
492        params: &Array1<f64>,
493    ) -> Result<(f64, f64)> {
494        // Placeholder
495        Ok((0.3, 0.7))
496    }
497
498    /// Compute gradient w.r.t. learning rates for Meta-SGD
499    fn compute_lr_gradient(
500        &self,
501        task: &MetaTask,
502        adapted_params: &Array1<f64>,
503    ) -> Result<Array1<f64>> {
504        // Placeholder
505        Ok(Array1::zeros(self.meta_params.len()))
506    }
507
508    /// Sample batch of tasks
509    fn sample_task_batch(&self, tasks: &[MetaTask], batch_size: usize) -> Vec<MetaTask> {
510        let mut batch = Vec::new();
511        let mut rng = thread_rng();
512
513        for _ in 0..batch_size.min(tasks.len()) {
514            let idx = rng.gen_range(0..tasks.len());
515            batch.push(tasks[idx].clone());
516        }
517
518        batch
519    }
520
521    /// Adapt to new task
522    pub fn adapt_to_task(&mut self, task: &MetaTask) -> Result<Array1<f64>> {
523        let adapted_params = match self.algorithm {
524            MetaLearningAlgorithm::MAML {
525                inner_steps,
526                inner_lr,
527                ..
528            }
529            | MetaLearningAlgorithm::Reptile {
530                inner_steps,
531                inner_lr,
532            }
533            | MetaLearningAlgorithm::ProtoMAML {
534                inner_steps,
535                inner_lr,
536                ..
537            }
538            | MetaLearningAlgorithm::ANIL {
539                inner_steps,
540                inner_lr,
541            } => {
542                let mut params = self.meta_params.clone();
543                for _ in 0..inner_steps {
544                    let grad = self.compute_task_gradient(&task.train_data, &params)?;
545                    params = params - inner_lr * &grad;
546                }
547                params
548            }
549            MetaLearningAlgorithm::MetaSGD { inner_steps } => {
550                let mut params = self.meta_params.clone();
551                let lr = self.per_param_lr.as_ref().unwrap();
552                for _ in 0..inner_steps {
553                    let grad = self.compute_task_gradient(&task.train_data, &params)?;
554                    params = params - lr * &grad;
555                }
556                params
557            }
558        };
559
560        Ok(adapted_params)
561    }
562
563    /// Get task embedding
564    pub fn get_task_embedding(&self, task_id: &str) -> Option<&Array1<f64>> {
565        self.task_embeddings.get(task_id)
566    }
567
568    /// Get meta parameters
569    pub fn meta_params(&self) -> &Array1<f64> {
570        &self.meta_params
571    }
572
573    /// Get per-parameter learning rates
574    pub fn per_param_lr(&self) -> Option<&Array1<f64>> {
575        self.per_param_lr.as_ref()
576    }
577}
578
579/// Continual meta-learning with memory
580pub struct ContinualMetaLearner {
581    /// Base meta-learner
582    meta_learner: QuantumMetaLearner,
583
584    /// Memory buffer for past tasks
585    memory_buffer: Vec<MetaTask>,
586
587    /// Maximum memory size
588    memory_capacity: usize,
589
590    /// Replay ratio
591    replay_ratio: f64,
592}
593
594impl ContinualMetaLearner {
595    /// Create new continual meta-learner
596    pub fn new(
597        meta_learner: QuantumMetaLearner,
598        memory_capacity: usize,
599        replay_ratio: f64,
600    ) -> Self {
601        Self {
602            meta_learner,
603            memory_buffer: Vec::new(),
604            memory_capacity,
605            replay_ratio,
606        }
607    }
608
609    /// Learn new task while preserving old knowledge
610    pub fn learn_task(&mut self, new_task: MetaTask) -> Result<()> {
611        // Add to memory with reservoir sampling
612        if self.memory_buffer.len() < self.memory_capacity {
613            self.memory_buffer.push(new_task.clone());
614        } else {
615            let idx = fastrand::usize(0..self.memory_buffer.len());
616            self.memory_buffer[idx] = new_task.clone();
617        }
618
619        // Create mixed batch with replay
620        let num_replay = (self.memory_buffer.len() as f64 * self.replay_ratio) as usize;
621        let mut task_batch = vec![new_task];
622
623        for _ in 0..num_replay {
624            let idx = fastrand::usize(0..self.memory_buffer.len());
625            task_batch.push(self.memory_buffer[idx].clone());
626        }
627
628        // Update meta-learner
629        let mut dummy_optimizer = crate::autodiff::optimizers::Adam::new(0.001);
630        self.meta_learner
631            .meta_train(&task_batch, &mut dummy_optimizer, 10, task_batch.len())?;
632
633        Ok(())
634    }
635
636    /// Get memory buffer length
637    pub fn memory_buffer_len(&self) -> usize {
638        self.memory_buffer.len()
639    }
640}
641
642/// Task generator for meta-learning experiments
643pub struct TaskGenerator {
644    /// Feature dimension
645    feature_dim: usize,
646
647    /// Number of classes per task
648    num_classes: usize,
649
650    /// Task distribution parameters
651    task_params: HashMap<String, f64>,
652}
653
654impl TaskGenerator {
655    /// Create new task generator
656    pub fn new(feature_dim: usize, num_classes: usize) -> Self {
657        Self {
658            feature_dim,
659            num_classes,
660            task_params: HashMap::new(),
661        }
662    }
663
664    /// Generate sinusoid regression task
665    pub fn generate_sinusoid_task(&self, num_samples: usize) -> MetaTask {
666        let amplitude = 0.1 + 4.9 * thread_rng().gen::<f64>();
667        let phase = 2.0 * std::f64::consts::PI * thread_rng().gen::<f64>();
668
669        let mut train_data = Vec::new();
670        let mut test_data = Vec::new();
671
672        // Generate samples
673        for i in 0..num_samples {
674            let x = -5.0 + 10.0 * thread_rng().gen::<f64>();
675            let y = amplitude * (x + phase).sin();
676
677            let input = Array1::from_vec(vec![x]);
678            let label = if y > 0.0 { 1 } else { 0 }; // Binarize for classification
679
680            if i < num_samples / 2 {
681                train_data.push((input, label));
682            } else {
683                test_data.push((input, label));
684            }
685        }
686
687        MetaTask {
688            id: format!("sin_a{:.2}_p{:.2}", amplitude, phase),
689            train_data,
690            test_data,
691            num_classes: 2,
692            metadata: vec![
693                ("amplitude".to_string(), amplitude),
694                ("phase".to_string(), phase),
695            ]
696            .into_iter()
697            .collect(),
698        }
699    }
700
701    /// Generate classification task with rotated features
702    pub fn generate_rotation_task(&self, num_samples: usize) -> MetaTask {
703        let angle = 2.0 * std::f64::consts::PI * thread_rng().gen::<f64>();
704        let cos_a = angle.cos();
705        let sin_a = angle.sin();
706
707        let mut train_data = Vec::new();
708        let mut test_data = Vec::new();
709
710        for i in 0..num_samples {
711            // Generate base features
712            let mut features = Array1::zeros(self.feature_dim);
713            let label = i % self.num_classes;
714
715            // Class-specific pattern
716            for j in 0..self.feature_dim {
717                features[j] = if j % self.num_classes == label {
718                    1.0
719                } else {
720                    0.0
721                };
722                features[j] += 0.1 * thread_rng().gen::<f64>();
723            }
724
725            // Apply rotation (simplified for first 2 dims)
726            if self.feature_dim >= 2 {
727                let x = features[0];
728                let y = features[1];
729                features[0] = cos_a * x - sin_a * y;
730                features[1] = sin_a * x + cos_a * y;
731            }
732
733            if i < num_samples / 2 {
734                train_data.push((features, label));
735            } else {
736                test_data.push((features, label));
737            }
738        }
739
740        MetaTask {
741            id: format!("rot_{:.2}", angle),
742            train_data,
743            test_data,
744            num_classes: self.num_classes,
745            metadata: vec![("rotation_angle".to_string(), angle)]
746                .into_iter()
747                .collect(),
748        }
749    }
750}
751
752#[cfg(test)]
753mod tests {
754    use super::*;
755    use crate::autodiff::optimizers::Adam;
756    use crate::qnn::QNNLayerType;
757
758    #[test]
759    fn test_task_generator() {
760        let generator = TaskGenerator::new(4, 2);
761
762        let sin_task = generator.generate_sinusoid_task(20);
763        assert_eq!(sin_task.train_data.len(), 10);
764        assert_eq!(sin_task.test_data.len(), 10);
765
766        let rot_task = generator.generate_rotation_task(30);
767        assert_eq!(rot_task.train_data.len(), 15);
768        assert_eq!(rot_task.test_data.len(), 15);
769    }
770
771    #[test]
772    fn test_meta_learner_creation() {
773        let layers = vec![
774            QNNLayerType::EncodingLayer { num_features: 4 },
775            QNNLayerType::VariationalLayer { num_params: 8 },
776            QNNLayerType::MeasurementLayer {
777                measurement_basis: "computational".to_string(),
778            },
779        ];
780
781        let qnn = QuantumNeuralNetwork::new(layers, 4, 4, 2).unwrap();
782
783        let maml_algo = MetaLearningAlgorithm::MAML {
784            inner_steps: 5,
785            inner_lr: 0.01,
786            first_order: true,
787        };
788
789        let meta_learner = QuantumMetaLearner::new(maml_algo, qnn);
790        assert!(meta_learner.per_param_lr.is_none());
791
792        // Test Meta-SGD
793        let layers2 = vec![
794            QNNLayerType::EncodingLayer { num_features: 4 },
795            QNNLayerType::VariationalLayer { num_params: 8 },
796        ];
797        let qnn2 = QuantumNeuralNetwork::new(layers2, 4, 4, 2).unwrap();
798
799        let metasgd_algo = MetaLearningAlgorithm::MetaSGD { inner_steps: 3 };
800        let meta_sgd = QuantumMetaLearner::new(metasgd_algo, qnn2);
801        assert!(meta_sgd.per_param_lr.is_some());
802    }
803
804    #[test]
805    fn test_task_adaptation() {
806        let layers = vec![
807            QNNLayerType::EncodingLayer { num_features: 2 },
808            QNNLayerType::VariationalLayer { num_params: 6 },
809        ];
810
811        let qnn = QuantumNeuralNetwork::new(layers, 4, 2, 2).unwrap();
812        let algo = MetaLearningAlgorithm::Reptile {
813            inner_steps: 5,
814            inner_lr: 0.01,
815        };
816
817        let mut meta_learner = QuantumMetaLearner::new(algo, qnn);
818
819        // Generate task
820        let generator = TaskGenerator::new(2, 2);
821        let task = generator.generate_rotation_task(20);
822
823        // Adapt to task
824        let adapted_params = meta_learner.adapt_to_task(&task).unwrap();
825        assert_eq!(adapted_params.len(), meta_learner.meta_params.len());
826    }
827}