1#[allow(dead_code)]
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::numeric::Float;
9use std::collections::{HashMap, VecDeque};
10use std::fmt::Debug;
11
12use crate::error::{OptimError, Result};
13use crate::transformer::TransformerNetwork;
14
15#[derive(Debug, Clone, Copy)]
17pub enum MetaLearningStrategy {
18 MAML,
20 Reptile,
22 GradientBased,
24 MemoryAugmented,
26 TaskAgnostic,
28 FewShot,
30 Continual,
32}
33
34#[derive(Debug, Clone)]
36pub struct TransformerMetaLearner<
37 T: Float
38 + Debug
39 + Default
40 + Clone
41 + std::iter::Sum
42 + scirs2_core::ndarray::ScalarOperand
43 + Send
44 + Sync
45 + 'static,
46> {
47 strategy: MetaLearningStrategy,
49
50 meta_transformer: Option<TransformerNetwork<T>>,
52
53 task_embeddings: HashMap<String, Array1<T>>,
55
56 meta_history: VecDeque<MetaTrainingEvent<T>>,
58
59 domain_adapter: DomainAdapter<T>,
61
62 few_shot_learner: FewShotLearner<T>,
64
65 continual_learning: ContinualLearningState<T>,
67
68 meta_params: MetaLearningParams<T>,
70}
71
72#[derive(Debug, Clone)]
74pub struct MetaTrainingEvent<T: Float + Debug + Send + Sync + 'static> {
75 event_type: MetaEventType,
77
78 task_info: TaskInfo<T>,
80
81 performance: MetaPerformanceMetrics<T>,
83
84 adaptation_steps: usize,
86
87 timestamp: usize,
89}
90
91#[derive(Debug, Clone, Copy)]
93pub enum MetaEventType {
94 TaskAdaptation,
96 DomainTransfer,
98 FewShotLearning,
100 ContinualLearning,
102 MetaValidation,
104}
105
106#[derive(Debug, Clone)]
108pub struct TaskInfo<T: Float + Debug + Send + Sync + 'static> {
109 task_id: String,
111
112 characteristics: TaskCharacteristics<T>,
114
115 domain: DomainInfo,
117
118 difficulty: T,
120
121 expected_performance: Option<T>,
123}
124
125#[derive(Debug, Clone)]
127pub struct TaskCharacteristics<T: Float + Debug + Send + Sync + 'static> {
128 dimensionality: usize,
130
131 landscape_complexity: T,
133
134 noise_level: T,
136
137 conditioning: T,
139
140 sparsity: T,
142
143 temporal_dependencies: T,
145
146 feature_correlations: Array2<T>,
148}
149
150#[derive(Debug, Clone)]
152pub struct DomainInfo {
153 name: String,
155
156 domain_type: DomainType,
158
159 related_domains: Vec<String>,
161
162 features: HashMap<String, f64>,
164}
165
166#[derive(Debug, Clone, Copy)]
168pub enum DomainType {
169 Vision,
171 NLP,
173 RL,
175 TimeSeries,
177 Graph,
179 Scientific,
181 General,
183}
184
185#[derive(Debug, Clone)]
187pub struct MetaPerformanceMetrics<T: Float + Debug + Send + Sync + 'static> {
188 final_performance: T,
190
191 convergence_speed: T,
193
194 sample_efficiency: T,
196
197 generalization: T,
199
200 stability: T,
202
203 resource_usage: T,
205}
206
207#[derive(Debug, Clone)]
209pub struct DomainAdapter<T: Float + Debug + Send + Sync + 'static> {
210 adapters: HashMap<String, DomainSpecificAdapter<T>>,
212
213 similarity_estimator: DomainSimilarityEstimator<T>,
215
216 adaptation_strategies: Vec<AdaptationStrategy>,
218
219 transfer_tracker: TransferEfficiencyTracker<T>,
221}
222
223#[derive(Debug, Clone)]
225pub struct DomainSpecificAdapter<T: Float + Debug + Send + Sync + 'static> {
226 parameters: HashMap<String, Array1<T>>,
228
229 domain_features: Array1<T>,
231
232 adaptation_history: Vec<AdaptationEvent<T>>,
234
235 domain_performance: T,
237}
238
239#[derive(Debug, Clone)]
241pub struct FewShotLearner<T: Float + Debug + Send + Sync + 'static> {
242 support_memory: HashMap<String, Vec<Array1<T>>>,
244
245 prototypes: HashMap<String, Array1<T>>,
247
248 distance_learner: DistanceMetricLearner<T>,
250
251 adaptation_params: FewShotParams<T>,
253}
254
255#[derive(Debug, Clone)]
257pub struct ContinualLearningState<T: Float + Debug + Send + Sync + 'static> {
258 ewc_params: HashMap<String, Array1<T>>,
260
261 fisher_information: HashMap<String, Array2<T>>,
263
264 task_importance: HashMap<String, T>,
266
267 replay_buffer: Vec<ContinualLearningEvent<T>>,
269
270 forgetting_prevention: ForgettingPreventionStrategy,
272}
273
274#[derive(Debug, Clone)]
276pub struct MetaLearningParams<T: Float + Debug + Send + Sync + 'static> {
277 meta_learning_rate: T,
279
280 inner_steps: usize,
282
283 meta_batch_size: usize,
285
286 diversity_weight: T,
288
289 transfer_coefficient: T,
291
292 memory_retention: T,
294}
295
296#[derive(Debug, Clone, Copy)]
298pub enum AdaptationStrategy {
299 FineTuning,
300 ParameterSharing,
301 ModularAdaptation,
302 AttentionAdaptation,
303}
304
305#[derive(Debug, Clone)]
306pub struct DomainSimilarityEstimator<T: Float + Debug + Send + Sync + 'static> {
307 similarity_matrix: HashMap<(String, String), T>,
308 feature_extractors: HashMap<String, Array2<T>>,
309}
310
311#[derive(Debug, Clone)]
312pub struct TransferEfficiencyTracker<T: Float + Debug + Send + Sync + 'static> {
313 transfer_history: Vec<TransferEvent<T>>,
314 efficiency_metrics: HashMap<String, T>,
315}
316
317#[derive(Debug, Clone)]
318pub struct AdaptationEvent<T: Float + Debug + Send + Sync + 'static> {
319 timestamp: usize,
320 adaptation_loss: T,
321 performance_gain: T,
322 adaptation_steps: usize,
323}
324
325#[derive(Debug, Clone)]
326pub struct DistanceMetricLearner<T: Float + Debug + Send + Sync + 'static> {
327 metric_parameters: Array2<T>,
328 learned_similarities: HashMap<String, T>,
329}
330
331#[derive(Debug, Clone)]
332pub struct FewShotParams<T: Float + Debug + Send + Sync + 'static> {
333 support_size: usize,
334 query_size: usize,
335 adaptation_lr: T,
336 temperature: T,
337}
338
339#[derive(Debug, Clone)]
340pub struct ContinualLearningEvent<T: Float + Debug + Send + Sync + 'static> {
341 task_id: String,
342 gradients: Array1<T>,
343 performance: T,
344 timestamp: usize,
345}
346
347#[derive(Debug, Clone, Copy)]
348pub enum ForgettingPreventionStrategy {
349 EWC,
350 PackNet,
351 ProgressiveNetworks,
352 GEM,
353}
354
355#[derive(Debug, Clone)]
356pub struct TransferEvent<T: Float + Debug + Send + Sync + 'static> {
357 source_domain: String,
358 target_domain: String,
359 transfer_performance: T,
360 adaptation_time: usize,
361}
362
363impl<
364 T: Float
365 + Debug
366 + Send
367 + Sync
368 + 'static
369 + Default
370 + Clone
371 + std::iter::Sum
372 + scirs2_core::ndarray::ScalarOperand,
373 > TransformerMetaLearner<T>
374{
375 pub fn new(strategy: MetaLearningStrategy) -> Result<Self> {
377 Ok(Self {
378 strategy,
379 meta_transformer: None,
380 task_embeddings: HashMap::new(),
381 meta_history: VecDeque::new(),
382 domain_adapter: DomainAdapter::new()?,
383 few_shot_learner: FewShotLearner::new()?,
384 continual_learning: ContinualLearningState::new()?,
385 meta_params: MetaLearningParams::default(),
386 })
387 }
388
389 pub fn adapt_to_task(
391 &mut self,
392 task_info: &TaskInfo<T>,
393 support_data: &[Array1<T>],
394 query_data: &[Array1<T>],
395 ) -> Result<T> {
396 match self.strategy {
397 MetaLearningStrategy::MAML => self.maml_adaptation(task_info, support_data, query_data),
398 MetaLearningStrategy::Reptile => {
399 self.reptile_adaptation(task_info, support_data, query_data)
400 }
401 MetaLearningStrategy::FewShot => {
402 self.few_shot_adaptation(task_info, support_data, query_data)
403 }
404 MetaLearningStrategy::Continual => {
405 self.continual_adaptation(task_info, support_data, query_data)
406 }
407 _ => self.generic_adaptation(task_info, support_data, query_data),
408 }
409 }
410
411 fn maml_adaptation(
413 &mut self,
414 task_info: &TaskInfo<T>,
415 support_data: &[Array1<T>],
416 query_data: &[Array1<T>],
417 ) -> Result<T> {
418 let mut adaptation_loss = T::zero();
420
421 for _ in 0..self.meta_params.inner_steps {
423 let support_loss = self.compute_support_loss(support_data)?;
425
426 adaptation_loss = adaptation_loss + support_loss;
428 }
429
430 let query_loss = self.compute_query_loss(query_data)?;
432
433 let event = MetaTrainingEvent {
435 event_type: MetaEventType::TaskAdaptation,
436 task_info: task_info.clone(),
437 performance: MetaPerformanceMetrics {
438 final_performance: query_loss,
439 convergence_speed: scirs2_core::numeric::NumCast::from(
440 1.0 / self.meta_params.inner_steps as f64,
441 )
442 .unwrap_or_else(|| T::zero()),
443 sample_efficiency: T::from(support_data.len() as f64).expect("unwrap failed"),
444 generalization: T::one() / (T::one() + query_loss),
445 stability: scirs2_core::numeric::NumCast::from(0.9).unwrap_or_else(|| T::zero()),
446 resource_usage: scirs2_core::numeric::NumCast::from(
447 self.meta_params.inner_steps as f64,
448 )
449 .unwrap_or_else(|| T::zero()),
450 },
451 adaptation_steps: self.meta_params.inner_steps,
452 timestamp: self.meta_history.len(),
453 };
454
455 self.meta_history.push_back(event);
456
457 Ok(query_loss)
458 }
459
460 fn reptile_adaptation(
462 &mut self,
463 task_info: &TaskInfo<T>,
464 support_data: &[Array1<T>],
465 _query_data: &[Array1<T>],
466 ) -> Result<T> {
467 let initial_loss = self.compute_support_loss(support_data)?;
469
470 let mut final_loss = initial_loss;
472 for _ in 0..self.meta_params.inner_steps {
473 final_loss =
474 final_loss * scirs2_core::numeric::NumCast::from(0.95).unwrap_or_else(|| T::zero());
475 }
477
478 Ok(final_loss)
479 }
480
481 fn few_shot_adaptation(
483 &mut self,
484 task_info: &TaskInfo<T>,
485 support_data: &[Array1<T>],
486 query_data: &[Array1<T>],
487 ) -> Result<T> {
488 self.few_shot_learner
489 .adapt(task_info, support_data, query_data)
490 }
491
492 fn continual_adaptation(
494 &mut self,
495 task_info: &TaskInfo<T>,
496 support_data: &[Array1<T>],
497 query_data: &[Array1<T>],
498 ) -> Result<T> {
499 self.continual_learning
501 .update_for_task(task_info, support_data)?;
502
503 let base_loss = self.compute_support_loss(support_data)?;
505 let forgetting_penalty = self.continual_learning.compute_forgetting_penalty()?;
506
507 Ok(base_loss + forgetting_penalty)
508 }
509
510 fn generic_adaptation(
512 &mut self,
513 _task_info: &TaskInfo<T>,
514 support_data: &[Array1<T>],
515 query_data: &[Array1<T>],
516 ) -> Result<T> {
517 let support_loss = self.compute_support_loss(support_data)?;
518 let query_loss = self.compute_query_loss(query_data)?;
519 Ok((support_loss + query_loss)
520 / scirs2_core::numeric::NumCast::from(2.0).unwrap_or_else(|| T::zero()))
521 }
522
523 fn compute_support_loss(&self, support_data: &[Array1<T>]) -> Result<T> {
525 if support_data.is_empty() {
526 return Ok(T::zero());
527 }
528
529 let mut total_loss = T::zero();
530 for data in support_data {
531 let loss = data.iter().map(|&x| x * x).fold(T::zero(), |a, b| a + b);
533 total_loss = total_loss + loss;
534 }
535
536 Ok(total_loss / T::from(support_data.len() as f64).expect("unwrap failed"))
537 }
538
539 fn compute_query_loss(&self, query_data: &[Array1<T>]) -> Result<T> {
541 if query_data.is_empty() {
542 return Ok(T::zero());
543 }
544
545 let mut total_loss = T::zero();
546 for data in query_data {
547 let loss = data.iter().map(|&x| x * x).fold(T::zero(), |a, b| a + b);
549 total_loss = total_loss + loss;
550 }
551
552 Ok(total_loss / T::from(query_data.len() as f64).expect("unwrap failed"))
553 }
554
555 pub fn get_meta_statistics(&self) -> HashMap<String, T> {
557 let mut stats = HashMap::new();
558
559 stats.insert(
560 "meta_events_count".to_string(),
561 T::from(self.meta_history.len() as f64).expect("unwrap failed"),
562 );
563 stats.insert(
564 "task_embeddings_count".to_string(),
565 T::from(self.task_embeddings.len() as f64).expect("unwrap failed"),
566 );
567
568 if !self.meta_history.is_empty() {
570 let avg_performance = self
571 .meta_history
572 .iter()
573 .map(|event| event.performance.final_performance)
574 .fold(T::zero(), |a, b| a + b)
575 / T::from(self.meta_history.len() as f64).expect("unwrap failed");
576 stats.insert("average_performance".to_string(), avg_performance);
577 }
578
579 stats
580 }
581
582 pub fn update_meta_parameters(&mut self, params: MetaLearningParams<T>) {
584 self.meta_params = params;
585 }
586
587 pub fn domain_adapter(&self) -> &DomainAdapter<T> {
589 &self.domain_adapter
590 }
591
592 pub fn reset(&mut self) {
594 self.task_embeddings.clear();
595 self.meta_history.clear();
596 self.domain_adapter.reset();
597 self.few_shot_learner.reset();
598 self.continual_learning.reset();
599 }
600}
601
602impl<
604 T: Float
605 + Debug
606 + Send
607 + Sync
608 + 'static
609 + Default
610 + Clone
611 + std::iter::Sum
612 + scirs2_core::ndarray::ScalarOperand,
613 > DomainAdapter<T>
614{
615 fn new() -> Result<Self> {
616 Ok(Self {
617 adapters: HashMap::new(),
618 similarity_estimator: DomainSimilarityEstimator::new()?,
619 adaptation_strategies: vec![AdaptationStrategy::FineTuning],
620 transfer_tracker: TransferEfficiencyTracker::new()?,
621 })
622 }
623
624 fn reset(&mut self) {
625 self.adapters.clear();
626 }
627}
628
629impl<
630 T: Float
631 + Debug
632 + Send
633 + Sync
634 + 'static
635 + Default
636 + Clone
637 + std::iter::Sum
638 + scirs2_core::ndarray::ScalarOperand,
639 > FewShotLearner<T>
640{
641 fn new() -> Result<Self> {
642 Ok(Self {
643 support_memory: HashMap::new(),
644 prototypes: HashMap::new(),
645 distance_learner: DistanceMetricLearner::new()?,
646 adaptation_params: FewShotParams::default(),
647 })
648 }
649
650 fn adapt(
651 &mut self,
652 _task_info: &TaskInfo<T>,
653 support_data: &[Array1<T>],
654 query_data: &[Array1<T>],
655 ) -> Result<T> {
656 let support_loss = support_data
658 .iter()
659 .map(|x| x.iter().map(|&v| v * v).fold(T::zero(), |a, b| a + b))
660 .fold(T::zero(), |a, b| a + b);
661 let query_loss = query_data
662 .iter()
663 .map(|x| x.iter().map(|&v| v * v).fold(T::zero(), |a, b| a + b))
664 .fold(T::zero(), |a, b| a + b);
665
666 Ok((support_loss + query_loss)
667 / T::from((support_data.len() + query_data.len()) as f64).expect("unwrap failed"))
668 }
669
670 fn reset(&mut self) {
671 self.support_memory.clear();
672 self.prototypes.clear();
673 }
674}
675
676impl<
677 T: Float
678 + Debug
679 + Send
680 + Sync
681 + 'static
682 + Default
683 + Clone
684 + std::iter::Sum
685 + scirs2_core::ndarray::ScalarOperand,
686 > ContinualLearningState<T>
687{
688 fn new() -> Result<Self> {
689 Ok(Self {
690 ewc_params: HashMap::new(),
691 fisher_information: HashMap::new(),
692 task_importance: HashMap::new(),
693 replay_buffer: Vec::new(),
694 forgetting_prevention: ForgettingPreventionStrategy::EWC,
695 })
696 }
697
698 fn update_for_task(
699 &mut self,
700 task_info: &TaskInfo<T>,
701 _support_data: &[Array1<T>],
702 ) -> Result<()> {
703 self.task_importance
704 .insert(task_info.task_id.clone(), task_info.difficulty);
705 Ok(())
706 }
707
708 fn compute_forgetting_penalty(&self) -> Result<T> {
709 Ok(scirs2_core::numeric::NumCast::from(0.01).unwrap_or_else(|| T::zero()))
711 }
712
713 fn reset(&mut self) {
714 self.ewc_params.clear();
715 self.fisher_information.clear();
716 self.task_importance.clear();
717 self.replay_buffer.clear();
718 }
719}
720
721impl<
722 T: Float
723 + Debug
724 + Send
725 + Sync
726 + 'static
727 + Default
728 + Clone
729 + std::iter::Sum
730 + scirs2_core::ndarray::ScalarOperand,
731 > DomainSimilarityEstimator<T>
732{
733 fn new() -> Result<Self> {
734 Ok(Self {
735 similarity_matrix: HashMap::new(),
736 feature_extractors: HashMap::new(),
737 })
738 }
739}
740
741impl<
742 T: Float
743 + Debug
744 + Send
745 + Sync
746 + 'static
747 + Default
748 + Clone
749 + std::iter::Sum
750 + scirs2_core::ndarray::ScalarOperand,
751 > TransferEfficiencyTracker<T>
752{
753 fn new() -> Result<Self> {
754 Ok(Self {
755 transfer_history: Vec::new(),
756 efficiency_metrics: HashMap::new(),
757 })
758 }
759}
760
761impl<
762 T: Float
763 + Debug
764 + Send
765 + Sync
766 + 'static
767 + Default
768 + Clone
769 + std::iter::Sum
770 + scirs2_core::ndarray::ScalarOperand,
771 > DistanceMetricLearner<T>
772{
773 fn new() -> Result<Self> {
774 Ok(Self {
775 metric_parameters: Array2::eye(10), learned_similarities: HashMap::new(),
777 })
778 }
779}
780
781impl<
782 T: Float
783 + Debug
784 + Send
785 + Sync
786 + 'static
787 + Default
788 + Clone
789 + std::iter::Sum
790 + scirs2_core::ndarray::ScalarOperand,
791 > Default for MetaLearningParams<T>
792{
793 fn default() -> Self {
794 Self {
795 meta_learning_rate: scirs2_core::numeric::NumCast::from(0.001)
796 .unwrap_or_else(|| T::zero()),
797 inner_steps: 5,
798 meta_batch_size: 32,
799 diversity_weight: scirs2_core::numeric::NumCast::from(0.1).unwrap_or_else(|| T::zero()),
800 transfer_coefficient: scirs2_core::numeric::NumCast::from(0.5)
801 .unwrap_or_else(|| T::zero()),
802 memory_retention: scirs2_core::numeric::NumCast::from(0.95)
803 .unwrap_or_else(|| T::zero()),
804 }
805 }
806}
807
808impl<
809 T: Float
810 + Debug
811 + Send
812 + Sync
813 + 'static
814 + Default
815 + Clone
816 + std::iter::Sum
817 + scirs2_core::ndarray::ScalarOperand,
818 > Default for FewShotParams<T>
819{
820 fn default() -> Self {
821 Self {
822 support_size: 5,
823 query_size: 15,
824 adaptation_lr: scirs2_core::numeric::NumCast::from(0.01).unwrap_or_else(|| T::zero()),
825 temperature: scirs2_core::numeric::NumCast::from(1.0).unwrap_or_else(|| T::zero()),
826 }
827 }
828}