ema/
lib.rs

1//! Library for using exponential moving averages that is generic over the underlying float type.
2#![cfg_attr(
3  not(test),
4  deny(warnings, clippy::all, clippy::pedantic, clippy::cargo, missing_docs, missing_crate_level_docs)
5)]
6#![deny(unsafe_code)]
7#![cfg_attr(not(test), no_std)]
8
9use core::cmp::Ordering;
10use core::convert::TryInto;
11use core::time::Duration;
12use num_traits::identities::{One, Zero};
13use num_traits::Float;
14use ordered_float::{FloatIsNan, NotNan};
15
16/// A struct representing an exponential moving average
17///
18/// The weighting can be chosen for each accumulation. To have the weighting be part of the struct see [`StableEma`]
19#[must_use]
20#[derive(Clone)]
21pub struct Ema<F>
22where
23  F: Float,
24{
25  mean: NotNan<F>,
26  variance: NotNan<F>,
27}
28
29impl<F> PartialEq for Ema<F>
30where
31  F: Float,
32{
33  fn eq(&self, other: &Self) -> bool {
34    self.mean.eq(&other.mean) && self.variance.eq(&other.variance)
35  }
36}
37
38impl<F> Eq for Ema<F> where F: Float {}
39
40impl<F> PartialOrd for Ema<F>
41where
42  F: Float,
43{
44  fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
45    Some(self.cmp(other))
46  }
47}
48
49impl<F> Ord for Ema<F>
50where
51  F: Float,
52{
53  fn cmp(&self, other: &Self) -> Ordering {
54    self.mean.cmp(&other.mean).then_with(|| self.variance.cmp(&other.variance))
55  }
56}
57
58impl<F> Ema<F>
59where
60  F: Float + TryInto<NotNan<F>>,
61{
62  /// Tries to create a new `Ema` struct from raw float values
63  /// # Errors
64  /// Fails if `mean` or `variance` are NaN
65  pub fn try_new(
66    mean: impl TryInto<NotNan<F>, Error = FloatIsNan>,
67    variance: impl TryInto<NotNan<F>, Error = FloatIsNan>,
68  ) -> Result<Self, FloatIsNan> {
69    Ok(Self::new(mean.try_into()?, variance.try_into()?))
70  }
71}
72
73impl<F> Ema<F>
74where
75  F: Float,
76{
77  /// Returns a new `Ema` struct with the mean and variance estimates already initialized.
78  ///
79  /// It is recommended to choose these values to be as close to expected as possible so that they can converge quickly
80  pub fn new(mean: NotNan<F>, variance: NotNan<F>) -> Self {
81    Self { mean, variance }
82  }
83  /// Accumulates a new value into this `Ema`. The mean and variance are adjusted by the `recent_weight`
84  pub fn accumulate(&mut self, value: NotNan<F>, recent_weight: NotNan<F>) {
85    let recent_weight = recent_weight.min(NotNan::one()).max(NotNan::zero());
86    let mean = self.mean;
87    let delta = value - mean;
88    let new_mean = mean + recent_weight * delta;
89    let new_variance = (NotNan::one() - recent_weight) * (self.variance + recent_weight * delta * delta);
90    self.mean = new_mean;
91    self.variance = new_variance;
92  }
93  /// Tries to accumulate raw flaot values.
94  /// # Errors
95  /// Fails if `value` or `recent_weight` are NaN
96  pub fn try_accumulate(&mut self, value: F, recent_weight: F) -> Result<(), FloatIsNan> {
97    let value = NotNan::new(value)?;
98    let recent_weight = NotNan::new(recent_weight)?;
99    self.accumulate(value, recent_weight);
100    Ok(())
101  }
102  /// Returns the mean of this `Ema`
103  #[must_use]
104  #[inline]
105  pub fn mean(&self) -> NotNan<F> {
106    self.mean
107  }
108  /// Returns the variance of this `Ema`
109  #[must_use]
110  #[inline]
111  pub fn variance(&self) -> NotNan<F> {
112    self.variance
113  }
114  /// Returns the standard deviation of this `Ema`
115  #[allow(clippy::missing_panics_doc)]
116  #[must_use]
117  #[inline]
118  pub fn std_dev(&self) -> NotNan<F> {
119    // Not using `unwrap` or `expect` because we don't want to force the associated type to be `Debug`
120    NotNan::new(self.variance.sqrt()).unwrap_or_else(|_| panic!("sqrt won't return NaN if it didn't start with it"))
121  }
122  /// Returns the mean of this `Ema` as a duration in seconds. Useful when using an `Ema` to time events.
123  #[must_use]
124  #[inline]
125  pub fn mean_duration(&self) -> Duration {
126    Duration::from_secs_f64(self.mean().to_f64().unwrap_or(0.0).max(0.0))
127  }
128  /// Returns the standard deviation of this `Ema` as a duration in seconds. Useful when using an `Ema` to time events
129  #[must_use]
130  #[inline]
131  pub fn std_dev_duration(&self) -> Duration {
132    Duration::from_secs_f64(self.std_dev().to_f64().unwrap_or(0.0).max(0.0))
133  }
134}
135
136impl<F> Default for Ema<F>
137where
138  F: Float,
139{
140  fn default() -> Self {
141    Self {
142      mean: NotNan::zero(),
143      variance: NotNan::zero(),
144    }
145  }
146}
147
148impl<F> core::fmt::Debug for Ema<F>
149where
150  F: Float + core::fmt::Debug,
151{
152  fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
153    f.debug_struct("EMA")
154      .field("mean", &*self.mean)
155      .field("variance", &*self.variance)
156      .finish()
157  }
158}
159
160/// A stable [Ema] where the `recent_weight` is set at initialization and the same value is always used.
161#[derive(Clone)]
162#[must_use]
163pub struct StableEma<F>
164where
165  F: Float,
166{
167  ema: Ema<F>,
168  recent_weight: NotNan<F>,
169}
170
171impl<F> PartialEq for StableEma<F>
172where
173  F: Float,
174{
175  fn eq(&self, other: &Self) -> bool {
176    self.ema.eq(&other.ema) && self.recent_weight.eq(&other.recent_weight)
177  }
178}
179
180impl<F> Eq for StableEma<F> where F: Float {}
181
182impl<F> PartialOrd for StableEma<F>
183where
184  F: Float,
185{
186  fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
187    Some(self.cmp(other))
188  }
189}
190
191impl<F> Ord for StableEma<F>
192where
193  F: Float,
194{
195  fn cmp(&self, other: &Self) -> Ordering {
196    self.ema.cmp(&other.ema).then_with(|| self.recent_weight.cmp(&other.recent_weight))
197  }
198}
199
200impl<F> Default for StableEma<F>
201where
202  F: Float,
203{
204  fn default() -> Self {
205    Self {
206      ema: Ema::default(),
207      // Doing panics and stuff to avoid trait bounds.
208      recent_weight: NotNan::new(F::from(0.1).unwrap_or_else(|| panic!("cannot fail"))).unwrap_or_else(|_| panic!("inner is a number")),
209    }
210  }
211}
212
213impl<F> StableEma<F>
214where
215  F: Float,
216{
217  /// Returns a new `StableEma` with the `mean`, `variance`, and `recent_weight` all initialized.
218  ///
219  /// It is recommended to choose the `mean` and `variance` to be as close to expected as possible so that they can converge quickly
220  pub fn new(mean: NotNan<F>, variance: NotNan<F>, recent_weight: NotNan<F>) -> Self {
221    Self {
222      ema: Ema::new(mean, variance),
223      recent_weight,
224    }
225  }
226
227  /// Tries to create a new `StableEma` from raw float values.
228  /// # Errors
229  /// Fails if `mean`, `variance`, or `recent_weight` are NaN
230  pub fn try_new<T: TryInto<NotNan<F>, Error = FloatIsNan>>(mean: T, variance: T, recent_weight: T) -> Result<Self, FloatIsNan> {
231    Ok(Self::new(mean.try_into()?, variance.try_into()?, recent_weight.try_into()?))
232  }
233
234  /// Accumulates the value to this `StableEma`
235  pub fn accumulate(&mut self, value: NotNan<F>) {
236    self.ema.accumulate(value, self.recent_weight)
237  }
238
239  /// Tries to accumulate a raw float value
240  /// # Errors
241  /// Fails if `value` is NaN
242  pub fn try_accumulate(&mut self, value: F) -> Result<(), FloatIsNan> {
243    self.accumulate(NotNan::new(value)?);
244    Ok(())
245  }
246
247  /// Returns the mean of this `StableEma`
248  #[inline]
249  #[must_use]
250  pub fn mean(&self) -> NotNan<F> {
251    self.ema.mean()
252  }
253
254  /// Returns the variance of this `StableEma`
255  #[inline]
256  #[must_use]
257  pub fn variance(&self) -> NotNan<F> {
258    self.ema.variance()
259  }
260
261  /// Returns the standard deviation of this `StableEma`
262  #[inline]
263  #[must_use]
264  pub fn std_dev(&self) -> NotNan<F> {
265    self.ema.std_dev()
266  }
267
268  /// Returns the recent weight that this `StableEma` uses to accumulate values
269  #[must_use]
270  pub fn recent_weight(&self) -> NotNan<F> {
271    self.recent_weight
272  }
273
274  /// Returns the mean of this `StableEma` as a duration in seconds. Useful when using an `Ema` to time events.
275  #[inline]
276  #[must_use]
277  pub fn mean_duration(&self) -> Duration {
278    self.ema.mean_duration()
279  }
280
281  /// Returns the standard deviation of this `StableEma` as a duration in seconds. Useful when using an `Ema` to time events.
282  #[inline]
283  #[must_use]
284  pub fn std_dev_duration(&self) -> Duration {
285    self.ema.std_dev_duration()
286  }
287
288  /// Change the recent weight.
289  /// # Safety
290  /// This is not unsafe to call, but it violates the notion that this has
291  /// a stable recent weight
292  #[allow(unsafe_code)]
293  pub unsafe fn set_recent_weight(&mut self, recent_weight: NotNan<F>) {
294    self.recent_weight = recent_weight;
295  }
296}
297
298impl<F> core::fmt::Debug for StableEma<F>
299where
300  F: Float + core::fmt::Debug,
301{
302  fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
303    f.debug_struct("StableEMA")
304      .field("mean", &*self.ema.mean)
305      .field("variance", &*self.ema.variance)
306      .field("recent_weight", &*self.recent_weight)
307      .finish()
308  }
309}
310
311#[cfg(test)]
312mod test {
313  use super::*;
314
315  fn test_ema<F: Float + num_traits::FromPrimitive + core::fmt::Debug>(iters: u32, mean_epsilon: F, variance_epsilon: F) {
316    let mut ema = StableEma::<F>::new(NotNan::one(), NotNan::zero(), NotNan::new(F::from_f64(0.2).unwrap()).unwrap());
317    assert_eq!(*ema.mean(), F::one());
318    assert_eq!(*ema.variance(), F::zero());
319    assert_eq!(*ema.std_dev(), F::zero());
320    assert_eq!(ema.mean_duration(), Duration::from_secs(1));
321    assert_eq!(ema.std_dev_duration(), Duration::from_secs(0));
322    assert_eq!(*ema.recent_weight(), F::from_f64(0.2).unwrap());
323    (0..10000).for_each(|_| ema.accumulate(NotNan::one()));
324    assert_eq!(ema.mean(), NotNan::one());
325    assert_eq!(ema.variance(), NotNan::zero());
326
327    (1..=iters).for_each(|i| {
328      ema.accumulate(NotNan::new(F::from(i as f64).unwrap()).unwrap());
329      if i > iters / 2 {
330        assert!(
331          (ema.mean() - F::from((i - 4) as f64).unwrap()).abs() <= mean_epsilon,
332          "mean: {:?}",
333          ema.mean()
334        );
335        assert!(
336          (ema.variance() - F::from(20.0).unwrap()).abs() <= variance_epsilon,
337          "variance: {:?}",
338          ema.variance()
339        );
340        assert!(
341          (ema.std_dev() - F::from(20.0.sqrt()).unwrap()).abs() <= variance_epsilon,
342          "std_dev: {:?}",
343          ema.std_dev()
344        );
345      }
346    });
347  }
348
349  #[test]
350  fn test_types() {
351    use half::{bf16, f16};
352    test_ema::<f32>(10000, 1e-7, 1e-5);
353    let mut ema = StableEma::<f32>::default();
354    ema.try_accumulate(f32::NAN).unwrap_err();
355    test_ema::<f64>(100000, 1e-7, 1e-5);
356    let mut ema = StableEma::<f64>::default();
357    ema.try_accumulate(f64::NAN).unwrap_err();
358    test_ema::<bf16>(250, bf16::from_f32(1e-7), bf16::from_f32(0.25));
359    let mut ema = Ema::<bf16>::default();
360    ema.try_accumulate(bf16::from_f32(f32::NAN), bf16::from_f32(0.5)).unwrap_err();
361    test_ema::<f16>(500, f16::from_f32(1e-7), f16::from_f32(0.125));
362    let mut ema = Ema::<f16>::default();
363    ema.try_accumulate(f16::from_f32(f32::NAN), f16::from_f32(0.5)).unwrap_err();
364  }
365}