rusty-machine 0.5.4

A machine learning library.
Documentation
use rm::linalg::Matrix;
use rm::linalg::Vector;
use rm::learning::SupModel;
use rm::learning::lin_reg::LinRegressor;
use libnum::abs;

#[test]
fn test_optimized_regression() {
    let mut lin_mod = LinRegressor::default();
    let inputs = Matrix::new(3, 1, vec![2.0, 3.0, 4.0]);
    let targets = Vector::new(vec![5.0, 6.0, 7.0]);

    lin_mod.train_with_optimization(&inputs, &targets);

    let _ = lin_mod.parameters().unwrap();
}

#[test]
fn test_regression() {
    let mut lin_mod = LinRegressor::default();
    let inputs = Matrix::new(3, 1, vec![2.0, 3.0, 4.0]);
    let targets = Vector::new(vec![5.0, 6.0, 7.0]);

    lin_mod.train(&inputs, &targets).unwrap();

    let parameters = lin_mod.parameters().unwrap();

    let err_1 = abs(parameters[0] - 3.0);
    let err_2 = abs(parameters[1] - 1.0);

    assert!(err_1 < 1e-8);
    assert!(err_2 < 1e-8);
}

#[test]
#[should_panic]
fn test_no_train_params() {
    let lin_mod = LinRegressor::default();

    let _ = lin_mod.parameters().unwrap();
}

#[test]
#[should_panic]
fn test_no_train_predict() {
    let lin_mod = LinRegressor::default();
    let inputs = Matrix::new(3, 2, vec![1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);

    let _ = lin_mod.predict(&inputs).unwrap();
}

#[cfg(feature = "datasets")]
#[test]
fn test_regression_datasets_trees() {
    use rm::datasets::trees;
    let trees = trees::load();

    let mut lin_mod = LinRegressor::default();
    lin_mod.train(&trees.data(), &trees.target()).unwrap();
    let params = lin_mod.parameters().unwrap();
    assert_eq!(params, &Vector::new(vec![-57.98765891838409, 4.708160503017506, 0.3392512342447438]));

    let predicted = lin_mod.predict(&trees.data()).unwrap();
    let expected = vec![4.837659653793278, 4.55385163347481, 4.816981265588826, 15.874115228921276,
                        19.869008437727473, 21.018326956518717, 16.192688074961563, 19.245949183164257,
                        21.413021404689726, 20.187581283767756, 22.015402271048487, 21.468464618616007,
                        21.468464618616007, 20.50615412980805, 23.954109686181766, 27.852202904652785,
                        31.583966481344966, 33.806481916796706, 30.60097760433255, 28.697035014921106,
                        34.388184394951004, 36.008318964043994, 35.38525970948079, 41.76899799551756,
                        44.87770231764652, 50.942867757643015, 52.223751092491256, 53.42851282520877,
                        53.899328875510534, 53.899328875510534, 68.51530482306926];
    assert_eq!(predicted, Vector::new(expected));
}