1use ordered_float::OrderedFloat;
3use rand::seq::SliceRandom;
4use rand::Rng;
5use std::ops::Range;
6use thiserror::Error;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10#[non_exhaustive]
11pub enum ColumnType {
12 Numerical = 0,
14
15 Categorical = 1,
17}
18
19impl ColumnType {
20 pub(crate) fn is_left(self, x: f64, split_value: f64) -> bool {
21 match self {
22 Self::Numerical => x <= split_value,
23 Self::Categorical => (x - split_value).abs() < std::f64::EPSILON,
24 }
25 }
26}
27
28#[derive(Debug)]
30pub struct TableBuilder {
31 column_types: Vec<ColumnType>,
32 columns: Vec<Vec<f64>>,
33}
34
35impl TableBuilder {
36 pub fn new() -> Self {
38 Self {
39 column_types: Vec::new(),
40 columns: Vec::new(),
41 }
42 }
43
44 pub fn set_feature_column_types(&mut self, types: &[ColumnType]) -> Result<(), TableError> {
48 if self.columns.is_empty() {
49 self.columns = vec![Vec::new(); types.len() + 1];
50 }
51
52 if self.columns.len() != types.len() + 1 {
53 return Err(TableError::ColumnSizeMismatch);
54 }
55
56 self.column_types = types.to_owned();
57 Ok(())
58 }
59
60 pub fn add_row(&mut self, features: &[f64], target: f64) -> Result<(), TableError> {
62 if self.columns.is_empty() {
63 self.columns = vec![Vec::new(); features.len() + 1];
64 }
65
66 if self.columns.len() != features.len() + 1 {
67 return Err(TableError::ColumnSizeMismatch);
68 }
69
70 if !target.is_finite() {
71 return Err(TableError::NonFiniteTarget);
72 }
73
74 if self.column_types.is_empty() {
75 self.column_types = (0..features.len()).map(|_| ColumnType::Numerical).collect();
76 }
77
78 for (column, value) in self
79 .columns
80 .iter_mut()
81 .zip(features.iter().copied().chain(std::iter::once(target)))
82 {
83 column.push(value);
84 }
85
86 Ok(())
87 }
88
89 pub fn build(&self) -> Result<Table, TableError> {
91 if self.columns.is_empty() || self.columns[0].is_empty() {
92 return Err(TableError::EmptyTable);
93 }
94
95 let rows_len = self.columns[0].len();
96 Ok(Table {
97 row_index: (0..rows_len).collect(),
98 row_range: Range {
99 start: 0,
100 end: rows_len,
101 },
102 column_types: &self.column_types,
103 columns: &self.columns,
104 })
105 }
106}
107
108impl Default for TableBuilder {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114#[derive(Debug, Clone)]
116pub struct Table<'a> {
117 row_index: Vec<usize>,
118 row_range: Range<usize>,
119 column_types: &'a [ColumnType],
120 columns: &'a [Vec<f64>],
121}
122
123impl<'a> Table<'a> {
124 pub fn rows<'b>(&'b self) -> impl 'b + Iterator<Item = Vec<f64>> + Clone {
128 self.row_indices().map(move |i| {
129 (0..self.columns.len())
130 .map(|j| self.columns[j][i])
131 .collect()
132 })
133 }
134
135 pub fn filter<F>(&mut self, f: F) -> usize
139 where
140 F: Fn(&[f64]) -> bool,
141 {
142 let mut n = 0;
143 let mut i = self.row_range.start;
144 while i < self.row_range.end {
145 let row_i = self.row_index[i];
146 let row = (0..self.columns.len())
147 .map(|j| self.columns[j][row_i])
148 .collect::<Vec<_>>();
149 if f(&row) {
150 i += 1;
151 } else {
152 self.row_index.swap(i, self.row_range.end - 1);
153 self.row_range.end -= 1;
154 n += 1;
155 }
156 }
157 n
158 }
159
160 pub fn train_test_split<R: Rng + ?Sized>(
162 mut self,
163 rng: &mut R,
164 test_rate: f64,
165 ) -> (Self, Self) {
166 (&mut self.row_index[self.row_range.start..self.row_range.end]).shuffle(rng);
167 let test_num = (self.rows_len() as f64 * test_rate).round() as usize;
168
169 let mut train = self.clone();
170 let mut test = self;
171 test.row_range.end = test.row_range.start + test_num;
172 train.row_range.start = test.row_range.end;
173
174 (train, test)
175 }
176
177 pub(crate) fn target<'b>(&'b self) -> impl 'b + Iterator<Item = f64> + Clone {
178 self.column(self.columns.len() - 1)
179 }
180
181 pub(crate) fn column<'b>(
182 &'b self,
183 column_index: usize,
184 ) -> impl 'b + Iterator<Item = f64> + Clone {
185 self.row_indices()
186 .map(move |i| self.columns[column_index][i])
187 }
188
189 pub(crate) fn features_len(&self) -> usize {
190 self.columns.len() - 1
191 }
192
193 pub(crate) fn rows_len(&self) -> usize {
194 self.row_range.end - self.row_range.start
195 }
196
197 pub(crate) fn column_types(&self) -> &'a [ColumnType] {
198 self.column_types
199 }
200
201 fn row_indices<'b>(&'b self) -> impl 'b + Iterator<Item = usize> + Clone {
202 self.row_index[self.row_range.start..self.row_range.end]
203 .iter()
204 .copied()
205 }
206
207 pub(crate) fn sort_rows_by_column(&mut self, column: usize) {
208 let columns = &self.columns;
209 (&mut self.row_index[self.row_range.start..self.row_range.end])
210 .sort_by_key(|&x| OrderedFloat(columns[column][x]))
211 }
212
213 pub(crate) fn sort_rows_by_categorical_column(&mut self, column: usize, value: f64) {
214 let columns = &self.columns;
215 (&mut self.row_index[self.row_range.start..self.row_range.end]).sort_by_key(|&x| {
216 if (columns[column][x] - value).abs() < std::f64::EPSILON {
217 0
218 } else {
219 1
220 }
221 })
222 }
223
224 pub(crate) fn bootstrap_sample<R: Rng + ?Sized>(
225 &self,
226 rng: &mut R,
227 max_samples: usize,
228 ) -> Self {
229 let samples = std::cmp::min(max_samples, self.rows_len());
230 let row_index = (0..samples)
231 .map(|_| self.row_index[rng.gen_range(self.row_range.start, self.row_range.end)])
232 .collect::<Vec<_>>();
233 let row_range = Range {
234 start: 0,
235 end: samples,
236 };
237 Self {
238 row_index,
239 row_range,
240 column_types: self.column_types,
241 columns: self.columns,
242 }
243 }
244
245 pub(crate) fn split_points<'b>(
246 &'b self,
247 column_index: usize,
248 ) -> impl 'b + Iterator<Item = (Range<usize>, f64)> {
249 let column = &self.columns[column_index];
251 let categorical = self.column_types[column_index] == ColumnType::Categorical;
252 self.row_indices()
253 .map(move |i| column[i])
254 .enumerate()
255 .scan(None, move |prev, (i, x)| {
256 if prev.is_none() {
257 *prev = Some((x, i));
258 Some(None)
259 } else if prev.map_or(false, |(y, _)| (y - x).abs() > std::f64::EPSILON) {
260 let (y, j) = prev.expect("never fails");
261 *prev = Some((x, i));
262 if categorical {
263 let r = Range { start: j, end: i };
264 Some(Some((r, y)))
265 } else {
266 let r = Range { start: 0, end: i };
267 Some(Some((r, (x + y) / 2.0)))
268 }
269 } else {
270 Some(None)
271 }
272 })
273 .filter_map(|t| t)
274 }
275
276 pub(crate) fn with_split<F, T>(&mut self, row: usize, mut f: F) -> (T, T)
277 where
278 F: FnMut(&mut Self) -> T,
279 {
280 let row = row + self.row_range.start;
281 let original = self.row_range.clone();
282
283 self.row_range.end = row;
284 let left = f(self);
285 self.row_range.end = original.end;
286
287 self.row_range.start = row;
288 let right = f(self);
289 self.row_range.start = original.start;
290
291 (left, right)
292 }
293}
294
295#[derive(Debug, Error, Clone, PartialEq, Eq, Hash)]
297pub enum TableError {
298 #[error("table must have at least one column and one row")]
300 EmptyTable,
301
302 #[error("some of rows have a different column count from others")]
304 ColumnSizeMismatch,
305
306 #[error("target column contains non finite numbers")]
308 NonFiniteTarget,
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314
315 #[test]
316 fn error_check_works() -> anyhow::Result<()> {
317 assert_eq!(
318 TableBuilder::default().build().err(),
319 Some(TableError::EmptyTable)
320 );
321
322 let mut table = TableBuilder::default();
323 table.set_feature_column_types(&[ColumnType::Numerical])?;
324 assert_eq!(
325 table.add_row(&[1.0, 1.0], 10.0).err(),
326 Some(TableError::ColumnSizeMismatch)
327 );
328
329 assert_eq!(
330 TableBuilder::default()
331 .add_row(&[1.0], std::f64::INFINITY)
332 .err(),
333 Some(TableError::NonFiniteTarget)
334 );
335
336 Ok(())
337 }
338
339 #[test]
340 fn train_test_split_works() -> anyhow::Result<()> {
341 let mut builder = TableBuilder::new();
342 for _ in 0..100 {
343 builder.add_row(&[0.0], 1.0)?;
344 }
345 let table = builder.build()?;
346 assert_eq!(table.rows_len(), 100);
347
348 let (train, test) = table.train_test_split(&mut rand::thread_rng(), 0.25);
349 assert_eq!(train.rows_len(), 75);
350 assert_eq!(test.rows_len(), 25);
351
352 Ok(())
353 }
354
355 #[test]
356 fn filter_works() -> anyhow::Result<()> {
357 let mut builder = TableBuilder::new();
358 for i in 0..100 {
359 builder.add_row(&[0.0], i as f64)?;
360 }
361 let mut table = builder.build()?;
362 assert_eq!(table.rows_len(), 100);
363
364 let removed = table.filter(|row| row[row.len() - 1] < 10.0);
365 assert_eq!(removed, 90);
366 assert_eq!(table.rows_len(), 10);
367 Ok(())
368 }
369}