1use std::io::Read;
2
3use csv::ReaderBuilder;
4use flate2::read::GzDecoder;
5use linfa::Dataset;
6use ndarray::prelude::*;
7use ndarray_csv::{Array2Reader, ReadError};
8
9pub fn array_from_gz_csv<R: Read>(
11    gz: R,
12    has_headers: bool,
13    separator: u8,
14) -> Result<Array2<f64>, ReadError> {
15    let file = GzDecoder::new(gz);
17    array_from_csv(file, has_headers, separator)
18}
19
20pub fn array_from_csv<R: Read>(
22    csv: R,
23    has_headers: bool,
24    separator: u8,
25) -> Result<Array2<f64>, ReadError> {
26    let mut reader = ReaderBuilder::new()
28        .has_headers(has_headers)
29        .delimiter(separator)
30        .from_reader(csv);
31
32    reader.deserialize_array2_dynamic()
34}
35
36#[cfg(feature = "iris")]
37pub fn iris() -> Dataset<f64, usize, Ix1> {
40    let data = include_bytes!("../data/iris.csv.gz");
41    let array = array_from_gz_csv(&data[..], true, b',').unwrap();
42
43    let (data, targets) = (
44        array.slice(s![.., 0..4]).to_owned(),
45        array.column(4).to_owned(),
46    );
47
48    let feature_names = vec!["sepal length", "sepal width", "petal length", "petal width"];
49
50    Dataset::new(data, targets)
51        .map_targets(|x| *x as usize)
52        .with_feature_names(feature_names)
53}
54
55#[cfg(feature = "diabetes")]
56pub fn diabetes() -> Dataset<f64, f64, Ix1> {
58    let data = include_bytes!("../data/diabetes_data.csv.gz");
59    let data = array_from_gz_csv(&data[..], true, b',').unwrap();
60
61    let targets = include_bytes!("../data/diabetes_target.csv.gz");
62    let targets = array_from_gz_csv(&targets[..], true, b',')
63        .unwrap()
64        .column(0)
65        .to_owned();
66
67    let feature_names = vec![
68        "age",
69        "sex",
70        "body mass index",
71        "blood pressure",
72        "t-cells",
73        "low-density lipoproteins",
74        "high-density lipoproteins",
75        "thyroid stimulating hormone",
76        "lamotrigine",
77        "blood sugar level",
78    ];
79
80    Dataset::new(data, targets).with_feature_names(feature_names)
81}
82
83#[cfg(feature = "winequality")]
84pub fn winequality() -> Dataset<f64, usize, Ix1> {
86    let data = include_bytes!("../data/winequality-red.csv.gz");
87    let array = array_from_gz_csv(&data[..], true, b',').unwrap();
88
89    let (data, targets) = (
90        array.slice(s![.., 0..11]).to_owned(),
91        array.column(11).to_owned(),
92    );
93
94    let feature_names = vec![
95        "fixed acidity",
96        "volatile acidity",
97        "citric acid",
98        "residual sugar",
99        "chlorides",
100        "free sulfur dioxide",
101        "total sulfur dioxide",
102        "density",
103        "pH",
104        "sulphates",
105        "alcohol",
106    ];
107
108    Dataset::new(data, targets)
109        .map_targets(|x| *x as usize)
110        .with_feature_names(feature_names)
111}
112
113#[cfg(feature = "linnerud")]
114pub fn linnerud() -> Dataset<f64, f64> {
127    let input_data = include_bytes!("../data/linnerud_exercise.csv.gz");
128    let input_array = array_from_gz_csv(&input_data[..], true, b',').unwrap();
129
130    let output_data = include_bytes!("../data/linnerud_physiological.csv.gz");
131    let output_array = array_from_gz_csv(&output_data[..], true, b',').unwrap();
132
133    let feature_names = vec!["Chins", "Situps", "Jumps"];
134    let target_names = vec!["Weight", "Waist", "Pulse"];
135
136    Dataset::new(input_array, output_array)
137        .with_feature_names(feature_names)
138        .with_target_names(target_names)
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use approx::assert_abs_diff_eq;
145    use linfa::prelude::*;
146
147    #[cfg(feature = "iris")]
148    #[test]
149    fn test_iris() {
150        let ds = iris();
151
152        assert_eq!((ds.nsamples(), ds.nfeatures(), ds.ntargets()), (150, 4, 1));
154
155        assert_eq!(
157            ds.feature_names(),
158            &["sepal length", "sepal width", "petal length", "petal width"]
159        );
160
161        assert_abs_diff_eq!(
163            ds.label_frequencies()
164                .into_iter()
165                .map(|b| b.1)
166                .collect::<Array1<_>>(),
167            array![50., 50., 50.]
168        );
169
170        let _pcc = ds.pearson_correlation_with_p_value(100);
172        let mean_features = ds.records().mean_axis(Axis(0)).unwrap();
177        assert_abs_diff_eq!(
178            mean_features,
179            array![5.84, 3.05, 3.75, 1.20],
180            epsilon = 0.01
181        );
182    }
183
184    #[cfg(feature = "diabetes")]
185    #[test]
186    fn test_diabetes() {
187        let ds = diabetes();
188
189        assert_eq!((ds.nsamples(), ds.nfeatures(), ds.ntargets()), (441, 10, 1));
191
192        let _pcc = ds.pearson_correlation_with_p_value(100);
195        let mean_features = ds.records().mean_axis(Axis(0)).unwrap();
199        assert_abs_diff_eq!(mean_features, Array1::zeros(10), epsilon = 0.005);
200    }
201
202    #[cfg(feature = "winequality")]
203    #[test]
204    fn test_winequality() {
205        use approx::abs_diff_eq;
206
207        let ds = winequality();
208
209        assert_eq!(
211            (ds.nsamples(), ds.nfeatures(), ds.ntargets()),
212            (1599, 11, 1)
213        );
214
215        let feature_names = vec![
217            "fixed acidity",
218            "volatile acidity",
219            "citric acid",
220            "residual sugar",
221            "chlorides",
222            "free sulfur dioxide",
223            "total sulfur dioxide",
224            "density",
225            "pH",
226            "sulphates",
227            "alcohol",
228        ];
229        assert_eq!(ds.feature_names(), feature_names);
230
231        let compare_to = vec![
233            (5, 681.0),
234            (7, 199.0),
235            (6, 638.0),
236            (8, 18.0),
237            (3, 10.0),
238            (4, 53.0),
239        ];
240
241        let freqs = ds.label_frequencies();
242        assert!(compare_to.into_iter().all(|(key, val)| {
243            freqs
244                .get(&key)
245                .map(|x| abs_diff_eq!(*x, val))
246                .unwrap_or(false)
247        }));
248
249        let _pcc = ds.pearson_correlation_with_p_value(100);
252        }
254
255    #[cfg(feature = "linnerud")]
256    #[test]
257    fn test_linnerud() {
258        let ds = linnerud();
259
260        assert_eq!((ds.nsamples(), ds.nfeatures(), ds.ntargets()), (20, 3, 3));
262
263        let feature_names = vec!["Chins", "Situps", "Jumps"];
265        assert_eq!(ds.feature_names(), feature_names);
266
267        let target_names = vec!["Weight", "Waist", "Pulse"];
269        assert_eq!(ds.target_names(), target_names);
270
271        let mean_targets = ds.targets().mean_axis(Axis(0)).unwrap();
273        assert_abs_diff_eq!(mean_targets, array![178.6, 35.4, 56.1]);
274    }
275}