randomforest/
table.rs

1//! Table data which contains features and a target columns.
2use ordered_float::OrderedFloat;
3use rand::seq::SliceRandom;
4use rand::Rng;
5use std::ops::Range;
6use thiserror::Error;
7
8/// Column type.
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10#[non_exhaustive]
11pub enum ColumnType {
12    /// Numerical column.
13    Numerical = 0,
14
15    /// Categorical column.
16    Categorical = 1,
17}
18
19impl ColumnType {
20    pub(crate) fn is_left(self, x: f64, split_value: f64) -> bool {
21        match self {
22            Self::Numerical => x <= split_value,
23            Self::Categorical => (x - split_value).abs() < std::f64::EPSILON,
24        }
25    }
26}
27
28/// `Table` builder.
29#[derive(Debug)]
30pub struct TableBuilder {
31    column_types: Vec<ColumnType>,
32    columns: Vec<Vec<f64>>,
33}
34
35impl TableBuilder {
36    /// Makes a new `TableBuilder` instance.
37    pub fn new() -> Self {
38        Self {
39            column_types: Vec::new(),
40            columns: Vec::new(),
41        }
42    }
43
44    /// Sets the types of feature columns.
45    ///
46    /// In the default, the feature columns are regarded as numerical.
47    pub fn set_feature_column_types(&mut self, types: &[ColumnType]) -> Result<(), TableError> {
48        if self.columns.is_empty() {
49            self.columns = vec![Vec::new(); types.len() + 1];
50        }
51
52        if self.columns.len() != types.len() + 1 {
53            return Err(TableError::ColumnSizeMismatch);
54        }
55
56        self.column_types = types.to_owned();
57        Ok(())
58    }
59
60    /// Adds a row to the table.
61    pub fn add_row(&mut self, features: &[f64], target: f64) -> Result<(), TableError> {
62        if self.columns.is_empty() {
63            self.columns = vec![Vec::new(); features.len() + 1];
64        }
65
66        if self.columns.len() != features.len() + 1 {
67            return Err(TableError::ColumnSizeMismatch);
68        }
69
70        if !target.is_finite() {
71            return Err(TableError::NonFiniteTarget);
72        }
73
74        if self.column_types.is_empty() {
75            self.column_types = (0..features.len()).map(|_| ColumnType::Numerical).collect();
76        }
77
78        for (column, value) in self
79            .columns
80            .iter_mut()
81            .zip(features.iter().copied().chain(std::iter::once(target)))
82        {
83            column.push(value);
84        }
85
86        Ok(())
87    }
88
89    /// Builds a `Table` instance.
90    pub fn build(&self) -> Result<Table, TableError> {
91        if self.columns.is_empty() || self.columns[0].is_empty() {
92            return Err(TableError::EmptyTable);
93        }
94
95        let rows_len = self.columns[0].len();
96        Ok(Table {
97            row_index: (0..rows_len).collect(),
98            row_range: Range {
99                start: 0,
100                end: rows_len,
101            },
102            column_types: &self.column_types,
103            columns: &self.columns,
104        })
105    }
106}
107
108impl Default for TableBuilder {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114/// A table.
115#[derive(Debug, Clone)]
116pub struct Table<'a> {
117    row_index: Vec<usize>,
118    row_range: Range<usize>,
119    column_types: &'a [ColumnType],
120    columns: &'a [Vec<f64>],
121}
122
123impl<'a> Table<'a> {
124    /// Returns an iterator over all rows of the table.
125    ///
126    /// The last element of each row is the target value.
127    pub fn rows<'b>(&'b self) -> impl 'b + Iterator<Item = Vec<f64>> + Clone {
128        self.row_indices().map(move |i| {
129            (0..self.columns.len())
130                .map(|j| self.columns[j][i])
131                .collect()
132        })
133    }
134
135    /// Removes rows which don't match the given condition from the table.
136    ///
137    /// Note that after calling this method the order of rows isn't preserved.
138    pub fn filter<F>(&mut self, f: F) -> usize
139    where
140        F: Fn(&[f64]) -> bool,
141    {
142        let mut n = 0;
143        let mut i = self.row_range.start;
144        while i < self.row_range.end {
145            let row_i = self.row_index[i];
146            let row = (0..self.columns.len())
147                .map(|j| self.columns[j][row_i])
148                .collect::<Vec<_>>();
149            if f(&row) {
150                i += 1;
151            } else {
152                self.row_index.swap(i, self.row_range.end - 1);
153                self.row_range.end -= 1;
154                n += 1;
155            }
156        }
157        n
158    }
159
160    /// Splits the table into train and test datasets.
161    pub fn train_test_split<R: Rng + ?Sized>(
162        mut self,
163        rng: &mut R,
164        test_rate: f64,
165    ) -> (Self, Self) {
166        (&mut self.row_index[self.row_range.start..self.row_range.end]).shuffle(rng);
167        let test_num = (self.rows_len() as f64 * test_rate).round() as usize;
168
169        let mut train = self.clone();
170        let mut test = self;
171        test.row_range.end = test.row_range.start + test_num;
172        train.row_range.start = test.row_range.end;
173
174        (train, test)
175    }
176
177    pub(crate) fn target<'b>(&'b self) -> impl 'b + Iterator<Item = f64> + Clone {
178        self.column(self.columns.len() - 1)
179    }
180
181    pub(crate) fn column<'b>(
182        &'b self,
183        column_index: usize,
184    ) -> impl 'b + Iterator<Item = f64> + Clone {
185        self.row_indices()
186            .map(move |i| self.columns[column_index][i])
187    }
188
189    pub(crate) fn features_len(&self) -> usize {
190        self.columns.len() - 1
191    }
192
193    pub(crate) fn rows_len(&self) -> usize {
194        self.row_range.end - self.row_range.start
195    }
196
197    pub(crate) fn column_types(&self) -> &'a [ColumnType] {
198        self.column_types
199    }
200
201    fn row_indices<'b>(&'b self) -> impl 'b + Iterator<Item = usize> + Clone {
202        self.row_index[self.row_range.start..self.row_range.end]
203            .iter()
204            .copied()
205    }
206
207    pub(crate) fn sort_rows_by_column(&mut self, column: usize) {
208        let columns = &self.columns;
209        (&mut self.row_index[self.row_range.start..self.row_range.end])
210            .sort_by_key(|&x| OrderedFloat(columns[column][x]))
211    }
212
213    pub(crate) fn sort_rows_by_categorical_column(&mut self, column: usize, value: f64) {
214        let columns = &self.columns;
215        (&mut self.row_index[self.row_range.start..self.row_range.end]).sort_by_key(|&x| {
216            if (columns[column][x] - value).abs() < std::f64::EPSILON {
217                0
218            } else {
219                1
220            }
221        })
222    }
223
224    pub(crate) fn bootstrap_sample<R: Rng + ?Sized>(
225        &self,
226        rng: &mut R,
227        max_samples: usize,
228    ) -> Self {
229        let samples = std::cmp::min(max_samples, self.rows_len());
230        let row_index = (0..samples)
231            .map(|_| self.row_index[rng.gen_range(self.row_range.start, self.row_range.end)])
232            .collect::<Vec<_>>();
233        let row_range = Range {
234            start: 0,
235            end: samples,
236        };
237        Self {
238            row_index,
239            row_range,
240            column_types: self.column_types,
241            columns: self.columns,
242        }
243    }
244
245    pub(crate) fn split_points<'b>(
246        &'b self,
247        column_index: usize,
248    ) -> impl 'b + Iterator<Item = (Range<usize>, f64)> {
249        // Assumption: `self.columns[column]` has been sorted.
250        let column = &self.columns[column_index];
251        let categorical = self.column_types[column_index] == ColumnType::Categorical;
252        self.row_indices()
253            .map(move |i| column[i])
254            .enumerate()
255            .scan(None, move |prev, (i, x)| {
256                if prev.is_none() {
257                    *prev = Some((x, i));
258                    Some(None)
259                } else if prev.map_or(false, |(y, _)| (y - x).abs() > std::f64::EPSILON) {
260                    let (y, j) = prev.expect("never fails");
261                    *prev = Some((x, i));
262                    if categorical {
263                        let r = Range { start: j, end: i };
264                        Some(Some((r, y)))
265                    } else {
266                        let r = Range { start: 0, end: i };
267                        Some(Some((r, (x + y) / 2.0)))
268                    }
269                } else {
270                    Some(None)
271                }
272            })
273            .filter_map(|t| t)
274    }
275
276    pub(crate) fn with_split<F, T>(&mut self, row: usize, mut f: F) -> (T, T)
277    where
278        F: FnMut(&mut Self) -> T,
279    {
280        let row = row + self.row_range.start;
281        let original = self.row_range.clone();
282
283        self.row_range.end = row;
284        let left = f(self);
285        self.row_range.end = original.end;
286
287        self.row_range.start = row;
288        let right = f(self);
289        self.row_range.start = original.start;
290
291        (left, right)
292    }
293}
294
295/// Error kinds which could be returned during buidling a table.
296#[derive(Debug, Error, Clone, PartialEq, Eq, Hash)]
297pub enum TableError {
298    /// Table must have at least one column and one row.
299    #[error("table must have at least one column and one row")]
300    EmptyTable,
301
302    /// Some of rows have a different column count from others.
303    #[error("some of rows have a different column count from others")]
304    ColumnSizeMismatch,
305
306    /// Target column contains non finite numbers.
307    #[error("target column contains non finite numbers")]
308    NonFiniteTarget,
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314
315    #[test]
316    fn error_check_works() -> anyhow::Result<()> {
317        assert_eq!(
318            TableBuilder::default().build().err(),
319            Some(TableError::EmptyTable)
320        );
321
322        let mut table = TableBuilder::default();
323        table.set_feature_column_types(&[ColumnType::Numerical])?;
324        assert_eq!(
325            table.add_row(&[1.0, 1.0], 10.0).err(),
326            Some(TableError::ColumnSizeMismatch)
327        );
328
329        assert_eq!(
330            TableBuilder::default()
331                .add_row(&[1.0], std::f64::INFINITY)
332                .err(),
333            Some(TableError::NonFiniteTarget)
334        );
335
336        Ok(())
337    }
338
339    #[test]
340    fn train_test_split_works() -> anyhow::Result<()> {
341        let mut builder = TableBuilder::new();
342        for _ in 0..100 {
343            builder.add_row(&[0.0], 1.0)?;
344        }
345        let table = builder.build()?;
346        assert_eq!(table.rows_len(), 100);
347
348        let (train, test) = table.train_test_split(&mut rand::thread_rng(), 0.25);
349        assert_eq!(train.rows_len(), 75);
350        assert_eq!(test.rows_len(), 25);
351
352        Ok(())
353    }
354
355    #[test]
356    fn filter_works() -> anyhow::Result<()> {
357        let mut builder = TableBuilder::new();
358        for i in 0..100 {
359            builder.add_row(&[0.0], i as f64)?;
360        }
361        let mut table = builder.build()?;
362        assert_eq!(table.rows_len(), 100);
363
364        let removed = table.filter(|row| row[row.len() - 1] < 10.0);
365        assert_eq!(removed, 90);
366        assert_eq!(table.rows_len(), 10);
367        Ok(())
368    }
369}