1use arrow::array::{BooleanArray, Float64Array, UInt8Array, UInt16Array};
2use rand::seq::SliceRandom;
3use rand::{SeedableRng, rngs::StdRng};
4use std::cmp::Ordering;
5use std::error::Error;
6use std::fmt::{Display, Formatter};
7
8pub const MAX_NUMERIC_BINS: usize = 128;
9const DEFAULT_CANARIES: usize = 2;
10
11type PreprocessedRows = (Vec<Vec<f64>>, Float64Array, usize, usize);
12
13pub trait TableAccess: Sync {
14 fn n_rows(&self) -> usize;
15 fn n_features(&self) -> usize;
16 fn canaries(&self) -> usize;
17 fn numeric_bin_cap(&self) -> usize;
18 fn binned_feature_count(&self) -> usize;
19 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64;
20 fn is_binary_feature(&self, index: usize) -> bool;
21 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16;
22 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool>;
23 fn binned_column_kind(&self, index: usize) -> BinnedColumnKind;
24 fn is_binary_binned_feature(&self, index: usize) -> bool;
25 fn target_value(&self, row_index: usize) -> f64;
26
27 fn is_canary_binned_feature(&self, index: usize) -> bool {
28 matches!(
29 self.binned_column_kind(index),
30 BinnedColumnKind::Canary { .. }
31 )
32 }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum TableKind {
37 Dense,
38 Sparse,
39}
40
41#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
42pub enum NumericBins {
43 #[default]
44 Auto,
45 Fixed(usize),
46}
47
48impl NumericBins {
49 pub fn fixed(requested: usize) -> Result<Self, DenseTableError> {
50 if requested == 0 || requested > MAX_NUMERIC_BINS {
51 return Err(DenseTableError::InvalidBinCount { requested });
52 }
53 Ok(Self::Fixed(requested))
54 }
55
56 pub fn cap(self) -> usize {
57 match self {
58 NumericBins::Auto => MAX_NUMERIC_BINS,
59 NumericBins::Fixed(requested) => requested,
60 }
61 }
62}
63
64#[derive(Debug, Clone)]
66pub struct DenseTable {
67 feature_columns: Vec<FeatureColumn>,
68 binned_feature_columns: Vec<BinnedFeatureColumn>,
69 binned_column_kinds: Vec<BinnedColumnKind>,
70 target: Float64Array,
71 n_rows: usize,
72 n_features: usize,
73 canaries: usize,
74 numeric_bins: NumericBins,
75}
76
77#[derive(Debug, Clone)]
79pub struct SparseTable {
80 feature_columns: Vec<SparseBinaryColumn>,
81 binned_feature_columns: Vec<SparseBinaryColumn>,
82 binned_column_kinds: Vec<BinnedColumnKind>,
83 target: Float64Array,
84 n_rows: usize,
85 n_features: usize,
86 canaries: usize,
87 numeric_bins: NumericBins,
88}
89
90#[derive(Debug, Clone)]
91struct SparseBinaryColumn {
92 row_indices: Vec<usize>,
93}
94
95impl SparseBinaryColumn {
96 fn value(&self, row_index: usize) -> bool {
97 self.row_indices.binary_search(&row_index).is_ok()
98 }
99}
100
101#[derive(Debug, Clone)]
102pub enum Table {
103 Dense(DenseTable),
104 Sparse(SparseTable),
105}
106
107#[derive(Debug, Clone)]
108enum FeatureColumn {
109 Numeric(Float64Array),
110 Binary(BooleanArray),
111}
112
113#[derive(Debug, Clone)]
114enum BinnedFeatureColumn {
115 NumericU8(UInt8Array),
116 NumericU16(UInt16Array),
117 Binary(BooleanArray),
118}
119
120#[derive(Debug, Clone, Copy)]
121pub enum FeatureColumnRef<'a> {
122 Numeric(&'a Float64Array),
123 Binary(&'a BooleanArray),
124}
125
126#[derive(Debug, Clone, Copy)]
127pub enum BinnedFeatureColumnRef<'a> {
128 NumericU8(&'a UInt8Array),
129 NumericU16(&'a UInt16Array),
130 Binary(&'a BooleanArray),
131}
132
133#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134pub enum BinnedColumnKind {
135 Real {
136 source_index: usize,
137 },
138 Canary {
139 source_index: usize,
140 copy_index: usize,
141 },
142}
143
144#[derive(Debug, Clone, PartialEq, Eq)]
145pub enum DenseTableError {
146 MismatchedLengths {
147 x: usize,
148 y: usize,
149 },
150 RaggedRows {
151 row: usize,
152 expected: usize,
153 actual: usize,
154 },
155 NonBinaryColumn {
156 column: usize,
157 },
158 InvalidBinCount {
159 requested: usize,
160 },
161}
162
163impl Display for DenseTableError {
164 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
165 match self {
166 DenseTableError::MismatchedLengths { x, y } => write!(
167 f,
168 "Mismatched lengths: X has {} rows while y has {} values.",
169 x, y
170 ),
171 DenseTableError::RaggedRows {
172 row,
173 expected,
174 actual,
175 } => write!(
176 f,
177 "Ragged row at index {}: expected {} columns, found {}.",
178 row, expected, actual
179 ),
180 DenseTableError::NonBinaryColumn { column } => write!(
181 f,
182 "SparseTable requires binary features, but column {} contains non-binary values.",
183 column
184 ),
185 DenseTableError::InvalidBinCount { requested } => write!(
186 f,
187 "Invalid bins value {}. Expected 'auto' or an integer between 1 and {}.",
188 requested, MAX_NUMERIC_BINS
189 ),
190 }
191 }
192}
193
194impl Error for DenseTableError {}
195
196impl DenseTable {
197 pub fn new(x: Vec<Vec<f64>>, y: Vec<f64>) -> Result<Self, DenseTableError> {
198 Self::with_canaries(x, y, DEFAULT_CANARIES)
199 }
200
201 pub fn with_canaries(
202 x: Vec<Vec<f64>>,
203 y: Vec<f64>,
204 canaries: usize,
205 ) -> Result<Self, DenseTableError> {
206 Self::with_options(x, y, canaries, NumericBins::Auto)
207 }
208
209 pub fn with_options(
210 x: Vec<Vec<f64>>,
211 y: Vec<f64>,
212 canaries: usize,
213 numeric_bins: NumericBins,
214 ) -> Result<Self, DenseTableError> {
215 let (columns, target, n_rows, n_features) = preprocess_rows(&x, y)?;
216 Ok(Self::from_columns(
217 &columns,
218 target,
219 n_rows,
220 n_features,
221 canaries,
222 numeric_bins,
223 ))
224 }
225
226 fn from_columns(
227 columns: &[Vec<f64>],
228 target: Float64Array,
229 n_rows: usize,
230 n_features: usize,
231 canaries: usize,
232 numeric_bins: NumericBins,
233 ) -> Self {
234 let feature_columns = columns
235 .iter()
236 .map(|column| build_feature_column(column))
237 .collect();
238
239 let real_binned_columns: Vec<BinnedFeatureColumn> = columns
240 .iter()
241 .map(|column| build_binned_feature_column(column, numeric_bins))
242 .collect();
243 let canary_columns: Vec<(BinnedColumnKind, BinnedFeatureColumn)> = (0..canaries)
244 .flat_map(|copy_index| {
245 real_binned_columns
246 .iter()
247 .enumerate()
248 .map(move |(source_index, column)| {
249 (
250 BinnedColumnKind::Canary {
251 source_index,
252 copy_index,
253 },
254 shuffle_canary_column(column, copy_index, source_index),
255 )
256 })
257 })
258 .collect();
259
260 let (binned_column_kinds, binned_feature_columns): (Vec<_>, Vec<_>) = (0..n_features)
261 .map(|source_index| BinnedColumnKind::Real { source_index })
262 .zip(real_binned_columns)
263 .chain(canary_columns)
264 .unzip();
265
266 Self {
267 feature_columns,
268 binned_feature_columns,
269 binned_column_kinds,
270 target,
271 n_rows,
272 n_features,
273 canaries,
274 numeric_bins,
275 }
276 }
277
278 #[inline]
279 pub fn n_rows(&self) -> usize {
280 self.n_rows
281 }
282
283 #[inline]
284 pub fn n_features(&self) -> usize {
285 self.n_features
286 }
287
288 #[inline]
289 pub fn canaries(&self) -> usize {
290 self.canaries
291 }
292
293 #[inline]
294 pub fn numeric_bin_cap(&self) -> usize {
295 self.numeric_bins.cap()
296 }
297
298 #[inline]
299 pub fn binned_feature_count(&self) -> usize {
300 self.binned_feature_columns.len()
301 }
302
303 #[inline]
304 pub fn feature_column(&self, index: usize) -> FeatureColumnRef<'_> {
305 match &self.feature_columns[index] {
306 FeatureColumn::Numeric(column) => FeatureColumnRef::Numeric(column),
307 FeatureColumn::Binary(column) => FeatureColumnRef::Binary(column),
308 }
309 }
310
311 #[inline]
312 pub fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
313 match &self.feature_columns[feature_index] {
314 FeatureColumn::Numeric(column) => column.value(row_index),
315 FeatureColumn::Binary(column) => f64::from(u8::from(column.value(row_index))),
316 }
317 }
318
319 #[inline]
320 pub fn is_binary_feature(&self, index: usize) -> bool {
321 matches!(self.feature_columns[index], FeatureColumn::Binary(_))
322 }
323
324 #[inline]
325 pub fn binned_feature_column(&self, index: usize) -> BinnedFeatureColumnRef<'_> {
326 match &self.binned_feature_columns[index] {
327 BinnedFeatureColumn::NumericU8(column) => BinnedFeatureColumnRef::NumericU8(column),
328 BinnedFeatureColumn::NumericU16(column) => BinnedFeatureColumnRef::NumericU16(column),
329 BinnedFeatureColumn::Binary(column) => BinnedFeatureColumnRef::Binary(column),
330 }
331 }
332
333 #[inline]
334 pub fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
335 match &self.binned_feature_columns[feature_index] {
336 BinnedFeatureColumn::NumericU8(column) => u16::from(column.value(row_index)),
337 BinnedFeatureColumn::NumericU16(column) => column.value(row_index),
338 BinnedFeatureColumn::Binary(column) => u16::from(u8::from(column.value(row_index))),
339 }
340 }
341
342 #[inline]
343 pub fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
344 match &self.binned_feature_columns[feature_index] {
345 BinnedFeatureColumn::Binary(column) => Some(column.value(row_index)),
346 BinnedFeatureColumn::NumericU8(_) | BinnedFeatureColumn::NumericU16(_) => None,
347 }
348 }
349
350 #[inline]
351 pub fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
352 self.binned_column_kinds[index]
353 }
354
355 #[inline]
356 pub fn is_canary_binned_feature(&self, index: usize) -> bool {
357 matches!(
358 self.binned_column_kinds[index],
359 BinnedColumnKind::Canary { .. }
360 )
361 }
362
363 #[inline]
364 pub fn is_binary_binned_feature(&self, index: usize) -> bool {
365 matches!(
366 self.binned_feature_columns[index],
367 BinnedFeatureColumn::Binary(_)
368 )
369 }
370
371 #[inline]
372 pub fn target(&self) -> &Float64Array {
373 &self.target
374 }
375}
376
377impl SparseTable {
378 pub fn new(x: Vec<Vec<f64>>, y: Vec<f64>) -> Result<Self, DenseTableError> {
379 Self::with_canaries(x, y, DEFAULT_CANARIES)
380 }
381
382 pub fn with_canaries(
383 x: Vec<Vec<f64>>,
384 y: Vec<f64>,
385 canaries: usize,
386 ) -> Result<Self, DenseTableError> {
387 Self::with_options(x, y, canaries, NumericBins::Auto)
388 }
389
390 pub fn with_options(
391 x: Vec<Vec<f64>>,
392 y: Vec<f64>,
393 canaries: usize,
394 numeric_bins: NumericBins,
395 ) -> Result<Self, DenseTableError> {
396 let (columns, target, n_rows, n_features) = preprocess_rows(&x, y)?;
397 validate_binary_columns(&columns)?;
398 Ok(Self::from_columns(
399 &columns,
400 target,
401 n_rows,
402 n_features,
403 canaries,
404 numeric_bins,
405 ))
406 }
407
408 fn from_columns(
409 columns: &[Vec<f64>],
410 target: Float64Array,
411 n_rows: usize,
412 n_features: usize,
413 canaries: usize,
414 numeric_bins: NumericBins,
415 ) -> Self {
416 let feature_columns: Vec<SparseBinaryColumn> = columns
417 .iter()
418 .map(|column| sparse_binary_column_from_values(column))
419 .collect();
420
421 let canary_columns: Vec<(BinnedColumnKind, SparseBinaryColumn)> = (0..canaries)
422 .flat_map(|copy_index| {
423 feature_columns
424 .iter()
425 .enumerate()
426 .map(move |(source_index, column)| {
427 (
428 BinnedColumnKind::Canary {
429 source_index,
430 copy_index,
431 },
432 shuffle_sparse_binary_column(column, n_rows, copy_index, source_index),
433 )
434 })
435 })
436 .collect();
437
438 let (binned_column_kinds, binned_feature_columns): (Vec<_>, Vec<_>) = (0..n_features)
439 .map(|source_index| BinnedColumnKind::Real { source_index })
440 .zip(feature_columns.iter().cloned())
441 .chain(canary_columns)
442 .unzip();
443
444 Self {
445 feature_columns,
446 binned_feature_columns,
447 binned_column_kinds,
448 target,
449 n_rows,
450 n_features,
451 canaries,
452 numeric_bins,
453 }
454 }
455
456 pub fn from_sparse_binary_columns(
457 n_rows: usize,
458 n_features: usize,
459 columns: Vec<Vec<usize>>,
460 y: Vec<f64>,
461 canaries: usize,
462 ) -> Result<Self, DenseTableError> {
463 Self::from_sparse_binary_columns_with_options(
464 n_rows,
465 n_features,
466 columns,
467 y,
468 canaries,
469 NumericBins::Auto,
470 )
471 }
472
473 pub fn from_sparse_binary_columns_with_options(
474 n_rows: usize,
475 n_features: usize,
476 columns: Vec<Vec<usize>>,
477 y: Vec<f64>,
478 canaries: usize,
479 numeric_bins: NumericBins,
480 ) -> Result<Self, DenseTableError> {
481 if n_rows != y.len() {
482 return Err(DenseTableError::MismatchedLengths {
483 x: n_rows,
484 y: y.len(),
485 });
486 }
487 if n_features != columns.len() {
488 return Err(DenseTableError::RaggedRows {
489 row: columns.len(),
490 expected: n_features,
491 actual: columns.len(),
492 });
493 }
494
495 let feature_columns = columns
496 .into_iter()
497 .enumerate()
498 .map(|(column_idx, mut row_indices)| {
499 row_indices.sort_unstable();
500 row_indices.dedup();
501 if row_indices.iter().any(|row_idx| *row_idx >= n_rows) {
502 return Err(DenseTableError::NonBinaryColumn { column: column_idx });
503 }
504 Ok(SparseBinaryColumn { row_indices })
505 })
506 .collect::<Result<Vec<_>, _>>()?;
507
508 let canary_columns: Vec<(BinnedColumnKind, SparseBinaryColumn)> = (0..canaries)
509 .flat_map(|copy_index| {
510 feature_columns
511 .iter()
512 .enumerate()
513 .map(move |(source_index, column)| {
514 (
515 BinnedColumnKind::Canary {
516 source_index,
517 copy_index,
518 },
519 shuffle_sparse_binary_column(column, n_rows, copy_index, source_index),
520 )
521 })
522 })
523 .collect();
524
525 let (binned_column_kinds, binned_feature_columns): (Vec<_>, Vec<_>) = (0..n_features)
526 .map(|source_index| BinnedColumnKind::Real { source_index })
527 .zip(feature_columns.iter().cloned())
528 .chain(canary_columns)
529 .unzip();
530
531 Ok(Self {
532 feature_columns,
533 binned_feature_columns,
534 binned_column_kinds,
535 target: Float64Array::from(y),
536 n_rows,
537 n_features,
538 canaries,
539 numeric_bins,
540 })
541 }
542
543 #[inline]
544 pub fn n_rows(&self) -> usize {
545 self.n_rows
546 }
547
548 #[inline]
549 pub fn n_features(&self) -> usize {
550 self.n_features
551 }
552
553 #[inline]
554 pub fn canaries(&self) -> usize {
555 self.canaries
556 }
557
558 #[inline]
559 pub fn numeric_bin_cap(&self) -> usize {
560 self.numeric_bins.cap()
561 }
562
563 #[inline]
564 pub fn binned_feature_count(&self) -> usize {
565 self.binned_feature_columns.len()
566 }
567
568 #[inline]
569 pub fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
570 f64::from(u8::from(
571 self.feature_columns[feature_index].value(row_index),
572 ))
573 }
574
575 #[inline]
576 pub fn is_binary_feature(&self, _index: usize) -> bool {
577 true
578 }
579
580 #[inline]
581 pub fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
582 u16::from(u8::from(
583 self.binned_feature_columns[feature_index].value(row_index),
584 ))
585 }
586
587 #[inline]
588 pub fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
589 Some(self.binned_feature_columns[feature_index].value(row_index))
590 }
591
592 #[inline]
593 pub fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
594 self.binned_column_kinds[index]
595 }
596
597 #[inline]
598 pub fn is_canary_binned_feature(&self, index: usize) -> bool {
599 matches!(
600 self.binned_column_kinds[index],
601 BinnedColumnKind::Canary { .. }
602 )
603 }
604
605 #[inline]
606 pub fn is_binary_binned_feature(&self, _index: usize) -> bool {
607 true
608 }
609
610 #[inline]
611 pub fn target(&self) -> &Float64Array {
612 &self.target
613 }
614}
615
616impl Table {
617 pub fn new(x: Vec<Vec<f64>>, y: Vec<f64>) -> Result<Self, DenseTableError> {
618 Self::with_canaries(x, y, DEFAULT_CANARIES)
619 }
620
621 pub fn with_canaries(
622 x: Vec<Vec<f64>>,
623 y: Vec<f64>,
624 canaries: usize,
625 ) -> Result<Self, DenseTableError> {
626 Self::with_options(x, y, canaries, NumericBins::Auto)
627 }
628
629 pub fn with_options(
630 x: Vec<Vec<f64>>,
631 y: Vec<f64>,
632 canaries: usize,
633 numeric_bins: NumericBins,
634 ) -> Result<Self, DenseTableError> {
635 let (columns, target, n_rows, n_features) = preprocess_rows(&x, y)?;
636
637 if columns.iter().all(|column| is_binary_column(column)) {
638 Ok(Self::Sparse(SparseTable::from_columns(
639 &columns,
640 target,
641 n_rows,
642 n_features,
643 canaries,
644 numeric_bins,
645 )))
646 } else {
647 Ok(Self::Dense(DenseTable::from_columns(
648 &columns,
649 target,
650 n_rows,
651 n_features,
652 canaries,
653 numeric_bins,
654 )))
655 }
656 }
657
658 pub fn kind(&self) -> TableKind {
659 match self {
660 Table::Dense(_) => TableKind::Dense,
661 Table::Sparse(_) => TableKind::Sparse,
662 }
663 }
664
665 pub fn as_dense(&self) -> Option<&DenseTable> {
666 match self {
667 Table::Dense(table) => Some(table),
668 Table::Sparse(_) => None,
669 }
670 }
671
672 pub fn as_sparse(&self) -> Option<&SparseTable> {
673 match self {
674 Table::Dense(_) => None,
675 Table::Sparse(table) => Some(table),
676 }
677 }
678}
679
680impl TableAccess for DenseTable {
681 fn n_rows(&self) -> usize {
682 self.n_rows()
683 }
684
685 fn n_features(&self) -> usize {
686 self.n_features()
687 }
688
689 fn canaries(&self) -> usize {
690 self.canaries()
691 }
692
693 fn numeric_bin_cap(&self) -> usize {
694 self.numeric_bin_cap()
695 }
696
697 fn binned_feature_count(&self) -> usize {
698 self.binned_feature_count()
699 }
700
701 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
702 self.feature_value(feature_index, row_index)
703 }
704
705 fn is_binary_feature(&self, index: usize) -> bool {
706 self.is_binary_feature(index)
707 }
708
709 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
710 self.binned_value(feature_index, row_index)
711 }
712
713 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
714 self.binned_boolean_value(feature_index, row_index)
715 }
716
717 fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
718 self.binned_column_kind(index)
719 }
720
721 fn is_binary_binned_feature(&self, index: usize) -> bool {
722 self.is_binary_binned_feature(index)
723 }
724
725 fn target_value(&self, row_index: usize) -> f64 {
726 self.target().value(row_index)
727 }
728}
729
730impl TableAccess for SparseTable {
731 fn n_rows(&self) -> usize {
732 self.n_rows()
733 }
734
735 fn n_features(&self) -> usize {
736 self.n_features()
737 }
738
739 fn canaries(&self) -> usize {
740 self.canaries()
741 }
742
743 fn numeric_bin_cap(&self) -> usize {
744 self.numeric_bin_cap()
745 }
746
747 fn binned_feature_count(&self) -> usize {
748 self.binned_feature_count()
749 }
750
751 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
752 self.feature_value(feature_index, row_index)
753 }
754
755 fn is_binary_feature(&self, index: usize) -> bool {
756 self.is_binary_feature(index)
757 }
758
759 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
760 self.binned_value(feature_index, row_index)
761 }
762
763 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
764 self.binned_boolean_value(feature_index, row_index)
765 }
766
767 fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
768 self.binned_column_kind(index)
769 }
770
771 fn is_binary_binned_feature(&self, index: usize) -> bool {
772 self.is_binary_binned_feature(index)
773 }
774
775 fn target_value(&self, row_index: usize) -> f64 {
776 self.target().value(row_index)
777 }
778}
779
780impl TableAccess for Table {
781 fn n_rows(&self) -> usize {
782 match self {
783 Table::Dense(table) => table.n_rows(),
784 Table::Sparse(table) => table.n_rows(),
785 }
786 }
787
788 fn n_features(&self) -> usize {
789 match self {
790 Table::Dense(table) => table.n_features(),
791 Table::Sparse(table) => table.n_features(),
792 }
793 }
794
795 fn canaries(&self) -> usize {
796 match self {
797 Table::Dense(table) => table.canaries(),
798 Table::Sparse(table) => table.canaries(),
799 }
800 }
801
802 fn numeric_bin_cap(&self) -> usize {
803 match self {
804 Table::Dense(table) => table.numeric_bin_cap(),
805 Table::Sparse(table) => table.numeric_bin_cap(),
806 }
807 }
808
809 fn binned_feature_count(&self) -> usize {
810 match self {
811 Table::Dense(table) => table.binned_feature_count(),
812 Table::Sparse(table) => table.binned_feature_count(),
813 }
814 }
815
816 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
817 match self {
818 Table::Dense(table) => table.feature_value(feature_index, row_index),
819 Table::Sparse(table) => table.feature_value(feature_index, row_index),
820 }
821 }
822
823 fn is_binary_feature(&self, index: usize) -> bool {
824 match self {
825 Table::Dense(table) => table.is_binary_feature(index),
826 Table::Sparse(table) => table.is_binary_feature(index),
827 }
828 }
829
830 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
831 match self {
832 Table::Dense(table) => table.binned_value(feature_index, row_index),
833 Table::Sparse(table) => table.binned_value(feature_index, row_index),
834 }
835 }
836
837 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
838 match self {
839 Table::Dense(table) => table.binned_boolean_value(feature_index, row_index),
840 Table::Sparse(table) => table.binned_boolean_value(feature_index, row_index),
841 }
842 }
843
844 fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
845 match self {
846 Table::Dense(table) => table.binned_column_kind(index),
847 Table::Sparse(table) => table.binned_column_kind(index),
848 }
849 }
850
851 fn is_binary_binned_feature(&self, index: usize) -> bool {
852 match self {
853 Table::Dense(table) => table.is_binary_binned_feature(index),
854 Table::Sparse(table) => table.is_binary_binned_feature(index),
855 }
856 }
857
858 fn target_value(&self, row_index: usize) -> f64 {
859 match self {
860 Table::Dense(table) => table.target().value(row_index),
861 Table::Sparse(table) => table.target().value(row_index),
862 }
863 }
864}
865
866fn preprocess_rows(x: &[Vec<f64>], y: Vec<f64>) -> Result<PreprocessedRows, DenseTableError> {
867 validate_shape(x, &y)?;
868 let n_rows = x.len();
869 let n_features = x.first().map_or(0, Vec::len);
870 let columns = collect_columns(x, n_features);
871 Ok((columns, Float64Array::from(y), n_rows, n_features))
872}
873
874fn validate_shape(x: &[Vec<f64>], y: &[f64]) -> Result<(), DenseTableError> {
875 if x.len() != y.len() {
876 return Err(DenseTableError::MismatchedLengths {
877 x: x.len(),
878 y: y.len(),
879 });
880 }
881
882 let n_features = x.first().map_or(0, Vec::len);
883 for (row_idx, row) in x.iter().enumerate() {
884 if row.len() != n_features {
885 return Err(DenseTableError::RaggedRows {
886 row: row_idx,
887 expected: n_features,
888 actual: row.len(),
889 });
890 }
891 }
892
893 Ok(())
894}
895
896fn collect_columns(x: &[Vec<f64>], n_features: usize) -> Vec<Vec<f64>> {
897 (0..n_features)
898 .map(|col_idx| x.iter().map(|row| row[col_idx]).collect())
899 .collect()
900}
901
902fn validate_binary_columns(columns: &[Vec<f64>]) -> Result<(), DenseTableError> {
903 for (column_idx, column) in columns.iter().enumerate() {
904 if !is_binary_column(column) {
905 return Err(DenseTableError::NonBinaryColumn { column: column_idx });
906 }
907 }
908
909 Ok(())
910}
911
912fn build_feature_column(values: &[f64]) -> FeatureColumn {
913 if is_binary_column(values) {
914 FeatureColumn::Binary(BooleanArray::from(to_binary_values(values)))
915 } else {
916 FeatureColumn::Numeric(Float64Array::from(values.to_vec()))
917 }
918}
919
920fn build_binned_feature_column(values: &[f64], numeric_bins: NumericBins) -> BinnedFeatureColumn {
921 if is_binary_column(values) {
922 BinnedFeatureColumn::Binary(BooleanArray::from(to_binary_values(values)))
923 } else {
924 let bins = bin_numeric_column(values, numeric_bins);
925 if bins.iter().all(|value| *value <= u16::from(u8::MAX)) {
926 BinnedFeatureColumn::NumericU8(UInt8Array::from(
927 bins.into_iter()
928 .map(|value| value as u8)
929 .collect::<Vec<_>>(),
930 ))
931 } else {
932 BinnedFeatureColumn::NumericU16(UInt16Array::from(bins))
933 }
934 }
935}
936
937fn is_binary_column(values: &[f64]) -> bool {
938 values.iter().all(|value| {
939 matches!(value.total_cmp(&0.0), Ordering::Equal)
940 || matches!(value.total_cmp(&1.0), Ordering::Equal)
941 })
942}
943
944fn to_binary_values(values: &[f64]) -> Vec<bool> {
945 values
946 .iter()
947 .map(|value| value.total_cmp(&1.0) == Ordering::Equal)
948 .collect()
949}
950
951fn sparse_binary_column_from_values(values: &[f64]) -> SparseBinaryColumn {
952 SparseBinaryColumn {
953 row_indices: values
954 .iter()
955 .enumerate()
956 .filter_map(|(row_idx, value)| {
957 (value.total_cmp(&1.0) == Ordering::Equal).then_some(row_idx)
958 })
959 .collect(),
960 }
961}
962
963pub fn numeric_bin_boundaries(values: &[f64], numeric_bins: NumericBins) -> Vec<(u16, f64)> {
964 if values.is_empty() {
965 return Vec::new();
966 }
967
968 let mut ranked_values: Vec<(usize, f64)> = values.iter().copied().enumerate().collect();
969 ranked_values.sort_by(|left, right| left.1.total_cmp(&right.1));
970
971 let unique_value_count = ranked_values
972 .iter()
973 .map(|(_row_idx, value)| *value)
974 .fold(Vec::<f64>::new(), |mut unique_values, value| {
975 let is_new_value = unique_values
976 .last()
977 .is_none_or(|last_value| last_value.total_cmp(&value) != Ordering::Equal);
978 if is_new_value {
979 unique_values.push(value);
980 }
981 unique_values
982 })
983 .len();
984
985 let bin_count = resolved_numeric_bin_count(values.len(), unique_value_count, numeric_bins);
986 let mut unique_rank = 0usize;
987 let mut start = 0usize;
988 let mut boundaries = Vec::new();
989
990 while start < ranked_values.len() {
991 let current_value = ranked_values[start].1;
992 let end = ranked_values[start..]
993 .iter()
994 .position(|(_row_idx, value)| value.total_cmp(¤t_value) != Ordering::Equal)
995 .map_or(ranked_values.len(), |offset| start + offset);
996
997 let bin = match numeric_bins {
998 NumericBins::Auto => ((start * bin_count) / values.len()) as u16,
999 NumericBins::Fixed(_) => {
1000 let max_bin = (bin_count - 1) as u16;
1001 if unique_value_count == 1 {
1002 0
1003 } else {
1004 ((unique_rank * usize::from(max_bin)) / (unique_value_count - 1)) as u16
1005 }
1006 }
1007 };
1008
1009 if let Some((last_bin, last_upper_bound)) = boundaries.last_mut() {
1010 if *last_bin == bin {
1011 *last_upper_bound = current_value;
1012 } else {
1013 boundaries.push((bin, current_value));
1014 }
1015 } else {
1016 boundaries.push((bin, current_value));
1017 }
1018
1019 unique_rank += 1;
1020 start = end;
1021 }
1022
1023 boundaries
1024}
1025
1026fn bin_numeric_column(values: &[f64], numeric_bins: NumericBins) -> Vec<u16> {
1027 if values.is_empty() {
1028 return Vec::new();
1029 }
1030
1031 let mut ranked_values: Vec<(usize, f64)> = values.iter().copied().enumerate().collect();
1032 ranked_values.sort_by(|left, right| left.1.total_cmp(&right.1));
1033
1034 let unique_value_count = ranked_values
1035 .iter()
1036 .map(|(_row_idx, value)| *value)
1037 .fold(Vec::<f64>::new(), |mut unique_values, value| {
1038 let is_new_value = unique_values
1039 .last()
1040 .is_none_or(|last_value| last_value.total_cmp(&value) != Ordering::Equal);
1041 if is_new_value {
1042 unique_values.push(value);
1043 }
1044 unique_values
1045 })
1046 .len();
1047
1048 let mut bins = vec![0u16; values.len()];
1049 let bin_count = resolved_numeric_bin_count(values.len(), unique_value_count, numeric_bins);
1050 let mut unique_rank = 0usize;
1051 let mut start = 0usize;
1052
1053 while start < ranked_values.len() {
1054 let current_value = ranked_values[start].1;
1055 let end = ranked_values[start..]
1056 .iter()
1057 .position(|(_row_idx, value)| value.total_cmp(¤t_value) != Ordering::Equal)
1058 .map_or(ranked_values.len(), |offset| start + offset);
1059
1060 let bin = match numeric_bins {
1061 NumericBins::Auto => ((start * bin_count) / values.len()) as u16,
1062 NumericBins::Fixed(_) => {
1063 let max_bin = (bin_count - 1) as u16;
1064 if unique_value_count == 1 {
1065 0
1066 } else {
1067 ((unique_rank * usize::from(max_bin)) / (unique_value_count - 1)) as u16
1068 }
1069 }
1070 };
1071
1072 for (row_idx, _value) in &ranked_values[start..end] {
1073 bins[*row_idx] = bin;
1074 }
1075
1076 unique_rank += 1;
1077 start = end;
1078 }
1079
1080 bins
1081}
1082
1083fn resolved_numeric_bin_count(
1084 value_count: usize,
1085 unique_value_count: usize,
1086 numeric_bins: NumericBins,
1087) -> usize {
1088 match numeric_bins {
1089 NumericBins::Auto => {
1090 let populated_bin_cap = (value_count / 2).max(1);
1091 let capped_unique_values = unique_value_count
1092 .min(MAX_NUMERIC_BINS)
1093 .min(populated_bin_cap)
1094 .max(1);
1095 highest_power_of_two_at_most(capped_unique_values)
1096 }
1097 NumericBins::Fixed(requested) => requested.min(unique_value_count).max(1),
1098 }
1099}
1100
1101fn highest_power_of_two_at_most(value: usize) -> usize {
1102 if value <= 1 {
1103 1
1104 } else {
1105 1usize << (usize::BITS as usize - 1 - value.leading_zeros() as usize)
1106 }
1107}
1108
1109fn shuffle_canary_column(
1110 values: &BinnedFeatureColumn,
1111 copy_index: usize,
1112 source_index: usize,
1113) -> BinnedFeatureColumn {
1114 match values {
1115 BinnedFeatureColumn::NumericU8(values) => {
1116 let mut shuffled = (0..values.len())
1117 .map(|idx| values.value(idx))
1118 .collect::<Vec<_>>();
1119 shuffle_values(&mut shuffled, copy_index, source_index);
1120 BinnedFeatureColumn::NumericU8(UInt8Array::from(shuffled))
1121 }
1122 BinnedFeatureColumn::NumericU16(values) => {
1123 let mut shuffled = (0..values.len())
1124 .map(|idx| values.value(idx))
1125 .collect::<Vec<_>>();
1126 shuffle_values(&mut shuffled, copy_index, source_index);
1127 BinnedFeatureColumn::NumericU16(UInt16Array::from(shuffled))
1128 }
1129 BinnedFeatureColumn::Binary(values) => {
1130 BinnedFeatureColumn::Binary(shuffle_boolean_array(values, copy_index, source_index))
1131 }
1132 }
1133}
1134
1135fn shuffle_boolean_array(
1136 values: &BooleanArray,
1137 copy_index: usize,
1138 source_index: usize,
1139) -> BooleanArray {
1140 let mut shuffled = (0..values.len())
1141 .map(|idx| values.value(idx))
1142 .collect::<Vec<_>>();
1143 shuffle_values(&mut shuffled, copy_index, source_index);
1144 BooleanArray::from(shuffled)
1145}
1146
1147fn shuffle_sparse_binary_column(
1148 values: &SparseBinaryColumn,
1149 n_rows: usize,
1150 copy_index: usize,
1151 source_index: usize,
1152) -> SparseBinaryColumn {
1153 let mut dense = vec![false; n_rows];
1154 for row_idx in &values.row_indices {
1155 dense[*row_idx] = true;
1156 }
1157 shuffle_values(&mut dense, copy_index, source_index);
1158 SparseBinaryColumn {
1159 row_indices: dense
1160 .into_iter()
1161 .enumerate()
1162 .filter_map(|(row_idx, value)| value.then_some(row_idx))
1163 .collect(),
1164 }
1165}
1166
1167fn shuffle_values<T>(values: &mut [T], copy_index: usize, source_index: usize) {
1168 let seed = 0xA11CE5EED_u64
1169 ^ ((copy_index as u64) << 32)
1170 ^ (source_index as u64)
1171 ^ ((values.len() as u64) << 16);
1172 let mut rng = StdRng::seed_from_u64(seed);
1173 values.shuffle(&mut rng);
1174}
1175
1176#[cfg(test)]
1177mod tests {
1178 use super::*;
1179 use std::collections::{BTreeMap, BTreeSet};
1180
1181 #[test]
1182 fn builds_arrow_backed_dense_table() {
1183 let table =
1184 DenseTable::new(vec![vec![0.0, 10.0], vec![1.0, 20.0]], vec![3.0, 5.0]).unwrap();
1185
1186 assert_eq!(table.n_rows(), 2);
1187 assert_eq!(table.n_features(), 2);
1188 assert_eq!(table.canaries(), 2);
1189 assert_eq!(table.binned_feature_count(), 6);
1190 assert_eq!(table.feature_value(0, 0), 0.0);
1191 assert_eq!(table.feature_value(0, 1), 1.0);
1192 assert_eq!(table.target().value(0), 3.0);
1193 assert_eq!(table.target().value(1), 5.0);
1194 assert!(!table.is_canary_binned_feature(0));
1195 assert!(table.is_canary_binned_feature(2));
1196 }
1197
1198 #[test]
1199 fn builds_sparse_table_for_all_binary_features() {
1200 let table = Table::with_canaries(
1201 vec![vec![0.0, 1.0], vec![1.0, 0.0], vec![1.0, 1.0]],
1202 vec![0.0, 1.0, 1.0],
1203 1,
1204 )
1205 .unwrap();
1206
1207 assert_eq!(table.kind(), TableKind::Sparse);
1208 assert!(table.is_binary_feature(0));
1209 assert!(table.is_binary_feature(1));
1210 assert!(table.is_binary_binned_feature(0));
1211 assert_eq!(table.binned_feature_count(), 4);
1212 }
1213
1214 #[test]
1215 fn builds_dense_table_when_any_feature_is_non_binary() {
1216 let table = Table::with_canaries(
1217 vec![vec![0.0, 1.5], vec![1.0, 0.0], vec![1.0, 2.0]],
1218 vec![0.0, 1.0, 1.0],
1219 1,
1220 )
1221 .unwrap();
1222
1223 assert_eq!(table.kind(), TableKind::Dense);
1224 assert!(table.is_binary_feature(0));
1225 assert!(!table.is_binary_feature(1));
1226 }
1227
1228 #[test]
1229 fn sparse_table_rejects_non_binary_columns() {
1230 let err =
1231 SparseTable::with_canaries(vec![vec![0.0, 2.0], vec![1.0, 0.0]], vec![0.0, 1.0], 0)
1232 .unwrap_err();
1233
1234 assert_eq!(err, DenseTableError::NonBinaryColumn { column: 1 });
1235 }
1236
1237 #[test]
1238 fn auto_bins_numeric_columns_into_power_of_two_bins_up_to_128() {
1239 let x: Vec<Vec<f64>> = (0..1024).map(|value| vec![value as f64]).collect();
1240 let y: Vec<f64> = vec![1.0; 1024];
1241
1242 let table = DenseTable::with_canaries(x, y, 0).unwrap();
1243
1244 assert_eq!(table.binned_value(0, 0), 0);
1245 assert_eq!(table.binned_value(0, 1023), 127);
1246 assert!((1..1024).all(|idx| table.binned_value(0, idx - 1) <= table.binned_value(0, idx)));
1247 assert_eq!(
1248 (0..1024)
1249 .map(|idx| table.binned_value(0, idx))
1250 .collect::<BTreeSet<_>>()
1251 .len(),
1252 128
1253 );
1254 }
1255
1256 #[test]
1257 fn auto_bins_choose_highest_populated_power_of_two() {
1258 let x: Vec<Vec<f64>> = (0..300).map(|value| vec![value as f64]).collect();
1259 let y = vec![0.0; 300];
1260
1261 let table = DenseTable::with_canaries(x, y, 0).unwrap();
1262
1263 assert_eq!(
1264 (0..300)
1265 .map(|idx| table.binned_value(0, idx))
1266 .collect::<BTreeSet<_>>()
1267 .len(),
1268 128
1269 );
1270 }
1271
1272 #[test]
1273 fn auto_bins_require_at_least_two_rows_per_bin() {
1274 let x: Vec<Vec<f64>> = (0..8).map(|value| vec![value as f64]).collect();
1275 let y = vec![0.0; 8];
1276
1277 let table = DenseTable::with_canaries(x, y, 0).unwrap();
1278 let counts = (0..table.n_rows()).fold(BTreeMap::new(), |mut counts, row_idx| {
1279 *counts
1280 .entry(table.binned_value(0, row_idx))
1281 .or_insert(0usize) += 1;
1282 counts
1283 });
1284
1285 assert_eq!(counts.len(), 4);
1286 assert!(counts.values().all(|count| *count >= 2));
1287 }
1288
1289 #[test]
1290 fn fixed_bins_cap_numeric_columns_to_requested_limit() {
1291 let x: Vec<Vec<f64>> = (0..300).map(|value| vec![value as f64]).collect();
1292 let y = vec![0.0; 300];
1293
1294 let table = DenseTable::with_options(x, y, 0, NumericBins::Fixed(64)).unwrap();
1295
1296 assert_eq!(
1297 (0..300)
1298 .map(|idx| table.binned_value(0, idx))
1299 .collect::<BTreeSet<_>>()
1300 .len(),
1301 64
1302 );
1303 }
1304
1305 #[test]
1306 fn rejects_invalid_fixed_bin_count() {
1307 assert_eq!(
1308 NumericBins::fixed(0).unwrap_err(),
1309 DenseTableError::InvalidBinCount { requested: 0 }
1310 );
1311 assert_eq!(
1312 NumericBins::fixed(513).unwrap_err(),
1313 DenseTableError::InvalidBinCount { requested: 513 }
1314 );
1315 }
1316
1317 #[test]
1318 fn keeps_equal_values_in_the_same_bin() {
1319 let table = DenseTable::with_canaries(
1320 vec![vec![0.0], vec![0.0], vec![1.0], vec![1.0], vec![2.0]],
1321 vec![0.0; 5],
1322 0,
1323 )
1324 .unwrap();
1325
1326 assert_eq!(table.binned_value(0, 0), table.binned_value(0, 1));
1327 assert_eq!(table.binned_value(0, 2), table.binned_value(0, 3));
1328 assert!(table.binned_value(0, 1) <= table.binned_value(0, 2));
1329 assert!(table.binned_value(0, 3) < table.binned_value(0, 4));
1330 }
1331
1332 #[test]
1333 fn stores_binary_columns_as_booleans() {
1334 let table = DenseTable::with_canaries(
1335 vec![vec![0.0, 2.0], vec![1.0, 3.0], vec![0.0, 4.0]],
1336 vec![0.0; 3],
1337 1,
1338 )
1339 .unwrap();
1340
1341 assert!(table.is_binary_feature(0));
1342 assert!(!table.is_binary_feature(1));
1343 assert!(table.is_binary_binned_feature(0));
1344 assert!(!table.is_binary_binned_feature(1));
1345 assert!(table.is_binary_binned_feature(2));
1346 assert_eq!(table.feature_value(0, 0), 0.0);
1347 assert_eq!(table.feature_value(0, 1), 1.0);
1348 assert_eq!(table.binned_boolean_value(0, 0), Some(false));
1349 assert_eq!(table.binned_boolean_value(0, 1), Some(true));
1350 }
1351
1352 #[test]
1353 fn stores_small_auto_binned_numeric_columns_as_u8() {
1354 let table = DenseTable::with_canaries(
1355 (0..8).map(|value| vec![value as f64]).collect(),
1356 vec![0.0; 8],
1357 0,
1358 )
1359 .unwrap();
1360
1361 assert!(matches!(
1362 table.binned_feature_column(0),
1363 BinnedFeatureColumnRef::NumericU8(_)
1364 ));
1365 }
1366
1367 #[test]
1368 fn creates_canary_columns_as_shuffled_binned_copies() {
1369 let table = DenseTable::with_canaries(
1370 vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]],
1371 vec![0.0; 5],
1372 1,
1373 )
1374 .unwrap();
1375
1376 assert!(matches!(
1377 table.binned_column_kind(1),
1378 BinnedColumnKind::Canary {
1379 source_index: 0,
1380 copy_index: 0
1381 }
1382 ));
1383 assert_eq!(
1384 (0..table.n_rows())
1385 .map(|idx| table.binned_value(0, idx))
1386 .collect::<BTreeSet<_>>(),
1387 (0..table.n_rows())
1388 .map(|idx| table.binned_value(1, idx))
1389 .collect::<BTreeSet<_>>()
1390 );
1391 assert_ne!(
1392 (0..table.n_rows())
1393 .map(|idx| table.binned_value(0, idx))
1394 .collect::<Vec<_>>(),
1395 (0..table.n_rows())
1396 .map(|idx| table.binned_value(1, idx))
1397 .collect::<Vec<_>>()
1398 );
1399 }
1400
1401 #[test]
1402 fn rejects_ragged_rows() {
1403 let err = DenseTable::new(vec![vec![1.0, 2.0], vec![3.0]], vec![1.0, 2.0]).unwrap_err();
1404
1405 assert_eq!(
1406 err,
1407 DenseTableError::RaggedRows {
1408 row: 1,
1409 expected: 2,
1410 actual: 1,
1411 }
1412 );
1413 }
1414
1415 #[test]
1416 fn rejects_mismatched_lengths() {
1417 let err = DenseTable::new(vec![vec![1.0], vec![2.0]], vec![1.0]).unwrap_err();
1418
1419 assert_eq!(err, DenseTableError::MismatchedLengths { x: 2, y: 1 });
1420 }
1421
1422 #[test]
1423 fn canary_generation_is_deterministic_for_identical_inputs() {
1424 let left = DenseTable::with_canaries(
1425 vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]],
1426 vec![0.0; 5],
1427 2,
1428 )
1429 .unwrap();
1430 let right = DenseTable::with_canaries(
1431 vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]],
1432 vec![0.0; 5],
1433 2,
1434 )
1435 .unwrap();
1436
1437 let left_values = binned_snapshot(&left);
1438 let right_values = binned_snapshot(&right);
1439
1440 assert_eq!(left_values, right_values);
1441 }
1442
1443 #[test]
1444 fn binary_canaries_remain_boolean_and_preserve_value_counts() {
1445 let table = DenseTable::with_canaries(
1446 vec![
1447 vec![0.0],
1448 vec![1.0],
1449 vec![0.0],
1450 vec![1.0],
1451 vec![1.0],
1452 vec![0.0],
1453 ],
1454 vec![0.0; 6],
1455 2,
1456 )
1457 .unwrap();
1458
1459 let real_true_count = (0..table.n_rows())
1460 .filter(|row_idx| table.binned_boolean_value(0, *row_idx) == Some(true))
1461 .count();
1462
1463 for feature_index in 1..table.binned_feature_count() {
1464 assert!(table.is_binary_binned_feature(feature_index));
1465 let canary_true_count = (0..table.n_rows())
1466 .filter(|row_idx| table.binned_boolean_value(feature_index, *row_idx) == Some(true))
1467 .count();
1468 assert_eq!(canary_true_count, real_true_count);
1469 }
1470 }
1471
1472 #[test]
1473 fn numeric_bin_boundaries_capture_training_bin_upper_bounds() {
1474 let boundaries = numeric_bin_boundaries(&[1.0, 1.0, 2.0, 10.0], NumericBins::Auto);
1475
1476 assert_eq!(boundaries, vec![(0, 1.0), (1, 10.0)]);
1477 }
1478
1479 fn binned_snapshot(table: &DenseTable) -> Vec<u16> {
1480 let mut values = Vec::new();
1481
1482 for feature_idx in 0..table.binned_feature_count() {
1483 for row_idx in 0..table.n_rows() {
1484 values.push(table.binned_value(feature_idx, row_idx));
1485 }
1486 }
1487
1488 values
1489 }
1490}