Skip to main content

optirs_learned/transformer_based_optimizer/
meta_learning.rs

1// Meta-learning components for transformer-based optimization
2
3use super::config::{ActivationFunction, MetaLearningConfig};
4use super::feedforward::FeedForwardNetwork;
5use super::layers::LayerNormalization;
6use crate::error::Result;
7use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
8use scirs2_core::numeric::Float;
9use std::collections::{HashMap, VecDeque};
10use std::fmt::Debug;
11use std::time::Instant;
12
13/// Meta-learning strategy types
14#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum MetaLearningStrategy {
16    /// Model-Agnostic Meta-Learning (MAML)
17    MAML,
18    /// First-Order MAML (FOMAML)
19    FOMAML,
20    /// Reptile algorithm
21    Reptile,
22    /// Gradient-based meta-learning
23    GradientBased,
24    /// Memory-augmented networks
25    MemoryAugmented,
26}
27
28/// Transformer meta-learning implementation
29pub struct TransformerMetaLearning<T: Float + Debug + Send + Sync + 'static> {
30    /// Meta-learning strategy
31    strategy: MetaLearningStrategy,
32
33    /// Configuration
34    config: MetaLearningConfig<T>,
35
36    /// Meta-optimizer for outer loop
37    meta_optimizer: MetaOptimizer<T>,
38
39    /// Task adaptation network
40    adaptation_network: AdaptationNetwork<T>,
41
42    /// Memory bank for storing task experiences
43    memory_bank: MemoryBank<T>,
44
45    /// Performance tracker
46    performance_tracker: PerformanceTracker<T>,
47
48    /// Current meta-learning state
49    meta_state: MetaState<T>,
50}
51
52impl<T: Float + Debug + Send + Sync + 'static + scirs2_core::ndarray::ScalarOperand>
53    TransformerMetaLearning<T>
54{
55    /// Create new transformer meta-learning component
56    pub fn new(config: &super::config::TransformerBasedOptimizerConfig<T>) -> Result<Self> {
57        let meta_config = config.meta_learning_config.clone();
58        let strategy = MetaLearningStrategy::MAML; // Default strategy
59
60        let meta_optimizer = MetaOptimizer::new(&meta_config)?;
61        let adaptation_network =
62            AdaptationNetwork::new(config.model_dimension, config.feedforward_dimension)?;
63        let memory_bank = MemoryBank::new(1000, config.model_dimension)?;
64        let performance_tracker = PerformanceTracker::new();
65        let meta_state = MetaState::new(config.model_dimension)?;
66
67        Ok(Self {
68            strategy,
69            config: meta_config,
70            meta_optimizer,
71            adaptation_network,
72            memory_bank,
73            performance_tracker,
74            meta_state,
75        })
76    }
77
78    /// Perform meta-learning step
79    pub fn meta_step(
80        &mut self,
81        tasks: &[TaskBatch<T>],
82        support_data: &[Array2<T>],
83        query_data: &[Array2<T>],
84    ) -> Result<MetaLearningResult<T>> {
85        match self.strategy {
86            MetaLearningStrategy::MAML => self.maml_step(tasks, support_data, query_data),
87            MetaLearningStrategy::FOMAML => self.fomaml_step(tasks, support_data, query_data),
88            MetaLearningStrategy::Reptile => self.reptile_step(tasks, support_data, query_data),
89            MetaLearningStrategy::GradientBased => {
90                self.gradient_based_step(tasks, support_data, query_data)
91            }
92            MetaLearningStrategy::MemoryAugmented => {
93                self.memory_augmented_step(tasks, support_data, query_data)
94            }
95        }
96    }
97
98    /// MAML meta-learning step
99    fn maml_step(
100        &mut self,
101        tasks: &[TaskBatch<T>],
102        support_data: &[Array2<T>],
103        query_data: &[Array2<T>],
104    ) -> Result<MetaLearningResult<T>> {
105        let start_time = Instant::now();
106        let mut total_loss = T::zero();
107        let mut task_adaptations = Vec::new();
108
109        for (i, task) in tasks.iter().enumerate() {
110            // Inner adaptation loop
111            let mut adapted_params = self.meta_state.get_parameters().clone();
112
113            for inner_step in 0..self.config.inner_steps {
114                // Compute gradients on support set
115                let support_loss =
116                    self.compute_task_loss(&adapted_params, &support_data[i], task)?;
117                let gradients = self.compute_gradients(&adapted_params, support_loss)?;
118
119                // Update parameters
120                for (param, grad) in adapted_params.iter_mut().zip(gradients.iter()) {
121                    *param = *param - self.config.inner_learning_rate * (*grad);
122                }
123            }
124
125            // Evaluate on query set
126            let query_loss = self.compute_task_loss(&adapted_params, &query_data[i], task)?;
127            total_loss = total_loss + query_loss;
128
129            task_adaptations.push(TaskAdaptation {
130                task_id: task.id.clone(),
131                adapted_parameters: adapted_params,
132                support_loss: self.compute_task_loss(
133                    self.meta_state.get_parameters(),
134                    &support_data[i],
135                    task,
136                )?,
137                query_loss,
138                adaptation_steps: self.config.inner_steps,
139            });
140        }
141
142        // Meta-update
143        let meta_loss = total_loss / T::from(tasks.len()).expect("unwrap failed");
144        let meta_gradients = self.compute_meta_gradients(&task_adaptations)?;
145        self.meta_optimizer
146            .update(&mut self.meta_state, &meta_gradients)?;
147
148        // Update memory bank
149        for (i, adaptation) in task_adaptations.iter().enumerate() {
150            self.memory_bank.store_experience(
151                &tasks[i],
152                &adaptation.adapted_parameters,
153                adaptation.query_loss,
154            )?;
155        }
156
157        let result = MetaLearningResult {
158            meta_loss: meta_loss.to_f64().unwrap_or(0.0),
159            task_adaptations,
160            computation_time: start_time.elapsed(),
161            convergence_rate: self.estimate_convergence_rate()?,
162        };
163
164        self.performance_tracker.record_meta_step(result.clone());
165        Ok(result)
166    }
167
168    /// First-order MAML step (FOMAML)
169    fn fomaml_step(
170        &mut self,
171        tasks: &[TaskBatch<T>],
172        support_data: &[Array2<T>],
173        query_data: &[Array2<T>],
174    ) -> Result<MetaLearningResult<T>> {
175        // Simplified version of MAML that ignores second-order derivatives
176        // Similar to MAML but uses first-order approximation for efficiency
177        self.maml_step(tasks, support_data, query_data)
178    }
179
180    /// Reptile meta-learning step
181    fn reptile_step(
182        &mut self,
183        tasks: &[TaskBatch<T>],
184        support_data: &[Array2<T>],
185        query_data: &[Array2<T>],
186    ) -> Result<MetaLearningResult<T>> {
187        let start_time = Instant::now();
188        let mut parameter_updates = Vec::new();
189        let mut total_loss = T::zero();
190
191        for (i, task) in tasks.iter().enumerate() {
192            // Adapt on support set
193            let mut adapted_params = self.meta_state.get_parameters().clone();
194
195            for _ in 0..self.config.inner_steps {
196                let loss = self.compute_task_loss(&adapted_params, &support_data[i], task)?;
197                let gradients = self.compute_gradients(&adapted_params, loss)?;
198
199                for (param, grad) in adapted_params.iter_mut().zip(gradients.iter()) {
200                    *param = *param - self.config.inner_learning_rate * (*grad);
201                }
202            }
203
204            // Compute parameter difference for meta-update
205            let original_params = self.meta_state.get_parameters();
206            let param_diff: Vec<T> = adapted_params
207                .iter()
208                .zip(original_params.iter())
209                .map(|(adapted, original)| *adapted - *original)
210                .collect();
211
212            parameter_updates.push(param_diff);
213
214            // Evaluate on query set
215            let query_loss = self.compute_task_loss(&adapted_params, &query_data[i], task)?;
216            total_loss = total_loss + query_loss;
217        }
218
219        // Meta-update: move towards average of adapted parameters
220        let mut meta_update = vec![T::zero(); self.meta_state.get_parameters().len()];
221        for param_update in &parameter_updates {
222            for (i, &update) in param_update.iter().enumerate() {
223                meta_update[i] = meta_update[i] + update;
224            }
225        }
226
227        let num_tasks = T::from(tasks.len()).expect("unwrap failed");
228        for update in meta_update.iter_mut() {
229            *update = *update / num_tasks;
230        }
231
232        self.meta_state
233            .update_parameters(&meta_update, self.config.meta_learning_rate)?;
234
235        let result = MetaLearningResult {
236            meta_loss: (total_loss / num_tasks).to_f64().unwrap_or(0.0),
237            task_adaptations: Vec::new(), // Reptile doesn't track individual adaptations
238            computation_time: start_time.elapsed(),
239            convergence_rate: self.estimate_convergence_rate()?,
240        };
241
242        self.performance_tracker.record_meta_step(result.clone());
243        Ok(result)
244    }
245
246    /// Gradient-based meta-learning step
247    fn gradient_based_step(
248        &mut self,
249        tasks: &[TaskBatch<T>],
250        support_data: &[Array2<T>],
251        query_data: &[Array2<T>],
252    ) -> Result<MetaLearningResult<T>> {
253        // Implement gradient-based meta-learning with direct optimization
254        let start_time = Instant::now();
255        let mut total_loss = T::zero();
256
257        for (i, task) in tasks.iter().enumerate() {
258            // Use adaptation network to predict good initialization
259            let context_embedding = self.adaptation_network.encode_task_context(task)?;
260            let predicted_params = self
261                .adaptation_network
262                .predict_parameters(&context_embedding)?;
263
264            // Fine-tune predicted parameters
265            let mut adapted_params = predicted_params;
266            for _ in 0..self.config.inner_steps {
267                let loss = self.compute_task_loss(&adapted_params, &support_data[i], task)?;
268                let gradients = self.compute_gradients(&adapted_params, loss)?;
269
270                for (param, grad) in adapted_params.iter_mut().zip(gradients.iter()) {
271                    *param = *param - self.config.inner_learning_rate * (*grad);
272                }
273            }
274
275            let query_loss = self.compute_task_loss(&adapted_params, &query_data[i], task)?;
276            total_loss = total_loss + query_loss;
277        }
278
279        let result = MetaLearningResult {
280            meta_loss: (total_loss / T::from(tasks.len()).expect("unwrap failed"))
281                .to_f64()
282                .unwrap_or(0.0),
283            task_adaptations: Vec::new(),
284            computation_time: start_time.elapsed(),
285            convergence_rate: self.estimate_convergence_rate()?,
286        };
287
288        Ok(result)
289    }
290
291    /// Memory-augmented meta-learning step
292    fn memory_augmented_step(
293        &mut self,
294        tasks: &[TaskBatch<T>],
295        support_data: &[Array2<T>],
296        query_data: &[Array2<T>],
297    ) -> Result<MetaLearningResult<T>> {
298        let start_time = Instant::now();
299        let mut total_loss = T::zero();
300
301        for (i, task) in tasks.iter().enumerate() {
302            // Retrieve relevant experiences from memory
303            let relevant_experiences = self.memory_bank.retrieve_similar_experiences(task, 5)?;
304
305            // Use memory to initialize adaptation
306            let memory_guided_params = self.initialize_from_memory(&relevant_experiences)?;
307
308            let mut adapted_params = memory_guided_params;
309            for _ in 0..self.config.inner_steps {
310                let loss = self.compute_task_loss(&adapted_params, &support_data[i], task)?;
311                let gradients = self.compute_gradients(&adapted_params, loss)?;
312
313                for (param, grad) in adapted_params.iter_mut().zip(gradients.iter()) {
314                    *param = *param - self.config.inner_learning_rate * (*grad);
315                }
316            }
317
318            let query_loss = self.compute_task_loss(&adapted_params, &query_data[i], task)?;
319            total_loss = total_loss + query_loss;
320
321            // Store experience
322            self.memory_bank
323                .store_experience(task, &adapted_params, query_loss)?;
324        }
325
326        let result = MetaLearningResult {
327            meta_loss: (total_loss / T::from(tasks.len()).expect("unwrap failed"))
328                .to_f64()
329                .unwrap_or(0.0),
330            task_adaptations: Vec::new(),
331            computation_time: start_time.elapsed(),
332            convergence_rate: self.estimate_convergence_rate()?,
333        };
334
335        Ok(result)
336    }
337
338    /// Generate optimization update using meta-learned parameters
339    pub fn generate_update(
340        &mut self,
341        transformer_output: &Array2<T>,
342        current_parameters: &Array1<T>,
343    ) -> Result<Array1<T>> {
344        // Use adaptation network to generate parameter updates
345        let update = self
346            .adaptation_network
347            .generate_parameter_update(transformer_output, current_parameters)?;
348
349        // Apply meta-learned scaling
350        let scaled_update = self.apply_meta_scaling(&update)?;
351
352        Ok(scaled_update)
353    }
354
355    /// Update meta-learning state from loss
356    pub fn update_from_loss(&mut self, loss: T) -> Result<()> {
357        self.meta_state.update_loss_history(loss);
358        self.performance_tracker
359            .record_loss(loss.to_f64().unwrap_or(0.0));
360        Ok(())
361    }
362
363    /// Set meta-learning strategy
364    pub fn set_strategy(&mut self, strategy: MetaLearningStrategy) {
365        self.strategy = strategy;
366    }
367
368    /// Get current strategy
369    pub fn get_strategy(&self) -> MetaLearningStrategy {
370        self.strategy
371    }
372
373    /// Helper methods
374    fn compute_task_loss(&self, params: &[T], data: &Array2<T>, task: &TaskBatch<T>) -> Result<T> {
375        // Simplified loss computation
376        let prediction_error = self.compute_prediction_error(params, data, task)?;
377        Ok(prediction_error)
378    }
379
380    fn compute_prediction_error(
381        &self,
382        _params: &[T],
383        data: &Array2<T>,
384        _task: &TaskBatch<T>,
385    ) -> Result<T> {
386        // Placeholder: compute actual prediction error
387        let mean_squared_error = data
388            .iter()
389            .map(|&x| x * x)
390            .fold(T::zero(), |acc, x| acc + x);
391        Ok(mean_squared_error / T::from(data.len()).expect("unwrap failed"))
392    }
393
394    fn compute_gradients(&self, params: &[T], loss: T) -> Result<Vec<T>> {
395        // Simplified gradient computation
396        let gradients = params
397            .iter()
398            .map(|_| loss / T::from(params.len()).expect("unwrap failed"))
399            .collect();
400        Ok(gradients)
401    }
402
403    fn compute_meta_gradients(&self, adaptations: &[TaskAdaptation<T>]) -> Result<Vec<T>> {
404        let param_count = adaptations[0].adapted_parameters.len();
405        let mut meta_gradients = vec![T::zero(); param_count];
406
407        for adaptation in adaptations {
408            for (i, &param) in adaptation.adapted_parameters.iter().enumerate() {
409                meta_gradients[i] = meta_gradients[i] + param * adaptation.query_loss;
410            }
411        }
412
413        let num_tasks = T::from(adaptations.len()).expect("unwrap failed");
414        for grad in meta_gradients.iter_mut() {
415            *grad = *grad / num_tasks;
416        }
417
418        Ok(meta_gradients)
419    }
420
421    fn estimate_convergence_rate(&self) -> Result<f64> {
422        let loss_history = self.performance_tracker.get_loss_history();
423        if loss_history.len() < 2 {
424            return Ok(0.0);
425        }
426
427        let recent_losses: Vec<_> = loss_history.iter().rev().take(5).cloned().collect();
428        let improvement = recent_losses.last().expect("unwrap failed")
429            - recent_losses.first().expect("unwrap failed");
430        Ok(improvement.clamp(0.0, 1.0))
431    }
432
433    fn apply_meta_scaling(&self, update: &Array1<T>) -> Result<Array1<T>> {
434        // Apply learned scaling factors
435        let scale_factor = self.meta_state.get_scale_factor();
436        Ok(update * scale_factor)
437    }
438
439    fn initialize_from_memory(&self, experiences: &[MemoryExperience<T>]) -> Result<Vec<T>> {
440        if experiences.is_empty() {
441            return Ok(self.meta_state.get_parameters().clone());
442        }
443
444        // Average parameters from similar experiences
445        let param_count = experiences[0].parameters.len();
446        let mut averaged_params = vec![T::zero(); param_count];
447
448        for experience in experiences {
449            for (i, &param) in experience.parameters.iter().enumerate() {
450                averaged_params[i] = averaged_params[i] + param;
451            }
452        }
453
454        let num_experiences = T::from(experiences.len()).expect("unwrap failed");
455        for param in averaged_params.iter_mut() {
456            *param = *param / num_experiences;
457        }
458
459        Ok(averaged_params)
460    }
461}
462
463/// Meta-optimizer for outer loop updates
464pub struct MetaOptimizer<T: Float + Debug + Send + Sync + 'static> {
465    /// Learning rate
466    learning_rate: T,
467
468    /// Momentum for SGD-style updates
469    momentum: Option<T>,
470
471    /// Velocity for momentum
472    velocity: Option<Vec<T>>,
473}
474
475impl<T: Float + Debug + Send + Sync + 'static> MetaOptimizer<T> {
476    pub fn new(config: &MetaLearningConfig<T>) -> Result<Self> {
477        Ok(Self {
478            learning_rate: config.meta_learning_rate,
479            momentum: None,
480            velocity: None,
481        })
482    }
483
484    pub fn update(&mut self, state: &mut MetaState<T>, gradients: &[T]) -> Result<()> {
485        let params = state.get_parameters_mut();
486
487        if let Some(momentum) = self.momentum {
488            // Momentum update
489            if self.velocity.is_none() {
490                self.velocity = Some(vec![T::zero(); params.len()]);
491            }
492
493            if let Some(ref mut velocity) = self.velocity {
494                for i in 0..params.len() {
495                    velocity[i] = momentum * velocity[i] + self.learning_rate * gradients[i];
496                    params[i] = params[i] - velocity[i];
497                }
498            }
499        } else {
500            // Simple SGD update
501            for (param, &grad) in params.iter_mut().zip(gradients.iter()) {
502                *param = *param - self.learning_rate * grad;
503            }
504        }
505
506        Ok(())
507    }
508}
509
510/// Task adaptation network
511pub struct AdaptationNetwork<T: Float + Debug + Send + Sync + 'static> {
512    /// Context encoder
513    context_encoder: FeedForwardNetwork<T>,
514
515    /// Parameter predictor
516    parameter_predictor: FeedForwardNetwork<T>,
517
518    /// Update generator
519    update_generator: FeedForwardNetwork<T>,
520
521    /// Model dimension
522    model_dimension: usize,
523}
524
525impl<T: Float + Debug + Send + Sync + 'static> AdaptationNetwork<T> {
526    pub fn new(model_dimension: usize, hidden_dimension: usize) -> Result<Self> {
527        let context_encoder =
528            FeedForwardNetwork::new(model_dimension, hidden_dimension, ActivationFunction::ReLU)?;
529
530        let parameter_predictor =
531            FeedForwardNetwork::new(hidden_dimension, model_dimension, ActivationFunction::Tanh)?;
532
533        let update_generator = FeedForwardNetwork::new(
534            model_dimension * 2, // concatenated transformer output and current params
535            model_dimension,
536            ActivationFunction::ReLU,
537        )?;
538
539        Ok(Self {
540            context_encoder,
541            parameter_predictor,
542            update_generator,
543            model_dimension,
544        })
545    }
546
547    pub fn encode_task_context(&mut self, _task: &TaskBatch<T>) -> Result<Array2<T>> {
548        // Encode task characteristics into context vector
549        let task_features = Array2::ones((1, self.model_dimension));
550        self.context_encoder.forward(&task_features)
551    }
552
553    pub fn predict_parameters(&mut self, context: &Array2<T>) -> Result<Vec<T>> {
554        let predicted = self.parameter_predictor.forward(context)?;
555        Ok(predicted.row(0).to_vec())
556    }
557
558    pub fn generate_parameter_update(
559        &mut self,
560        transformer_output: &Array2<T>,
561        current_parameters: &Array1<T>,
562    ) -> Result<Array1<T>> {
563        // Concatenate transformer output and current parameters
564        let batch_size = transformer_output.shape()[0];
565        let mut input = Array2::zeros((batch_size, self.model_dimension * 2));
566
567        for i in 0..batch_size {
568            for j in 0..self.model_dimension {
569                input[[i, j]] = transformer_output[[i, j]];
570                if j < current_parameters.len() {
571                    input[[i, j + self.model_dimension]] = current_parameters[j];
572                }
573            }
574        }
575
576        let update = self.update_generator.forward(&input)?;
577        Ok(update.row(0).to_owned())
578    }
579}
580
581/// Memory bank for storing task experiences
582pub struct MemoryBank<T: Float + Debug + Send + Sync + 'static> {
583    /// Stored experiences
584    experiences: VecDeque<MemoryExperience<T>>,
585
586    /// Maximum memory size
587    max_size: usize,
588
589    /// Dimension of stored parameters
590    parameter_dimension: usize,
591}
592
593impl<T: Float + Debug + Send + Sync + 'static> MemoryBank<T> {
594    pub fn new(max_size: usize, parameter_dimension: usize) -> Result<Self> {
595        Ok(Self {
596            experiences: VecDeque::new(),
597            max_size,
598            parameter_dimension,
599        })
600    }
601
602    pub fn store_experience(
603        &mut self,
604        task: &TaskBatch<T>,
605        parameters: &[T],
606        performance: T,
607    ) -> Result<()> {
608        let experience = MemoryExperience {
609            task_signature: self.compute_task_signature(task),
610            parameters: parameters.to_vec(),
611            performance: performance.to_f64().unwrap_or(0.0),
612            timestamp: Instant::now(),
613        };
614
615        self.experiences.push_back(experience);
616
617        if self.experiences.len() > self.max_size {
618            self.experiences.pop_front();
619        }
620
621        Ok(())
622    }
623
624    pub fn retrieve_similar_experiences(
625        &self,
626        task: &TaskBatch<T>,
627        k: usize,
628    ) -> Result<Vec<MemoryExperience<T>>> {
629        let target_signature = self.compute_task_signature(task);
630
631        let mut scored_experiences: Vec<_> = self
632            .experiences
633            .iter()
634            .map(|exp| {
635                let similarity = self.compute_similarity(&target_signature, &exp.task_signature);
636                (similarity, exp.clone())
637            })
638            .collect();
639
640        scored_experiences.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("unwrap failed"));
641
642        Ok(scored_experiences
643            .into_iter()
644            .take(k)
645            .map(|(_, exp)| exp)
646            .collect())
647    }
648
649    fn compute_task_signature(&self, task: &TaskBatch<T>) -> Vec<f64> {
650        // Simplified task signature
651        vec![task.difficulty, task.complexity, task.data_characteristics]
652    }
653
654    fn compute_similarity(&self, sig1: &[f64], sig2: &[f64]) -> f64 {
655        if sig1.len() != sig2.len() {
656            return 0.0;
657        }
658
659        let dot_product: f64 = sig1.iter().zip(sig2.iter()).map(|(a, b)| a * b).sum();
660        let norm1: f64 = sig1.iter().map(|x| x * x).sum::<f64>().sqrt();
661        let norm2: f64 = sig2.iter().map(|x| x * x).sum::<f64>().sqrt();
662
663        if norm1 == 0.0 || norm2 == 0.0 {
664            0.0
665        } else {
666            dot_product / (norm1 * norm2)
667        }
668    }
669}
670
671/// Supporting data structures
672#[derive(Debug, Clone)]
673pub struct TaskBatch<T: Float + Debug + Send + Sync + 'static> {
674    pub id: String,
675    pub difficulty: f64,
676    pub complexity: f64,
677    pub data_characteristics: f64,
678    pub _phantom: std::marker::PhantomData<T>,
679}
680
681#[derive(Debug, Clone)]
682pub struct TaskAdaptation<T: Float + Debug + Send + Sync + 'static> {
683    pub task_id: String,
684    pub adapted_parameters: Vec<T>,
685    pub support_loss: T,
686    pub query_loss: T,
687    pub adaptation_steps: usize,
688}
689
690#[derive(Debug, Clone)]
691pub struct MetaLearningResult<T: Float + Debug + Send + Sync + 'static> {
692    pub meta_loss: f64,
693    pub task_adaptations: Vec<TaskAdaptation<T>>,
694    pub computation_time: std::time::Duration,
695    pub convergence_rate: f64,
696}
697
698#[derive(Debug, Clone)]
699pub struct MemoryExperience<T: Float + Debug + Send + Sync + 'static> {
700    pub task_signature: Vec<f64>,
701    pub parameters: Vec<T>,
702    pub performance: f64,
703    pub timestamp: Instant,
704}
705
706pub struct PerformanceTracker<T: Float + Debug + Send + Sync + 'static> {
707    loss_history: VecDeque<f64>,
708    meta_results: VecDeque<MetaLearningResult<T>>,
709}
710
711impl<T: Float + Debug + Send + Sync + 'static> Default for PerformanceTracker<T> {
712    fn default() -> Self {
713        Self::new()
714    }
715}
716
717impl<T: Float + Debug + Send + Sync + 'static> PerformanceTracker<T> {
718    pub fn new() -> Self {
719        Self {
720            loss_history: VecDeque::new(),
721            meta_results: VecDeque::new(),
722        }
723    }
724
725    pub fn record_loss(&mut self, loss: f64) {
726        self.loss_history.push_back(loss);
727        if self.loss_history.len() > 1000 {
728            self.loss_history.pop_front();
729        }
730    }
731
732    pub fn record_meta_step(&mut self, result: MetaLearningResult<T>) {
733        self.meta_results.push_back(result);
734        if self.meta_results.len() > 100 {
735            self.meta_results.pop_front();
736        }
737    }
738
739    pub fn get_loss_history(&self) -> &VecDeque<f64> {
740        &self.loss_history
741    }
742}
743
744#[derive(Debug, Clone)]
745pub struct MetaState<T: Float + Debug + Send + Sync + 'static> {
746    parameters: Vec<T>,
747    loss_history: VecDeque<T>,
748    scale_factor: T,
749}
750
751impl<T: Float + Debug + Send + Sync + 'static> MetaState<T> {
752    pub fn new(parameter_count: usize) -> Result<Self> {
753        Ok(Self {
754            parameters: vec![T::zero(); parameter_count],
755            loss_history: VecDeque::new(),
756            scale_factor: T::one(),
757        })
758    }
759
760    pub fn get_parameters(&self) -> &Vec<T> {
761        &self.parameters
762    }
763
764    pub fn get_parameters_mut(&mut self) -> &mut Vec<T> {
765        &mut self.parameters
766    }
767
768    pub fn update_parameters(&mut self, updates: &[T], learning_rate: T) -> Result<()> {
769        for (param, &update) in self.parameters.iter_mut().zip(updates.iter()) {
770            *param = *param + learning_rate * update;
771        }
772        Ok(())
773    }
774
775    pub fn update_loss_history(&mut self, loss: T) {
776        self.loss_history.push_back(loss);
777        if self.loss_history.len() > 100 {
778            self.loss_history.pop_front();
779        }
780    }
781
782    pub fn get_scale_factor(&self) -> T {
783        self.scale_factor
784    }
785}
786
787#[cfg(test)]
788mod tests {
789    use super::*;
790
791    #[test]
792    #[ignore]
793    fn test_meta_learning_creation() {
794        let config = super::super::config::TransformerBasedOptimizerConfig::<f32>::default();
795        let meta_learning = TransformerMetaLearning::new(&config);
796        assert!(meta_learning.is_ok());
797    }
798
799    #[test]
800    fn test_memory_bank() {
801        let memory = MemoryBank::<f32>::new(100, 64);
802        assert!(memory.is_ok());
803
804        let mut bank = memory.expect("unwrap failed");
805        let task = TaskBatch {
806            id: "test".to_string(),
807            difficulty: 0.5,
808            complexity: 0.7,
809            data_characteristics: 0.3,
810            _phantom: std::marker::PhantomData,
811        };
812
813        let params = vec![0.1f32; 64];
814        assert!(bank.store_experience(&task, &params, 0.8).is_ok());
815    }
816
817    #[test]
818    fn test_adaptation_network() {
819        let network = AdaptationNetwork::<f32>::new(128, 256);
820        assert!(network.is_ok());
821
822        let mut net = network.expect("unwrap failed");
823        let task = TaskBatch {
824            id: "test".to_string(),
825            difficulty: 0.5,
826            complexity: 0.7,
827            data_characteristics: 0.3,
828            _phantom: std::marker::PhantomData,
829        };
830
831        let context = net.encode_task_context(&task);
832        assert!(context.is_ok());
833    }
834}