dataset_core/datasets/iris.rs
1use crate::{Dataset, DatasetError, acquire_dataset, download_to};
2use csv::ReaderBuilder;
3use ndarray::{Array1, Array2};
4use std::fs::File;
5
6/// The URL for the Iris dataset.
7///
8/// # Citation
9///
10/// R. A. Fisher. "Iris," UCI Machine Learning Repository, \[Online\].
11/// Available: <https://doi.org/10.24432/C56C76>
12const IRIS_DATA_URL: &str = "https://gist.githubusercontent.com/curran/a08a1080b88344b0c8a7/raw/0e7a9b0a5d22642a06d3d5b9bcbad9890c8ee534/iris.csv";
13
14/// The name of the Iris dataset file.
15const IRIS_FILENAME: &str = "iris.csv";
16
17/// The SHA256 hash of the Iris dataset file.
18const IRIS_SHA256: &str = "c52742e50315a99f956a383faedf7575552675f6409ef0f9a47076dd08479930";
19
20/// The name of the dataset
21const IRIS_DATASET_NAME: &str = "iris";
22
23/// A struct representing the Iris dataset with lazy loading.
24///
25/// The dataset is not loaded until you call one of the data accessor methods.
26/// Once loaded, the data is cached for subsequent accesses.
27///
28/// # About Dataset
29///
30/// The Iris dataset is a classic dataset for classification tasks. It includes three iris species
31/// with 50 samples each as well as some properties about each flower. One flower species is
32/// linearly separable from the other two, but the other two are not linearly separable from each other.
33///
34/// Features:
35/// - sepal length in cm
36/// - sepal width in cm
37/// - petal length in cm
38/// - petal width in cm
39///
40/// Labels:
41/// - species name (in `&str`): `"setosa"`, `"versicolor"`, `"virginica"`
42///
43/// See more information at <https://archive.ics.uci.edu/dataset/53/iris>
44///
45/// # Citation
46///
47/// R. A. Fisher. "Iris," UCI Machine Learning Repository, \[Online\].
48/// Available: <https://doi.org/10.24432/C56C76>
49///
50/// # Thread Safety
51///
52/// This struct automatically implements `Send` and `Sync` (All fields implement them), making it safe to share across threads.
53/// The internal [`Dataset`] ensures thread-safe lazy initialization.
54///
55/// # Example
56/// ```rust
57/// use dataset_core::datasets::iris::Iris;
58///
59/// let download_dir = "./iris"; // the code will create the directory if it doesn't exist
60///
61/// let dataset = Iris::new(download_dir);
62/// let features = dataset.features().unwrap();
63/// let labels = dataset.labels().unwrap();
64///
65/// let (features, labels) = dataset.data().unwrap(); // this is also a way to get features and labels
66/// // you can use `.to_owned()` to get owned copies of the data
67/// let mut features_owned = features.to_owned();
68/// let mut labels_owned = labels.to_owned();
69///
70/// // Example: Modify feature values
71/// features_owned[[0, 0]] = 5.5;
72/// labels_owned[0] = "setosa-modified";
73///
74/// assert_eq!(features.shape(), &[150, 4]);
75/// assert_eq!(labels.len(), 150);
76///
77/// // clean up: remove the downloaded files (dispensable)
78/// std::fs::remove_dir_all(download_dir).unwrap();
79/// ```
80#[derive(Debug)]
81pub struct Iris {
82 dataset: Dataset<(Array2<f64>, Array1<&'static str>)>,
83}
84
85impl Iris {
86 /// Create a new Iris instance without loading data.
87 ///
88 /// The dataset will be loaded lazily when you first call any data accessor method.
89 /// This is a lightweight operation that only stores the storage directory.
90 ///
91 /// # Parameters
92 ///
93 /// - `storage_dir` - Directory where the dataset will be stored.
94 ///
95 /// # Returns
96 ///
97 /// - `Self` - `Iris` instance ready for lazy loading.
98 pub fn new(storage_dir: &str) -> Self {
99 Iris {
100 dataset: Dataset::new(storage_dir),
101 }
102 }
103
104 /// Acquire and parse the Iris dataset.
105 fn load_data(dir: &str) -> Result<(Array2<f64>, Array1<&'static str>), DatasetError> {
106 // Prepare the dataset file
107 let file_path = acquire_dataset(
108 dir,
109 IRIS_FILENAME,
110 IRIS_DATASET_NAME,
111 Some(IRIS_SHA256),
112 |temp_path| {
113 download_to(IRIS_DATA_URL, temp_path, None)?;
114 Ok(temp_path.join(IRIS_FILENAME))
115 },
116 )?;
117
118 // Parse the file
119 let file = File::open(&file_path)?;
120 let mut rdr = ReaderBuilder::new().has_headers(true).from_reader(file);
121
122 let mut features = Vec::new();
123 let mut labels = Vec::new();
124 let mut num_features: Option<usize> = None;
125
126 for (idx, result) in rdr.records().enumerate() {
127 let record = result.map_err(|e| DatasetError::csv_read_error(IRIS_DATASET_NAME, e))?;
128 let line_num = idx + 2; // +1 for 0-indexed, +1 for header
129
130 if num_features.is_none() {
131 if record.len() < 2 {
132 return Err(DatasetError::invalid_column_count(
133 IRIS_DATASET_NAME,
134 2,
135 record.len(),
136 line_num,
137 &format!("{:?}", record),
138 ));
139 }
140 num_features = Some(record.len() - 1);
141 }
142
143 let n_features = num_features.unwrap();
144 if record.len() != n_features + 1 {
145 return Err(DatasetError::invalid_column_count(
146 IRIS_DATASET_NAME,
147 n_features + 1,
148 record.len(),
149 line_num,
150 &format!("{:?}", record),
151 ));
152 }
153
154 for i in 0..n_features {
155 features.push(record[i].parse::<f64>().map_err(|e| {
156 let field = format!("feature[{i}]");
157 DatasetError::parse_failed(
158 IRIS_DATASET_NAME,
159 &field,
160 line_num,
161 &format!("{:?}", record),
162 e,
163 )
164 })?);
165 }
166
167 labels.push(match &record[n_features] {
168 "setosa" => "setosa",
169 "versicolor" => "versicolor",
170 "virginica" => "virginica",
171 other => {
172 return Err(DatasetError::invalid_value(
173 IRIS_DATASET_NAME,
174 "label",
175 other,
176 line_num,
177 &format!("{:?}", record),
178 ));
179 }
180 });
181 }
182
183 let n_samples = labels.len();
184 if n_samples == 0 {
185 return Err(DatasetError::empty_dataset(IRIS_DATASET_NAME));
186 }
187
188 let n_features = num_features.unwrap();
189 let features_array = Array2::from_shape_vec((n_samples, n_features), features)
190 .map_err(|e| DatasetError::array_shape_error(IRIS_DATASET_NAME, "features", e))?;
191 let labels_array = Array1::from_vec(labels);
192
193 Ok((features_array, labels_array))
194 }
195
196 /// Get a reference to the feature matrix.
197 ///
198 /// This method triggers lazy loading on first call. Subsequent calls return
199 /// the cached data instantly.
200 ///
201 /// # Returns
202 ///
203 /// - `&Array2<f64>` - Reference to feature matrix with shape `(150, 4)` containing:
204 /// - sepal length in cm
205 /// - sepal width in cm
206 /// - petal length in cm
207 /// - petal width in cm
208 ///
209 /// # Errors
210 ///
211 /// Returns `DatasetError` if:
212 /// - Download fails due to network issues
213 /// - File extraction or I/O operations fail
214 /// - Data format is invalid (wrong number of columns, unparseable values, or invalid labels)
215 /// - Dataset size doesn't match expected dimensions (150 samples, 4 features)
216 pub fn features(&self) -> Result<&Array2<f64>, DatasetError> {
217 Ok(&self.dataset.load(Self::load_data)?.0)
218 }
219
220 /// Get a reference to the labels vector.
221 ///
222 /// This method triggers lazy loading on first call. Subsequent calls return
223 /// the cached data instantly.
224 ///
225 /// # Returns
226 ///
227 /// - `&Array1<&'static str>` - Reference to labels vector with shape `(150,)` containing species names (`"setosa"`, `"versicolor"`, `"virginica"`)
228 ///
229 /// # Errors
230 ///
231 /// Returns `DatasetError` if:
232 /// - Download fails due to network issues
233 /// - File extraction or I/O operations fail
234 /// - Data format is invalid (wrong number of columns, unparseable values, or invalid labels)
235 /// - Dataset size doesn't match expected dimensions (150 samples)
236 pub fn labels(&self) -> Result<&Array1<&'static str>, DatasetError> {
237 Ok(&self.dataset.load(Self::load_data)?.1)
238 }
239
240 /// Get both features and labels as references.
241 ///
242 /// This method triggers lazy loading on first call. Subsequent calls return
243 /// the cached data instantly.
244 ///
245 /// # Returns
246 ///
247 /// - `&Array2<f64>` - Reference to feature matrix with shape `(150, 4)` containing:
248 /// - sepal length in cm
249 /// - sepal width in cm
250 /// - petal length in cm
251 /// - petal width in cm
252 /// - `&Array1<&'static str>` - Reference to labels vector with shape `(150,)` containing species names (`"setosa"`, `"versicolor"`, `"virginica"`)
253 ///
254 /// # Errors
255 ///
256 /// Returns `DatasetError` if:
257 /// - Download fails due to network issues
258 /// - File extraction or I/O operations fail
259 /// - Data format is invalid (wrong number of columns, unparseable values, or invalid labels)
260 /// - Dataset size doesn't match expected dimensions (150 samples, 4 features)
261 pub fn data(&self) -> Result<(&Array2<f64>, &Array1<&'static str>), DatasetError> {
262 let data = self.dataset.load(Self::load_data)?;
263 Ok((&data.0, &data.1))
264 }
265}