tinylearn 0.1.0

Machine learning in WebAssembly and no_std environments
Documentation
use approx::assert_abs_diff_eq;
use ndarray::array;
use ndarray::Array1;
use ndarray::Array2;
use ndarray::Axis;
use tinylearn::lm;

#[test]
fn test_linear_regression() {
    tracing_subscriber::fmt::init();

    let mut reader = csv::Reader::from_path("tests/basic.csv").unwrap();

    let headers = reader.headers().unwrap().clone();

    let record_count = reader.records().count();
    reader = csv::Reader::from_path("tests/basic.csv").unwrap();

    let mut ys = Array1::<f64>::zeros(record_count);
    let mut xs = Array2::<f64>::zeros((record_count, headers.len() - 1));

    for (i, result) in reader.records().enumerate() {
        let record = result.unwrap();
        ys[i] = record[0].parse::<f64>().unwrap();
        for j in 1..record.len() {
            xs[[i, j - 1]] = record[j].parse::<f64>().unwrap();
        }
    }
    for row in xs.axis_iter(Axis(0)) {
        tracing::info!("row: {:?}", row);
    }
    tracing::info!("ys: {:?}", ys);

    let model = lm::LinearRegression {
        fit_intercept: true,
    };
    let model = model.fit(&xs, &ys);
    tracing::info!("model: {:?}", model);
    assert_abs_diff_eq!(
        model.coefficients,
        &array![-0.51499, 0.51175],
        epsilon = 1e-3
    );
    assert_abs_diff_eq!(model.intercept, 19.24392220, epsilon = 1e-8);

    let model = lm::LinearRegression {
        fit_intercept: false,
    };
    let model = model.fit(&xs, &ys);
    tracing::info!("model: {:?}", model);
    assert_abs_diff_eq!(
        model.coefficients,
        &array![1.95896584, -0.20944023],
        epsilon = 1e-7
    );
    assert_abs_diff_eq!(model.intercept, 0.0, epsilon = 1e-8);
}