#![allow(clippy::suboptimal_flops)]
extern crate alloc;
use alloc::boxed::Box;
use alloc::vec;
#[inline]
fn sigmoid(z: f64) -> f64 {
if z >= 0.0 {
1.0 / (1.0 + crate::math::exp(-z))
} else {
let e = crate::math::exp(z);
e / (1.0 + e)
}
}
#[derive(Debug, Clone)]
pub struct LogisticRegressionF64 {
weights: Box<[f64]>,
learning_rate: f64,
dims: usize,
count: u64,
}
#[derive(Debug, Clone)]
pub struct LogisticRegressionF64Builder {
dimensions: Option<usize>,
learning_rate: Option<f64>,
}
impl LogisticRegressionF64 {
#[inline]
#[must_use]
pub fn builder() -> LogisticRegressionF64Builder {
LogisticRegressionF64Builder {
dimensions: Option::None,
learning_rate: Option::None,
}
}
#[inline]
#[must_use]
pub fn predict(&self, features: &[f64]) -> f64 {
assert_eq!(
features.len(),
self.dims,
"feature length {} != dimensions {}",
features.len(),
self.dims,
);
let mut z = 0.0_f64;
for i in 0..self.dims {
z += self.weights[i] * features[i];
}
sigmoid(z)
}
#[inline]
pub fn update(&mut self, features: &[f64], outcome: bool) {
debug_assert!(features.iter().all(|f| f.is_finite()), "features must be finite");
assert_eq!(
features.len(),
self.dims,
"feature length {} != dimensions {}",
features.len(),
self.dims,
);
let mut z = 0.0_f64;
for i in 0..self.dims {
z += self.weights[i] * features[i];
}
let p = sigmoid(z);
let error = (outcome as u8 as f64) - p;
let step = self.learning_rate * error;
for i in 0..self.dims {
self.weights[i] += step * features[i];
}
self.count += 1;
}
#[inline]
#[must_use]
pub fn weights(&self) -> &[f64] {
&self.weights
}
#[inline]
#[must_use]
pub fn dimensions(&self) -> usize {
self.dims
}
#[inline]
#[must_use]
pub fn learning_rate(&self) -> f64 {
self.learning_rate
}
#[inline]
#[must_use]
pub fn count(&self) -> u64 {
self.count
}
#[inline]
#[must_use]
pub fn is_primed(&self) -> bool {
self.count > 0
}
#[inline]
pub fn reset(&mut self) {
self.weights.fill(0.0);
self.count = 0;
}
}
impl LogisticRegressionF64Builder {
#[inline]
#[must_use]
pub fn dimensions(mut self, dims: usize) -> Self {
self.dimensions = Option::Some(dims);
self
}
#[inline]
#[must_use]
pub fn learning_rate(mut self, lr: f64) -> Self {
self.learning_rate = Option::Some(lr);
self
}
#[inline]
pub fn build(self) -> Result<LogisticRegressionF64, crate::ConfigError> {
let dims = self
.dimensions
.ok_or(crate::ConfigError::Missing("dimensions"))?;
let lr = self
.learning_rate
.ok_or(crate::ConfigError::Missing("learning_rate"))?;
if dims < 1 {
return Err(crate::ConfigError::Invalid("dimensions must be >= 1"));
}
if lr <= 0.0 {
return Err(crate::ConfigError::Invalid(
"learning_rate must be positive",
));
}
Ok(LogisticRegressionF64 {
weights: vec![0.0_f64; dims].into_boxed_slice(),
learning_rate: lr,
dims,
count: 0,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn linearly_separable_convergence() {
let mut lr = LogisticRegressionF64::builder()
.dimensions(2)
.learning_rate(0.1)
.build()
.unwrap();
for i in 0..2000 {
let offset = (i as f64 * 0.37).sin() * 0.3;
if i % 2 == 0 {
lr.update(&[1.0 + offset, 1.0 + offset], true);
} else {
lr.update(&[-1.0 + offset, -1.0 + offset], false);
}
}
let p_positive = lr.predict(&[2.0, 2.0]);
let p_negative = lr.predict(&[-2.0, -2.0]);
assert!(
p_positive > 0.9,
"p(true | [2,2]) = {p_positive}, expected > 0.9"
);
assert!(
p_negative < 0.1,
"p(true | [-2,-2]) = {p_negative}, expected < 0.1"
);
}
#[test]
fn predict_in_range() {
let mut lr = LogisticRegressionF64::builder()
.dimensions(2)
.learning_rate(0.1)
.build()
.unwrap();
let p = lr.predict(&[100.0, -100.0]);
assert!((0.0..=1.0).contains(&p), "p = {p}, expected in [0, 1]");
lr.update(&[1.0, 0.0], true);
let p = lr.predict(&[1000.0, 0.0]);
assert!((0.0..=1.0).contains(&p), "p = {p}, expected in [0, 1]");
let p = lr.predict(&[-1000.0, 0.0]);
assert!((0.0..=1.0).contains(&p), "p = {p}, expected in [0, 1]");
}
#[test]
fn reset_clears_weights() {
let mut lr = LogisticRegressionF64::builder()
.dimensions(2)
.learning_rate(0.1)
.build()
.unwrap();
lr.update(&[1.0, 2.0], true);
assert!(lr.count() > 0);
assert!(lr.weights().iter().any(|&w| w != 0.0));
lr.reset();
assert_eq!(lr.count(), 0);
assert!(lr.weights().iter().all(|&w| w == 0.0));
}
#[test]
#[should_panic(expected = "feature length")]
fn dimension_mismatch_predict() {
let lr = LogisticRegressionF64::builder()
.dimensions(3)
.learning_rate(0.1)
.build()
.unwrap();
lr.predict(&[1.0, 2.0]);
}
#[test]
#[should_panic(expected = "feature length")]
fn dimension_mismatch_update() {
let mut lr = LogisticRegressionF64::builder()
.dimensions(3)
.learning_rate(0.1)
.build()
.unwrap();
lr.update(&[1.0], true);
}
#[test]
fn builder_rejects_zero_dimensions() {
let result = LogisticRegressionF64::builder()
.dimensions(0)
.learning_rate(0.1)
.build();
assert!(result.is_err());
}
#[test]
fn builder_rejects_negative_learning_rate() {
let result = LogisticRegressionF64::builder()
.dimensions(2)
.learning_rate(-0.01)
.build();
assert!(result.is_err());
}
#[test]
fn builder_missing_dimensions() {
let result = LogisticRegressionF64::builder().learning_rate(0.1).build();
assert!(matches!(
result,
Err(crate::ConfigError::Missing("dimensions"))
));
}
#[test]
fn builder_missing_learning_rate() {
let result = LogisticRegressionF64::builder().dimensions(2).build();
assert!(matches!(
result,
Err(crate::ConfigError::Missing("learning_rate"))
));
}
#[test]
fn count_tracks_updates() {
let mut lr = LogisticRegressionF64::builder()
.dimensions(1)
.learning_rate(0.1)
.build()
.unwrap();
assert_eq!(lr.count(), 0);
lr.update(&[1.0], true);
assert_eq!(lr.count(), 1);
lr.update(&[1.0], false);
assert_eq!(lr.count(), 2);
}
}