1use std::sync::Arc;
17
18use arrow::{
19 array::{Array, Float64Array, RecordBatch, StringArray},
20 datatypes::{DataType, Field, Schema},
21};
22
23use super::CanonicalDataset;
24use crate::{ArrowDataset, Dataset, Result};
25
26pub fn iris() -> Result<IrisDataset> {
53 IrisDataset::load()
54}
55
56#[derive(Debug, Clone)]
61pub struct IrisDataset {
62 data: ArrowDataset,
63}
64
65impl IrisDataset {
66 pub fn load() -> Result<Self> {
72 let schema = Arc::new(Schema::new(vec![
73 Field::new("sepal_length", DataType::Float64, false),
74 Field::new("sepal_width", DataType::Float64, false),
75 Field::new("petal_length", DataType::Float64, false),
76 Field::new("petal_width", DataType::Float64, false),
77 Field::new("species", DataType::Utf8, false),
78 ]));
79
80 let (sepal_length, sepal_width, petal_length, petal_width, species) = iris_data();
83
84 let batch = RecordBatch::try_new(
85 schema,
86 vec![
87 Arc::new(Float64Array::from(sepal_length)),
88 Arc::new(Float64Array::from(sepal_width)),
89 Arc::new(Float64Array::from(petal_length)),
90 Arc::new(Float64Array::from(petal_width)),
91 Arc::new(StringArray::from(species)),
92 ],
93 )
94 .map_err(crate::Error::Arrow)?;
95
96 let data = ArrowDataset::from_batch(batch)?;
97
98 Ok(Self { data })
99 }
100
101 #[must_use]
103 pub fn into_inner(self) -> ArrowDataset {
104 self.data
105 }
106
107 pub fn features(&self) -> Result<ArrowDataset> {
113 use crate::transform::{Select, Transform};
114 let select = Select::new(vec![
115 "sepal_length",
116 "sepal_width",
117 "petal_length",
118 "petal_width",
119 ]);
120 let batch = select.apply(
121 self.data
122 .get_batch(0)
123 .ok_or_else(|| crate::Error::empty_dataset("Iris dataset is empty"))?
124 .clone(),
125 )?;
126 ArrowDataset::from_batch(batch)
127 }
128
129 #[must_use]
131 pub fn labels(&self) -> Vec<String> {
132 if let Some(batch) = self.data.get_batch(0) {
133 if let Some(col) = batch.column_by_name("species") {
134 if let Some(arr) = col.as_any().downcast_ref::<StringArray>() {
135 return (0..arr.len()).map(|i| arr.value(i).to_string()).collect();
136 }
137 }
138 }
139 Vec::new()
140 }
141
142 #[must_use]
144 pub fn labels_numeric(&self) -> Vec<i32> {
145 self.labels()
146 .iter()
147 .map(|s| match s.as_str() {
148 "setosa" => 0,
149 "versicolor" => 1,
150 "virginica" => 2,
151 _ => -1,
152 })
153 .collect()
154 }
155}
156
157impl CanonicalDataset for IrisDataset {
158 fn data(&self) -> &ArrowDataset {
159 &self.data
160 }
161
162 fn num_features(&self) -> usize {
163 4
164 }
165
166 fn num_classes(&self) -> usize {
167 3
168 }
169
170 fn feature_names(&self) -> &'static [&'static str] {
171 &["sepal_length", "sepal_width", "petal_length", "petal_width"]
172 }
173
174 fn target_name(&self) -> &'static str {
175 "species"
176 }
177
178 fn description(&self) -> &'static str {
179 "Iris flower dataset (Fisher, 1936). 150 samples of 3 iris species \
180 (setosa, versicolor, virginica) with 4 features: sepal length/width \
181 and petal length/width in centimeters."
182 }
183}
184
185#[allow(clippy::type_complexity, clippy::similar_names)]
187fn iris_data() -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>, Vec<&'static str>) {
188 let setosa_sl = vec![
190 5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9, 5.4, 4.8, 4.8, 4.3, 5.8, 5.7, 5.4, 5.1,
191 5.7, 5.1, 5.4, 5.1, 4.6, 5.1, 4.8, 5.0, 5.0, 5.2, 5.2, 4.7, 4.8, 5.4, 5.2, 5.5, 4.9, 5.0,
192 5.5, 4.9, 4.4, 5.1, 5.0, 4.5, 4.4, 5.0, 5.1, 4.8, 5.1, 4.6, 5.3, 5.0,
193 ];
194 let setosa_sw = vec![
195 3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1, 3.7, 3.4, 3.0, 3.0, 4.0, 4.4, 3.9, 3.5,
196 3.8, 3.8, 3.4, 3.7, 3.6, 3.3, 3.4, 3.0, 3.4, 3.5, 3.4, 3.2, 3.1, 3.4, 4.1, 4.2, 3.1, 3.2,
197 3.5, 3.6, 3.0, 3.4, 3.5, 2.3, 3.2, 3.5, 3.8, 3.0, 3.8, 3.2, 3.7, 3.3,
198 ];
199 let setosa_pl = vec![
200 1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.4, 1.5, 1.5, 1.6, 1.4, 1.1, 1.2, 1.5, 1.3, 1.4,
201 1.7, 1.5, 1.7, 1.5, 1.0, 1.7, 1.9, 1.6, 1.6, 1.5, 1.4, 1.6, 1.6, 1.5, 1.5, 1.4, 1.5, 1.2,
202 1.3, 1.4, 1.3, 1.5, 1.3, 1.3, 1.3, 1.6, 1.9, 1.4, 1.6, 1.4, 1.5, 1.4,
203 ];
204 let setosa_pw = vec![
205 0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.4, 0.4, 0.3,
206 0.3, 0.3, 0.2, 0.4, 0.2, 0.5, 0.2, 0.2, 0.4, 0.2, 0.2, 0.2, 0.2, 0.4, 0.1, 0.2, 0.2, 0.2,
207 0.2, 0.1, 0.2, 0.2, 0.3, 0.3, 0.2, 0.6, 0.4, 0.3, 0.2, 0.2, 0.2, 0.2,
208 ];
209
210 let versicolor_sl = vec![
212 7.0, 6.4, 6.9, 5.5, 6.5, 5.7, 6.3, 4.9, 6.6, 5.2, 5.0, 5.9, 6.0, 6.1, 5.6, 6.7, 5.6, 5.8,
213 6.2, 5.6, 5.9, 6.1, 6.3, 6.1, 6.4, 6.6, 6.8, 6.7, 6.0, 5.7, 5.5, 5.5, 5.8, 6.0, 5.4, 6.0,
214 6.7, 6.3, 5.6, 5.5, 5.5, 6.1, 5.8, 5.0, 5.6, 5.7, 5.7, 6.2, 5.1, 5.7,
215 ];
216 let versicolor_sw = vec![
217 3.2, 3.2, 3.1, 2.3, 2.8, 2.8, 3.3, 2.4, 2.9, 2.7, 2.0, 3.0, 2.2, 2.9, 2.9, 3.1, 3.0, 2.7,
218 2.2, 2.5, 3.2, 2.8, 2.5, 2.8, 2.9, 3.0, 2.8, 3.0, 2.9, 2.6, 2.4, 2.4, 2.7, 2.7, 3.0, 3.4,
219 3.1, 2.3, 3.0, 2.5, 2.6, 3.0, 2.6, 2.3, 2.7, 3.0, 2.9, 2.9, 2.5, 2.8,
220 ];
221 let versicolor_pl = vec![
222 4.7, 4.5, 4.9, 4.0, 4.6, 4.5, 4.7, 3.3, 4.6, 3.9, 3.5, 4.2, 4.0, 4.7, 3.6, 4.4, 4.5, 4.1,
223 4.5, 3.9, 4.8, 4.0, 4.9, 4.7, 4.3, 4.4, 4.8, 5.0, 4.5, 3.5, 3.8, 3.7, 3.9, 5.1, 4.5, 4.5,
224 4.7, 4.4, 4.1, 4.0, 4.4, 4.6, 4.0, 3.3, 4.2, 4.2, 4.2, 4.3, 3.0, 4.1,
225 ];
226 let versicolor_pw = vec![
227 1.4, 1.5, 1.5, 1.3, 1.5, 1.3, 1.6, 1.0, 1.3, 1.4, 1.0, 1.5, 1.0, 1.4, 1.3, 1.4, 1.5, 1.0,
228 1.5, 1.1, 1.8, 1.3, 1.5, 1.2, 1.3, 1.4, 1.4, 1.7, 1.5, 1.0, 1.1, 1.0, 1.2, 1.6, 1.5, 1.6,
229 1.5, 1.3, 1.3, 1.3, 1.2, 1.4, 1.2, 1.0, 1.3, 1.2, 1.3, 1.3, 1.1, 1.3,
230 ];
231
232 let virginica_sl = vec![
234 6.3, 5.8, 7.1, 6.3, 6.5, 7.6, 4.9, 7.3, 6.7, 7.2, 6.5, 6.4, 6.8, 5.7, 5.8, 6.4, 6.5, 7.7,
235 7.7, 6.0, 6.9, 5.6, 7.7, 6.3, 6.7, 7.2, 6.2, 6.1, 6.4, 7.2, 7.4, 7.9, 6.4, 6.3, 6.1, 7.7,
236 6.3, 6.4, 6.0, 6.9, 6.7, 6.9, 5.8, 6.8, 6.7, 6.7, 6.3, 6.5, 6.2, 5.9,
237 ];
238 let virginica_sw = vec![
239 3.3, 2.7, 3.0, 2.9, 3.0, 3.0, 2.5, 2.9, 2.5, 3.6, 3.2, 2.7, 3.0, 2.5, 2.8, 3.2, 3.0, 3.8,
240 2.6, 2.2, 3.2, 2.8, 2.8, 2.7, 3.3, 3.2, 2.8, 3.0, 2.8, 3.0, 2.8, 3.8, 2.8, 2.8, 2.6, 3.0,
241 3.4, 3.1, 3.0, 3.1, 3.1, 3.1, 2.7, 3.2, 3.3, 3.0, 2.5, 3.0, 3.4, 3.0,
242 ];
243 let virginica_pl = vec![
244 6.0, 5.1, 5.9, 5.6, 5.8, 6.6, 4.5, 6.3, 5.8, 6.1, 5.1, 5.3, 5.5, 5.0, 5.1, 5.3, 5.5, 6.7,
245 6.9, 5.0, 5.7, 4.9, 6.7, 4.9, 5.7, 6.0, 4.8, 4.9, 5.6, 5.8, 6.1, 6.4, 5.6, 5.1, 5.6, 6.1,
246 5.6, 5.5, 4.8, 5.4, 5.6, 5.1, 5.1, 5.9, 5.7, 5.2, 5.0, 5.2, 5.4, 5.1,
247 ];
248 let virginica_pw = vec![
249 2.5, 1.9, 2.1, 1.8, 2.2, 2.1, 1.7, 1.8, 1.8, 2.5, 2.0, 1.9, 2.1, 2.0, 2.4, 2.3, 1.8, 2.2,
250 2.3, 1.5, 2.3, 2.0, 2.0, 1.8, 2.1, 1.8, 1.8, 1.8, 2.1, 1.6, 1.9, 2.0, 2.2, 1.5, 1.4, 2.3,
251 2.4, 1.8, 1.8, 2.1, 2.4, 2.3, 1.9, 2.3, 2.5, 2.3, 1.9, 2.0, 2.3, 1.8,
252 ];
253
254 let mut sepal_length = setosa_sl;
256 sepal_length.extend(versicolor_sl);
257 sepal_length.extend(virginica_sl);
258
259 let mut sepal_width = setosa_sw;
260 sepal_width.extend(versicolor_sw);
261 sepal_width.extend(virginica_sw);
262
263 let mut petal_length = setosa_pl;
264 petal_length.extend(versicolor_pl);
265 petal_length.extend(virginica_pl);
266
267 let mut petal_width = setosa_pw;
268 petal_width.extend(versicolor_pw);
269 petal_width.extend(virginica_pw);
270
271 let species: Vec<&'static str> = std::iter::repeat("setosa")
272 .take(50)
273 .chain(std::iter::repeat("versicolor").take(50))
274 .chain(std::iter::repeat("virginica").take(50))
275 .collect();
276
277 (
278 sepal_length,
279 sepal_width,
280 petal_length,
281 petal_width,
282 species,
283 )
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use crate::Dataset;
290
291 #[test]
292 fn test_iris_load() {
293 let dataset = iris().ok();
294 assert!(dataset.is_some());
295 let dataset = dataset.unwrap_or_else(|| panic!("Failed to load iris"));
296 assert_eq!(dataset.len(), 150);
297 }
298
299 #[test]
300 fn test_iris_features() {
301 let dataset = iris().unwrap_or_else(|e| panic!("Failed: {e}"));
302 assert_eq!(dataset.num_features(), 4);
303 assert_eq!(dataset.num_classes(), 3);
304 }
305
306 #[test]
307 fn test_iris_labels() {
308 let dataset = iris().unwrap_or_else(|e| panic!("Failed: {e}"));
309 let labels = dataset.labels();
310 assert_eq!(labels.len(), 150);
311
312 let setosa_count = labels.iter().filter(|s| *s == "setosa").count();
314 let versicolor_count = labels.iter().filter(|s| *s == "versicolor").count();
315 let virginica_count = labels.iter().filter(|s| *s == "virginica").count();
316
317 assert_eq!(setosa_count, 50);
318 assert_eq!(versicolor_count, 50);
319 assert_eq!(virginica_count, 50);
320 }
321
322 #[test]
323 fn test_iris_labels_numeric() {
324 let dataset = iris().unwrap_or_else(|e| panic!("Failed: {e}"));
325 let labels = dataset.labels_numeric();
326 assert_eq!(labels.len(), 150);
327
328 assert!(labels[0..50].iter().all(|&x| x == 0));
330 assert!(labels[50..100].iter().all(|&x| x == 1));
332 assert!(labels[100..150].iter().all(|&x| x == 2));
334 }
335
336 #[test]
337 fn test_iris_schema() {
338 let dataset = iris().unwrap_or_else(|e| panic!("Failed: {e}"));
339 let schema = dataset.data().schema();
340
341 assert_eq!(schema.fields().len(), 5);
342 assert!(schema.field_with_name("sepal_length").is_ok());
343 assert!(schema.field_with_name("sepal_width").is_ok());
344 assert!(schema.field_with_name("petal_length").is_ok());
345 assert!(schema.field_with_name("petal_width").is_ok());
346 assert!(schema.field_with_name("species").is_ok());
347 }
348
349 #[test]
350 fn test_iris_feature_extraction() {
351 let dataset = iris().unwrap_or_else(|e| panic!("Failed: {e}"));
352 let features = dataset.features();
353 assert!(features.is_ok());
354
355 let features = features.unwrap_or_else(|e| panic!("Failed: {e}"));
356 assert_eq!(features.schema().fields().len(), 4);
357 assert!(features.schema().field_with_name("species").is_err());
358 }
359
360 #[test]
361 fn test_iris_description() {
362 let dataset = iris().unwrap_or_else(|e| panic!("Failed: {e}"));
363 assert!(dataset.description().contains("Fisher"));
364 assert!(dataset.description().contains("150"));
365 }
366
367 #[test]
368 fn test_iris_canonical_trait() {
369 let dataset = iris().unwrap_or_else(|e| panic!("Failed: {e}"));
370
371 assert_eq!(dataset.feature_names().len(), 4);
372 assert_eq!(dataset.target_name(), "species");
373 assert!(!dataset.is_empty());
374 }
375
376 #[test]
377 fn test_iris_into_inner() {
378 let dataset = iris().unwrap();
379 let inner = dataset.into_inner();
380 assert_eq!(inner.len(), 150);
381 }
382
383 #[test]
384 fn test_iris_clone() {
385 let dataset = iris().unwrap();
386 let cloned = dataset.clone();
387 assert_eq!(cloned.len(), dataset.len());
388 }
389
390 #[test]
391 fn test_iris_debug() {
392 let dataset = iris().unwrap();
393 let debug = format!("{:?}", dataset);
394 assert!(debug.contains("IrisDataset"));
395 }
396
397 #[test]
398 fn test_iris_data_access() {
399 let dataset = iris().unwrap();
400 let data = dataset.data();
401 assert_eq!(data.len(), 150);
402 }
403
404 #[test]
405 fn test_iris_data_function() {
406 let (sl, sw, pl, pw, species) = iris_data();
407 assert_eq!(sl.len(), 150);
408 assert_eq!(sw.len(), 150);
409 assert_eq!(pl.len(), 150);
410 assert_eq!(pw.len(), 150);
411 assert_eq!(species.len(), 150);
412 }
413
414 #[test]
415 fn test_iris_data_species_distribution() {
416 let (_, _, _, _, species) = iris_data();
417 let setosa_count = species.iter().filter(|&&s| s == "setosa").count();
418 let versicolor_count = species.iter().filter(|&&s| s == "versicolor").count();
419 let virginica_count = species.iter().filter(|&&s| s == "virginica").count();
420 assert_eq!(setosa_count, 50);
421 assert_eq!(versicolor_count, 50);
422 assert_eq!(virginica_count, 50);
423 }
424
425 #[test]
426 fn test_iris_sepal_length_range() {
427 let (sepal_length, _, _, _, _) = iris_data();
428 for &val in &sepal_length {
429 assert!(
430 (4.0..=8.0).contains(&val),
431 "Sepal length {} out of typical range",
432 val
433 );
434 }
435 }
436
437 #[test]
438 fn test_iris_sepal_width_range() {
439 let (_, sepal_width, _, _, _) = iris_data();
440 for &val in &sepal_width {
441 assert!(
442 (2.0..=5.0).contains(&val),
443 "Sepal width {} out of typical range",
444 val
445 );
446 }
447 }
448
449 #[test]
450 fn test_iris_feature_names_content() {
451 let dataset = iris().unwrap();
452 let names = dataset.feature_names();
453 assert!(names.contains(&"sepal_length"));
454 assert!(names.contains(&"sepal_width"));
455 assert!(names.contains(&"petal_length"));
456 assert!(names.contains(&"petal_width"));
457 }
458}