inferust 0.1.12

Statistical modeling for Rust — OLS/WLS regression, GLM, survival analysis, ARIMA/VAR, nonparametric tests, and more. A statsmodels-style library.
Documentation
use std::env;
use std::hint::black_box;
use std::time::{Duration, Instant};

use inferust::regression::{Ols, OlsSolver};

fn main() {
    let n = arg_usize("--rows", 10_000);
    let p = arg_usize("--features", 8);
    let repeats = arg_usize("--repeats", 10);
    let warmups = arg_usize("--warmups", 2);
    let solver = arg_solver();

    let (x, y) = dataset(n, p);
    let feature_names = (0..p).map(|i| format!("x{}", i + 1)).collect::<Vec<_>>();
    let model = Ols::new()
        .with_feature_names(feature_names)
        .with_solver(solver);

    for _ in 0..warmups {
        black_box(model.fit(&x, &y).expect("warmup fit should succeed"));
    }

    let mut timings = Vec::with_capacity(repeats);
    let mut checksum = 0.0;
    for _ in 0..repeats {
        let started = Instant::now();
        let result = model.fit(&x, &y).expect("benchmark fit should succeed");
        let elapsed = started.elapsed();
        checksum += result.coefficients.iter().sum::<f64>();
        timings.push(elapsed);
    }

    timings.sort_unstable();
    println!(
        "engine=rust-inferust solver={} rows={} features={} repeats={} warmups={} min_ms={:.3} median_ms={:.3} mean_ms={:.3} checksum={:.8}",
        solver_name(solver),
        n,
        p,
        repeats,
        warmups,
        millis(timings[0]),
        millis(timings[timings.len() / 2]),
        mean_ms(&timings),
        checksum
    );
}

fn dataset(n: usize, p: usize) -> (Vec<Vec<f64>>, Vec<f64>) {
    let mut x = Vec::with_capacity(n);
    let mut y = Vec::with_capacity(n);
    let beta = (0..p)
        .map(|j| (j as f64 + 1.0) / p as f64)
        .collect::<Vec<_>>();

    for i in 0..n {
        let mut row = Vec::with_capacity(p);
        let mut yi = 1.5;
        for (j, beta_j) in beta.iter().enumerate() {
            let value = ((i + 1) as f64 * (j + 1) as f64 * 0.001).sin()
                + ((i + j + 3) as f64 * 0.017).cos()
                + ((i % 97) as f64) * 0.0001;
            row.push(value);
            yi += beta_j * value;
        }
        yi += ((i + 11) as f64 * 0.037).sin() * 0.01;
        x.push(row);
        y.push(yi);
    }

    (x, y)
}

fn arg_solver() -> OlsSolver {
    let mut args = env::args().skip(1);
    while let Some(arg) = args.next() {
        let value = if arg == "--solver" {
            args.next()
                .unwrap_or_else(|| panic!("missing value for --solver"))
        } else if let Some(value) = arg.strip_prefix("--solver=") {
            value.to_string()
        } else {
            continue;
        };

        return match value.as_str() {
            "cholesky" | "fast" => OlsSolver::Cholesky,
            "svd" | "stable" => OlsSolver::Svd,
            _ => panic!("invalid value for --solver: {value}"),
        };
    }
    OlsSolver::Cholesky
}

fn solver_name(solver: OlsSolver) -> &'static str {
    match solver {
        OlsSolver::Cholesky => "cholesky",
        OlsSolver::Svd => "svd",
    }
}

fn arg_usize(flag: &str, default: usize) -> usize {
    let mut args = env::args().skip(1);
    while let Some(arg) = args.next() {
        if arg == flag {
            return args
                .next()
                .unwrap_or_else(|| panic!("missing value for {flag}"))
                .parse()
                .unwrap_or_else(|_| panic!("invalid value for {flag}"));
        }
        if let Some(value) = arg.strip_prefix(&format!("{flag}=")) {
            return value
                .parse()
                .unwrap_or_else(|_| panic!("invalid value for {flag}"));
        }
    }
    default
}

fn millis(duration: Duration) -> f64 {
    duration.as_secs_f64() * 1_000.0
}

fn mean_ms(durations: &[Duration]) -> f64 {
    durations.iter().map(|d| millis(*d)).sum::<f64>() / durations.len() as f64
}