1use std::io::Read;
2
3use csv::ReaderBuilder;
4use flate2::read::GzDecoder;
5use linfa::Dataset;
6use ndarray::prelude::*;
7use ndarray_csv::{Array2Reader, ReadError};
8
9pub fn array_from_gz_csv<R: Read>(
11 gz: R,
12 has_headers: bool,
13 separator: u8,
14) -> Result<Array2<f64>, ReadError> {
15 let file = GzDecoder::new(gz);
17 array_from_csv(file, has_headers, separator)
18}
19
20pub fn array_from_csv<R: Read>(
22 csv: R,
23 has_headers: bool,
24 separator: u8,
25) -> Result<Array2<f64>, ReadError> {
26 let mut reader = ReaderBuilder::new()
28 .has_headers(has_headers)
29 .delimiter(separator)
30 .from_reader(csv);
31
32 reader.deserialize_array2_dynamic()
34}
35
36#[cfg(feature = "iris")]
37pub fn iris() -> Dataset<f64, usize, Ix1> {
40 let data = include_bytes!("../data/iris.csv.gz");
41 let array = array_from_gz_csv(&data[..], true, b',').unwrap();
42
43 let (data, targets) = (
44 array.slice(s![.., 0..4]).to_owned(),
45 array.column(4).to_owned(),
46 );
47
48 let feature_names = vec!["sepal length", "sepal width", "petal length", "petal width"];
49
50 Dataset::new(data, targets)
51 .map_targets(|x| *x as usize)
52 .with_feature_names(feature_names)
53}
54
55#[cfg(feature = "diabetes")]
56pub fn diabetes() -> Dataset<f64, f64, Ix1> {
58 let data = include_bytes!("../data/diabetes_data.csv.gz");
59 let data = array_from_gz_csv(&data[..], true, b',').unwrap();
60
61 let targets = include_bytes!("../data/diabetes_target.csv.gz");
62 let targets = array_from_gz_csv(&targets[..], true, b',')
63 .unwrap()
64 .column(0)
65 .to_owned();
66
67 let feature_names = vec![
68 "age",
69 "sex",
70 "body mass index",
71 "blood pressure",
72 "t-cells",
73 "low-density lipoproteins",
74 "high-density lipoproteins",
75 "thyroid stimulating hormone",
76 "lamotrigine",
77 "blood sugar level",
78 ];
79
80 Dataset::new(data, targets).with_feature_names(feature_names)
81}
82
83#[cfg(feature = "winequality")]
84pub fn winequality() -> Dataset<f64, usize, Ix1> {
86 let data = include_bytes!("../data/winequality-red.csv.gz");
87 let array = array_from_gz_csv(&data[..], true, b',').unwrap();
88
89 let (data, targets) = (
90 array.slice(s![.., 0..11]).to_owned(),
91 array.column(11).to_owned(),
92 );
93
94 let feature_names = vec![
95 "fixed acidity",
96 "volatile acidity",
97 "citric acid",
98 "residual sugar",
99 "chlorides",
100 "free sulfur dioxide",
101 "total sulfur dioxide",
102 "density",
103 "pH",
104 "sulphates",
105 "alcohol",
106 ];
107
108 Dataset::new(data, targets)
109 .map_targets(|x| *x as usize)
110 .with_feature_names(feature_names)
111}
112
113#[cfg(feature = "linnerud")]
114pub fn linnerud() -> Dataset<f64, f64> {
127 let input_data = include_bytes!("../data/linnerud_exercise.csv.gz");
128 let input_array = array_from_gz_csv(&input_data[..], true, b',').unwrap();
129
130 let output_data = include_bytes!("../data/linnerud_physiological.csv.gz");
131 let output_array = array_from_gz_csv(&output_data[..], true, b',').unwrap();
132
133 let feature_names = vec!["Chins", "Situps", "Jumps"];
134
135 Dataset::new(input_array, output_array).with_feature_names(feature_names)
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use approx::assert_abs_diff_eq;
142 use linfa::prelude::*;
143
144 #[cfg(feature = "iris")]
145 #[test]
146 fn test_iris() {
147 let ds = iris();
148
149 assert_eq!((ds.nsamples(), ds.nfeatures(), ds.ntargets()), (150, 4, 1));
151
152 assert_eq!(
154 ds.feature_names(),
155 &["sepal length", "sepal width", "petal length", "petal width"]
156 );
157
158 assert_abs_diff_eq!(
160 ds.label_frequencies()
161 .into_iter()
162 .map(|b| b.1)
163 .collect::<Array1<_>>(),
164 array![50., 50., 50.]
165 );
166
167 let _pcc = ds.pearson_correlation_with_p_value(100);
169 let mean_features = ds.records().mean_axis(Axis(0)).unwrap();
174 assert_abs_diff_eq!(
175 mean_features,
176 array![5.84, 3.05, 3.75, 1.20],
177 epsilon = 0.01
178 );
179 }
180
181 #[cfg(feature = "diabetes")]
182 #[test]
183 fn test_diabetes() {
184 let ds = diabetes();
185
186 assert_eq!((ds.nsamples(), ds.nfeatures(), ds.ntargets()), (441, 10, 1));
188
189 let _pcc = ds.pearson_correlation_with_p_value(100);
192 let mean_features = ds.records().mean_axis(Axis(0)).unwrap();
196 assert_abs_diff_eq!(mean_features, Array1::zeros(10), epsilon = 0.005);
197 }
198
199 #[cfg(feature = "winequality")]
200 #[test]
201 fn test_winequality() {
202 use approx::abs_diff_eq;
203
204 let ds = winequality();
205
206 assert_eq!(
208 (ds.nsamples(), ds.nfeatures(), ds.ntargets()),
209 (1599, 11, 1)
210 );
211
212 let feature_names = vec![
214 "fixed acidity",
215 "volatile acidity",
216 "citric acid",
217 "residual sugar",
218 "chlorides",
219 "free sulfur dioxide",
220 "total sulfur dioxide",
221 "density",
222 "pH",
223 "sulphates",
224 "alcohol",
225 ];
226 assert_eq!(ds.feature_names(), feature_names);
227
228 let compare_to = vec![
230 (5, 681.0),
231 (7, 199.0),
232 (6, 638.0),
233 (8, 18.0),
234 (3, 10.0),
235 (4, 53.0),
236 ];
237
238 let freqs = ds.label_frequencies();
239 assert!(compare_to.into_iter().all(|(key, val)| {
240 freqs
241 .get(&key)
242 .map(|x| abs_diff_eq!(*x, val))
243 .unwrap_or(false)
244 }));
245
246 let _pcc = ds.pearson_correlation_with_p_value(100);
249 }
251
252 #[cfg(feature = "linnerud")]
253 #[test]
254 fn test_linnerud() {
255 let ds = linnerud();
256
257 assert_eq!((ds.nsamples(), ds.nfeatures(), ds.ntargets()), (20, 3, 3));
259
260 let feature_names = vec!["Chins", "Situps", "Jumps"];
262 assert_eq!(ds.feature_names(), feature_names);
263
264 let mean_targets = ds.targets().mean_axis(Axis(0)).unwrap();
266 assert_abs_diff_eq!(mean_targets, array![178.6, 35.4, 56.1]);
267 }
268}