linfa_datasets/
dataset.rs

1use std::io::Read;
2
3use csv::ReaderBuilder;
4use flate2::read::GzDecoder;
5use linfa::Dataset;
6use ndarray::prelude::*;
7use ndarray_csv::{Array2Reader, ReadError};
8
9/// Convert Gzipped CSV bytes into 2D array
10pub fn array_from_gz_csv<R: Read>(
11    gz: R,
12    has_headers: bool,
13    separator: u8,
14) -> Result<Array2<f64>, ReadError> {
15    // unzip file
16    let file = GzDecoder::new(gz);
17    array_from_csv(file, has_headers, separator)
18}
19
20/// Convert CSV bytes into 2D array
21pub fn array_from_csv<R: Read>(
22    csv: R,
23    has_headers: bool,
24    separator: u8,
25) -> Result<Array2<f64>, ReadError> {
26    // parse CSV
27    let mut reader = ReaderBuilder::new()
28        .has_headers(has_headers)
29        .delimiter(separator)
30        .from_reader(csv);
31
32    // extract ndarray
33    reader.deserialize_array2_dynamic()
34}
35
36#[cfg(feature = "iris")]
37/// Read in the iris-flower dataset from dataset path.
38// The `.csv` data is two dimensional: Axis(0) denotes y-axis (rows), Axis(1) denotes x-axis (columns)
39pub fn iris() -> Dataset<f64, usize, Ix1> {
40    let data = include_bytes!("../data/iris.csv.gz");
41    let array = array_from_gz_csv(&data[..], true, b',').unwrap();
42
43    let (data, targets) = (
44        array.slice(s![.., 0..4]).to_owned(),
45        array.column(4).to_owned(),
46    );
47
48    let feature_names = vec!["sepal length", "sepal width", "petal length", "petal width"];
49
50    Dataset::new(data, targets)
51        .map_targets(|x| *x as usize)
52        .with_feature_names(feature_names)
53}
54
55#[cfg(feature = "diabetes")]
56/// Read in the diabetes dataset from dataset path
57pub fn diabetes() -> Dataset<f64, f64, Ix1> {
58    let data = include_bytes!("../data/diabetes_data.csv.gz");
59    let data = array_from_gz_csv(&data[..], true, b',').unwrap();
60
61    let targets = include_bytes!("../data/diabetes_target.csv.gz");
62    let targets = array_from_gz_csv(&targets[..], true, b',')
63        .unwrap()
64        .column(0)
65        .to_owned();
66
67    let feature_names = vec![
68        "age",
69        "sex",
70        "body mass index",
71        "blood pressure",
72        "t-cells",
73        "low-density lipoproteins",
74        "high-density lipoproteins",
75        "thyroid stimulating hormone",
76        "lamotrigine",
77        "blood sugar level",
78    ];
79
80    Dataset::new(data, targets).with_feature_names(feature_names)
81}
82
83#[cfg(feature = "winequality")]
84/// Read in the winequality dataset from dataset path
85pub fn winequality() -> Dataset<f64, usize, Ix1> {
86    let data = include_bytes!("../data/winequality-red.csv.gz");
87    let array = array_from_gz_csv(&data[..], true, b',').unwrap();
88
89    let (data, targets) = (
90        array.slice(s![.., 0..11]).to_owned(),
91        array.column(11).to_owned(),
92    );
93
94    let feature_names = vec![
95        "fixed acidity",
96        "volatile acidity",
97        "citric acid",
98        "residual sugar",
99        "chlorides",
100        "free sulfur dioxide",
101        "total sulfur dioxide",
102        "density",
103        "pH",
104        "sulphates",
105        "alcohol",
106    ];
107
108    Dataset::new(data, targets)
109        .map_targets(|x| *x as usize)
110        .with_feature_names(feature_names)
111}
112
113#[cfg(feature = "linnerud")]
114/// Read in the physical exercise dataset from dataset path.
115///
116/// Linnerud dataset contains 20 samples collected from 20 middle-aged men in a fitness club.
117///
118/// ## Features:
119/// 3 exercises measurements: Chins, Situps, Jumps
120///
121/// ## Targets:
122/// 3 physiological measurements: Weight, Waist, Pulse
123///
124/// # Reference:
125/// Tenenhaus (1998). La regression PLS: theorie et pratique. Paris: Editions Technip. Table p 15.
126pub fn linnerud() -> Dataset<f64, f64> {
127    let input_data = include_bytes!("../data/linnerud_exercise.csv.gz");
128    let input_array = array_from_gz_csv(&input_data[..], true, b',').unwrap();
129
130    let output_data = include_bytes!("../data/linnerud_physiological.csv.gz");
131    let output_array = array_from_gz_csv(&output_data[..], true, b',').unwrap();
132
133    let feature_names = vec!["Chins", "Situps", "Jumps"];
134
135    Dataset::new(input_array, output_array).with_feature_names(feature_names)
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use approx::assert_abs_diff_eq;
142    use linfa::prelude::*;
143
144    #[cfg(feature = "iris")]
145    #[test]
146    fn test_iris() {
147        let ds = iris();
148
149        // check that we have the right amount of data
150        assert_eq!((ds.nsamples(), ds.nfeatures(), ds.ntargets()), (150, 4, 1));
151
152        // check for feature names
153        assert_eq!(
154            ds.feature_names(),
155            &["sepal length", "sepal width", "petal length", "petal width"]
156        );
157
158        // check label frequency
159        assert_abs_diff_eq!(
160            ds.label_frequencies()
161                .into_iter()
162                .map(|b| b.1)
163                .collect::<Array1<_>>(),
164            array![50., 50., 50.]
165        );
166
167        // perform correlation analysis and assert that petal length and width are correlated
168        let _pcc = ds.pearson_correlation_with_p_value(100);
169        // TODO: wait for pearson correlation to accept rng
170        // assert_abs_diff_eq!(pcc.get_p_values().unwrap()[5], 0.04, epsilon = 0.04);
171
172        // get the mean per feature
173        let mean_features = ds.records().mean_axis(Axis(0)).unwrap();
174        assert_abs_diff_eq!(
175            mean_features,
176            array![5.84, 3.05, 3.75, 1.20],
177            epsilon = 0.01
178        );
179    }
180
181    #[cfg(feature = "diabetes")]
182    #[test]
183    fn test_diabetes() {
184        let ds = diabetes();
185
186        // check that we have the right amount of data
187        assert_eq!((ds.nsamples(), ds.nfeatures(), ds.ntargets()), (441, 10, 1));
188
189        // perform correlation analysis and assert that T-Cells and low-density lipoproteins are
190        // correlated
191        let _pcc = ds.pearson_correlation_with_p_value(100);
192        //assert_abs_diff_eq!(pcc.get_p_values().unwrap()[30], 0.02, epsilon = 0.02);
193
194        // get the mean per feature, the data should be normalized
195        let mean_features = ds.records().mean_axis(Axis(0)).unwrap();
196        assert_abs_diff_eq!(mean_features, Array1::zeros(10), epsilon = 0.005);
197    }
198
199    #[cfg(feature = "winequality")]
200    #[test]
201    fn test_winequality() {
202        use approx::abs_diff_eq;
203
204        let ds = winequality();
205
206        // check that we have the right amount of data
207        assert_eq!(
208            (ds.nsamples(), ds.nfeatures(), ds.ntargets()),
209            (1599, 11, 1)
210        );
211
212        // check for feature names
213        let feature_names = vec![
214            "fixed acidity",
215            "volatile acidity",
216            "citric acid",
217            "residual sugar",
218            "chlorides",
219            "free sulfur dioxide",
220            "total sulfur dioxide",
221            "density",
222            "pH",
223            "sulphates",
224            "alcohol",
225        ];
226        assert_eq!(ds.feature_names(), feature_names);
227
228        // check label frequency
229        let compare_to = vec![
230            (5, 681.0),
231            (7, 199.0),
232            (6, 638.0),
233            (8, 18.0),
234            (3, 10.0),
235            (4, 53.0),
236        ];
237
238        let freqs = ds.label_frequencies();
239        assert!(compare_to.into_iter().all(|(key, val)| {
240            freqs
241                .get(&key)
242                .map(|x| abs_diff_eq!(*x, val))
243                .unwrap_or(false)
244        }));
245
246        // perform correlation analysis and assert that fixed acidity and citric acid are
247        // correlated
248        let _pcc = ds.pearson_correlation_with_p_value(100);
249        //assert_abs_diff_eq!(pcc.get_p_values().unwrap()[1], 0.05, epsilon = 0.05);
250    }
251
252    #[cfg(feature = "linnerud")]
253    #[test]
254    fn test_linnerud() {
255        let ds = linnerud();
256
257        // check that we have the right amount of data
258        assert_eq!((ds.nsamples(), ds.nfeatures(), ds.ntargets()), (20, 3, 3));
259
260        // check for feature names
261        let feature_names = vec!["Chins", "Situps", "Jumps"];
262        assert_eq!(ds.feature_names(), feature_names);
263
264        // get the mean per target: Weight, Waist, Pulse
265        let mean_targets = ds.targets().mean_axis(Axis(0)).unwrap();
266        assert_abs_diff_eq!(mean_targets, array![178.6, 35.4, 56.1]);
267    }
268}