dataset_ml/iris.rs
1//! Iris flower dataset.
2//!
3//! The classic Fisher Iris dataset for multi-class classification.
4//! It contains measurements for three Iris species: `setosa`, `versicolor`,
5//! and `virginica`.
6//!
7//! **Features (4):**
8//! - `sepal_length` - sepal length in cm
9//! - `sepal_width` - sepal width in cm
10//! - `petal_length` - petal length in cm
11//! - `petal_width` - petal width in cm
12//!
13//! **Target:** `species` - one of `setosa`, `versicolor`, or `virginica`
14//!
15//! **Samples:** 150 total, with 50 samples per species
16//! **Application:** Multi-class classification / species recognition
17//!
18//! **Source:** UCI Machine Learning Repository
19//! <https://doi.org/10.24432/C56C76>
20
21use dataset_core::{Dataset, DatasetError, acquire_dataset, download_to};
22use csv::ReaderBuilder;
23use ndarray::{Array1, Array2};
24use std::fs::File;
25
26/// The URL for the Iris dataset.
27///
28/// # Citation
29///
30/// R. A. Fisher. "Iris," UCI Machine Learning Repository, \[Online\].
31/// Available: <https://doi.org/10.24432/C56C76>
32const IRIS_DATA_URL: &str = "https://gist.githubusercontent.com/curran/a08a1080b88344b0c8a7/raw/0e7a9b0a5d22642a06d3d5b9bcbad9890c8ee534/iris.csv";
33
34/// The name of the Iris dataset file.
35const IRIS_FILENAME: &str = "iris.csv";
36
37/// The SHA256 hash of the Iris dataset file.
38const IRIS_SHA256: &str = "c52742e50315a99f956a383faedf7575552675f6409ef0f9a47076dd08479930";
39
40/// The name of the dataset
41const IRIS_DATASET_NAME: &str = "iris";
42
43/// A struct representing the Iris dataset with lazy loading.
44///
45/// The dataset is not loaded until you call one of the data accessor methods.
46/// Once loaded, the data is cached for subsequent accesses.
47///
48/// # About Dataset
49///
50/// The Iris dataset is a classic dataset for classification tasks. It includes three iris species
51/// with 50 samples each as well as some properties about each flower. One flower species is
52/// linearly separable from the other two, but the other two are not linearly separable from each other.
53///
54/// Features:
55/// - sepal length in cm
56/// - sepal width in cm
57/// - petal length in cm
58/// - petal width in cm
59///
60/// Labels:
61/// - species name (in `&str`): `"setosa"`, `"versicolor"`, `"virginica"`
62///
63/// See more information at <https://archive.ics.uci.edu/dataset/53/iris>
64///
65/// # Citation
66///
67/// R. A. Fisher. "Iris," UCI Machine Learning Repository, \[Online\].
68/// Available: <https://doi.org/10.24432/C56C76>
69///
70/// # Thread Safety
71///
72/// This struct automatically implements `Send` and `Sync` (All fields implement them), making it safe to share across threads.
73/// The internal [`Dataset`] ensures thread-safe lazy initialization.
74///
75/// # Example
76/// ```no_run
77/// use dataset_ml::iris::Iris;
78///
79/// let download_dir = "./iris"; // the code will create the directory if it doesn't exist
80///
81/// let dataset = Iris::new(download_dir);
82/// let features = dataset.features().unwrap();
83/// let labels = dataset.labels().unwrap();
84///
85/// let (features, labels) = dataset.data().unwrap(); // this is also a way to get features and labels
86/// // you can use `.to_owned()` to get owned copies of the data
87/// let mut features_owned = features.to_owned();
88/// let mut labels_owned = labels.to_owned();
89///
90/// // Example: Modify feature values
91/// features_owned[[0, 0]] = 5.5;
92/// labels_owned[0] = "setosa-modified";
93///
94/// assert_eq!(features.shape(), &[150, 4]);
95/// assert_eq!(labels.len(), 150);
96/// ```
97#[derive(Debug)]
98pub struct Iris {
99 dataset: Dataset<(Array2<f64>, Array1<&'static str>)>,
100}
101
102impl Iris {
103 /// Create a new Iris instance without loading data.
104 ///
105 /// The dataset will be loaded lazily when you first call any data accessor method.
106 /// This is a lightweight operation that only stores the storage directory.
107 ///
108 /// # Parameters
109 ///
110 /// - `storage_dir` - Directory where the dataset will be stored.
111 ///
112 /// # Returns
113 ///
114 /// - `Self` - `Iris` instance ready for lazy loading.
115 pub fn new(storage_dir: &str) -> Self {
116 Iris {
117 dataset: Dataset::new(storage_dir),
118 }
119 }
120
121 /// Acquire and parse the Iris dataset.
122 fn load_data(dir: &str) -> Result<(Array2<f64>, Array1<&'static str>), DatasetError> {
123 // Prepare the dataset file
124 let file_path = acquire_dataset(
125 dir,
126 IRIS_FILENAME,
127 IRIS_DATASET_NAME,
128 Some(IRIS_SHA256),
129 |temp_path| {
130 download_to(IRIS_DATA_URL, temp_path, None)?;
131 Ok(temp_path.join(IRIS_FILENAME))
132 },
133 )?;
134
135 // Parse the file
136 let file = File::open(&file_path)?;
137 let mut rdr = ReaderBuilder::new().has_headers(true).from_reader(file);
138
139 let mut features = Vec::new();
140 let mut labels = Vec::new();
141 let mut num_features: Option<usize> = None;
142
143 for (idx, result) in rdr.records().enumerate() {
144 let record = result.map_err(|e| DatasetError::csv_read_error(IRIS_DATASET_NAME, e))?;
145 let line_num = idx + 2; // +1 for 0-indexed, +1 for header
146
147 if num_features.is_none() {
148 if record.len() < 2 {
149 return Err(DatasetError::invalid_column_count(
150 IRIS_DATASET_NAME,
151 2,
152 record.len(),
153 line_num,
154 ));
155 }
156 num_features = Some(record.len() - 1);
157 }
158
159 let n_features = num_features.unwrap();
160 if record.len() != n_features + 1 {
161 return Err(DatasetError::invalid_column_count(
162 IRIS_DATASET_NAME,
163 n_features + 1,
164 record.len(),
165 line_num,
166 ));
167 }
168
169 for i in 0..n_features {
170 features.push(record[i].parse::<f64>().map_err(|e| {
171 let field = format!("feature[{i}]");
172 DatasetError::parse_failed(
173 IRIS_DATASET_NAME,
174 &field,
175 line_num,
176 e,
177 )
178 })?);
179 }
180
181 labels.push(match &record[n_features] {
182 "setosa" => "setosa",
183 "versicolor" => "versicolor",
184 "virginica" => "virginica",
185 other => {
186 return Err(DatasetError::invalid_value(
187 IRIS_DATASET_NAME,
188 "label",
189 other,
190 line_num,
191 ));
192 }
193 });
194 }
195
196 let n_samples = labels.len();
197 if n_samples == 0 {
198 return Err(DatasetError::empty_dataset(IRIS_DATASET_NAME));
199 }
200
201 let n_features = num_features.unwrap();
202 let features_array = Array2::from_shape_vec((n_samples, n_features), features)
203 .map_err(|e| DatasetError::array_shape_error(IRIS_DATASET_NAME, "features", e))?;
204 let labels_array = Array1::from_vec(labels);
205
206 Ok((features_array, labels_array))
207 }
208
209 /// Get a reference to the feature matrix.
210 ///
211 /// This method triggers lazy loading on first call. Subsequent calls return
212 /// the cached data instantly.
213 ///
214 /// # Returns
215 ///
216 /// - `&Array2<f64>` - Reference to feature matrix with shape `(150, 4)` containing:
217 /// - sepal length in cm
218 /// - sepal width in cm
219 /// - petal length in cm
220 /// - petal width in cm
221 ///
222 /// # Errors
223 ///
224 /// Returns `DatasetError` if:
225 /// - Download fails due to network issues
226 /// - File extraction or I/O operations fail
227 /// - Data format is invalid (wrong number of columns, unparseable values, or invalid labels)
228 /// - Dataset size doesn't match expected dimensions (150 samples, 4 features)
229 pub fn features(&self) -> Result<&Array2<f64>, DatasetError> {
230 Ok(&self.dataset.load(Self::load_data)?.0)
231 }
232
233 /// Get a reference to the labels vector.
234 ///
235 /// This method triggers lazy loading on first call. Subsequent calls return
236 /// the cached data instantly.
237 ///
238 /// # Returns
239 ///
240 /// - `&Array1<&'static str>` - Reference to labels vector with shape `(150,)` containing species names (`"setosa"`, `"versicolor"`, `"virginica"`)
241 ///
242 /// # Errors
243 ///
244 /// Returns `DatasetError` if:
245 /// - Download fails due to network issues
246 /// - File extraction or I/O operations fail
247 /// - Data format is invalid (wrong number of columns, unparseable values, or invalid labels)
248 /// - Dataset size doesn't match expected dimensions (150 samples)
249 pub fn labels(&self) -> Result<&Array1<&'static str>, DatasetError> {
250 Ok(&self.dataset.load(Self::load_data)?.1)
251 }
252
253 /// Get both features and labels as references.
254 ///
255 /// This method triggers lazy loading on first call. Subsequent calls return
256 /// the cached data instantly.
257 ///
258 /// # Returns
259 ///
260 /// - `&Array2<f64>` - Reference to feature matrix with shape `(150, 4)` containing:
261 /// - sepal length in cm
262 /// - sepal width in cm
263 /// - petal length in cm
264 /// - petal width in cm
265 /// - `&Array1<&'static str>` - Reference to labels vector with shape `(150,)` containing species names (`"setosa"`, `"versicolor"`, `"virginica"`)
266 ///
267 /// # Errors
268 ///
269 /// Returns `DatasetError` if:
270 /// - Download fails due to network issues
271 /// - File extraction or I/O operations fail
272 /// - Data format is invalid (wrong number of columns, unparseable values, or invalid labels)
273 /// - Dataset size doesn't match expected dimensions (150 samples, 4 features)
274 pub fn data(&self) -> Result<(&Array2<f64>, &Array1<&'static str>), DatasetError> {
275 let data = self.dataset.load(Self::load_data)?;
276 Ok((&data.0, &data.1))
277 }
278}