Skip to main content

alimentar/datasets/
iris.rs

1//! Iris dataset loader
2//!
3//! The classic Iris flower dataset by Ronald Fisher (1936).
4//! Contains 150 samples of 3 iris species with 4 features each.
5//!
6//! # Example
7//!
8//! ```
9//! use alimentar::datasets::{iris, CanonicalDataset};
10//!
11//! let dataset = iris().unwrap();
12//! assert_eq!(dataset.len(), 150);
13//! assert_eq!(dataset.num_classes(), 3);
14//! ```
15
16use std::sync::Arc;
17
18use arrow::{
19    array::{Array, Float64Array, RecordBatch, StringArray},
20    datatypes::{DataType, Field, Schema},
21};
22
23use super::CanonicalDataset;
24use crate::{ArrowDataset, Dataset, Result};
25
26/// Load the Iris dataset
27///
28/// Returns a dataset with 150 samples and 5 columns:
29/// - sepal_length (f64)
30/// - sepal_width (f64)
31/// - petal_length (f64)
32/// - petal_width (f64)
33/// - species (string: "setosa", "versicolor", "virginica")
34///
35/// # Errors
36///
37/// Returns an error if the dataset cannot be constructed (should never happen
38/// for embedded data).
39///
40/// # Example
41///
42/// ```
43/// use alimentar::datasets::{iris, CanonicalDataset};
44///
45/// let dataset = iris().unwrap();
46/// println!(
47///     "Iris dataset: {} samples, {} features",
48///     dataset.len(),
49///     dataset.num_features()
50/// );
51/// ```
52pub fn iris() -> Result<IrisDataset> {
53    IrisDataset::load()
54}
55
56/// The Iris flower dataset
57///
58/// A classic dataset for classification containing measurements of 150 iris
59/// flowers from 3 species (setosa, versicolor, virginica).
60#[derive(Debug, Clone)]
61pub struct IrisDataset {
62    data: ArrowDataset,
63}
64
65impl IrisDataset {
66    /// Load the embedded Iris dataset
67    ///
68    /// # Errors
69    ///
70    /// Returns an error if dataset construction fails.
71    pub fn load() -> Result<Self> {
72        let schema = Arc::new(Schema::new(vec![
73            Field::new("sepal_length", DataType::Float64, false),
74            Field::new("sepal_width", DataType::Float64, false),
75            Field::new("petal_length", DataType::Float64, false),
76            Field::new("petal_width", DataType::Float64, false),
77            Field::new("species", DataType::Utf8, false),
78        ]));
79
80        // Embedded Iris data (150 samples)
81        // Data from UCI ML Repository / scikit-learn
82        let (sepal_length, sepal_width, petal_length, petal_width, species) = iris_data();
83
84        let batch = RecordBatch::try_new(
85            schema,
86            vec![
87                Arc::new(Float64Array::from(sepal_length)),
88                Arc::new(Float64Array::from(sepal_width)),
89                Arc::new(Float64Array::from(petal_length)),
90                Arc::new(Float64Array::from(petal_width)),
91                Arc::new(StringArray::from(species)),
92            ],
93        )
94        .map_err(crate::Error::Arrow)?;
95
96        let data = ArrowDataset::from_batch(batch)?;
97
98        Ok(Self { data })
99    }
100
101    /// Get the underlying Arrow dataset
102    #[must_use]
103    pub fn into_inner(self) -> ArrowDataset {
104        self.data
105    }
106
107    /// Get feature columns as a new dataset (excludes species)
108    ///
109    /// # Errors
110    ///
111    /// Returns an error if transform fails.
112    pub fn features(&self) -> Result<ArrowDataset> {
113        use crate::transform::{Select, Transform};
114        let select = Select::new(vec![
115            "sepal_length",
116            "sepal_width",
117            "petal_length",
118            "petal_width",
119        ]);
120        let batch = select.apply(
121            self.data
122                .get_batch(0)
123                .ok_or_else(|| crate::Error::empty_dataset("Iris dataset is empty"))?
124                .clone(),
125        )?;
126        ArrowDataset::from_batch(batch)
127    }
128
129    /// Get species labels as string array
130    #[must_use]
131    pub fn labels(&self) -> Vec<String> {
132        if let Some(batch) = self.data.get_batch(0) {
133            if let Some(col) = batch.column_by_name("species") {
134                if let Some(arr) = col.as_any().downcast_ref::<StringArray>() {
135                    return (0..arr.len()).map(|i| arr.value(i).to_string()).collect();
136                }
137            }
138        }
139        Vec::new()
140    }
141
142    /// Get species labels as numeric (0=setosa, 1=versicolor, 2=virginica)
143    #[must_use]
144    pub fn labels_numeric(&self) -> Vec<i32> {
145        self.labels()
146            .iter()
147            .map(|s| match s.as_str() {
148                "setosa" => 0,
149                "versicolor" => 1,
150                "virginica" => 2,
151                _ => -1,
152            })
153            .collect()
154    }
155}
156
157impl CanonicalDataset for IrisDataset {
158    fn data(&self) -> &ArrowDataset {
159        &self.data
160    }
161
162    fn num_features(&self) -> usize {
163        4
164    }
165
166    fn num_classes(&self) -> usize {
167        3
168    }
169
170    fn feature_names(&self) -> &'static [&'static str] {
171        &["sepal_length", "sepal_width", "petal_length", "petal_width"]
172    }
173
174    fn target_name(&self) -> &'static str {
175        "species"
176    }
177
178    fn description(&self) -> &'static str {
179        "Iris flower dataset (Fisher, 1936). 150 samples of 3 iris species \
180         (setosa, versicolor, virginica) with 4 features: sepal length/width \
181         and petal length/width in centimeters."
182    }
183}
184
185/// Returns the embedded Iris dataset values
186#[allow(clippy::type_complexity, clippy::similar_names)]
187fn iris_data() -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>, Vec<&'static str>) {
188    // Iris setosa (50 samples)
189    let setosa_sl = vec![
190        5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9, 5.4, 4.8, 4.8, 4.3, 5.8, 5.7, 5.4, 5.1,
191        5.7, 5.1, 5.4, 5.1, 4.6, 5.1, 4.8, 5.0, 5.0, 5.2, 5.2, 4.7, 4.8, 5.4, 5.2, 5.5, 4.9, 5.0,
192        5.5, 4.9, 4.4, 5.1, 5.0, 4.5, 4.4, 5.0, 5.1, 4.8, 5.1, 4.6, 5.3, 5.0,
193    ];
194    let setosa_sw = vec![
195        3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1, 3.7, 3.4, 3.0, 3.0, 4.0, 4.4, 3.9, 3.5,
196        3.8, 3.8, 3.4, 3.7, 3.6, 3.3, 3.4, 3.0, 3.4, 3.5, 3.4, 3.2, 3.1, 3.4, 4.1, 4.2, 3.1, 3.2,
197        3.5, 3.6, 3.0, 3.4, 3.5, 2.3, 3.2, 3.5, 3.8, 3.0, 3.8, 3.2, 3.7, 3.3,
198    ];
199    let setosa_pl = vec![
200        1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.4, 1.5, 1.5, 1.6, 1.4, 1.1, 1.2, 1.5, 1.3, 1.4,
201        1.7, 1.5, 1.7, 1.5, 1.0, 1.7, 1.9, 1.6, 1.6, 1.5, 1.4, 1.6, 1.6, 1.5, 1.5, 1.4, 1.5, 1.2,
202        1.3, 1.4, 1.3, 1.5, 1.3, 1.3, 1.3, 1.6, 1.9, 1.4, 1.6, 1.4, 1.5, 1.4,
203    ];
204    let setosa_pw = vec![
205        0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.4, 0.4, 0.3,
206        0.3, 0.3, 0.2, 0.4, 0.2, 0.5, 0.2, 0.2, 0.4, 0.2, 0.2, 0.2, 0.2, 0.4, 0.1, 0.2, 0.2, 0.2,
207        0.2, 0.1, 0.2, 0.2, 0.3, 0.3, 0.2, 0.6, 0.4, 0.3, 0.2, 0.2, 0.2, 0.2,
208    ];
209
210    // Iris versicolor (50 samples)
211    let versicolor_sl = vec![
212        7.0, 6.4, 6.9, 5.5, 6.5, 5.7, 6.3, 4.9, 6.6, 5.2, 5.0, 5.9, 6.0, 6.1, 5.6, 6.7, 5.6, 5.8,
213        6.2, 5.6, 5.9, 6.1, 6.3, 6.1, 6.4, 6.6, 6.8, 6.7, 6.0, 5.7, 5.5, 5.5, 5.8, 6.0, 5.4, 6.0,
214        6.7, 6.3, 5.6, 5.5, 5.5, 6.1, 5.8, 5.0, 5.6, 5.7, 5.7, 6.2, 5.1, 5.7,
215    ];
216    let versicolor_sw = vec![
217        3.2, 3.2, 3.1, 2.3, 2.8, 2.8, 3.3, 2.4, 2.9, 2.7, 2.0, 3.0, 2.2, 2.9, 2.9, 3.1, 3.0, 2.7,
218        2.2, 2.5, 3.2, 2.8, 2.5, 2.8, 2.9, 3.0, 2.8, 3.0, 2.9, 2.6, 2.4, 2.4, 2.7, 2.7, 3.0, 3.4,
219        3.1, 2.3, 3.0, 2.5, 2.6, 3.0, 2.6, 2.3, 2.7, 3.0, 2.9, 2.9, 2.5, 2.8,
220    ];
221    let versicolor_pl = vec![
222        4.7, 4.5, 4.9, 4.0, 4.6, 4.5, 4.7, 3.3, 4.6, 3.9, 3.5, 4.2, 4.0, 4.7, 3.6, 4.4, 4.5, 4.1,
223        4.5, 3.9, 4.8, 4.0, 4.9, 4.7, 4.3, 4.4, 4.8, 5.0, 4.5, 3.5, 3.8, 3.7, 3.9, 5.1, 4.5, 4.5,
224        4.7, 4.4, 4.1, 4.0, 4.4, 4.6, 4.0, 3.3, 4.2, 4.2, 4.2, 4.3, 3.0, 4.1,
225    ];
226    let versicolor_pw = vec![
227        1.4, 1.5, 1.5, 1.3, 1.5, 1.3, 1.6, 1.0, 1.3, 1.4, 1.0, 1.5, 1.0, 1.4, 1.3, 1.4, 1.5, 1.0,
228        1.5, 1.1, 1.8, 1.3, 1.5, 1.2, 1.3, 1.4, 1.4, 1.7, 1.5, 1.0, 1.1, 1.0, 1.2, 1.6, 1.5, 1.6,
229        1.5, 1.3, 1.3, 1.3, 1.2, 1.4, 1.2, 1.0, 1.3, 1.2, 1.3, 1.3, 1.1, 1.3,
230    ];
231
232    // Iris virginica (50 samples)
233    let virginica_sl = vec![
234        6.3, 5.8, 7.1, 6.3, 6.5, 7.6, 4.9, 7.3, 6.7, 7.2, 6.5, 6.4, 6.8, 5.7, 5.8, 6.4, 6.5, 7.7,
235        7.7, 6.0, 6.9, 5.6, 7.7, 6.3, 6.7, 7.2, 6.2, 6.1, 6.4, 7.2, 7.4, 7.9, 6.4, 6.3, 6.1, 7.7,
236        6.3, 6.4, 6.0, 6.9, 6.7, 6.9, 5.8, 6.8, 6.7, 6.7, 6.3, 6.5, 6.2, 5.9,
237    ];
238    let virginica_sw = vec![
239        3.3, 2.7, 3.0, 2.9, 3.0, 3.0, 2.5, 2.9, 2.5, 3.6, 3.2, 2.7, 3.0, 2.5, 2.8, 3.2, 3.0, 3.8,
240        2.6, 2.2, 3.2, 2.8, 2.8, 2.7, 3.3, 3.2, 2.8, 3.0, 2.8, 3.0, 2.8, 3.8, 2.8, 2.8, 2.6, 3.0,
241        3.4, 3.1, 3.0, 3.1, 3.1, 3.1, 2.7, 3.2, 3.3, 3.0, 2.5, 3.0, 3.4, 3.0,
242    ];
243    let virginica_pl = vec![
244        6.0, 5.1, 5.9, 5.6, 5.8, 6.6, 4.5, 6.3, 5.8, 6.1, 5.1, 5.3, 5.5, 5.0, 5.1, 5.3, 5.5, 6.7,
245        6.9, 5.0, 5.7, 4.9, 6.7, 4.9, 5.7, 6.0, 4.8, 4.9, 5.6, 5.8, 6.1, 6.4, 5.6, 5.1, 5.6, 6.1,
246        5.6, 5.5, 4.8, 5.4, 5.6, 5.1, 5.1, 5.9, 5.7, 5.2, 5.0, 5.2, 5.4, 5.1,
247    ];
248    let virginica_pw = vec![
249        2.5, 1.9, 2.1, 1.8, 2.2, 2.1, 1.7, 1.8, 1.8, 2.5, 2.0, 1.9, 2.1, 2.0, 2.4, 2.3, 1.8, 2.2,
250        2.3, 1.5, 2.3, 2.0, 2.0, 1.8, 2.1, 1.8, 1.8, 1.8, 2.1, 1.6, 1.9, 2.0, 2.2, 1.5, 1.4, 2.3,
251        2.4, 1.8, 1.8, 2.1, 2.4, 2.3, 1.9, 2.3, 2.5, 2.3, 1.9, 2.0, 2.3, 1.8,
252    ];
253
254    // Combine all data
255    let mut sepal_length = setosa_sl;
256    sepal_length.extend(versicolor_sl);
257    sepal_length.extend(virginica_sl);
258
259    let mut sepal_width = setosa_sw;
260    sepal_width.extend(versicolor_sw);
261    sepal_width.extend(virginica_sw);
262
263    let mut petal_length = setosa_pl;
264    petal_length.extend(versicolor_pl);
265    petal_length.extend(virginica_pl);
266
267    let mut petal_width = setosa_pw;
268    petal_width.extend(versicolor_pw);
269    petal_width.extend(virginica_pw);
270
271    let species: Vec<&'static str> = std::iter::repeat("setosa")
272        .take(50)
273        .chain(std::iter::repeat("versicolor").take(50))
274        .chain(std::iter::repeat("virginica").take(50))
275        .collect();
276
277    (
278        sepal_length,
279        sepal_width,
280        petal_length,
281        petal_width,
282        species,
283    )
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use crate::Dataset;
290
291    #[test]
292    fn test_iris_load() {
293        let dataset = iris().ok();
294        assert!(dataset.is_some());
295        let dataset = dataset.unwrap_or_else(|| panic!("Failed to load iris"));
296        assert_eq!(dataset.len(), 150);
297    }
298
299    #[test]
300    fn test_iris_features() {
301        let dataset = iris().unwrap_or_else(|e| panic!("Failed: {e}"));
302        assert_eq!(dataset.num_features(), 4);
303        assert_eq!(dataset.num_classes(), 3);
304    }
305
306    #[test]
307    fn test_iris_labels() {
308        let dataset = iris().unwrap_or_else(|e| panic!("Failed: {e}"));
309        let labels = dataset.labels();
310        assert_eq!(labels.len(), 150);
311
312        // Check distribution: 50 of each species
313        let setosa_count = labels.iter().filter(|s| *s == "setosa").count();
314        let versicolor_count = labels.iter().filter(|s| *s == "versicolor").count();
315        let virginica_count = labels.iter().filter(|s| *s == "virginica").count();
316
317        assert_eq!(setosa_count, 50);
318        assert_eq!(versicolor_count, 50);
319        assert_eq!(virginica_count, 50);
320    }
321
322    #[test]
323    fn test_iris_labels_numeric() {
324        let dataset = iris().unwrap_or_else(|e| panic!("Failed: {e}"));
325        let labels = dataset.labels_numeric();
326        assert_eq!(labels.len(), 150);
327
328        // First 50 should be 0 (setosa)
329        assert!(labels[0..50].iter().all(|&x| x == 0));
330        // Next 50 should be 1 (versicolor)
331        assert!(labels[50..100].iter().all(|&x| x == 1));
332        // Last 50 should be 2 (virginica)
333        assert!(labels[100..150].iter().all(|&x| x == 2));
334    }
335
336    #[test]
337    fn test_iris_schema() {
338        let dataset = iris().unwrap_or_else(|e| panic!("Failed: {e}"));
339        let schema = dataset.data().schema();
340
341        assert_eq!(schema.fields().len(), 5);
342        assert!(schema.field_with_name("sepal_length").is_ok());
343        assert!(schema.field_with_name("sepal_width").is_ok());
344        assert!(schema.field_with_name("petal_length").is_ok());
345        assert!(schema.field_with_name("petal_width").is_ok());
346        assert!(schema.field_with_name("species").is_ok());
347    }
348
349    #[test]
350    fn test_iris_feature_extraction() {
351        let dataset = iris().unwrap_or_else(|e| panic!("Failed: {e}"));
352        let features = dataset.features();
353        assert!(features.is_ok());
354
355        let features = features.unwrap_or_else(|e| panic!("Failed: {e}"));
356        assert_eq!(features.schema().fields().len(), 4);
357        assert!(features.schema().field_with_name("species").is_err());
358    }
359
360    #[test]
361    fn test_iris_description() {
362        let dataset = iris().unwrap_or_else(|e| panic!("Failed: {e}"));
363        assert!(dataset.description().contains("Fisher"));
364        assert!(dataset.description().contains("150"));
365    }
366
367    #[test]
368    fn test_iris_canonical_trait() {
369        let dataset = iris().unwrap_or_else(|e| panic!("Failed: {e}"));
370
371        assert_eq!(dataset.feature_names().len(), 4);
372        assert_eq!(dataset.target_name(), "species");
373        assert!(!dataset.is_empty());
374    }
375
376    #[test]
377    fn test_iris_into_inner() {
378        let dataset = iris().unwrap();
379        let inner = dataset.into_inner();
380        assert_eq!(inner.len(), 150);
381    }
382
383    #[test]
384    fn test_iris_clone() {
385        let dataset = iris().unwrap();
386        let cloned = dataset.clone();
387        assert_eq!(cloned.len(), dataset.len());
388    }
389
390    #[test]
391    fn test_iris_debug() {
392        let dataset = iris().unwrap();
393        let debug = format!("{:?}", dataset);
394        assert!(debug.contains("IrisDataset"));
395    }
396
397    #[test]
398    fn test_iris_data_access() {
399        let dataset = iris().unwrap();
400        let data = dataset.data();
401        assert_eq!(data.len(), 150);
402    }
403
404    #[test]
405    fn test_iris_data_function() {
406        let (sl, sw, pl, pw, species) = iris_data();
407        assert_eq!(sl.len(), 150);
408        assert_eq!(sw.len(), 150);
409        assert_eq!(pl.len(), 150);
410        assert_eq!(pw.len(), 150);
411        assert_eq!(species.len(), 150);
412    }
413
414    #[test]
415    fn test_iris_data_species_distribution() {
416        let (_, _, _, _, species) = iris_data();
417        let setosa_count = species.iter().filter(|&&s| s == "setosa").count();
418        let versicolor_count = species.iter().filter(|&&s| s == "versicolor").count();
419        let virginica_count = species.iter().filter(|&&s| s == "virginica").count();
420        assert_eq!(setosa_count, 50);
421        assert_eq!(versicolor_count, 50);
422        assert_eq!(virginica_count, 50);
423    }
424
425    #[test]
426    fn test_iris_sepal_length_range() {
427        let (sepal_length, _, _, _, _) = iris_data();
428        for &val in &sepal_length {
429            assert!(
430                (4.0..=8.0).contains(&val),
431                "Sepal length {} out of typical range",
432                val
433            );
434        }
435    }
436
437    #[test]
438    fn test_iris_sepal_width_range() {
439        let (_, sepal_width, _, _, _) = iris_data();
440        for &val in &sepal_width {
441            assert!(
442                (2.0..=5.0).contains(&val),
443                "Sepal width {} out of typical range",
444                val
445            );
446        }
447    }
448
449    #[test]
450    fn test_iris_feature_names_content() {
451        let dataset = iris().unwrap();
452        let names = dataset.feature_names();
453        assert!(names.contains(&"sepal_length"));
454        assert!(names.contains(&"sepal_width"));
455        assert!(names.contains(&"petal_length"));
456        assert!(names.contains(&"petal_width"));
457    }
458}