1use yscv_tensor::Tensor;
2
3pub struct ExponentialMovingAverage {
10 decay: f32,
11 shadow_params: Vec<Tensor>,
12 num_updates: usize,
13}
14
15impl ExponentialMovingAverage {
16 pub fn new(decay: f32) -> Self {
18 Self {
19 decay,
20 shadow_params: Vec::new(),
21 num_updates: 0,
22 }
23 }
24
25 pub fn register(&mut self, params: &[Tensor]) {
27 self.shadow_params = params.to_vec();
28 }
29
30 pub fn update(&mut self, params: &[Tensor]) {
35 assert_eq!(
36 params.len(),
37 self.shadow_params.len(),
38 "param count mismatch: expected {} but got {}",
39 self.shadow_params.len(),
40 params.len(),
41 );
42 let decay = self.decay;
43 let one_minus_decay = 1.0 - decay;
44 for (shadow, param) in self.shadow_params.iter_mut().zip(params.iter()) {
45 let s = shadow.data_mut();
46 let p = param.data();
47 assert_eq!(s.len(), p.len(), "tensor length mismatch in EMA update");
48 let len = s.len();
49 for i in 0..len {
50 s[i] = decay * s[i] + one_minus_decay * p[i];
51 }
52 }
53 self.num_updates += 1;
54 }
55
56 pub fn shadow_params(&self) -> &[Tensor] {
58 &self.shadow_params
59 }
60
61 pub fn apply_shadow(&self, params: &mut [Tensor]) {
66 assert_eq!(
67 params.len(),
68 self.shadow_params.len(),
69 "param count mismatch in apply_shadow",
70 );
71 for (dst, src) in params.iter_mut().zip(self.shadow_params.iter()) {
72 let d = dst.data_mut();
73 let s = src.data();
74 assert_eq!(d.len(), s.len(), "tensor length mismatch in apply_shadow");
75 d.copy_from_slice(s);
76 }
77 }
78
79 pub fn num_updates(&self) -> usize {
81 self.num_updates
82 }
83}