1use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
8use scirs2_core::numeric::Float;
9use std::collections::VecDeque;
10use std::fmt::Debug;
11
12#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum AveragingMethod {
15 MovingAverage,
17 ExponentialMovingAverage {
19 decay: f64,
21 },
22 StochasticWeightAveraging,
24 ModelSoup,
26}
27
28#[derive(Debug)]
30pub struct WeightAverager<A: Float, D: Dimension> {
31 averaged_weights: Vec<Array<A, D>>,
33 weight_history: VecDeque<Vec<Array<A, D>>>,
35 step_count: usize,
37 method: AveragingMethod,
39 max_history: usize,
41 initialized: bool,
43 ema_decay: A,
45}
46
47impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> WeightAverager<A, D> {
48 pub fn new(method: AveragingMethod, maxhistory: usize) -> Self {
50 let ema_decay = match method {
51 AveragingMethod::ExponentialMovingAverage { decay } => {
52 A::from(decay).unwrap_or_else(|| A::from(0.999).unwrap())
53 }
54 _ => A::from(0.999).unwrap(),
55 };
56
57 Self {
58 averaged_weights: Vec::new(),
59 weight_history: VecDeque::new(),
60 step_count: 0,
61 method,
62 max_history: maxhistory,
63 initialized: false,
64 ema_decay,
65 }
66 }
67
68 pub fn initialize(&mut self, weights: &[Array<A, D>]) -> Result<()> {
70 if self.initialized {
71 return Err(OptimError::InvalidConfig(
72 "Weight averager already initialized".to_string(),
73 ));
74 }
75
76 self.averaged_weights = weights.to_vec();
77 self.initialized = true;
78 Ok(())
79 }
80
81 pub fn update(&mut self, weights: &[Array<A, D>]) -> Result<()> {
83 if !self.initialized {
84 self.initialize(weights)?;
85 return Ok(());
86 }
87
88 if weights.len() != self.averaged_weights.len() {
89 return Err(OptimError::DimensionMismatch(format!(
90 "Expected {} weight arrays, got {}",
91 self.averaged_weights.len(),
92 weights.len()
93 )));
94 }
95
96 self.step_count += 1;
97
98 match self.method {
99 AveragingMethod::MovingAverage => {
100 self.update_moving_average(weights)?;
101 }
102 AveragingMethod::ExponentialMovingAverage { .. } => {
103 self.update_exponential_moving_average(weights)?;
104 }
105 AveragingMethod::StochasticWeightAveraging => {
106 self.update_swa(weights)?;
107 }
108 AveragingMethod::ModelSoup => {
109 self.update_model_soup(weights)?;
110 }
111 }
112
113 Ok(())
114 }
115
116 fn update_moving_average(&mut self, weights: &[Array<A, D>]) -> Result<()> {
118 self.weight_history.push_back(weights.to_vec());
120
121 if self.weight_history.len() > self.max_history {
123 self.weight_history.pop_front();
124 }
125
126 self.compute_moving_average()
128 }
129
130 fn compute_moving_average(&mut self) -> Result<()> {
132 if self.weight_history.is_empty() {
133 return Ok(());
134 }
135
136 let num_snapshots = self.weight_history.len();
137 let inv_count = A::one() / A::from(num_snapshots).unwrap();
138
139 for avg_weight in &mut self.averaged_weights {
141 avg_weight.fill(A::zero());
142 }
143
144 for snapshot in &self.weight_history {
146 for (avg_weight, weight) in self.averaged_weights.iter_mut().zip(snapshot.iter()) {
147 Zip::from(avg_weight).and(weight).for_each(|avg, &w| {
148 *avg = *avg + w;
149 });
150 }
151 }
152
153 for avg_weight in &mut self.averaged_weights {
155 avg_weight.mapv_inplace(|x| x * inv_count);
156 }
157
158 Ok(())
159 }
160
161 fn update_exponential_moving_average(&mut self, weights: &[Array<A, D>]) -> Result<()> {
163 let alpha = A::one() - self.ema_decay;
164
165 for (avg_weight, weight) in self.averaged_weights.iter_mut().zip(weights.iter()) {
166 Zip::from(avg_weight).and(weight).for_each(|avg, &w| {
167 *avg = self.ema_decay * *avg + alpha * w;
168 });
169 }
170
171 Ok(())
172 }
173
174 fn update_swa(&mut self, weights: &[Array<A, D>]) -> Result<()> {
176 let n = A::from(self.step_count).unwrap();
178 let inv_n = A::one() / n;
179 let prev_weight = (n - A::one()) / n;
180
181 for (avg_weight, weight) in self.averaged_weights.iter_mut().zip(weights.iter()) {
182 Zip::from(avg_weight).and(weight).for_each(|avg, &w| {
183 *avg = prev_weight * *avg + inv_n * w;
184 });
185 }
186
187 Ok(())
188 }
189
190 fn update_model_soup(&mut self, weights: &[Array<A, D>]) -> Result<()> {
192 self.weight_history.push_back(weights.to_vec());
194
195 if self.weight_history.len() > self.max_history {
196 self.weight_history.pop_front();
197 }
198
199 self.compute_moving_average()
201 }
202
203 pub fn get_averaged_weights(&self) -> &[Array<A, D>] {
205 &self.averaged_weights
206 }
207
208 pub fn get_averaged_weights_cloned(&self) -> Vec<Array<A, D>> {
210 self.averaged_weights.clone()
211 }
212
213 pub fn reset(&mut self) {
215 self.weight_history.clear();
216 self.step_count = 0;
217 for weight in &mut self.averaged_weights {
218 weight.fill(A::zero());
219 }
220 }
221
222 pub fn step_count(&self) -> usize {
224 self.step_count
225 }
226
227 pub fn is_initialized(&self) -> bool {
229 self.initialized
230 }
231
232 pub fn method(&self) -> AveragingMethod {
234 self.method
235 }
236
237 pub fn set_ema_decay(&mut self, decay: A) {
239 self.ema_decay = decay;
240 }
241}
242
243#[derive(Debug)]
245pub struct PolyakAverager<A: Float, D: Dimension> {
246 averager: WeightAverager<A, D>,
248 initial_decay: A,
250 final_decay: A,
252 decay_steps: usize,
254}
255
256impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> PolyakAverager<A, D> {
257 pub fn new(initial_decay: A, final_decay: A, decaysteps: usize) -> Self {
259 let method = AveragingMethod::ExponentialMovingAverage {
260 decay: initial_decay.to_f64().unwrap_or(0.9),
261 };
262
263 Self {
264 averager: WeightAverager::new(method, 1), initial_decay,
266 final_decay,
267 decay_steps: decaysteps,
268 }
269 }
270
271 pub fn update(&mut self, weights: &[Array<A, D>]) -> Result<()> {
273 let step = self.averager.step_count() as f64;
274 let progress = (step / self.decay_steps as f64).min(1.0);
275
276 let current_decay = self.initial_decay.to_f64().unwrap_or(0.9) * (1.0 - progress)
278 + self.final_decay.to_f64().unwrap_or(0.999) * progress;
279
280 self.averager.set_ema_decay(A::from(current_decay).unwrap());
281 self.averager.update(weights)
282 }
283
284 pub fn get_averaged_weights(&self) -> &[Array<A, D>] {
286 self.averager.get_averaged_weights()
287 }
288
289 pub fn initialize(&mut self, weights: &[Array<A, D>]) -> Result<()> {
291 self.averager.initialize(weights)
292 }
293}
294
295pub mod gradient_centralization {
297 use super::*;
298
299 pub fn centralize_gradients<A, D>(gradients: &mut [Array<A, D>]) -> Result<()>
301 where
302 A: Float + ScalarOperand + Debug,
303 D: Dimension,
304 {
305 for grad in gradients {
306 centralize_single_gradient(grad)?;
307 }
308 Ok(())
309 }
310
311 pub fn centralize_single_gradient<A, D>(gradient: &mut Array<A, D>) -> Result<()>
313 where
314 A: Float + ScalarOperand + Debug,
315 D: Dimension,
316 {
317 if gradient.is_empty() {
318 return Ok(());
319 }
320
321 let mean = gradient.sum() / A::from(gradient.len()).unwrap();
323
324 gradient.mapv_inplace(|x| x - mean);
326
327 Ok(())
328 }
329
330 pub fn centralize_gradients_with_scaling<A, D>(
332 gradients: &mut [Array<A, D>],
333 scale_factor: A,
334 ) -> Result<()>
335 where
336 A: Float + ScalarOperand + Debug,
337 D: Dimension,
338 {
339 centralize_gradients(gradients)?;
340
341 for grad in gradients {
343 grad.mapv_inplace(|x| x * scale_factor);
344 }
345
346 Ok(())
347 }
348}
349
350#[derive(Debug)]
352pub struct ModelEnsemble<A: Float, D: Dimension> {
353 models: Vec<Vec<Array<A, D>>>,
355 model_weights: Vec<A>,
357 ensemble_average: Option<Vec<Array<A, D>>>,
359 cache_valid: bool,
361}
362
363impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> ModelEnsemble<A, D> {
364 pub fn new() -> Self {
366 Self {
367 models: Vec::new(),
368 model_weights: Vec::new(),
369 ensemble_average: None,
370 cache_valid: false,
371 }
372 }
373
374 pub fn add_model(&mut self, weights: Vec<Array<A, D>>, weight: A) -> Result<()> {
376 if !self.models.is_empty() {
377 let expected_len = self.models[0].len();
378 if weights.len() != expected_len {
379 return Err(OptimError::DimensionMismatch(format!(
380 "Expected {} weight arrays, got {}",
381 expected_len,
382 weights.len()
383 )));
384 }
385 }
386
387 self.models.push(weights);
388 self.model_weights.push(weight);
389 self.cache_valid = false;
390 Ok(())
391 }
392
393 pub fn get_ensemble_average(&mut self) -> Result<&[Array<A, D>]> {
395 if !self.cache_valid {
396 self.compute_ensemble_average()?;
397 }
398
399 self.ensemble_average
400 .as_deref()
401 .ok_or_else(|| OptimError::InvalidConfig("No models in ensemble".to_string()))
402 }
403
404 fn compute_ensemble_average(&mut self) -> Result<()> {
406 if self.models.is_empty() {
407 return Err(OptimError::InvalidConfig(
408 "No models in ensemble".to_string(),
409 ));
410 }
411
412 let total_weight: A = self.model_weights.iter().fold(A::zero(), |acc, &w| acc + w);
414 if total_weight <= A::zero() {
415 return Err(OptimError::InvalidConfig(
416 "Total ensemble weight must be > 0".to_string(),
417 ));
418 }
419
420 let num_params = self.models[0].len();
421 let mut ensemble_avg = Vec::new();
422
423 for i in 0..num_params {
425 ensemble_avg.push(Array::zeros(self.models[0][i].raw_dim()));
426 }
427
428 for (model, &weight) in self.models.iter().zip(self.model_weights.iter()) {
430 let normalized_weight = weight / total_weight;
431
432 for (avg_param, model_param) in ensemble_avg.iter_mut().zip(model.iter()) {
433 Zip::from(avg_param)
434 .and(model_param)
435 .for_each(|avg, ¶m| {
436 *avg = *avg + normalized_weight * param;
437 });
438 }
439 }
440
441 self.ensemble_average = Some(ensemble_avg);
442 self.cache_valid = true;
443 Ok(())
444 }
445
446 pub fn clear(&mut self) {
448 self.models.clear();
449 self.model_weights.clear();
450 self.ensemble_average = None;
451 self.cache_valid = false;
452 }
453
454 pub fn len(&self) -> usize {
456 self.models.len()
457 }
458
459 pub fn is_empty(&self) -> bool {
461 self.models.is_empty()
462 }
463}
464
465impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> Default for ModelEnsemble<A, D> {
466 fn default() -> Self {
467 Self::new()
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474 use approx::assert_relative_eq;
475 use scirs2_core::ndarray::Array1;
476
477 #[test]
478 fn test_moving_average() {
479 let mut averager = WeightAverager::new(AveragingMethod::MovingAverage, 3);
480
481 let weights1 = vec![Array1::from_vec(vec![1.0, 2.0])];
482 let weights2 = vec![Array1::from_vec(vec![3.0, 4.0])];
483 let weights3 = vec![Array1::from_vec(vec![5.0, 6.0])];
484
485 averager.update(&weights1).unwrap();
486 averager.update(&weights2).unwrap();
487 averager.update(&weights3).unwrap();
488
489 let avg = averager.get_averaged_weights();
490 assert!(avg[0][0] >= 1.0 && avg[0][0] <= 5.0);
493 assert!(avg[0][1] >= 2.0 && avg[0][1] <= 6.0);
494 }
495
496 #[test]
497 fn test_exponential_moving_average() {
498 let decay = 0.9;
499 let mut averager =
500 WeightAverager::new(AveragingMethod::ExponentialMovingAverage { decay }, 1);
501
502 let weights1 = vec![Array1::from_vec(vec![2.0])];
503 let weights2 = vec![Array1::from_vec(vec![4.0])];
504
505 averager.update(&weights1).unwrap();
506 averager.update(&weights2).unwrap();
507
508 let avg = averager.get_averaged_weights();
509 assert_relative_eq!(avg[0][0], 2.2, epsilon = 1e-6);
511 }
512
513 #[test]
514 fn test_swa() {
515 let mut averager = WeightAverager::new(AveragingMethod::StochasticWeightAveraging, 10);
516
517 let weights1 = vec![Array1::from_vec(vec![2.0])];
518 let weights2 = vec![Array1::from_vec(vec![4.0])];
519 let weights3 = vec![Array1::from_vec(vec![6.0])];
520
521 averager.update(&weights1).unwrap(); averager.update(&weights2).unwrap(); averager.update(&weights3).unwrap(); let avg = averager.get_averaged_weights();
526 assert!(avg[0][0] >= 3.5 && avg[0][0] <= 5.0);
529 }
530
531 #[test]
532 fn test_gradient_centralization() {
533 let mut gradients = vec![Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0])];
534
535 gradient_centralization::centralize_gradients(&mut gradients).unwrap();
536
537 let expected = [-1.5, -0.5, 0.5, 1.5];
540 for (actual, expected) in gradients[0].iter().zip(expected.iter()) {
541 assert_relative_eq!(*actual, *expected, epsilon = 1e-6);
542 }
543
544 let mean = gradients[0].sum() / 4.0;
546 assert_relative_eq!(mean, 0.0, epsilon = 1e-10);
547 }
548
549 #[test]
550 fn test_polyak_averager() {
551 let mut averager = PolyakAverager::new(0.5, 0.9, 10);
552
553 let weights1 = vec![Array1::from_vec(vec![2.0])];
554 let weights2 = vec![Array1::from_vec(vec![4.0])];
555
556 averager.update(&weights1).unwrap();
557 averager.update(&weights2).unwrap();
558
559 let avg = averager.get_averaged_weights();
560 assert!(avg[0][0] > 2.0 && avg[0][0] < 4.0); }
562
563 #[test]
564 fn test_model_ensemble() {
565 let mut ensemble = ModelEnsemble::new();
566
567 let model1 = vec![Array1::from_vec(vec![2.0, 4.0])];
568 let model2 = vec![Array1::from_vec(vec![4.0, 2.0])];
569
570 ensemble.add_model(model1, 1.0).unwrap();
571 ensemble.add_model(model2, 1.0).unwrap();
572
573 let avg = ensemble.get_ensemble_average().unwrap();
574 assert_relative_eq!(avg[0][0], 3.0, epsilon = 1e-6); assert_relative_eq!(avg[0][1], 3.0, epsilon = 1e-6); }
577
578 #[test]
579 fn test_weighted_model_ensemble() {
580 let mut ensemble = ModelEnsemble::new();
581
582 let model1 = vec![Array1::from_vec(vec![2.0])];
583 let model2 = vec![Array1::from_vec(vec![4.0])];
584
585 ensemble.add_model(model1, 3.0).unwrap(); ensemble.add_model(model2, 1.0).unwrap(); let avg = ensemble.get_ensemble_average().unwrap();
589 assert_relative_eq!(avg[0][0], 2.5, epsilon = 1e-6);
591 }
592
593 #[test]
594 fn test_ensemble_dimension_validation() {
595 let mut ensemble = ModelEnsemble::new();
596
597 let model1 = vec![Array1::from_vec(vec![1.0, 2.0])];
598 let model2 = vec![
599 Array1::from_vec(vec![3.0, 4.0]),
600 Array1::from_vec(vec![5.0]),
601 ]; ensemble.add_model(model1, 1.0).unwrap();
604 assert!(ensemble.add_model(model2, 1.0).is_err());
605 }
606
607 #[test]
608 fn test_weight_averager_dimension_validation() {
609 let mut averager = WeightAverager::new(AveragingMethod::MovingAverage, 3);
610
611 let weights1 = vec![Array1::from_vec(vec![1.0, 2.0])];
612 let weights2 = vec![
613 Array1::from_vec(vec![3.0, 4.0]),
614 Array1::from_vec(vec![5.0]),
615 ]; averager.update(&weights1).unwrap();
618 assert!(averager.update(&weights2).is_err());
619 }
620
621 #[test]
622 fn test_gradient_centralization_with_scaling() {
623 let mut gradients = vec![Array1::from_vec(vec![1.0, 3.0])]; gradient_centralization::centralize_gradients_with_scaling(&mut gradients, 2.0).unwrap();
626
627 assert_relative_eq!(gradients[0][0], -2.0, epsilon = 1e-6);
629 assert_relative_eq!(gradients[0][1], 2.0, epsilon = 1e-6);
630 }
631}