iris_data/
lib.rs

1use serde::{Deserialize, Serialize};
2
3/// This famous (Fisher's or Anderson's) iris data set gives the measurements in centimeters of the variables sepal length and width and petal length and width, respectively, for 50 flowers from each of 3 species of iris. The species are Iris setosa, versicolor, and virginica.
4#[derive(Serialize, Deserialize, PartialEq, Debug)]
5pub struct Iris {
6    /// Sepal length in cm
7    #[serde(rename = "Sepal.Length")]
8    pub sepal_length: f64,
9    /// Sepal width in cm
10    #[serde(rename = "Sepal.Width")]
11    pub sepal_width: f64,
12    /// Petal length in cm
13    #[serde(rename = "Petal.Length")]
14    pub petal_length: f64,
15    /// Petal width in cm
16    #[serde(rename = "Petal.Width")]
17    pub petal_width: f64,
18    /// Class of iris plant
19    #[serde(rename = "Species")]
20    pub species: Species,
21}
22
23/// The species of iris. The species are Iris setosa, versicolor, and virginica.
24#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
25#[serde(rename_all = "lowercase")]
26pub enum Species {
27    Setosa,
28    Versicolor,
29    Virginica,
30}
31
32#[cfg(test)]
33mod tests {
34    use std::error::Error;
35    use csv::Reader;
36    use super::*;
37
38    const FIRST_RECORD_DATA: &str =
39        "Sepal.Length,Sepal.Width,Petal.Length,Petal.Width,Species\n5.1,3.5,1.4,0.2,setosa";
40
41    /// First record in the dataset
42    fn first_record() -> Iris {
43        Iris {
44            sepal_length: 5.1,
45            sepal_width: 3.5,
46            petal_length: 1.4,
47            petal_width: 0.2,
48            species: Species::Setosa,
49        }
50    }
51
52    #[test]
53    fn test_deserialize() -> Result<(), Box<dyn Error>> {
54        let mut rdr = Reader::from_reader(FIRST_RECORD_DATA.as_bytes());
55        let mut iter = rdr.deserialize();
56
57        let res = iter.next().unwrap();
58        let record: Iris = res?;
59
60        assert_eq!(record, first_record());
61
62        Ok(())
63    }
64
65    /// Writing is not exactly the same due to differences in decimal points
66    #[test]
67    fn test_serialize() -> Result<(), Box<dyn Error>> {
68        let mut writer = csv::Writer::from_writer(vec![]);
69        writer.serialize(first_record())?;
70
71        // with headers
72        let data = String::from_utf8(writer.into_inner()?)?;
73
74        // trim is to remove the Unix trailing newline
75        let (hdrs, record) = data.trim_end().split_once('\n').unwrap();
76
77        // no trailing newline here
78        let expected = FIRST_RECORD_DATA.split_once('\n').unwrap();
79
80        // headers should match
81        assert_eq!(hdrs, expected.0);
82        assert_eq!(record, "5.1,3.5,1.4,0.2,setosa");
83
84        Ok(())
85    }
86}