amadeus_core/par_sink/
stddev.rs

1use derive_new::new;
2use educe::Educe;
3use serde::{Deserialize, Serialize};
4use std::marker::PhantomData;
5
6use super::{folder_par_sink, FolderSync, FolderSyncReducer, ParallelPipe, ParallelSink};
7use crate::util::u64_to_f64;
8
9#[derive(new)]
10#[must_use]
11pub struct StdDev<P> {
12	pipe: P,
13}
14
15impl_par_dist! {
16	impl<P: ParallelPipe<Item, Output = f64>, Item, > ParallelSink<Item> for StdDev<P> {
17		folder_par_sink!(
18			SDFolder<StepA>,
19			SDFolder<StepB>,
20			self,
21			SDFolder::new(),
22			SDFolder::new()
23		);
24	}
25}
26
27#[derive(Educe, Serialize, Deserialize, new)]
28#[educe(Clone)]
29#[serde(bound = "")]
30pub struct SDFolder<Step> {
31	marker: PhantomData<fn() -> Step>,
32}
33
34pub struct StepA;
35pub struct StepB;
36
37#[derive(Serialize, Deserialize, new)]
38pub struct SDState {
39	#[new(default)]
40	count: u64,
41	#[new(default)]
42	mean: f64,
43	#[new(default)]
44	variance: f64,
45}
46
47impl FolderSync<f64> for SDFolder<StepA> {
48	type State = SDState;
49	type Done = f64;
50
51	#[inline(always)]
52	fn zero(&mut self) -> Self::State {
53		SDState::new()
54	}
55
56	#[inline(always)]
57	fn push(&mut self, state: &mut Self::State, item: f64) {
58		// Taken from https://docs.rs/streaming-stats/0.2.3/src/stats/online.rs.html#64-103
59		let q_prev = state.variance * u64_to_f64(state.count);
60		let mean_prev = state.mean;
61		state.count += 1;
62		let count = u64_to_f64(state.count);
63		state.mean += (item - state.mean) / count;
64		state.variance = (q_prev + (item - mean_prev) * (item - state.mean)) / count;
65	}
66
67	#[inline(always)]
68	fn done(&mut self, state: Self::State) -> Self::Done {
69		state.variance.sqrt()
70	}
71}
72
73impl FolderSync<SDState> for SDFolder<StepB> {
74	type State = SDState;
75	type Done = f64;
76
77	#[inline(always)]
78	fn zero(&mut self) -> Self::State {
79		SDState::new()
80	}
81
82	#[inline(always)]
83	fn push(&mut self, state: &mut Self::State, item: SDState) {
84		let (s1, s2) = (u64_to_f64(state.count), u64_to_f64(item.count));
85		let meandiffsq = (state.mean - item.mean) * (state.mean - item.mean);
86		let mean = ((s1 * state.mean) + (s2 * item.mean)) / (s1 + s2);
87		let var = (((s1 * state.variance) + (s2 * item.variance)) / (s1 + s2))
88			+ ((s1 * s2 * meandiffsq) / ((s1 + s2) * (s1 + s2)));
89		state.count += item.count;
90		state.mean = mean;
91		state.variance = var;
92	}
93
94	#[inline(always)]
95	fn done(&mut self, state: Self::State) -> Self::Done {
96		state.variance.sqrt()
97	}
98}