Skip to main content

forestfire_data/
lib.rs

1use arrow::array::{BooleanArray, Float64Array, UInt8Array, UInt16Array};
2use rand::seq::SliceRandom;
3use rand::{SeedableRng, rngs::StdRng};
4use std::cmp::Ordering;
5use std::error::Error;
6use std::fmt::{Display, Formatter};
7
8pub const MAX_NUMERIC_BINS: usize = 128;
9const DEFAULT_CANARIES: usize = 2;
10
11type PreprocessedRows = (Vec<Vec<f64>>, Float64Array, usize, usize);
12
13pub trait TableAccess: Sync {
14    fn n_rows(&self) -> usize;
15    fn n_features(&self) -> usize;
16    fn canaries(&self) -> usize;
17    fn numeric_bin_cap(&self) -> usize;
18    fn binned_feature_count(&self) -> usize;
19    fn feature_value(&self, feature_index: usize, row_index: usize) -> f64;
20    fn is_binary_feature(&self, index: usize) -> bool;
21    fn binned_value(&self, feature_index: usize, row_index: usize) -> u16;
22    fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool>;
23    fn binned_column_kind(&self, index: usize) -> BinnedColumnKind;
24    fn is_binary_binned_feature(&self, index: usize) -> bool;
25    fn target_value(&self, row_index: usize) -> f64;
26
27    fn is_canary_binned_feature(&self, index: usize) -> bool {
28        matches!(
29            self.binned_column_kind(index),
30            BinnedColumnKind::Canary { .. }
31        )
32    }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum TableKind {
37    Dense,
38    Sparse,
39}
40
41#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
42pub enum NumericBins {
43    #[default]
44    Auto,
45    Fixed(usize),
46}
47
48impl NumericBins {
49    pub fn fixed(requested: usize) -> Result<Self, DenseTableError> {
50        if requested == 0 || requested > MAX_NUMERIC_BINS {
51            return Err(DenseTableError::InvalidBinCount { requested });
52        }
53        Ok(Self::Fixed(requested))
54    }
55
56    pub fn cap(self) -> usize {
57        match self {
58            NumericBins::Auto => MAX_NUMERIC_BINS,
59            NumericBins::Fixed(requested) => requested,
60        }
61    }
62}
63
64/// Arrow-backed dense table for tabular regression/classification data.
65#[derive(Debug, Clone)]
66pub struct DenseTable {
67    feature_columns: Vec<FeatureColumn>,
68    binned_feature_columns: Vec<BinnedFeatureColumn>,
69    binned_column_kinds: Vec<BinnedColumnKind>,
70    target: Float64Array,
71    n_rows: usize,
72    n_features: usize,
73    canaries: usize,
74    numeric_bins: NumericBins,
75}
76
77/// Arrow-backed sparse table specialized for binary feature matrices.
78#[derive(Debug, Clone)]
79pub struct SparseTable {
80    feature_columns: Vec<SparseBinaryColumn>,
81    binned_feature_columns: Vec<SparseBinaryColumn>,
82    binned_column_kinds: Vec<BinnedColumnKind>,
83    target: Float64Array,
84    n_rows: usize,
85    n_features: usize,
86    canaries: usize,
87    numeric_bins: NumericBins,
88}
89
90#[derive(Debug, Clone)]
91struct SparseBinaryColumn {
92    row_indices: Vec<usize>,
93}
94
95impl SparseBinaryColumn {
96    fn value(&self, row_index: usize) -> bool {
97        self.row_indices.binary_search(&row_index).is_ok()
98    }
99}
100
101#[derive(Debug, Clone)]
102pub enum Table {
103    Dense(DenseTable),
104    Sparse(SparseTable),
105}
106
107#[derive(Debug, Clone)]
108enum FeatureColumn {
109    Numeric(Float64Array),
110    Binary(BooleanArray),
111}
112
113#[derive(Debug, Clone)]
114enum BinnedFeatureColumn {
115    NumericU8(UInt8Array),
116    NumericU16(UInt16Array),
117    Binary(BooleanArray),
118}
119
120#[derive(Debug, Clone, Copy)]
121pub enum FeatureColumnRef<'a> {
122    Numeric(&'a Float64Array),
123    Binary(&'a BooleanArray),
124}
125
126#[derive(Debug, Clone, Copy)]
127pub enum BinnedFeatureColumnRef<'a> {
128    NumericU8(&'a UInt8Array),
129    NumericU16(&'a UInt16Array),
130    Binary(&'a BooleanArray),
131}
132
133#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134pub enum BinnedColumnKind {
135    Real {
136        source_index: usize,
137    },
138    Canary {
139        source_index: usize,
140        copy_index: usize,
141    },
142}
143
144#[derive(Debug, Clone, PartialEq, Eq)]
145pub enum DenseTableError {
146    MismatchedLengths {
147        x: usize,
148        y: usize,
149    },
150    RaggedRows {
151        row: usize,
152        expected: usize,
153        actual: usize,
154    },
155    NonBinaryColumn {
156        column: usize,
157    },
158    InvalidBinCount {
159        requested: usize,
160    },
161}
162
163impl Display for DenseTableError {
164    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
165        match self {
166            DenseTableError::MismatchedLengths { x, y } => write!(
167                f,
168                "Mismatched lengths: X has {} rows while y has {} values.",
169                x, y
170            ),
171            DenseTableError::RaggedRows {
172                row,
173                expected,
174                actual,
175            } => write!(
176                f,
177                "Ragged row at index {}: expected {} columns, found {}.",
178                row, expected, actual
179            ),
180            DenseTableError::NonBinaryColumn { column } => write!(
181                f,
182                "SparseTable requires binary features, but column {} contains non-binary values.",
183                column
184            ),
185            DenseTableError::InvalidBinCount { requested } => write!(
186                f,
187                "Invalid bins value {}. Expected 'auto' or an integer between 1 and {}.",
188                requested, MAX_NUMERIC_BINS
189            ),
190        }
191    }
192}
193
194impl Error for DenseTableError {}
195
196impl DenseTable {
197    pub fn new(x: Vec<Vec<f64>>, y: Vec<f64>) -> Result<Self, DenseTableError> {
198        Self::with_canaries(x, y, DEFAULT_CANARIES)
199    }
200
201    pub fn with_canaries(
202        x: Vec<Vec<f64>>,
203        y: Vec<f64>,
204        canaries: usize,
205    ) -> Result<Self, DenseTableError> {
206        Self::with_options(x, y, canaries, NumericBins::Auto)
207    }
208
209    pub fn with_options(
210        x: Vec<Vec<f64>>,
211        y: Vec<f64>,
212        canaries: usize,
213        numeric_bins: NumericBins,
214    ) -> Result<Self, DenseTableError> {
215        let (columns, target, n_rows, n_features) = preprocess_rows(&x, y)?;
216        Ok(Self::from_columns(
217            &columns,
218            target,
219            n_rows,
220            n_features,
221            canaries,
222            numeric_bins,
223        ))
224    }
225
226    fn from_columns(
227        columns: &[Vec<f64>],
228        target: Float64Array,
229        n_rows: usize,
230        n_features: usize,
231        canaries: usize,
232        numeric_bins: NumericBins,
233    ) -> Self {
234        let feature_columns = columns
235            .iter()
236            .map(|column| build_feature_column(column))
237            .collect();
238
239        let real_binned_columns: Vec<BinnedFeatureColumn> = columns
240            .iter()
241            .map(|column| build_binned_feature_column(column, numeric_bins))
242            .collect();
243        let canary_columns: Vec<(BinnedColumnKind, BinnedFeatureColumn)> = (0..canaries)
244            .flat_map(|copy_index| {
245                real_binned_columns
246                    .iter()
247                    .enumerate()
248                    .map(move |(source_index, column)| {
249                        (
250                            BinnedColumnKind::Canary {
251                                source_index,
252                                copy_index,
253                            },
254                            shuffle_canary_column(column, copy_index, source_index),
255                        )
256                    })
257            })
258            .collect();
259
260        let (binned_column_kinds, binned_feature_columns): (Vec<_>, Vec<_>) = (0..n_features)
261            .map(|source_index| BinnedColumnKind::Real { source_index })
262            .zip(real_binned_columns)
263            .chain(canary_columns)
264            .unzip();
265
266        Self {
267            feature_columns,
268            binned_feature_columns,
269            binned_column_kinds,
270            target,
271            n_rows,
272            n_features,
273            canaries,
274            numeric_bins,
275        }
276    }
277
278    #[inline]
279    pub fn n_rows(&self) -> usize {
280        self.n_rows
281    }
282
283    #[inline]
284    pub fn n_features(&self) -> usize {
285        self.n_features
286    }
287
288    #[inline]
289    pub fn canaries(&self) -> usize {
290        self.canaries
291    }
292
293    #[inline]
294    pub fn numeric_bin_cap(&self) -> usize {
295        self.numeric_bins.cap()
296    }
297
298    #[inline]
299    pub fn binned_feature_count(&self) -> usize {
300        self.binned_feature_columns.len()
301    }
302
303    #[inline]
304    pub fn feature_column(&self, index: usize) -> FeatureColumnRef<'_> {
305        match &self.feature_columns[index] {
306            FeatureColumn::Numeric(column) => FeatureColumnRef::Numeric(column),
307            FeatureColumn::Binary(column) => FeatureColumnRef::Binary(column),
308        }
309    }
310
311    #[inline]
312    pub fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
313        match &self.feature_columns[feature_index] {
314            FeatureColumn::Numeric(column) => column.value(row_index),
315            FeatureColumn::Binary(column) => f64::from(u8::from(column.value(row_index))),
316        }
317    }
318
319    #[inline]
320    pub fn is_binary_feature(&self, index: usize) -> bool {
321        matches!(self.feature_columns[index], FeatureColumn::Binary(_))
322    }
323
324    #[inline]
325    pub fn binned_feature_column(&self, index: usize) -> BinnedFeatureColumnRef<'_> {
326        match &self.binned_feature_columns[index] {
327            BinnedFeatureColumn::NumericU8(column) => BinnedFeatureColumnRef::NumericU8(column),
328            BinnedFeatureColumn::NumericU16(column) => BinnedFeatureColumnRef::NumericU16(column),
329            BinnedFeatureColumn::Binary(column) => BinnedFeatureColumnRef::Binary(column),
330        }
331    }
332
333    #[inline]
334    pub fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
335        match &self.binned_feature_columns[feature_index] {
336            BinnedFeatureColumn::NumericU8(column) => u16::from(column.value(row_index)),
337            BinnedFeatureColumn::NumericU16(column) => column.value(row_index),
338            BinnedFeatureColumn::Binary(column) => u16::from(u8::from(column.value(row_index))),
339        }
340    }
341
342    #[inline]
343    pub fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
344        match &self.binned_feature_columns[feature_index] {
345            BinnedFeatureColumn::Binary(column) => Some(column.value(row_index)),
346            BinnedFeatureColumn::NumericU8(_) | BinnedFeatureColumn::NumericU16(_) => None,
347        }
348    }
349
350    #[inline]
351    pub fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
352        self.binned_column_kinds[index]
353    }
354
355    #[inline]
356    pub fn is_canary_binned_feature(&self, index: usize) -> bool {
357        matches!(
358            self.binned_column_kinds[index],
359            BinnedColumnKind::Canary { .. }
360        )
361    }
362
363    #[inline]
364    pub fn is_binary_binned_feature(&self, index: usize) -> bool {
365        matches!(
366            self.binned_feature_columns[index],
367            BinnedFeatureColumn::Binary(_)
368        )
369    }
370
371    #[inline]
372    pub fn target(&self) -> &Float64Array {
373        &self.target
374    }
375}
376
377impl SparseTable {
378    pub fn new(x: Vec<Vec<f64>>, y: Vec<f64>) -> Result<Self, DenseTableError> {
379        Self::with_canaries(x, y, DEFAULT_CANARIES)
380    }
381
382    pub fn with_canaries(
383        x: Vec<Vec<f64>>,
384        y: Vec<f64>,
385        canaries: usize,
386    ) -> Result<Self, DenseTableError> {
387        Self::with_options(x, y, canaries, NumericBins::Auto)
388    }
389
390    pub fn with_options(
391        x: Vec<Vec<f64>>,
392        y: Vec<f64>,
393        canaries: usize,
394        numeric_bins: NumericBins,
395    ) -> Result<Self, DenseTableError> {
396        let (columns, target, n_rows, n_features) = preprocess_rows(&x, y)?;
397        validate_binary_columns(&columns)?;
398        Ok(Self::from_columns(
399            &columns,
400            target,
401            n_rows,
402            n_features,
403            canaries,
404            numeric_bins,
405        ))
406    }
407
408    fn from_columns(
409        columns: &[Vec<f64>],
410        target: Float64Array,
411        n_rows: usize,
412        n_features: usize,
413        canaries: usize,
414        numeric_bins: NumericBins,
415    ) -> Self {
416        let feature_columns: Vec<SparseBinaryColumn> = columns
417            .iter()
418            .map(|column| sparse_binary_column_from_values(column))
419            .collect();
420
421        let canary_columns: Vec<(BinnedColumnKind, SparseBinaryColumn)> = (0..canaries)
422            .flat_map(|copy_index| {
423                feature_columns
424                    .iter()
425                    .enumerate()
426                    .map(move |(source_index, column)| {
427                        (
428                            BinnedColumnKind::Canary {
429                                source_index,
430                                copy_index,
431                            },
432                            shuffle_sparse_binary_column(column, n_rows, copy_index, source_index),
433                        )
434                    })
435            })
436            .collect();
437
438        let (binned_column_kinds, binned_feature_columns): (Vec<_>, Vec<_>) = (0..n_features)
439            .map(|source_index| BinnedColumnKind::Real { source_index })
440            .zip(feature_columns.iter().cloned())
441            .chain(canary_columns)
442            .unzip();
443
444        Self {
445            feature_columns,
446            binned_feature_columns,
447            binned_column_kinds,
448            target,
449            n_rows,
450            n_features,
451            canaries,
452            numeric_bins,
453        }
454    }
455
456    pub fn from_sparse_binary_columns(
457        n_rows: usize,
458        n_features: usize,
459        columns: Vec<Vec<usize>>,
460        y: Vec<f64>,
461        canaries: usize,
462    ) -> Result<Self, DenseTableError> {
463        Self::from_sparse_binary_columns_with_options(
464            n_rows,
465            n_features,
466            columns,
467            y,
468            canaries,
469            NumericBins::Auto,
470        )
471    }
472
473    pub fn from_sparse_binary_columns_with_options(
474        n_rows: usize,
475        n_features: usize,
476        columns: Vec<Vec<usize>>,
477        y: Vec<f64>,
478        canaries: usize,
479        numeric_bins: NumericBins,
480    ) -> Result<Self, DenseTableError> {
481        if n_rows != y.len() {
482            return Err(DenseTableError::MismatchedLengths {
483                x: n_rows,
484                y: y.len(),
485            });
486        }
487        if n_features != columns.len() {
488            return Err(DenseTableError::RaggedRows {
489                row: columns.len(),
490                expected: n_features,
491                actual: columns.len(),
492            });
493        }
494
495        let feature_columns = columns
496            .into_iter()
497            .enumerate()
498            .map(|(column_idx, mut row_indices)| {
499                row_indices.sort_unstable();
500                row_indices.dedup();
501                if row_indices.iter().any(|row_idx| *row_idx >= n_rows) {
502                    return Err(DenseTableError::NonBinaryColumn { column: column_idx });
503                }
504                Ok(SparseBinaryColumn { row_indices })
505            })
506            .collect::<Result<Vec<_>, _>>()?;
507
508        let canary_columns: Vec<(BinnedColumnKind, SparseBinaryColumn)> = (0..canaries)
509            .flat_map(|copy_index| {
510                feature_columns
511                    .iter()
512                    .enumerate()
513                    .map(move |(source_index, column)| {
514                        (
515                            BinnedColumnKind::Canary {
516                                source_index,
517                                copy_index,
518                            },
519                            shuffle_sparse_binary_column(column, n_rows, copy_index, source_index),
520                        )
521                    })
522            })
523            .collect();
524
525        let (binned_column_kinds, binned_feature_columns): (Vec<_>, Vec<_>) = (0..n_features)
526            .map(|source_index| BinnedColumnKind::Real { source_index })
527            .zip(feature_columns.iter().cloned())
528            .chain(canary_columns)
529            .unzip();
530
531        Ok(Self {
532            feature_columns,
533            binned_feature_columns,
534            binned_column_kinds,
535            target: Float64Array::from(y),
536            n_rows,
537            n_features,
538            canaries,
539            numeric_bins,
540        })
541    }
542
543    #[inline]
544    pub fn n_rows(&self) -> usize {
545        self.n_rows
546    }
547
548    #[inline]
549    pub fn n_features(&self) -> usize {
550        self.n_features
551    }
552
553    #[inline]
554    pub fn canaries(&self) -> usize {
555        self.canaries
556    }
557
558    #[inline]
559    pub fn numeric_bin_cap(&self) -> usize {
560        self.numeric_bins.cap()
561    }
562
563    #[inline]
564    pub fn binned_feature_count(&self) -> usize {
565        self.binned_feature_columns.len()
566    }
567
568    #[inline]
569    pub fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
570        f64::from(u8::from(
571            self.feature_columns[feature_index].value(row_index),
572        ))
573    }
574
575    #[inline]
576    pub fn is_binary_feature(&self, _index: usize) -> bool {
577        true
578    }
579
580    #[inline]
581    pub fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
582        u16::from(u8::from(
583            self.binned_feature_columns[feature_index].value(row_index),
584        ))
585    }
586
587    #[inline]
588    pub fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
589        Some(self.binned_feature_columns[feature_index].value(row_index))
590    }
591
592    #[inline]
593    pub fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
594        self.binned_column_kinds[index]
595    }
596
597    #[inline]
598    pub fn is_canary_binned_feature(&self, index: usize) -> bool {
599        matches!(
600            self.binned_column_kinds[index],
601            BinnedColumnKind::Canary { .. }
602        )
603    }
604
605    #[inline]
606    pub fn is_binary_binned_feature(&self, _index: usize) -> bool {
607        true
608    }
609
610    #[inline]
611    pub fn target(&self) -> &Float64Array {
612        &self.target
613    }
614}
615
616impl Table {
617    pub fn new(x: Vec<Vec<f64>>, y: Vec<f64>) -> Result<Self, DenseTableError> {
618        Self::with_canaries(x, y, DEFAULT_CANARIES)
619    }
620
621    pub fn with_canaries(
622        x: Vec<Vec<f64>>,
623        y: Vec<f64>,
624        canaries: usize,
625    ) -> Result<Self, DenseTableError> {
626        Self::with_options(x, y, canaries, NumericBins::Auto)
627    }
628
629    pub fn with_options(
630        x: Vec<Vec<f64>>,
631        y: Vec<f64>,
632        canaries: usize,
633        numeric_bins: NumericBins,
634    ) -> Result<Self, DenseTableError> {
635        let (columns, target, n_rows, n_features) = preprocess_rows(&x, y)?;
636
637        if columns.iter().all(|column| is_binary_column(column)) {
638            Ok(Self::Sparse(SparseTable::from_columns(
639                &columns,
640                target,
641                n_rows,
642                n_features,
643                canaries,
644                numeric_bins,
645            )))
646        } else {
647            Ok(Self::Dense(DenseTable::from_columns(
648                &columns,
649                target,
650                n_rows,
651                n_features,
652                canaries,
653                numeric_bins,
654            )))
655        }
656    }
657
658    pub fn kind(&self) -> TableKind {
659        match self {
660            Table::Dense(_) => TableKind::Dense,
661            Table::Sparse(_) => TableKind::Sparse,
662        }
663    }
664
665    pub fn as_dense(&self) -> Option<&DenseTable> {
666        match self {
667            Table::Dense(table) => Some(table),
668            Table::Sparse(_) => None,
669        }
670    }
671
672    pub fn as_sparse(&self) -> Option<&SparseTable> {
673        match self {
674            Table::Dense(_) => None,
675            Table::Sparse(table) => Some(table),
676        }
677    }
678}
679
680impl TableAccess for DenseTable {
681    fn n_rows(&self) -> usize {
682        self.n_rows()
683    }
684
685    fn n_features(&self) -> usize {
686        self.n_features()
687    }
688
689    fn canaries(&self) -> usize {
690        self.canaries()
691    }
692
693    fn numeric_bin_cap(&self) -> usize {
694        self.numeric_bin_cap()
695    }
696
697    fn binned_feature_count(&self) -> usize {
698        self.binned_feature_count()
699    }
700
701    fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
702        self.feature_value(feature_index, row_index)
703    }
704
705    fn is_binary_feature(&self, index: usize) -> bool {
706        self.is_binary_feature(index)
707    }
708
709    fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
710        self.binned_value(feature_index, row_index)
711    }
712
713    fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
714        self.binned_boolean_value(feature_index, row_index)
715    }
716
717    fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
718        self.binned_column_kind(index)
719    }
720
721    fn is_binary_binned_feature(&self, index: usize) -> bool {
722        self.is_binary_binned_feature(index)
723    }
724
725    fn target_value(&self, row_index: usize) -> f64 {
726        self.target().value(row_index)
727    }
728}
729
730impl TableAccess for SparseTable {
731    fn n_rows(&self) -> usize {
732        self.n_rows()
733    }
734
735    fn n_features(&self) -> usize {
736        self.n_features()
737    }
738
739    fn canaries(&self) -> usize {
740        self.canaries()
741    }
742
743    fn numeric_bin_cap(&self) -> usize {
744        self.numeric_bin_cap()
745    }
746
747    fn binned_feature_count(&self) -> usize {
748        self.binned_feature_count()
749    }
750
751    fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
752        self.feature_value(feature_index, row_index)
753    }
754
755    fn is_binary_feature(&self, index: usize) -> bool {
756        self.is_binary_feature(index)
757    }
758
759    fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
760        self.binned_value(feature_index, row_index)
761    }
762
763    fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
764        self.binned_boolean_value(feature_index, row_index)
765    }
766
767    fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
768        self.binned_column_kind(index)
769    }
770
771    fn is_binary_binned_feature(&self, index: usize) -> bool {
772        self.is_binary_binned_feature(index)
773    }
774
775    fn target_value(&self, row_index: usize) -> f64 {
776        self.target().value(row_index)
777    }
778}
779
780impl TableAccess for Table {
781    fn n_rows(&self) -> usize {
782        match self {
783            Table::Dense(table) => table.n_rows(),
784            Table::Sparse(table) => table.n_rows(),
785        }
786    }
787
788    fn n_features(&self) -> usize {
789        match self {
790            Table::Dense(table) => table.n_features(),
791            Table::Sparse(table) => table.n_features(),
792        }
793    }
794
795    fn canaries(&self) -> usize {
796        match self {
797            Table::Dense(table) => table.canaries(),
798            Table::Sparse(table) => table.canaries(),
799        }
800    }
801
802    fn numeric_bin_cap(&self) -> usize {
803        match self {
804            Table::Dense(table) => table.numeric_bin_cap(),
805            Table::Sparse(table) => table.numeric_bin_cap(),
806        }
807    }
808
809    fn binned_feature_count(&self) -> usize {
810        match self {
811            Table::Dense(table) => table.binned_feature_count(),
812            Table::Sparse(table) => table.binned_feature_count(),
813        }
814    }
815
816    fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
817        match self {
818            Table::Dense(table) => table.feature_value(feature_index, row_index),
819            Table::Sparse(table) => table.feature_value(feature_index, row_index),
820        }
821    }
822
823    fn is_binary_feature(&self, index: usize) -> bool {
824        match self {
825            Table::Dense(table) => table.is_binary_feature(index),
826            Table::Sparse(table) => table.is_binary_feature(index),
827        }
828    }
829
830    fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
831        match self {
832            Table::Dense(table) => table.binned_value(feature_index, row_index),
833            Table::Sparse(table) => table.binned_value(feature_index, row_index),
834        }
835    }
836
837    fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
838        match self {
839            Table::Dense(table) => table.binned_boolean_value(feature_index, row_index),
840            Table::Sparse(table) => table.binned_boolean_value(feature_index, row_index),
841        }
842    }
843
844    fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
845        match self {
846            Table::Dense(table) => table.binned_column_kind(index),
847            Table::Sparse(table) => table.binned_column_kind(index),
848        }
849    }
850
851    fn is_binary_binned_feature(&self, index: usize) -> bool {
852        match self {
853            Table::Dense(table) => table.is_binary_binned_feature(index),
854            Table::Sparse(table) => table.is_binary_binned_feature(index),
855        }
856    }
857
858    fn target_value(&self, row_index: usize) -> f64 {
859        match self {
860            Table::Dense(table) => table.target().value(row_index),
861            Table::Sparse(table) => table.target().value(row_index),
862        }
863    }
864}
865
866fn preprocess_rows(x: &[Vec<f64>], y: Vec<f64>) -> Result<PreprocessedRows, DenseTableError> {
867    validate_shape(x, &y)?;
868    let n_rows = x.len();
869    let n_features = x.first().map_or(0, Vec::len);
870    let columns = collect_columns(x, n_features);
871    Ok((columns, Float64Array::from(y), n_rows, n_features))
872}
873
874fn validate_shape(x: &[Vec<f64>], y: &[f64]) -> Result<(), DenseTableError> {
875    if x.len() != y.len() {
876        return Err(DenseTableError::MismatchedLengths {
877            x: x.len(),
878            y: y.len(),
879        });
880    }
881
882    let n_features = x.first().map_or(0, Vec::len);
883    for (row_idx, row) in x.iter().enumerate() {
884        if row.len() != n_features {
885            return Err(DenseTableError::RaggedRows {
886                row: row_idx,
887                expected: n_features,
888                actual: row.len(),
889            });
890        }
891    }
892
893    Ok(())
894}
895
896fn collect_columns(x: &[Vec<f64>], n_features: usize) -> Vec<Vec<f64>> {
897    (0..n_features)
898        .map(|col_idx| x.iter().map(|row| row[col_idx]).collect())
899        .collect()
900}
901
902fn validate_binary_columns(columns: &[Vec<f64>]) -> Result<(), DenseTableError> {
903    for (column_idx, column) in columns.iter().enumerate() {
904        if !is_binary_column(column) {
905            return Err(DenseTableError::NonBinaryColumn { column: column_idx });
906        }
907    }
908
909    Ok(())
910}
911
912fn build_feature_column(values: &[f64]) -> FeatureColumn {
913    if is_binary_column(values) {
914        FeatureColumn::Binary(BooleanArray::from(to_binary_values(values)))
915    } else {
916        FeatureColumn::Numeric(Float64Array::from(values.to_vec()))
917    }
918}
919
920fn build_binned_feature_column(values: &[f64], numeric_bins: NumericBins) -> BinnedFeatureColumn {
921    if is_binary_column(values) {
922        BinnedFeatureColumn::Binary(BooleanArray::from(to_binary_values(values)))
923    } else {
924        let bins = bin_numeric_column(values, numeric_bins);
925        if bins.iter().all(|value| *value <= u16::from(u8::MAX)) {
926            BinnedFeatureColumn::NumericU8(UInt8Array::from(
927                bins.into_iter()
928                    .map(|value| value as u8)
929                    .collect::<Vec<_>>(),
930            ))
931        } else {
932            BinnedFeatureColumn::NumericU16(UInt16Array::from(bins))
933        }
934    }
935}
936
937fn is_binary_column(values: &[f64]) -> bool {
938    values.iter().all(|value| {
939        matches!(value.total_cmp(&0.0), Ordering::Equal)
940            || matches!(value.total_cmp(&1.0), Ordering::Equal)
941    })
942}
943
944fn to_binary_values(values: &[f64]) -> Vec<bool> {
945    values
946        .iter()
947        .map(|value| value.total_cmp(&1.0) == Ordering::Equal)
948        .collect()
949}
950
951fn sparse_binary_column_from_values(values: &[f64]) -> SparseBinaryColumn {
952    SparseBinaryColumn {
953        row_indices: values
954            .iter()
955            .enumerate()
956            .filter_map(|(row_idx, value)| {
957                (value.total_cmp(&1.0) == Ordering::Equal).then_some(row_idx)
958            })
959            .collect(),
960    }
961}
962
963pub fn numeric_bin_boundaries(values: &[f64], numeric_bins: NumericBins) -> Vec<(u16, f64)> {
964    if values.is_empty() {
965        return Vec::new();
966    }
967
968    let mut ranked_values: Vec<(usize, f64)> = values.iter().copied().enumerate().collect();
969    ranked_values.sort_by(|left, right| left.1.total_cmp(&right.1));
970
971    let unique_value_count = ranked_values
972        .iter()
973        .map(|(_row_idx, value)| *value)
974        .fold(Vec::<f64>::new(), |mut unique_values, value| {
975            let is_new_value = unique_values
976                .last()
977                .is_none_or(|last_value| last_value.total_cmp(&value) != Ordering::Equal);
978            if is_new_value {
979                unique_values.push(value);
980            }
981            unique_values
982        })
983        .len();
984
985    let bin_count = resolved_numeric_bin_count(values.len(), unique_value_count, numeric_bins);
986    let mut unique_rank = 0usize;
987    let mut start = 0usize;
988    let mut boundaries = Vec::new();
989
990    while start < ranked_values.len() {
991        let current_value = ranked_values[start].1;
992        let end = ranked_values[start..]
993            .iter()
994            .position(|(_row_idx, value)| value.total_cmp(&current_value) != Ordering::Equal)
995            .map_or(ranked_values.len(), |offset| start + offset);
996
997        let bin = match numeric_bins {
998            NumericBins::Auto => ((start * bin_count) / values.len()) as u16,
999            NumericBins::Fixed(_) => {
1000                let max_bin = (bin_count - 1) as u16;
1001                if unique_value_count == 1 {
1002                    0
1003                } else {
1004                    ((unique_rank * usize::from(max_bin)) / (unique_value_count - 1)) as u16
1005                }
1006            }
1007        };
1008
1009        if let Some((last_bin, last_upper_bound)) = boundaries.last_mut() {
1010            if *last_bin == bin {
1011                *last_upper_bound = current_value;
1012            } else {
1013                boundaries.push((bin, current_value));
1014            }
1015        } else {
1016            boundaries.push((bin, current_value));
1017        }
1018
1019        unique_rank += 1;
1020        start = end;
1021    }
1022
1023    boundaries
1024}
1025
1026fn bin_numeric_column(values: &[f64], numeric_bins: NumericBins) -> Vec<u16> {
1027    if values.is_empty() {
1028        return Vec::new();
1029    }
1030
1031    let mut ranked_values: Vec<(usize, f64)> = values.iter().copied().enumerate().collect();
1032    ranked_values.sort_by(|left, right| left.1.total_cmp(&right.1));
1033
1034    let unique_value_count = ranked_values
1035        .iter()
1036        .map(|(_row_idx, value)| *value)
1037        .fold(Vec::<f64>::new(), |mut unique_values, value| {
1038            let is_new_value = unique_values
1039                .last()
1040                .is_none_or(|last_value| last_value.total_cmp(&value) != Ordering::Equal);
1041            if is_new_value {
1042                unique_values.push(value);
1043            }
1044            unique_values
1045        })
1046        .len();
1047
1048    let mut bins = vec![0u16; values.len()];
1049    let bin_count = resolved_numeric_bin_count(values.len(), unique_value_count, numeric_bins);
1050    let mut unique_rank = 0usize;
1051    let mut start = 0usize;
1052
1053    while start < ranked_values.len() {
1054        let current_value = ranked_values[start].1;
1055        let end = ranked_values[start..]
1056            .iter()
1057            .position(|(_row_idx, value)| value.total_cmp(&current_value) != Ordering::Equal)
1058            .map_or(ranked_values.len(), |offset| start + offset);
1059
1060        let bin = match numeric_bins {
1061            NumericBins::Auto => ((start * bin_count) / values.len()) as u16,
1062            NumericBins::Fixed(_) => {
1063                let max_bin = (bin_count - 1) as u16;
1064                if unique_value_count == 1 {
1065                    0
1066                } else {
1067                    ((unique_rank * usize::from(max_bin)) / (unique_value_count - 1)) as u16
1068                }
1069            }
1070        };
1071
1072        for (row_idx, _value) in &ranked_values[start..end] {
1073            bins[*row_idx] = bin;
1074        }
1075
1076        unique_rank += 1;
1077        start = end;
1078    }
1079
1080    bins
1081}
1082
1083fn resolved_numeric_bin_count(
1084    value_count: usize,
1085    unique_value_count: usize,
1086    numeric_bins: NumericBins,
1087) -> usize {
1088    match numeric_bins {
1089        NumericBins::Auto => {
1090            let populated_bin_cap = (value_count / 2).max(1);
1091            let capped_unique_values = unique_value_count
1092                .min(MAX_NUMERIC_BINS)
1093                .min(populated_bin_cap)
1094                .max(1);
1095            highest_power_of_two_at_most(capped_unique_values)
1096        }
1097        NumericBins::Fixed(requested) => requested.min(unique_value_count).max(1),
1098    }
1099}
1100
1101fn highest_power_of_two_at_most(value: usize) -> usize {
1102    if value <= 1 {
1103        1
1104    } else {
1105        1usize << (usize::BITS as usize - 1 - value.leading_zeros() as usize)
1106    }
1107}
1108
1109fn shuffle_canary_column(
1110    values: &BinnedFeatureColumn,
1111    copy_index: usize,
1112    source_index: usize,
1113) -> BinnedFeatureColumn {
1114    match values {
1115        BinnedFeatureColumn::NumericU8(values) => {
1116            let mut shuffled = (0..values.len())
1117                .map(|idx| values.value(idx))
1118                .collect::<Vec<_>>();
1119            shuffle_values(&mut shuffled, copy_index, source_index);
1120            BinnedFeatureColumn::NumericU8(UInt8Array::from(shuffled))
1121        }
1122        BinnedFeatureColumn::NumericU16(values) => {
1123            let mut shuffled = (0..values.len())
1124                .map(|idx| values.value(idx))
1125                .collect::<Vec<_>>();
1126            shuffle_values(&mut shuffled, copy_index, source_index);
1127            BinnedFeatureColumn::NumericU16(UInt16Array::from(shuffled))
1128        }
1129        BinnedFeatureColumn::Binary(values) => {
1130            BinnedFeatureColumn::Binary(shuffle_boolean_array(values, copy_index, source_index))
1131        }
1132    }
1133}
1134
1135fn shuffle_boolean_array(
1136    values: &BooleanArray,
1137    copy_index: usize,
1138    source_index: usize,
1139) -> BooleanArray {
1140    let mut shuffled = (0..values.len())
1141        .map(|idx| values.value(idx))
1142        .collect::<Vec<_>>();
1143    shuffle_values(&mut shuffled, copy_index, source_index);
1144    BooleanArray::from(shuffled)
1145}
1146
1147fn shuffle_sparse_binary_column(
1148    values: &SparseBinaryColumn,
1149    n_rows: usize,
1150    copy_index: usize,
1151    source_index: usize,
1152) -> SparseBinaryColumn {
1153    let mut dense = vec![false; n_rows];
1154    for row_idx in &values.row_indices {
1155        dense[*row_idx] = true;
1156    }
1157    shuffle_values(&mut dense, copy_index, source_index);
1158    SparseBinaryColumn {
1159        row_indices: dense
1160            .into_iter()
1161            .enumerate()
1162            .filter_map(|(row_idx, value)| value.then_some(row_idx))
1163            .collect(),
1164    }
1165}
1166
1167fn shuffle_values<T>(values: &mut [T], copy_index: usize, source_index: usize) {
1168    let seed = 0xA11CE5EED_u64
1169        ^ ((copy_index as u64) << 32)
1170        ^ (source_index as u64)
1171        ^ ((values.len() as u64) << 16);
1172    let mut rng = StdRng::seed_from_u64(seed);
1173    values.shuffle(&mut rng);
1174}
1175
1176#[cfg(test)]
1177mod tests {
1178    use super::*;
1179    use std::collections::{BTreeMap, BTreeSet};
1180
1181    #[test]
1182    fn builds_arrow_backed_dense_table() {
1183        let table =
1184            DenseTable::new(vec![vec![0.0, 10.0], vec![1.0, 20.0]], vec![3.0, 5.0]).unwrap();
1185
1186        assert_eq!(table.n_rows(), 2);
1187        assert_eq!(table.n_features(), 2);
1188        assert_eq!(table.canaries(), 2);
1189        assert_eq!(table.binned_feature_count(), 6);
1190        assert_eq!(table.feature_value(0, 0), 0.0);
1191        assert_eq!(table.feature_value(0, 1), 1.0);
1192        assert_eq!(table.target().value(0), 3.0);
1193        assert_eq!(table.target().value(1), 5.0);
1194        assert!(!table.is_canary_binned_feature(0));
1195        assert!(table.is_canary_binned_feature(2));
1196    }
1197
1198    #[test]
1199    fn builds_sparse_table_for_all_binary_features() {
1200        let table = Table::with_canaries(
1201            vec![vec![0.0, 1.0], vec![1.0, 0.0], vec![1.0, 1.0]],
1202            vec![0.0, 1.0, 1.0],
1203            1,
1204        )
1205        .unwrap();
1206
1207        assert_eq!(table.kind(), TableKind::Sparse);
1208        assert!(table.is_binary_feature(0));
1209        assert!(table.is_binary_feature(1));
1210        assert!(table.is_binary_binned_feature(0));
1211        assert_eq!(table.binned_feature_count(), 4);
1212    }
1213
1214    #[test]
1215    fn builds_dense_table_when_any_feature_is_non_binary() {
1216        let table = Table::with_canaries(
1217            vec![vec![0.0, 1.5], vec![1.0, 0.0], vec![1.0, 2.0]],
1218            vec![0.0, 1.0, 1.0],
1219            1,
1220        )
1221        .unwrap();
1222
1223        assert_eq!(table.kind(), TableKind::Dense);
1224        assert!(table.is_binary_feature(0));
1225        assert!(!table.is_binary_feature(1));
1226    }
1227
1228    #[test]
1229    fn sparse_table_rejects_non_binary_columns() {
1230        let err =
1231            SparseTable::with_canaries(vec![vec![0.0, 2.0], vec![1.0, 0.0]], vec![0.0, 1.0], 0)
1232                .unwrap_err();
1233
1234        assert_eq!(err, DenseTableError::NonBinaryColumn { column: 1 });
1235    }
1236
1237    #[test]
1238    fn auto_bins_numeric_columns_into_power_of_two_bins_up_to_128() {
1239        let x: Vec<Vec<f64>> = (0..1024).map(|value| vec![value as f64]).collect();
1240        let y: Vec<f64> = vec![1.0; 1024];
1241
1242        let table = DenseTable::with_canaries(x, y, 0).unwrap();
1243
1244        assert_eq!(table.binned_value(0, 0), 0);
1245        assert_eq!(table.binned_value(0, 1023), 127);
1246        assert!((1..1024).all(|idx| table.binned_value(0, idx - 1) <= table.binned_value(0, idx)));
1247        assert_eq!(
1248            (0..1024)
1249                .map(|idx| table.binned_value(0, idx))
1250                .collect::<BTreeSet<_>>()
1251                .len(),
1252            128
1253        );
1254    }
1255
1256    #[test]
1257    fn auto_bins_choose_highest_populated_power_of_two() {
1258        let x: Vec<Vec<f64>> = (0..300).map(|value| vec![value as f64]).collect();
1259        let y = vec![0.0; 300];
1260
1261        let table = DenseTable::with_canaries(x, y, 0).unwrap();
1262
1263        assert_eq!(
1264            (0..300)
1265                .map(|idx| table.binned_value(0, idx))
1266                .collect::<BTreeSet<_>>()
1267                .len(),
1268            128
1269        );
1270    }
1271
1272    #[test]
1273    fn auto_bins_require_at_least_two_rows_per_bin() {
1274        let x: Vec<Vec<f64>> = (0..8).map(|value| vec![value as f64]).collect();
1275        let y = vec![0.0; 8];
1276
1277        let table = DenseTable::with_canaries(x, y, 0).unwrap();
1278        let counts = (0..table.n_rows()).fold(BTreeMap::new(), |mut counts, row_idx| {
1279            *counts
1280                .entry(table.binned_value(0, row_idx))
1281                .or_insert(0usize) += 1;
1282            counts
1283        });
1284
1285        assert_eq!(counts.len(), 4);
1286        assert!(counts.values().all(|count| *count >= 2));
1287    }
1288
1289    #[test]
1290    fn fixed_bins_cap_numeric_columns_to_requested_limit() {
1291        let x: Vec<Vec<f64>> = (0..300).map(|value| vec![value as f64]).collect();
1292        let y = vec![0.0; 300];
1293
1294        let table = DenseTable::with_options(x, y, 0, NumericBins::Fixed(64)).unwrap();
1295
1296        assert_eq!(
1297            (0..300)
1298                .map(|idx| table.binned_value(0, idx))
1299                .collect::<BTreeSet<_>>()
1300                .len(),
1301            64
1302        );
1303    }
1304
1305    #[test]
1306    fn rejects_invalid_fixed_bin_count() {
1307        assert_eq!(
1308            NumericBins::fixed(0).unwrap_err(),
1309            DenseTableError::InvalidBinCount { requested: 0 }
1310        );
1311        assert_eq!(
1312            NumericBins::fixed(513).unwrap_err(),
1313            DenseTableError::InvalidBinCount { requested: 513 }
1314        );
1315    }
1316
1317    #[test]
1318    fn keeps_equal_values_in_the_same_bin() {
1319        let table = DenseTable::with_canaries(
1320            vec![vec![0.0], vec![0.0], vec![1.0], vec![1.0], vec![2.0]],
1321            vec![0.0; 5],
1322            0,
1323        )
1324        .unwrap();
1325
1326        assert_eq!(table.binned_value(0, 0), table.binned_value(0, 1));
1327        assert_eq!(table.binned_value(0, 2), table.binned_value(0, 3));
1328        assert!(table.binned_value(0, 1) <= table.binned_value(0, 2));
1329        assert!(table.binned_value(0, 3) < table.binned_value(0, 4));
1330    }
1331
1332    #[test]
1333    fn stores_binary_columns_as_booleans() {
1334        let table = DenseTable::with_canaries(
1335            vec![vec![0.0, 2.0], vec![1.0, 3.0], vec![0.0, 4.0]],
1336            vec![0.0; 3],
1337            1,
1338        )
1339        .unwrap();
1340
1341        assert!(table.is_binary_feature(0));
1342        assert!(!table.is_binary_feature(1));
1343        assert!(table.is_binary_binned_feature(0));
1344        assert!(!table.is_binary_binned_feature(1));
1345        assert!(table.is_binary_binned_feature(2));
1346        assert_eq!(table.feature_value(0, 0), 0.0);
1347        assert_eq!(table.feature_value(0, 1), 1.0);
1348        assert_eq!(table.binned_boolean_value(0, 0), Some(false));
1349        assert_eq!(table.binned_boolean_value(0, 1), Some(true));
1350    }
1351
1352    #[test]
1353    fn stores_small_auto_binned_numeric_columns_as_u8() {
1354        let table = DenseTable::with_canaries(
1355            (0..8).map(|value| vec![value as f64]).collect(),
1356            vec![0.0; 8],
1357            0,
1358        )
1359        .unwrap();
1360
1361        assert!(matches!(
1362            table.binned_feature_column(0),
1363            BinnedFeatureColumnRef::NumericU8(_)
1364        ));
1365    }
1366
1367    #[test]
1368    fn creates_canary_columns_as_shuffled_binned_copies() {
1369        let table = DenseTable::with_canaries(
1370            vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]],
1371            vec![0.0; 5],
1372            1,
1373        )
1374        .unwrap();
1375
1376        assert!(matches!(
1377            table.binned_column_kind(1),
1378            BinnedColumnKind::Canary {
1379                source_index: 0,
1380                copy_index: 0
1381            }
1382        ));
1383        assert_eq!(
1384            (0..table.n_rows())
1385                .map(|idx| table.binned_value(0, idx))
1386                .collect::<BTreeSet<_>>(),
1387            (0..table.n_rows())
1388                .map(|idx| table.binned_value(1, idx))
1389                .collect::<BTreeSet<_>>()
1390        );
1391        assert_ne!(
1392            (0..table.n_rows())
1393                .map(|idx| table.binned_value(0, idx))
1394                .collect::<Vec<_>>(),
1395            (0..table.n_rows())
1396                .map(|idx| table.binned_value(1, idx))
1397                .collect::<Vec<_>>()
1398        );
1399    }
1400
1401    #[test]
1402    fn rejects_ragged_rows() {
1403        let err = DenseTable::new(vec![vec![1.0, 2.0], vec![3.0]], vec![1.0, 2.0]).unwrap_err();
1404
1405        assert_eq!(
1406            err,
1407            DenseTableError::RaggedRows {
1408                row: 1,
1409                expected: 2,
1410                actual: 1,
1411            }
1412        );
1413    }
1414
1415    #[test]
1416    fn rejects_mismatched_lengths() {
1417        let err = DenseTable::new(vec![vec![1.0], vec![2.0]], vec![1.0]).unwrap_err();
1418
1419        assert_eq!(err, DenseTableError::MismatchedLengths { x: 2, y: 1 });
1420    }
1421
1422    #[test]
1423    fn canary_generation_is_deterministic_for_identical_inputs() {
1424        let left = DenseTable::with_canaries(
1425            vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]],
1426            vec![0.0; 5],
1427            2,
1428        )
1429        .unwrap();
1430        let right = DenseTable::with_canaries(
1431            vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]],
1432            vec![0.0; 5],
1433            2,
1434        )
1435        .unwrap();
1436
1437        let left_values = binned_snapshot(&left);
1438        let right_values = binned_snapshot(&right);
1439
1440        assert_eq!(left_values, right_values);
1441    }
1442
1443    #[test]
1444    fn binary_canaries_remain_boolean_and_preserve_value_counts() {
1445        let table = DenseTable::with_canaries(
1446            vec![
1447                vec![0.0],
1448                vec![1.0],
1449                vec![0.0],
1450                vec![1.0],
1451                vec![1.0],
1452                vec![0.0],
1453            ],
1454            vec![0.0; 6],
1455            2,
1456        )
1457        .unwrap();
1458
1459        let real_true_count = (0..table.n_rows())
1460            .filter(|row_idx| table.binned_boolean_value(0, *row_idx) == Some(true))
1461            .count();
1462
1463        for feature_index in 1..table.binned_feature_count() {
1464            assert!(table.is_binary_binned_feature(feature_index));
1465            let canary_true_count = (0..table.n_rows())
1466                .filter(|row_idx| table.binned_boolean_value(feature_index, *row_idx) == Some(true))
1467                .count();
1468            assert_eq!(canary_true_count, real_true_count);
1469        }
1470    }
1471
1472    #[test]
1473    fn numeric_bin_boundaries_capture_training_bin_upper_bounds() {
1474        let boundaries = numeric_bin_boundaries(&[1.0, 1.0, 2.0, 10.0], NumericBins::Auto);
1475
1476        assert_eq!(boundaries, vec![(0, 1.0), (1, 10.0)]);
1477    }
1478
1479    fn binned_snapshot(table: &DenseTable) -> Vec<u16> {
1480        let mut values = Vec::new();
1481
1482        for feature_idx in 0..table.binned_feature_count() {
1483            for row_idx in 0..table.n_rows() {
1484                values.push(table.binned_value(feature_idx, row_idx));
1485            }
1486        }
1487
1488        values
1489    }
1490}