1use arrow::array::{Array, BooleanArray, Float64Array, UInt8Array, UInt16Array};
10use rand::seq::SliceRandom;
11use rand::{SeedableRng, rngs::StdRng};
12use std::cmp::Ordering;
13use std::error::Error;
14use std::fmt::{Display, Formatter};
15
16pub const MAX_NUMERIC_BINS: usize = 128;
17const DEFAULT_CANARIES: usize = 2;
18pub const BINARY_MISSING_BIN: u16 = 2;
19
20type PreprocessedRows = (Vec<Vec<f64>>, Float64Array, usize, usize);
21
22pub trait TableAccess: Sync {
29 fn n_rows(&self) -> usize;
30 fn n_features(&self) -> usize;
31 fn canaries(&self) -> usize;
32 fn numeric_bin_cap(&self) -> usize;
33 fn binned_feature_count(&self) -> usize;
34 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64;
35 fn is_missing(&self, feature_index: usize, row_index: usize) -> bool;
36 fn is_binary_feature(&self, index: usize) -> bool;
37 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16;
38 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool>;
39 fn binned_column_kind(&self, index: usize) -> BinnedColumnKind;
40 fn is_binary_binned_feature(&self, index: usize) -> bool;
41 fn target_value(&self, row_index: usize) -> f64;
42
43 fn is_canary_binned_feature(&self, index: usize) -> bool {
44 matches!(
45 self.binned_column_kind(index),
46 BinnedColumnKind::Canary { .. }
47 )
48 }
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum TableKind {
53 Dense,
55 Sparse,
57}
58
59#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
60pub enum NumericBins {
61 #[default]
63 Auto,
64 Fixed(usize),
66}
67
68impl NumericBins {
69 pub fn fixed(requested: usize) -> Result<Self, DenseTableError> {
70 if requested == 0 || requested > MAX_NUMERIC_BINS {
71 return Err(DenseTableError::InvalidBinCount { requested });
72 }
73 Ok(Self::Fixed(requested))
74 }
75
76 pub fn cap(self) -> usize {
77 match self {
78 NumericBins::Auto => MAX_NUMERIC_BINS,
79 NumericBins::Fixed(requested) => requested,
80 }
81 }
82}
83
84#[derive(Debug, Clone)]
90pub struct DenseTable {
91 feature_columns: Vec<FeatureColumn>,
92 binned_feature_columns: Vec<BinnedFeatureColumn>,
93 binned_column_kinds: Vec<BinnedColumnKind>,
94 target: Float64Array,
95 n_rows: usize,
96 n_features: usize,
97 canaries: usize,
98 numeric_bins: NumericBins,
99}
100
101#[derive(Debug, Clone)]
106pub struct SparseTable {
107 feature_columns: Vec<SparseBinaryColumn>,
108 binned_feature_columns: Vec<SparseBinaryColumn>,
109 binned_column_kinds: Vec<BinnedColumnKind>,
110 target: Float64Array,
111 n_rows: usize,
112 n_features: usize,
113 canaries: usize,
114 numeric_bins: NumericBins,
115}
116
117#[derive(Debug, Clone)]
118struct SparseBinaryColumn {
119 row_indices: Vec<usize>,
120}
121
122impl SparseBinaryColumn {
123 fn value(&self, row_index: usize) -> bool {
124 self.row_indices.binary_search(&row_index).is_ok()
125 }
126}
127
128#[derive(Debug, Clone)]
129pub enum Table {
130 Dense(DenseTable),
132 Sparse(SparseTable),
134}
135
136#[derive(Debug, Clone)]
137enum FeatureColumn {
138 Numeric(Float64Array),
139 Binary(BooleanArray),
140}
141
142#[derive(Debug, Clone)]
143enum BinnedFeatureColumn {
144 NumericU8(UInt8Array),
145 NumericU16(UInt16Array),
146 Binary(BooleanArray),
147}
148
149#[derive(Debug, Clone, Copy)]
150pub enum FeatureColumnRef<'a> {
151 Numeric(&'a Float64Array),
152 Binary(&'a BooleanArray),
153}
154
155#[derive(Debug, Clone, Copy)]
156pub enum BinnedFeatureColumnRef<'a> {
157 NumericU8(&'a UInt8Array),
158 NumericU16(&'a UInt16Array),
159 Binary(&'a BooleanArray),
160}
161
162#[derive(Debug, Clone, Copy, PartialEq, Eq)]
163pub enum BinnedColumnKind {
164 Real { source_index: usize },
166 Canary {
168 source_index: usize,
169 copy_index: usize,
170 },
171}
172
173#[derive(Debug, Clone, PartialEq, Eq)]
174pub enum DenseTableError {
175 MismatchedLengths {
176 x: usize,
177 y: usize,
178 },
179 RaggedRows {
180 row: usize,
181 expected: usize,
182 actual: usize,
183 },
184 NonBinaryColumn {
185 column: usize,
186 },
187 InvalidBinCount {
188 requested: usize,
189 },
190}
191
192impl Display for DenseTableError {
193 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
194 match self {
195 DenseTableError::MismatchedLengths { x, y } => write!(
196 f,
197 "Mismatched lengths: X has {} rows while y has {} values.",
198 x, y
199 ),
200 DenseTableError::RaggedRows {
201 row,
202 expected,
203 actual,
204 } => write!(
205 f,
206 "Ragged row at index {}: expected {} columns, found {}.",
207 row, expected, actual
208 ),
209 DenseTableError::NonBinaryColumn { column } => write!(
210 f,
211 "SparseTable requires binary features, but column {} contains non-binary values.",
212 column
213 ),
214 DenseTableError::InvalidBinCount { requested } => write!(
215 f,
216 "Invalid bins value {}. Expected 'auto' or an integer between 1 and {}.",
217 requested, MAX_NUMERIC_BINS
218 ),
219 }
220 }
221}
222
223impl Error for DenseTableError {}
224
225impl DenseTable {
226 pub fn new(x: Vec<Vec<f64>>, y: Vec<f64>) -> Result<Self, DenseTableError> {
229 Self::with_canaries(x, y, DEFAULT_CANARIES)
230 }
231
232 pub fn with_canaries(
234 x: Vec<Vec<f64>>,
235 y: Vec<f64>,
236 canaries: usize,
237 ) -> Result<Self, DenseTableError> {
238 Self::with_options(x, y, canaries, NumericBins::Auto)
239 }
240
241 pub fn with_options(
243 x: Vec<Vec<f64>>,
244 y: Vec<f64>,
245 canaries: usize,
246 numeric_bins: NumericBins,
247 ) -> Result<Self, DenseTableError> {
248 let (columns, target, n_rows, n_features) = preprocess_rows(&x, y)?;
249 Ok(Self::from_columns(
250 &columns,
251 target,
252 n_rows,
253 n_features,
254 canaries,
255 numeric_bins,
256 ))
257 }
258
259 fn from_columns(
260 columns: &[Vec<f64>],
261 target: Float64Array,
262 n_rows: usize,
263 n_features: usize,
264 canaries: usize,
265 numeric_bins: NumericBins,
266 ) -> Self {
267 let feature_columns = columns
268 .iter()
269 .map(|column| build_feature_column(column))
270 .collect();
271
272 let real_binned_columns: Vec<BinnedFeatureColumn> = columns
273 .iter()
274 .map(|column| build_binned_feature_column(column, numeric_bins))
275 .collect();
276 let canary_columns: Vec<(BinnedColumnKind, BinnedFeatureColumn)> = (0..canaries)
277 .flat_map(|copy_index| {
278 real_binned_columns
279 .iter()
280 .enumerate()
281 .map(move |(source_index, column)| {
282 (
283 BinnedColumnKind::Canary {
284 source_index,
285 copy_index,
286 },
287 shuffle_canary_column(column, copy_index, source_index),
291 )
292 })
293 })
294 .collect();
295
296 let (binned_column_kinds, binned_feature_columns): (Vec<_>, Vec<_>) = (0..n_features)
297 .map(|source_index| BinnedColumnKind::Real { source_index })
298 .zip(real_binned_columns)
299 .chain(canary_columns)
300 .unzip();
301
302 Self {
303 feature_columns,
304 binned_feature_columns,
305 binned_column_kinds,
306 target,
307 n_rows,
308 n_features,
309 canaries,
310 numeric_bins,
311 }
312 }
313
314 #[inline]
315 pub fn n_rows(&self) -> usize {
316 self.n_rows
317 }
318
319 #[inline]
320 pub fn n_features(&self) -> usize {
321 self.n_features
322 }
323
324 #[inline]
325 pub fn canaries(&self) -> usize {
326 self.canaries
327 }
328
329 #[inline]
330 pub fn numeric_bin_cap(&self) -> usize {
331 self.numeric_bins.cap()
332 }
333
334 #[inline]
335 pub fn binned_feature_count(&self) -> usize {
336 self.binned_feature_columns.len()
337 }
338
339 #[inline]
340 pub fn feature_column(&self, index: usize) -> FeatureColumnRef<'_> {
341 match &self.feature_columns[index] {
342 FeatureColumn::Numeric(column) => FeatureColumnRef::Numeric(column),
343 FeatureColumn::Binary(column) => FeatureColumnRef::Binary(column),
344 }
345 }
346
347 #[inline]
348 pub fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
349 match &self.feature_columns[feature_index] {
350 FeatureColumn::Numeric(column) => column.value(row_index),
351 FeatureColumn::Binary(column) => {
352 if column.is_null(row_index) {
353 f64::NAN
354 } else {
355 f64::from(u8::from(column.value(row_index)))
356 }
357 }
358 }
359 }
360
361 #[inline]
362 pub fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
363 match &self.feature_columns[feature_index] {
364 FeatureColumn::Numeric(column) => column.value(row_index).is_nan(),
365 FeatureColumn::Binary(column) => column.is_null(row_index),
366 }
367 }
368
369 #[inline]
370 pub fn is_binary_feature(&self, index: usize) -> bool {
371 matches!(self.feature_columns[index], FeatureColumn::Binary(_))
372 }
373
374 #[inline]
375 pub fn binned_feature_column(&self, index: usize) -> BinnedFeatureColumnRef<'_> {
376 match &self.binned_feature_columns[index] {
377 BinnedFeatureColumn::NumericU8(column) => BinnedFeatureColumnRef::NumericU8(column),
378 BinnedFeatureColumn::NumericU16(column) => BinnedFeatureColumnRef::NumericU16(column),
379 BinnedFeatureColumn::Binary(column) => BinnedFeatureColumnRef::Binary(column),
380 }
381 }
382
383 #[inline]
384 pub fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
385 match &self.binned_feature_columns[feature_index] {
386 BinnedFeatureColumn::NumericU8(column) => u16::from(column.value(row_index)),
387 BinnedFeatureColumn::NumericU16(column) => column.value(row_index),
388 BinnedFeatureColumn::Binary(column) => {
389 if column.is_null(row_index) {
390 BINARY_MISSING_BIN
391 } else {
392 u16::from(u8::from(column.value(row_index)))
393 }
394 }
395 }
396 }
397
398 #[inline]
399 pub fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
400 match &self.binned_feature_columns[feature_index] {
401 BinnedFeatureColumn::Binary(column) => {
402 (!column.is_null(row_index)).then(|| column.value(row_index))
403 }
404 BinnedFeatureColumn::NumericU8(_) | BinnedFeatureColumn::NumericU16(_) => None,
405 }
406 }
407
408 #[inline]
409 pub fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
410 self.binned_column_kinds[index]
411 }
412
413 #[inline]
414 pub fn is_canary_binned_feature(&self, index: usize) -> bool {
415 matches!(
416 self.binned_column_kinds[index],
417 BinnedColumnKind::Canary { .. }
418 )
419 }
420
421 #[inline]
422 pub fn is_binary_binned_feature(&self, index: usize) -> bool {
423 matches!(
424 self.binned_feature_columns[index],
425 BinnedFeatureColumn::Binary(_)
426 )
427 }
428
429 #[inline]
430 pub fn target(&self) -> &Float64Array {
431 &self.target
432 }
433}
434
435impl SparseTable {
436 pub fn new(x: Vec<Vec<f64>>, y: Vec<f64>) -> Result<Self, DenseTableError> {
437 Self::with_canaries(x, y, DEFAULT_CANARIES)
438 }
439
440 pub fn with_canaries(
441 x: Vec<Vec<f64>>,
442 y: Vec<f64>,
443 canaries: usize,
444 ) -> Result<Self, DenseTableError> {
445 Self::with_options(x, y, canaries, NumericBins::Auto)
446 }
447
448 pub fn with_options(
449 x: Vec<Vec<f64>>,
450 y: Vec<f64>,
451 canaries: usize,
452 numeric_bins: NumericBins,
453 ) -> Result<Self, DenseTableError> {
454 let (columns, target, n_rows, n_features) = preprocess_rows(&x, y)?;
455 validate_binary_columns(&columns)?;
456 Ok(Self::from_columns(
457 &columns,
458 target,
459 n_rows,
460 n_features,
461 canaries,
462 numeric_bins,
463 ))
464 }
465
466 fn from_columns(
467 columns: &[Vec<f64>],
468 target: Float64Array,
469 n_rows: usize,
470 n_features: usize,
471 canaries: usize,
472 numeric_bins: NumericBins,
473 ) -> Self {
474 let feature_columns: Vec<SparseBinaryColumn> = columns
475 .iter()
476 .map(|column| sparse_binary_column_from_values(column))
477 .collect();
478
479 let canary_columns: Vec<(BinnedColumnKind, SparseBinaryColumn)> = (0..canaries)
480 .flat_map(|copy_index| {
481 feature_columns
482 .iter()
483 .enumerate()
484 .map(move |(source_index, column)| {
485 (
486 BinnedColumnKind::Canary {
487 source_index,
488 copy_index,
489 },
490 shuffle_sparse_binary_column(column, n_rows, copy_index, source_index),
491 )
492 })
493 })
494 .collect();
495
496 let (binned_column_kinds, binned_feature_columns): (Vec<_>, Vec<_>) = (0..n_features)
497 .map(|source_index| BinnedColumnKind::Real { source_index })
498 .zip(feature_columns.iter().cloned())
499 .chain(canary_columns)
500 .unzip();
501
502 Self {
503 feature_columns,
504 binned_feature_columns,
505 binned_column_kinds,
506 target,
507 n_rows,
508 n_features,
509 canaries,
510 numeric_bins,
511 }
512 }
513
514 pub fn from_sparse_binary_columns(
515 n_rows: usize,
516 n_features: usize,
517 columns: Vec<Vec<usize>>,
518 y: Vec<f64>,
519 canaries: usize,
520 ) -> Result<Self, DenseTableError> {
521 Self::from_sparse_binary_columns_with_options(
522 n_rows,
523 n_features,
524 columns,
525 y,
526 canaries,
527 NumericBins::Auto,
528 )
529 }
530
531 pub fn from_sparse_binary_columns_with_options(
532 n_rows: usize,
533 n_features: usize,
534 columns: Vec<Vec<usize>>,
535 y: Vec<f64>,
536 canaries: usize,
537 numeric_bins: NumericBins,
538 ) -> Result<Self, DenseTableError> {
539 if n_rows != y.len() {
540 return Err(DenseTableError::MismatchedLengths {
541 x: n_rows,
542 y: y.len(),
543 });
544 }
545 if n_features != columns.len() {
546 return Err(DenseTableError::RaggedRows {
547 row: columns.len(),
548 expected: n_features,
549 actual: columns.len(),
550 });
551 }
552
553 let feature_columns = columns
554 .into_iter()
555 .enumerate()
556 .map(|(column_idx, mut row_indices)| {
557 row_indices.sort_unstable();
558 row_indices.dedup();
559 if row_indices.iter().any(|row_idx| *row_idx >= n_rows) {
560 return Err(DenseTableError::NonBinaryColumn { column: column_idx });
561 }
562 Ok(SparseBinaryColumn { row_indices })
563 })
564 .collect::<Result<Vec<_>, _>>()?;
565
566 let canary_columns: Vec<(BinnedColumnKind, SparseBinaryColumn)> = (0..canaries)
567 .flat_map(|copy_index| {
568 feature_columns
569 .iter()
570 .enumerate()
571 .map(move |(source_index, column)| {
572 (
573 BinnedColumnKind::Canary {
574 source_index,
575 copy_index,
576 },
577 shuffle_sparse_binary_column(column, n_rows, copy_index, source_index),
578 )
579 })
580 })
581 .collect();
582
583 let (binned_column_kinds, binned_feature_columns): (Vec<_>, Vec<_>) = (0..n_features)
584 .map(|source_index| BinnedColumnKind::Real { source_index })
585 .zip(feature_columns.iter().cloned())
586 .chain(canary_columns)
587 .unzip();
588
589 Ok(Self {
590 feature_columns,
591 binned_feature_columns,
592 binned_column_kinds,
593 target: Float64Array::from(y),
594 n_rows,
595 n_features,
596 canaries,
597 numeric_bins,
598 })
599 }
600
601 #[inline]
602 pub fn n_rows(&self) -> usize {
603 self.n_rows
604 }
605
606 #[inline]
607 pub fn n_features(&self) -> usize {
608 self.n_features
609 }
610
611 #[inline]
612 pub fn canaries(&self) -> usize {
613 self.canaries
614 }
615
616 #[inline]
617 pub fn numeric_bin_cap(&self) -> usize {
618 self.numeric_bins.cap()
619 }
620
621 #[inline]
622 pub fn binned_feature_count(&self) -> usize {
623 self.binned_feature_columns.len()
624 }
625
626 #[inline]
627 pub fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
628 f64::from(u8::from(
629 self.feature_columns[feature_index].value(row_index),
630 ))
631 }
632
633 #[inline]
634 pub fn is_missing(&self, _feature_index: usize, _row_index: usize) -> bool {
635 false
636 }
637
638 #[inline]
639 pub fn is_binary_feature(&self, _index: usize) -> bool {
640 true
641 }
642
643 #[inline]
644 pub fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
645 u16::from(u8::from(
646 self.binned_feature_columns[feature_index].value(row_index),
647 ))
648 }
649
650 #[inline]
651 pub fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
652 Some(self.binned_feature_columns[feature_index].value(row_index))
653 }
654
655 #[inline]
656 pub fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
657 self.binned_column_kinds[index]
658 }
659
660 #[inline]
661 pub fn is_canary_binned_feature(&self, index: usize) -> bool {
662 matches!(
663 self.binned_column_kinds[index],
664 BinnedColumnKind::Canary { .. }
665 )
666 }
667
668 #[inline]
669 pub fn is_binary_binned_feature(&self, _index: usize) -> bool {
670 true
671 }
672
673 #[inline]
674 pub fn target(&self) -> &Float64Array {
675 &self.target
676 }
677}
678
679impl Table {
680 pub fn new(x: Vec<Vec<f64>>, y: Vec<f64>) -> Result<Self, DenseTableError> {
681 Self::with_canaries(x, y, DEFAULT_CANARIES)
682 }
683
684 pub fn with_canaries(
685 x: Vec<Vec<f64>>,
686 y: Vec<f64>,
687 canaries: usize,
688 ) -> Result<Self, DenseTableError> {
689 Self::with_options(x, y, canaries, NumericBins::Auto)
690 }
691
692 pub fn with_options(
693 x: Vec<Vec<f64>>,
694 y: Vec<f64>,
695 canaries: usize,
696 numeric_bins: NumericBins,
697 ) -> Result<Self, DenseTableError> {
698 let (columns, target, n_rows, n_features) = preprocess_rows(&x, y)?;
699
700 if columns
701 .iter()
702 .all(|column| is_sparse_binary_eligible_column(column))
703 {
704 Ok(Self::Sparse(SparseTable::from_columns(
705 &columns,
706 target,
707 n_rows,
708 n_features,
709 canaries,
710 numeric_bins,
711 )))
712 } else {
713 Ok(Self::Dense(DenseTable::from_columns(
714 &columns,
715 target,
716 n_rows,
717 n_features,
718 canaries,
719 numeric_bins,
720 )))
721 }
722 }
723
724 pub fn kind(&self) -> TableKind {
725 match self {
726 Table::Dense(_) => TableKind::Dense,
727 Table::Sparse(_) => TableKind::Sparse,
728 }
729 }
730
731 pub fn as_dense(&self) -> Option<&DenseTable> {
732 match self {
733 Table::Dense(table) => Some(table),
734 Table::Sparse(_) => None,
735 }
736 }
737
738 pub fn as_sparse(&self) -> Option<&SparseTable> {
739 match self {
740 Table::Dense(_) => None,
741 Table::Sparse(table) => Some(table),
742 }
743 }
744}
745
746impl TableAccess for DenseTable {
747 fn n_rows(&self) -> usize {
748 self.n_rows()
749 }
750
751 fn n_features(&self) -> usize {
752 self.n_features()
753 }
754
755 fn canaries(&self) -> usize {
756 self.canaries()
757 }
758
759 fn numeric_bin_cap(&self) -> usize {
760 self.numeric_bin_cap()
761 }
762
763 fn binned_feature_count(&self) -> usize {
764 self.binned_feature_count()
765 }
766
767 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
768 self.feature_value(feature_index, row_index)
769 }
770
771 fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
772 self.is_missing(feature_index, row_index)
773 }
774
775 fn is_binary_feature(&self, index: usize) -> bool {
776 self.is_binary_feature(index)
777 }
778
779 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
780 self.binned_value(feature_index, row_index)
781 }
782
783 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
784 self.binned_boolean_value(feature_index, row_index)
785 }
786
787 fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
788 self.binned_column_kind(index)
789 }
790
791 fn is_binary_binned_feature(&self, index: usize) -> bool {
792 self.is_binary_binned_feature(index)
793 }
794
795 fn target_value(&self, row_index: usize) -> f64 {
796 self.target().value(row_index)
797 }
798}
799
800impl TableAccess for SparseTable {
801 fn n_rows(&self) -> usize {
802 self.n_rows()
803 }
804
805 fn n_features(&self) -> usize {
806 self.n_features()
807 }
808
809 fn canaries(&self) -> usize {
810 self.canaries()
811 }
812
813 fn numeric_bin_cap(&self) -> usize {
814 self.numeric_bin_cap()
815 }
816
817 fn binned_feature_count(&self) -> usize {
818 self.binned_feature_count()
819 }
820
821 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
822 self.feature_value(feature_index, row_index)
823 }
824
825 fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
826 self.is_missing(feature_index, row_index)
827 }
828
829 fn is_binary_feature(&self, index: usize) -> bool {
830 self.is_binary_feature(index)
831 }
832
833 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
834 self.binned_value(feature_index, row_index)
835 }
836
837 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
838 self.binned_boolean_value(feature_index, row_index)
839 }
840
841 fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
842 self.binned_column_kind(index)
843 }
844
845 fn is_binary_binned_feature(&self, index: usize) -> bool {
846 self.is_binary_binned_feature(index)
847 }
848
849 fn target_value(&self, row_index: usize) -> f64 {
850 self.target().value(row_index)
851 }
852}
853
854impl TableAccess for Table {
855 fn n_rows(&self) -> usize {
856 match self {
857 Table::Dense(table) => table.n_rows(),
858 Table::Sparse(table) => table.n_rows(),
859 }
860 }
861
862 fn n_features(&self) -> usize {
863 match self {
864 Table::Dense(table) => table.n_features(),
865 Table::Sparse(table) => table.n_features(),
866 }
867 }
868
869 fn canaries(&self) -> usize {
870 match self {
871 Table::Dense(table) => table.canaries(),
872 Table::Sparse(table) => table.canaries(),
873 }
874 }
875
876 fn numeric_bin_cap(&self) -> usize {
877 match self {
878 Table::Dense(table) => table.numeric_bin_cap(),
879 Table::Sparse(table) => table.numeric_bin_cap(),
880 }
881 }
882
883 fn binned_feature_count(&self) -> usize {
884 match self {
885 Table::Dense(table) => table.binned_feature_count(),
886 Table::Sparse(table) => table.binned_feature_count(),
887 }
888 }
889
890 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
891 match self {
892 Table::Dense(table) => table.feature_value(feature_index, row_index),
893 Table::Sparse(table) => table.feature_value(feature_index, row_index),
894 }
895 }
896
897 fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
898 match self {
899 Table::Dense(table) => table.is_missing(feature_index, row_index),
900 Table::Sparse(table) => table.is_missing(feature_index, row_index),
901 }
902 }
903
904 fn is_binary_feature(&self, index: usize) -> bool {
905 match self {
906 Table::Dense(table) => table.is_binary_feature(index),
907 Table::Sparse(table) => table.is_binary_feature(index),
908 }
909 }
910
911 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
912 match self {
913 Table::Dense(table) => table.binned_value(feature_index, row_index),
914 Table::Sparse(table) => table.binned_value(feature_index, row_index),
915 }
916 }
917
918 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
919 match self {
920 Table::Dense(table) => table.binned_boolean_value(feature_index, row_index),
921 Table::Sparse(table) => table.binned_boolean_value(feature_index, row_index),
922 }
923 }
924
925 fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
926 match self {
927 Table::Dense(table) => table.binned_column_kind(index),
928 Table::Sparse(table) => table.binned_column_kind(index),
929 }
930 }
931
932 fn is_binary_binned_feature(&self, index: usize) -> bool {
933 match self {
934 Table::Dense(table) => table.is_binary_binned_feature(index),
935 Table::Sparse(table) => table.is_binary_binned_feature(index),
936 }
937 }
938
939 fn target_value(&self, row_index: usize) -> f64 {
940 match self {
941 Table::Dense(table) => table.target().value(row_index),
942 Table::Sparse(table) => table.target().value(row_index),
943 }
944 }
945}
946
947fn preprocess_rows(x: &[Vec<f64>], y: Vec<f64>) -> Result<PreprocessedRows, DenseTableError> {
948 validate_shape(x, &y)?;
949 let n_rows = x.len();
950 let n_features = x.first().map_or(0, Vec::len);
951 let columns = collect_columns(x, n_features);
952 Ok((columns, Float64Array::from(y), n_rows, n_features))
953}
954
955fn validate_shape(x: &[Vec<f64>], y: &[f64]) -> Result<(), DenseTableError> {
956 if x.len() != y.len() {
957 return Err(DenseTableError::MismatchedLengths {
958 x: x.len(),
959 y: y.len(),
960 });
961 }
962
963 let n_features = x.first().map_or(0, Vec::len);
964 for (row_idx, row) in x.iter().enumerate() {
965 if row.len() != n_features {
966 return Err(DenseTableError::RaggedRows {
967 row: row_idx,
968 expected: n_features,
969 actual: row.len(),
970 });
971 }
972 }
973
974 Ok(())
975}
976
977fn collect_columns(x: &[Vec<f64>], n_features: usize) -> Vec<Vec<f64>> {
978 (0..n_features)
979 .map(|col_idx| x.iter().map(|row| row[col_idx]).collect())
980 .collect()
981}
982
983fn validate_binary_columns(columns: &[Vec<f64>]) -> Result<(), DenseTableError> {
984 for (column_idx, column) in columns.iter().enumerate() {
985 if !is_sparse_binary_eligible_column(column) {
986 return Err(DenseTableError::NonBinaryColumn { column: column_idx });
987 }
988 }
989
990 Ok(())
991}
992
993fn build_feature_column(values: &[f64]) -> FeatureColumn {
994 if is_binary_column(values) {
995 FeatureColumn::Binary(BooleanArray::from(to_binary_values(values)))
996 } else {
997 FeatureColumn::Numeric(Float64Array::from(values.to_vec()))
998 }
999}
1000
1001fn build_binned_feature_column(values: &[f64], numeric_bins: NumericBins) -> BinnedFeatureColumn {
1002 if is_binary_column(values) {
1003 BinnedFeatureColumn::Binary(BooleanArray::from(to_binary_values(values)))
1004 } else {
1005 let bins = bin_numeric_column(values, numeric_bins);
1006 if bins.iter().all(|value| *value <= u16::from(u8::MAX)) {
1007 BinnedFeatureColumn::NumericU8(UInt8Array::from(
1008 bins.into_iter()
1009 .map(|value| value as u8)
1010 .collect::<Vec<_>>(),
1011 ))
1012 } else {
1013 BinnedFeatureColumn::NumericU16(UInt16Array::from(bins))
1014 }
1015 }
1016}
1017
1018fn is_binary_column(values: &[f64]) -> bool {
1019 values.iter().all(|value| {
1020 value.is_nan()
1021 || matches!(value.total_cmp(&0.0), Ordering::Equal)
1022 || matches!(value.total_cmp(&1.0), Ordering::Equal)
1023 })
1024}
1025
1026fn is_sparse_binary_eligible_column(values: &[f64]) -> bool {
1027 values.iter().all(|value| {
1028 matches!(value.total_cmp(&0.0), Ordering::Equal)
1029 || matches!(value.total_cmp(&1.0), Ordering::Equal)
1030 })
1031}
1032
1033fn to_binary_values(values: &[f64]) -> Vec<Option<bool>> {
1034 values
1035 .iter()
1036 .map(|value| {
1037 if value.is_nan() {
1038 None
1039 } else {
1040 Some(value.total_cmp(&1.0) == Ordering::Equal)
1041 }
1042 })
1043 .collect()
1044}
1045
1046fn sparse_binary_column_from_values(values: &[f64]) -> SparseBinaryColumn {
1047 SparseBinaryColumn {
1048 row_indices: values
1049 .iter()
1050 .enumerate()
1051 .filter_map(|(row_idx, value)| {
1052 (value.total_cmp(&1.0) == Ordering::Equal).then_some(row_idx)
1053 })
1054 .collect(),
1055 }
1056}
1057
1058pub fn numeric_bin_boundaries(values: &[f64], numeric_bins: NumericBins) -> Vec<(u16, f64)> {
1059 if values.is_empty() {
1060 return Vec::new();
1061 }
1062
1063 let mut ranked_values: Vec<(usize, f64)> = values
1064 .iter()
1065 .copied()
1066 .enumerate()
1067 .filter(|(_, value)| !value.is_nan())
1068 .collect();
1069 if ranked_values.is_empty() {
1070 return Vec::new();
1071 }
1072 ranked_values.sort_by(|left, right| left.1.total_cmp(&right.1));
1073
1074 let unique_value_count = ranked_values
1075 .iter()
1076 .map(|(_row_idx, value)| *value)
1077 .fold(Vec::<f64>::new(), |mut unique_values, value| {
1078 let is_new_value = unique_values
1079 .last()
1080 .is_none_or(|last_value| last_value.total_cmp(&value) != Ordering::Equal);
1081 if is_new_value {
1082 unique_values.push(value);
1083 }
1084 unique_values
1085 })
1086 .len();
1087
1088 let bin_count =
1089 resolved_numeric_bin_count(ranked_values.len(), unique_value_count, numeric_bins);
1090 let mut unique_rank = 0usize;
1091 let mut start = 0usize;
1092 let mut boundaries = Vec::new();
1093
1094 while start < ranked_values.len() {
1095 let current_value = ranked_values[start].1;
1096 let end = ranked_values[start..]
1097 .iter()
1098 .position(|(_row_idx, value)| value.total_cmp(¤t_value) != Ordering::Equal)
1099 .map_or(ranked_values.len(), |offset| start + offset);
1100
1101 let bin = match numeric_bins {
1102 NumericBins::Auto => ((start * bin_count) / ranked_values.len()) as u16,
1103 NumericBins::Fixed(_) => {
1104 let max_bin = (bin_count - 1) as u16;
1105 if unique_value_count == 1 {
1106 0
1107 } else {
1108 ((unique_rank * usize::from(max_bin)) / (unique_value_count - 1)) as u16
1109 }
1110 }
1111 };
1112
1113 if let Some((last_bin, last_upper_bound)) = boundaries.last_mut() {
1114 if *last_bin == bin {
1115 *last_upper_bound = current_value;
1116 } else {
1117 boundaries.push((bin, current_value));
1118 }
1119 } else {
1120 boundaries.push((bin, current_value));
1121 }
1122
1123 unique_rank += 1;
1124 start = end;
1125 }
1126
1127 boundaries
1128}
1129
1130fn bin_numeric_column(values: &[f64], numeric_bins: NumericBins) -> Vec<u16> {
1131 if values.is_empty() {
1132 return Vec::new();
1133 }
1134
1135 let missing_bin = numeric_missing_bin(numeric_bins);
1136 let mut bins = vec![missing_bin; values.len()];
1137 let mut ranked_values: Vec<(usize, f64)> = values
1138 .iter()
1139 .copied()
1140 .enumerate()
1141 .filter(|(_, value)| !value.is_nan())
1142 .collect();
1143 if ranked_values.is_empty() {
1144 return bins;
1145 }
1146 ranked_values.sort_by(|left, right| left.1.total_cmp(&right.1));
1147
1148 let unique_value_count = ranked_values
1149 .iter()
1150 .map(|(_row_idx, value)| *value)
1151 .fold(Vec::<f64>::new(), |mut unique_values, value| {
1152 let is_new_value = unique_values
1153 .last()
1154 .is_none_or(|last_value| last_value.total_cmp(&value) != Ordering::Equal);
1155 if is_new_value {
1156 unique_values.push(value);
1157 }
1158 unique_values
1159 })
1160 .len();
1161
1162 let bin_count =
1163 resolved_numeric_bin_count(ranked_values.len(), unique_value_count, numeric_bins);
1164 let mut unique_rank = 0usize;
1165 let mut start = 0usize;
1166
1167 while start < ranked_values.len() {
1168 let current_value = ranked_values[start].1;
1169 let end = ranked_values[start..]
1170 .iter()
1171 .position(|(_row_idx, value)| value.total_cmp(¤t_value) != Ordering::Equal)
1172 .map_or(ranked_values.len(), |offset| start + offset);
1173
1174 let bin = match numeric_bins {
1175 NumericBins::Auto => ((start * bin_count) / ranked_values.len()) as u16,
1176 NumericBins::Fixed(_) => {
1177 let max_bin = (bin_count - 1) as u16;
1178 if unique_value_count == 1 {
1179 0
1180 } else {
1181 ((unique_rank * usize::from(max_bin)) / (unique_value_count - 1)) as u16
1182 }
1183 }
1184 };
1185
1186 for (row_idx, _value) in &ranked_values[start..end] {
1187 bins[*row_idx] = bin;
1188 }
1189
1190 unique_rank += 1;
1191 start = end;
1192 }
1193
1194 bins
1195}
1196
1197fn resolved_numeric_bin_count(
1198 value_count: usize,
1199 unique_value_count: usize,
1200 numeric_bins: NumericBins,
1201) -> usize {
1202 match numeric_bins {
1203 NumericBins::Auto => {
1204 let populated_bin_cap = (value_count / 2).max(1);
1205 let capped_unique_values = unique_value_count
1206 .min(MAX_NUMERIC_BINS)
1207 .min(populated_bin_cap)
1208 .max(1);
1209 highest_power_of_two_at_most(capped_unique_values)
1210 }
1211 NumericBins::Fixed(requested) => requested.min(unique_value_count).max(1),
1212 }
1213}
1214
1215pub fn numeric_missing_bin(numeric_bins: NumericBins) -> u16 {
1216 numeric_bins.cap() as u16
1217}
1218
1219fn highest_power_of_two_at_most(value: usize) -> usize {
1220 if value <= 1 {
1221 1
1222 } else {
1223 1usize << (usize::BITS as usize - 1 - value.leading_zeros() as usize)
1224 }
1225}
1226
1227fn shuffle_canary_column(
1228 values: &BinnedFeatureColumn,
1229 copy_index: usize,
1230 source_index: usize,
1231) -> BinnedFeatureColumn {
1232 match values {
1233 BinnedFeatureColumn::NumericU8(values) => {
1234 let mut shuffled = (0..values.len())
1235 .map(|idx| values.value(idx))
1236 .collect::<Vec<_>>();
1237 shuffle_values(&mut shuffled, copy_index, source_index);
1238 BinnedFeatureColumn::NumericU8(UInt8Array::from(shuffled))
1239 }
1240 BinnedFeatureColumn::NumericU16(values) => {
1241 let mut shuffled = (0..values.len())
1242 .map(|idx| values.value(idx))
1243 .collect::<Vec<_>>();
1244 shuffle_values(&mut shuffled, copy_index, source_index);
1245 BinnedFeatureColumn::NumericU16(UInt16Array::from(shuffled))
1246 }
1247 BinnedFeatureColumn::Binary(values) => {
1248 BinnedFeatureColumn::Binary(shuffle_boolean_array(values, copy_index, source_index))
1249 }
1250 }
1251}
1252
1253fn shuffle_boolean_array(
1254 values: &BooleanArray,
1255 copy_index: usize,
1256 source_index: usize,
1257) -> BooleanArray {
1258 let mut shuffled = (0..values.len())
1259 .map(|idx| values.value(idx))
1260 .collect::<Vec<_>>();
1261 shuffle_values(&mut shuffled, copy_index, source_index);
1262 BooleanArray::from(shuffled)
1263}
1264
1265fn shuffle_sparse_binary_column(
1266 values: &SparseBinaryColumn,
1267 n_rows: usize,
1268 copy_index: usize,
1269 source_index: usize,
1270) -> SparseBinaryColumn {
1271 let mut dense = vec![false; n_rows];
1272 for row_idx in &values.row_indices {
1273 dense[*row_idx] = true;
1274 }
1275 shuffle_values(&mut dense, copy_index, source_index);
1276 SparseBinaryColumn {
1277 row_indices: dense
1278 .into_iter()
1279 .enumerate()
1280 .filter_map(|(row_idx, value)| value.then_some(row_idx))
1281 .collect(),
1282 }
1283}
1284
1285fn shuffle_values<T>(values: &mut [T], copy_index: usize, source_index: usize) {
1286 let seed = 0xA11CE5EED_u64
1287 ^ ((copy_index as u64) << 32)
1288 ^ (source_index as u64)
1289 ^ ((values.len() as u64) << 16);
1290 let mut rng = StdRng::seed_from_u64(seed);
1291 values.shuffle(&mut rng);
1292}
1293
1294#[cfg(test)]
1295mod tests {
1296 use super::*;
1297 use std::collections::{BTreeMap, BTreeSet};
1298
1299 #[test]
1300 fn builds_arrow_backed_dense_table() {
1301 let table =
1302 DenseTable::new(vec![vec![0.0, 10.0], vec![1.0, 20.0]], vec![3.0, 5.0]).unwrap();
1303
1304 assert_eq!(table.n_rows(), 2);
1305 assert_eq!(table.n_features(), 2);
1306 assert_eq!(table.canaries(), 2);
1307 assert_eq!(table.binned_feature_count(), 6);
1308 assert_eq!(table.feature_value(0, 0), 0.0);
1309 assert_eq!(table.feature_value(0, 1), 1.0);
1310 assert_eq!(table.target().value(0), 3.0);
1311 assert_eq!(table.target().value(1), 5.0);
1312 assert!(!table.is_canary_binned_feature(0));
1313 assert!(table.is_canary_binned_feature(2));
1314 }
1315
1316 #[test]
1317 fn builds_sparse_table_for_all_binary_features() {
1318 let table = Table::with_canaries(
1319 vec![vec![0.0, 1.0], vec![1.0, 0.0], vec![1.0, 1.0]],
1320 vec![0.0, 1.0, 1.0],
1321 1,
1322 )
1323 .unwrap();
1324
1325 assert_eq!(table.kind(), TableKind::Sparse);
1326 assert!(table.is_binary_feature(0));
1327 assert!(table.is_binary_feature(1));
1328 assert!(table.is_binary_binned_feature(0));
1329 assert_eq!(table.binned_feature_count(), 4);
1330 }
1331
1332 #[test]
1333 fn builds_dense_table_when_any_feature_is_non_binary() {
1334 let table = Table::with_canaries(
1335 vec![vec![0.0, 1.5], vec![1.0, 0.0], vec![1.0, 2.0]],
1336 vec![0.0, 1.0, 1.0],
1337 1,
1338 )
1339 .unwrap();
1340
1341 assert_eq!(table.kind(), TableKind::Dense);
1342 assert!(table.is_binary_feature(0));
1343 assert!(!table.is_binary_feature(1));
1344 }
1345
1346 #[test]
1347 fn sparse_table_rejects_non_binary_columns() {
1348 let err =
1349 SparseTable::with_canaries(vec![vec![0.0, 2.0], vec![1.0, 0.0]], vec![0.0, 1.0], 0)
1350 .unwrap_err();
1351
1352 assert_eq!(err, DenseTableError::NonBinaryColumn { column: 1 });
1353 }
1354
1355 #[test]
1356 fn auto_bins_numeric_columns_into_power_of_two_bins_up_to_128() {
1357 let x: Vec<Vec<f64>> = (0..1024).map(|value| vec![value as f64]).collect();
1358 let y: Vec<f64> = vec![1.0; 1024];
1359
1360 let table = DenseTable::with_canaries(x, y, 0).unwrap();
1361
1362 assert_eq!(table.binned_value(0, 0), 0);
1363 assert_eq!(table.binned_value(0, 1023), 127);
1364 assert!((1..1024).all(|idx| table.binned_value(0, idx - 1) <= table.binned_value(0, idx)));
1365 assert_eq!(
1366 (0..1024)
1367 .map(|idx| table.binned_value(0, idx))
1368 .collect::<BTreeSet<_>>()
1369 .len(),
1370 128
1371 );
1372 }
1373
1374 #[test]
1375 fn auto_bins_choose_highest_populated_power_of_two() {
1376 let x: Vec<Vec<f64>> = (0..300).map(|value| vec![value as f64]).collect();
1377 let y = vec![0.0; 300];
1378
1379 let table = DenseTable::with_canaries(x, y, 0).unwrap();
1380
1381 assert_eq!(
1382 (0..300)
1383 .map(|idx| table.binned_value(0, idx))
1384 .collect::<BTreeSet<_>>()
1385 .len(),
1386 128
1387 );
1388 }
1389
1390 #[test]
1391 fn auto_bins_require_at_least_two_rows_per_bin() {
1392 let x: Vec<Vec<f64>> = (0..8).map(|value| vec![value as f64]).collect();
1393 let y = vec![0.0; 8];
1394
1395 let table = DenseTable::with_canaries(x, y, 0).unwrap();
1396 let counts = (0..table.n_rows()).fold(BTreeMap::new(), |mut counts, row_idx| {
1397 *counts
1398 .entry(table.binned_value(0, row_idx))
1399 .or_insert(0usize) += 1;
1400 counts
1401 });
1402
1403 assert_eq!(counts.len(), 4);
1404 assert!(counts.values().all(|count| *count >= 2));
1405 }
1406
1407 #[test]
1408 fn fixed_bins_cap_numeric_columns_to_requested_limit() {
1409 let x: Vec<Vec<f64>> = (0..300).map(|value| vec![value as f64]).collect();
1410 let y = vec![0.0; 300];
1411
1412 let table = DenseTable::with_options(x, y, 0, NumericBins::Fixed(64)).unwrap();
1413
1414 assert_eq!(
1415 (0..300)
1416 .map(|idx| table.binned_value(0, idx))
1417 .collect::<BTreeSet<_>>()
1418 .len(),
1419 64
1420 );
1421 }
1422
1423 #[test]
1424 fn rejects_invalid_fixed_bin_count() {
1425 assert_eq!(
1426 NumericBins::fixed(0).unwrap_err(),
1427 DenseTableError::InvalidBinCount { requested: 0 }
1428 );
1429 assert_eq!(
1430 NumericBins::fixed(513).unwrap_err(),
1431 DenseTableError::InvalidBinCount { requested: 513 }
1432 );
1433 }
1434
1435 #[test]
1436 fn keeps_equal_values_in_the_same_bin() {
1437 let table = DenseTable::with_canaries(
1438 vec![vec![0.0], vec![0.0], vec![1.0], vec![1.0], vec![2.0]],
1439 vec![0.0; 5],
1440 0,
1441 )
1442 .unwrap();
1443
1444 assert_eq!(table.binned_value(0, 0), table.binned_value(0, 1));
1445 assert_eq!(table.binned_value(0, 2), table.binned_value(0, 3));
1446 assert!(table.binned_value(0, 1) <= table.binned_value(0, 2));
1447 assert!(table.binned_value(0, 3) < table.binned_value(0, 4));
1448 }
1449
1450 #[test]
1451 fn stores_binary_columns_as_booleans() {
1452 let table = DenseTable::with_canaries(
1453 vec![vec![0.0, 2.0], vec![1.0, 3.0], vec![0.0, 4.0]],
1454 vec![0.0; 3],
1455 1,
1456 )
1457 .unwrap();
1458
1459 assert!(table.is_binary_feature(0));
1460 assert!(!table.is_binary_feature(1));
1461 assert!(table.is_binary_binned_feature(0));
1462 assert!(!table.is_binary_binned_feature(1));
1463 assert!(table.is_binary_binned_feature(2));
1464 assert_eq!(table.feature_value(0, 0), 0.0);
1465 assert_eq!(table.feature_value(0, 1), 1.0);
1466 assert_eq!(table.binned_boolean_value(0, 0), Some(false));
1467 assert_eq!(table.binned_boolean_value(0, 1), Some(true));
1468 }
1469
1470 #[test]
1471 fn stores_small_auto_binned_numeric_columns_as_u8() {
1472 let table = DenseTable::with_canaries(
1473 (0..8).map(|value| vec![value as f64]).collect(),
1474 vec![0.0; 8],
1475 0,
1476 )
1477 .unwrap();
1478
1479 assert!(matches!(
1480 table.binned_feature_column(0),
1481 BinnedFeatureColumnRef::NumericU8(_)
1482 ));
1483 }
1484
1485 #[test]
1486 fn creates_canary_columns_as_shuffled_binned_copies() {
1487 let table = DenseTable::with_canaries(
1488 vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]],
1489 vec![0.0; 5],
1490 1,
1491 )
1492 .unwrap();
1493
1494 assert!(matches!(
1495 table.binned_column_kind(1),
1496 BinnedColumnKind::Canary {
1497 source_index: 0,
1498 copy_index: 0
1499 }
1500 ));
1501 assert_eq!(
1502 (0..table.n_rows())
1503 .map(|idx| table.binned_value(0, idx))
1504 .collect::<BTreeSet<_>>(),
1505 (0..table.n_rows())
1506 .map(|idx| table.binned_value(1, idx))
1507 .collect::<BTreeSet<_>>()
1508 );
1509 assert_ne!(
1510 (0..table.n_rows())
1511 .map(|idx| table.binned_value(0, idx))
1512 .collect::<Vec<_>>(),
1513 (0..table.n_rows())
1514 .map(|idx| table.binned_value(1, idx))
1515 .collect::<Vec<_>>()
1516 );
1517 }
1518
1519 #[test]
1520 fn rejects_ragged_rows() {
1521 let err = DenseTable::new(vec![vec![1.0, 2.0], vec![3.0]], vec![1.0, 2.0]).unwrap_err();
1522
1523 assert_eq!(
1524 err,
1525 DenseTableError::RaggedRows {
1526 row: 1,
1527 expected: 2,
1528 actual: 1,
1529 }
1530 );
1531 }
1532
1533 #[test]
1534 fn rejects_mismatched_lengths() {
1535 let err = DenseTable::new(vec![vec![1.0], vec![2.0]], vec![1.0]).unwrap_err();
1536
1537 assert_eq!(err, DenseTableError::MismatchedLengths { x: 2, y: 1 });
1538 }
1539
1540 #[test]
1541 fn canary_generation_is_deterministic_for_identical_inputs() {
1542 let left = DenseTable::with_canaries(
1543 vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]],
1544 vec![0.0; 5],
1545 2,
1546 )
1547 .unwrap();
1548 let right = DenseTable::with_canaries(
1549 vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]],
1550 vec![0.0; 5],
1551 2,
1552 )
1553 .unwrap();
1554
1555 let left_values = binned_snapshot(&left);
1556 let right_values = binned_snapshot(&right);
1557
1558 assert_eq!(left_values, right_values);
1559 }
1560
1561 #[test]
1562 fn binary_canaries_remain_boolean_and_preserve_value_counts() {
1563 let table = DenseTable::with_canaries(
1564 vec![
1565 vec![0.0],
1566 vec![1.0],
1567 vec![0.0],
1568 vec![1.0],
1569 vec![1.0],
1570 vec![0.0],
1571 ],
1572 vec![0.0; 6],
1573 2,
1574 )
1575 .unwrap();
1576
1577 let real_true_count = (0..table.n_rows())
1578 .filter(|row_idx| table.binned_boolean_value(0, *row_idx) == Some(true))
1579 .count();
1580
1581 for feature_index in 1..table.binned_feature_count() {
1582 assert!(table.is_binary_binned_feature(feature_index));
1583 let canary_true_count = (0..table.n_rows())
1584 .filter(|row_idx| table.binned_boolean_value(feature_index, *row_idx) == Some(true))
1585 .count();
1586 assert_eq!(canary_true_count, real_true_count);
1587 }
1588 }
1589
1590 #[test]
1591 fn numeric_bin_boundaries_capture_training_bin_upper_bounds() {
1592 let boundaries = numeric_bin_boundaries(&[1.0, 1.0, 2.0, 10.0], NumericBins::Auto);
1593
1594 assert_eq!(boundaries, vec![(0, 1.0), (1, 10.0)]);
1595 }
1596
1597 fn binned_snapshot(table: &DenseTable) -> Vec<u16> {
1598 let mut values = Vec::new();
1599
1600 for feature_idx in 0..table.binned_feature_count() {
1601 for row_idx in 0..table.n_rows() {
1602 values.push(table.binned_value(feature_idx, row_idx));
1603 }
1604 }
1605
1606 values
1607 }
1608}