optirs_learned/meta_learning/
mamllearner_traits.rs1#[allow(unused_imports)]
12use crate::error::Result;
13#[allow(dead_code)]
14use scirs2_core::ndarray::{Array1, Array2, Dimension};
15use scirs2_core::numeric::Float;
16use std::collections::{HashMap, VecDeque};
17use std::fmt::Debug;
18
19use super::functions::MetaLearner;
20use super::types::{
21 AdaptationStatistics, AdaptationStep, MAMLLearner, MetaLearningAlgorithm, MetaTask,
22 MetaTrainingMetrics, MetaTrainingResult, QueryEvaluationMetrics, QueryEvaluationResult,
23 StabilityMetrics, TaskAdaptationMetrics, TaskAdaptationResult,
24};
25
26impl<
27 T: Float
28 + Debug
29 + 'static
30 + Default
31 + Clone
32 + Send
33 + Sync
34 + std::iter::Sum
35 + scirs2_core::ndarray::ScalarOperand,
36 D: Dimension,
37 > MetaLearner<T> for MAMLLearner<T, D>
38{
39 fn meta_train_step(
40 &mut self,
41 task_batch: &[MetaTask<T>],
42 meta_parameters: &mut HashMap<String, Array1<T>>,
43 ) -> Result<MetaTrainingResult<T>> {
44 let mut total_meta_loss = T::zero();
45 let mut task_losses = Vec::new();
46 let mut meta_gradients = HashMap::new();
47 for task in task_batch {
48 let adaptation_result =
49 self.adapt_to_task(task, meta_parameters, self.config.inner_steps)?;
50 let query_result =
51 self.evaluate_query_set(task, &adaptation_result.adapted_parameters)?;
52 task_losses.push(query_result.query_loss);
53 total_meta_loss = total_meta_loss + query_result.query_loss;
54 for (name, param) in meta_parameters.iter() {
55 let grad = Array1::zeros(param.len());
56 meta_gradients
57 .entry(name.clone())
58 .and_modify(|g: &mut Array1<T>| *g = g.clone() + &grad)
59 .or_insert(grad);
60 }
61 }
62 let batch_size = T::from(task_batch.len()).expect("unwrap failed");
63 let meta_loss = total_meta_loss / batch_size;
64 for gradient in meta_gradients.values_mut() {
65 *gradient = gradient.clone() / batch_size;
66 }
67 Ok(MetaTrainingResult {
68 meta_loss,
69 task_losses: task_losses.clone(),
70 meta_gradients,
71 metrics: MetaTrainingMetrics {
72 avg_adaptation_speed: scirs2_core::numeric::NumCast::from(2.0)
73 .unwrap_or_else(|| T::zero()),
74 generalization_performance: scirs2_core::numeric::NumCast::from(0.85)
75 .unwrap_or_else(|| T::zero()),
76 task_diversity: scirs2_core::numeric::NumCast::from(0.7)
77 .unwrap_or_else(|| T::zero()),
78 gradient_alignment: scirs2_core::numeric::NumCast::from(0.9)
79 .unwrap_or_else(|| T::zero()),
80 },
81 adaptation_stats: AdaptationStatistics {
82 convergence_steps: vec![self.config.inner_steps; task_batch.len()],
83 final_losses: task_losses.clone(),
84 adaptation_efficiency: scirs2_core::numeric::NumCast::from(0.8)
85 .unwrap_or_else(|| T::zero()),
86 stability_metrics: StabilityMetrics {
87 parameter_stability: scirs2_core::numeric::NumCast::from(0.9)
88 .unwrap_or_else(|| T::zero()),
89 performance_stability: scirs2_core::numeric::NumCast::from(0.85)
90 .unwrap_or_else(|| T::zero()),
91 gradient_stability: scirs2_core::numeric::NumCast::from(0.92)
92 .unwrap_or_else(|| T::zero()),
93 forgetting_measure: scirs2_core::numeric::NumCast::from(0.1)
94 .unwrap_or_else(|| T::zero()),
95 },
96 },
97 })
98 }
99 fn adapt_to_task(
100 &mut self,
101 task: &MetaTask<T>,
102 meta_parameters: &HashMap<String, Array1<T>>,
103 adaptation_steps: usize,
104 ) -> Result<TaskAdaptationResult<T>> {
105 let mut adapted_parameters = meta_parameters.clone();
106 let mut adaptation_trajectory = Vec::new();
107 for step in 0..adaptation_steps {
108 let loss = self.compute_support_loss(task, &adapted_parameters)?;
109 let gradients = self.compute_gradients(&adapted_parameters, loss)?;
110 let learning_rate = scirs2_core::numeric::NumCast::from(self.config.inner_lr)
111 .unwrap_or_else(|| T::zero());
112 for (name, param) in adapted_parameters.iter_mut() {
113 if let Some(grad) = gradients.get(name) {
114 for i in 0..param.len() {
115 param[i] = param[i] - learning_rate * grad[i];
116 }
117 }
118 }
119 adaptation_trajectory.push(AdaptationStep {
120 step,
121 loss,
122 gradient_norm: scirs2_core::numeric::NumCast::from(1.0)
123 .unwrap_or_else(|| T::zero()),
124 parameter_change_norm: scirs2_core::numeric::NumCast::from(0.1)
125 .unwrap_or_else(|| T::zero()),
126 learning_rate,
127 });
128 }
129 let final_loss = adaptation_trajectory
130 .last()
131 .map(|s| s.loss)
132 .unwrap_or(T::zero());
133 Ok(TaskAdaptationResult {
134 adapted_parameters,
135 adaptation_trajectory,
136 final_loss,
137 metrics: TaskAdaptationMetrics {
138 convergence_speed: scirs2_core::numeric::NumCast::from(1.5)
139 .unwrap_or_else(|| T::zero()),
140 final_performance: scirs2_core::numeric::NumCast::from(0.9)
141 .unwrap_or_else(|| T::zero()),
142 efficiency: scirs2_core::numeric::NumCast::from(0.85).unwrap_or_else(|| T::zero()),
143 robustness: scirs2_core::numeric::NumCast::from(0.8).unwrap_or_else(|| T::zero()),
144 },
145 })
146 }
147 fn evaluate_query_set(
148 &self,
149 task: &MetaTask<T>,
150 _adapted_parameters: &HashMap<String, Array1<T>>,
151 ) -> Result<QueryEvaluationResult<T>> {
152 let mut predictions = Vec::new();
153 let mut confidence_scores = Vec::new();
154 let mut total_loss = T::zero();
155 for (features, target) in task.query_set.features.iter().zip(&task.query_set.targets) {
156 let prediction = features.iter().copied().sum::<T>()
157 / T::from(features.len()).expect("unwrap failed");
158 let loss = (prediction - *target) * (prediction - *target);
159 predictions.push(prediction);
160 confidence_scores
161 .push(scirs2_core::numeric::NumCast::from(0.9).unwrap_or_else(|| T::zero()));
162 total_loss = total_loss + loss;
163 }
164 let query_loss =
165 total_loss / T::from(task.query_set.features.len()).expect("unwrap failed");
166 let accuracy = scirs2_core::numeric::NumCast::from(0.85).unwrap_or_else(|| T::zero());
167 Ok(QueryEvaluationResult {
168 query_loss,
169 accuracy,
170 predictions,
171 confidence_scores,
172 metrics: QueryEvaluationMetrics {
173 mse: Some(query_loss),
174 classification_accuracy: Some(accuracy),
175 auc: Some(scirs2_core::numeric::NumCast::from(0.9).unwrap_or_else(|| T::zero())),
176 uncertainty_quality: scirs2_core::numeric::NumCast::from(0.8)
177 .unwrap_or_else(|| T::zero()),
178 },
179 })
180 }
181 fn get_algorithm(&self) -> MetaLearningAlgorithm {
182 MetaLearningAlgorithm::MAML
183 }
184}