use crate::solvers::binomial::{BinomialRegressor, FittedBinomial};
use crate::solvers::traits::{FittedRegressor, RegressionError, Regressor};
use faer::{Col, Mat};
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum Penalty {
#[default]
None,
L2(f64),
}
#[derive(Debug, Clone)]
pub struct LogisticRegression {
penalty: Penalty,
with_intercept: bool,
threshold: f64,
max_iterations: usize,
tolerance: f64,
compute_inference: bool,
confidence_level: f64,
}
impl LogisticRegression {
pub fn builder() -> LogisticRegressionBuilder {
LogisticRegressionBuilder::default()
}
pub fn fit(&self, x: &Mat<f64>, y: &Col<f64>) -> Result<FittedLogistic, RegressionError> {
for i in 0..y.nrows() {
let val = y[i];
if val != 0.0 && val != 1.0 {
return Err(RegressionError::NumericalError(format!(
"y must be binary (0.0 or 1.0), found {} at index {}",
val, i
)));
}
}
let mut builder = BinomialRegressor::logistic()
.with_intercept(self.with_intercept)
.max_iterations(self.max_iterations)
.tolerance(self.tolerance)
.compute_inference(self.compute_inference)
.confidence_level(self.confidence_level);
if let Penalty::L2(lambda) = self.penalty {
builder = builder.lambda(lambda);
}
let binomial = builder.build();
let inner = binomial.fit(x, y)?;
Ok(FittedLogistic {
inner,
threshold: self.threshold,
})
}
}
#[derive(Debug)]
pub struct FittedLogistic {
inner: FittedBinomial,
threshold: f64,
}
impl FittedLogistic {
pub fn predict(&self, x: &Mat<f64>) -> Col<f64> {
let probs = self.predict_proba(x);
let threshold = self.threshold;
Col::from_fn(
probs.nrows(),
|i| {
if probs[i] >= threshold {
1.0
} else {
0.0
}
},
)
}
pub fn predict_proba(&self, x: &Mat<f64>) -> Col<f64> {
self.inner.predict_probability(x)
}
pub fn decision_function(&self, x: &Mat<f64>) -> Col<f64> {
self.inner.predict_linear(x)
}
pub fn score(&self, x: &Mat<f64>, y: &Col<f64>) -> f64 {
let predictions = self.predict(x);
let n = y.nrows();
let correct = predictions
.iter()
.zip(y.iter())
.filter(|(&pred, &actual)| pred == actual)
.count();
correct as f64 / n as f64
}
pub fn coefficients(&self) -> &Col<f64> {
&self.inner.result().coefficients
}
pub fn intercept(&self) -> Option<f64> {
self.inner.result().intercept
}
pub fn inner(&self) -> &FittedBinomial {
&self.inner
}
pub fn n_iter(&self) -> usize {
self.inner.iterations
}
}
#[derive(Debug, Clone)]
pub struct LogisticRegressionBuilder {
penalty: Penalty,
with_intercept: bool,
threshold: f64,
max_iterations: usize,
tolerance: f64,
compute_inference: bool,
confidence_level: f64,
}
impl Default for LogisticRegressionBuilder {
fn default() -> Self {
Self {
penalty: Penalty::None,
with_intercept: true,
threshold: 0.5,
max_iterations: 100,
tolerance: 1e-8,
compute_inference: true,
confidence_level: 0.95,
}
}
}
impl LogisticRegressionBuilder {
pub fn penalty(mut self, penalty: Penalty) -> Self {
self.penalty = penalty;
self
}
pub fn l2(mut self, lambda: f64) -> Self {
self.penalty = Penalty::L2(lambda);
self
}
pub fn c(mut self, c_value: f64) -> Self {
assert!(c_value > 0.0, "C must be positive, got {}", c_value);
self.penalty = Penalty::L2(1.0 / c_value);
self
}
pub fn with_intercept(mut self, include: bool) -> Self {
self.with_intercept = include;
self
}
pub fn threshold(mut self, threshold: f64) -> Self {
self.threshold = threshold;
self
}
pub fn max_iterations(mut self, max_iter: usize) -> Self {
self.max_iterations = max_iter;
self
}
pub fn tolerance(mut self, tol: f64) -> Self {
self.tolerance = tol;
self
}
pub fn compute_inference(mut self, compute: bool) -> Self {
self.compute_inference = compute;
self
}
pub fn confidence_level(mut self, level: f64) -> Self {
self.confidence_level = level;
self
}
pub fn build(self) -> LogisticRegression {
LogisticRegression {
penalty: self.penalty,
with_intercept: self.with_intercept,
threshold: self.threshold,
max_iterations: self.max_iterations,
tolerance: self.tolerance,
compute_inference: self.compute_inference,
confidence_level: self.confidence_level,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_data(n: usize) -> (Mat<f64>, Col<f64>) {
let x = Mat::from_fn(n, 1, |i, _| (i as f64) / (n as f64) * 4.0 - 2.0);
let y = Col::from_fn(n, |i| {
let xi = (i as f64) / (n as f64) * 4.0 - 2.0;
let prob = 1.0 / (1.0 + (-xi).exp());
if prob > 0.5 + 0.1 * ((i % 5) as f64 - 2.0) / 2.0 {
1.0
} else {
0.0
}
});
(x, y)
}
#[test]
fn test_logistic_defaults() {
let builder = LogisticRegressionBuilder::default();
assert_eq!(builder.penalty, Penalty::None);
assert!(builder.with_intercept);
assert!((builder.threshold - 0.5).abs() < f64::EPSILON);
assert_eq!(builder.max_iterations, 100);
assert!((builder.tolerance - 1e-8).abs() < f64::EPSILON);
assert!(builder.compute_inference);
assert!((builder.confidence_level - 0.95).abs() < f64::EPSILON);
}
#[test]
fn test_logistic_binary_classification() {
let (x, y) = create_test_data(100);
let model = LogisticRegression::builder().build();
let fitted = model.fit(&x, &y).expect("model should fit");
assert!(
fitted.coefficients()[0] > 0.0,
"coefficient should be positive, got {}",
fitted.coefficients()[0]
);
assert!(fitted.intercept().is_some());
let acc = fitted.score(&x, &y);
assert!(acc > 0.8, "accuracy should be > 0.8, got {}", acc);
assert!(fitted.n_iter() > 0);
}
#[test]
fn test_predict_proba_range() {
let (x, y) = create_test_data(100);
let fitted = LogisticRegression::builder()
.build()
.fit(&x, &y)
.expect("model should fit");
let probs = fitted.predict_proba(&x);
for i in 0..probs.nrows() {
assert!(
probs[i] >= 0.0 && probs[i] <= 1.0,
"probability at index {} is {}, expected [0, 1]",
i,
probs[i]
);
}
}
#[test]
fn test_predict_class_labels() {
let (x, y) = create_test_data(100);
let fitted = LogisticRegression::builder()
.build()
.fit(&x, &y)
.expect("model should fit");
let labels = fitted.predict(&x);
for i in 0..labels.nrows() {
assert!(
labels[i] == 0.0 || labels[i] == 1.0,
"label at index {} is {}, expected 0.0 or 1.0",
i,
labels[i]
);
}
}
#[test]
fn test_decision_function_sign() {
let (x, y) = create_test_data(100);
let fitted = LogisticRegression::builder()
.build()
.fit(&x, &y)
.expect("model should fit");
let decision = fitted.decision_function(&x);
let labels = fitted.predict(&x);
for i in 0..decision.nrows() {
if decision[i] > 0.0 {
assert_eq!(
labels[i], 1.0,
"positive decision at index {} should yield class 1",
i
);
} else if decision[i] < 0.0 {
assert_eq!(
labels[i], 0.0,
"negative decision at index {} should yield class 0",
i
);
}
}
}
#[test]
fn test_score_accuracy() {
let (x, y) = create_test_data(100);
let fitted = LogisticRegression::builder()
.build()
.fit(&x, &y)
.expect("model should fit");
let acc = fitted.score(&x, &y);
let predictions = fitted.predict(&x);
let n = y.nrows();
let correct = predictions
.iter()
.zip(y.iter())
.filter(|(&p, &a)| p == a)
.count();
let expected_acc = correct as f64 / n as f64;
assert!(
(acc - expected_acc).abs() < f64::EPSILON,
"score() returned {}, expected {}",
acc,
expected_acc
);
}
#[test]
fn test_threshold() {
let (x, y) = create_test_data(100);
let fitted_default = LogisticRegression::builder()
.threshold(0.5)
.build()
.fit(&x, &y)
.expect("model should fit");
let fitted_high = LogisticRegression::builder()
.threshold(0.9)
.build()
.fit(&x, &y)
.expect("model should fit");
let labels_default = fitted_default.predict(&x);
let labels_high = fitted_high.predict(&x);
let count_ones_default: usize = labels_default.iter().filter(|&&v| v == 1.0).count();
let count_ones_high: usize = labels_high.iter().filter(|&&v| v == 1.0).count();
assert!(
count_ones_high <= count_ones_default,
"higher threshold should predict fewer 1s: {} (threshold=0.9) vs {} (threshold=0.5)",
count_ones_high,
count_ones_default
);
}
#[test]
fn test_l2_regularization() {
let (x, y) = create_test_data(100);
let fitted_no_reg = LogisticRegression::builder()
.build()
.fit(&x, &y)
.expect("model should fit without regularization");
let fitted_l2 = LogisticRegression::builder()
.l2(10.0)
.build()
.fit(&x, &y)
.expect("model should fit with L2 regularization");
let coef_no_reg = fitted_no_reg.coefficients()[0].abs();
let coef_l2 = fitted_l2.coefficients()[0].abs();
assert!(
coef_l2 < coef_no_reg,
"L2 regularization should shrink coefficient: |{}| (L2) should be < |{}| (no reg)",
coef_l2,
coef_no_reg
);
}
#[test]
fn test_invalid_y_values() {
let x = Mat::from_fn(10, 1, |i, _| i as f64);
let y = Col::from_fn(10, |i| i as f64 * 0.1);
let model = LogisticRegression::builder().build();
let result = model.fit(&x, &y);
assert!(result.is_err(), "should reject non-binary y values");
}
#[test]
fn test_c_convention() {
let builder = LogisticRegressionBuilder::default().c(0.1);
assert_eq!(builder.penalty, Penalty::L2(10.0));
}
#[test]
fn test_inner_access() {
let (x, y) = create_test_data(100);
let fitted = LogisticRegression::builder()
.build()
.fit(&x, &y)
.expect("model should fit");
let inner = fitted.inner();
let result = inner.result();
assert_eq!(result.coefficients.nrows(), 1);
}
}