inferust 0.1.12

Statistical modeling for Rust — OLS/WLS regression, GLM, survival analysis, ARIMA/VAR, nonparametric tests, and more. A statsmodels-style library.
Documentation
use crate::error::{InferustError, Result};

#[derive(Debug, Clone)]
pub struct RegressionMetrics {
    pub mae: f64,
    pub mse: f64,
    pub rmse: f64,
    pub mape: f64,
    pub r_squared: f64,
}

#[derive(Debug, Clone)]
pub struct ConfusionMatrix {
    pub threshold: f64,
    pub true_positives: usize,
    pub false_positives: usize,
    pub true_negatives: usize,
    pub false_negatives: usize,
}

#[derive(Debug, Clone)]
pub struct BootstrapInterval {
    pub estimate: f64,
    pub lower: f64,
    pub upper: f64,
}

pub fn regression_metrics(y_true: &[f64], y_pred: &[f64]) -> Result<RegressionMetrics> {
    validate_same_len(y_true, y_pred)?;
    let n = y_true.len() as f64;
    let mean = y_true.iter().sum::<f64>() / n;
    let mae = y_true
        .iter()
        .zip(y_pred.iter())
        .map(|(actual, pred)| (actual - pred).abs())
        .sum::<f64>()
        / n;
    let mse = y_true
        .iter()
        .zip(y_pred.iter())
        .map(|(actual, pred)| (actual - pred).powi(2))
        .sum::<f64>()
        / n;
    let nonzero_actuals = y_true
        .iter()
        .zip(y_pred.iter())
        .filter(|(actual, _)| actual.abs() > f64::EPSILON)
        .collect::<Vec<_>>();
    let mape = if nonzero_actuals.is_empty() {
        f64::NAN
    } else {
        nonzero_actuals
            .iter()
            .map(|(actual, pred)| ((*actual - *pred) / *actual).abs())
            .sum::<f64>()
            / nonzero_actuals.len() as f64
    };
    let ss_res = y_true
        .iter()
        .zip(y_pred.iter())
        .map(|(actual, pred)| (actual - pred).powi(2))
        .sum::<f64>();
    let ss_tot = y_true
        .iter()
        .map(|actual| (actual - mean).powi(2))
        .sum::<f64>();
    Ok(RegressionMetrics {
        mae,
        mse,
        rmse: mse.sqrt(),
        mape,
        r_squared: 1.0 - ss_res / ss_tot.max(1e-12),
    })
}

pub fn confusion_matrix(
    y_true: &[f64],
    probabilities: &[f64],
    threshold: f64,
) -> Result<ConfusionMatrix> {
    validate_same_len(y_true, probabilities)?;
    if !(0.0..1.0).contains(&threshold) {
        return Err(InferustError::InvalidInput(
            "threshold must be between 0 and 1".into(),
        ));
    }
    let mut matrix = ConfusionMatrix {
        threshold,
        true_positives: 0,
        false_positives: 0,
        true_negatives: 0,
        false_negatives: 0,
    };
    for (actual, probability) in y_true.iter().zip(probabilities.iter()) {
        match (*probability >= threshold, *actual == 1.0) {
            (true, true) => matrix.true_positives += 1,
            (true, false) => matrix.false_positives += 1,
            (false, false) => matrix.true_negatives += 1,
            (false, true) => matrix.false_negatives += 1,
        }
    }
    Ok(matrix)
}

pub fn bootstrap_mean_interval(
    data: &[f64],
    resamples: usize,
    alpha: f64,
) -> Result<BootstrapInterval> {
    if data.is_empty() {
        return Err(InferustError::InsufficientData { needed: 1, got: 0 });
    }
    if resamples < 2 {
        return Err(InferustError::InsufficientData {
            needed: 2,
            got: resamples,
        });
    }
    if !(0.0..1.0).contains(&alpha) {
        return Err(InferustError::InvalidInput(
            "alpha must be between 0 and 1".into(),
        ));
    }
    let estimate = data.iter().sum::<f64>() / data.len() as f64;
    let mut rng = Lcg::new(0x5eed_u64);
    let mut means = Vec::with_capacity(resamples);
    for _ in 0..resamples {
        let mut total = 0.0;
        for _ in 0..data.len() {
            total += data[rng.next_index(data.len())];
        }
        means.push(total / data.len() as f64);
    }
    means.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
    let lower_idx = ((alpha / 2.0) * (resamples - 1) as f64).round() as usize;
    let upper_idx = ((1.0 - alpha / 2.0) * (resamples - 1) as f64).round() as usize;
    Ok(BootstrapInterval {
        estimate,
        lower: means[lower_idx],
        upper: means[upper_idx],
    })
}

fn validate_same_len(left: &[f64], right: &[f64]) -> Result<()> {
    if left.len() != right.len() {
        return Err(InferustError::DimensionMismatch {
            x_rows: right.len(),
            y_len: left.len(),
        });
    }
    if left.is_empty() {
        return Err(InferustError::InsufficientData { needed: 1, got: 0 });
    }
    Ok(())
}

struct Lcg {
    state: u64,
}

impl Lcg {
    fn new(seed: u64) -> Self {
        Self { state: seed }
    }

    fn next_index(&mut self, len: usize) -> usize {
        self.state = self.state.wrapping_mul(6364136223846793005).wrapping_add(1);
        ((self.state >> 32) as usize) % len
    }
}

#[cfg(test)]
mod tests {
    use super::{bootstrap_mean_interval, confusion_matrix, regression_metrics};

    #[test]
    fn computes_evaluation_metrics() {
        let y = vec![1.0, 2.0, 3.0, 4.0];
        let pred = vec![1.1, 1.9, 3.2, 3.8];
        let metrics = regression_metrics(&y, &pred).unwrap();
        assert!(metrics.rmse > 0.0);
        let matrix = confusion_matrix(&[0.0, 1.0, 1.0], &[0.2, 0.7, 0.4], 0.5).unwrap();
        assert_eq!(matrix.true_positives, 1);
        let interval = bootstrap_mean_interval(&y, 50, 0.05).unwrap();
        assert!(interval.lower <= interval.upper);
    }

    #[test]
    fn regression_metrics_handles_zero_actual_mape() {
        let metrics = regression_metrics(&[0.0, 0.0], &[0.1, 0.2]).unwrap();
        assert!(metrics.mape.is_nan());
    }
}