use std::ops::{Div, Sub};
use crate::{welford_online::welford_online, Iterstats};
pub trait ZScore<A = Self>: Sized {
type Output;
fn zscore<I>(iter: I) -> ZScoreIter<Self::Output>
where
I: Iterator<Item = A> + Clone;
}
#[derive(Debug, Clone)]
pub struct ZScoreIter<T> {
iter: std::vec::IntoIter<T>,
mean: T,
stddev: T,
}
impl<T> Iterator for ZScoreIter<T>
where
T: Copy + Sub<Output = T> + Div<Output = T>,
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|v| (v - self.mean) / self.stddev)
}
}
macro_rules! zscore_impl {
($typ:ty) => {
impl ZScore for $typ {
type Output = $typ;
fn zscore<I>(iter: I) -> ZScoreIter<Self::Output>
where
I: Iterator<Item = Self> + Clone,
{
let wo = welford_online(iter.clone());
let itertype = iter.collect::<Vec<_>>();
ZScoreIter {
iter: itertype.into_iter(),
mean: wo.mean,
stddev: (wo.sum_of_squares / wo.count).sqrt(),
}
}
}
impl ZScore for &$typ {
type Output = $typ;
fn zscore<I>(iter: I) -> ZScoreIter<Self::Output>
where
I: Iterator<Item = Self> + Clone,
{
iter.map(|i| *i).zscore()
}
}
};
}
zscore_impl!(f64);
zscore_impl!(f32);
#[cfg(test)]
mod tests {
use super::*;
use paste::paste;
macro_rules! test_zscore {
( $name:ident: $iterty:ty as $iter:expr; into_iter => nan ) => {
paste! {
#[test]
fn [<$name _into_iter >]() {
assert!(<$iterty>::zscore($iter.into_iter()).all(|v| v.is_nan()));
}
}
};
( $name:ident: $iterty:ty as $iter:expr; iter => nan ) => {
paste! {
#[test]
fn [<$name _iter >]() {
assert!(<&$iterty>::zscore($iter.iter()).all(|v| v.is_nan()));
}
}
};
( $name:ident: $iterty:ty as $iter:expr => nan ) => {
test_zscore!($name: $iterty as $iter ; into_iter => nan);
test_zscore!($name: $iterty as $iter ; iter => nan);
};
( $name:ident: $iterty:ty as $iter:expr; into_iter => $expected:expr ) => {
paste! {
#[test]
fn [<$name _into_iter >]() {
let zscores: Vec<_> = <$iterty>::zscore($iter.into_iter()).collect();
assert_eq!(zscores, $expected);
}
}
};
( $name:ident: $iterty:ty as $iter:expr; iter => $expected:expr ) => {
paste! {
#[test]
fn [<$name _iter >]() {
let zscores: Vec<_> = <&$iterty>::zscore($iter.iter()).collect();
assert_eq!(zscores, $expected);
}
}
};
( $name:ident: $iterty:ty as $iter:expr => $expected:expr ) => {
test_zscore!($name: $iterty as $iter ; into_iter => $expected);
test_zscore!($name: $iterty as $iter ; iter => $expected);
};
}
test_zscore!(f64: f64 as [1.0, 2.0, 3.0, 4.0] => vec![-1.3416407864998738, -0.4472135954999579, 0.4472135954999579, 1.3416407864998738] );
test_zscore!(f64_with_nan: f64 as [1.0, 2.0, 3.0, 4.0, f64::NAN] => nan );
test_zscore!(f64_with_inf: f64 as [1.0, 2.0, 3.0, 4.0, f64::INFINITY] => nan );
test_zscore!(f32: f32 as [1.0, 2.0, 3.0, 4.0] => vec![-1.3416407, -0.4472136, 0.4472136, 1.3416407] );
test_zscore!(f32_with_nan: f32 as [1.0, 2.0, 3.0, 4.0, f32::NAN] => nan );
test_zscore!(f32_with_inf: f32 as [1.0, 2.0, 3.0, 4.0, f32::INFINITY] => nan );
}