quantrs2_ml/
meta_learning.rs

1//! Quantum Meta-Learning Algorithms
2//!
3//! This module implements various meta-learning algorithms adapted for quantum circuits,
4//! enabling quantum models to learn how to learn from limited data across multiple tasks.
5
6use crate::autodiff::optimizers::Optimizer;
7use crate::error::{MLError, Result};
8use crate::optimization::OptimizationMethod;
9use crate::qnn::{QNNLayerType, QuantumNeuralNetwork};
10use quantrs2_circuit::builder::{Circuit, Simulator};
11use quantrs2_core::gate::{
12    single::{RotationX, RotationY, RotationZ},
13    GateOp,
14};
15use quantrs2_sim::statevector::StateVectorSimulator;
16use scirs2_core::ndarray::{s, Array1, Array2, Array3, Axis};
17use scirs2_core::random::prelude::*;
18use std::collections::HashMap;
19
20/// Meta-learning algorithm types
21#[derive(Debug, Clone, Copy)]
22pub enum MetaLearningAlgorithm {
23    /// Model-Agnostic Meta-Learning
24    MAML {
25        inner_steps: usize,
26        inner_lr: f64,
27        first_order: bool,
28    },
29
30    /// Reptile algorithm
31    Reptile { inner_steps: usize, inner_lr: f64 },
32
33    /// Prototypical MAML
34    ProtoMAML {
35        inner_steps: usize,
36        inner_lr: f64,
37        proto_weight: f64,
38    },
39
40    /// Meta-SGD with learnable inner learning rates
41    MetaSGD { inner_steps: usize },
42
43    /// Almost No Inner Loop (ANIL)
44    ANIL { inner_steps: usize, inner_lr: f64 },
45}
46
47/// Task definition for meta-learning
48#[derive(Debug, Clone)]
49pub struct MetaTask {
50    /// Task identifier
51    pub id: String,
52
53    /// Training data (support set)
54    pub train_data: Vec<(Array1<f64>, usize)>,
55
56    /// Test data (query set)
57    pub test_data: Vec<(Array1<f64>, usize)>,
58
59    /// Number of classes
60    pub num_classes: usize,
61
62    /// Task-specific metadata
63    pub metadata: HashMap<String, f64>,
64}
65
66/// Base quantum meta-learner
67pub struct QuantumMetaLearner {
68    /// Meta-learning algorithm
69    algorithm: MetaLearningAlgorithm,
70
71    /// Base quantum model
72    model: QuantumNeuralNetwork,
73
74    /// Meta-parameters
75    meta_params: Array1<f64>,
76
77    /// Per-parameter learning rates (for Meta-SGD)
78    per_param_lr: Option<Array1<f64>>,
79
80    /// Task embeddings
81    task_embeddings: HashMap<String, Array1<f64>>,
82
83    /// Training history
84    history: MetaLearningHistory,
85}
86
87/// Training history for meta-learning
88#[derive(Debug, Clone)]
89pub struct MetaLearningHistory {
90    /// Meta-train losses
91    pub meta_train_losses: Vec<f64>,
92
93    /// Meta-validation accuracies
94    pub meta_val_accuracies: Vec<f64>,
95
96    /// Per-task performance
97    pub task_performance: HashMap<String, Vec<f64>>,
98}
99
100impl QuantumMetaLearner {
101    /// Create a new quantum meta-learner
102    pub fn new(algorithm: MetaLearningAlgorithm, model: QuantumNeuralNetwork) -> Self {
103        let num_params = model.parameters.len();
104        let meta_params = model.parameters.clone();
105
106        let per_param_lr = match algorithm {
107            MetaLearningAlgorithm::MetaSGD { .. } => Some(Array1::from_elem(num_params, 0.01)),
108            _ => None,
109        };
110
111        Self {
112            algorithm,
113            model,
114            meta_params,
115            per_param_lr,
116            task_embeddings: HashMap::new(),
117            history: MetaLearningHistory {
118                meta_train_losses: Vec::new(),
119                meta_val_accuracies: Vec::new(),
120                task_performance: HashMap::new(),
121            },
122        }
123    }
124
125    /// Meta-train on multiple tasks
126    pub fn meta_train(
127        &mut self,
128        tasks: &[MetaTask],
129        meta_optimizer: &mut dyn Optimizer,
130        meta_epochs: usize,
131        tasks_per_batch: usize,
132    ) -> Result<()> {
133        println!("Starting meta-training with {} tasks...", tasks.len());
134
135        for epoch in 0..meta_epochs {
136            let mut epoch_loss = 0.0;
137            let mut epoch_acc = 0.0;
138
139            // Sample batch of tasks
140            let task_batch = self.sample_task_batch(tasks, tasks_per_batch);
141
142            // Perform meta-update based on algorithm
143            match self.algorithm {
144                MetaLearningAlgorithm::MAML { .. } => {
145                    let (loss, acc) = self.maml_update(&task_batch, meta_optimizer)?;
146                    epoch_loss += loss;
147                    epoch_acc += acc;
148                }
149                MetaLearningAlgorithm::Reptile { .. } => {
150                    let (loss, acc) = self.reptile_update(&task_batch, meta_optimizer)?;
151                    epoch_loss += loss;
152                    epoch_acc += acc;
153                }
154                MetaLearningAlgorithm::ProtoMAML { .. } => {
155                    let (loss, acc) = self.protomaml_update(&task_batch, meta_optimizer)?;
156                    epoch_loss += loss;
157                    epoch_acc += acc;
158                }
159                MetaLearningAlgorithm::MetaSGD { .. } => {
160                    let (loss, acc) = self.metasgd_update(&task_batch, meta_optimizer)?;
161                    epoch_loss += loss;
162                    epoch_acc += acc;
163                }
164                MetaLearningAlgorithm::ANIL { .. } => {
165                    let (loss, acc) = self.anil_update(&task_batch, meta_optimizer)?;
166                    epoch_loss += loss;
167                    epoch_acc += acc;
168                }
169            }
170
171            // Update history
172            self.history.meta_train_losses.push(epoch_loss);
173            self.history.meta_val_accuracies.push(epoch_acc);
174
175            if epoch % 10 == 0 {
176                println!(
177                    "Epoch {}: Loss = {:.4}, Accuracy = {:.2}%",
178                    epoch,
179                    epoch_loss,
180                    epoch_acc * 100.0
181                );
182            }
183        }
184
185        Ok(())
186    }
187
188    /// MAML update step
189    fn maml_update(
190        &mut self,
191        tasks: &[MetaTask],
192        optimizer: &mut dyn Optimizer,
193    ) -> Result<(f64, f64)> {
194        let (inner_steps, inner_lr, first_order) = match self.algorithm {
195            MetaLearningAlgorithm::MAML {
196                inner_steps,
197                inner_lr,
198                first_order,
199            } => (inner_steps, inner_lr, first_order),
200            _ => unreachable!(),
201        };
202
203        let mut total_loss = 0.0;
204        let mut total_acc = 0.0;
205        let mut meta_gradients = Array1::zeros(self.meta_params.len());
206
207        for task in tasks {
208            // Clone meta-parameters for inner loop
209            let mut task_params = self.meta_params.clone();
210
211            // Inner loop: adapt to task
212            for _ in 0..inner_steps {
213                let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
214                task_params = task_params - inner_lr * &grad;
215            }
216
217            // Compute loss on query set with adapted parameters
218            let (query_loss, query_acc) = self.evaluate_task(&task.test_data, &task_params)?;
219            total_loss += query_loss;
220            total_acc += query_acc;
221
222            // Compute meta-gradient
223            if !first_order {
224                // Full second-order MAML gradient
225                let meta_grad = self.compute_maml_gradient(task, &task_params, inner_lr)?;
226                meta_gradients = meta_gradients + meta_grad;
227            } else {
228                // First-order approximation (FO-MAML)
229                let grad = self.compute_task_gradient(&task.test_data, &task_params)?;
230                meta_gradients = meta_gradients + grad;
231            }
232        }
233
234        // Average gradients and update meta-parameters
235        meta_gradients = meta_gradients / tasks.len() as f64;
236        self.meta_params = self.meta_params.clone() - 0.001 * &meta_gradients; // Meta learning rate
237
238        Ok((
239            total_loss / tasks.len() as f64,
240            total_acc / tasks.len() as f64,
241        ))
242    }
243
244    /// Reptile update step
245    fn reptile_update(
246        &mut self,
247        tasks: &[MetaTask],
248        optimizer: &mut dyn Optimizer,
249    ) -> Result<(f64, f64)> {
250        let (inner_steps, inner_lr) = match self.algorithm {
251            MetaLearningAlgorithm::Reptile {
252                inner_steps,
253                inner_lr,
254            } => (inner_steps, inner_lr),
255            _ => unreachable!(),
256        };
257
258        let mut total_loss = 0.0;
259        let mut total_acc = 0.0;
260        let epsilon = 0.1; // Reptile step size
261
262        for task in tasks {
263            // Clone meta-parameters
264            let mut task_params = self.meta_params.clone();
265
266            // Perform multiple SGD steps on task
267            for _ in 0..inner_steps {
268                let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
269                task_params = task_params - inner_lr * &grad;
270            }
271
272            // Evaluate adapted model
273            let (loss, acc) = self.evaluate_task(&task.test_data, &task_params)?;
274            total_loss += loss;
275            total_acc += acc;
276
277            // Reptile update: move meta-params toward task-adapted params
278            let direction = &task_params - &self.meta_params;
279            self.meta_params = &self.meta_params + epsilon * &direction;
280        }
281
282        Ok((
283            total_loss / tasks.len() as f64,
284            total_acc / tasks.len() as f64,
285        ))
286    }
287
288    /// ProtoMAML update step
289    fn protomaml_update(
290        &mut self,
291        tasks: &[MetaTask],
292        optimizer: &mut dyn Optimizer,
293    ) -> Result<(f64, f64)> {
294        let (inner_steps, inner_lr, proto_weight) = match self.algorithm {
295            MetaLearningAlgorithm::ProtoMAML {
296                inner_steps,
297                inner_lr,
298                proto_weight,
299            } => (inner_steps, inner_lr, proto_weight),
300            _ => unreachable!(),
301        };
302
303        let mut total_loss = 0.0;
304        let mut total_acc = 0.0;
305
306        for task in tasks {
307            // Compute prototypes for each class
308            let prototypes = self.compute_prototypes(&task.train_data, task.num_classes)?;
309
310            // Clone parameters for adaptation
311            let mut task_params = self.meta_params.clone();
312
313            // Inner loop with prototype regularization
314            for _ in 0..inner_steps {
315                let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
316                let proto_reg =
317                    self.prototype_regularization(&task.train_data, &prototypes, &task_params)?;
318                task_params = task_params - inner_lr * (&grad + proto_weight * &proto_reg);
319            }
320
321            // Evaluate with prototypical classification
322            let (loss, acc) =
323                self.evaluate_with_prototypes(&task.test_data, &prototypes, &task_params)?;
324            total_loss += loss;
325            total_acc += acc;
326        }
327
328        Ok((
329            total_loss / tasks.len() as f64,
330            total_acc / tasks.len() as f64,
331        ))
332    }
333
334    /// Meta-SGD update step
335    fn metasgd_update(
336        &mut self,
337        tasks: &[MetaTask],
338        optimizer: &mut dyn Optimizer,
339    ) -> Result<(f64, f64)> {
340        let inner_steps = match self.algorithm {
341            MetaLearningAlgorithm::MetaSGD { inner_steps } => inner_steps,
342            _ => unreachable!(),
343        };
344
345        let mut total_loss = 0.0;
346        let mut total_acc = 0.0;
347        let mut meta_lr_gradients = Array1::zeros(self.meta_params.len());
348
349        for task in tasks {
350            let mut task_params = self.meta_params.clone();
351
352            // Inner loop with learnable per-parameter learning rates
353            for _ in 0..inner_steps {
354                let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
355                let lr = self
356                    .per_param_lr
357                    .as_ref()
358                    .expect("per_param_lr must be initialized for MetaSGD");
359                task_params = task_params - lr * &grad;
360            }
361
362            // Evaluate and compute gradients w.r.t. both parameters and learning rates
363            let (loss, acc) = self.evaluate_task(&task.test_data, &task_params)?;
364            total_loss += loss;
365            total_acc += acc;
366
367            // Compute gradient w.r.t. learning rates
368            let lr_grad = self.compute_lr_gradient(task, &task_params)?;
369            meta_lr_gradients = meta_lr_gradients + lr_grad;
370        }
371
372        // Update both parameters and learning rates
373        if let Some(ref mut lr) = self.per_param_lr {
374            *lr = lr.clone() - &(0.001 * &meta_lr_gradients / tasks.len() as f64);
375        }
376
377        Ok((
378            total_loss / tasks.len() as f64,
379            total_acc / tasks.len() as f64,
380        ))
381    }
382
383    /// ANIL update step
384    fn anil_update(
385        &mut self,
386        tasks: &[MetaTask],
387        optimizer: &mut dyn Optimizer,
388    ) -> Result<(f64, f64)> {
389        let (inner_steps, inner_lr) = match self.algorithm {
390            MetaLearningAlgorithm::ANIL {
391                inner_steps,
392                inner_lr,
393            } => (inner_steps, inner_lr),
394            _ => unreachable!(),
395        };
396
397        // ANIL: Only adapt the final layer(s) in inner loop
398        let num_params = self.meta_params.len();
399        let final_layer_start = (num_params * 3) / 4; // Last 25% of parameters
400
401        let mut total_loss = 0.0;
402        let mut total_acc = 0.0;
403
404        for task in tasks {
405            let mut task_params = self.meta_params.clone();
406
407            // Inner loop: only update final layer parameters
408            for _ in 0..inner_steps {
409                let grad = self.compute_task_gradient(&task.train_data, &task_params)?;
410
411                // Only update final layer
412                for i in final_layer_start..num_params {
413                    task_params[i] -= inner_lr * grad[i];
414                }
415            }
416
417            let (loss, acc) = self.evaluate_task(&task.test_data, &task_params)?;
418            total_loss += loss;
419            total_acc += acc;
420        }
421
422        Ok((
423            total_loss / tasks.len() as f64,
424            total_acc / tasks.len() as f64,
425        ))
426    }
427
428    /// Compute gradient for a task
429    fn compute_task_gradient(
430        &self,
431        data: &[(Array1<f64>, usize)],
432        params: &Array1<f64>,
433    ) -> Result<Array1<f64>> {
434        // Placeholder - would compute actual quantum gradients
435        Ok(Array1::zeros(params.len()))
436    }
437
438    /// Evaluate model on task data
439    fn evaluate_task(
440        &self,
441        data: &[(Array1<f64>, usize)],
442        params: &Array1<f64>,
443    ) -> Result<(f64, f64)> {
444        // Placeholder - would evaluate quantum model
445        let loss = 0.5 + 0.5 * thread_rng().gen::<f64>();
446        let acc = 0.5 + 0.3 * thread_rng().gen::<f64>();
447        Ok((loss, acc))
448    }
449
450    /// Compute MAML gradient (second-order)
451    fn compute_maml_gradient(
452        &self,
453        task: &MetaTask,
454        adapted_params: &Array1<f64>,
455        inner_lr: f64,
456    ) -> Result<Array1<f64>> {
457        // Placeholder - would compute Hessian-vector products
458        Ok(Array1::zeros(self.meta_params.len()))
459    }
460
461    /// Compute prototypes for ProtoMAML
462    fn compute_prototypes(
463        &self,
464        data: &[(Array1<f64>, usize)],
465        num_classes: usize,
466    ) -> Result<Vec<Array1<f64>>> {
467        let feature_dim = 16; // Placeholder dimension
468        let mut prototypes = vec![Array1::zeros(feature_dim); num_classes];
469        let mut counts = vec![0; num_classes];
470
471        // Placeholder - would encode data and compute class means
472        for (x, label) in data {
473            counts[*label] += 1;
474        }
475
476        Ok(prototypes)
477    }
478
479    /// Prototype regularization
480    fn prototype_regularization(
481        &self,
482        data: &[(Array1<f64>, usize)],
483        prototypes: &[Array1<f64>],
484        params: &Array1<f64>,
485    ) -> Result<Array1<f64>> {
486        // Placeholder - would compute prototype-based regularization
487        Ok(Array1::zeros(params.len()))
488    }
489
490    /// Evaluate with prototypical classification
491    fn evaluate_with_prototypes(
492        &self,
493        data: &[(Array1<f64>, usize)],
494        prototypes: &[Array1<f64>],
495        params: &Array1<f64>,
496    ) -> Result<(f64, f64)> {
497        // Placeholder
498        Ok((0.3, 0.7))
499    }
500
501    /// Compute gradient w.r.t. learning rates for Meta-SGD
502    fn compute_lr_gradient(
503        &self,
504        task: &MetaTask,
505        adapted_params: &Array1<f64>,
506    ) -> Result<Array1<f64>> {
507        // Placeholder
508        Ok(Array1::zeros(self.meta_params.len()))
509    }
510
511    /// Sample batch of tasks
512    fn sample_task_batch(&self, tasks: &[MetaTask], batch_size: usize) -> Vec<MetaTask> {
513        let mut batch = Vec::new();
514        let mut rng = thread_rng();
515
516        for _ in 0..batch_size.min(tasks.len()) {
517            let idx = rng.gen_range(0..tasks.len());
518            batch.push(tasks[idx].clone());
519        }
520
521        batch
522    }
523
524    /// Adapt to new task
525    pub fn adapt_to_task(&mut self, task: &MetaTask) -> Result<Array1<f64>> {
526        let adapted_params = match self.algorithm {
527            MetaLearningAlgorithm::MAML {
528                inner_steps,
529                inner_lr,
530                ..
531            }
532            | MetaLearningAlgorithm::Reptile {
533                inner_steps,
534                inner_lr,
535            }
536            | MetaLearningAlgorithm::ProtoMAML {
537                inner_steps,
538                inner_lr,
539                ..
540            }
541            | MetaLearningAlgorithm::ANIL {
542                inner_steps,
543                inner_lr,
544            } => {
545                let mut params = self.meta_params.clone();
546                for _ in 0..inner_steps {
547                    let grad = self.compute_task_gradient(&task.train_data, &params)?;
548                    params = params - inner_lr * &grad;
549                }
550                params
551            }
552            MetaLearningAlgorithm::MetaSGD { inner_steps } => {
553                let mut params = self.meta_params.clone();
554                let lr = self
555                    .per_param_lr
556                    .as_ref()
557                    .expect("per_param_lr must be initialized for MetaSGD");
558                for _ in 0..inner_steps {
559                    let grad = self.compute_task_gradient(&task.train_data, &params)?;
560                    params = params - lr * &grad;
561                }
562                params
563            }
564        };
565
566        Ok(adapted_params)
567    }
568
569    /// Get task embedding
570    pub fn get_task_embedding(&self, task_id: &str) -> Option<&Array1<f64>> {
571        self.task_embeddings.get(task_id)
572    }
573
574    /// Get meta parameters
575    pub fn meta_params(&self) -> &Array1<f64> {
576        &self.meta_params
577    }
578
579    /// Get per-parameter learning rates
580    pub fn per_param_lr(&self) -> Option<&Array1<f64>> {
581        self.per_param_lr.as_ref()
582    }
583}
584
585/// Continual meta-learning with memory
586pub struct ContinualMetaLearner {
587    /// Base meta-learner
588    meta_learner: QuantumMetaLearner,
589
590    /// Memory buffer for past tasks
591    memory_buffer: Vec<MetaTask>,
592
593    /// Maximum memory size
594    memory_capacity: usize,
595
596    /// Replay ratio
597    replay_ratio: f64,
598}
599
600impl ContinualMetaLearner {
601    /// Create new continual meta-learner
602    pub fn new(
603        meta_learner: QuantumMetaLearner,
604        memory_capacity: usize,
605        replay_ratio: f64,
606    ) -> Self {
607        Self {
608            meta_learner,
609            memory_buffer: Vec::new(),
610            memory_capacity,
611            replay_ratio,
612        }
613    }
614
615    /// Learn new task while preserving old knowledge
616    pub fn learn_task(&mut self, new_task: MetaTask) -> Result<()> {
617        // Add to memory with reservoir sampling
618        if self.memory_buffer.len() < self.memory_capacity {
619            self.memory_buffer.push(new_task.clone());
620        } else {
621            let idx = fastrand::usize(0..self.memory_buffer.len());
622            self.memory_buffer[idx] = new_task.clone();
623        }
624
625        // Create mixed batch with replay
626        let num_replay = (self.memory_buffer.len() as f64 * self.replay_ratio) as usize;
627        let mut task_batch = vec![new_task];
628
629        for _ in 0..num_replay {
630            let idx = fastrand::usize(0..self.memory_buffer.len());
631            task_batch.push(self.memory_buffer[idx].clone());
632        }
633
634        // Update meta-learner
635        let mut dummy_optimizer = crate::autodiff::optimizers::Adam::new(0.001);
636        self.meta_learner
637            .meta_train(&task_batch, &mut dummy_optimizer, 10, task_batch.len())?;
638
639        Ok(())
640    }
641
642    /// Get memory buffer length
643    pub fn memory_buffer_len(&self) -> usize {
644        self.memory_buffer.len()
645    }
646}
647
648/// Task generator for meta-learning experiments
649pub struct TaskGenerator {
650    /// Feature dimension
651    feature_dim: usize,
652
653    /// Number of classes per task
654    num_classes: usize,
655
656    /// Task distribution parameters
657    task_params: HashMap<String, f64>,
658}
659
660impl TaskGenerator {
661    /// Create new task generator
662    pub fn new(feature_dim: usize, num_classes: usize) -> Self {
663        Self {
664            feature_dim,
665            num_classes,
666            task_params: HashMap::new(),
667        }
668    }
669
670    /// Generate sinusoid regression task
671    pub fn generate_sinusoid_task(&self, num_samples: usize) -> MetaTask {
672        let amplitude = 0.1 + 4.9 * thread_rng().gen::<f64>();
673        let phase = 2.0 * std::f64::consts::PI * thread_rng().gen::<f64>();
674
675        let mut train_data = Vec::new();
676        let mut test_data = Vec::new();
677
678        // Generate samples
679        for i in 0..num_samples {
680            let x = -5.0 + 10.0 * thread_rng().gen::<f64>();
681            let y = amplitude * (x + phase).sin();
682
683            let input = Array1::from_vec(vec![x]);
684            let label = if y > 0.0 { 1 } else { 0 }; // Binarize for classification
685
686            if i < num_samples / 2 {
687                train_data.push((input, label));
688            } else {
689                test_data.push((input, label));
690            }
691        }
692
693        MetaTask {
694            id: format!("sin_a{:.2}_p{:.2}", amplitude, phase),
695            train_data,
696            test_data,
697            num_classes: 2,
698            metadata: vec![
699                ("amplitude".to_string(), amplitude),
700                ("phase".to_string(), phase),
701            ]
702            .into_iter()
703            .collect(),
704        }
705    }
706
707    /// Generate classification task with rotated features
708    pub fn generate_rotation_task(&self, num_samples: usize) -> MetaTask {
709        let angle = 2.0 * std::f64::consts::PI * thread_rng().gen::<f64>();
710        let cos_a = angle.cos();
711        let sin_a = angle.sin();
712
713        let mut train_data = Vec::new();
714        let mut test_data = Vec::new();
715
716        for i in 0..num_samples {
717            // Generate base features
718            let mut features = Array1::zeros(self.feature_dim);
719            let label = i % self.num_classes;
720
721            // Class-specific pattern
722            for j in 0..self.feature_dim {
723                features[j] = if j % self.num_classes == label {
724                    1.0
725                } else {
726                    0.0
727                };
728                features[j] += 0.1 * thread_rng().gen::<f64>();
729            }
730
731            // Apply rotation (simplified for first 2 dims)
732            if self.feature_dim >= 2 {
733                let x = features[0];
734                let y = features[1];
735                features[0] = cos_a * x - sin_a * y;
736                features[1] = sin_a * x + cos_a * y;
737            }
738
739            if i < num_samples / 2 {
740                train_data.push((features, label));
741            } else {
742                test_data.push((features, label));
743            }
744        }
745
746        MetaTask {
747            id: format!("rot_{:.2}", angle),
748            train_data,
749            test_data,
750            num_classes: self.num_classes,
751            metadata: vec![("rotation_angle".to_string(), angle)]
752                .into_iter()
753                .collect(),
754        }
755    }
756}
757
758#[cfg(test)]
759mod tests {
760    use super::*;
761    use crate::autodiff::optimizers::Adam;
762    use crate::qnn::QNNLayerType;
763
764    #[test]
765    fn test_task_generator() {
766        let generator = TaskGenerator::new(4, 2);
767
768        let sin_task = generator.generate_sinusoid_task(20);
769        assert_eq!(sin_task.train_data.len(), 10);
770        assert_eq!(sin_task.test_data.len(), 10);
771
772        let rot_task = generator.generate_rotation_task(30);
773        assert_eq!(rot_task.train_data.len(), 15);
774        assert_eq!(rot_task.test_data.len(), 15);
775    }
776
777    #[test]
778    fn test_meta_learner_creation() {
779        let layers = vec![
780            QNNLayerType::EncodingLayer { num_features: 4 },
781            QNNLayerType::VariationalLayer { num_params: 8 },
782            QNNLayerType::MeasurementLayer {
783                measurement_basis: "computational".to_string(),
784            },
785        ];
786
787        let qnn = QuantumNeuralNetwork::new(layers, 4, 4, 2).expect("Failed to create QNN");
788
789        let maml_algo = MetaLearningAlgorithm::MAML {
790            inner_steps: 5,
791            inner_lr: 0.01,
792            first_order: true,
793        };
794
795        let meta_learner = QuantumMetaLearner::new(maml_algo, qnn);
796        assert!(meta_learner.per_param_lr.is_none());
797
798        // Test Meta-SGD
799        let layers2 = vec![
800            QNNLayerType::EncodingLayer { num_features: 4 },
801            QNNLayerType::VariationalLayer { num_params: 8 },
802        ];
803        let qnn2 =
804            QuantumNeuralNetwork::new(layers2, 4, 4, 2).expect("Failed to create QNN for Meta-SGD");
805
806        let metasgd_algo = MetaLearningAlgorithm::MetaSGD { inner_steps: 3 };
807        let meta_sgd = QuantumMetaLearner::new(metasgd_algo, qnn2);
808        assert!(meta_sgd.per_param_lr.is_some());
809    }
810
811    #[test]
812    fn test_task_adaptation() {
813        let layers = vec![
814            QNNLayerType::EncodingLayer { num_features: 2 },
815            QNNLayerType::VariationalLayer { num_params: 6 },
816        ];
817
818        let qnn = QuantumNeuralNetwork::new(layers, 4, 2, 2).expect("Failed to create QNN");
819        let algo = MetaLearningAlgorithm::Reptile {
820            inner_steps: 5,
821            inner_lr: 0.01,
822        };
823
824        let mut meta_learner = QuantumMetaLearner::new(algo, qnn);
825
826        // Generate task
827        let generator = TaskGenerator::new(2, 2);
828        let task = generator.generate_rotation_task(20);
829
830        // Adapt to task
831        let adapted_params = meta_learner
832            .adapt_to_task(&task)
833            .expect("Task adaptation should succeed");
834        assert_eq!(adapted_params.len(), meta_learner.meta_params.len());
835    }
836}