1use arrow::array::{BooleanArray, Float64Array, UInt8Array, UInt16Array};
10use rand::seq::SliceRandom;
11use rand::{SeedableRng, rngs::StdRng};
12use std::cmp::Ordering;
13use std::error::Error;
14use std::fmt::{Display, Formatter};
15
16pub const MAX_NUMERIC_BINS: usize = 128;
17const DEFAULT_CANARIES: usize = 2;
18
19type PreprocessedRows = (Vec<Vec<f64>>, Float64Array, usize, usize);
20
21pub trait TableAccess: Sync {
28 fn n_rows(&self) -> usize;
29 fn n_features(&self) -> usize;
30 fn canaries(&self) -> usize;
31 fn numeric_bin_cap(&self) -> usize;
32 fn binned_feature_count(&self) -> usize;
33 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64;
34 fn is_binary_feature(&self, index: usize) -> bool;
35 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16;
36 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool>;
37 fn binned_column_kind(&self, index: usize) -> BinnedColumnKind;
38 fn is_binary_binned_feature(&self, index: usize) -> bool;
39 fn target_value(&self, row_index: usize) -> f64;
40
41 fn is_canary_binned_feature(&self, index: usize) -> bool {
42 matches!(
43 self.binned_column_kind(index),
44 BinnedColumnKind::Canary { .. }
45 )
46 }
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum TableKind {
51 Dense,
53 Sparse,
55}
56
57#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
58pub enum NumericBins {
59 #[default]
61 Auto,
62 Fixed(usize),
64}
65
66impl NumericBins {
67 pub fn fixed(requested: usize) -> Result<Self, DenseTableError> {
68 if requested == 0 || requested > MAX_NUMERIC_BINS {
69 return Err(DenseTableError::InvalidBinCount { requested });
70 }
71 Ok(Self::Fixed(requested))
72 }
73
74 pub fn cap(self) -> usize {
75 match self {
76 NumericBins::Auto => MAX_NUMERIC_BINS,
77 NumericBins::Fixed(requested) => requested,
78 }
79 }
80}
81
82#[derive(Debug, Clone)]
88pub struct DenseTable {
89 feature_columns: Vec<FeatureColumn>,
90 binned_feature_columns: Vec<BinnedFeatureColumn>,
91 binned_column_kinds: Vec<BinnedColumnKind>,
92 target: Float64Array,
93 n_rows: usize,
94 n_features: usize,
95 canaries: usize,
96 numeric_bins: NumericBins,
97}
98
99#[derive(Debug, Clone)]
104pub struct SparseTable {
105 feature_columns: Vec<SparseBinaryColumn>,
106 binned_feature_columns: Vec<SparseBinaryColumn>,
107 binned_column_kinds: Vec<BinnedColumnKind>,
108 target: Float64Array,
109 n_rows: usize,
110 n_features: usize,
111 canaries: usize,
112 numeric_bins: NumericBins,
113}
114
115#[derive(Debug, Clone)]
116struct SparseBinaryColumn {
117 row_indices: Vec<usize>,
118}
119
120impl SparseBinaryColumn {
121 fn value(&self, row_index: usize) -> bool {
122 self.row_indices.binary_search(&row_index).is_ok()
123 }
124}
125
126#[derive(Debug, Clone)]
127pub enum Table {
128 Dense(DenseTable),
130 Sparse(SparseTable),
132}
133
134#[derive(Debug, Clone)]
135enum FeatureColumn {
136 Numeric(Float64Array),
137 Binary(BooleanArray),
138}
139
140#[derive(Debug, Clone)]
141enum BinnedFeatureColumn {
142 NumericU8(UInt8Array),
143 NumericU16(UInt16Array),
144 Binary(BooleanArray),
145}
146
147#[derive(Debug, Clone, Copy)]
148pub enum FeatureColumnRef<'a> {
149 Numeric(&'a Float64Array),
150 Binary(&'a BooleanArray),
151}
152
153#[derive(Debug, Clone, Copy)]
154pub enum BinnedFeatureColumnRef<'a> {
155 NumericU8(&'a UInt8Array),
156 NumericU16(&'a UInt16Array),
157 Binary(&'a BooleanArray),
158}
159
160#[derive(Debug, Clone, Copy, PartialEq, Eq)]
161pub enum BinnedColumnKind {
162 Real { source_index: usize },
164 Canary {
166 source_index: usize,
167 copy_index: usize,
168 },
169}
170
171#[derive(Debug, Clone, PartialEq, Eq)]
172pub enum DenseTableError {
173 MismatchedLengths {
174 x: usize,
175 y: usize,
176 },
177 RaggedRows {
178 row: usize,
179 expected: usize,
180 actual: usize,
181 },
182 NonBinaryColumn {
183 column: usize,
184 },
185 InvalidBinCount {
186 requested: usize,
187 },
188}
189
190impl Display for DenseTableError {
191 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
192 match self {
193 DenseTableError::MismatchedLengths { x, y } => write!(
194 f,
195 "Mismatched lengths: X has {} rows while y has {} values.",
196 x, y
197 ),
198 DenseTableError::RaggedRows {
199 row,
200 expected,
201 actual,
202 } => write!(
203 f,
204 "Ragged row at index {}: expected {} columns, found {}.",
205 row, expected, actual
206 ),
207 DenseTableError::NonBinaryColumn { column } => write!(
208 f,
209 "SparseTable requires binary features, but column {} contains non-binary values.",
210 column
211 ),
212 DenseTableError::InvalidBinCount { requested } => write!(
213 f,
214 "Invalid bins value {}. Expected 'auto' or an integer between 1 and {}.",
215 requested, MAX_NUMERIC_BINS
216 ),
217 }
218 }
219}
220
221impl Error for DenseTableError {}
222
223impl DenseTable {
224 pub fn new(x: Vec<Vec<f64>>, y: Vec<f64>) -> Result<Self, DenseTableError> {
227 Self::with_canaries(x, y, DEFAULT_CANARIES)
228 }
229
230 pub fn with_canaries(
232 x: Vec<Vec<f64>>,
233 y: Vec<f64>,
234 canaries: usize,
235 ) -> Result<Self, DenseTableError> {
236 Self::with_options(x, y, canaries, NumericBins::Auto)
237 }
238
239 pub fn with_options(
241 x: Vec<Vec<f64>>,
242 y: Vec<f64>,
243 canaries: usize,
244 numeric_bins: NumericBins,
245 ) -> Result<Self, DenseTableError> {
246 let (columns, target, n_rows, n_features) = preprocess_rows(&x, y)?;
247 Ok(Self::from_columns(
248 &columns,
249 target,
250 n_rows,
251 n_features,
252 canaries,
253 numeric_bins,
254 ))
255 }
256
257 fn from_columns(
258 columns: &[Vec<f64>],
259 target: Float64Array,
260 n_rows: usize,
261 n_features: usize,
262 canaries: usize,
263 numeric_bins: NumericBins,
264 ) -> Self {
265 let feature_columns = columns
266 .iter()
267 .map(|column| build_feature_column(column))
268 .collect();
269
270 let real_binned_columns: Vec<BinnedFeatureColumn> = columns
271 .iter()
272 .map(|column| build_binned_feature_column(column, numeric_bins))
273 .collect();
274 let canary_columns: Vec<(BinnedColumnKind, BinnedFeatureColumn)> = (0..canaries)
275 .flat_map(|copy_index| {
276 real_binned_columns
277 .iter()
278 .enumerate()
279 .map(move |(source_index, column)| {
280 (
281 BinnedColumnKind::Canary {
282 source_index,
283 copy_index,
284 },
285 shuffle_canary_column(column, copy_index, source_index),
289 )
290 })
291 })
292 .collect();
293
294 let (binned_column_kinds, binned_feature_columns): (Vec<_>, Vec<_>) = (0..n_features)
295 .map(|source_index| BinnedColumnKind::Real { source_index })
296 .zip(real_binned_columns)
297 .chain(canary_columns)
298 .unzip();
299
300 Self {
301 feature_columns,
302 binned_feature_columns,
303 binned_column_kinds,
304 target,
305 n_rows,
306 n_features,
307 canaries,
308 numeric_bins,
309 }
310 }
311
312 #[inline]
313 pub fn n_rows(&self) -> usize {
314 self.n_rows
315 }
316
317 #[inline]
318 pub fn n_features(&self) -> usize {
319 self.n_features
320 }
321
322 #[inline]
323 pub fn canaries(&self) -> usize {
324 self.canaries
325 }
326
327 #[inline]
328 pub fn numeric_bin_cap(&self) -> usize {
329 self.numeric_bins.cap()
330 }
331
332 #[inline]
333 pub fn binned_feature_count(&self) -> usize {
334 self.binned_feature_columns.len()
335 }
336
337 #[inline]
338 pub fn feature_column(&self, index: usize) -> FeatureColumnRef<'_> {
339 match &self.feature_columns[index] {
340 FeatureColumn::Numeric(column) => FeatureColumnRef::Numeric(column),
341 FeatureColumn::Binary(column) => FeatureColumnRef::Binary(column),
342 }
343 }
344
345 #[inline]
346 pub fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
347 match &self.feature_columns[feature_index] {
348 FeatureColumn::Numeric(column) => column.value(row_index),
349 FeatureColumn::Binary(column) => f64::from(u8::from(column.value(row_index))),
350 }
351 }
352
353 #[inline]
354 pub fn is_binary_feature(&self, index: usize) -> bool {
355 matches!(self.feature_columns[index], FeatureColumn::Binary(_))
356 }
357
358 #[inline]
359 pub fn binned_feature_column(&self, index: usize) -> BinnedFeatureColumnRef<'_> {
360 match &self.binned_feature_columns[index] {
361 BinnedFeatureColumn::NumericU8(column) => BinnedFeatureColumnRef::NumericU8(column),
362 BinnedFeatureColumn::NumericU16(column) => BinnedFeatureColumnRef::NumericU16(column),
363 BinnedFeatureColumn::Binary(column) => BinnedFeatureColumnRef::Binary(column),
364 }
365 }
366
367 #[inline]
368 pub fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
369 match &self.binned_feature_columns[feature_index] {
370 BinnedFeatureColumn::NumericU8(column) => u16::from(column.value(row_index)),
371 BinnedFeatureColumn::NumericU16(column) => column.value(row_index),
372 BinnedFeatureColumn::Binary(column) => u16::from(u8::from(column.value(row_index))),
373 }
374 }
375
376 #[inline]
377 pub fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
378 match &self.binned_feature_columns[feature_index] {
379 BinnedFeatureColumn::Binary(column) => Some(column.value(row_index)),
380 BinnedFeatureColumn::NumericU8(_) | BinnedFeatureColumn::NumericU16(_) => None,
381 }
382 }
383
384 #[inline]
385 pub fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
386 self.binned_column_kinds[index]
387 }
388
389 #[inline]
390 pub fn is_canary_binned_feature(&self, index: usize) -> bool {
391 matches!(
392 self.binned_column_kinds[index],
393 BinnedColumnKind::Canary { .. }
394 )
395 }
396
397 #[inline]
398 pub fn is_binary_binned_feature(&self, index: usize) -> bool {
399 matches!(
400 self.binned_feature_columns[index],
401 BinnedFeatureColumn::Binary(_)
402 )
403 }
404
405 #[inline]
406 pub fn target(&self) -> &Float64Array {
407 &self.target
408 }
409}
410
411impl SparseTable {
412 pub fn new(x: Vec<Vec<f64>>, y: Vec<f64>) -> Result<Self, DenseTableError> {
413 Self::with_canaries(x, y, DEFAULT_CANARIES)
414 }
415
416 pub fn with_canaries(
417 x: Vec<Vec<f64>>,
418 y: Vec<f64>,
419 canaries: usize,
420 ) -> Result<Self, DenseTableError> {
421 Self::with_options(x, y, canaries, NumericBins::Auto)
422 }
423
424 pub fn with_options(
425 x: Vec<Vec<f64>>,
426 y: Vec<f64>,
427 canaries: usize,
428 numeric_bins: NumericBins,
429 ) -> Result<Self, DenseTableError> {
430 let (columns, target, n_rows, n_features) = preprocess_rows(&x, y)?;
431 validate_binary_columns(&columns)?;
432 Ok(Self::from_columns(
433 &columns,
434 target,
435 n_rows,
436 n_features,
437 canaries,
438 numeric_bins,
439 ))
440 }
441
442 fn from_columns(
443 columns: &[Vec<f64>],
444 target: Float64Array,
445 n_rows: usize,
446 n_features: usize,
447 canaries: usize,
448 numeric_bins: NumericBins,
449 ) -> Self {
450 let feature_columns: Vec<SparseBinaryColumn> = columns
451 .iter()
452 .map(|column| sparse_binary_column_from_values(column))
453 .collect();
454
455 let canary_columns: Vec<(BinnedColumnKind, SparseBinaryColumn)> = (0..canaries)
456 .flat_map(|copy_index| {
457 feature_columns
458 .iter()
459 .enumerate()
460 .map(move |(source_index, column)| {
461 (
462 BinnedColumnKind::Canary {
463 source_index,
464 copy_index,
465 },
466 shuffle_sparse_binary_column(column, n_rows, copy_index, source_index),
467 )
468 })
469 })
470 .collect();
471
472 let (binned_column_kinds, binned_feature_columns): (Vec<_>, Vec<_>) = (0..n_features)
473 .map(|source_index| BinnedColumnKind::Real { source_index })
474 .zip(feature_columns.iter().cloned())
475 .chain(canary_columns)
476 .unzip();
477
478 Self {
479 feature_columns,
480 binned_feature_columns,
481 binned_column_kinds,
482 target,
483 n_rows,
484 n_features,
485 canaries,
486 numeric_bins,
487 }
488 }
489
490 pub fn from_sparse_binary_columns(
491 n_rows: usize,
492 n_features: usize,
493 columns: Vec<Vec<usize>>,
494 y: Vec<f64>,
495 canaries: usize,
496 ) -> Result<Self, DenseTableError> {
497 Self::from_sparse_binary_columns_with_options(
498 n_rows,
499 n_features,
500 columns,
501 y,
502 canaries,
503 NumericBins::Auto,
504 )
505 }
506
507 pub fn from_sparse_binary_columns_with_options(
508 n_rows: usize,
509 n_features: usize,
510 columns: Vec<Vec<usize>>,
511 y: Vec<f64>,
512 canaries: usize,
513 numeric_bins: NumericBins,
514 ) -> Result<Self, DenseTableError> {
515 if n_rows != y.len() {
516 return Err(DenseTableError::MismatchedLengths {
517 x: n_rows,
518 y: y.len(),
519 });
520 }
521 if n_features != columns.len() {
522 return Err(DenseTableError::RaggedRows {
523 row: columns.len(),
524 expected: n_features,
525 actual: columns.len(),
526 });
527 }
528
529 let feature_columns = columns
530 .into_iter()
531 .enumerate()
532 .map(|(column_idx, mut row_indices)| {
533 row_indices.sort_unstable();
534 row_indices.dedup();
535 if row_indices.iter().any(|row_idx| *row_idx >= n_rows) {
536 return Err(DenseTableError::NonBinaryColumn { column: column_idx });
537 }
538 Ok(SparseBinaryColumn { row_indices })
539 })
540 .collect::<Result<Vec<_>, _>>()?;
541
542 let canary_columns: Vec<(BinnedColumnKind, SparseBinaryColumn)> = (0..canaries)
543 .flat_map(|copy_index| {
544 feature_columns
545 .iter()
546 .enumerate()
547 .map(move |(source_index, column)| {
548 (
549 BinnedColumnKind::Canary {
550 source_index,
551 copy_index,
552 },
553 shuffle_sparse_binary_column(column, n_rows, copy_index, source_index),
554 )
555 })
556 })
557 .collect();
558
559 let (binned_column_kinds, binned_feature_columns): (Vec<_>, Vec<_>) = (0..n_features)
560 .map(|source_index| BinnedColumnKind::Real { source_index })
561 .zip(feature_columns.iter().cloned())
562 .chain(canary_columns)
563 .unzip();
564
565 Ok(Self {
566 feature_columns,
567 binned_feature_columns,
568 binned_column_kinds,
569 target: Float64Array::from(y),
570 n_rows,
571 n_features,
572 canaries,
573 numeric_bins,
574 })
575 }
576
577 #[inline]
578 pub fn n_rows(&self) -> usize {
579 self.n_rows
580 }
581
582 #[inline]
583 pub fn n_features(&self) -> usize {
584 self.n_features
585 }
586
587 #[inline]
588 pub fn canaries(&self) -> usize {
589 self.canaries
590 }
591
592 #[inline]
593 pub fn numeric_bin_cap(&self) -> usize {
594 self.numeric_bins.cap()
595 }
596
597 #[inline]
598 pub fn binned_feature_count(&self) -> usize {
599 self.binned_feature_columns.len()
600 }
601
602 #[inline]
603 pub fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
604 f64::from(u8::from(
605 self.feature_columns[feature_index].value(row_index),
606 ))
607 }
608
609 #[inline]
610 pub fn is_binary_feature(&self, _index: usize) -> bool {
611 true
612 }
613
614 #[inline]
615 pub fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
616 u16::from(u8::from(
617 self.binned_feature_columns[feature_index].value(row_index),
618 ))
619 }
620
621 #[inline]
622 pub fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
623 Some(self.binned_feature_columns[feature_index].value(row_index))
624 }
625
626 #[inline]
627 pub fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
628 self.binned_column_kinds[index]
629 }
630
631 #[inline]
632 pub fn is_canary_binned_feature(&self, index: usize) -> bool {
633 matches!(
634 self.binned_column_kinds[index],
635 BinnedColumnKind::Canary { .. }
636 )
637 }
638
639 #[inline]
640 pub fn is_binary_binned_feature(&self, _index: usize) -> bool {
641 true
642 }
643
644 #[inline]
645 pub fn target(&self) -> &Float64Array {
646 &self.target
647 }
648}
649
650impl Table {
651 pub fn new(x: Vec<Vec<f64>>, y: Vec<f64>) -> Result<Self, DenseTableError> {
652 Self::with_canaries(x, y, DEFAULT_CANARIES)
653 }
654
655 pub fn with_canaries(
656 x: Vec<Vec<f64>>,
657 y: Vec<f64>,
658 canaries: usize,
659 ) -> Result<Self, DenseTableError> {
660 Self::with_options(x, y, canaries, NumericBins::Auto)
661 }
662
663 pub fn with_options(
664 x: Vec<Vec<f64>>,
665 y: Vec<f64>,
666 canaries: usize,
667 numeric_bins: NumericBins,
668 ) -> Result<Self, DenseTableError> {
669 let (columns, target, n_rows, n_features) = preprocess_rows(&x, y)?;
670
671 if columns.iter().all(|column| is_binary_column(column)) {
672 Ok(Self::Sparse(SparseTable::from_columns(
673 &columns,
674 target,
675 n_rows,
676 n_features,
677 canaries,
678 numeric_bins,
679 )))
680 } else {
681 Ok(Self::Dense(DenseTable::from_columns(
682 &columns,
683 target,
684 n_rows,
685 n_features,
686 canaries,
687 numeric_bins,
688 )))
689 }
690 }
691
692 pub fn kind(&self) -> TableKind {
693 match self {
694 Table::Dense(_) => TableKind::Dense,
695 Table::Sparse(_) => TableKind::Sparse,
696 }
697 }
698
699 pub fn as_dense(&self) -> Option<&DenseTable> {
700 match self {
701 Table::Dense(table) => Some(table),
702 Table::Sparse(_) => None,
703 }
704 }
705
706 pub fn as_sparse(&self) -> Option<&SparseTable> {
707 match self {
708 Table::Dense(_) => None,
709 Table::Sparse(table) => Some(table),
710 }
711 }
712}
713
714impl TableAccess for DenseTable {
715 fn n_rows(&self) -> usize {
716 self.n_rows()
717 }
718
719 fn n_features(&self) -> usize {
720 self.n_features()
721 }
722
723 fn canaries(&self) -> usize {
724 self.canaries()
725 }
726
727 fn numeric_bin_cap(&self) -> usize {
728 self.numeric_bin_cap()
729 }
730
731 fn binned_feature_count(&self) -> usize {
732 self.binned_feature_count()
733 }
734
735 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
736 self.feature_value(feature_index, row_index)
737 }
738
739 fn is_binary_feature(&self, index: usize) -> bool {
740 self.is_binary_feature(index)
741 }
742
743 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
744 self.binned_value(feature_index, row_index)
745 }
746
747 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
748 self.binned_boolean_value(feature_index, row_index)
749 }
750
751 fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
752 self.binned_column_kind(index)
753 }
754
755 fn is_binary_binned_feature(&self, index: usize) -> bool {
756 self.is_binary_binned_feature(index)
757 }
758
759 fn target_value(&self, row_index: usize) -> f64 {
760 self.target().value(row_index)
761 }
762}
763
764impl TableAccess for SparseTable {
765 fn n_rows(&self) -> usize {
766 self.n_rows()
767 }
768
769 fn n_features(&self) -> usize {
770 self.n_features()
771 }
772
773 fn canaries(&self) -> usize {
774 self.canaries()
775 }
776
777 fn numeric_bin_cap(&self) -> usize {
778 self.numeric_bin_cap()
779 }
780
781 fn binned_feature_count(&self) -> usize {
782 self.binned_feature_count()
783 }
784
785 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
786 self.feature_value(feature_index, row_index)
787 }
788
789 fn is_binary_feature(&self, index: usize) -> bool {
790 self.is_binary_feature(index)
791 }
792
793 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
794 self.binned_value(feature_index, row_index)
795 }
796
797 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
798 self.binned_boolean_value(feature_index, row_index)
799 }
800
801 fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
802 self.binned_column_kind(index)
803 }
804
805 fn is_binary_binned_feature(&self, index: usize) -> bool {
806 self.is_binary_binned_feature(index)
807 }
808
809 fn target_value(&self, row_index: usize) -> f64 {
810 self.target().value(row_index)
811 }
812}
813
814impl TableAccess for Table {
815 fn n_rows(&self) -> usize {
816 match self {
817 Table::Dense(table) => table.n_rows(),
818 Table::Sparse(table) => table.n_rows(),
819 }
820 }
821
822 fn n_features(&self) -> usize {
823 match self {
824 Table::Dense(table) => table.n_features(),
825 Table::Sparse(table) => table.n_features(),
826 }
827 }
828
829 fn canaries(&self) -> usize {
830 match self {
831 Table::Dense(table) => table.canaries(),
832 Table::Sparse(table) => table.canaries(),
833 }
834 }
835
836 fn numeric_bin_cap(&self) -> usize {
837 match self {
838 Table::Dense(table) => table.numeric_bin_cap(),
839 Table::Sparse(table) => table.numeric_bin_cap(),
840 }
841 }
842
843 fn binned_feature_count(&self) -> usize {
844 match self {
845 Table::Dense(table) => table.binned_feature_count(),
846 Table::Sparse(table) => table.binned_feature_count(),
847 }
848 }
849
850 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
851 match self {
852 Table::Dense(table) => table.feature_value(feature_index, row_index),
853 Table::Sparse(table) => table.feature_value(feature_index, row_index),
854 }
855 }
856
857 fn is_binary_feature(&self, index: usize) -> bool {
858 match self {
859 Table::Dense(table) => table.is_binary_feature(index),
860 Table::Sparse(table) => table.is_binary_feature(index),
861 }
862 }
863
864 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
865 match self {
866 Table::Dense(table) => table.binned_value(feature_index, row_index),
867 Table::Sparse(table) => table.binned_value(feature_index, row_index),
868 }
869 }
870
871 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
872 match self {
873 Table::Dense(table) => table.binned_boolean_value(feature_index, row_index),
874 Table::Sparse(table) => table.binned_boolean_value(feature_index, row_index),
875 }
876 }
877
878 fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
879 match self {
880 Table::Dense(table) => table.binned_column_kind(index),
881 Table::Sparse(table) => table.binned_column_kind(index),
882 }
883 }
884
885 fn is_binary_binned_feature(&self, index: usize) -> bool {
886 match self {
887 Table::Dense(table) => table.is_binary_binned_feature(index),
888 Table::Sparse(table) => table.is_binary_binned_feature(index),
889 }
890 }
891
892 fn target_value(&self, row_index: usize) -> f64 {
893 match self {
894 Table::Dense(table) => table.target().value(row_index),
895 Table::Sparse(table) => table.target().value(row_index),
896 }
897 }
898}
899
900fn preprocess_rows(x: &[Vec<f64>], y: Vec<f64>) -> Result<PreprocessedRows, DenseTableError> {
901 validate_shape(x, &y)?;
902 let n_rows = x.len();
903 let n_features = x.first().map_or(0, Vec::len);
904 let columns = collect_columns(x, n_features);
905 Ok((columns, Float64Array::from(y), n_rows, n_features))
906}
907
908fn validate_shape(x: &[Vec<f64>], y: &[f64]) -> Result<(), DenseTableError> {
909 if x.len() != y.len() {
910 return Err(DenseTableError::MismatchedLengths {
911 x: x.len(),
912 y: y.len(),
913 });
914 }
915
916 let n_features = x.first().map_or(0, Vec::len);
917 for (row_idx, row) in x.iter().enumerate() {
918 if row.len() != n_features {
919 return Err(DenseTableError::RaggedRows {
920 row: row_idx,
921 expected: n_features,
922 actual: row.len(),
923 });
924 }
925 }
926
927 Ok(())
928}
929
930fn collect_columns(x: &[Vec<f64>], n_features: usize) -> Vec<Vec<f64>> {
931 (0..n_features)
932 .map(|col_idx| x.iter().map(|row| row[col_idx]).collect())
933 .collect()
934}
935
936fn validate_binary_columns(columns: &[Vec<f64>]) -> Result<(), DenseTableError> {
937 for (column_idx, column) in columns.iter().enumerate() {
938 if !is_binary_column(column) {
939 return Err(DenseTableError::NonBinaryColumn { column: column_idx });
940 }
941 }
942
943 Ok(())
944}
945
946fn build_feature_column(values: &[f64]) -> FeatureColumn {
947 if is_binary_column(values) {
948 FeatureColumn::Binary(BooleanArray::from(to_binary_values(values)))
949 } else {
950 FeatureColumn::Numeric(Float64Array::from(values.to_vec()))
951 }
952}
953
954fn build_binned_feature_column(values: &[f64], numeric_bins: NumericBins) -> BinnedFeatureColumn {
955 if is_binary_column(values) {
956 BinnedFeatureColumn::Binary(BooleanArray::from(to_binary_values(values)))
957 } else {
958 let bins = bin_numeric_column(values, numeric_bins);
959 if bins.iter().all(|value| *value <= u16::from(u8::MAX)) {
960 BinnedFeatureColumn::NumericU8(UInt8Array::from(
961 bins.into_iter()
962 .map(|value| value as u8)
963 .collect::<Vec<_>>(),
964 ))
965 } else {
966 BinnedFeatureColumn::NumericU16(UInt16Array::from(bins))
967 }
968 }
969}
970
971fn is_binary_column(values: &[f64]) -> bool {
972 values.iter().all(|value| {
973 matches!(value.total_cmp(&0.0), Ordering::Equal)
974 || matches!(value.total_cmp(&1.0), Ordering::Equal)
975 })
976}
977
978fn to_binary_values(values: &[f64]) -> Vec<bool> {
979 values
980 .iter()
981 .map(|value| value.total_cmp(&1.0) == Ordering::Equal)
982 .collect()
983}
984
985fn sparse_binary_column_from_values(values: &[f64]) -> SparseBinaryColumn {
986 SparseBinaryColumn {
987 row_indices: values
988 .iter()
989 .enumerate()
990 .filter_map(|(row_idx, value)| {
991 (value.total_cmp(&1.0) == Ordering::Equal).then_some(row_idx)
992 })
993 .collect(),
994 }
995}
996
997pub fn numeric_bin_boundaries(values: &[f64], numeric_bins: NumericBins) -> Vec<(u16, f64)> {
998 if values.is_empty() {
999 return Vec::new();
1000 }
1001
1002 let mut ranked_values: Vec<(usize, f64)> = values.iter().copied().enumerate().collect();
1003 ranked_values.sort_by(|left, right| left.1.total_cmp(&right.1));
1004
1005 let unique_value_count = ranked_values
1006 .iter()
1007 .map(|(_row_idx, value)| *value)
1008 .fold(Vec::<f64>::new(), |mut unique_values, value| {
1009 let is_new_value = unique_values
1010 .last()
1011 .is_none_or(|last_value| last_value.total_cmp(&value) != Ordering::Equal);
1012 if is_new_value {
1013 unique_values.push(value);
1014 }
1015 unique_values
1016 })
1017 .len();
1018
1019 let bin_count = resolved_numeric_bin_count(values.len(), unique_value_count, numeric_bins);
1020 let mut unique_rank = 0usize;
1021 let mut start = 0usize;
1022 let mut boundaries = Vec::new();
1023
1024 while start < ranked_values.len() {
1025 let current_value = ranked_values[start].1;
1026 let end = ranked_values[start..]
1027 .iter()
1028 .position(|(_row_idx, value)| value.total_cmp(¤t_value) != Ordering::Equal)
1029 .map_or(ranked_values.len(), |offset| start + offset);
1030
1031 let bin = match numeric_bins {
1032 NumericBins::Auto => ((start * bin_count) / values.len()) as u16,
1033 NumericBins::Fixed(_) => {
1034 let max_bin = (bin_count - 1) as u16;
1035 if unique_value_count == 1 {
1036 0
1037 } else {
1038 ((unique_rank * usize::from(max_bin)) / (unique_value_count - 1)) as u16
1039 }
1040 }
1041 };
1042
1043 if let Some((last_bin, last_upper_bound)) = boundaries.last_mut() {
1044 if *last_bin == bin {
1045 *last_upper_bound = current_value;
1046 } else {
1047 boundaries.push((bin, current_value));
1048 }
1049 } else {
1050 boundaries.push((bin, current_value));
1051 }
1052
1053 unique_rank += 1;
1054 start = end;
1055 }
1056
1057 boundaries
1058}
1059
1060fn bin_numeric_column(values: &[f64], numeric_bins: NumericBins) -> Vec<u16> {
1061 if values.is_empty() {
1062 return Vec::new();
1063 }
1064
1065 let mut ranked_values: Vec<(usize, f64)> = values.iter().copied().enumerate().collect();
1066 ranked_values.sort_by(|left, right| left.1.total_cmp(&right.1));
1067
1068 let unique_value_count = ranked_values
1069 .iter()
1070 .map(|(_row_idx, value)| *value)
1071 .fold(Vec::<f64>::new(), |mut unique_values, value| {
1072 let is_new_value = unique_values
1073 .last()
1074 .is_none_or(|last_value| last_value.total_cmp(&value) != Ordering::Equal);
1075 if is_new_value {
1076 unique_values.push(value);
1077 }
1078 unique_values
1079 })
1080 .len();
1081
1082 let mut bins = vec![0u16; values.len()];
1083 let bin_count = resolved_numeric_bin_count(values.len(), unique_value_count, numeric_bins);
1084 let mut unique_rank = 0usize;
1085 let mut start = 0usize;
1086
1087 while start < ranked_values.len() {
1088 let current_value = ranked_values[start].1;
1089 let end = ranked_values[start..]
1090 .iter()
1091 .position(|(_row_idx, value)| value.total_cmp(¤t_value) != Ordering::Equal)
1092 .map_or(ranked_values.len(), |offset| start + offset);
1093
1094 let bin = match numeric_bins {
1095 NumericBins::Auto => ((start * bin_count) / values.len()) as u16,
1096 NumericBins::Fixed(_) => {
1097 let max_bin = (bin_count - 1) as u16;
1098 if unique_value_count == 1 {
1099 0
1100 } else {
1101 ((unique_rank * usize::from(max_bin)) / (unique_value_count - 1)) as u16
1102 }
1103 }
1104 };
1105
1106 for (row_idx, _value) in &ranked_values[start..end] {
1107 bins[*row_idx] = bin;
1108 }
1109
1110 unique_rank += 1;
1111 start = end;
1112 }
1113
1114 bins
1115}
1116
1117fn resolved_numeric_bin_count(
1118 value_count: usize,
1119 unique_value_count: usize,
1120 numeric_bins: NumericBins,
1121) -> usize {
1122 match numeric_bins {
1123 NumericBins::Auto => {
1124 let populated_bin_cap = (value_count / 2).max(1);
1125 let capped_unique_values = unique_value_count
1126 .min(MAX_NUMERIC_BINS)
1127 .min(populated_bin_cap)
1128 .max(1);
1129 highest_power_of_two_at_most(capped_unique_values)
1130 }
1131 NumericBins::Fixed(requested) => requested.min(unique_value_count).max(1),
1132 }
1133}
1134
1135fn highest_power_of_two_at_most(value: usize) -> usize {
1136 if value <= 1 {
1137 1
1138 } else {
1139 1usize << (usize::BITS as usize - 1 - value.leading_zeros() as usize)
1140 }
1141}
1142
1143fn shuffle_canary_column(
1144 values: &BinnedFeatureColumn,
1145 copy_index: usize,
1146 source_index: usize,
1147) -> BinnedFeatureColumn {
1148 match values {
1149 BinnedFeatureColumn::NumericU8(values) => {
1150 let mut shuffled = (0..values.len())
1151 .map(|idx| values.value(idx))
1152 .collect::<Vec<_>>();
1153 shuffle_values(&mut shuffled, copy_index, source_index);
1154 BinnedFeatureColumn::NumericU8(UInt8Array::from(shuffled))
1155 }
1156 BinnedFeatureColumn::NumericU16(values) => {
1157 let mut shuffled = (0..values.len())
1158 .map(|idx| values.value(idx))
1159 .collect::<Vec<_>>();
1160 shuffle_values(&mut shuffled, copy_index, source_index);
1161 BinnedFeatureColumn::NumericU16(UInt16Array::from(shuffled))
1162 }
1163 BinnedFeatureColumn::Binary(values) => {
1164 BinnedFeatureColumn::Binary(shuffle_boolean_array(values, copy_index, source_index))
1165 }
1166 }
1167}
1168
1169fn shuffle_boolean_array(
1170 values: &BooleanArray,
1171 copy_index: usize,
1172 source_index: usize,
1173) -> BooleanArray {
1174 let mut shuffled = (0..values.len())
1175 .map(|idx| values.value(idx))
1176 .collect::<Vec<_>>();
1177 shuffle_values(&mut shuffled, copy_index, source_index);
1178 BooleanArray::from(shuffled)
1179}
1180
1181fn shuffle_sparse_binary_column(
1182 values: &SparseBinaryColumn,
1183 n_rows: usize,
1184 copy_index: usize,
1185 source_index: usize,
1186) -> SparseBinaryColumn {
1187 let mut dense = vec![false; n_rows];
1188 for row_idx in &values.row_indices {
1189 dense[*row_idx] = true;
1190 }
1191 shuffle_values(&mut dense, copy_index, source_index);
1192 SparseBinaryColumn {
1193 row_indices: dense
1194 .into_iter()
1195 .enumerate()
1196 .filter_map(|(row_idx, value)| value.then_some(row_idx))
1197 .collect(),
1198 }
1199}
1200
1201fn shuffle_values<T>(values: &mut [T], copy_index: usize, source_index: usize) {
1202 let seed = 0xA11CE5EED_u64
1203 ^ ((copy_index as u64) << 32)
1204 ^ (source_index as u64)
1205 ^ ((values.len() as u64) << 16);
1206 let mut rng = StdRng::seed_from_u64(seed);
1207 values.shuffle(&mut rng);
1208}
1209
1210#[cfg(test)]
1211mod tests {
1212 use super::*;
1213 use std::collections::{BTreeMap, BTreeSet};
1214
1215 #[test]
1216 fn builds_arrow_backed_dense_table() {
1217 let table =
1218 DenseTable::new(vec![vec![0.0, 10.0], vec![1.0, 20.0]], vec![3.0, 5.0]).unwrap();
1219
1220 assert_eq!(table.n_rows(), 2);
1221 assert_eq!(table.n_features(), 2);
1222 assert_eq!(table.canaries(), 2);
1223 assert_eq!(table.binned_feature_count(), 6);
1224 assert_eq!(table.feature_value(0, 0), 0.0);
1225 assert_eq!(table.feature_value(0, 1), 1.0);
1226 assert_eq!(table.target().value(0), 3.0);
1227 assert_eq!(table.target().value(1), 5.0);
1228 assert!(!table.is_canary_binned_feature(0));
1229 assert!(table.is_canary_binned_feature(2));
1230 }
1231
1232 #[test]
1233 fn builds_sparse_table_for_all_binary_features() {
1234 let table = Table::with_canaries(
1235 vec![vec![0.0, 1.0], vec![1.0, 0.0], vec![1.0, 1.0]],
1236 vec![0.0, 1.0, 1.0],
1237 1,
1238 )
1239 .unwrap();
1240
1241 assert_eq!(table.kind(), TableKind::Sparse);
1242 assert!(table.is_binary_feature(0));
1243 assert!(table.is_binary_feature(1));
1244 assert!(table.is_binary_binned_feature(0));
1245 assert_eq!(table.binned_feature_count(), 4);
1246 }
1247
1248 #[test]
1249 fn builds_dense_table_when_any_feature_is_non_binary() {
1250 let table = Table::with_canaries(
1251 vec![vec![0.0, 1.5], vec![1.0, 0.0], vec![1.0, 2.0]],
1252 vec![0.0, 1.0, 1.0],
1253 1,
1254 )
1255 .unwrap();
1256
1257 assert_eq!(table.kind(), TableKind::Dense);
1258 assert!(table.is_binary_feature(0));
1259 assert!(!table.is_binary_feature(1));
1260 }
1261
1262 #[test]
1263 fn sparse_table_rejects_non_binary_columns() {
1264 let err =
1265 SparseTable::with_canaries(vec![vec![0.0, 2.0], vec![1.0, 0.0]], vec![0.0, 1.0], 0)
1266 .unwrap_err();
1267
1268 assert_eq!(err, DenseTableError::NonBinaryColumn { column: 1 });
1269 }
1270
1271 #[test]
1272 fn auto_bins_numeric_columns_into_power_of_two_bins_up_to_128() {
1273 let x: Vec<Vec<f64>> = (0..1024).map(|value| vec![value as f64]).collect();
1274 let y: Vec<f64> = vec![1.0; 1024];
1275
1276 let table = DenseTable::with_canaries(x, y, 0).unwrap();
1277
1278 assert_eq!(table.binned_value(0, 0), 0);
1279 assert_eq!(table.binned_value(0, 1023), 127);
1280 assert!((1..1024).all(|idx| table.binned_value(0, idx - 1) <= table.binned_value(0, idx)));
1281 assert_eq!(
1282 (0..1024)
1283 .map(|idx| table.binned_value(0, idx))
1284 .collect::<BTreeSet<_>>()
1285 .len(),
1286 128
1287 );
1288 }
1289
1290 #[test]
1291 fn auto_bins_choose_highest_populated_power_of_two() {
1292 let x: Vec<Vec<f64>> = (0..300).map(|value| vec![value as f64]).collect();
1293 let y = vec![0.0; 300];
1294
1295 let table = DenseTable::with_canaries(x, y, 0).unwrap();
1296
1297 assert_eq!(
1298 (0..300)
1299 .map(|idx| table.binned_value(0, idx))
1300 .collect::<BTreeSet<_>>()
1301 .len(),
1302 128
1303 );
1304 }
1305
1306 #[test]
1307 fn auto_bins_require_at_least_two_rows_per_bin() {
1308 let x: Vec<Vec<f64>> = (0..8).map(|value| vec![value as f64]).collect();
1309 let y = vec![0.0; 8];
1310
1311 let table = DenseTable::with_canaries(x, y, 0).unwrap();
1312 let counts = (0..table.n_rows()).fold(BTreeMap::new(), |mut counts, row_idx| {
1313 *counts
1314 .entry(table.binned_value(0, row_idx))
1315 .or_insert(0usize) += 1;
1316 counts
1317 });
1318
1319 assert_eq!(counts.len(), 4);
1320 assert!(counts.values().all(|count| *count >= 2));
1321 }
1322
1323 #[test]
1324 fn fixed_bins_cap_numeric_columns_to_requested_limit() {
1325 let x: Vec<Vec<f64>> = (0..300).map(|value| vec![value as f64]).collect();
1326 let y = vec![0.0; 300];
1327
1328 let table = DenseTable::with_options(x, y, 0, NumericBins::Fixed(64)).unwrap();
1329
1330 assert_eq!(
1331 (0..300)
1332 .map(|idx| table.binned_value(0, idx))
1333 .collect::<BTreeSet<_>>()
1334 .len(),
1335 64
1336 );
1337 }
1338
1339 #[test]
1340 fn rejects_invalid_fixed_bin_count() {
1341 assert_eq!(
1342 NumericBins::fixed(0).unwrap_err(),
1343 DenseTableError::InvalidBinCount { requested: 0 }
1344 );
1345 assert_eq!(
1346 NumericBins::fixed(513).unwrap_err(),
1347 DenseTableError::InvalidBinCount { requested: 513 }
1348 );
1349 }
1350
1351 #[test]
1352 fn keeps_equal_values_in_the_same_bin() {
1353 let table = DenseTable::with_canaries(
1354 vec![vec![0.0], vec![0.0], vec![1.0], vec![1.0], vec![2.0]],
1355 vec![0.0; 5],
1356 0,
1357 )
1358 .unwrap();
1359
1360 assert_eq!(table.binned_value(0, 0), table.binned_value(0, 1));
1361 assert_eq!(table.binned_value(0, 2), table.binned_value(0, 3));
1362 assert!(table.binned_value(0, 1) <= table.binned_value(0, 2));
1363 assert!(table.binned_value(0, 3) < table.binned_value(0, 4));
1364 }
1365
1366 #[test]
1367 fn stores_binary_columns_as_booleans() {
1368 let table = DenseTable::with_canaries(
1369 vec![vec![0.0, 2.0], vec![1.0, 3.0], vec![0.0, 4.0]],
1370 vec![0.0; 3],
1371 1,
1372 )
1373 .unwrap();
1374
1375 assert!(table.is_binary_feature(0));
1376 assert!(!table.is_binary_feature(1));
1377 assert!(table.is_binary_binned_feature(0));
1378 assert!(!table.is_binary_binned_feature(1));
1379 assert!(table.is_binary_binned_feature(2));
1380 assert_eq!(table.feature_value(0, 0), 0.0);
1381 assert_eq!(table.feature_value(0, 1), 1.0);
1382 assert_eq!(table.binned_boolean_value(0, 0), Some(false));
1383 assert_eq!(table.binned_boolean_value(0, 1), Some(true));
1384 }
1385
1386 #[test]
1387 fn stores_small_auto_binned_numeric_columns_as_u8() {
1388 let table = DenseTable::with_canaries(
1389 (0..8).map(|value| vec![value as f64]).collect(),
1390 vec![0.0; 8],
1391 0,
1392 )
1393 .unwrap();
1394
1395 assert!(matches!(
1396 table.binned_feature_column(0),
1397 BinnedFeatureColumnRef::NumericU8(_)
1398 ));
1399 }
1400
1401 #[test]
1402 fn creates_canary_columns_as_shuffled_binned_copies() {
1403 let table = DenseTable::with_canaries(
1404 vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]],
1405 vec![0.0; 5],
1406 1,
1407 )
1408 .unwrap();
1409
1410 assert!(matches!(
1411 table.binned_column_kind(1),
1412 BinnedColumnKind::Canary {
1413 source_index: 0,
1414 copy_index: 0
1415 }
1416 ));
1417 assert_eq!(
1418 (0..table.n_rows())
1419 .map(|idx| table.binned_value(0, idx))
1420 .collect::<BTreeSet<_>>(),
1421 (0..table.n_rows())
1422 .map(|idx| table.binned_value(1, idx))
1423 .collect::<BTreeSet<_>>()
1424 );
1425 assert_ne!(
1426 (0..table.n_rows())
1427 .map(|idx| table.binned_value(0, idx))
1428 .collect::<Vec<_>>(),
1429 (0..table.n_rows())
1430 .map(|idx| table.binned_value(1, idx))
1431 .collect::<Vec<_>>()
1432 );
1433 }
1434
1435 #[test]
1436 fn rejects_ragged_rows() {
1437 let err = DenseTable::new(vec![vec![1.0, 2.0], vec![3.0]], vec![1.0, 2.0]).unwrap_err();
1438
1439 assert_eq!(
1440 err,
1441 DenseTableError::RaggedRows {
1442 row: 1,
1443 expected: 2,
1444 actual: 1,
1445 }
1446 );
1447 }
1448
1449 #[test]
1450 fn rejects_mismatched_lengths() {
1451 let err = DenseTable::new(vec![vec![1.0], vec![2.0]], vec![1.0]).unwrap_err();
1452
1453 assert_eq!(err, DenseTableError::MismatchedLengths { x: 2, y: 1 });
1454 }
1455
1456 #[test]
1457 fn canary_generation_is_deterministic_for_identical_inputs() {
1458 let left = DenseTable::with_canaries(
1459 vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]],
1460 vec![0.0; 5],
1461 2,
1462 )
1463 .unwrap();
1464 let right = DenseTable::with_canaries(
1465 vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]],
1466 vec![0.0; 5],
1467 2,
1468 )
1469 .unwrap();
1470
1471 let left_values = binned_snapshot(&left);
1472 let right_values = binned_snapshot(&right);
1473
1474 assert_eq!(left_values, right_values);
1475 }
1476
1477 #[test]
1478 fn binary_canaries_remain_boolean_and_preserve_value_counts() {
1479 let table = DenseTable::with_canaries(
1480 vec![
1481 vec![0.0],
1482 vec![1.0],
1483 vec![0.0],
1484 vec![1.0],
1485 vec![1.0],
1486 vec![0.0],
1487 ],
1488 vec![0.0; 6],
1489 2,
1490 )
1491 .unwrap();
1492
1493 let real_true_count = (0..table.n_rows())
1494 .filter(|row_idx| table.binned_boolean_value(0, *row_idx) == Some(true))
1495 .count();
1496
1497 for feature_index in 1..table.binned_feature_count() {
1498 assert!(table.is_binary_binned_feature(feature_index));
1499 let canary_true_count = (0..table.n_rows())
1500 .filter(|row_idx| table.binned_boolean_value(feature_index, *row_idx) == Some(true))
1501 .count();
1502 assert_eq!(canary_true_count, real_true_count);
1503 }
1504 }
1505
1506 #[test]
1507 fn numeric_bin_boundaries_capture_training_bin_upper_bounds() {
1508 let boundaries = numeric_bin_boundaries(&[1.0, 1.0, 2.0, 10.0], NumericBins::Auto);
1509
1510 assert_eq!(boundaries, vec![(0, 1.0), (1, 10.0)]);
1511 }
1512
1513 fn binned_snapshot(table: &DenseTable) -> Vec<u16> {
1514 let mut values = Vec::new();
1515
1516 for feature_idx in 0..table.binned_feature_count() {
1517 for row_idx in 0..table.n_rows() {
1518 values.push(table.binned_value(feature_idx, row_idx));
1519 }
1520 }
1521
1522 values
1523 }
1524}