optirs_learned/transformer/training/
meta_learning.rs

1// Meta-learning capabilities for transformer optimization
2//
3// This module implements meta-learning strategies that allow the transformer
4// optimizer to quickly adapt to new tasks and optimization landscapes.
5
6#[allow(dead_code)]
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::numeric::Float;
9use std::collections::{HashMap, VecDeque};
10use std::fmt::Debug;
11
12use crate::error::{OptimError, Result};
13use crate::transformer::TransformerNetwork;
14
15/// Meta-learning strategies
16#[derive(Debug, Clone, Copy)]
17pub enum MetaLearningStrategy {
18    /// Model-Agnostic Meta-Learning (MAML)
19    MAML,
20    /// Reptile algorithm
21    Reptile,
22    /// Gradient-based meta-learning
23    GradientBased,
24    /// Memory-augmented meta-learning
25    MemoryAugmented,
26    /// Task-agnostic meta-learning
27    TaskAgnostic,
28    /// Few-shot meta-learning
29    FewShot,
30    /// Continual meta-learning
31    Continual,
32}
33
34/// Meta-learner for transformer optimizer
35#[derive(Debug, Clone)]
36pub struct TransformerMetaLearner<
37    T: Float
38        + Debug
39        + Default
40        + Clone
41        + std::iter::Sum
42        + scirs2_core::ndarray::ScalarOperand
43        + Send
44        + Sync
45        + 'static,
46> {
47    /// Meta-learning strategy
48    strategy: MetaLearningStrategy,
49
50    /// Meta-transformer for higher-level learning
51    meta_transformer: Option<TransformerNetwork<T>>,
52
53    /// Task embeddings
54    task_embeddings: HashMap<String, Array1<T>>,
55
56    /// Meta-training history
57    meta_history: VecDeque<MetaTrainingEvent<T>>,
58
59    /// Domain adaptation module
60    domain_adapter: DomainAdapter<T>,
61
62    /// Few-shot learning capabilities
63    few_shot_learner: FewShotLearner<T>,
64
65    /// Continual learning state
66    continual_learning: ContinualLearningState<T>,
67
68    /// Meta-learning parameters
69    meta_params: MetaLearningParams<T>,
70}
71
72/// Meta-training event
73#[derive(Debug, Clone)]
74pub struct MetaTrainingEvent<T: Float + Debug + Send + Sync + 'static> {
75    /// Event type
76    event_type: MetaEventType,
77
78    /// Task information
79    task_info: TaskInfo<T>,
80
81    /// Performance metrics
82    performance: MetaPerformanceMetrics<T>,
83
84    /// Adaptation steps
85    adaptation_steps: usize,
86
87    /// Timestamp
88    timestamp: usize,
89}
90
91/// Meta-event types
92#[derive(Debug, Clone, Copy)]
93pub enum MetaEventType {
94    /// Task adaptation
95    TaskAdaptation,
96    /// Domain transfer
97    DomainTransfer,
98    /// Few-shot learning
99    FewShotLearning,
100    /// Continual learning
101    ContinualLearning,
102    /// Meta-validation
103    MetaValidation,
104}
105
106/// Task information
107#[derive(Debug, Clone)]
108pub struct TaskInfo<T: Float + Debug + Send + Sync + 'static> {
109    /// Task identifier
110    task_id: String,
111
112    /// Task characteristics
113    characteristics: TaskCharacteristics<T>,
114
115    /// Domain information
116    domain: DomainInfo,
117
118    /// Difficulty level
119    difficulty: T,
120
121    /// Expected performance
122    expected_performance: Option<T>,
123}
124
125/// Task characteristics
126#[derive(Debug, Clone)]
127pub struct TaskCharacteristics<T: Float + Debug + Send + Sync + 'static> {
128    /// Problem dimensionality
129    dimensionality: usize,
130
131    /// Landscape complexity
132    landscape_complexity: T,
133
134    /// Noise level
135    noise_level: T,
136
137    /// Conditioning number
138    conditioning: T,
139
140    /// Sparsity level
141    sparsity: T,
142
143    /// Temporal dependencies
144    temporal_dependencies: T,
145
146    /// Feature correlations
147    feature_correlations: Array2<T>,
148}
149
150/// Domain information
151#[derive(Debug, Clone)]
152pub struct DomainInfo {
153    /// Domain name
154    name: String,
155
156    /// Domain type
157    domain_type: DomainType,
158
159    /// Related domains
160    related_domains: Vec<String>,
161
162    /// Domain-specific features
163    features: HashMap<String, f64>,
164}
165
166/// Domain types
167#[derive(Debug, Clone, Copy)]
168pub enum DomainType {
169    /// Computer vision
170    Vision,
171    /// Natural language processing
172    NLP,
173    /// Reinforcement learning
174    RL,
175    /// Time series
176    TimeSeries,
177    /// Graph neural networks
178    Graph,
179    /// Scientific computing
180    Scientific,
181    /// General optimization
182    General,
183}
184
185/// Meta-performance metrics
186#[derive(Debug, Clone)]
187pub struct MetaPerformanceMetrics<T: Float + Debug + Send + Sync + 'static> {
188    /// Final performance
189    final_performance: T,
190
191    /// Convergence speed
192    convergence_speed: T,
193
194    /// Sample efficiency
195    sample_efficiency: T,
196
197    /// Generalization score
198    generalization: T,
199
200    /// Stability measure
201    stability: T,
202
203    /// Resource usage
204    resource_usage: T,
205}
206
207/// Domain adapter for cross-domain transfer
208#[derive(Debug, Clone)]
209pub struct DomainAdapter<T: Float + Debug + Send + Sync + 'static> {
210    /// Domain-specific adapters
211    adapters: HashMap<String, DomainSpecificAdapter<T>>,
212
213    /// Domain similarity estimator
214    similarity_estimator: DomainSimilarityEstimator<T>,
215
216    /// Adaptation strategies
217    adaptation_strategies: Vec<AdaptationStrategy>,
218
219    /// Transfer efficiency tracker
220    transfer_tracker: TransferEfficiencyTracker<T>,
221}
222
223/// Domain-specific adapter
224#[derive(Debug, Clone)]
225pub struct DomainSpecificAdapter<T: Float + Debug + Send + Sync + 'static> {
226    /// Adapter parameters
227    parameters: HashMap<String, Array1<T>>,
228
229    /// Domain features
230    domain_features: Array1<T>,
231
232    /// Adaptation history
233    adaptation_history: Vec<AdaptationEvent<T>>,
234
235    /// Performance on domain
236    domain_performance: T,
237}
238
239/// Few-shot learner component
240#[derive(Debug, Clone)]
241pub struct FewShotLearner<T: Float + Debug + Send + Sync + 'static> {
242    /// Support set memory
243    support_memory: HashMap<String, Vec<Array1<T>>>,
244
245    /// Prototype vectors
246    prototypes: HashMap<String, Array1<T>>,
247
248    /// Distance metric learner
249    distance_learner: DistanceMetricLearner<T>,
250
251    /// Few-shot adaptation parameters
252    adaptation_params: FewShotParams<T>,
253}
254
255/// Continual learning state
256#[derive(Debug, Clone)]
257pub struct ContinualLearningState<T: Float + Debug + Send + Sync + 'static> {
258    /// Elastic weight consolidation parameters
259    ewc_params: HashMap<String, Array1<T>>,
260
261    /// Fisher information matrix
262    fisher_information: HashMap<String, Array2<T>>,
263
264    /// Previous task importance scores
265    task_importance: HashMap<String, T>,
266
267    /// Memory replay buffer
268    replay_buffer: Vec<ContinualLearningEvent<T>>,
269
270    /// Catastrophic forgetting prevention strategy
271    forgetting_prevention: ForgettingPreventionStrategy,
272}
273
274/// Meta-learning parameters
275#[derive(Debug, Clone)]
276pub struct MetaLearningParams<T: Float + Debug + Send + Sync + 'static> {
277    /// Learning rate for meta-updates
278    meta_learning_rate: T,
279
280    /// Number of inner gradient steps
281    inner_steps: usize,
282
283    /// Meta-batch size
284    meta_batch_size: usize,
285
286    /// Task diversity weight
287    diversity_weight: T,
288
289    /// Transfer learning coefficient
290    transfer_coefficient: T,
291
292    /// Memory retention factor
293    memory_retention: T,
294}
295
296// Additional supporting types
297#[derive(Debug, Clone, Copy)]
298pub enum AdaptationStrategy {
299    FineTuning,
300    ParameterSharing,
301    ModularAdaptation,
302    AttentionAdaptation,
303}
304
305#[derive(Debug, Clone)]
306pub struct DomainSimilarityEstimator<T: Float + Debug + Send + Sync + 'static> {
307    similarity_matrix: HashMap<(String, String), T>,
308    feature_extractors: HashMap<String, Array2<T>>,
309}
310
311#[derive(Debug, Clone)]
312pub struct TransferEfficiencyTracker<T: Float + Debug + Send + Sync + 'static> {
313    transfer_history: Vec<TransferEvent<T>>,
314    efficiency_metrics: HashMap<String, T>,
315}
316
317#[derive(Debug, Clone)]
318pub struct AdaptationEvent<T: Float + Debug + Send + Sync + 'static> {
319    timestamp: usize,
320    adaptation_loss: T,
321    performance_gain: T,
322    adaptation_steps: usize,
323}
324
325#[derive(Debug, Clone)]
326pub struct DistanceMetricLearner<T: Float + Debug + Send + Sync + 'static> {
327    metric_parameters: Array2<T>,
328    learned_similarities: HashMap<String, T>,
329}
330
331#[derive(Debug, Clone)]
332pub struct FewShotParams<T: Float + Debug + Send + Sync + 'static> {
333    support_size: usize,
334    query_size: usize,
335    adaptation_lr: T,
336    temperature: T,
337}
338
339#[derive(Debug, Clone)]
340pub struct ContinualLearningEvent<T: Float + Debug + Send + Sync + 'static> {
341    task_id: String,
342    gradients: Array1<T>,
343    performance: T,
344    timestamp: usize,
345}
346
347#[derive(Debug, Clone, Copy)]
348pub enum ForgettingPreventionStrategy {
349    EWC,
350    PackNet,
351    ProgressiveNetworks,
352    GEM,
353}
354
355#[derive(Debug, Clone)]
356pub struct TransferEvent<T: Float + Debug + Send + Sync + 'static> {
357    source_domain: String,
358    target_domain: String,
359    transfer_performance: T,
360    adaptation_time: usize,
361}
362
363impl<
364        T: Float
365            + Debug
366            + Send
367            + Sync
368            + 'static
369            + Default
370            + Clone
371            + std::iter::Sum
372            + scirs2_core::ndarray::ScalarOperand,
373    > TransformerMetaLearner<T>
374{
375    /// Create new meta-learner
376    pub fn new(strategy: MetaLearningStrategy) -> Result<Self> {
377        Ok(Self {
378            strategy,
379            meta_transformer: None,
380            task_embeddings: HashMap::new(),
381            meta_history: VecDeque::new(),
382            domain_adapter: DomainAdapter::new()?,
383            few_shot_learner: FewShotLearner::new()?,
384            continual_learning: ContinualLearningState::new()?,
385            meta_params: MetaLearningParams::default(),
386        })
387    }
388
389    /// Adapt to a new task
390    pub fn adapt_to_task(
391        &mut self,
392        task_info: &TaskInfo<T>,
393        support_data: &[Array1<T>],
394        query_data: &[Array1<T>],
395    ) -> Result<T> {
396        match self.strategy {
397            MetaLearningStrategy::MAML => self.maml_adaptation(task_info, support_data, query_data),
398            MetaLearningStrategy::Reptile => {
399                self.reptile_adaptation(task_info, support_data, query_data)
400            }
401            MetaLearningStrategy::FewShot => {
402                self.few_shot_adaptation(task_info, support_data, query_data)
403            }
404            MetaLearningStrategy::Continual => {
405                self.continual_adaptation(task_info, support_data, query_data)
406            }
407            _ => self.generic_adaptation(task_info, support_data, query_data),
408        }
409    }
410
411    /// MAML adaptation
412    fn maml_adaptation(
413        &mut self,
414        task_info: &TaskInfo<T>,
415        support_data: &[Array1<T>],
416        query_data: &[Array1<T>],
417    ) -> Result<T> {
418        // Simplified MAML implementation
419        let mut adaptation_loss = T::zero();
420
421        // Perform inner loop updates
422        for _ in 0..self.meta_params.inner_steps {
423            // Compute gradients on support set
424            let support_loss = self.compute_support_loss(support_data)?;
425
426            // Update parameters (simplified)
427            adaptation_loss = adaptation_loss + support_loss;
428        }
429
430        // Evaluate on query set
431        let query_loss = self.compute_query_loss(query_data)?;
432
433        // Record adaptation event
434        let event = MetaTrainingEvent {
435            event_type: MetaEventType::TaskAdaptation,
436            task_info: task_info.clone(),
437            performance: MetaPerformanceMetrics {
438                final_performance: query_loss,
439                convergence_speed: scirs2_core::numeric::NumCast::from(
440                    1.0 / self.meta_params.inner_steps as f64,
441                )
442                .unwrap_or_else(|| T::zero()),
443                sample_efficiency: T::from(support_data.len() as f64).unwrap(),
444                generalization: T::one() / (T::one() + query_loss),
445                stability: scirs2_core::numeric::NumCast::from(0.9).unwrap_or_else(|| T::zero()),
446                resource_usage: scirs2_core::numeric::NumCast::from(
447                    self.meta_params.inner_steps as f64,
448                )
449                .unwrap_or_else(|| T::zero()),
450            },
451            adaptation_steps: self.meta_params.inner_steps,
452            timestamp: self.meta_history.len(),
453        };
454
455        self.meta_history.push_back(event);
456
457        Ok(query_loss)
458    }
459
460    /// Reptile adaptation
461    fn reptile_adaptation(
462        &mut self,
463        task_info: &TaskInfo<T>,
464        support_data: &[Array1<T>],
465        _query_data: &[Array1<T>],
466    ) -> Result<T> {
467        // Simplified Reptile implementation
468        let initial_loss = self.compute_support_loss(support_data)?;
469
470        // Perform multiple gradient steps
471        let mut final_loss = initial_loss;
472        for _ in 0..self.meta_params.inner_steps {
473            final_loss =
474                final_loss * scirs2_core::numeric::NumCast::from(0.95).unwrap_or_else(|| T::zero());
475            // Simplified decay
476        }
477
478        Ok(final_loss)
479    }
480
481    /// Few-shot adaptation
482    fn few_shot_adaptation(
483        &mut self,
484        task_info: &TaskInfo<T>,
485        support_data: &[Array1<T>],
486        query_data: &[Array1<T>],
487    ) -> Result<T> {
488        self.few_shot_learner
489            .adapt(task_info, support_data, query_data)
490    }
491
492    /// Continual learning adaptation
493    fn continual_adaptation(
494        &mut self,
495        task_info: &TaskInfo<T>,
496        support_data: &[Array1<T>],
497        query_data: &[Array1<T>],
498    ) -> Result<T> {
499        // Update continual learning state
500        self.continual_learning
501            .update_for_task(task_info, support_data)?;
502
503        // Compute adaptation loss with forgetting prevention
504        let base_loss = self.compute_support_loss(support_data)?;
505        let forgetting_penalty = self.continual_learning.compute_forgetting_penalty()?;
506
507        Ok(base_loss + forgetting_penalty)
508    }
509
510    /// Generic adaptation fallback
511    fn generic_adaptation(
512        &mut self,
513        _task_info: &TaskInfo<T>,
514        support_data: &[Array1<T>],
515        query_data: &[Array1<T>],
516    ) -> Result<T> {
517        let support_loss = self.compute_support_loss(support_data)?;
518        let query_loss = self.compute_query_loss(query_data)?;
519        Ok((support_loss + query_loss)
520            / scirs2_core::numeric::NumCast::from(2.0).unwrap_or_else(|| T::zero()))
521    }
522
523    /// Compute loss on support set
524    fn compute_support_loss(&self, support_data: &[Array1<T>]) -> Result<T> {
525        if support_data.is_empty() {
526            return Ok(T::zero());
527        }
528
529        let mut total_loss = T::zero();
530        for data in support_data {
531            // Simplified loss computation
532            let loss = data.iter().map(|&x| x * x).fold(T::zero(), |a, b| a + b);
533            total_loss = total_loss + loss;
534        }
535
536        Ok(total_loss / T::from(support_data.len() as f64).unwrap())
537    }
538
539    /// Compute loss on query set
540    fn compute_query_loss(&self, query_data: &[Array1<T>]) -> Result<T> {
541        if query_data.is_empty() {
542            return Ok(T::zero());
543        }
544
545        let mut total_loss = T::zero();
546        for data in query_data {
547            // Simplified loss computation
548            let loss = data.iter().map(|&x| x * x).fold(T::zero(), |a, b| a + b);
549            total_loss = total_loss + loss;
550        }
551
552        Ok(total_loss / T::from(query_data.len() as f64).unwrap())
553    }
554
555    /// Get meta-learning statistics
556    pub fn get_meta_statistics(&self) -> HashMap<String, T> {
557        let mut stats = HashMap::new();
558
559        stats.insert(
560            "meta_events_count".to_string(),
561            T::from(self.meta_history.len() as f64).unwrap(),
562        );
563        stats.insert(
564            "task_embeddings_count".to_string(),
565            T::from(self.task_embeddings.len() as f64).unwrap(),
566        );
567
568        // Compute average performance
569        if !self.meta_history.is_empty() {
570            let avg_performance = self
571                .meta_history
572                .iter()
573                .map(|event| event.performance.final_performance)
574                .fold(T::zero(), |a, b| a + b)
575                / T::from(self.meta_history.len() as f64).unwrap();
576            stats.insert("average_performance".to_string(), avg_performance);
577        }
578
579        stats
580    }
581
582    /// Update meta-parameters
583    pub fn update_meta_parameters(&mut self, params: MetaLearningParams<T>) {
584        self.meta_params = params;
585    }
586
587    /// Get domain adapter
588    pub fn domain_adapter(&self) -> &DomainAdapter<T> {
589        &self.domain_adapter
590    }
591
592    /// Reset meta-learner state
593    pub fn reset(&mut self) {
594        self.task_embeddings.clear();
595        self.meta_history.clear();
596        self.domain_adapter.reset();
597        self.few_shot_learner.reset();
598        self.continual_learning.reset();
599    }
600}
601
602// Implementation for supporting types
603impl<
604        T: Float
605            + Debug
606            + Send
607            + Sync
608            + 'static
609            + Default
610            + Clone
611            + std::iter::Sum
612            + scirs2_core::ndarray::ScalarOperand,
613    > DomainAdapter<T>
614{
615    fn new() -> Result<Self> {
616        Ok(Self {
617            adapters: HashMap::new(),
618            similarity_estimator: DomainSimilarityEstimator::new()?,
619            adaptation_strategies: vec![AdaptationStrategy::FineTuning],
620            transfer_tracker: TransferEfficiencyTracker::new()?,
621        })
622    }
623
624    fn reset(&mut self) {
625        self.adapters.clear();
626    }
627}
628
629impl<
630        T: Float
631            + Debug
632            + Send
633            + Sync
634            + 'static
635            + Default
636            + Clone
637            + std::iter::Sum
638            + scirs2_core::ndarray::ScalarOperand,
639    > FewShotLearner<T>
640{
641    fn new() -> Result<Self> {
642        Ok(Self {
643            support_memory: HashMap::new(),
644            prototypes: HashMap::new(),
645            distance_learner: DistanceMetricLearner::new()?,
646            adaptation_params: FewShotParams::default(),
647        })
648    }
649
650    fn adapt(
651        &mut self,
652        _task_info: &TaskInfo<T>,
653        support_data: &[Array1<T>],
654        query_data: &[Array1<T>],
655    ) -> Result<T> {
656        // Simplified few-shot adaptation
657        let support_loss = support_data
658            .iter()
659            .map(|x| x.iter().map(|&v| v * v).fold(T::zero(), |a, b| a + b))
660            .fold(T::zero(), |a, b| a + b);
661        let query_loss = query_data
662            .iter()
663            .map(|x| x.iter().map(|&v| v * v).fold(T::zero(), |a, b| a + b))
664            .fold(T::zero(), |a, b| a + b);
665
666        Ok((support_loss + query_loss)
667            / T::from((support_data.len() + query_data.len()) as f64).unwrap())
668    }
669
670    fn reset(&mut self) {
671        self.support_memory.clear();
672        self.prototypes.clear();
673    }
674}
675
676impl<
677        T: Float
678            + Debug
679            + Send
680            + Sync
681            + 'static
682            + Default
683            + Clone
684            + std::iter::Sum
685            + scirs2_core::ndarray::ScalarOperand,
686    > ContinualLearningState<T>
687{
688    fn new() -> Result<Self> {
689        Ok(Self {
690            ewc_params: HashMap::new(),
691            fisher_information: HashMap::new(),
692            task_importance: HashMap::new(),
693            replay_buffer: Vec::new(),
694            forgetting_prevention: ForgettingPreventionStrategy::EWC,
695        })
696    }
697
698    fn update_for_task(
699        &mut self,
700        task_info: &TaskInfo<T>,
701        _support_data: &[Array1<T>],
702    ) -> Result<()> {
703        self.task_importance
704            .insert(task_info.task_id.clone(), task_info.difficulty);
705        Ok(())
706    }
707
708    fn compute_forgetting_penalty(&self) -> Result<T> {
709        // Simplified forgetting penalty
710        Ok(scirs2_core::numeric::NumCast::from(0.01).unwrap_or_else(|| T::zero()))
711    }
712
713    fn reset(&mut self) {
714        self.ewc_params.clear();
715        self.fisher_information.clear();
716        self.task_importance.clear();
717        self.replay_buffer.clear();
718    }
719}
720
721impl<
722        T: Float
723            + Debug
724            + Send
725            + Sync
726            + 'static
727            + Default
728            + Clone
729            + std::iter::Sum
730            + scirs2_core::ndarray::ScalarOperand,
731    > DomainSimilarityEstimator<T>
732{
733    fn new() -> Result<Self> {
734        Ok(Self {
735            similarity_matrix: HashMap::new(),
736            feature_extractors: HashMap::new(),
737        })
738    }
739}
740
741impl<
742        T: Float
743            + Debug
744            + Send
745            + Sync
746            + 'static
747            + Default
748            + Clone
749            + std::iter::Sum
750            + scirs2_core::ndarray::ScalarOperand,
751    > TransferEfficiencyTracker<T>
752{
753    fn new() -> Result<Self> {
754        Ok(Self {
755            transfer_history: Vec::new(),
756            efficiency_metrics: HashMap::new(),
757        })
758    }
759}
760
761impl<
762        T: Float
763            + Debug
764            + Send
765            + Sync
766            + 'static
767            + Default
768            + Clone
769            + std::iter::Sum
770            + scirs2_core::ndarray::ScalarOperand,
771    > DistanceMetricLearner<T>
772{
773    fn new() -> Result<Self> {
774        Ok(Self {
775            metric_parameters: Array2::eye(10), // Default 10x10 identity matrix
776            learned_similarities: HashMap::new(),
777        })
778    }
779}
780
781impl<
782        T: Float
783            + Debug
784            + Send
785            + Sync
786            + 'static
787            + Default
788            + Clone
789            + std::iter::Sum
790            + scirs2_core::ndarray::ScalarOperand,
791    > Default for MetaLearningParams<T>
792{
793    fn default() -> Self {
794        Self {
795            meta_learning_rate: scirs2_core::numeric::NumCast::from(0.001)
796                .unwrap_or_else(|| T::zero()),
797            inner_steps: 5,
798            meta_batch_size: 32,
799            diversity_weight: scirs2_core::numeric::NumCast::from(0.1).unwrap_or_else(|| T::zero()),
800            transfer_coefficient: scirs2_core::numeric::NumCast::from(0.5)
801                .unwrap_or_else(|| T::zero()),
802            memory_retention: scirs2_core::numeric::NumCast::from(0.95)
803                .unwrap_or_else(|| T::zero()),
804        }
805    }
806}
807
808impl<
809        T: Float
810            + Debug
811            + Send
812            + Sync
813            + 'static
814            + Default
815            + Clone
816            + std::iter::Sum
817            + scirs2_core::ndarray::ScalarOperand,
818    > Default for FewShotParams<T>
819{
820    fn default() -> Self {
821        Self {
822            support_size: 5,
823            query_size: 15,
824            adaptation_lr: scirs2_core::numeric::NumCast::from(0.01).unwrap_or_else(|| T::zero()),
825            temperature: scirs2_core::numeric::NumCast::from(1.0).unwrap_or_else(|| T::zero()),
826        }
827    }
828}