Skip to main content

forestfire_data/
lib.rs

1//! Data layer shared by training and inference.
2//!
3//! The key design choice in this crate is that both dense and sparse storage are
4//! exposed through the same [`TableAccess`] trait. Trainers therefore reason
5//! about "rows, features, bins, and targets" instead of about concrete storage
6//! formats. That keeps sampling views, inference-time reconstructed tables, and
7//! Arrow-backed tables interoperable.
8
9use arrow::array::{Array, BooleanArray, Float64Array, UInt8Array, UInt16Array};
10use rand::seq::SliceRandom;
11use rand::{SeedableRng, rngs::StdRng};
12use std::cmp::Ordering;
13use std::error::Error;
14use std::fmt::{Display, Formatter};
15
16pub const MAX_NUMERIC_BINS: usize = 128;
17const DEFAULT_CANARIES: usize = 2;
18pub const BINARY_MISSING_BIN: u16 = 2;
19
20type PreprocessedRows = (Vec<Vec<f64>>, Float64Array, usize, usize);
21
22/// Common interface consumed by tree trainers and predictors.
23///
24/// The trait exposes both raw floating-point values and pre-binned values.
25/// Training mostly works on the binned representation for speed and predictable
26/// threshold semantics; the raw values are still available for export and some
27/// user-facing reconstruction tasks.
28pub trait TableAccess: Sync {
29    fn n_rows(&self) -> usize;
30    fn n_features(&self) -> usize;
31    fn canaries(&self) -> usize;
32    fn numeric_bin_cap(&self) -> usize;
33    fn binned_feature_count(&self) -> usize;
34    fn feature_value(&self, feature_index: usize, row_index: usize) -> f64;
35    fn is_missing(&self, feature_index: usize, row_index: usize) -> bool;
36    fn is_binary_feature(&self, index: usize) -> bool;
37    fn binned_value(&self, feature_index: usize, row_index: usize) -> u16;
38    fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool>;
39    fn binned_column_kind(&self, index: usize) -> BinnedColumnKind;
40    fn is_binary_binned_feature(&self, index: usize) -> bool;
41    fn target_value(&self, row_index: usize) -> f64;
42
43    fn is_canary_binned_feature(&self, index: usize) -> bool {
44        matches!(
45            self.binned_column_kind(index),
46            BinnedColumnKind::Canary { .. }
47        )
48    }
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum TableKind {
53    /// Standard dense row/column storage.
54    Dense,
55    /// Sparse binary feature storage.
56    Sparse,
57}
58
59#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
60pub enum NumericBins {
61    /// Use the crate-wide default bin cap.
62    #[default]
63    Auto,
64    /// Use a fixed number of bins for every numeric feature.
65    Fixed(usize),
66}
67
68impl NumericBins {
69    pub fn fixed(requested: usize) -> Result<Self, DenseTableError> {
70        if requested == 0 || requested > MAX_NUMERIC_BINS {
71            return Err(DenseTableError::InvalidBinCount { requested });
72        }
73        Ok(Self::Fixed(requested))
74    }
75
76    pub fn cap(self) -> usize {
77        match self {
78            NumericBins::Auto => MAX_NUMERIC_BINS,
79            NumericBins::Fixed(requested) => requested,
80        }
81    }
82}
83
84/// Arrow-backed dense table for tabular regression/classification data.
85///
86/// Features are ingested once, converted into a raw Arrow representation, and
87/// then binned into compact integer columns. The model trainers work primarily
88/// from those binned columns.
89#[derive(Debug, Clone)]
90pub struct DenseTable {
91    feature_columns: Vec<FeatureColumn>,
92    binned_feature_columns: Vec<BinnedFeatureColumn>,
93    binned_column_kinds: Vec<BinnedColumnKind>,
94    target: Float64Array,
95    n_rows: usize,
96    n_features: usize,
97    canaries: usize,
98    numeric_bins: NumericBins,
99}
100
101/// Arrow-backed sparse table specialized for binary feature matrices.
102///
103/// This is intentionally specialized to binary data. The sparse path is mainly
104/// about avoiding the memory cost of materializing dense one-hot-like matrices.
105#[derive(Debug, Clone)]
106pub struct SparseTable {
107    feature_columns: Vec<SparseBinaryColumn>,
108    binned_feature_columns: Vec<SparseBinaryColumn>,
109    binned_column_kinds: Vec<BinnedColumnKind>,
110    target: Float64Array,
111    n_rows: usize,
112    n_features: usize,
113    canaries: usize,
114    numeric_bins: NumericBins,
115}
116
117#[derive(Debug, Clone)]
118struct SparseBinaryColumn {
119    row_indices: Vec<usize>,
120}
121
122impl SparseBinaryColumn {
123    fn value(&self, row_index: usize) -> bool {
124        self.row_indices.binary_search(&row_index).is_ok()
125    }
126}
127
128#[derive(Debug, Clone)]
129pub enum Table {
130    /// Dense numeric/binary table.
131    Dense(DenseTable),
132    /// Sparse binary table.
133    Sparse(SparseTable),
134}
135
136#[derive(Debug, Clone)]
137enum FeatureColumn {
138    Numeric(Float64Array),
139    Binary(BooleanArray),
140}
141
142#[derive(Debug, Clone)]
143enum BinnedFeatureColumn {
144    NumericU8(UInt8Array),
145    NumericU16(UInt16Array),
146    Binary(BooleanArray),
147}
148
149#[derive(Debug, Clone, Copy)]
150pub enum FeatureColumnRef<'a> {
151    Numeric(&'a Float64Array),
152    Binary(&'a BooleanArray),
153}
154
155#[derive(Debug, Clone, Copy)]
156pub enum BinnedFeatureColumnRef<'a> {
157    NumericU8(&'a UInt8Array),
158    NumericU16(&'a UInt16Array),
159    Binary(&'a BooleanArray),
160}
161
162#[derive(Debug, Clone, Copy, PartialEq, Eq)]
163pub enum BinnedColumnKind {
164    /// Real feature originating from the input table.
165    Real { source_index: usize },
166    /// Shuffled copy of a real feature used as a canary.
167    Canary {
168        source_index: usize,
169        copy_index: usize,
170    },
171}
172
173#[derive(Debug, Clone, PartialEq, Eq)]
174pub enum DenseTableError {
175    MismatchedLengths {
176        x: usize,
177        y: usize,
178    },
179    RaggedRows {
180        row: usize,
181        expected: usize,
182        actual: usize,
183    },
184    NonBinaryColumn {
185        column: usize,
186    },
187    InvalidBinCount {
188        requested: usize,
189    },
190}
191
192impl Display for DenseTableError {
193    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
194        match self {
195            DenseTableError::MismatchedLengths { x, y } => write!(
196                f,
197                "Mismatched lengths: X has {} rows while y has {} values.",
198                x, y
199            ),
200            DenseTableError::RaggedRows {
201                row,
202                expected,
203                actual,
204            } => write!(
205                f,
206                "Ragged row at index {}: expected {} columns, found {}.",
207                row, expected, actual
208            ),
209            DenseTableError::NonBinaryColumn { column } => write!(
210                f,
211                "SparseTable requires binary features, but column {} contains non-binary values.",
212                column
213            ),
214            DenseTableError::InvalidBinCount { requested } => write!(
215                f,
216                "Invalid bins value {}. Expected 'auto' or an integer between 1 and {}.",
217                requested, MAX_NUMERIC_BINS
218            ),
219        }
220    }
221}
222
223impl Error for DenseTableError {}
224
225impl DenseTable {
226    /// Construct a dense table with the default canary count and automatic
227    /// numeric binning.
228    pub fn new(x: Vec<Vec<f64>>, y: Vec<f64>) -> Result<Self, DenseTableError> {
229        Self::with_canaries(x, y, DEFAULT_CANARIES)
230    }
231
232    /// Construct a dense table while choosing how many canary copies to append.
233    pub fn with_canaries(
234        x: Vec<Vec<f64>>,
235        y: Vec<f64>,
236        canaries: usize,
237    ) -> Result<Self, DenseTableError> {
238        Self::with_options(x, y, canaries, NumericBins::Auto)
239    }
240
241    /// Construct a dense table with full preprocessing control.
242    pub fn with_options(
243        x: Vec<Vec<f64>>,
244        y: Vec<f64>,
245        canaries: usize,
246        numeric_bins: NumericBins,
247    ) -> Result<Self, DenseTableError> {
248        let (columns, target, n_rows, n_features) = preprocess_rows(&x, y)?;
249        Ok(Self::from_columns(
250            &columns,
251            target,
252            n_rows,
253            n_features,
254            canaries,
255            numeric_bins,
256        ))
257    }
258
259    fn from_columns(
260        columns: &[Vec<f64>],
261        target: Float64Array,
262        n_rows: usize,
263        n_features: usize,
264        canaries: usize,
265        numeric_bins: NumericBins,
266    ) -> Self {
267        let feature_columns = columns
268            .iter()
269            .map(|column| build_feature_column(column))
270            .collect();
271
272        let real_binned_columns: Vec<BinnedFeatureColumn> = columns
273            .iter()
274            .map(|column| build_binned_feature_column(column, numeric_bins))
275            .collect();
276        let canary_columns: Vec<(BinnedColumnKind, BinnedFeatureColumn)> = (0..canaries)
277            .flat_map(|copy_index| {
278                real_binned_columns
279                    .iter()
280                    .enumerate()
281                    .map(move |(source_index, column)| {
282                        (
283                            BinnedColumnKind::Canary {
284                                source_index,
285                                copy_index,
286                            },
287                            // Canary columns are shuffled copies of real columns.
288                            // They are designed to look statistically plausible
289                            // while carrying no real signal.
290                            shuffle_canary_column(column, copy_index, source_index),
291                        )
292                    })
293            })
294            .collect();
295
296        let (binned_column_kinds, binned_feature_columns): (Vec<_>, Vec<_>) = (0..n_features)
297            .map(|source_index| BinnedColumnKind::Real { source_index })
298            .zip(real_binned_columns)
299            .chain(canary_columns)
300            .unzip();
301
302        Self {
303            feature_columns,
304            binned_feature_columns,
305            binned_column_kinds,
306            target,
307            n_rows,
308            n_features,
309            canaries,
310            numeric_bins,
311        }
312    }
313
314    #[inline]
315    pub fn n_rows(&self) -> usize {
316        self.n_rows
317    }
318
319    #[inline]
320    pub fn n_features(&self) -> usize {
321        self.n_features
322    }
323
324    #[inline]
325    pub fn canaries(&self) -> usize {
326        self.canaries
327    }
328
329    #[inline]
330    pub fn numeric_bin_cap(&self) -> usize {
331        self.numeric_bins.cap()
332    }
333
334    #[inline]
335    pub fn binned_feature_count(&self) -> usize {
336        self.binned_feature_columns.len()
337    }
338
339    #[inline]
340    pub fn feature_column(&self, index: usize) -> FeatureColumnRef<'_> {
341        match &self.feature_columns[index] {
342            FeatureColumn::Numeric(column) => FeatureColumnRef::Numeric(column),
343            FeatureColumn::Binary(column) => FeatureColumnRef::Binary(column),
344        }
345    }
346
347    #[inline]
348    pub fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
349        match &self.feature_columns[feature_index] {
350            FeatureColumn::Numeric(column) => column.value(row_index),
351            FeatureColumn::Binary(column) => {
352                if column.is_null(row_index) {
353                    f64::NAN
354                } else {
355                    f64::from(u8::from(column.value(row_index)))
356                }
357            }
358        }
359    }
360
361    #[inline]
362    pub fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
363        match &self.feature_columns[feature_index] {
364            FeatureColumn::Numeric(column) => column.value(row_index).is_nan(),
365            FeatureColumn::Binary(column) => column.is_null(row_index),
366        }
367    }
368
369    #[inline]
370    pub fn is_binary_feature(&self, index: usize) -> bool {
371        matches!(self.feature_columns[index], FeatureColumn::Binary(_))
372    }
373
374    #[inline]
375    pub fn binned_feature_column(&self, index: usize) -> BinnedFeatureColumnRef<'_> {
376        match &self.binned_feature_columns[index] {
377            BinnedFeatureColumn::NumericU8(column) => BinnedFeatureColumnRef::NumericU8(column),
378            BinnedFeatureColumn::NumericU16(column) => BinnedFeatureColumnRef::NumericU16(column),
379            BinnedFeatureColumn::Binary(column) => BinnedFeatureColumnRef::Binary(column),
380        }
381    }
382
383    #[inline]
384    pub fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
385        match &self.binned_feature_columns[feature_index] {
386            BinnedFeatureColumn::NumericU8(column) => u16::from(column.value(row_index)),
387            BinnedFeatureColumn::NumericU16(column) => column.value(row_index),
388            BinnedFeatureColumn::Binary(column) => {
389                if column.is_null(row_index) {
390                    BINARY_MISSING_BIN
391                } else {
392                    u16::from(u8::from(column.value(row_index)))
393                }
394            }
395        }
396    }
397
398    #[inline]
399    pub fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
400        match &self.binned_feature_columns[feature_index] {
401            BinnedFeatureColumn::Binary(column) => {
402                (!column.is_null(row_index)).then(|| column.value(row_index))
403            }
404            BinnedFeatureColumn::NumericU8(_) | BinnedFeatureColumn::NumericU16(_) => None,
405        }
406    }
407
408    #[inline]
409    pub fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
410        self.binned_column_kinds[index]
411    }
412
413    #[inline]
414    pub fn is_canary_binned_feature(&self, index: usize) -> bool {
415        matches!(
416            self.binned_column_kinds[index],
417            BinnedColumnKind::Canary { .. }
418        )
419    }
420
421    #[inline]
422    pub fn is_binary_binned_feature(&self, index: usize) -> bool {
423        matches!(
424            self.binned_feature_columns[index],
425            BinnedFeatureColumn::Binary(_)
426        )
427    }
428
429    #[inline]
430    pub fn target(&self) -> &Float64Array {
431        &self.target
432    }
433}
434
435impl SparseTable {
436    pub fn new(x: Vec<Vec<f64>>, y: Vec<f64>) -> Result<Self, DenseTableError> {
437        Self::with_canaries(x, y, DEFAULT_CANARIES)
438    }
439
440    pub fn with_canaries(
441        x: Vec<Vec<f64>>,
442        y: Vec<f64>,
443        canaries: usize,
444    ) -> Result<Self, DenseTableError> {
445        Self::with_options(x, y, canaries, NumericBins::Auto)
446    }
447
448    pub fn with_options(
449        x: Vec<Vec<f64>>,
450        y: Vec<f64>,
451        canaries: usize,
452        numeric_bins: NumericBins,
453    ) -> Result<Self, DenseTableError> {
454        let (columns, target, n_rows, n_features) = preprocess_rows(&x, y)?;
455        validate_binary_columns(&columns)?;
456        Ok(Self::from_columns(
457            &columns,
458            target,
459            n_rows,
460            n_features,
461            canaries,
462            numeric_bins,
463        ))
464    }
465
466    fn from_columns(
467        columns: &[Vec<f64>],
468        target: Float64Array,
469        n_rows: usize,
470        n_features: usize,
471        canaries: usize,
472        numeric_bins: NumericBins,
473    ) -> Self {
474        let feature_columns: Vec<SparseBinaryColumn> = columns
475            .iter()
476            .map(|column| sparse_binary_column_from_values(column))
477            .collect();
478
479        let canary_columns: Vec<(BinnedColumnKind, SparseBinaryColumn)> = (0..canaries)
480            .flat_map(|copy_index| {
481                feature_columns
482                    .iter()
483                    .enumerate()
484                    .map(move |(source_index, column)| {
485                        (
486                            BinnedColumnKind::Canary {
487                                source_index,
488                                copy_index,
489                            },
490                            shuffle_sparse_binary_column(column, n_rows, copy_index, source_index),
491                        )
492                    })
493            })
494            .collect();
495
496        let (binned_column_kinds, binned_feature_columns): (Vec<_>, Vec<_>) = (0..n_features)
497            .map(|source_index| BinnedColumnKind::Real { source_index })
498            .zip(feature_columns.iter().cloned())
499            .chain(canary_columns)
500            .unzip();
501
502        Self {
503            feature_columns,
504            binned_feature_columns,
505            binned_column_kinds,
506            target,
507            n_rows,
508            n_features,
509            canaries,
510            numeric_bins,
511        }
512    }
513
514    pub fn from_sparse_binary_columns(
515        n_rows: usize,
516        n_features: usize,
517        columns: Vec<Vec<usize>>,
518        y: Vec<f64>,
519        canaries: usize,
520    ) -> Result<Self, DenseTableError> {
521        Self::from_sparse_binary_columns_with_options(
522            n_rows,
523            n_features,
524            columns,
525            y,
526            canaries,
527            NumericBins::Auto,
528        )
529    }
530
531    pub fn from_sparse_binary_columns_with_options(
532        n_rows: usize,
533        n_features: usize,
534        columns: Vec<Vec<usize>>,
535        y: Vec<f64>,
536        canaries: usize,
537        numeric_bins: NumericBins,
538    ) -> Result<Self, DenseTableError> {
539        if n_rows != y.len() {
540            return Err(DenseTableError::MismatchedLengths {
541                x: n_rows,
542                y: y.len(),
543            });
544        }
545        if n_features != columns.len() {
546            return Err(DenseTableError::RaggedRows {
547                row: columns.len(),
548                expected: n_features,
549                actual: columns.len(),
550            });
551        }
552
553        let feature_columns = columns
554            .into_iter()
555            .enumerate()
556            .map(|(column_idx, mut row_indices)| {
557                row_indices.sort_unstable();
558                row_indices.dedup();
559                if row_indices.iter().any(|row_idx| *row_idx >= n_rows) {
560                    return Err(DenseTableError::NonBinaryColumn { column: column_idx });
561                }
562                Ok(SparseBinaryColumn { row_indices })
563            })
564            .collect::<Result<Vec<_>, _>>()?;
565
566        let canary_columns: Vec<(BinnedColumnKind, SparseBinaryColumn)> = (0..canaries)
567            .flat_map(|copy_index| {
568                feature_columns
569                    .iter()
570                    .enumerate()
571                    .map(move |(source_index, column)| {
572                        (
573                            BinnedColumnKind::Canary {
574                                source_index,
575                                copy_index,
576                            },
577                            shuffle_sparse_binary_column(column, n_rows, copy_index, source_index),
578                        )
579                    })
580            })
581            .collect();
582
583        let (binned_column_kinds, binned_feature_columns): (Vec<_>, Vec<_>) = (0..n_features)
584            .map(|source_index| BinnedColumnKind::Real { source_index })
585            .zip(feature_columns.iter().cloned())
586            .chain(canary_columns)
587            .unzip();
588
589        Ok(Self {
590            feature_columns,
591            binned_feature_columns,
592            binned_column_kinds,
593            target: Float64Array::from(y),
594            n_rows,
595            n_features,
596            canaries,
597            numeric_bins,
598        })
599    }
600
601    #[inline]
602    pub fn n_rows(&self) -> usize {
603        self.n_rows
604    }
605
606    #[inline]
607    pub fn n_features(&self) -> usize {
608        self.n_features
609    }
610
611    #[inline]
612    pub fn canaries(&self) -> usize {
613        self.canaries
614    }
615
616    #[inline]
617    pub fn numeric_bin_cap(&self) -> usize {
618        self.numeric_bins.cap()
619    }
620
621    #[inline]
622    pub fn binned_feature_count(&self) -> usize {
623        self.binned_feature_columns.len()
624    }
625
626    #[inline]
627    pub fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
628        f64::from(u8::from(
629            self.feature_columns[feature_index].value(row_index),
630        ))
631    }
632
633    #[inline]
634    pub fn is_missing(&self, _feature_index: usize, _row_index: usize) -> bool {
635        false
636    }
637
638    #[inline]
639    pub fn is_binary_feature(&self, _index: usize) -> bool {
640        true
641    }
642
643    #[inline]
644    pub fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
645        u16::from(u8::from(
646            self.binned_feature_columns[feature_index].value(row_index),
647        ))
648    }
649
650    #[inline]
651    pub fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
652        Some(self.binned_feature_columns[feature_index].value(row_index))
653    }
654
655    #[inline]
656    pub fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
657        self.binned_column_kinds[index]
658    }
659
660    #[inline]
661    pub fn is_canary_binned_feature(&self, index: usize) -> bool {
662        matches!(
663            self.binned_column_kinds[index],
664            BinnedColumnKind::Canary { .. }
665        )
666    }
667
668    #[inline]
669    pub fn is_binary_binned_feature(&self, _index: usize) -> bool {
670        true
671    }
672
673    #[inline]
674    pub fn target(&self) -> &Float64Array {
675        &self.target
676    }
677}
678
679impl Table {
680    pub fn new(x: Vec<Vec<f64>>, y: Vec<f64>) -> Result<Self, DenseTableError> {
681        Self::with_canaries(x, y, DEFAULT_CANARIES)
682    }
683
684    pub fn with_canaries(
685        x: Vec<Vec<f64>>,
686        y: Vec<f64>,
687        canaries: usize,
688    ) -> Result<Self, DenseTableError> {
689        Self::with_options(x, y, canaries, NumericBins::Auto)
690    }
691
692    pub fn with_options(
693        x: Vec<Vec<f64>>,
694        y: Vec<f64>,
695        canaries: usize,
696        numeric_bins: NumericBins,
697    ) -> Result<Self, DenseTableError> {
698        let (columns, target, n_rows, n_features) = preprocess_rows(&x, y)?;
699
700        if columns
701            .iter()
702            .all(|column| is_sparse_binary_eligible_column(column))
703        {
704            Ok(Self::Sparse(SparseTable::from_columns(
705                &columns,
706                target,
707                n_rows,
708                n_features,
709                canaries,
710                numeric_bins,
711            )))
712        } else {
713            Ok(Self::Dense(DenseTable::from_columns(
714                &columns,
715                target,
716                n_rows,
717                n_features,
718                canaries,
719                numeric_bins,
720            )))
721        }
722    }
723
724    pub fn kind(&self) -> TableKind {
725        match self {
726            Table::Dense(_) => TableKind::Dense,
727            Table::Sparse(_) => TableKind::Sparse,
728        }
729    }
730
731    pub fn as_dense(&self) -> Option<&DenseTable> {
732        match self {
733            Table::Dense(table) => Some(table),
734            Table::Sparse(_) => None,
735        }
736    }
737
738    pub fn as_sparse(&self) -> Option<&SparseTable> {
739        match self {
740            Table::Dense(_) => None,
741            Table::Sparse(table) => Some(table),
742        }
743    }
744}
745
746impl TableAccess for DenseTable {
747    fn n_rows(&self) -> usize {
748        self.n_rows()
749    }
750
751    fn n_features(&self) -> usize {
752        self.n_features()
753    }
754
755    fn canaries(&self) -> usize {
756        self.canaries()
757    }
758
759    fn numeric_bin_cap(&self) -> usize {
760        self.numeric_bin_cap()
761    }
762
763    fn binned_feature_count(&self) -> usize {
764        self.binned_feature_count()
765    }
766
767    fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
768        self.feature_value(feature_index, row_index)
769    }
770
771    fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
772        self.is_missing(feature_index, row_index)
773    }
774
775    fn is_binary_feature(&self, index: usize) -> bool {
776        self.is_binary_feature(index)
777    }
778
779    fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
780        self.binned_value(feature_index, row_index)
781    }
782
783    fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
784        self.binned_boolean_value(feature_index, row_index)
785    }
786
787    fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
788        self.binned_column_kind(index)
789    }
790
791    fn is_binary_binned_feature(&self, index: usize) -> bool {
792        self.is_binary_binned_feature(index)
793    }
794
795    fn target_value(&self, row_index: usize) -> f64 {
796        self.target().value(row_index)
797    }
798}
799
800impl TableAccess for SparseTable {
801    fn n_rows(&self) -> usize {
802        self.n_rows()
803    }
804
805    fn n_features(&self) -> usize {
806        self.n_features()
807    }
808
809    fn canaries(&self) -> usize {
810        self.canaries()
811    }
812
813    fn numeric_bin_cap(&self) -> usize {
814        self.numeric_bin_cap()
815    }
816
817    fn binned_feature_count(&self) -> usize {
818        self.binned_feature_count()
819    }
820
821    fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
822        self.feature_value(feature_index, row_index)
823    }
824
825    fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
826        self.is_missing(feature_index, row_index)
827    }
828
829    fn is_binary_feature(&self, index: usize) -> bool {
830        self.is_binary_feature(index)
831    }
832
833    fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
834        self.binned_value(feature_index, row_index)
835    }
836
837    fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
838        self.binned_boolean_value(feature_index, row_index)
839    }
840
841    fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
842        self.binned_column_kind(index)
843    }
844
845    fn is_binary_binned_feature(&self, index: usize) -> bool {
846        self.is_binary_binned_feature(index)
847    }
848
849    fn target_value(&self, row_index: usize) -> f64 {
850        self.target().value(row_index)
851    }
852}
853
854impl TableAccess for Table {
855    fn n_rows(&self) -> usize {
856        match self {
857            Table::Dense(table) => table.n_rows(),
858            Table::Sparse(table) => table.n_rows(),
859        }
860    }
861
862    fn n_features(&self) -> usize {
863        match self {
864            Table::Dense(table) => table.n_features(),
865            Table::Sparse(table) => table.n_features(),
866        }
867    }
868
869    fn canaries(&self) -> usize {
870        match self {
871            Table::Dense(table) => table.canaries(),
872            Table::Sparse(table) => table.canaries(),
873        }
874    }
875
876    fn numeric_bin_cap(&self) -> usize {
877        match self {
878            Table::Dense(table) => table.numeric_bin_cap(),
879            Table::Sparse(table) => table.numeric_bin_cap(),
880        }
881    }
882
883    fn binned_feature_count(&self) -> usize {
884        match self {
885            Table::Dense(table) => table.binned_feature_count(),
886            Table::Sparse(table) => table.binned_feature_count(),
887        }
888    }
889
890    fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
891        match self {
892            Table::Dense(table) => table.feature_value(feature_index, row_index),
893            Table::Sparse(table) => table.feature_value(feature_index, row_index),
894        }
895    }
896
897    fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
898        match self {
899            Table::Dense(table) => table.is_missing(feature_index, row_index),
900            Table::Sparse(table) => table.is_missing(feature_index, row_index),
901        }
902    }
903
904    fn is_binary_feature(&self, index: usize) -> bool {
905        match self {
906            Table::Dense(table) => table.is_binary_feature(index),
907            Table::Sparse(table) => table.is_binary_feature(index),
908        }
909    }
910
911    fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
912        match self {
913            Table::Dense(table) => table.binned_value(feature_index, row_index),
914            Table::Sparse(table) => table.binned_value(feature_index, row_index),
915        }
916    }
917
918    fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
919        match self {
920            Table::Dense(table) => table.binned_boolean_value(feature_index, row_index),
921            Table::Sparse(table) => table.binned_boolean_value(feature_index, row_index),
922        }
923    }
924
925    fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
926        match self {
927            Table::Dense(table) => table.binned_column_kind(index),
928            Table::Sparse(table) => table.binned_column_kind(index),
929        }
930    }
931
932    fn is_binary_binned_feature(&self, index: usize) -> bool {
933        match self {
934            Table::Dense(table) => table.is_binary_binned_feature(index),
935            Table::Sparse(table) => table.is_binary_binned_feature(index),
936        }
937    }
938
939    fn target_value(&self, row_index: usize) -> f64 {
940        match self {
941            Table::Dense(table) => table.target().value(row_index),
942            Table::Sparse(table) => table.target().value(row_index),
943        }
944    }
945}
946
947fn preprocess_rows(x: &[Vec<f64>], y: Vec<f64>) -> Result<PreprocessedRows, DenseTableError> {
948    validate_shape(x, &y)?;
949    let n_rows = x.len();
950    let n_features = x.first().map_or(0, Vec::len);
951    let columns = collect_columns(x, n_features);
952    Ok((columns, Float64Array::from(y), n_rows, n_features))
953}
954
955fn validate_shape(x: &[Vec<f64>], y: &[f64]) -> Result<(), DenseTableError> {
956    if x.len() != y.len() {
957        return Err(DenseTableError::MismatchedLengths {
958            x: x.len(),
959            y: y.len(),
960        });
961    }
962
963    let n_features = x.first().map_or(0, Vec::len);
964    for (row_idx, row) in x.iter().enumerate() {
965        if row.len() != n_features {
966            return Err(DenseTableError::RaggedRows {
967                row: row_idx,
968                expected: n_features,
969                actual: row.len(),
970            });
971        }
972    }
973
974    Ok(())
975}
976
977fn collect_columns(x: &[Vec<f64>], n_features: usize) -> Vec<Vec<f64>> {
978    (0..n_features)
979        .map(|col_idx| x.iter().map(|row| row[col_idx]).collect())
980        .collect()
981}
982
983fn validate_binary_columns(columns: &[Vec<f64>]) -> Result<(), DenseTableError> {
984    for (column_idx, column) in columns.iter().enumerate() {
985        if !is_sparse_binary_eligible_column(column) {
986            return Err(DenseTableError::NonBinaryColumn { column: column_idx });
987        }
988    }
989
990    Ok(())
991}
992
993fn build_feature_column(values: &[f64]) -> FeatureColumn {
994    if is_binary_column(values) {
995        FeatureColumn::Binary(BooleanArray::from(to_binary_values(values)))
996    } else {
997        FeatureColumn::Numeric(Float64Array::from(values.to_vec()))
998    }
999}
1000
1001fn build_binned_feature_column(values: &[f64], numeric_bins: NumericBins) -> BinnedFeatureColumn {
1002    if is_binary_column(values) {
1003        BinnedFeatureColumn::Binary(BooleanArray::from(to_binary_values(values)))
1004    } else {
1005        let bins = bin_numeric_column(values, numeric_bins);
1006        if bins.iter().all(|value| *value <= u16::from(u8::MAX)) {
1007            BinnedFeatureColumn::NumericU8(UInt8Array::from(
1008                bins.into_iter()
1009                    .map(|value| value as u8)
1010                    .collect::<Vec<_>>(),
1011            ))
1012        } else {
1013            BinnedFeatureColumn::NumericU16(UInt16Array::from(bins))
1014        }
1015    }
1016}
1017
1018fn is_binary_column(values: &[f64]) -> bool {
1019    values.iter().all(|value| {
1020        value.is_nan()
1021            || matches!(value.total_cmp(&0.0), Ordering::Equal)
1022            || matches!(value.total_cmp(&1.0), Ordering::Equal)
1023    })
1024}
1025
1026fn is_sparse_binary_eligible_column(values: &[f64]) -> bool {
1027    values.iter().all(|value| {
1028        matches!(value.total_cmp(&0.0), Ordering::Equal)
1029            || matches!(value.total_cmp(&1.0), Ordering::Equal)
1030    })
1031}
1032
1033fn to_binary_values(values: &[f64]) -> Vec<Option<bool>> {
1034    values
1035        .iter()
1036        .map(|value| {
1037            if value.is_nan() {
1038                None
1039            } else {
1040                Some(value.total_cmp(&1.0) == Ordering::Equal)
1041            }
1042        })
1043        .collect()
1044}
1045
1046fn sparse_binary_column_from_values(values: &[f64]) -> SparseBinaryColumn {
1047    SparseBinaryColumn {
1048        row_indices: values
1049            .iter()
1050            .enumerate()
1051            .filter_map(|(row_idx, value)| {
1052                (value.total_cmp(&1.0) == Ordering::Equal).then_some(row_idx)
1053            })
1054            .collect(),
1055    }
1056}
1057
1058pub fn numeric_bin_boundaries(values: &[f64], numeric_bins: NumericBins) -> Vec<(u16, f64)> {
1059    if values.is_empty() {
1060        return Vec::new();
1061    }
1062
1063    let mut ranked_values: Vec<(usize, f64)> = values
1064        .iter()
1065        .copied()
1066        .enumerate()
1067        .filter(|(_, value)| !value.is_nan())
1068        .collect();
1069    if ranked_values.is_empty() {
1070        return Vec::new();
1071    }
1072    ranked_values.sort_by(|left, right| left.1.total_cmp(&right.1));
1073
1074    let unique_value_count = ranked_values
1075        .iter()
1076        .map(|(_row_idx, value)| *value)
1077        .fold(Vec::<f64>::new(), |mut unique_values, value| {
1078            let is_new_value = unique_values
1079                .last()
1080                .is_none_or(|last_value| last_value.total_cmp(&value) != Ordering::Equal);
1081            if is_new_value {
1082                unique_values.push(value);
1083            }
1084            unique_values
1085        })
1086        .len();
1087
1088    let bin_count =
1089        resolved_numeric_bin_count(ranked_values.len(), unique_value_count, numeric_bins);
1090    let mut unique_rank = 0usize;
1091    let mut start = 0usize;
1092    let mut boundaries = Vec::new();
1093
1094    while start < ranked_values.len() {
1095        let current_value = ranked_values[start].1;
1096        let end = ranked_values[start..]
1097            .iter()
1098            .position(|(_row_idx, value)| value.total_cmp(&current_value) != Ordering::Equal)
1099            .map_or(ranked_values.len(), |offset| start + offset);
1100
1101        let bin = match numeric_bins {
1102            NumericBins::Auto => ((start * bin_count) / ranked_values.len()) as u16,
1103            NumericBins::Fixed(_) => {
1104                let max_bin = (bin_count - 1) as u16;
1105                if unique_value_count == 1 {
1106                    0
1107                } else {
1108                    ((unique_rank * usize::from(max_bin)) / (unique_value_count - 1)) as u16
1109                }
1110            }
1111        };
1112
1113        if let Some((last_bin, last_upper_bound)) = boundaries.last_mut() {
1114            if *last_bin == bin {
1115                *last_upper_bound = current_value;
1116            } else {
1117                boundaries.push((bin, current_value));
1118            }
1119        } else {
1120            boundaries.push((bin, current_value));
1121        }
1122
1123        unique_rank += 1;
1124        start = end;
1125    }
1126
1127    boundaries
1128}
1129
1130fn bin_numeric_column(values: &[f64], numeric_bins: NumericBins) -> Vec<u16> {
1131    if values.is_empty() {
1132        return Vec::new();
1133    }
1134
1135    let missing_bin = numeric_missing_bin(numeric_bins);
1136    let mut bins = vec![missing_bin; values.len()];
1137    let mut ranked_values: Vec<(usize, f64)> = values
1138        .iter()
1139        .copied()
1140        .enumerate()
1141        .filter(|(_, value)| !value.is_nan())
1142        .collect();
1143    if ranked_values.is_empty() {
1144        return bins;
1145    }
1146    ranked_values.sort_by(|left, right| left.1.total_cmp(&right.1));
1147
1148    let unique_value_count = ranked_values
1149        .iter()
1150        .map(|(_row_idx, value)| *value)
1151        .fold(Vec::<f64>::new(), |mut unique_values, value| {
1152            let is_new_value = unique_values
1153                .last()
1154                .is_none_or(|last_value| last_value.total_cmp(&value) != Ordering::Equal);
1155            if is_new_value {
1156                unique_values.push(value);
1157            }
1158            unique_values
1159        })
1160        .len();
1161
1162    let bin_count =
1163        resolved_numeric_bin_count(ranked_values.len(), unique_value_count, numeric_bins);
1164    let mut unique_rank = 0usize;
1165    let mut start = 0usize;
1166
1167    while start < ranked_values.len() {
1168        let current_value = ranked_values[start].1;
1169        let end = ranked_values[start..]
1170            .iter()
1171            .position(|(_row_idx, value)| value.total_cmp(&current_value) != Ordering::Equal)
1172            .map_or(ranked_values.len(), |offset| start + offset);
1173
1174        let bin = match numeric_bins {
1175            NumericBins::Auto => ((start * bin_count) / ranked_values.len()) as u16,
1176            NumericBins::Fixed(_) => {
1177                let max_bin = (bin_count - 1) as u16;
1178                if unique_value_count == 1 {
1179                    0
1180                } else {
1181                    ((unique_rank * usize::from(max_bin)) / (unique_value_count - 1)) as u16
1182                }
1183            }
1184        };
1185
1186        for (row_idx, _value) in &ranked_values[start..end] {
1187            bins[*row_idx] = bin;
1188        }
1189
1190        unique_rank += 1;
1191        start = end;
1192    }
1193
1194    bins
1195}
1196
1197fn resolved_numeric_bin_count(
1198    value_count: usize,
1199    unique_value_count: usize,
1200    numeric_bins: NumericBins,
1201) -> usize {
1202    match numeric_bins {
1203        NumericBins::Auto => {
1204            let populated_bin_cap = (value_count / 2).max(1);
1205            let capped_unique_values = unique_value_count
1206                .min(MAX_NUMERIC_BINS)
1207                .min(populated_bin_cap)
1208                .max(1);
1209            highest_power_of_two_at_most(capped_unique_values)
1210        }
1211        NumericBins::Fixed(requested) => requested.min(unique_value_count).max(1),
1212    }
1213}
1214
1215pub fn numeric_missing_bin(numeric_bins: NumericBins) -> u16 {
1216    numeric_bins.cap() as u16
1217}
1218
1219fn highest_power_of_two_at_most(value: usize) -> usize {
1220    if value <= 1 {
1221        1
1222    } else {
1223        1usize << (usize::BITS as usize - 1 - value.leading_zeros() as usize)
1224    }
1225}
1226
1227fn shuffle_canary_column(
1228    values: &BinnedFeatureColumn,
1229    copy_index: usize,
1230    source_index: usize,
1231) -> BinnedFeatureColumn {
1232    match values {
1233        BinnedFeatureColumn::NumericU8(values) => {
1234            let mut shuffled = (0..values.len())
1235                .map(|idx| values.value(idx))
1236                .collect::<Vec<_>>();
1237            shuffle_values(&mut shuffled, copy_index, source_index);
1238            BinnedFeatureColumn::NumericU8(UInt8Array::from(shuffled))
1239        }
1240        BinnedFeatureColumn::NumericU16(values) => {
1241            let mut shuffled = (0..values.len())
1242                .map(|idx| values.value(idx))
1243                .collect::<Vec<_>>();
1244            shuffle_values(&mut shuffled, copy_index, source_index);
1245            BinnedFeatureColumn::NumericU16(UInt16Array::from(shuffled))
1246        }
1247        BinnedFeatureColumn::Binary(values) => {
1248            BinnedFeatureColumn::Binary(shuffle_boolean_array(values, copy_index, source_index))
1249        }
1250    }
1251}
1252
1253fn shuffle_boolean_array(
1254    values: &BooleanArray,
1255    copy_index: usize,
1256    source_index: usize,
1257) -> BooleanArray {
1258    let mut shuffled = (0..values.len())
1259        .map(|idx| values.value(idx))
1260        .collect::<Vec<_>>();
1261    shuffle_values(&mut shuffled, copy_index, source_index);
1262    BooleanArray::from(shuffled)
1263}
1264
1265fn shuffle_sparse_binary_column(
1266    values: &SparseBinaryColumn,
1267    n_rows: usize,
1268    copy_index: usize,
1269    source_index: usize,
1270) -> SparseBinaryColumn {
1271    let mut dense = vec![false; n_rows];
1272    for row_idx in &values.row_indices {
1273        dense[*row_idx] = true;
1274    }
1275    shuffle_values(&mut dense, copy_index, source_index);
1276    SparseBinaryColumn {
1277        row_indices: dense
1278            .into_iter()
1279            .enumerate()
1280            .filter_map(|(row_idx, value)| value.then_some(row_idx))
1281            .collect(),
1282    }
1283}
1284
1285fn shuffle_values<T>(values: &mut [T], copy_index: usize, source_index: usize) {
1286    let seed = 0xA11CE5EED_u64
1287        ^ ((copy_index as u64) << 32)
1288        ^ (source_index as u64)
1289        ^ ((values.len() as u64) << 16);
1290    let mut rng = StdRng::seed_from_u64(seed);
1291    values.shuffle(&mut rng);
1292}
1293
1294#[cfg(test)]
1295mod tests {
1296    use super::*;
1297    use std::collections::{BTreeMap, BTreeSet};
1298
1299    #[test]
1300    fn builds_arrow_backed_dense_table() {
1301        let table =
1302            DenseTable::new(vec![vec![0.0, 10.0], vec![1.0, 20.0]], vec![3.0, 5.0]).unwrap();
1303
1304        assert_eq!(table.n_rows(), 2);
1305        assert_eq!(table.n_features(), 2);
1306        assert_eq!(table.canaries(), 2);
1307        assert_eq!(table.binned_feature_count(), 6);
1308        assert_eq!(table.feature_value(0, 0), 0.0);
1309        assert_eq!(table.feature_value(0, 1), 1.0);
1310        assert_eq!(table.target().value(0), 3.0);
1311        assert_eq!(table.target().value(1), 5.0);
1312        assert!(!table.is_canary_binned_feature(0));
1313        assert!(table.is_canary_binned_feature(2));
1314    }
1315
1316    #[test]
1317    fn builds_sparse_table_for_all_binary_features() {
1318        let table = Table::with_canaries(
1319            vec![vec![0.0, 1.0], vec![1.0, 0.0], vec![1.0, 1.0]],
1320            vec![0.0, 1.0, 1.0],
1321            1,
1322        )
1323        .unwrap();
1324
1325        assert_eq!(table.kind(), TableKind::Sparse);
1326        assert!(table.is_binary_feature(0));
1327        assert!(table.is_binary_feature(1));
1328        assert!(table.is_binary_binned_feature(0));
1329        assert_eq!(table.binned_feature_count(), 4);
1330    }
1331
1332    #[test]
1333    fn builds_dense_table_when_any_feature_is_non_binary() {
1334        let table = Table::with_canaries(
1335            vec![vec![0.0, 1.5], vec![1.0, 0.0], vec![1.0, 2.0]],
1336            vec![0.0, 1.0, 1.0],
1337            1,
1338        )
1339        .unwrap();
1340
1341        assert_eq!(table.kind(), TableKind::Dense);
1342        assert!(table.is_binary_feature(0));
1343        assert!(!table.is_binary_feature(1));
1344    }
1345
1346    #[test]
1347    fn sparse_table_rejects_non_binary_columns() {
1348        let err =
1349            SparseTable::with_canaries(vec![vec![0.0, 2.0], vec![1.0, 0.0]], vec![0.0, 1.0], 0)
1350                .unwrap_err();
1351
1352        assert_eq!(err, DenseTableError::NonBinaryColumn { column: 1 });
1353    }
1354
1355    #[test]
1356    fn auto_bins_numeric_columns_into_power_of_two_bins_up_to_128() {
1357        let x: Vec<Vec<f64>> = (0..1024).map(|value| vec![value as f64]).collect();
1358        let y: Vec<f64> = vec![1.0; 1024];
1359
1360        let table = DenseTable::with_canaries(x, y, 0).unwrap();
1361
1362        assert_eq!(table.binned_value(0, 0), 0);
1363        assert_eq!(table.binned_value(0, 1023), 127);
1364        assert!((1..1024).all(|idx| table.binned_value(0, idx - 1) <= table.binned_value(0, idx)));
1365        assert_eq!(
1366            (0..1024)
1367                .map(|idx| table.binned_value(0, idx))
1368                .collect::<BTreeSet<_>>()
1369                .len(),
1370            128
1371        );
1372    }
1373
1374    #[test]
1375    fn auto_bins_choose_highest_populated_power_of_two() {
1376        let x: Vec<Vec<f64>> = (0..300).map(|value| vec![value as f64]).collect();
1377        let y = vec![0.0; 300];
1378
1379        let table = DenseTable::with_canaries(x, y, 0).unwrap();
1380
1381        assert_eq!(
1382            (0..300)
1383                .map(|idx| table.binned_value(0, idx))
1384                .collect::<BTreeSet<_>>()
1385                .len(),
1386            128
1387        );
1388    }
1389
1390    #[test]
1391    fn auto_bins_require_at_least_two_rows_per_bin() {
1392        let x: Vec<Vec<f64>> = (0..8).map(|value| vec![value as f64]).collect();
1393        let y = vec![0.0; 8];
1394
1395        let table = DenseTable::with_canaries(x, y, 0).unwrap();
1396        let counts = (0..table.n_rows()).fold(BTreeMap::new(), |mut counts, row_idx| {
1397            *counts
1398                .entry(table.binned_value(0, row_idx))
1399                .or_insert(0usize) += 1;
1400            counts
1401        });
1402
1403        assert_eq!(counts.len(), 4);
1404        assert!(counts.values().all(|count| *count >= 2));
1405    }
1406
1407    #[test]
1408    fn fixed_bins_cap_numeric_columns_to_requested_limit() {
1409        let x: Vec<Vec<f64>> = (0..300).map(|value| vec![value as f64]).collect();
1410        let y = vec![0.0; 300];
1411
1412        let table = DenseTable::with_options(x, y, 0, NumericBins::Fixed(64)).unwrap();
1413
1414        assert_eq!(
1415            (0..300)
1416                .map(|idx| table.binned_value(0, idx))
1417                .collect::<BTreeSet<_>>()
1418                .len(),
1419            64
1420        );
1421    }
1422
1423    #[test]
1424    fn rejects_invalid_fixed_bin_count() {
1425        assert_eq!(
1426            NumericBins::fixed(0).unwrap_err(),
1427            DenseTableError::InvalidBinCount { requested: 0 }
1428        );
1429        assert_eq!(
1430            NumericBins::fixed(513).unwrap_err(),
1431            DenseTableError::InvalidBinCount { requested: 513 }
1432        );
1433    }
1434
1435    #[test]
1436    fn keeps_equal_values_in_the_same_bin() {
1437        let table = DenseTable::with_canaries(
1438            vec![vec![0.0], vec![0.0], vec![1.0], vec![1.0], vec![2.0]],
1439            vec![0.0; 5],
1440            0,
1441        )
1442        .unwrap();
1443
1444        assert_eq!(table.binned_value(0, 0), table.binned_value(0, 1));
1445        assert_eq!(table.binned_value(0, 2), table.binned_value(0, 3));
1446        assert!(table.binned_value(0, 1) <= table.binned_value(0, 2));
1447        assert!(table.binned_value(0, 3) < table.binned_value(0, 4));
1448    }
1449
1450    #[test]
1451    fn stores_binary_columns_as_booleans() {
1452        let table = DenseTable::with_canaries(
1453            vec![vec![0.0, 2.0], vec![1.0, 3.0], vec![0.0, 4.0]],
1454            vec![0.0; 3],
1455            1,
1456        )
1457        .unwrap();
1458
1459        assert!(table.is_binary_feature(0));
1460        assert!(!table.is_binary_feature(1));
1461        assert!(table.is_binary_binned_feature(0));
1462        assert!(!table.is_binary_binned_feature(1));
1463        assert!(table.is_binary_binned_feature(2));
1464        assert_eq!(table.feature_value(0, 0), 0.0);
1465        assert_eq!(table.feature_value(0, 1), 1.0);
1466        assert_eq!(table.binned_boolean_value(0, 0), Some(false));
1467        assert_eq!(table.binned_boolean_value(0, 1), Some(true));
1468    }
1469
1470    #[test]
1471    fn stores_small_auto_binned_numeric_columns_as_u8() {
1472        let table = DenseTable::with_canaries(
1473            (0..8).map(|value| vec![value as f64]).collect(),
1474            vec![0.0; 8],
1475            0,
1476        )
1477        .unwrap();
1478
1479        assert!(matches!(
1480            table.binned_feature_column(0),
1481            BinnedFeatureColumnRef::NumericU8(_)
1482        ));
1483    }
1484
1485    #[test]
1486    fn creates_canary_columns_as_shuffled_binned_copies() {
1487        let table = DenseTable::with_canaries(
1488            vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]],
1489            vec![0.0; 5],
1490            1,
1491        )
1492        .unwrap();
1493
1494        assert!(matches!(
1495            table.binned_column_kind(1),
1496            BinnedColumnKind::Canary {
1497                source_index: 0,
1498                copy_index: 0
1499            }
1500        ));
1501        assert_eq!(
1502            (0..table.n_rows())
1503                .map(|idx| table.binned_value(0, idx))
1504                .collect::<BTreeSet<_>>(),
1505            (0..table.n_rows())
1506                .map(|idx| table.binned_value(1, idx))
1507                .collect::<BTreeSet<_>>()
1508        );
1509        assert_ne!(
1510            (0..table.n_rows())
1511                .map(|idx| table.binned_value(0, idx))
1512                .collect::<Vec<_>>(),
1513            (0..table.n_rows())
1514                .map(|idx| table.binned_value(1, idx))
1515                .collect::<Vec<_>>()
1516        );
1517    }
1518
1519    #[test]
1520    fn rejects_ragged_rows() {
1521        let err = DenseTable::new(vec![vec![1.0, 2.0], vec![3.0]], vec![1.0, 2.0]).unwrap_err();
1522
1523        assert_eq!(
1524            err,
1525            DenseTableError::RaggedRows {
1526                row: 1,
1527                expected: 2,
1528                actual: 1,
1529            }
1530        );
1531    }
1532
1533    #[test]
1534    fn rejects_mismatched_lengths() {
1535        let err = DenseTable::new(vec![vec![1.0], vec![2.0]], vec![1.0]).unwrap_err();
1536
1537        assert_eq!(err, DenseTableError::MismatchedLengths { x: 2, y: 1 });
1538    }
1539
1540    #[test]
1541    fn canary_generation_is_deterministic_for_identical_inputs() {
1542        let left = DenseTable::with_canaries(
1543            vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]],
1544            vec![0.0; 5],
1545            2,
1546        )
1547        .unwrap();
1548        let right = DenseTable::with_canaries(
1549            vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]],
1550            vec![0.0; 5],
1551            2,
1552        )
1553        .unwrap();
1554
1555        let left_values = binned_snapshot(&left);
1556        let right_values = binned_snapshot(&right);
1557
1558        assert_eq!(left_values, right_values);
1559    }
1560
1561    #[test]
1562    fn binary_canaries_remain_boolean_and_preserve_value_counts() {
1563        let table = DenseTable::with_canaries(
1564            vec![
1565                vec![0.0],
1566                vec![1.0],
1567                vec![0.0],
1568                vec![1.0],
1569                vec![1.0],
1570                vec![0.0],
1571            ],
1572            vec![0.0; 6],
1573            2,
1574        )
1575        .unwrap();
1576
1577        let real_true_count = (0..table.n_rows())
1578            .filter(|row_idx| table.binned_boolean_value(0, *row_idx) == Some(true))
1579            .count();
1580
1581        for feature_index in 1..table.binned_feature_count() {
1582            assert!(table.is_binary_binned_feature(feature_index));
1583            let canary_true_count = (0..table.n_rows())
1584                .filter(|row_idx| table.binned_boolean_value(feature_index, *row_idx) == Some(true))
1585                .count();
1586            assert_eq!(canary_true_count, real_true_count);
1587        }
1588    }
1589
1590    #[test]
1591    fn numeric_bin_boundaries_capture_training_bin_upper_bounds() {
1592        let boundaries = numeric_bin_boundaries(&[1.0, 1.0, 2.0, 10.0], NumericBins::Auto);
1593
1594        assert_eq!(boundaries, vec![(0, 1.0), (1, 10.0)]);
1595    }
1596
1597    fn binned_snapshot(table: &DenseTable) -> Vec<u16> {
1598        let mut values = Vec::new();
1599
1600        for feature_idx in 0..table.binned_feature_count() {
1601            for row_idx in 0..table.n_rows() {
1602                values.push(table.binned_value(feature_idx, row_idx));
1603            }
1604        }
1605
1606        values
1607    }
1608}