rusty_ai/data/
dataset.rs

1use nalgebra::{DMatrix, DVector};
2use num_traits::{Float, FromPrimitive, Num, ToPrimitive};
3use rand::seq::SliceRandom;
4use rand::Rng;
5use rand::{rngs::StdRng, SeedableRng};
6use std::cmp::PartialOrd;
7use std::error::Error;
8use std::fmt::{self, Display};
9use std::fmt::{Debug, Formatter};
10use std::hash::Hash;
11use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign};
12
13pub trait DataValue:
14    Debug
15    + Clone
16    + Copy
17    + Num
18    + FromPrimitive
19    + ToPrimitive
20    + AddAssign
21    + SubAssign
22    + MulAssign
23    + DivAssign
24    + Send
25    + Sync
26    + Display
27    + 'static
28{
29}
30
31impl<T> DataValue for T where
32    T: Debug
33        + Clone
34        + Copy
35        + Num
36        + FromPrimitive
37        + ToPrimitive
38        + AddAssign
39        + SubAssign
40        + MulAssign
41        + DivAssign
42        + Send
43        + Sync
44        + Display
45        + 'static
46{
47}
48
49pub trait Number: DataValue + PartialOrd {}
50impl<T> Number for T where T: DataValue + PartialOrd {}
51
52pub trait WholeNumber: Number + Eq + Hash {}
53impl<T> WholeNumber for T where T: Number + Eq + Hash {}
54
55pub trait RealNumber: Number + Float {}
56impl<T> RealNumber for T where T: Number + Float {}
57
58pub trait TargetValue: DataValue {}
59impl<T> TargetValue for T where T: DataValue {}
60
61pub struct Dataset<XT: Number, YT: TargetValue> {
62    pub x: DMatrix<XT>,
63    pub y: DVector<YT>,
64}
65
66impl<XT: Number, YT: TargetValue> Debug for Dataset<XT, YT> {
67    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
68        write!(f, "Dataset {{\n    x: [\n")?;
69
70        for i in 0..self.x.nrows() {
71            write!(f, "        [")?;
72            for j in 0..self.x.ncols() {
73                write!(f, "{:?}, ", self.x[(i, j)])?;
74            }
75            writeln!(f, "],")?;
76        }
77
78        write!(f, "    ],\n    y: [")?;
79        for i in 0..self.y.len() {
80            write!(f, "{:?}, ", self.y[i])?;
81        }
82        write!(f, "]\n}}")
83    }
84}
85
86/// Implementation of a generic dataset structure.
87///
88/// This structure represents a dataset consisting of input features (`x`) and target values (`y`).
89/// It provides various methods for manipulating and analyzing the dataset.
90///
91/// # Type Parameters
92///
93/// - `XT`: The type of the input features.
94/// - `YT`: The type of the target values.
95///
96/// # Examples
97///
98/// ```
99/// use nalgebra::{DMatrix, DVector};
100/// use rusty_ai::data::dataset::Dataset;
101/// use rand::prelude::*;
102///
103/// // Define a dataset with input features of type f64 and target values of type u32
104/// let x = DMatrix::from_row_slice(3, 2, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
105/// let y = DVector::from_vec(vec![0, 1, 0]);
106/// let dataset = Dataset::new(x, y);
107///
108/// // Split the dataset into training and testing sets
109/// let (mut train_set, test_set) = dataset.train_test_split(0.8, Some(42)).unwrap();
110///
111/// // Standardize the input features of the dataset
112/// train_set.standardize();
113///
114/// // Split the dataset based on a threshold value
115/// let (left_set, right_set) = dataset.split_on_threshold(0, 3.5);
116///
117/// // Sample a subset of the dataset
118/// let sample_set = dataset.samples(2, Some(123));
119/// ```
120
121impl<XT: Number, YT: TargetValue> Dataset<XT, YT> {
122    /// Creates a new dataset with the given input features and target values.
123    ///
124    /// # Arguments
125    ///
126    /// * `x` - The input features of the dataset.
127    /// * `y` - The target values of the dataset.
128    ///
129    /// # Returns
130    ///
131    /// A new `Dataset` instance.
132    pub fn new(x: DMatrix<XT>, y: DVector<YT>) -> Self {
133        Self { x, y }
134    }
135
136    /// Splits the dataset into its constituent parts.
137    ///
138    /// # Returns
139    ///
140    /// A tuple containing references to the input features and target values of the dataset.
141    pub fn into_parts(&self) -> (&DMatrix<XT>, &DVector<YT>) {
142        (&self.x, &self.y)
143    }
144
145    /// Checks if the dataset is not empty.
146    ///
147    /// # Returns
148    ///
149    /// `true` if the dataset is not empty, `false` otherwise.
150    pub fn is_not_empty(&self) -> bool {
151        !(self.x.is_empty() || self.y.is_empty())
152    }
153
154    /// Returns the number of rows in the dataset.
155    ///
156    /// # Returns
157    ///
158    /// The number of rows in the dataset.
159    pub fn nrows(&self) -> usize {
160        self.x.nrows()
161    }
162
163    /// Standardizes the input features of the dataset.
164    ///
165    /// This method calculates the mean and standard deviation of each input feature and
166    /// standardizes the values by subtracting the mean and dividing by the standard deviation.
167    ///
168    /// # Requirements
169    ///
170    /// The input features (`XT`) must implement the `RealNumber` trait.
171    pub fn standardize(&mut self)
172    where
173        XT: RealNumber,
174    {
175        let (nrows, _) = self.x.shape();
176
177        let means = self
178            .x
179            .column_iter()
180            .map(|col| col.sum() / XT::from_usize(col.len()).unwrap())
181            .collect::<Vec<_>>();
182        let std_devs = self
183            .x
184            .column_iter()
185            .zip(means.iter())
186            .map(|(col, mean)| {
187                let mut sum = XT::from_f64(0.0).unwrap();
188                for val in col.iter() {
189                    sum += (*val - *mean) * (*val - *mean);
190                }
191                (sum / XT::from_usize(nrows).unwrap()).sqrt()
192            })
193            .collect::<Vec<_>>();
194        let standardized_cols = self
195            .x
196            .column_iter()
197            .zip(means.iter())
198            .zip(std_devs.iter())
199            .map(|((col, &mean), &std_dev)| col.map(|val| (val - mean) / std_dev))
200            .collect::<Vec<_>>();
201        self.x = DMatrix::from_columns(&standardized_cols);
202    }
203
204    /// Splits the dataset into training and testing sets.
205    ///
206    /// # Arguments
207    ///
208    /// * `train_size` - The proportion of the dataset to use for training. Should be between 0.0 and 1.0.
209    /// * `seed` - An optional seed value for the random number generator.
210    ///
211    /// # Returns
212    ///
213    /// A result containing the training and testing datasets, or an error if the train size is invalid.
214    pub fn train_test_split(
215        &self,
216        train_size: f64,
217        seed: Option<u64>,
218    ) -> Result<(Self, Self), Box<dyn Error>> {
219        if !(0.0..=1.0).contains(&train_size) {
220            return Err("Train size should be between 0.0 and 1.0".into());
221        }
222        let mut rng = match seed {
223            Some(seed) => StdRng::seed_from_u64(seed),
224            None => StdRng::from_entropy(),
225        };
226
227        let mut indices = (0..self.x.nrows()).collect::<Vec<_>>();
228        indices.shuffle(&mut rng);
229        let train_size = (self.x.nrows() as f64 * train_size).floor() as usize;
230        let train_indices = &indices[..train_size];
231        let test_indices = &indices[train_size..];
232
233        let train_x = train_indices
234            .iter()
235            .map(|&index| self.x.row(index))
236            .collect::<Vec<_>>();
237        let train_y = train_indices
238            .iter()
239            .map(|&index| self.y[index])
240            .collect::<Vec<_>>();
241
242        let test_x = test_indices
243            .iter()
244            .map(|&index| self.x.row(index))
245            .collect::<Vec<_>>();
246        let test_y = test_indices
247            .iter()
248            .map(|&index| self.y[index])
249            .collect::<Vec<_>>();
250
251        let train_dataset = Self::new(DMatrix::from_rows(&train_x), DVector::from_vec(train_y));
252        let test_dataset = Self::new(DMatrix::from_rows(&test_x), DVector::from_vec(test_y));
253
254        Ok((train_dataset, test_dataset))
255    }
256
257    /// Splits the dataset based on a threshold value.
258    ///
259    /// This method partitions the dataset into two subsets based on the specified feature index and threshold value.
260    /// The left subset contains rows where the feature value is less than or equal to the threshold,
261    /// while the right subset contains rows where the feature value is greater than the threshold.
262    ///
263    /// # Arguments
264    ///
265    /// * `feature_index` - The index of the feature to split on.
266    /// * `threshold` - The threshold value for the split.
267    ///
268    /// # Returns
269    ///
270    /// A tuple containing the left and right subsets of the dataset.
271    pub fn split_on_threshold(&self, feature_index: usize, threshold: XT) -> (Self, Self) {
272        let (left_indices, right_indices): (Vec<_>, Vec<_>) = self
273            .x
274            .row_iter()
275            .enumerate()
276            .partition(|(_, row)| row[feature_index] <= threshold);
277
278        let left_x: Vec<_> = left_indices
279            .iter()
280            .map(|&(index, _)| self.x.row(index))
281            .collect();
282        let left_y: Vec<_> = left_indices
283            .iter()
284            .map(|&(index, _)| self.y.row(index))
285            .collect();
286
287        let right_x: Vec<_> = right_indices
288            .iter()
289            .map(|&(index, _)| self.x.row(index))
290            .collect();
291        let right_y: Vec<_> = right_indices
292            .iter()
293            .map(|&(index, _)| self.y.row(index))
294            .collect();
295
296        let left_dataset = if left_x.is_empty() {
297            Self::new(DMatrix::zeros(0, self.x.ncols()), DVector::zeros(0))
298        } else {
299            Self::new(DMatrix::from_rows(&left_x), DVector::from_rows(&left_y))
300        };
301
302        let right_dataset = if right_x.is_empty() {
303            Self::new(DMatrix::zeros(0, self.x.ncols()), DVector::zeros(0))
304        } else {
305            Self::new(DMatrix::from_rows(&right_x), DVector::from_rows(&right_y))
306        };
307
308        (left_dataset, right_dataset)
309    }
310
311    /// Samples a subset of the dataset.
312    ///
313    /// This method randomly selects a specified number of rows from the dataset to create a new subset.
314    ///
315    /// # Arguments
316    ///
317    /// * `sample_size` - The number of rows to sample.
318    /// * `seed` - An optional seed value for the random number generator.
319    ///
320    /// # Returns
321    ///
322    /// A new dataset containing the sampled subset.
323    pub fn samples(&self, sample_size: usize, seed: Option<u64>) -> Self {
324        let mut rng = match seed {
325            Some(seed) => StdRng::seed_from_u64(seed),
326            None => StdRng::from_entropy(),
327        };
328
329        let nrows = self.x.nrows();
330        let sample_indices = (0..sample_size)
331            .map(|_| rng.gen_range(0..nrows))
332            .collect::<Vec<_>>();
333
334        let sample_x = sample_indices
335            .iter()
336            .map(|&index| self.x.row(index))
337            .collect::<Vec<_>>();
338        let sample_y = sample_indices
339            .iter()
340            .map(|&index| self.y[index])
341            .collect::<Vec<_>>();
342
343        Self::new(DMatrix::from_rows(&sample_x), DVector::from_vec(sample_y))
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use approx::assert_relative_eq;
350
351    use super::*;
352
353    #[test]
354    fn test_dataset_new() {
355        let x = DMatrix::from_row_slice(2, 2, &[1, 2, 3, 4]);
356        let y = DVector::from_vec(vec![5, 6]);
357        let dataset = Dataset::new(x.clone(), y.clone());
358        assert_eq!(dataset.x, x);
359        assert_eq!(dataset.y, y);
360    }
361
362    #[test]
363    fn test_dataset_into_parts() {
364        let x = DMatrix::from_row_slice(2, 2, &[1, 2, 3, 4]);
365        let y = DVector::from_vec(vec![5, 6]);
366        let dataset = Dataset::new(x.clone(), y.clone());
367        let (x_parts, y_parts) = dataset.into_parts();
368        assert_eq!(x_parts, &x);
369        assert_eq!(y_parts, &y);
370    }
371
372    #[test]
373    fn test_dataset_formatting() {
374        // Create a simple dataset
375        let x = DMatrix::from_row_slice(2, 2, &[1, 2, 3, 4]);
376        let y = DVector::from_vec(vec![5, 6]);
377        let dataset = Dataset::new(x, y);
378
379        // Get the string representation of the dataset
380        let dataset_str = format!("{:?}", dataset);
381
382        // Define the expected string
383        let expected_str = "\
384Dataset {
385    x: [
386        [1, 2, ],
387        [3, 4, ],
388    ],
389    y: [5, 6, ]
390}";
391
392        // Compare the generated string with the expected string
393        assert_eq!(dataset_str, expected_str);
394    }
395
396    #[test]
397    fn test_dataset_is_not_empty() {
398        let x = DMatrix::from_row_slice(2, 2, &[1, 2, 3, 4]);
399        let y = DVector::from_vec(vec![5, 6]);
400        let dataset = Dataset::new(x, y);
401        assert!(dataset.is_not_empty());
402
403        let empty_x = DMatrix::<f64>::from_row_slice(0, 2, &[]);
404        let empty_y = DVector::<f64>::from_vec(vec![]);
405        let empty_dataset = Dataset::new(empty_x, empty_y);
406        assert!(!empty_dataset.is_not_empty());
407    }
408
409    #[test]
410    fn test_dataset_standardize() {
411        let x = DMatrix::from_row_slice(3, 2, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
412        let y = DVector::from_vec(vec![7.0, 8.0, 9.0]);
413        let mut dataset = Dataset::new(x, y);
414        println!("{}", dataset.x);
415        dataset.standardize();
416        println!("{}", dataset.x);
417
418        let expected_x = DMatrix::from_row_slice(
419            3,
420            2,
421            &[
422                -1.224744871391589,
423                -1.224744871391589,
424                0.0,
425                0.0,
426                1.224744871391589,
427                1.224744871391589,
428            ],
429        );
430        assert_relative_eq!(dataset.x, expected_x, epsilon = 1e-6);
431    }
432
433    #[test]
434    fn test_dataset_train_test_split() {
435        let x = DMatrix::from_row_slice(4, 2, &[1, 2, 3, 4, 5, 6, 7, 8]);
436        let y = DVector::from_vec(vec![9, 10, 11, 12]);
437        let dataset = Dataset::new(x, y);
438
439        let (train_dataset, test_dataset) = dataset.train_test_split(0.75, None).unwrap();
440        assert_eq!(train_dataset.x.nrows(), 3);
441        assert_eq!(test_dataset.x.nrows(), 1);
442    }
443
444    #[test]
445    fn test_dataset_split_on_threshold() {
446        let x = DMatrix::from_row_slice(4, 2, &[1, 2, 3, 4, 5, 6, 7, 8]);
447        let y = DVector::from_vec(vec![9, 10, 11, 12]);
448        let dataset = Dataset::new(x, y);
449
450        let (left_dataset, right_dataset) = dataset.split_on_threshold(0, 4);
451        assert_eq!(left_dataset.x.nrows(), 2);
452        assert_eq!(right_dataset.x.nrows(), 2);
453    }
454
455    #[test]
456    fn test_dataset_split_on_threshold_left_empty() {
457        let x = DMatrix::from_row_slice(4, 2, &[1, 2, 3, 4, 5, 6, 7, 8]);
458        let y = DVector::from_vec(vec![9, 10, 11, 12]);
459        let dataset = Dataset::new(x, y);
460
461        let (left_dataset, right_dataset) = dataset.split_on_threshold(0, -1);
462        assert_eq!(left_dataset.x.nrows(), 0);
463        assert_eq!(right_dataset.x.nrows(), 4);
464    }
465
466    #[test]
467    fn test_dataset_split_on_threshold_right_empty() {
468        let x = DMatrix::from_row_slice(4, 2, &[1, 2, 3, 4, 5, 6, 7, 8]);
469        let y = DVector::from_vec(vec![9, 10, 11, 12]);
470        let dataset = Dataset::new(x, y);
471
472        let (left_dataset, right_dataset) = dataset.split_on_threshold(0, 9);
473        assert_eq!(left_dataset.x.nrows(), 4);
474        assert_eq!(right_dataset.x.nrows(), 0);
475    }
476
477    #[test]
478    fn test_dataset_samples() {
479        let x = DMatrix::from_row_slice(4, 2, &[1, 2, 3, 4, 5, 6, 7, 8]);
480        let y = DVector::from_vec(vec![9, 10, 11, 12]);
481        let dataset = Dataset::new(x, y);
482
483        let sampled_dataset = dataset.samples(2, None);
484        assert_eq!(sampled_dataset.x.nrows(), 2);
485    }
486
487    #[test]
488    fn test_dataset_samples_with_seed() {
489        let x = DMatrix::from_row_slice(4, 2, &[1, 2, 3, 4, 5, 6, 7, 8]);
490        let y = DVector::from_vec(vec![9, 10, 11, 12]);
491        let dataset = Dataset::new(x, y);
492
493        let sampled_dataset = dataset.samples(2, Some(1000));
494        assert_eq!(sampled_dataset.x.nrows(), 2);
495    }
496}