#![allow(clippy::suboptimal_flops, clippy::float_cmp)]
macro_rules! impl_linear_regression {
($name:ident, $builder:ident, $ty:ty) => {
#[doc = concat!("use nexus_stats::regression::", stringify!($name), ";")]
#[doc = concat!("let mut r = ", stringify!($name), "::new();")]
#[doc = concat!("for x in 0..100u64 { r.update(x as ", stringify!($ty), ", 2.0 as ", stringify!($ty), " * x as ", stringify!($ty), " + 3.0 as ", stringify!($ty), "); }")]
#[doc = concat!("assert!((slope - 2.0 as ", stringify!($ty), ").abs() < 0.001 as ", stringify!($ty), ");")]
#[derive(Debug, Clone)]
pub struct $name {
sum_x: $ty,
sum_x2: $ty,
sum_y: $ty,
sum_xy: $ty,
sum_y2: $ty,
count: u64,
intercept: bool,
}
#[doc = stringify!($name)]
#[derive(Debug, Clone)]
pub struct $builder {
intercept: bool,
}
impl $name {
#[inline]
#[must_use]
pub fn builder() -> $builder {
$builder { intercept: true }
}
#[inline]
#[must_use]
pub fn new() -> Self {
Self {
sum_x: 0.0 as $ty,
sum_x2: 0.0 as $ty,
sum_y: 0.0 as $ty,
sum_xy: 0.0 as $ty,
sum_y2: 0.0 as $ty,
count: 0,
intercept: true,
}
}
#[inline]
#[must_use]
pub fn through_origin() -> Self {
Self {
sum_x: 0.0 as $ty,
sum_x2: 0.0 as $ty,
sum_y: 0.0 as $ty,
sum_xy: 0.0 as $ty,
sum_y2: 0.0 as $ty,
count: 0,
intercept: false,
}
}
#[inline]
fn dim(&self) -> usize {
1 + self.intercept as usize
}
#[inline]
pub fn update(&mut self, x: $ty, y: $ty) -> Result<(), crate::DataError> {
check_finite!(x);
check_finite!(y);
self.count += 1;
self.sum_x += x;
self.sum_x2 += x * x;
self.sum_y += y;
self.sum_xy += x * y;
self.sum_y2 += y * y;
Ok(())
}
#[must_use]
pub fn slope(&self) -> Option<$ty> {
if (self.count as usize) < self.dim() {
return Option::None;
}
if self.intercept {
let n = self.count as $ty;
let denom = n * self.sum_x2 - self.sum_x * self.sum_x;
if denom == 0.0 as $ty {
return Option::None;
}
Option::Some((n * self.sum_xy - self.sum_x * self.sum_y) / denom)
} else {
if self.sum_x2 == 0.0 as $ty {
return Option::None;
}
Option::Some(self.sum_xy / self.sum_x2)
}
}
#[must_use]
pub fn intercept_value(&self) -> Option<$ty> {
if !self.intercept {
return Option::None;
}
let slope = self.slope()?;
let n = self.count as $ty;
Option::Some((self.sum_y - slope * self.sum_x) / n)
}
#[must_use]
pub fn r_squared(&self) -> Option<$ty> {
let slope = self.slope()?;
let n = self.count as $ty;
let (ss_res, ss_tot) = if self.intercept {
let intercept = (self.sum_y - slope * self.sum_x) / n;
let ss_res = self.sum_y2
- 2.0 as $ty * slope * self.sum_xy
- 2.0 as $ty * intercept * self.sum_y
+ slope * slope * self.sum_x2
+ 2.0 as $ty * slope * intercept * self.sum_x
+ n * intercept * intercept;
let ss_tot = self.sum_y2 - self.sum_y * self.sum_y / n;
(ss_res, ss_tot)
} else {
let ss_res = self.sum_y2
- 2.0 as $ty * slope * self.sum_xy
+ slope * slope * self.sum_x2;
(ss_res, self.sum_y2)
};
if ss_tot <= 0.0 as $ty {
return Option::None;
}
Option::Some(1.0 as $ty - ss_res / ss_tot)
}
#[must_use]
pub fn predict(&self, x: $ty) -> Option<$ty> {
let slope = self.slope()?;
if self.intercept {
let intercept = self.intercept_value()?;
Option::Some(slope * x + intercept)
} else {
Option::Some(slope * x)
}
}
#[inline]
#[must_use]
pub fn count(&self) -> u64 {
self.count
}
#[inline]
#[must_use]
pub fn is_primed(&self) -> bool {
(self.count as usize) >= self.dim()
}
#[inline]
#[must_use]
pub fn has_intercept(&self) -> bool {
self.intercept
}
#[inline]
pub fn reset(&mut self) {
self.sum_x = 0.0 as $ty;
self.sum_x2 = 0.0 as $ty;
self.sum_y = 0.0 as $ty;
self.sum_xy = 0.0 as $ty;
self.sum_y2 = 0.0 as $ty;
self.count = 0;
}
}
impl Default for $name {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl $builder {
#[inline]
#[must_use]
pub fn intercept(mut self, intercept: bool) -> Self {
self.intercept = intercept;
self
}
pub fn build(self) -> Result<$name, crate::ConfigError> {
Ok($name {
sum_x: 0.0 as $ty,
sum_x2: 0.0 as $ty,
sum_y: 0.0 as $ty,
sum_xy: 0.0 as $ty,
sum_y2: 0.0 as $ty,
count: 0,
intercept: self.intercept,
})
}
}
};
}
impl_linear_regression!(LinearRegressionF64, LinearRegressionF64Builder, f64);
impl_linear_regression!(LinearRegressionF32, LinearRegressionF32Builder, f32);
macro_rules! impl_ew_linear_regression {
($name:ident, $builder:ident, $ty:ty) => {
#[doc = concat!("use nexus_stats::regression::", stringify!($name), ";")]
#[doc = concat!("let mut r = ", stringify!($name), "::builder()")]
#[doc = concat!("for x in 0..200u64 { r.update(x as ", stringify!($ty), ", 2.0 as ", stringify!($ty), " * x as ", stringify!($ty), "); }")]
#[derive(Debug, Clone)]
pub struct $name {
sum_x: $ty,
sum_x2: $ty,
sum_y: $ty,
sum_xy: $ty,
sum_y2: $ty,
alpha: $ty,
one_minus_alpha: $ty,
effective_n: $ty,
count: u64,
intercept: bool,
}
#[doc = stringify!($name)]
#[derive(Debug, Clone)]
pub struct $builder {
alpha: Option<$ty>,
intercept: bool,
}
impl $name {
#[inline]
#[must_use]
pub fn builder() -> $builder {
$builder {
alpha: Option::None,
intercept: true,
}
}
fn dim(&self) -> usize {
1 + self.intercept as usize
}
#[inline]
pub fn update(&mut self, x: $ty, y: $ty) -> Result<(), crate::DataError> {
check_finite!(x);
check_finite!(y);
self.count += 1;
self.effective_n = self.one_minus_alpha * self.effective_n + 1.0 as $ty;
self.sum_x = self.one_minus_alpha * self.sum_x + x;
self.sum_x2 = self.one_minus_alpha * self.sum_x2 + x * x;
self.sum_y = self.one_minus_alpha * self.sum_y + y;
self.sum_xy = self.one_minus_alpha * self.sum_xy + x * y;
self.sum_y2 = self.one_minus_alpha * self.sum_y2 + y * y;
Ok(())
}
#[must_use]
pub fn slope(&self) -> Option<$ty> {
if (self.effective_n as usize) < self.dim() {
return Option::None;
}
if self.intercept {
let n = self.effective_n;
let denom = n * self.sum_x2 - self.sum_x * self.sum_x;
if denom == 0.0 as $ty {
return Option::None;
}
Option::Some((n * self.sum_xy - self.sum_x * self.sum_y) / denom)
} else {
if self.sum_x2 == 0.0 as $ty {
return Option::None;
}
Option::Some(self.sum_xy / self.sum_x2)
}
}
#[must_use]
pub fn intercept_value(&self) -> Option<$ty> {
if !self.intercept {
return Option::None;
}
let slope = self.slope()?;
Option::Some((self.sum_y - slope * self.sum_x) / self.effective_n)
}
#[must_use]
pub fn r_squared(&self) -> Option<$ty> {
let slope = self.slope()?;
let n = self.effective_n;
let (ss_res, ss_tot) = if self.intercept {
let intercept = (self.sum_y - slope * self.sum_x) / n;
let ss_res = self.sum_y2
- 2.0 as $ty * slope * self.sum_xy
- 2.0 as $ty * intercept * self.sum_y
+ slope * slope * self.sum_x2
+ 2.0 as $ty * slope * intercept * self.sum_x
+ n * intercept * intercept;
let ss_tot = self.sum_y2 - self.sum_y * self.sum_y / n;
(ss_res, ss_tot)
} else {
let ss_res = self.sum_y2
- 2.0 as $ty * slope * self.sum_xy
+ slope * slope * self.sum_x2;
(ss_res, self.sum_y2)
};
if ss_tot <= 0.0 as $ty {
return Option::None;
}
Option::Some(1.0 as $ty - ss_res / ss_tot)
}
#[must_use]
pub fn predict(&self, x: $ty) -> Option<$ty> {
let slope = self.slope()?;
if self.intercept {
let intercept = self.intercept_value()?;
Option::Some(slope * x + intercept)
} else {
Option::Some(slope * x)
}
}
#[inline]
#[must_use]
pub fn count(&self) -> u64 {
self.count
}
#[inline]
#[must_use]
pub fn effective_count(&self) -> $ty {
self.effective_n
}
#[inline]
#[must_use]
pub fn alpha(&self) -> $ty {
self.alpha
}
#[inline]
#[must_use]
pub fn is_primed(&self) -> bool {
(self.effective_n as usize) >= self.dim()
}
#[inline]
#[must_use]
pub fn has_intercept(&self) -> bool {
self.intercept
}
#[inline]
pub fn reset(&mut self) {
self.sum_x = 0.0 as $ty;
self.sum_x2 = 0.0 as $ty;
self.sum_y = 0.0 as $ty;
self.sum_xy = 0.0 as $ty;
self.sum_y2 = 0.0 as $ty;
self.effective_n = 0.0 as $ty;
self.count = 0;
}
}
impl $builder {
#[inline]
#[must_use]
pub fn alpha(mut self, alpha: $ty) -> Self {
self.alpha = Option::Some(alpha);
self
}
#[inline]
#[must_use]
pub fn intercept(mut self, intercept: bool) -> Self {
self.intercept = intercept;
self
}
pub fn build(self) -> Result<$name, crate::ConfigError> {
let alpha = self.alpha
.ok_or(crate::ConfigError::Missing("alpha"))?;
if !(alpha > 0.0 as $ty && alpha < 1.0 as $ty) {
return Err(crate::ConfigError::Invalid(
"alpha must be in (0, 1) exclusive",
));
}
Ok($name {
sum_x: 0.0 as $ty,
sum_x2: 0.0 as $ty,
sum_y: 0.0 as $ty,
sum_xy: 0.0 as $ty,
sum_y2: 0.0 as $ty,
alpha,
one_minus_alpha: 1.0 as $ty - alpha,
effective_n: 0.0 as $ty,
count: 0,
intercept: self.intercept,
})
}
}
};
}
impl_ew_linear_regression!(EwLinearRegressionF64, EwLinearRegressionF64Builder, f64);
impl_ew_linear_regression!(EwLinearRegressionF32, EwLinearRegressionF32Builder, f32);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn linear_exact_fit() {
let mut r = LinearRegressionF64::new();
for x in 0..100 {
r.update(x as f64, 2.0 * x as f64 + 3.0).unwrap();
}
assert!((r.slope().unwrap() - 2.0).abs() < 1e-8);
assert!((r.intercept_value().unwrap() - 3.0).abs() < 1e-8);
assert!((r.r_squared().unwrap() - 1.0).abs() < 1e-10);
}
#[test]
fn through_origin() {
let mut r = LinearRegressionF64::through_origin();
for x in 1..100 {
r.update(x as f64, 5.0 * x as f64).unwrap();
}
assert!((r.slope().unwrap() - 5.0).abs() < 1e-8);
assert!(r.intercept_value().is_none());
assert!(!r.has_intercept());
}
#[test]
fn predict_linear() {
let mut r = LinearRegressionF64::new();
for x in 0..100 {
r.update(x as f64, 2.0 * x as f64 + 3.0).unwrap();
}
let y = r.predict(50.0).unwrap();
assert!((y - 103.0).abs() < 1e-6, "predict(50) = {y}");
}
#[test]
fn r_squared_noisy() {
let mut r = LinearRegressionF64::new();
let mut rng = 12345u64;
for x in 0..1000 {
rng ^= rng << 13;
rng ^= rng >> 7;
rng ^= rng << 17;
let noise = (rng % 100) as f64 - 50.0;
r.update(x as f64, 2.0 * x as f64 + noise).unwrap();
}
let r2 = r.r_squared().unwrap();
assert!(r2 > 0.9 && r2 < 1.0, "R² with noise = {r2}");
}
#[test]
fn constant_y_r_squared_none() {
let mut r = LinearRegressionF64::new();
for x in 0..100 {
r.update(x as f64, 42.0).unwrap();
}
assert!(r.r_squared().is_none());
}
#[test]
fn not_primed_returns_none() {
let mut r = LinearRegressionF64::new();
assert!(r.slope().is_none());
r.update(1.0, 2.0).unwrap();
assert!(r.slope().is_none()); r.update(2.0, 4.0).unwrap();
assert!(r.slope().is_some());
}
#[test]
fn through_origin_primes_at_1() {
let mut r = LinearRegressionF64::through_origin();
assert!(!r.is_primed());
r.update(1.0, 5.0).unwrap();
assert!(r.is_primed());
assert!((r.slope().unwrap() - 5.0).abs() < 1e-10);
}
#[test]
fn reset_clears_state() {
let mut r = LinearRegressionF64::new();
for x in 0..100 {
r.update(x as f64, x as f64).unwrap();
}
r.reset();
assert_eq!(r.count(), 0);
assert!(r.slope().is_none());
assert!(r.has_intercept());
}
#[test]
fn builder_no_intercept() {
let r = LinearRegressionF64::builder()
.intercept(false)
.build()
.unwrap();
assert!(!r.has_intercept());
}
#[test]
fn empty_returns_none() {
let r = LinearRegressionF64::new();
assert!(r.slope().is_none());
assert!(r.intercept_value().is_none());
assert!(r.predict(1.0).is_none());
assert!(r.r_squared().is_none());
}
#[test]
fn f32_basic() {
let mut r = LinearRegressionF32::new();
for x in 0..100u32 {
r.update(x as f32, 2.0 * x as f32 + 3.0).unwrap();
}
assert!((r.slope().unwrap() - 2.0).abs() < 0.01);
}
#[test]
fn default_is_new() {
let r = LinearRegressionF64::default();
assert_eq!(r.count(), 0);
assert!(r.has_intercept());
}
#[test]
fn ew_basic() {
let mut r = EwLinearRegressionF64::builder()
.alpha(0.05)
.build()
.unwrap();
for x in 0..500 {
r.update(x as f64, 2.0 * x as f64 + 3.0).unwrap();
}
let slope = r.slope().unwrap();
assert!((slope - 2.0).abs() < 0.5, "ew slope = {slope}");
}
#[test]
fn ew_adapts_to_trend_change() {
let mut r = EwLinearRegressionF64::builder()
.alpha(0.05)
.build()
.unwrap();
for x in 0..200 {
r.update(x as f64, x as f64).unwrap();
}
for x in 200..500 {
r.update(x as f64, -(x as f64) + 400.0).unwrap();
}
let slope = r.slope().unwrap();
assert!(slope < 0.0, "slope should be negative, got {slope}");
}
#[test]
fn ew_rejects_invalid_alpha() {
assert!(EwLinearRegressionF64::builder().alpha(0.0).build().is_err());
assert!(EwLinearRegressionF64::builder().alpha(1.0).build().is_err());
}
#[test]
fn ew_rejects_missing_alpha() {
assert!(EwLinearRegressionF64::builder().build().is_err());
}
#[test]
fn ew_reset() {
let mut r = EwLinearRegressionF64::builder()
.alpha(0.05)
.build()
.unwrap();
for x in 0..100 {
r.update(x as f64, x as f64).unwrap();
}
r.reset();
assert_eq!(r.count(), 0);
assert!(r.slope().is_none());
}
#[test]
fn ew_f32() {
let mut r = EwLinearRegressionF32::builder()
.alpha(0.05)
.build()
.unwrap();
for x in 0..200u32 {
r.update(x as f32, x as f32).unwrap();
}
assert!(r.is_primed());
}
#[test]
fn ols_rejects_nan_and_inf() {
let mut r = LinearRegressionF64::new();
assert_eq!(r.update(f64::NAN, 1.0), Err(crate::DataError::NotANumber));
assert_eq!(
r.update(1.0, f64::INFINITY),
Err(crate::DataError::Infinite)
);
assert_eq!(r.count(), 0);
}
#[test]
fn ew_rejects_nan_and_inf() {
let mut r = EwLinearRegressionF64::builder()
.alpha(0.05)
.build()
.unwrap();
assert_eq!(r.update(f64::NAN, 1.0), Err(crate::DataError::NotANumber));
assert_eq!(
r.update(1.0, f64::INFINITY),
Err(crate::DataError::Infinite)
);
assert_eq!(r.count(), 0);
}
}