use std::io::Read;
use csv::ReaderBuilder;
use flate2::read::GzDecoder;
use linfa::Dataset;
use ndarray::prelude::*;
use ndarray_csv::{Array2Reader, ReadError};
pub fn array_from_gz_csv<R: Read>(
gz: R,
has_headers: bool,
separator: u8,
) -> Result<Array2<f64>, ReadError> {
let file = GzDecoder::new(gz);
array_from_csv(file, has_headers, separator)
}
pub fn array_from_csv<R: Read>(
csv: R,
has_headers: bool,
separator: u8,
) -> Result<Array2<f64>, ReadError> {
let mut reader = ReaderBuilder::new()
.has_headers(has_headers)
.delimiter(separator)
.from_reader(csv);
reader.deserialize_array2_dynamic()
}
#[cfg(feature = "iris")]
pub fn iris() -> Dataset<f64, usize, Ix1> {
let data = include_bytes!("../data/iris.csv.gz");
let array = array_from_gz_csv(&data[..], true, b',').unwrap();
let (data, targets) = (
array.slice(s![.., 0..4]).to_owned(),
array.column(4).to_owned(),
);
let feature_names = vec!["sepal length", "sepal width", "petal length", "petal width"];
Dataset::new(data, targets)
.map_targets(|x| *x as usize)
.with_feature_names(feature_names)
}
#[cfg(feature = "diabetes")]
pub fn diabetes() -> Dataset<f64, f64, Ix1> {
let data = include_bytes!("../data/diabetes_data.csv.gz");
let data = array_from_gz_csv(&data[..], true, b',').unwrap();
let targets = include_bytes!("../data/diabetes_target.csv.gz");
let targets = array_from_gz_csv(&targets[..], true, b',')
.unwrap()
.column(0)
.to_owned();
let feature_names = vec![
"age",
"sex",
"body mass index",
"blood pressure",
"t-cells",
"low-density lipoproteins",
"high-density lipoproteins",
"thyroid stimulating hormone",
"lamotrigine",
"blood sugar level",
];
Dataset::new(data, targets).with_feature_names(feature_names)
}
#[cfg(feature = "winequality")]
pub fn winequality() -> Dataset<f64, usize, Ix1> {
let data = include_bytes!("../data/winequality-red.csv.gz");
let array = array_from_gz_csv(&data[..], true, b',').unwrap();
let (data, targets) = (
array.slice(s![.., 0..11]).to_owned(),
array.column(11).to_owned(),
);
let feature_names = vec![
"fixed acidity",
"volatile acidity",
"citric acid",
"residual sugar",
"chlorides",
"free sulfur dioxide",
"total sulfur dioxide",
"density",
"pH",
"sulphates",
"alcohol",
];
Dataset::new(data, targets)
.map_targets(|x| *x as usize)
.with_feature_names(feature_names)
}
#[cfg(feature = "linnerud")]
pub fn linnerud() -> Dataset<f64, f64> {
let input_data = include_bytes!("../data/linnerud_exercise.csv.gz");
let input_array = array_from_gz_csv(&input_data[..], true, b',').unwrap();
let output_data = include_bytes!("../data/linnerud_physiological.csv.gz");
let output_array = array_from_gz_csv(&output_data[..], true, b',').unwrap();
let feature_names = vec!["Chins", "Situps", "Jumps"];
Dataset::new(input_array, output_array).with_feature_names(feature_names)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use linfa::prelude::*;
#[cfg(feature = "iris")]
#[test]
fn test_iris() {
let ds = iris();
assert_eq!((ds.nsamples(), ds.nfeatures(), ds.ntargets()), (150, 4, 1));
assert_eq!(
ds.feature_names(),
&["sepal length", "sepal width", "petal length", "petal width"]
);
assert_abs_diff_eq!(
ds.label_frequencies()
.into_iter()
.map(|b| b.1)
.collect::<Array1<_>>(),
array![50., 50., 50.]
);
let _pcc = ds.pearson_correlation_with_p_value(100);
let mean_features = ds.records().mean_axis(Axis(0)).unwrap();
assert_abs_diff_eq!(
mean_features,
array![5.84, 3.05, 3.75, 1.20],
epsilon = 0.01
);
}
#[cfg(feature = "diabetes")]
#[test]
fn test_diabetes() {
let ds = diabetes();
assert_eq!((ds.nsamples(), ds.nfeatures(), ds.ntargets()), (441, 10, 1));
let _pcc = ds.pearson_correlation_with_p_value(100);
let mean_features = ds.records().mean_axis(Axis(0)).unwrap();
assert_abs_diff_eq!(mean_features, Array1::zeros(10), epsilon = 0.005);
}
#[cfg(feature = "winequality")]
#[test]
fn test_winequality() {
use approx::abs_diff_eq;
let ds = winequality();
assert_eq!(
(ds.nsamples(), ds.nfeatures(), ds.ntargets()),
(1599, 11, 1)
);
let feature_names = vec![
"fixed acidity",
"volatile acidity",
"citric acid",
"residual sugar",
"chlorides",
"free sulfur dioxide",
"total sulfur dioxide",
"density",
"pH",
"sulphates",
"alcohol",
];
assert_eq!(ds.feature_names(), feature_names);
let compare_to = vec![
(5, 681.0),
(7, 199.0),
(6, 638.0),
(8, 18.0),
(3, 10.0),
(4, 53.0),
];
let freqs = ds.label_frequencies();
assert!(compare_to.into_iter().all(|(key, val)| {
freqs
.get(&key)
.map(|x| abs_diff_eq!(*x, val))
.unwrap_or(false)
}));
let _pcc = ds.pearson_correlation_with_p_value(100);
}
#[cfg(feature = "linnerud")]
#[test]
fn test_linnerud() {
let ds = linnerud();
assert_eq!((ds.nsamples(), ds.nfeatures(), ds.ntargets()), (20, 3, 3));
let feature_names = vec!["Chins", "Situps", "Jumps"];
assert_eq!(ds.feature_names(), feature_names);
let mean_targets = ds.targets().mean_axis(Axis(0)).unwrap();
assert_abs_diff_eq!(mean_targets, array![178.6, 35.4, 56.1]);
}
}