use crate::error::{OptimError, Result};
use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
use scirs2_core::numeric::Float;
use std::collections::VecDeque;
use std::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum AveragingMethod {
MovingAverage,
ExponentialMovingAverage {
decay: f64,
},
StochasticWeightAveraging,
ModelSoup,
}
#[derive(Debug)]
pub struct WeightAverager<A: Float, D: Dimension> {
averaged_weights: Vec<Array<A, D>>,
weight_history: VecDeque<Vec<Array<A, D>>>,
step_count: usize,
method: AveragingMethod,
max_history: usize,
initialized: bool,
ema_decay: A,
}
impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> WeightAverager<A, D> {
pub fn new(method: AveragingMethod, maxhistory: usize) -> Self {
let ema_decay = match method {
AveragingMethod::ExponentialMovingAverage { decay } => {
A::from(decay).unwrap_or_else(|| A::from(0.999).expect("unwrap failed"))
}
_ => A::from(0.999).expect("unwrap failed"),
};
Self {
averaged_weights: Vec::new(),
weight_history: VecDeque::new(),
step_count: 0,
method,
max_history: maxhistory,
initialized: false,
ema_decay,
}
}
pub fn initialize(&mut self, weights: &[Array<A, D>]) -> Result<()> {
if self.initialized {
return Err(OptimError::InvalidConfig(
"Weight averager already initialized".to_string(),
));
}
self.averaged_weights = weights.to_vec();
self.initialized = true;
Ok(())
}
pub fn update(&mut self, weights: &[Array<A, D>]) -> Result<()> {
if !self.initialized {
self.initialize(weights)?;
return Ok(());
}
if weights.len() != self.averaged_weights.len() {
return Err(OptimError::DimensionMismatch(format!(
"Expected {} weight arrays, got {}",
self.averaged_weights.len(),
weights.len()
)));
}
self.step_count += 1;
match self.method {
AveragingMethod::MovingAverage => {
self.update_moving_average(weights)?;
}
AveragingMethod::ExponentialMovingAverage { .. } => {
self.update_exponential_moving_average(weights)?;
}
AveragingMethod::StochasticWeightAveraging => {
self.update_swa(weights)?;
}
AveragingMethod::ModelSoup => {
self.update_model_soup(weights)?;
}
}
Ok(())
}
fn update_moving_average(&mut self, weights: &[Array<A, D>]) -> Result<()> {
self.weight_history.push_back(weights.to_vec());
if self.weight_history.len() > self.max_history {
self.weight_history.pop_front();
}
self.compute_moving_average()
}
fn compute_moving_average(&mut self) -> Result<()> {
if self.weight_history.is_empty() {
return Ok(());
}
let num_snapshots = self.weight_history.len();
let inv_count = A::one() / A::from(num_snapshots).expect("unwrap failed");
for avg_weight in &mut self.averaged_weights {
avg_weight.fill(A::zero());
}
for snapshot in &self.weight_history {
for (avg_weight, weight) in self.averaged_weights.iter_mut().zip(snapshot.iter()) {
Zip::from(avg_weight).and(weight).for_each(|avg, &w| {
*avg = *avg + w;
});
}
}
for avg_weight in &mut self.averaged_weights {
avg_weight.mapv_inplace(|x| x * inv_count);
}
Ok(())
}
fn update_exponential_moving_average(&mut self, weights: &[Array<A, D>]) -> Result<()> {
let alpha = A::one() - self.ema_decay;
for (avg_weight, weight) in self.averaged_weights.iter_mut().zip(weights.iter()) {
Zip::from(avg_weight).and(weight).for_each(|avg, &w| {
*avg = self.ema_decay * *avg + alpha * w;
});
}
Ok(())
}
fn update_swa(&mut self, weights: &[Array<A, D>]) -> Result<()> {
let n = A::from(self.step_count).expect("unwrap failed");
let inv_n = A::one() / n;
let prev_weight = (n - A::one()) / n;
for (avg_weight, weight) in self.averaged_weights.iter_mut().zip(weights.iter()) {
Zip::from(avg_weight).and(weight).for_each(|avg, &w| {
*avg = prev_weight * *avg + inv_n * w;
});
}
Ok(())
}
fn update_model_soup(&mut self, weights: &[Array<A, D>]) -> Result<()> {
self.weight_history.push_back(weights.to_vec());
if self.weight_history.len() > self.max_history {
self.weight_history.pop_front();
}
self.compute_moving_average()
}
pub fn get_averaged_weights(&self) -> &[Array<A, D>] {
&self.averaged_weights
}
pub fn get_averaged_weights_cloned(&self) -> Vec<Array<A, D>> {
self.averaged_weights.clone()
}
pub fn reset(&mut self) {
self.weight_history.clear();
self.step_count = 0;
for weight in &mut self.averaged_weights {
weight.fill(A::zero());
}
}
pub fn step_count(&self) -> usize {
self.step_count
}
pub fn is_initialized(&self) -> bool {
self.initialized
}
pub fn method(&self) -> AveragingMethod {
self.method
}
pub fn set_ema_decay(&mut self, decay: A) {
self.ema_decay = decay;
}
}
#[derive(Debug)]
pub struct PolyakAverager<A: Float, D: Dimension> {
averager: WeightAverager<A, D>,
initial_decay: A,
final_decay: A,
decay_steps: usize,
}
impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> PolyakAverager<A, D> {
pub fn new(initial_decay: A, final_decay: A, decaysteps: usize) -> Self {
let method = AveragingMethod::ExponentialMovingAverage {
decay: initial_decay.to_f64().unwrap_or(0.9),
};
Self {
averager: WeightAverager::new(method, 1), initial_decay,
final_decay,
decay_steps: decaysteps,
}
}
pub fn update(&mut self, weights: &[Array<A, D>]) -> Result<()> {
let step = self.averager.step_count() as f64;
let progress = (step / self.decay_steps as f64).min(1.0);
let current_decay = self.initial_decay.to_f64().unwrap_or(0.9) * (1.0 - progress)
+ self.final_decay.to_f64().unwrap_or(0.999) * progress;
self.averager
.set_ema_decay(A::from(current_decay).expect("unwrap failed"));
self.averager.update(weights)
}
pub fn get_averaged_weights(&self) -> &[Array<A, D>] {
self.averager.get_averaged_weights()
}
pub fn initialize(&mut self, weights: &[Array<A, D>]) -> Result<()> {
self.averager.initialize(weights)
}
}
pub mod gradient_centralization {
use super::*;
pub fn centralize_gradients<A, D>(gradients: &mut [Array<A, D>]) -> Result<()>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
for grad in gradients {
centralize_single_gradient(grad)?;
}
Ok(())
}
pub fn centralize_single_gradient<A, D>(gradient: &mut Array<A, D>) -> Result<()>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
if gradient.is_empty() {
return Ok(());
}
let mean = gradient.sum() / A::from(gradient.len()).expect("unwrap failed");
gradient.mapv_inplace(|x| x - mean);
Ok(())
}
pub fn centralize_gradients_with_scaling<A, D>(
gradients: &mut [Array<A, D>],
scale_factor: A,
) -> Result<()>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
centralize_gradients(gradients)?;
for grad in gradients {
grad.mapv_inplace(|x| x * scale_factor);
}
Ok(())
}
}
#[derive(Debug)]
pub struct ModelEnsemble<A: Float, D: Dimension> {
models: Vec<Vec<Array<A, D>>>,
model_weights: Vec<A>,
ensemble_average: Option<Vec<Array<A, D>>>,
cache_valid: bool,
}
impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> ModelEnsemble<A, D> {
pub fn new() -> Self {
Self {
models: Vec::new(),
model_weights: Vec::new(),
ensemble_average: None,
cache_valid: false,
}
}
pub fn add_model(&mut self, weights: Vec<Array<A, D>>, weight: A) -> Result<()> {
if !self.models.is_empty() {
let expected_len = self.models[0].len();
if weights.len() != expected_len {
return Err(OptimError::DimensionMismatch(format!(
"Expected {} weight arrays, got {}",
expected_len,
weights.len()
)));
}
}
self.models.push(weights);
self.model_weights.push(weight);
self.cache_valid = false;
Ok(())
}
pub fn get_ensemble_average(&mut self) -> Result<&[Array<A, D>]> {
if !self.cache_valid {
self.compute_ensemble_average()?;
}
self.ensemble_average
.as_deref()
.ok_or_else(|| OptimError::InvalidConfig("No models in ensemble".to_string()))
}
fn compute_ensemble_average(&mut self) -> Result<()> {
if self.models.is_empty() {
return Err(OptimError::InvalidConfig(
"No models in ensemble".to_string(),
));
}
let total_weight: A = self.model_weights.iter().fold(A::zero(), |acc, &w| acc + w);
if total_weight <= A::zero() {
return Err(OptimError::InvalidConfig(
"Total ensemble weight must be > 0".to_string(),
));
}
let num_params = self.models[0].len();
let mut ensemble_avg = Vec::new();
for i in 0..num_params {
ensemble_avg.push(Array::zeros(self.models[0][i].raw_dim()));
}
for (model, &weight) in self.models.iter().zip(self.model_weights.iter()) {
let normalized_weight = weight / total_weight;
for (avg_param, model_param) in ensemble_avg.iter_mut().zip(model.iter()) {
Zip::from(avg_param)
.and(model_param)
.for_each(|avg, ¶m| {
*avg = *avg + normalized_weight * param;
});
}
}
self.ensemble_average = Some(ensemble_avg);
self.cache_valid = true;
Ok(())
}
pub fn clear(&mut self) {
self.models.clear();
self.model_weights.clear();
self.ensemble_average = None;
self.cache_valid = false;
}
pub fn len(&self) -> usize {
self.models.len()
}
pub fn is_empty(&self) -> bool {
self.models.is_empty()
}
}
impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> Default for ModelEnsemble<A, D> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::Array1;
#[test]
fn test_moving_average() {
let mut averager = WeightAverager::new(AveragingMethod::MovingAverage, 3);
let weights1 = vec![Array1::from_vec(vec![1.0, 2.0])];
let weights2 = vec![Array1::from_vec(vec![3.0, 4.0])];
let weights3 = vec![Array1::from_vec(vec![5.0, 6.0])];
averager.update(&weights1).expect("unwrap failed");
averager.update(&weights2).expect("unwrap failed");
averager.update(&weights3).expect("unwrap failed");
let avg = averager.get_averaged_weights();
assert!(avg[0][0] >= 1.0 && avg[0][0] <= 5.0);
assert!(avg[0][1] >= 2.0 && avg[0][1] <= 6.0);
}
#[test]
fn test_exponential_moving_average() {
let decay = 0.9;
let mut averager =
WeightAverager::new(AveragingMethod::ExponentialMovingAverage { decay }, 1);
let weights1 = vec![Array1::from_vec(vec![2.0])];
let weights2 = vec![Array1::from_vec(vec![4.0])];
averager.update(&weights1).expect("unwrap failed");
averager.update(&weights2).expect("unwrap failed");
let avg = averager.get_averaged_weights();
assert_relative_eq!(avg[0][0], 2.2, epsilon = 1e-6);
}
#[test]
fn test_swa() {
let mut averager = WeightAverager::new(AveragingMethod::StochasticWeightAveraging, 10);
let weights1 = vec![Array1::from_vec(vec![2.0])];
let weights2 = vec![Array1::from_vec(vec![4.0])];
let weights3 = vec![Array1::from_vec(vec![6.0])];
averager.update(&weights1).expect("unwrap failed"); averager.update(&weights2).expect("unwrap failed"); averager.update(&weights3).expect("unwrap failed");
let avg = averager.get_averaged_weights();
assert!(avg[0][0] >= 3.5 && avg[0][0] <= 5.0);
}
#[test]
fn test_gradient_centralization() {
let mut gradients = vec![Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0])];
gradient_centralization::centralize_gradients(&mut gradients).expect("unwrap failed");
let expected = [-1.5, -0.5, 0.5, 1.5];
for (actual, expected) in gradients[0].iter().zip(expected.iter()) {
assert_relative_eq!(*actual, *expected, epsilon = 1e-6);
}
let mean = gradients[0].sum() / 4.0;
assert_relative_eq!(mean, 0.0, epsilon = 1e-10);
}
#[test]
fn test_polyak_averager() {
let mut averager = PolyakAverager::new(0.5, 0.9, 10);
let weights1 = vec![Array1::from_vec(vec![2.0])];
let weights2 = vec![Array1::from_vec(vec![4.0])];
averager.update(&weights1).expect("unwrap failed");
averager.update(&weights2).expect("unwrap failed");
let avg = averager.get_averaged_weights();
assert!(avg[0][0] > 2.0 && avg[0][0] < 4.0); }
#[test]
fn test_model_ensemble() {
let mut ensemble = ModelEnsemble::new();
let model1 = vec![Array1::from_vec(vec![2.0, 4.0])];
let model2 = vec![Array1::from_vec(vec![4.0, 2.0])];
ensemble.add_model(model1, 1.0).expect("unwrap failed");
ensemble.add_model(model2, 1.0).expect("unwrap failed");
let avg = ensemble.get_ensemble_average().expect("unwrap failed");
assert_relative_eq!(avg[0][0], 3.0, epsilon = 1e-6); assert_relative_eq!(avg[0][1], 3.0, epsilon = 1e-6); }
#[test]
fn test_weighted_model_ensemble() {
let mut ensemble = ModelEnsemble::new();
let model1 = vec![Array1::from_vec(vec![2.0])];
let model2 = vec![Array1::from_vec(vec![4.0])];
ensemble.add_model(model1, 3.0).expect("unwrap failed"); ensemble.add_model(model2, 1.0).expect("unwrap failed");
let avg = ensemble.get_ensemble_average().expect("unwrap failed");
assert_relative_eq!(avg[0][0], 2.5, epsilon = 1e-6);
}
#[test]
fn test_ensemble_dimension_validation() {
let mut ensemble = ModelEnsemble::new();
let model1 = vec![Array1::from_vec(vec![1.0, 2.0])];
let model2 = vec![
Array1::from_vec(vec![3.0, 4.0]),
Array1::from_vec(vec![5.0]),
];
ensemble.add_model(model1, 1.0).expect("unwrap failed");
assert!(ensemble.add_model(model2, 1.0).is_err());
}
#[test]
fn test_weight_averager_dimension_validation() {
let mut averager = WeightAverager::new(AveragingMethod::MovingAverage, 3);
let weights1 = vec![Array1::from_vec(vec![1.0, 2.0])];
let weights2 = vec![
Array1::from_vec(vec![3.0, 4.0]),
Array1::from_vec(vec![5.0]),
];
averager.update(&weights1).expect("unwrap failed");
assert!(averager.update(&weights2).is_err());
}
#[test]
fn test_gradient_centralization_with_scaling() {
let mut gradients = vec![Array1::from_vec(vec![1.0, 3.0])];
gradient_centralization::centralize_gradients_with_scaling(&mut gradients, 2.0)
.expect("unwrap failed");
assert_relative_eq!(gradients[0][0], -2.0, epsilon = 1e-6);
assert_relative_eq!(gradients[0][1], 2.0, epsilon = 1e-6);
}
}