optirs_learned/transformer_based_optimizer/
meta_learning.rs

1// Meta-learning components for transformer-based optimization
2
3use super::config::{ActivationFunction, MetaLearningConfig};
4use super::feedforward::FeedForwardNetwork;
5use super::layers::LayerNormalization;
6use crate::error::Result;
7use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
8use scirs2_core::numeric::Float;
9use std::collections::{HashMap, VecDeque};
10use std::fmt::Debug;
11use std::time::Instant;
12
13/// Meta-learning strategy types
14#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum MetaLearningStrategy {
16    /// Model-Agnostic Meta-Learning (MAML)
17    MAML,
18    /// First-Order MAML (FOMAML)
19    FOMAML,
20    /// Reptile algorithm
21    Reptile,
22    /// Gradient-based meta-learning
23    GradientBased,
24    /// Memory-augmented networks
25    MemoryAugmented,
26}
27
28/// Transformer meta-learning implementation
29pub struct TransformerMetaLearning<T: Float + Debug + Send + Sync + 'static> {
30    /// Meta-learning strategy
31    strategy: MetaLearningStrategy,
32
33    /// Configuration
34    config: MetaLearningConfig<T>,
35
36    /// Meta-optimizer for outer loop
37    meta_optimizer: MetaOptimizer<T>,
38
39    /// Task adaptation network
40    adaptation_network: AdaptationNetwork<T>,
41
42    /// Memory bank for storing task experiences
43    memory_bank: MemoryBank<T>,
44
45    /// Performance tracker
46    performance_tracker: PerformanceTracker<T>,
47
48    /// Current meta-learning state
49    meta_state: MetaState<T>,
50}
51
52impl<T: Float + Debug + Send + Sync + 'static + scirs2_core::ndarray::ScalarOperand>
53    TransformerMetaLearning<T>
54{
55    /// Create new transformer meta-learning component
56    pub fn new(config: &super::config::TransformerBasedOptimizerConfig<T>) -> Result<Self> {
57        let meta_config = config.meta_learning_config.clone();
58        let strategy = MetaLearningStrategy::MAML; // Default strategy
59
60        let meta_optimizer = MetaOptimizer::new(&meta_config)?;
61        let adaptation_network =
62            AdaptationNetwork::new(config.model_dimension, config.feedforward_dimension)?;
63        let memory_bank = MemoryBank::new(1000, config.model_dimension)?;
64        let performance_tracker = PerformanceTracker::new();
65        let meta_state = MetaState::new(config.model_dimension)?;
66
67        Ok(Self {
68            strategy,
69            config: meta_config,
70            meta_optimizer,
71            adaptation_network,
72            memory_bank,
73            performance_tracker,
74            meta_state,
75        })
76    }
77
78    /// Perform meta-learning step
79    pub fn meta_step(
80        &mut self,
81        tasks: &[TaskBatch<T>],
82        support_data: &[Array2<T>],
83        query_data: &[Array2<T>],
84    ) -> Result<MetaLearningResult<T>> {
85        match self.strategy {
86            MetaLearningStrategy::MAML => self.maml_step(tasks, support_data, query_data),
87            MetaLearningStrategy::FOMAML => self.fomaml_step(tasks, support_data, query_data),
88            MetaLearningStrategy::Reptile => self.reptile_step(tasks, support_data, query_data),
89            MetaLearningStrategy::GradientBased => {
90                self.gradient_based_step(tasks, support_data, query_data)
91            }
92            MetaLearningStrategy::MemoryAugmented => {
93                self.memory_augmented_step(tasks, support_data, query_data)
94            }
95        }
96    }
97
98    /// MAML meta-learning step
99    fn maml_step(
100        &mut self,
101        tasks: &[TaskBatch<T>],
102        support_data: &[Array2<T>],
103        query_data: &[Array2<T>],
104    ) -> Result<MetaLearningResult<T>> {
105        let start_time = Instant::now();
106        let mut total_loss = T::zero();
107        let mut task_adaptations = Vec::new();
108
109        for (i, task) in tasks.iter().enumerate() {
110            // Inner adaptation loop
111            let mut adapted_params = self.meta_state.get_parameters().clone();
112
113            for inner_step in 0..self.config.inner_steps {
114                // Compute gradients on support set
115                let support_loss =
116                    self.compute_task_loss(&adapted_params, &support_data[i], task)?;
117                let gradients = self.compute_gradients(&adapted_params, support_loss)?;
118
119                // Update parameters
120                for (param, grad) in adapted_params.iter_mut().zip(gradients.iter()) {
121                    *param = *param - self.config.inner_learning_rate * (*grad);
122                }
123            }
124
125            // Evaluate on query set
126            let query_loss = self.compute_task_loss(&adapted_params, &query_data[i], task)?;
127            total_loss = total_loss + query_loss;
128
129            task_adaptations.push(TaskAdaptation {
130                task_id: task.id.clone(),
131                adapted_parameters: adapted_params,
132                support_loss: self.compute_task_loss(
133                    self.meta_state.get_parameters(),
134                    &support_data[i],
135                    task,
136                )?,
137                query_loss,
138                adaptation_steps: self.config.inner_steps,
139            });
140        }
141
142        // Meta-update
143        let meta_loss = total_loss / T::from(tasks.len()).unwrap();
144        let meta_gradients = self.compute_meta_gradients(&task_adaptations)?;
145        self.meta_optimizer
146            .update(&mut self.meta_state, &meta_gradients)?;
147
148        // Update memory bank
149        for (i, adaptation) in task_adaptations.iter().enumerate() {
150            self.memory_bank.store_experience(
151                &tasks[i],
152                &adaptation.adapted_parameters,
153                adaptation.query_loss,
154            )?;
155        }
156
157        let result = MetaLearningResult {
158            meta_loss: meta_loss.to_f64().unwrap_or(0.0),
159            task_adaptations,
160            computation_time: start_time.elapsed(),
161            convergence_rate: self.estimate_convergence_rate()?,
162        };
163
164        self.performance_tracker.record_meta_step(result.clone());
165        Ok(result)
166    }
167
168    /// First-order MAML step (FOMAML)
169    fn fomaml_step(
170        &mut self,
171        tasks: &[TaskBatch<T>],
172        support_data: &[Array2<T>],
173        query_data: &[Array2<T>],
174    ) -> Result<MetaLearningResult<T>> {
175        // Simplified version of MAML that ignores second-order derivatives
176        // Similar to MAML but uses first-order approximation for efficiency
177        self.maml_step(tasks, support_data, query_data)
178    }
179
180    /// Reptile meta-learning step
181    fn reptile_step(
182        &mut self,
183        tasks: &[TaskBatch<T>],
184        support_data: &[Array2<T>],
185        query_data: &[Array2<T>],
186    ) -> Result<MetaLearningResult<T>> {
187        let start_time = Instant::now();
188        let mut parameter_updates = Vec::new();
189        let mut total_loss = T::zero();
190
191        for (i, task) in tasks.iter().enumerate() {
192            // Adapt on support set
193            let mut adapted_params = self.meta_state.get_parameters().clone();
194
195            for _ in 0..self.config.inner_steps {
196                let loss = self.compute_task_loss(&adapted_params, &support_data[i], task)?;
197                let gradients = self.compute_gradients(&adapted_params, loss)?;
198
199                for (param, grad) in adapted_params.iter_mut().zip(gradients.iter()) {
200                    *param = *param - self.config.inner_learning_rate * (*grad);
201                }
202            }
203
204            // Compute parameter difference for meta-update
205            let original_params = self.meta_state.get_parameters();
206            let param_diff: Vec<T> = adapted_params
207                .iter()
208                .zip(original_params.iter())
209                .map(|(adapted, original)| *adapted - *original)
210                .collect();
211
212            parameter_updates.push(param_diff);
213
214            // Evaluate on query set
215            let query_loss = self.compute_task_loss(&adapted_params, &query_data[i], task)?;
216            total_loss = total_loss + query_loss;
217        }
218
219        // Meta-update: move towards average of adapted parameters
220        let mut meta_update = vec![T::zero(); self.meta_state.get_parameters().len()];
221        for param_update in &parameter_updates {
222            for (i, &update) in param_update.iter().enumerate() {
223                meta_update[i] = meta_update[i] + update;
224            }
225        }
226
227        let num_tasks = T::from(tasks.len()).unwrap();
228        for update in meta_update.iter_mut() {
229            *update = *update / num_tasks;
230        }
231
232        self.meta_state
233            .update_parameters(&meta_update, self.config.meta_learning_rate)?;
234
235        let result = MetaLearningResult {
236            meta_loss: (total_loss / num_tasks).to_f64().unwrap_or(0.0),
237            task_adaptations: Vec::new(), // Reptile doesn't track individual adaptations
238            computation_time: start_time.elapsed(),
239            convergence_rate: self.estimate_convergence_rate()?,
240        };
241
242        self.performance_tracker.record_meta_step(result.clone());
243        Ok(result)
244    }
245
246    /// Gradient-based meta-learning step
247    fn gradient_based_step(
248        &mut self,
249        tasks: &[TaskBatch<T>],
250        support_data: &[Array2<T>],
251        query_data: &[Array2<T>],
252    ) -> Result<MetaLearningResult<T>> {
253        // Implement gradient-based meta-learning with direct optimization
254        let start_time = Instant::now();
255        let mut total_loss = T::zero();
256
257        for (i, task) in tasks.iter().enumerate() {
258            // Use adaptation network to predict good initialization
259            let context_embedding = self.adaptation_network.encode_task_context(task)?;
260            let predicted_params = self
261                .adaptation_network
262                .predict_parameters(&context_embedding)?;
263
264            // Fine-tune predicted parameters
265            let mut adapted_params = predicted_params;
266            for _ in 0..self.config.inner_steps {
267                let loss = self.compute_task_loss(&adapted_params, &support_data[i], task)?;
268                let gradients = self.compute_gradients(&adapted_params, loss)?;
269
270                for (param, grad) in adapted_params.iter_mut().zip(gradients.iter()) {
271                    *param = *param - self.config.inner_learning_rate * (*grad);
272                }
273            }
274
275            let query_loss = self.compute_task_loss(&adapted_params, &query_data[i], task)?;
276            total_loss = total_loss + query_loss;
277        }
278
279        let result = MetaLearningResult {
280            meta_loss: (total_loss / T::from(tasks.len()).unwrap())
281                .to_f64()
282                .unwrap_or(0.0),
283            task_adaptations: Vec::new(),
284            computation_time: start_time.elapsed(),
285            convergence_rate: self.estimate_convergence_rate()?,
286        };
287
288        Ok(result)
289    }
290
291    /// Memory-augmented meta-learning step
292    fn memory_augmented_step(
293        &mut self,
294        tasks: &[TaskBatch<T>],
295        support_data: &[Array2<T>],
296        query_data: &[Array2<T>],
297    ) -> Result<MetaLearningResult<T>> {
298        let start_time = Instant::now();
299        let mut total_loss = T::zero();
300
301        for (i, task) in tasks.iter().enumerate() {
302            // Retrieve relevant experiences from memory
303            let relevant_experiences = self.memory_bank.retrieve_similar_experiences(task, 5)?;
304
305            // Use memory to initialize adaptation
306            let memory_guided_params = self.initialize_from_memory(&relevant_experiences)?;
307
308            let mut adapted_params = memory_guided_params;
309            for _ in 0..self.config.inner_steps {
310                let loss = self.compute_task_loss(&adapted_params, &support_data[i], task)?;
311                let gradients = self.compute_gradients(&adapted_params, loss)?;
312
313                for (param, grad) in adapted_params.iter_mut().zip(gradients.iter()) {
314                    *param = *param - self.config.inner_learning_rate * (*grad);
315                }
316            }
317
318            let query_loss = self.compute_task_loss(&adapted_params, &query_data[i], task)?;
319            total_loss = total_loss + query_loss;
320
321            // Store experience
322            self.memory_bank
323                .store_experience(task, &adapted_params, query_loss)?;
324        }
325
326        let result = MetaLearningResult {
327            meta_loss: (total_loss / T::from(tasks.len()).unwrap())
328                .to_f64()
329                .unwrap_or(0.0),
330            task_adaptations: Vec::new(),
331            computation_time: start_time.elapsed(),
332            convergence_rate: self.estimate_convergence_rate()?,
333        };
334
335        Ok(result)
336    }
337
338    /// Generate optimization update using meta-learned parameters
339    pub fn generate_update(
340        &mut self,
341        transformer_output: &Array2<T>,
342        current_parameters: &Array1<T>,
343    ) -> Result<Array1<T>> {
344        // Use adaptation network to generate parameter updates
345        let update = self
346            .adaptation_network
347            .generate_parameter_update(transformer_output, current_parameters)?;
348
349        // Apply meta-learned scaling
350        let scaled_update = self.apply_meta_scaling(&update)?;
351
352        Ok(scaled_update)
353    }
354
355    /// Update meta-learning state from loss
356    pub fn update_from_loss(&mut self, loss: T) -> Result<()> {
357        self.meta_state.update_loss_history(loss);
358        self.performance_tracker
359            .record_loss(loss.to_f64().unwrap_or(0.0));
360        Ok(())
361    }
362
363    /// Set meta-learning strategy
364    pub fn set_strategy(&mut self, strategy: MetaLearningStrategy) {
365        self.strategy = strategy;
366    }
367
368    /// Get current strategy
369    pub fn get_strategy(&self) -> MetaLearningStrategy {
370        self.strategy
371    }
372
373    /// Helper methods
374    fn compute_task_loss(&self, params: &[T], data: &Array2<T>, task: &TaskBatch<T>) -> Result<T> {
375        // Simplified loss computation
376        let prediction_error = self.compute_prediction_error(params, data, task)?;
377        Ok(prediction_error)
378    }
379
380    fn compute_prediction_error(
381        &self,
382        _params: &[T],
383        data: &Array2<T>,
384        _task: &TaskBatch<T>,
385    ) -> Result<T> {
386        // Placeholder: compute actual prediction error
387        let mean_squared_error = data
388            .iter()
389            .map(|&x| x * x)
390            .fold(T::zero(), |acc, x| acc + x);
391        Ok(mean_squared_error / T::from(data.len()).unwrap())
392    }
393
394    fn compute_gradients(&self, params: &[T], loss: T) -> Result<Vec<T>> {
395        // Simplified gradient computation
396        let gradients = params
397            .iter()
398            .map(|_| loss / T::from(params.len()).unwrap())
399            .collect();
400        Ok(gradients)
401    }
402
403    fn compute_meta_gradients(&self, adaptations: &[TaskAdaptation<T>]) -> Result<Vec<T>> {
404        let param_count = adaptations[0].adapted_parameters.len();
405        let mut meta_gradients = vec![T::zero(); param_count];
406
407        for adaptation in adaptations {
408            for (i, &param) in adaptation.adapted_parameters.iter().enumerate() {
409                meta_gradients[i] = meta_gradients[i] + param * adaptation.query_loss;
410            }
411        }
412
413        let num_tasks = T::from(adaptations.len()).unwrap();
414        for grad in meta_gradients.iter_mut() {
415            *grad = *grad / num_tasks;
416        }
417
418        Ok(meta_gradients)
419    }
420
421    fn estimate_convergence_rate(&self) -> Result<f64> {
422        let loss_history = self.performance_tracker.get_loss_history();
423        if loss_history.len() < 2 {
424            return Ok(0.0);
425        }
426
427        let recent_losses: Vec<_> = loss_history.iter().rev().take(5).cloned().collect();
428        let improvement = recent_losses.last().unwrap() - recent_losses.first().unwrap();
429        Ok(improvement.clamp(0.0, 1.0))
430    }
431
432    fn apply_meta_scaling(&self, update: &Array1<T>) -> Result<Array1<T>> {
433        // Apply learned scaling factors
434        let scale_factor = self.meta_state.get_scale_factor();
435        Ok(update * scale_factor)
436    }
437
438    fn initialize_from_memory(&self, experiences: &[MemoryExperience<T>]) -> Result<Vec<T>> {
439        if experiences.is_empty() {
440            return Ok(self.meta_state.get_parameters().clone());
441        }
442
443        // Average parameters from similar experiences
444        let param_count = experiences[0].parameters.len();
445        let mut averaged_params = vec![T::zero(); param_count];
446
447        for experience in experiences {
448            for (i, &param) in experience.parameters.iter().enumerate() {
449                averaged_params[i] = averaged_params[i] + param;
450            }
451        }
452
453        let num_experiences = T::from(experiences.len()).unwrap();
454        for param in averaged_params.iter_mut() {
455            *param = *param / num_experiences;
456        }
457
458        Ok(averaged_params)
459    }
460}
461
462/// Meta-optimizer for outer loop updates
463pub struct MetaOptimizer<T: Float + Debug + Send + Sync + 'static> {
464    /// Learning rate
465    learning_rate: T,
466
467    /// Momentum for SGD-style updates
468    momentum: Option<T>,
469
470    /// Velocity for momentum
471    velocity: Option<Vec<T>>,
472}
473
474impl<T: Float + Debug + Send + Sync + 'static> MetaOptimizer<T> {
475    pub fn new(config: &MetaLearningConfig<T>) -> Result<Self> {
476        Ok(Self {
477            learning_rate: config.meta_learning_rate,
478            momentum: None,
479            velocity: None,
480        })
481    }
482
483    pub fn update(&mut self, state: &mut MetaState<T>, gradients: &[T]) -> Result<()> {
484        let params = state.get_parameters_mut();
485
486        if let Some(momentum) = self.momentum {
487            // Momentum update
488            if self.velocity.is_none() {
489                self.velocity = Some(vec![T::zero(); params.len()]);
490            }
491
492            if let Some(ref mut velocity) = self.velocity {
493                for i in 0..params.len() {
494                    velocity[i] = momentum * velocity[i] + self.learning_rate * gradients[i];
495                    params[i] = params[i] - velocity[i];
496                }
497            }
498        } else {
499            // Simple SGD update
500            for (param, &grad) in params.iter_mut().zip(gradients.iter()) {
501                *param = *param - self.learning_rate * grad;
502            }
503        }
504
505        Ok(())
506    }
507}
508
509/// Task adaptation network
510pub struct AdaptationNetwork<T: Float + Debug + Send + Sync + 'static> {
511    /// Context encoder
512    context_encoder: FeedForwardNetwork<T>,
513
514    /// Parameter predictor
515    parameter_predictor: FeedForwardNetwork<T>,
516
517    /// Update generator
518    update_generator: FeedForwardNetwork<T>,
519
520    /// Model dimension
521    model_dimension: usize,
522}
523
524impl<T: Float + Debug + Send + Sync + 'static> AdaptationNetwork<T> {
525    pub fn new(model_dimension: usize, hidden_dimension: usize) -> Result<Self> {
526        let context_encoder =
527            FeedForwardNetwork::new(model_dimension, hidden_dimension, ActivationFunction::ReLU)?;
528
529        let parameter_predictor =
530            FeedForwardNetwork::new(hidden_dimension, model_dimension, ActivationFunction::Tanh)?;
531
532        let update_generator = FeedForwardNetwork::new(
533            model_dimension * 2, // concatenated transformer output and current params
534            model_dimension,
535            ActivationFunction::ReLU,
536        )?;
537
538        Ok(Self {
539            context_encoder,
540            parameter_predictor,
541            update_generator,
542            model_dimension,
543        })
544    }
545
546    pub fn encode_task_context(&mut self, _task: &TaskBatch<T>) -> Result<Array2<T>> {
547        // Encode task characteristics into context vector
548        let task_features = Array2::ones((1, self.model_dimension));
549        self.context_encoder.forward(&task_features)
550    }
551
552    pub fn predict_parameters(&mut self, context: &Array2<T>) -> Result<Vec<T>> {
553        let predicted = self.parameter_predictor.forward(context)?;
554        Ok(predicted.row(0).to_vec())
555    }
556
557    pub fn generate_parameter_update(
558        &mut self,
559        transformer_output: &Array2<T>,
560        current_parameters: &Array1<T>,
561    ) -> Result<Array1<T>> {
562        // Concatenate transformer output and current parameters
563        let batch_size = transformer_output.shape()[0];
564        let mut input = Array2::zeros((batch_size, self.model_dimension * 2));
565
566        for i in 0..batch_size {
567            for j in 0..self.model_dimension {
568                input[[i, j]] = transformer_output[[i, j]];
569                if j < current_parameters.len() {
570                    input[[i, j + self.model_dimension]] = current_parameters[j];
571                }
572            }
573        }
574
575        let update = self.update_generator.forward(&input)?;
576        Ok(update.row(0).to_owned())
577    }
578}
579
580/// Memory bank for storing task experiences
581pub struct MemoryBank<T: Float + Debug + Send + Sync + 'static> {
582    /// Stored experiences
583    experiences: VecDeque<MemoryExperience<T>>,
584
585    /// Maximum memory size
586    max_size: usize,
587
588    /// Dimension of stored parameters
589    parameter_dimension: usize,
590}
591
592impl<T: Float + Debug + Send + Sync + 'static> MemoryBank<T> {
593    pub fn new(max_size: usize, parameter_dimension: usize) -> Result<Self> {
594        Ok(Self {
595            experiences: VecDeque::new(),
596            max_size,
597            parameter_dimension,
598        })
599    }
600
601    pub fn store_experience(
602        &mut self,
603        task: &TaskBatch<T>,
604        parameters: &[T],
605        performance: T,
606    ) -> Result<()> {
607        let experience = MemoryExperience {
608            task_signature: self.compute_task_signature(task),
609            parameters: parameters.to_vec(),
610            performance: performance.to_f64().unwrap_or(0.0),
611            timestamp: Instant::now(),
612        };
613
614        self.experiences.push_back(experience);
615
616        if self.experiences.len() > self.max_size {
617            self.experiences.pop_front();
618        }
619
620        Ok(())
621    }
622
623    pub fn retrieve_similar_experiences(
624        &self,
625        task: &TaskBatch<T>,
626        k: usize,
627    ) -> Result<Vec<MemoryExperience<T>>> {
628        let target_signature = self.compute_task_signature(task);
629
630        let mut scored_experiences: Vec<_> = self
631            .experiences
632            .iter()
633            .map(|exp| {
634                let similarity = self.compute_similarity(&target_signature, &exp.task_signature);
635                (similarity, exp.clone())
636            })
637            .collect();
638
639        scored_experiences.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
640
641        Ok(scored_experiences
642            .into_iter()
643            .take(k)
644            .map(|(_, exp)| exp)
645            .collect())
646    }
647
648    fn compute_task_signature(&self, task: &TaskBatch<T>) -> Vec<f64> {
649        // Simplified task signature
650        vec![task.difficulty, task.complexity, task.data_characteristics]
651    }
652
653    fn compute_similarity(&self, sig1: &[f64], sig2: &[f64]) -> f64 {
654        if sig1.len() != sig2.len() {
655            return 0.0;
656        }
657
658        let dot_product: f64 = sig1.iter().zip(sig2.iter()).map(|(a, b)| a * b).sum();
659        let norm1: f64 = sig1.iter().map(|x| x * x).sum::<f64>().sqrt();
660        let norm2: f64 = sig2.iter().map(|x| x * x).sum::<f64>().sqrt();
661
662        if norm1 == 0.0 || norm2 == 0.0 {
663            0.0
664        } else {
665            dot_product / (norm1 * norm2)
666        }
667    }
668}
669
670/// Supporting data structures
671#[derive(Debug, Clone)]
672pub struct TaskBatch<T: Float + Debug + Send + Sync + 'static> {
673    pub id: String,
674    pub difficulty: f64,
675    pub complexity: f64,
676    pub data_characteristics: f64,
677    pub _phantom: std::marker::PhantomData<T>,
678}
679
680#[derive(Debug, Clone)]
681pub struct TaskAdaptation<T: Float + Debug + Send + Sync + 'static> {
682    pub task_id: String,
683    pub adapted_parameters: Vec<T>,
684    pub support_loss: T,
685    pub query_loss: T,
686    pub adaptation_steps: usize,
687}
688
689#[derive(Debug, Clone)]
690pub struct MetaLearningResult<T: Float + Debug + Send + Sync + 'static> {
691    pub meta_loss: f64,
692    pub task_adaptations: Vec<TaskAdaptation<T>>,
693    pub computation_time: std::time::Duration,
694    pub convergence_rate: f64,
695}
696
697#[derive(Debug, Clone)]
698pub struct MemoryExperience<T: Float + Debug + Send + Sync + 'static> {
699    pub task_signature: Vec<f64>,
700    pub parameters: Vec<T>,
701    pub performance: f64,
702    pub timestamp: Instant,
703}
704
705pub struct PerformanceTracker<T: Float + Debug + Send + Sync + 'static> {
706    loss_history: VecDeque<f64>,
707    meta_results: VecDeque<MetaLearningResult<T>>,
708}
709
710impl<T: Float + Debug + Send + Sync + 'static> Default for PerformanceTracker<T> {
711    fn default() -> Self {
712        Self::new()
713    }
714}
715
716impl<T: Float + Debug + Send + Sync + 'static> PerformanceTracker<T> {
717    pub fn new() -> Self {
718        Self {
719            loss_history: VecDeque::new(),
720            meta_results: VecDeque::new(),
721        }
722    }
723
724    pub fn record_loss(&mut self, loss: f64) {
725        self.loss_history.push_back(loss);
726        if self.loss_history.len() > 1000 {
727            self.loss_history.pop_front();
728        }
729    }
730
731    pub fn record_meta_step(&mut self, result: MetaLearningResult<T>) {
732        self.meta_results.push_back(result);
733        if self.meta_results.len() > 100 {
734            self.meta_results.pop_front();
735        }
736    }
737
738    pub fn get_loss_history(&self) -> &VecDeque<f64> {
739        &self.loss_history
740    }
741}
742
743#[derive(Debug, Clone)]
744pub struct MetaState<T: Float + Debug + Send + Sync + 'static> {
745    parameters: Vec<T>,
746    loss_history: VecDeque<T>,
747    scale_factor: T,
748}
749
750impl<T: Float + Debug + Send + Sync + 'static> MetaState<T> {
751    pub fn new(parameter_count: usize) -> Result<Self> {
752        Ok(Self {
753            parameters: vec![T::zero(); parameter_count],
754            loss_history: VecDeque::new(),
755            scale_factor: T::one(),
756        })
757    }
758
759    pub fn get_parameters(&self) -> &Vec<T> {
760        &self.parameters
761    }
762
763    pub fn get_parameters_mut(&mut self) -> &mut Vec<T> {
764        &mut self.parameters
765    }
766
767    pub fn update_parameters(&mut self, updates: &[T], learning_rate: T) -> Result<()> {
768        for (param, &update) in self.parameters.iter_mut().zip(updates.iter()) {
769            *param = *param + learning_rate * update;
770        }
771        Ok(())
772    }
773
774    pub fn update_loss_history(&mut self, loss: T) {
775        self.loss_history.push_back(loss);
776        if self.loss_history.len() > 100 {
777            self.loss_history.pop_front();
778        }
779    }
780
781    pub fn get_scale_factor(&self) -> T {
782        self.scale_factor
783    }
784}
785
786#[cfg(test)]
787mod tests {
788    use super::*;
789
790    #[test]
791    #[ignore]
792    fn test_meta_learning_creation() {
793        let config = super::super::config::TransformerBasedOptimizerConfig::<f32>::default();
794        let meta_learning = TransformerMetaLearning::new(&config);
795        assert!(meta_learning.is_ok());
796    }
797
798    #[test]
799    fn test_memory_bank() {
800        let memory = MemoryBank::<f32>::new(100, 64);
801        assert!(memory.is_ok());
802
803        let mut bank = memory.unwrap();
804        let task = TaskBatch {
805            id: "test".to_string(),
806            difficulty: 0.5,
807            complexity: 0.7,
808            data_characteristics: 0.3,
809            _phantom: std::marker::PhantomData,
810        };
811
812        let params = vec![0.1f32; 64];
813        assert!(bank.store_experience(&task, &params, 0.8).is_ok());
814    }
815
816    #[test]
817    fn test_adaptation_network() {
818        let network = AdaptationNetwork::<f32>::new(128, 256);
819        assert!(network.is_ok());
820
821        let mut net = network.unwrap();
822        let task = TaskBatch {
823            id: "test".to_string(),
824            difficulty: 0.5,
825            complexity: 0.7,
826            data_characteristics: 0.3,
827            _phantom: std::marker::PhantomData,
828        };
829
830        let context = net.encode_task_context(&task);
831        assert!(context.is_ok());
832    }
833}