1use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
8use scirs2_core::numeric::Float;
9use std::collections::VecDeque;
10use std::fmt::Debug;
11
12#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum AveragingMethod {
15 MovingAverage,
17 ExponentialMovingAverage {
19 decay: f64,
21 },
22 StochasticWeightAveraging,
24 ModelSoup,
26}
27
28#[derive(Debug)]
30pub struct WeightAverager<A: Float, D: Dimension> {
31 averaged_weights: Vec<Array<A, D>>,
33 weight_history: VecDeque<Vec<Array<A, D>>>,
35 step_count: usize,
37 method: AveragingMethod,
39 max_history: usize,
41 initialized: bool,
43 ema_decay: A,
45}
46
47impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> WeightAverager<A, D> {
48 pub fn new(method: AveragingMethod, maxhistory: usize) -> Self {
50 let ema_decay = match method {
51 AveragingMethod::ExponentialMovingAverage { decay } => {
52 A::from(decay).unwrap_or_else(|| A::from(0.999).expect("unwrap failed"))
53 }
54 _ => A::from(0.999).expect("unwrap failed"),
55 };
56
57 Self {
58 averaged_weights: Vec::new(),
59 weight_history: VecDeque::new(),
60 step_count: 0,
61 method,
62 max_history: maxhistory,
63 initialized: false,
64 ema_decay,
65 }
66 }
67
68 pub fn initialize(&mut self, weights: &[Array<A, D>]) -> Result<()> {
70 if self.initialized {
71 return Err(OptimError::InvalidConfig(
72 "Weight averager already initialized".to_string(),
73 ));
74 }
75
76 self.averaged_weights = weights.to_vec();
77 self.initialized = true;
78 Ok(())
79 }
80
81 pub fn update(&mut self, weights: &[Array<A, D>]) -> Result<()> {
83 if !self.initialized {
84 self.initialize(weights)?;
85 return Ok(());
86 }
87
88 if weights.len() != self.averaged_weights.len() {
89 return Err(OptimError::DimensionMismatch(format!(
90 "Expected {} weight arrays, got {}",
91 self.averaged_weights.len(),
92 weights.len()
93 )));
94 }
95
96 self.step_count += 1;
97
98 match self.method {
99 AveragingMethod::MovingAverage => {
100 self.update_moving_average(weights)?;
101 }
102 AveragingMethod::ExponentialMovingAverage { .. } => {
103 self.update_exponential_moving_average(weights)?;
104 }
105 AveragingMethod::StochasticWeightAveraging => {
106 self.update_swa(weights)?;
107 }
108 AveragingMethod::ModelSoup => {
109 self.update_model_soup(weights)?;
110 }
111 }
112
113 Ok(())
114 }
115
116 fn update_moving_average(&mut self, weights: &[Array<A, D>]) -> Result<()> {
118 self.weight_history.push_back(weights.to_vec());
120
121 if self.weight_history.len() > self.max_history {
123 self.weight_history.pop_front();
124 }
125
126 self.compute_moving_average()
128 }
129
130 fn compute_moving_average(&mut self) -> Result<()> {
132 if self.weight_history.is_empty() {
133 return Ok(());
134 }
135
136 let num_snapshots = self.weight_history.len();
137 let inv_count = A::one() / A::from(num_snapshots).expect("unwrap failed");
138
139 for avg_weight in &mut self.averaged_weights {
141 avg_weight.fill(A::zero());
142 }
143
144 for snapshot in &self.weight_history {
146 for (avg_weight, weight) in self.averaged_weights.iter_mut().zip(snapshot.iter()) {
147 Zip::from(avg_weight).and(weight).for_each(|avg, &w| {
148 *avg = *avg + w;
149 });
150 }
151 }
152
153 for avg_weight in &mut self.averaged_weights {
155 avg_weight.mapv_inplace(|x| x * inv_count);
156 }
157
158 Ok(())
159 }
160
161 fn update_exponential_moving_average(&mut self, weights: &[Array<A, D>]) -> Result<()> {
163 let alpha = A::one() - self.ema_decay;
164
165 for (avg_weight, weight) in self.averaged_weights.iter_mut().zip(weights.iter()) {
166 Zip::from(avg_weight).and(weight).for_each(|avg, &w| {
167 *avg = self.ema_decay * *avg + alpha * w;
168 });
169 }
170
171 Ok(())
172 }
173
174 fn update_swa(&mut self, weights: &[Array<A, D>]) -> Result<()> {
176 let n = A::from(self.step_count).expect("unwrap failed");
178 let inv_n = A::one() / n;
179 let prev_weight = (n - A::one()) / n;
180
181 for (avg_weight, weight) in self.averaged_weights.iter_mut().zip(weights.iter()) {
182 Zip::from(avg_weight).and(weight).for_each(|avg, &w| {
183 *avg = prev_weight * *avg + inv_n * w;
184 });
185 }
186
187 Ok(())
188 }
189
190 fn update_model_soup(&mut self, weights: &[Array<A, D>]) -> Result<()> {
192 self.weight_history.push_back(weights.to_vec());
194
195 if self.weight_history.len() > self.max_history {
196 self.weight_history.pop_front();
197 }
198
199 self.compute_moving_average()
201 }
202
203 pub fn get_averaged_weights(&self) -> &[Array<A, D>] {
205 &self.averaged_weights
206 }
207
208 pub fn get_averaged_weights_cloned(&self) -> Vec<Array<A, D>> {
210 self.averaged_weights.clone()
211 }
212
213 pub fn reset(&mut self) {
215 self.weight_history.clear();
216 self.step_count = 0;
217 for weight in &mut self.averaged_weights {
218 weight.fill(A::zero());
219 }
220 }
221
222 pub fn step_count(&self) -> usize {
224 self.step_count
225 }
226
227 pub fn is_initialized(&self) -> bool {
229 self.initialized
230 }
231
232 pub fn method(&self) -> AveragingMethod {
234 self.method
235 }
236
237 pub fn set_ema_decay(&mut self, decay: A) {
239 self.ema_decay = decay;
240 }
241}
242
243#[derive(Debug)]
245pub struct PolyakAverager<A: Float, D: Dimension> {
246 averager: WeightAverager<A, D>,
248 initial_decay: A,
250 final_decay: A,
252 decay_steps: usize,
254}
255
256impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> PolyakAverager<A, D> {
257 pub fn new(initial_decay: A, final_decay: A, decaysteps: usize) -> Self {
259 let method = AveragingMethod::ExponentialMovingAverage {
260 decay: initial_decay.to_f64().unwrap_or(0.9),
261 };
262
263 Self {
264 averager: WeightAverager::new(method, 1), initial_decay,
266 final_decay,
267 decay_steps: decaysteps,
268 }
269 }
270
271 pub fn update(&mut self, weights: &[Array<A, D>]) -> Result<()> {
273 let step = self.averager.step_count() as f64;
274 let progress = (step / self.decay_steps as f64).min(1.0);
275
276 let current_decay = self.initial_decay.to_f64().unwrap_or(0.9) * (1.0 - progress)
278 + self.final_decay.to_f64().unwrap_or(0.999) * progress;
279
280 self.averager
281 .set_ema_decay(A::from(current_decay).expect("unwrap failed"));
282 self.averager.update(weights)
283 }
284
285 pub fn get_averaged_weights(&self) -> &[Array<A, D>] {
287 self.averager.get_averaged_weights()
288 }
289
290 pub fn initialize(&mut self, weights: &[Array<A, D>]) -> Result<()> {
292 self.averager.initialize(weights)
293 }
294}
295
296pub mod gradient_centralization {
298 use super::*;
299
300 pub fn centralize_gradients<A, D>(gradients: &mut [Array<A, D>]) -> Result<()>
302 where
303 A: Float + ScalarOperand + Debug,
304 D: Dimension,
305 {
306 for grad in gradients {
307 centralize_single_gradient(grad)?;
308 }
309 Ok(())
310 }
311
312 pub fn centralize_single_gradient<A, D>(gradient: &mut Array<A, D>) -> Result<()>
314 where
315 A: Float + ScalarOperand + Debug,
316 D: Dimension,
317 {
318 if gradient.is_empty() {
319 return Ok(());
320 }
321
322 let mean = gradient.sum() / A::from(gradient.len()).expect("unwrap failed");
324
325 gradient.mapv_inplace(|x| x - mean);
327
328 Ok(())
329 }
330
331 pub fn centralize_gradients_with_scaling<A, D>(
333 gradients: &mut [Array<A, D>],
334 scale_factor: A,
335 ) -> Result<()>
336 where
337 A: Float + ScalarOperand + Debug,
338 D: Dimension,
339 {
340 centralize_gradients(gradients)?;
341
342 for grad in gradients {
344 grad.mapv_inplace(|x| x * scale_factor);
345 }
346
347 Ok(())
348 }
349}
350
351#[derive(Debug)]
353pub struct ModelEnsemble<A: Float, D: Dimension> {
354 models: Vec<Vec<Array<A, D>>>,
356 model_weights: Vec<A>,
358 ensemble_average: Option<Vec<Array<A, D>>>,
360 cache_valid: bool,
362}
363
364impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> ModelEnsemble<A, D> {
365 pub fn new() -> Self {
367 Self {
368 models: Vec::new(),
369 model_weights: Vec::new(),
370 ensemble_average: None,
371 cache_valid: false,
372 }
373 }
374
375 pub fn add_model(&mut self, weights: Vec<Array<A, D>>, weight: A) -> Result<()> {
377 if !self.models.is_empty() {
378 let expected_len = self.models[0].len();
379 if weights.len() != expected_len {
380 return Err(OptimError::DimensionMismatch(format!(
381 "Expected {} weight arrays, got {}",
382 expected_len,
383 weights.len()
384 )));
385 }
386 }
387
388 self.models.push(weights);
389 self.model_weights.push(weight);
390 self.cache_valid = false;
391 Ok(())
392 }
393
394 pub fn get_ensemble_average(&mut self) -> Result<&[Array<A, D>]> {
396 if !self.cache_valid {
397 self.compute_ensemble_average()?;
398 }
399
400 self.ensemble_average
401 .as_deref()
402 .ok_or_else(|| OptimError::InvalidConfig("No models in ensemble".to_string()))
403 }
404
405 fn compute_ensemble_average(&mut self) -> Result<()> {
407 if self.models.is_empty() {
408 return Err(OptimError::InvalidConfig(
409 "No models in ensemble".to_string(),
410 ));
411 }
412
413 let total_weight: A = self.model_weights.iter().fold(A::zero(), |acc, &w| acc + w);
415 if total_weight <= A::zero() {
416 return Err(OptimError::InvalidConfig(
417 "Total ensemble weight must be > 0".to_string(),
418 ));
419 }
420
421 let num_params = self.models[0].len();
422 let mut ensemble_avg = Vec::new();
423
424 for i in 0..num_params {
426 ensemble_avg.push(Array::zeros(self.models[0][i].raw_dim()));
427 }
428
429 for (model, &weight) in self.models.iter().zip(self.model_weights.iter()) {
431 let normalized_weight = weight / total_weight;
432
433 for (avg_param, model_param) in ensemble_avg.iter_mut().zip(model.iter()) {
434 Zip::from(avg_param)
435 .and(model_param)
436 .for_each(|avg, ¶m| {
437 *avg = *avg + normalized_weight * param;
438 });
439 }
440 }
441
442 self.ensemble_average = Some(ensemble_avg);
443 self.cache_valid = true;
444 Ok(())
445 }
446
447 pub fn clear(&mut self) {
449 self.models.clear();
450 self.model_weights.clear();
451 self.ensemble_average = None;
452 self.cache_valid = false;
453 }
454
455 pub fn len(&self) -> usize {
457 self.models.len()
458 }
459
460 pub fn is_empty(&self) -> bool {
462 self.models.is_empty()
463 }
464}
465
466impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> Default for ModelEnsemble<A, D> {
467 fn default() -> Self {
468 Self::new()
469 }
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475 use approx::assert_relative_eq;
476 use scirs2_core::ndarray::Array1;
477
478 #[test]
479 fn test_moving_average() {
480 let mut averager = WeightAverager::new(AveragingMethod::MovingAverage, 3);
481
482 let weights1 = vec![Array1::from_vec(vec![1.0, 2.0])];
483 let weights2 = vec![Array1::from_vec(vec![3.0, 4.0])];
484 let weights3 = vec![Array1::from_vec(vec![5.0, 6.0])];
485
486 averager.update(&weights1).expect("unwrap failed");
487 averager.update(&weights2).expect("unwrap failed");
488 averager.update(&weights3).expect("unwrap failed");
489
490 let avg = averager.get_averaged_weights();
491 assert!(avg[0][0] >= 1.0 && avg[0][0] <= 5.0);
494 assert!(avg[0][1] >= 2.0 && avg[0][1] <= 6.0);
495 }
496
497 #[test]
498 fn test_exponential_moving_average() {
499 let decay = 0.9;
500 let mut averager =
501 WeightAverager::new(AveragingMethod::ExponentialMovingAverage { decay }, 1);
502
503 let weights1 = vec![Array1::from_vec(vec![2.0])];
504 let weights2 = vec![Array1::from_vec(vec![4.0])];
505
506 averager.update(&weights1).expect("unwrap failed");
507 averager.update(&weights2).expect("unwrap failed");
508
509 let avg = averager.get_averaged_weights();
510 assert_relative_eq!(avg[0][0], 2.2, epsilon = 1e-6);
512 }
513
514 #[test]
515 fn test_swa() {
516 let mut averager = WeightAverager::new(AveragingMethod::StochasticWeightAveraging, 10);
517
518 let weights1 = vec![Array1::from_vec(vec![2.0])];
519 let weights2 = vec![Array1::from_vec(vec![4.0])];
520 let weights3 = vec![Array1::from_vec(vec![6.0])];
521
522 averager.update(&weights1).expect("unwrap failed"); averager.update(&weights2).expect("unwrap failed"); averager.update(&weights3).expect("unwrap failed"); let avg = averager.get_averaged_weights();
527 assert!(avg[0][0] >= 3.5 && avg[0][0] <= 5.0);
530 }
531
532 #[test]
533 fn test_gradient_centralization() {
534 let mut gradients = vec![Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0])];
535
536 gradient_centralization::centralize_gradients(&mut gradients).expect("unwrap failed");
537
538 let expected = [-1.5, -0.5, 0.5, 1.5];
541 for (actual, expected) in gradients[0].iter().zip(expected.iter()) {
542 assert_relative_eq!(*actual, *expected, epsilon = 1e-6);
543 }
544
545 let mean = gradients[0].sum() / 4.0;
547 assert_relative_eq!(mean, 0.0, epsilon = 1e-10);
548 }
549
550 #[test]
551 fn test_polyak_averager() {
552 let mut averager = PolyakAverager::new(0.5, 0.9, 10);
553
554 let weights1 = vec![Array1::from_vec(vec![2.0])];
555 let weights2 = vec![Array1::from_vec(vec![4.0])];
556
557 averager.update(&weights1).expect("unwrap failed");
558 averager.update(&weights2).expect("unwrap failed");
559
560 let avg = averager.get_averaged_weights();
561 assert!(avg[0][0] > 2.0 && avg[0][0] < 4.0); }
563
564 #[test]
565 fn test_model_ensemble() {
566 let mut ensemble = ModelEnsemble::new();
567
568 let model1 = vec![Array1::from_vec(vec![2.0, 4.0])];
569 let model2 = vec![Array1::from_vec(vec![4.0, 2.0])];
570
571 ensemble.add_model(model1, 1.0).expect("unwrap failed");
572 ensemble.add_model(model2, 1.0).expect("unwrap failed");
573
574 let avg = ensemble.get_ensemble_average().expect("unwrap failed");
575 assert_relative_eq!(avg[0][0], 3.0, epsilon = 1e-6); assert_relative_eq!(avg[0][1], 3.0, epsilon = 1e-6); }
578
579 #[test]
580 fn test_weighted_model_ensemble() {
581 let mut ensemble = ModelEnsemble::new();
582
583 let model1 = vec![Array1::from_vec(vec![2.0])];
584 let model2 = vec![Array1::from_vec(vec![4.0])];
585
586 ensemble.add_model(model1, 3.0).expect("unwrap failed"); ensemble.add_model(model2, 1.0).expect("unwrap failed"); let avg = ensemble.get_ensemble_average().expect("unwrap failed");
590 assert_relative_eq!(avg[0][0], 2.5, epsilon = 1e-6);
592 }
593
594 #[test]
595 fn test_ensemble_dimension_validation() {
596 let mut ensemble = ModelEnsemble::new();
597
598 let model1 = vec![Array1::from_vec(vec![1.0, 2.0])];
599 let model2 = vec![
600 Array1::from_vec(vec![3.0, 4.0]),
601 Array1::from_vec(vec![5.0]),
602 ]; ensemble.add_model(model1, 1.0).expect("unwrap failed");
605 assert!(ensemble.add_model(model2, 1.0).is_err());
606 }
607
608 #[test]
609 fn test_weight_averager_dimension_validation() {
610 let mut averager = WeightAverager::new(AveragingMethod::MovingAverage, 3);
611
612 let weights1 = vec![Array1::from_vec(vec![1.0, 2.0])];
613 let weights2 = vec![
614 Array1::from_vec(vec![3.0, 4.0]),
615 Array1::from_vec(vec![5.0]),
616 ]; averager.update(&weights1).expect("unwrap failed");
619 assert!(averager.update(&weights2).is_err());
620 }
621
622 #[test]
623 fn test_gradient_centralization_with_scaling() {
624 let mut gradients = vec![Array1::from_vec(vec![1.0, 3.0])]; gradient_centralization::centralize_gradients_with_scaling(&mut gradients, 2.0)
627 .expect("unwrap failed");
628
629 assert_relative_eq!(gradients[0][0], -2.0, epsilon = 1e-6);
631 assert_relative_eq!(gradients[0][1], 2.0, epsilon = 1e-6);
632 }
633}