1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
use rand::seq::SliceRandom;

/// A tuple containing a vector of input values matched to a vector of their expected output values
type Row = (Vec<f64>, Vec<f64>);

/// A collection of input vectors matched with their expected output.
///
/// You can construct a `Dataset` manually like so:
///
/// ```rust
/// // Note that the inputs and target outputs are both vectors, even though the latter has just
/// // one element
/// let data = vec![
///     (vec![0.0, 0.0], vec![0.0]),
///     (vec![0.0, 1.0], vec![1.0]),
///     (vec![1.0, 0.0], vec![1.0]),
///     (vec![1.0, 1.0], vec![0.0]),
/// ];
///
/// let dataset = scholar::Dataset::from(data);
/// ```
#[derive(Debug)]
pub struct Dataset {
    data: Vec<Row>,
}

impl Dataset {
    /// Parses a `Dataset` from a CSV file.
    ///
    /// # Arguments
    ///
    /// * `file_path` - The path to the CSV file
    /// * `includes_headers` - Whether the CSV has a header row or not
    /// * `num_inputs` - The number of columns in the CSV that are designated as inputs (to a
    /// Machine Learning model)
    ///
    /// # Examples
    /// ```rust
    /// // Parses the first four columns of 'iris.csv' as inputs, and the remaining columns as
    /// // target outputs
    /// let dataset = scholar::Dataset::from_csv("iris.csv", false, 4);
    /// ```
    pub fn from_csv(
        file_path: impl AsRef<std::path::Path>,
        includes_headers: bool,
        num_inputs: usize,
    ) -> Result<Self, ParseCsvError> {
        use std::str::FromStr;

        let file = std::fs::File::open(file_path)?;
        let mut reader = csv::ReaderBuilder::new()
            .has_headers(includes_headers)
            .from_reader(file);

        let data: Result<Vec<Row>, ParseCsvError> = reader
            .records()
            .map(|row| {
                // Catches a possible parsing error
                let row = row?;
                let row = row
                    .iter()
                    .map(|val| {
                        let val = val.trim();
                        f64::from_str(val)
                    })
                    .collect::<Result<Vec<_>, _>>()?;

                let mut inputs = row;
                // Splits the row into input and output vectors
                let outputs = inputs.split_off(num_inputs);
                Ok((inputs, outputs))
            })
            .collect();
        Ok(Dataset::from(data?))
    }

    /// Splits the dataset into two, with the size of each determined by the given `train_portion`.
    /// This is useful for separating it into training and testing segments.
    ///
    /// # Examples
    ///
    /// ```rust
    /// let dataset = scholar::Dataset::from_csv("iris.csv", false, 4)?;
    ///
    /// // Randomly allocates 75% of the original dataset to `training_data`, and the rest
    /// // to `testing_data`
    /// let (training_data, testing_data) = dataset.split(0.75);
    /// ```
    ///
    /// # Panics
    ///
    /// This method panics if the given `train_portion` isn't between 0 and 1.
    pub fn split(mut self, train_portion: f64) -> (Self, Self) {
        if train_portion < 0.0 || train_portion > 1.0 {
            panic!(
                "training portion must be between 0 and 1 (found {})",
                train_portion
            );
        }

        // Shuffles the dataset to ensure a random split
        self.shuffle();

        let index = self.data.len() as f64 * train_portion;
        let test_split = self.data.split_off(index.round() as usize);

        (self, Self::from(test_split))
    }

    /// Shuffles the rows in the dataset.
    pub(crate) fn shuffle(&mut self) {
        self.data.shuffle(&mut rand::thread_rng());
    }

    /// Returns the number of rows in the dataset.
    ///
    /// # Examples
    ///
    /// ```rust
    /// // Data for the XOR problem
    /// let data = vec![
    ///     (vec![0.0, 0.0], vec![0.0]),
    ///     (vec![0.0, 1.0], vec![1.0]),
    ///     (vec![1.0, 0.0], vec![1.0]),
    ///     (vec![1.0, 1.0], vec![0.0]),
    /// ];
    ///
    /// let dataset = scholar::Dataset::from(data);
    /// assert_eq!(dataset.rows(), 4);
    /// ```
    pub fn rows(&self) -> usize {
        self.data.len()
    }

    /// Returns a reference to the row at the specified index.
    fn get(&self, index: usize) -> Option<&Row> {
        self.data.get(index)
    }
}

/// An enumeration over the possible errors when parsing a `Dataset` from a CSV.
#[derive(thiserror::Error, Debug)]
pub enum ParseCsvError {
    /// When reading from a file fails.
    #[error("failed to read file")]
    Read(#[from] std::io::Error),
    /// When parsing a CSV fails.
    #[error("failed to parse CSV")]
    Parse(#[from] csv::Error),
    /// When converting CSV values to floats fails.
    #[error("failed to convert value into float")]
    Convert(#[from] std::num::ParseFloatError),
}

impl From<Vec<Row>> for Dataset {
    fn from(data: Vec<Row>) -> Self {
        Self { data }
    }
}

impl<'a> IntoIterator for &'a Dataset {
    type Item = &'a Row;
    type IntoIter = DatasetIterator<'a>;

    fn into_iter(self) -> Self::IntoIter {
        DatasetIterator {
            dataset: self,
            index: 0,
        }
    }
}

/// An iterator over a `Dataset`.
pub struct DatasetIterator<'a> {
    dataset: &'a Dataset,
    index: usize,
}

impl<'a> Iterator for DatasetIterator<'a> {
    type Item = &'a Row;
    fn next(&mut self) -> Option<Self::Item> {
        let result = self.dataset.get(self.index);
        self.index += 1;
        result
    }
}