use std::collections::HashMap;
use crate::{
Array, Result,
error::{EmptyInputPayload, Error, LengthMismatchPayload},
lm::{load::Weights, tuner::optimizers::base::Optimizer},
};
pub type FilterFn = Box<dyn Fn(&str, &Array) -> bool>;
pub struct MultiOptimizer {
optimizers: Vec<Box<dyn Optimizer>>,
filters: Vec<FilterFn>,
step_count: usize,
current_lr: f32,
}
impl MultiOptimizer {
pub fn new(optimizers: Vec<Box<dyn Optimizer>>, filters: Vec<FilterFn>) -> Result<Self> {
if optimizers.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"MultiOptimizer: optimizers",
)));
}
if filters.len() != optimizers.len() - 1 {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"MultiOptimizer: filters (must equal optimizers - 1)",
optimizers.len() - 1,
filters.len(),
)));
}
let current_lr = optimizers[0].learning_rate();
Ok(Self {
optimizers,
filters,
step_count: 0,
current_lr,
})
}
fn split_dictionary(&self, weights: &Weights) -> Result<Vec<Weights>> {
let mut parts: Vec<Weights> = (0..self.optimizers.len()).map(|_| HashMap::new()).collect();
for (key, value) in weights {
let mut placed = false;
for (i, filter) in self.filters.iter().enumerate() {
if filter(key, value) {
parts[i].insert(key.clone(), value.try_clone()?);
placed = true;
break;
}
}
if !placed {
let last = parts.len() - 1;
parts[last].insert(key.clone(), value.try_clone()?);
}
}
Ok(parts)
}
}
impl Optimizer for MultiOptimizer {
fn init(&mut self, params: &Weights) -> Result<()> {
let split = self.split_dictionary(params)?;
for (opt, p) in self.optimizers.iter_mut().zip(split) {
opt.init(&p)?;
}
Ok(())
}
fn preflight(&mut self) -> Result<()> {
for optimizer in &mut self.optimizers {
optimizer.preflight()?;
}
Ok(())
}
fn apply_gradients(&mut self, gradients: &Weights, params: &mut Weights) -> Result<()> {
self.preflight()?;
let grad_split = self.split_dictionary(gradients)?;
for (opt, gs) in self.optimizers.iter_mut().zip(grad_split.iter()) {
let mut ps: Weights = HashMap::with_capacity(gs.len());
for key in gs.keys() {
if let Some(v) = params.get(key) {
ps.insert(key.clone(), v.try_clone()?);
}
}
opt.apply_gradients(gs, &mut ps)?;
for (k, v) in ps {
params.insert(k, v);
}
}
self.step_count += 1;
self.current_lr = self.optimizers[0].learning_rate();
Ok(())
}
fn step(&self) -> usize {
self.step_count
}
fn learning_rate(&self) -> f32 {
self.current_lr
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lm::tuner::optimizers::sgd::SGD;
fn scalar(v: f32) -> Result<Array> {
Array::full::<f32>(&[0i32; 0], v)
}
fn read_scalar(a: &Array) -> Result<f32> {
let mut clone = a.try_clone()?;
clone.item::<f32>()
}
#[test]
fn multi_routes_by_filter_to_distinct_sgd_lrs() -> Result<()> {
let bias_sgd: Box<dyn Optimizer> = Box::new(SGD::vanilla(1e-3)?);
let weight_sgd: Box<dyn Optimizer> = Box::new(SGD::vanilla(1e-1)?);
let bias_filter: FilterFn = Box::new(|name, _| name.starts_with("bias."));
let mut multi = MultiOptimizer::new(vec![bias_sgd, weight_sgd], vec![bias_filter])?;
let mut params: Weights = HashMap::new();
params.insert("bias.0".into(), scalar(1.0)?);
params.insert("layer.weight".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("bias.0".into(), scalar(1.0)?);
grads.insert("layer.weight".into(), scalar(1.0)?);
multi.apply_gradients(&grads, &mut params)?;
assert!((read_scalar(¶ms["bias.0"])? - 0.999).abs() < 1e-6);
assert!((read_scalar(¶ms["layer.weight"])? - 0.9).abs() < 1e-6);
Ok(())
}
#[test]
fn multi_rejects_wrong_filter_count() {
let res = MultiOptimizer::new(
vec![
Box::new(SGD::vanilla(0.1).unwrap()),
Box::new(SGD::vanilla(0.1).unwrap()),
],
vec![], );
assert!(res.is_err());
}
#[test]
fn multi_optimizer_atomicity_on_mid_run_nan_schedule() -> Result<()> {
use crate::lm::tuner::optimizers::base::LearningRate;
let child0: Box<dyn Optimizer> = Box::new(SGD::vanilla(0.1)?);
let bad_schedule =
LearningRate::Schedule(Box::new(|step| if step == 0 { 0.1_f32 } else { f32::NAN }));
let child1: Box<dyn Optimizer> = Box::new(SGD::vanilla(bad_schedule)?);
let x_filter: FilterFn = Box::new(|name, _| name == "x");
let mut multi = MultiOptimizer::new(vec![child0, child1], vec![x_filter])?;
let mut params: Weights = HashMap::new();
params.insert("x".into(), scalar(1.0)?);
params.insert("y".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("x".into(), scalar(1.0)?);
grads.insert("y".into(), scalar(1.0)?);
multi.apply_gradients(&grads, &mut params)?;
assert_eq!(
multi.step(),
1,
"multi step_count must be 1 after first apply"
);
let err = multi.apply_gradients(&grads, &mut params);
assert!(
err.is_err(),
"second apply_gradients must err when child1 schedule goes NaN"
);
assert_eq!(
multi.step(),
1,
"MultiOptimizer step_count must not advance when preflight rejects"
);
let x_val = read_scalar(¶ms["x"])?;
assert!(
(x_val - 0.9).abs() < 1e-6,
"x param must not be mutated by the rejected second apply (got {x_val})"
);
Ok(())
}
#[test]
fn multi_optimizer_atomicity_holds_for_stateful_schedule() -> Result<()> {
use crate::lm::tuner::optimizers::base::LearningRate;
use std::{cell::Cell, rc::Rc};
let call_count = Rc::new(Cell::new(0u32));
let bad_schedule = LearningRate::Schedule(Box::new(move |_step| {
let n = call_count.get();
call_count.set(n + 1);
if n.is_multiple_of(2) {
0.1_f32
} else {
f32::NAN
}
}));
let child0: Box<dyn Optimizer> = Box::new(SGD::vanilla(0.1)?);
let child1: Box<dyn Optimizer> = Box::new(SGD::vanilla(bad_schedule)?);
let x_filter: FilterFn = Box::new(|name, _| name == "x");
let mut multi = MultiOptimizer::new(vec![child0, child1], vec![x_filter])?;
let mut params: Weights = HashMap::new();
params.insert("x".into(), scalar(1.0)?);
params.insert("y".into(), scalar(1.0)?);
let mut grads: Weights = HashMap::new();
grads.insert("x".into(), scalar(1.0)?);
grads.insert("y".into(), scalar(1.0)?);
multi.apply_gradients(&grads, &mut params)?;
assert_eq!(multi.step(), 1, "step must advance after successful apply");
let x_val = read_scalar(¶ms["x"])?;
let y_val = read_scalar(¶ms["y"])?;
assert!((x_val - 0.9).abs() < 1e-6, "x should be 0.9, got {x_val}");
assert!((y_val - 0.9).abs() < 1e-6, "y should be 0.9, got {y_val}");
Ok(())
}
}