use crate::{TrainError, TrainResult};
use scirs2_core::ndarray::{Array1, Array2};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct MAMLConfig {
pub inner_steps: usize,
pub inner_lr: f64,
pub outer_lr: f64,
pub first_order: bool,
}
impl Default for MAMLConfig {
fn default() -> Self {
Self {
inner_steps: 5,
inner_lr: 0.01,
outer_lr: 0.001,
first_order: false,
}
}
}
#[derive(Debug, Clone)]
pub struct ReptileConfig {
pub inner_steps: usize,
pub inner_lr: f64,
pub outer_lr: f64,
}
impl Default for ReptileConfig {
fn default() -> Self {
Self {
inner_steps: 10,
inner_lr: 0.01,
outer_lr: 0.1,
}
}
}
#[derive(Debug, Clone)]
pub struct MetaTask {
pub support_x: Array2<f64>,
pub support_y: Array2<f64>,
pub query_x: Array2<f64>,
pub query_y: Array2<f64>,
}
impl MetaTask {
pub fn new(
support_x: Array2<f64>,
support_y: Array2<f64>,
query_x: Array2<f64>,
query_y: Array2<f64>,
) -> TrainResult<Self> {
if support_x.nrows() != support_y.nrows() {
return Err(TrainError::InvalidParameter(format!(
"Support X rows ({}) must match support Y rows ({})",
support_x.nrows(),
support_y.nrows()
)));
}
if query_x.nrows() != query_y.nrows() {
return Err(TrainError::InvalidParameter(format!(
"Query X rows ({}) must match query Y rows ({})",
query_x.nrows(),
query_y.nrows()
)));
}
Ok(Self {
support_x,
support_y,
query_x,
query_y,
})
}
pub fn support_size(&self) -> usize {
self.support_x.nrows()
}
pub fn query_size(&self) -> usize {
self.query_x.nrows()
}
}
pub trait MetaLearner {
fn meta_step(
&self,
tasks: &[MetaTask],
parameters: &HashMap<String, Array1<f64>>,
) -> TrainResult<(HashMap<String, Array1<f64>>, f64)>;
fn adapt(
&self,
task: &MetaTask,
parameters: &HashMap<String, Array1<f64>>,
) -> TrainResult<HashMap<String, Array1<f64>>>;
}
#[derive(Debug, Clone)]
pub struct MAML {
config: MAMLConfig,
}
impl MAML {
pub fn new(config: MAMLConfig) -> Self {
Self { config }
}
}
impl Default for MAML {
fn default() -> Self {
Self::new(MAMLConfig::default())
}
}
impl MetaLearner for MAML {
fn meta_step(
&self,
tasks: &[MetaTask],
parameters: &HashMap<String, Array1<f64>>,
) -> TrainResult<(HashMap<String, Array1<f64>>, f64)> {
let mut meta_gradients: HashMap<String, Array1<f64>> = HashMap::new();
let mut total_loss = 0.0;
for (name, param) in parameters {
meta_gradients.insert(name.clone(), Array1::zeros(param.len()));
}
for task in tasks {
let adapted_params = self.adapt(task, parameters)?;
let query_loss = self.compute_query_loss(task, &adapted_params)?;
total_loss += query_loss;
let task_gradients = if self.config.first_order {
self.compute_first_order_gradients(task, &adapted_params)?
} else {
self.compute_second_order_gradients(task, parameters, &adapted_params)?
};
for (name, grad) in task_gradients {
if let Some(meta_grad) = meta_gradients.get_mut(&name) {
*meta_grad = meta_grad.clone() + grad;
}
}
}
let n_tasks = tasks.len() as f64;
for grad in meta_gradients.values_mut() {
*grad = grad.mapv(|x| x / n_tasks);
}
total_loss /= n_tasks;
let mut updated_params = HashMap::new();
for (name, param) in parameters {
if let Some(grad) = meta_gradients.get(name) {
let updated = param - &grad.mapv(|g| g * self.config.outer_lr);
updated_params.insert(name.clone(), updated);
}
}
Ok((updated_params, total_loss))
}
fn adapt(
&self,
task: &MetaTask,
parameters: &HashMap<String, Array1<f64>>,
) -> TrainResult<HashMap<String, Array1<f64>>> {
let mut adapted_params = parameters.clone();
for _ in 0..self.config.inner_steps {
let gradients = self.compute_support_gradients(task, &adapted_params)?;
for (name, param) in &mut adapted_params {
if let Some(grad) = gradients.get(name) {
*param = param.clone() - &grad.mapv(|g| g * self.config.inner_lr);
}
}
}
Ok(adapted_params)
}
}
impl MAML {
fn compute_support_gradients(
&self,
task: &MetaTask,
_parameters: &HashMap<String, Array1<f64>>,
) -> TrainResult<HashMap<String, Array1<f64>>> {
let mut gradients = HashMap::new();
gradients.insert("weights".to_string(), Array1::zeros(task.support_x.ncols()));
Ok(gradients)
}
fn compute_query_loss(
&self,
task: &MetaTask,
_parameters: &HashMap<String, Array1<f64>>,
) -> TrainResult<f64> {
Ok(task.query_size() as f64 * 0.1)
}
fn compute_first_order_gradients(
&self,
task: &MetaTask,
_parameters: &HashMap<String, Array1<f64>>,
) -> TrainResult<HashMap<String, Array1<f64>>> {
let mut gradients = HashMap::new();
gradients.insert("weights".to_string(), Array1::zeros(task.query_x.ncols()));
Ok(gradients)
}
fn compute_second_order_gradients(
&self,
task: &MetaTask,
_meta_params: &HashMap<String, Array1<f64>>,
_adapted_params: &HashMap<String, Array1<f64>>,
) -> TrainResult<HashMap<String, Array1<f64>>> {
let mut gradients = HashMap::new();
gradients.insert("weights".to_string(), Array1::zeros(task.query_x.ncols()));
Ok(gradients)
}
}
#[derive(Debug, Clone)]
pub struct Reptile {
config: ReptileConfig,
}
impl Reptile {
pub fn new(config: ReptileConfig) -> Self {
Self { config }
}
}
impl Default for Reptile {
fn default() -> Self {
Self::new(ReptileConfig::default())
}
}
impl MetaLearner for Reptile {
fn meta_step(
&self,
tasks: &[MetaTask],
parameters: &HashMap<String, Array1<f64>>,
) -> TrainResult<(HashMap<String, Array1<f64>>, f64)> {
let mut total_loss = 0.0;
let mut accumulated_delta: HashMap<String, Array1<f64>> = HashMap::new();
for (name, param) in parameters {
accumulated_delta.insert(name.clone(), Array1::zeros(param.len()));
}
for task in tasks {
let task_params = self.adapt(task, parameters)?;
let task_loss = self.compute_task_loss(task, &task_params)?;
total_loss += task_loss;
for (name, param) in parameters {
if let Some(task_param) = task_params.get(name) {
let delta = task_param - param;
if let Some(acc_delta) = accumulated_delta.get_mut(name) {
*acc_delta = acc_delta.clone() + delta;
}
}
}
}
let n_tasks = tasks.len() as f64;
for delta in accumulated_delta.values_mut() {
*delta = delta.mapv(|x| x / n_tasks);
}
total_loss /= n_tasks;
let mut updated_params = HashMap::new();
for (name, param) in parameters {
if let Some(delta) = accumulated_delta.get(name) {
let updated = param + &delta.mapv(|d| d * self.config.outer_lr);
updated_params.insert(name.clone(), updated);
}
}
Ok((updated_params, total_loss))
}
fn adapt(
&self,
task: &MetaTask,
parameters: &HashMap<String, Array1<f64>>,
) -> TrainResult<HashMap<String, Array1<f64>>> {
let mut task_params = parameters.clone();
for _ in 0..self.config.inner_steps {
let gradients = self.compute_support_gradients(task, &task_params)?;
for (name, param) in &mut task_params {
if let Some(grad) = gradients.get(name) {
*param = param.clone() - &grad.mapv(|g| g * self.config.inner_lr);
}
}
}
Ok(task_params)
}
}
impl Reptile {
fn compute_support_gradients(
&self,
task: &MetaTask,
_parameters: &HashMap<String, Array1<f64>>,
) -> TrainResult<HashMap<String, Array1<f64>>> {
let mut gradients = HashMap::new();
gradients.insert("weights".to_string(), Array1::zeros(task.support_x.ncols()));
Ok(gradients)
}
fn compute_task_loss(
&self,
task: &MetaTask,
_parameters: &HashMap<String, Array1<f64>>,
) -> TrainResult<f64> {
Ok(task.query_size() as f64 * 0.1)
}
}
#[derive(Debug, Clone, Default)]
pub struct MetaStats {
pub meta_losses: Vec<f64>,
pub task_losses: Vec<Vec<f64>>,
pub iterations: usize,
}
impl MetaStats {
pub fn new() -> Self {
Self::default()
}
pub fn record_meta_step(&mut self, meta_loss: f64) {
self.meta_losses.push(meta_loss);
self.iterations += 1;
}
pub fn record_task_adaptation(&mut self, task_id: usize, losses: Vec<f64>) {
while self.task_losses.len() <= task_id {
self.task_losses.push(Vec::new());
}
self.task_losses[task_id] = losses;
}
pub fn avg_meta_loss(&self, last_n: usize) -> f64 {
if self.meta_losses.is_empty() {
return 0.0;
}
let n = last_n.min(self.meta_losses.len());
let start = self.meta_losses.len() - n;
self.meta_losses[start..].iter().sum::<f64>() / n as f64
}
pub fn is_improving(&self, window: usize) -> bool {
if self.meta_losses.len() < window * 2 {
return false;
}
let recent = self.avg_meta_loss(window);
let previous = {
let start = self.meta_losses.len() - window * 2;
let end = self.meta_losses.len() - window;
self.meta_losses[start..end].iter().sum::<f64>() / window as f64
};
recent < previous
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_maml_config_default() {
let config = MAMLConfig::default();
assert_eq!(config.inner_steps, 5);
assert_eq!(config.inner_lr, 0.01);
assert_eq!(config.outer_lr, 0.001);
assert!(!config.first_order);
}
#[test]
fn test_reptile_config_default() {
let config = ReptileConfig::default();
assert_eq!(config.inner_steps, 10);
assert_eq!(config.inner_lr, 0.01);
assert_eq!(config.outer_lr, 0.1);
}
#[test]
fn test_meta_task_creation() {
let support_x = Array2::zeros((5, 10));
let support_y = Array2::zeros((5, 2));
let query_x = Array2::zeros((15, 10));
let query_y = Array2::zeros((15, 2));
let task = MetaTask::new(support_x, support_y, query_x, query_y).expect("unwrap");
assert_eq!(task.support_size(), 5);
assert_eq!(task.query_size(), 15);
}
#[test]
fn test_meta_task_validation() {
let support_x = Array2::zeros((5, 10));
let support_y = Array2::zeros((4, 2)); let query_x = Array2::zeros((15, 10));
let query_y = Array2::zeros((15, 2));
let result = MetaTask::new(support_x, support_y, query_x, query_y);
assert!(result.is_err());
}
#[test]
fn test_maml_creation() {
let config = MAMLConfig::default();
let maml = MAML::new(config);
assert_eq!(maml.config.inner_steps, 5);
}
#[test]
fn test_maml_default() {
let maml = MAML::default();
assert_eq!(maml.config.inner_steps, 5);
}
#[test]
fn test_reptile_creation() {
let config = ReptileConfig::default();
let reptile = Reptile::new(config);
assert_eq!(reptile.config.inner_steps, 10);
}
#[test]
fn test_reptile_default() {
let reptile = Reptile::default();
assert_eq!(reptile.config.inner_steps, 10);
}
#[test]
fn test_maml_adapt() {
let maml = MAML::default();
let task = create_dummy_task();
let mut params = HashMap::new();
params.insert("weights".to_string(), Array1::zeros(10));
let adapted = maml.adapt(&task, ¶ms).expect("unwrap");
assert!(adapted.contains_key("weights"));
}
#[test]
fn test_reptile_adapt() {
let reptile = Reptile::default();
let task = create_dummy_task();
let mut params = HashMap::new();
params.insert("weights".to_string(), Array1::zeros(10));
let adapted = reptile.adapt(&task, ¶ms).expect("unwrap");
assert!(adapted.contains_key("weights"));
}
#[test]
fn test_maml_meta_step() {
let maml = MAML::default();
let tasks = vec![create_dummy_task(), create_dummy_task()];
let mut params = HashMap::new();
params.insert("weights".to_string(), Array1::zeros(10));
let (updated_params, loss) = maml.meta_step(&tasks, ¶ms).expect("unwrap");
assert!(updated_params.contains_key("weights"));
assert!(loss >= 0.0);
}
#[test]
fn test_reptile_meta_step() {
let reptile = Reptile::default();
let tasks = vec![create_dummy_task(), create_dummy_task()];
let mut params = HashMap::new();
params.insert("weights".to_string(), Array1::zeros(10));
let (updated_params, loss) = reptile.meta_step(&tasks, ¶ms).expect("unwrap");
assert!(updated_params.contains_key("weights"));
assert!(loss >= 0.0);
}
#[test]
fn test_meta_stats() {
let mut stats = MetaStats::new();
stats.record_meta_step(1.0);
stats.record_meta_step(0.8);
stats.record_meta_step(0.6);
assert_eq!(stats.iterations, 3);
assert_eq!(stats.meta_losses.len(), 3);
assert_eq!(stats.avg_meta_loss(2), 0.7);
}
#[test]
fn test_meta_stats_improvement() {
let mut stats = MetaStats::new();
for i in 0..20 {
stats.record_meta_step(1.0 - i as f64 * 0.01);
}
assert!(stats.is_improving(5));
}
#[test]
fn test_meta_stats_task_adaptation() {
let mut stats = MetaStats::new();
stats.record_task_adaptation(0, vec![1.0, 0.8, 0.6]);
stats.record_task_adaptation(1, vec![1.2, 0.9, 0.7]);
assert_eq!(stats.task_losses.len(), 2);
assert_eq!(stats.task_losses[0].len(), 3);
}
fn create_dummy_task() -> MetaTask {
let support_x = Array2::zeros((5, 10));
let support_y = Array2::zeros((5, 2));
let query_x = Array2::zeros((15, 10));
let query_y = Array2::zeros((15, 2));
MetaTask::new(support_x, support_y, query_x, query_y).expect("unwrap")
}
}