use super::*;
use anyhow::Result;
use scirs2_core::ndarray_ext::{Array1, Array2};
use scirs2_core::random::{Random, RngExt};
use std::collections::HashMap;
#[derive(Debug)]
pub struct MetaLearner {
pub config: MetaLearningConfig,
pub meta_parameters: HashMap<String, Array2<f32>>,
pub task_parameters: HashMap<String, Array2<f32>>,
}
impl MetaLearner {
pub fn new(config: MetaLearningConfig) -> Self {
let mut meta_parameters = HashMap::new();
let mut task_parameters = HashMap::new();
let mut random = Random::default();
meta_parameters.insert(
"meta_weights".to_string(),
Array2::from_shape_fn((512, 512), |_| (random.random::<f32>() - 0.5) * 0.1),
);
let mut random = Random::default();
task_parameters.insert(
"adaptation_weights".to_string(),
Array2::from_shape_fn((256, 512), |_| (random.random::<f32>() - 0.5) * 0.1),
);
Self {
config,
meta_parameters,
task_parameters,
}
}
pub fn adapt_to_task(
&mut self,
support_set: &[(Array1<f32>, Array1<f32>)],
_query_set: &[(Array1<f32>, Array1<f32>)],
) -> Result<HashMap<String, Array2<f32>>> {
match self.config.algorithm {
MetaLearningAlgorithm::MAML => self.maml_adaptation(support_set),
MetaLearningAlgorithm::ProtoNet => self.prototypical_adaptation(support_set),
_ => self.maml_adaptation(support_set),
}
}
fn maml_adaptation(
&mut self,
support_set: &[(Array1<f32>, Array1<f32>)],
) -> Result<HashMap<String, Array2<f32>>> {
let mut adapted_params = self.meta_parameters.clone();
for _step in 0..self.config.adaptation_steps {
for (input, _target) in support_set {
if let Some(weights) = adapted_params.get_mut("meta_weights") {
let _output = weights.dot(input);
*weights = &*weights * 0.99; }
}
}
Ok(adapted_params)
}
fn prototypical_adaptation(
&self,
support_set: &[(Array1<f32>, Array1<f32>)],
) -> Result<HashMap<String, Array2<f32>>> {
let mut prototypes = HashMap::new();
let mut class_counts = HashMap::new();
for (input, target) in support_set {
let class_id = target[0] as i32;
let class_key = class_id.to_string();
let prototype = prototypes
.entry(class_key.clone())
.or_insert(Array1::zeros(input.len()));
let count = class_counts.entry(class_key).or_insert(0);
*prototype = &*prototype + input;
*count += 1;
}
for (class_key, count) in class_counts {
if let Some(prototype) = prototypes.get_mut(&class_key) {
*prototype /= count as f32;
}
}
Ok(self.meta_parameters.clone())
}
}