1#![cfg_attr(
3 not(test),
4 deny(warnings, clippy::all, clippy::pedantic, clippy::cargo, missing_docs, missing_crate_level_docs)
5)]
6#![deny(unsafe_code)]
7#![cfg_attr(not(test), no_std)]
8
9use core::cmp::Ordering;
10use core::convert::TryInto;
11use core::time::Duration;
12use num_traits::identities::{One, Zero};
13use num_traits::Float;
14use ordered_float::{FloatIsNan, NotNan};
15
16#[must_use]
20#[derive(Clone)]
21pub struct Ema<F>
22where
23 F: Float,
24{
25 mean: NotNan<F>,
26 variance: NotNan<F>,
27}
28
29impl<F> PartialEq for Ema<F>
30where
31 F: Float,
32{
33 fn eq(&self, other: &Self) -> bool {
34 self.mean.eq(&other.mean) && self.variance.eq(&other.variance)
35 }
36}
37
38impl<F> Eq for Ema<F> where F: Float {}
39
40impl<F> PartialOrd for Ema<F>
41where
42 F: Float,
43{
44 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
45 Some(self.cmp(other))
46 }
47}
48
49impl<F> Ord for Ema<F>
50where
51 F: Float,
52{
53 fn cmp(&self, other: &Self) -> Ordering {
54 self.mean.cmp(&other.mean).then_with(|| self.variance.cmp(&other.variance))
55 }
56}
57
58impl<F> Ema<F>
59where
60 F: Float + TryInto<NotNan<F>>,
61{
62 pub fn try_new(
66 mean: impl TryInto<NotNan<F>, Error = FloatIsNan>,
67 variance: impl TryInto<NotNan<F>, Error = FloatIsNan>,
68 ) -> Result<Self, FloatIsNan> {
69 Ok(Self::new(mean.try_into()?, variance.try_into()?))
70 }
71}
72
73impl<F> Ema<F>
74where
75 F: Float,
76{
77 pub fn new(mean: NotNan<F>, variance: NotNan<F>) -> Self {
81 Self { mean, variance }
82 }
83 pub fn accumulate(&mut self, value: NotNan<F>, recent_weight: NotNan<F>) {
85 let recent_weight = recent_weight.min(NotNan::one()).max(NotNan::zero());
86 let mean = self.mean;
87 let delta = value - mean;
88 let new_mean = mean + recent_weight * delta;
89 let new_variance = (NotNan::one() - recent_weight) * (self.variance + recent_weight * delta * delta);
90 self.mean = new_mean;
91 self.variance = new_variance;
92 }
93 pub fn try_accumulate(&mut self, value: F, recent_weight: F) -> Result<(), FloatIsNan> {
97 let value = NotNan::new(value)?;
98 let recent_weight = NotNan::new(recent_weight)?;
99 self.accumulate(value, recent_weight);
100 Ok(())
101 }
102 #[must_use]
104 #[inline]
105 pub fn mean(&self) -> NotNan<F> {
106 self.mean
107 }
108 #[must_use]
110 #[inline]
111 pub fn variance(&self) -> NotNan<F> {
112 self.variance
113 }
114 #[allow(clippy::missing_panics_doc)]
116 #[must_use]
117 #[inline]
118 pub fn std_dev(&self) -> NotNan<F> {
119 NotNan::new(self.variance.sqrt()).unwrap_or_else(|_| panic!("sqrt won't return NaN if it didn't start with it"))
121 }
122 #[must_use]
124 #[inline]
125 pub fn mean_duration(&self) -> Duration {
126 Duration::from_secs_f64(self.mean().to_f64().unwrap_or(0.0).max(0.0))
127 }
128 #[must_use]
130 #[inline]
131 pub fn std_dev_duration(&self) -> Duration {
132 Duration::from_secs_f64(self.std_dev().to_f64().unwrap_or(0.0).max(0.0))
133 }
134}
135
136impl<F> Default for Ema<F>
137where
138 F: Float,
139{
140 fn default() -> Self {
141 Self {
142 mean: NotNan::zero(),
143 variance: NotNan::zero(),
144 }
145 }
146}
147
148impl<F> core::fmt::Debug for Ema<F>
149where
150 F: Float + core::fmt::Debug,
151{
152 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
153 f.debug_struct("EMA")
154 .field("mean", &*self.mean)
155 .field("variance", &*self.variance)
156 .finish()
157 }
158}
159
160#[derive(Clone)]
162#[must_use]
163pub struct StableEma<F>
164where
165 F: Float,
166{
167 ema: Ema<F>,
168 recent_weight: NotNan<F>,
169}
170
171impl<F> PartialEq for StableEma<F>
172where
173 F: Float,
174{
175 fn eq(&self, other: &Self) -> bool {
176 self.ema.eq(&other.ema) && self.recent_weight.eq(&other.recent_weight)
177 }
178}
179
180impl<F> Eq for StableEma<F> where F: Float {}
181
182impl<F> PartialOrd for StableEma<F>
183where
184 F: Float,
185{
186 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
187 Some(self.cmp(other))
188 }
189}
190
191impl<F> Ord for StableEma<F>
192where
193 F: Float,
194{
195 fn cmp(&self, other: &Self) -> Ordering {
196 self.ema.cmp(&other.ema).then_with(|| self.recent_weight.cmp(&other.recent_weight))
197 }
198}
199
200impl<F> Default for StableEma<F>
201where
202 F: Float,
203{
204 fn default() -> Self {
205 Self {
206 ema: Ema::default(),
207 recent_weight: NotNan::new(F::from(0.1).unwrap_or_else(|| panic!("cannot fail"))).unwrap_or_else(|_| panic!("inner is a number")),
209 }
210 }
211}
212
213impl<F> StableEma<F>
214where
215 F: Float,
216{
217 pub fn new(mean: NotNan<F>, variance: NotNan<F>, recent_weight: NotNan<F>) -> Self {
221 Self {
222 ema: Ema::new(mean, variance),
223 recent_weight,
224 }
225 }
226
227 pub fn try_new<T: TryInto<NotNan<F>, Error = FloatIsNan>>(mean: T, variance: T, recent_weight: T) -> Result<Self, FloatIsNan> {
231 Ok(Self::new(mean.try_into()?, variance.try_into()?, recent_weight.try_into()?))
232 }
233
234 pub fn accumulate(&mut self, value: NotNan<F>) {
236 self.ema.accumulate(value, self.recent_weight)
237 }
238
239 pub fn try_accumulate(&mut self, value: F) -> Result<(), FloatIsNan> {
243 self.accumulate(NotNan::new(value)?);
244 Ok(())
245 }
246
247 #[inline]
249 #[must_use]
250 pub fn mean(&self) -> NotNan<F> {
251 self.ema.mean()
252 }
253
254 #[inline]
256 #[must_use]
257 pub fn variance(&self) -> NotNan<F> {
258 self.ema.variance()
259 }
260
261 #[inline]
263 #[must_use]
264 pub fn std_dev(&self) -> NotNan<F> {
265 self.ema.std_dev()
266 }
267
268 #[must_use]
270 pub fn recent_weight(&self) -> NotNan<F> {
271 self.recent_weight
272 }
273
274 #[inline]
276 #[must_use]
277 pub fn mean_duration(&self) -> Duration {
278 self.ema.mean_duration()
279 }
280
281 #[inline]
283 #[must_use]
284 pub fn std_dev_duration(&self) -> Duration {
285 self.ema.std_dev_duration()
286 }
287
288 #[allow(unsafe_code)]
293 pub unsafe fn set_recent_weight(&mut self, recent_weight: NotNan<F>) {
294 self.recent_weight = recent_weight;
295 }
296}
297
298impl<F> core::fmt::Debug for StableEma<F>
299where
300 F: Float + core::fmt::Debug,
301{
302 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
303 f.debug_struct("StableEMA")
304 .field("mean", &*self.ema.mean)
305 .field("variance", &*self.ema.variance)
306 .field("recent_weight", &*self.recent_weight)
307 .finish()
308 }
309}
310
311#[cfg(test)]
312mod test {
313 use super::*;
314
315 fn test_ema<F: Float + num_traits::FromPrimitive + core::fmt::Debug>(iters: u32, mean_epsilon: F, variance_epsilon: F) {
316 let mut ema = StableEma::<F>::new(NotNan::one(), NotNan::zero(), NotNan::new(F::from_f64(0.2).unwrap()).unwrap());
317 assert_eq!(*ema.mean(), F::one());
318 assert_eq!(*ema.variance(), F::zero());
319 assert_eq!(*ema.std_dev(), F::zero());
320 assert_eq!(ema.mean_duration(), Duration::from_secs(1));
321 assert_eq!(ema.std_dev_duration(), Duration::from_secs(0));
322 assert_eq!(*ema.recent_weight(), F::from_f64(0.2).unwrap());
323 (0..10000).for_each(|_| ema.accumulate(NotNan::one()));
324 assert_eq!(ema.mean(), NotNan::one());
325 assert_eq!(ema.variance(), NotNan::zero());
326
327 (1..=iters).for_each(|i| {
328 ema.accumulate(NotNan::new(F::from(i as f64).unwrap()).unwrap());
329 if i > iters / 2 {
330 assert!(
331 (ema.mean() - F::from((i - 4) as f64).unwrap()).abs() <= mean_epsilon,
332 "mean: {:?}",
333 ema.mean()
334 );
335 assert!(
336 (ema.variance() - F::from(20.0).unwrap()).abs() <= variance_epsilon,
337 "variance: {:?}",
338 ema.variance()
339 );
340 assert!(
341 (ema.std_dev() - F::from(20.0.sqrt()).unwrap()).abs() <= variance_epsilon,
342 "std_dev: {:?}",
343 ema.std_dev()
344 );
345 }
346 });
347 }
348
349 #[test]
350 fn test_types() {
351 use half::{bf16, f16};
352 test_ema::<f32>(10000, 1e-7, 1e-5);
353 let mut ema = StableEma::<f32>::default();
354 ema.try_accumulate(f32::NAN).unwrap_err();
355 test_ema::<f64>(100000, 1e-7, 1e-5);
356 let mut ema = StableEma::<f64>::default();
357 ema.try_accumulate(f64::NAN).unwrap_err();
358 test_ema::<bf16>(250, bf16::from_f32(1e-7), bf16::from_f32(0.25));
359 let mut ema = Ema::<bf16>::default();
360 ema.try_accumulate(bf16::from_f32(f32::NAN), bf16::from_f32(0.5)).unwrap_err();
361 test_ema::<f16>(500, f16::from_f32(1e-7), f16::from_f32(0.125));
362 let mut ema = Ema::<f16>::default();
363 ema.try_accumulate(f16::from_f32(f32::NAN), f16::from_f32(0.5)).unwrap_err();
364 }
365}