scry-learn 0.1.0

Machine learning toolkit in pure Rust
Documentation
#![allow(clippy::needless_range_loop)]
//! Golden reference tests: scry-learn vs sklearn on real UCI datasets.
//!
//! These tests load the canonical Iris, Wine, Breast Cancer, Digits, and
//! California Housing datasets from CSV fixtures generated by sklearn,
//! train scry-learn models with matching hyperparameters, and assert that
//! the predictions match within tolerance.

use std::path::PathBuf;

// ─── Fixture loading helpers ─────────────────────────────────────────────

fn fixtures_dir() -> PathBuf {
    PathBuf::from(env!("CARGO_MANIFEST_DIR"))
        .join("tests")
        .join("fixtures")
}

/// Load a feature CSV into column-major format for Dataset.
fn load_features_csv(name: &str) -> (Vec<Vec<f64>>, Vec<String>) {
    let path = fixtures_dir().join(name);
    let mut rdr = csv::Reader::from_path(&path)
        .unwrap_or_else(|e| panic!("Failed to open {}: {e}", path.display()));
    let headers: Vec<String> = rdr.headers().unwrap().iter().map(String::from).collect();
    let n_cols = headers.len();

    let mut rows: Vec<Vec<f64>> = Vec::new();
    for result in rdr.records() {
        let record = result.unwrap();
        let row: Vec<f64> = record.iter().map(|s| s.parse::<f64>().unwrap()).collect();
        rows.push(row);
    }

    // Transpose to column-major.
    let mut cols = vec![vec![0.0; rows.len()]; n_cols];
    for (i, row) in rows.iter().enumerate() {
        for (j, &val) in row.iter().enumerate() {
            cols[j][i] = val;
        }
    }
    (cols, headers)
}

/// Load a target CSV.
fn load_target_csv(name: &str) -> Vec<f64> {
    let path = fixtures_dir().join(name);
    let mut rdr = csv::Reader::from_path(&path)
        .unwrap_or_else(|e| panic!("Failed to open {}: {e}", path.display()));
    let mut target = Vec::new();
    for result in rdr.records() {
        let record = result.unwrap();
        target.push(record[0].parse::<f64>().unwrap());
    }
    target
}

/// Load sklearn predictions JSON.
fn load_sklearn_json() -> serde_json::Value {
    let path = fixtures_dir().join("sklearn_predictions.json");
    let data = std::fs::read_to_string(&path)
        .unwrap_or_else(|e| panic!("Failed to read {}: {e}", path.display()));
    serde_json::from_str(&data).unwrap()
}

/// Compute classification accuracy.
fn accuracy(y_true: &[f64], y_pred: &[f64]) -> f64 {
    assert_eq!(y_true.len(), y_pred.len());
    let correct = y_true
        .iter()
        .zip(y_pred.iter())
        .filter(|(a, b)| (**a - **b).abs() < 0.5)
        .count();
    correct as f64 / y_true.len() as f64
}

/// Compute R² score.
fn r2_score(y_true: &[f64], y_pred: &[f64]) -> f64 {
    let mean = y_true.iter().sum::<f64>() / y_true.len() as f64;
    let ss_res: f64 = y_true
        .iter()
        .zip(y_pred)
        .map(|(t, p)| (t - p).powi(2))
        .sum();
    let ss_tot: f64 = y_true.iter().map(|t| (t - mean).powi(2)).sum();
    1.0 - ss_res / ss_tot
}

// ─── Dataset loading ─────────────────────────────────────────────────────

fn load_dataset(base: &str) -> scry_learn::dataset::Dataset {
    let (features, feat_names) = load_features_csv(&format!("{base}_features.csv"));
    let target = load_target_csv(&format!("{base}_target.csv"));
    scry_learn::dataset::Dataset::new(features, target, feat_names, "target")
}

// ═════════════════════════════════════════════════════════════════════════
// 1. Decision Tree tests (deterministic — should match closely)
// ═════════════════════════════════════════════════════════════════════════

#[test]
fn golden_dt_iris() {
    let data = load_dataset("iris");
    let sklearn = load_sklearn_json();
    let sklearn_acc = sklearn["dt_iris"]["accuracy"].as_f64().unwrap();

    let mut dt = scry_learn::tree::DecisionTreeClassifier::new().max_depth(5);
    dt.fit(&data).unwrap();
    let matrix = data.feature_matrix();
    let preds = dt.predict(&matrix).unwrap();
    let acc = accuracy(&data.target, &preds);

    assert!(
        acc >= sklearn_acc - 0.02,
        "DT Iris: scry {acc:.4} vs sklearn {sklearn_acc:.4} — gap > 2%"
    );
}

#[test]
fn golden_dt_wine() {
    let data = load_dataset("wine");
    let sklearn = load_sklearn_json();
    let sklearn_acc = sklearn["dt_wine"]["accuracy"].as_f64().unwrap();

    let mut dt = scry_learn::tree::DecisionTreeClassifier::new().max_depth(5);
    dt.fit(&data).unwrap();
    let matrix = data.feature_matrix();
    let preds = dt.predict(&matrix).unwrap();
    let acc = accuracy(&data.target, &preds);

    assert!(
        acc >= sklearn_acc - 0.02,
        "DT Wine: scry {acc:.4} vs sklearn {sklearn_acc:.4} — gap > 2%"
    );
}

#[test]
fn golden_dt_breast_cancer() {
    let data = load_dataset("breast_cancer");
    let sklearn = load_sklearn_json();
    let sklearn_acc = sklearn["dt_breast_cancer"]["accuracy"].as_f64().unwrap();

    let mut dt = scry_learn::tree::DecisionTreeClassifier::new().max_depth(5);
    dt.fit(&data).unwrap();
    let matrix = data.feature_matrix();
    let preds = dt.predict(&matrix).unwrap();
    let acc = accuracy(&data.target, &preds);

    assert!(
        acc >= sklearn_acc - 0.02,
        "DT Breast Cancer: scry {acc:.4} vs sklearn {sklearn_acc:.4} — gap > 2%"
    );
}

// ═════════════════════════════════════════════════════════════════════════
// 2. KNN tests
// ═════════════════════════════════════════════════════════════════════════

#[test]
fn golden_knn_iris() {
    let data = load_dataset("iris");
    let sklearn = load_sklearn_json();
    let sklearn_acc = sklearn["knn_iris"]["accuracy"].as_f64().unwrap();

    let mut knn = scry_learn::neighbors::KnnClassifier::new().k(5);
    knn.fit(&data).unwrap();
    let matrix = data.feature_matrix();
    let preds = knn.predict(&matrix).unwrap();
    let acc = accuracy(&data.target, &preds);

    assert!(
        acc >= sklearn_acc - 0.02,
        "KNN Iris: scry {acc:.4} vs sklearn {sklearn_acc:.4} — gap > 2%"
    );
}

#[test]
fn golden_knn_wine() {
    let data = load_dataset("wine");
    let sklearn = load_sklearn_json();
    let sklearn_acc = sklearn["knn_wine"]["accuracy"].as_f64().unwrap();

    let mut knn = scry_learn::neighbors::KnnClassifier::new().k(5);
    knn.fit(&data).unwrap();
    let matrix = data.feature_matrix();
    let preds = knn.predict(&matrix).unwrap();
    let acc = accuracy(&data.target, &preds);

    // KNN on Wine without scaling may differ — use relaxed tolerance
    assert!(
        acc >= sklearn_acc - 0.05,
        "KNN Wine: scry {acc:.4} vs sklearn {sklearn_acc:.4} — gap > 5%"
    );
}

// ═════════════════════════════════════════════════════════════════════════
// 3. Logistic Regression tests (scaled data)
// ═════════════════════════════════════════════════════════════════════════

#[test]
fn golden_logreg_iris() {
    let mut data = load_dataset("iris");

    // Scale features (sklearn uses StandardScaler in the fixture)
    let mut scaler = scry_learn::preprocess::StandardScaler::new();
    scry_learn::preprocess::Transformer::fit(&mut scaler, &data).unwrap();
    scry_learn::preprocess::Transformer::transform(&scaler, &mut data).unwrap();

    let sklearn = load_sklearn_json();
    let sklearn_acc = sklearn["logreg_iris"]["accuracy"].as_f64().unwrap();

    let mut lr = scry_learn::linear::LogisticRegression::new()
        .alpha(0.0)
        .max_iter(200);
    lr.fit(&data).unwrap();
    let matrix = data.feature_matrix();
    let preds = lr.predict(&matrix).unwrap();
    let acc = accuracy(&data.target, &preds);

    assert!(
        acc >= sklearn_acc - 0.03,
        "LogReg Iris: scry {acc:.4} vs sklearn {sklearn_acc:.4} — gap > 3%"
    );
}

// ═════════════════════════════════════════════════════════════════════════
// 4. Linear Regression on California Housing
// ═════════════════════════════════════════════════════════════════════════

#[test]
fn golden_linreg_california() {
    let mut data = load_dataset("california");

    // Scale features (sklearn uses StandardScaler)
    let mut scaler = scry_learn::preprocess::StandardScaler::new();
    scry_learn::preprocess::Transformer::fit(&mut scaler, &data).unwrap();
    scry_learn::preprocess::Transformer::transform(&scaler, &mut data).unwrap();

    let sklearn = load_sklearn_json();
    let sklearn_r2 = sklearn["linreg_california"]["r2_score"].as_f64().unwrap();

    let mut lr = scry_learn::linear::LinearRegression::new();
    lr.fit(&data).unwrap();
    let matrix = data.feature_matrix();
    let preds = lr.predict(&matrix).unwrap();
    let r2 = r2_score(&data.target, &preds);

    assert!(
        (r2 - sklearn_r2).abs() < 0.02,
        "LinReg California: scry R²={r2:.4} vs sklearn R²={sklearn_r2:.4} — gap > 0.02"
    );
}

// ═════════════════════════════════════════════════════════════════════════
// 5. KMeans on Iris
// ═════════════════════════════════════════════════════════════════════════

#[test]
fn golden_kmeans_iris() {
    let data = load_dataset("iris");
    let sklearn = load_sklearn_json();
    let sklearn_inertia = sklearn["kmeans_iris"]["inertia"].as_f64().unwrap();

    let mut km = scry_learn::cluster::KMeans::new(3)
        .seed(42)
        .max_iter(300)
        .n_init(10);
    km.fit(&data).unwrap();
    let inertia = km.inertia();

    // KMeans is stochastic — inertia should be in the same ballpark.
    // sklearn gets 78.85, a good implementation should be within 20%.
    let ratio = inertia / sklearn_inertia;
    assert!(
        (0.8..=1.3).contains(&ratio),
        "KMeans Iris: scry inertia={inertia:.1} vs sklearn {sklearn_inertia:.1} — ratio {ratio:.2} out of [0.8, 1.3]"
    );
}

// ═════════════════════════════════════════════════════════════════════════
// 6. StandardScaler on Iris (transform parity)
// ═════════════════════════════════════════════════════════════════════════

#[test]
fn golden_scaler_iris() {
    let data = load_dataset("iris");
    let sklearn = load_sklearn_json();
    let sklearn_means: Vec<f64> = sklearn["scaler_iris"]["means"]
        .as_array()
        .unwrap()
        .iter()
        .map(|v| v.as_f64().unwrap())
        .collect();
    let sklearn_stds: Vec<f64> = sklearn["scaler_iris"]["stds"]
        .as_array()
        .unwrap()
        .iter()
        .map(|v| v.as_f64().unwrap())
        .collect();

    let mut scaler = scry_learn::preprocess::StandardScaler::new();
    scry_learn::preprocess::Transformer::fit(&mut scaler, &data).unwrap();

    // Check means.
    for j in 0..data.n_features() {
        let mean = data.features[j].iter().sum::<f64>() / data.n_samples() as f64;
        assert!(
            (mean - sklearn_means[j]).abs() < 1e-6,
            "Scaler Iris mean[{j}]: scry {mean:.6} vs sklearn {:.6}",
            sklearn_means[j]
        );
    }

    // Check stds.
    for j in 0..data.n_features() {
        let mean = data.features[j].iter().sum::<f64>() / data.n_samples() as f64;
        let var = data.features[j]
            .iter()
            .map(|v| (v - mean).powi(2))
            .sum::<f64>()
            / data.n_samples() as f64;
        let std = var.sqrt();
        assert!(
            (std - sklearn_stds[j]).abs() < 1e-4,
            "Scaler Iris std[{j}]: scry {std:.6} vs sklearn {:.6}",
            sklearn_stds[j]
        );
    }
}

// ═════════════════════════════════════════════════════════════════════════
// 7. PCA on Iris (explained variance ratio)
// ═════════════════════════════════════════════════════════════════════════

#[test]
fn golden_pca_iris() {
    let data = load_dataset("iris");
    let sklearn = load_sklearn_json();
    let sklearn_evr: Vec<f64> = sklearn["pca_iris"]["explained_variance_ratio"]
        .as_array()
        .unwrap()
        .iter()
        .map(|v| v.as_f64().unwrap())
        .collect();

    let mut pca = scry_learn::preprocess::Pca::with_n_components(2);
    scry_learn::preprocess::Transformer::fit(&mut pca, &data).unwrap();
    let evr = pca.explained_variance_ratio();

    // Compare explained variance ratios within tolerance.
    for (i, (&scry_ev, &sk_ev)) in evr.iter().zip(sklearn_evr.iter()).enumerate() {
        assert!(
            (scry_ev - sk_ev).abs() < 0.02,
            "PCA Iris EVR[{i}]: scry {scry_ev:.4} vs sklearn {sk_ev:.4} — gap > 0.02"
        );
    }
}

// ═════════════════════════════════════════════════════════════════════════
// 8. Digits dataset — high-dimensional multiclass
// ═════════════════════════════════════════════════════════════════════════

#[test]
fn golden_dt_digits() {
    let data = load_dataset("digits");

    let mut dt = scry_learn::tree::DecisionTreeClassifier::new().max_depth(15);
    dt.fit(&data).unwrap();
    let matrix = data.feature_matrix();
    let preds = dt.predict(&matrix).unwrap();
    let acc = accuracy(&data.target, &preds);

    // Decision tree on digits should overfit to ~100% on train set.
    assert!(
        acc >= 0.95,
        "DT Digits: scry accuracy {acc:.4} < 95% — expected high training accuracy"
    );
}

#[test]
fn golden_knn_digits() {
    let data = load_dataset("digits");

    let mut knn = scry_learn::neighbors::KnnClassifier::new().k(5);
    knn.fit(&data).unwrap();
    let matrix = data.feature_matrix();
    let preds = knn.predict(&matrix).unwrap();
    let acc = accuracy(&data.target, &preds);

    // KNN on digits with k=5 should get ≥98% on training data.
    assert!(acc >= 0.98, "KNN Digits: scry accuracy {acc:.4} < 98%");
}