use arrow::array::{Array, BooleanArray, Float64Array, UInt8Array, UInt16Array};
use rand::seq::SliceRandom;
use rand::{SeedableRng, rngs::StdRng};
use std::cmp::Ordering;
use std::error::Error;
use std::fmt::{Display, Formatter};
pub const MAX_NUMERIC_BINS: usize = 128;
const DEFAULT_CANARIES: usize = 2;
pub const BINARY_MISSING_BIN: u16 = 2;
type PreprocessedRows = (Vec<Vec<f64>>, Float64Array, usize, usize);
pub trait TableAccess: Sync {
fn n_rows(&self) -> usize;
fn n_features(&self) -> usize;
fn canaries(&self) -> usize;
fn numeric_bin_cap(&self) -> usize;
fn binned_feature_count(&self) -> usize;
fn feature_value(&self, feature_index: usize, row_index: usize) -> f64;
fn is_missing(&self, feature_index: usize, row_index: usize) -> bool;
fn is_binary_feature(&self, index: usize) -> bool;
fn binned_value(&self, feature_index: usize, row_index: usize) -> u16;
fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool>;
fn binned_column_kind(&self, index: usize) -> BinnedColumnKind;
fn is_binary_binned_feature(&self, index: usize) -> bool;
fn target_value(&self, row_index: usize) -> f64;
fn is_canary_binned_feature(&self, index: usize) -> bool {
matches!(
self.binned_column_kind(index),
BinnedColumnKind::Canary { .. }
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TableKind {
Dense,
Sparse,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum NumericBins {
#[default]
Auto,
Fixed(usize),
}
impl NumericBins {
pub fn fixed(requested: usize) -> Result<Self, DenseTableError> {
if requested == 0 || requested > MAX_NUMERIC_BINS {
return Err(DenseTableError::InvalidBinCount { requested });
}
Ok(Self::Fixed(requested))
}
pub fn cap(self) -> usize {
match self {
NumericBins::Auto => MAX_NUMERIC_BINS,
NumericBins::Fixed(requested) => requested,
}
}
}
#[derive(Debug, Clone)]
pub struct DenseTable {
feature_columns: Vec<FeatureColumn>,
binned_feature_columns: Vec<BinnedFeatureColumn>,
binned_column_kinds: Vec<BinnedColumnKind>,
target: Float64Array,
n_rows: usize,
n_features: usize,
canaries: usize,
numeric_bins: NumericBins,
}
#[derive(Debug, Clone)]
pub struct SparseTable {
feature_columns: Vec<SparseBinaryColumn>,
binned_feature_columns: Vec<SparseBinaryColumn>,
binned_column_kinds: Vec<BinnedColumnKind>,
target: Float64Array,
n_rows: usize,
n_features: usize,
canaries: usize,
numeric_bins: NumericBins,
}
#[derive(Debug, Clone)]
struct SparseBinaryColumn {
row_indices: Vec<usize>,
}
impl SparseBinaryColumn {
fn value(&self, row_index: usize) -> bool {
self.row_indices.binary_search(&row_index).is_ok()
}
}
#[derive(Debug, Clone)]
pub enum Table {
Dense(DenseTable),
Sparse(SparseTable),
}
#[derive(Debug, Clone)]
enum FeatureColumn {
Numeric(Float64Array),
Binary(BooleanArray),
}
#[derive(Debug, Clone)]
enum BinnedFeatureColumn {
NumericU8(UInt8Array),
NumericU16(UInt16Array),
Binary(BooleanArray),
}
#[derive(Debug, Clone, Copy)]
pub enum FeatureColumnRef<'a> {
Numeric(&'a Float64Array),
Binary(&'a BooleanArray),
}
#[derive(Debug, Clone, Copy)]
pub enum BinnedFeatureColumnRef<'a> {
NumericU8(&'a UInt8Array),
NumericU16(&'a UInt16Array),
Binary(&'a BooleanArray),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BinnedColumnKind {
Real { source_index: usize },
Canary {
source_index: usize,
copy_index: usize,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DenseTableError {
MismatchedLengths {
x: usize,
y: usize,
},
RaggedRows {
row: usize,
expected: usize,
actual: usize,
},
NonBinaryColumn {
column: usize,
},
InvalidBinCount {
requested: usize,
},
}
impl Display for DenseTableError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
DenseTableError::MismatchedLengths { x, y } => write!(
f,
"Mismatched lengths: X has {} rows while y has {} values.",
x, y
),
DenseTableError::RaggedRows {
row,
expected,
actual,
} => write!(
f,
"Ragged row at index {}: expected {} columns, found {}.",
row, expected, actual
),
DenseTableError::NonBinaryColumn { column } => write!(
f,
"SparseTable requires binary features, but column {} contains non-binary values.",
column
),
DenseTableError::InvalidBinCount { requested } => write!(
f,
"Invalid bins value {}. Expected 'auto' or an integer between 1 and {}.",
requested, MAX_NUMERIC_BINS
),
}
}
}
impl Error for DenseTableError {}
impl DenseTable {
pub fn new(x: Vec<Vec<f64>>, y: Vec<f64>) -> Result<Self, DenseTableError> {
Self::with_canaries(x, y, DEFAULT_CANARIES)
}
pub fn with_canaries(
x: Vec<Vec<f64>>,
y: Vec<f64>,
canaries: usize,
) -> Result<Self, DenseTableError> {
Self::with_options(x, y, canaries, NumericBins::Auto)
}
pub fn with_options(
x: Vec<Vec<f64>>,
y: Vec<f64>,
canaries: usize,
numeric_bins: NumericBins,
) -> Result<Self, DenseTableError> {
let (columns, target, n_rows, n_features) = preprocess_rows(&x, y)?;
Ok(Self::from_columns(
&columns,
target,
n_rows,
n_features,
canaries,
numeric_bins,
))
}
fn from_columns(
columns: &[Vec<f64>],
target: Float64Array,
n_rows: usize,
n_features: usize,
canaries: usize,
numeric_bins: NumericBins,
) -> Self {
let feature_columns = columns
.iter()
.map(|column| build_feature_column(column))
.collect();
let real_binned_columns: Vec<BinnedFeatureColumn> = columns
.iter()
.map(|column| build_binned_feature_column(column, numeric_bins))
.collect();
let canary_columns: Vec<(BinnedColumnKind, BinnedFeatureColumn)> = (0..canaries)
.flat_map(|copy_index| {
real_binned_columns
.iter()
.enumerate()
.map(move |(source_index, column)| {
(
BinnedColumnKind::Canary {
source_index,
copy_index,
},
shuffle_canary_column(column, copy_index, source_index),
)
})
})
.collect();
let (binned_column_kinds, binned_feature_columns): (Vec<_>, Vec<_>) = (0..n_features)
.map(|source_index| BinnedColumnKind::Real { source_index })
.zip(real_binned_columns)
.chain(canary_columns)
.unzip();
Self {
feature_columns,
binned_feature_columns,
binned_column_kinds,
target,
n_rows,
n_features,
canaries,
numeric_bins,
}
}
#[inline]
pub fn n_rows(&self) -> usize {
self.n_rows
}
#[inline]
pub fn n_features(&self) -> usize {
self.n_features
}
#[inline]
pub fn canaries(&self) -> usize {
self.canaries
}
#[inline]
pub fn numeric_bin_cap(&self) -> usize {
self.numeric_bins.cap()
}
#[inline]
pub fn binned_feature_count(&self) -> usize {
self.binned_feature_columns.len()
}
#[inline]
pub fn feature_column(&self, index: usize) -> FeatureColumnRef<'_> {
match &self.feature_columns[index] {
FeatureColumn::Numeric(column) => FeatureColumnRef::Numeric(column),
FeatureColumn::Binary(column) => FeatureColumnRef::Binary(column),
}
}
#[inline]
pub fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
match &self.feature_columns[feature_index] {
FeatureColumn::Numeric(column) => column.value(row_index),
FeatureColumn::Binary(column) => {
if column.is_null(row_index) {
f64::NAN
} else {
f64::from(u8::from(column.value(row_index)))
}
}
}
}
#[inline]
pub fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
match &self.feature_columns[feature_index] {
FeatureColumn::Numeric(column) => column.value(row_index).is_nan(),
FeatureColumn::Binary(column) => column.is_null(row_index),
}
}
#[inline]
pub fn is_binary_feature(&self, index: usize) -> bool {
matches!(self.feature_columns[index], FeatureColumn::Binary(_))
}
#[inline]
pub fn binned_feature_column(&self, index: usize) -> BinnedFeatureColumnRef<'_> {
match &self.binned_feature_columns[index] {
BinnedFeatureColumn::NumericU8(column) => BinnedFeatureColumnRef::NumericU8(column),
BinnedFeatureColumn::NumericU16(column) => BinnedFeatureColumnRef::NumericU16(column),
BinnedFeatureColumn::Binary(column) => BinnedFeatureColumnRef::Binary(column),
}
}
#[inline]
pub fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
match &self.binned_feature_columns[feature_index] {
BinnedFeatureColumn::NumericU8(column) => u16::from(column.value(row_index)),
BinnedFeatureColumn::NumericU16(column) => column.value(row_index),
BinnedFeatureColumn::Binary(column) => {
if column.is_null(row_index) {
BINARY_MISSING_BIN
} else {
u16::from(u8::from(column.value(row_index)))
}
}
}
}
#[inline]
pub fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
match &self.binned_feature_columns[feature_index] {
BinnedFeatureColumn::Binary(column) => {
(!column.is_null(row_index)).then(|| column.value(row_index))
}
BinnedFeatureColumn::NumericU8(_) | BinnedFeatureColumn::NumericU16(_) => None,
}
}
#[inline]
pub fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
self.binned_column_kinds[index]
}
#[inline]
pub fn is_canary_binned_feature(&self, index: usize) -> bool {
matches!(
self.binned_column_kinds[index],
BinnedColumnKind::Canary { .. }
)
}
#[inline]
pub fn is_binary_binned_feature(&self, index: usize) -> bool {
matches!(
self.binned_feature_columns[index],
BinnedFeatureColumn::Binary(_)
)
}
#[inline]
pub fn target(&self) -> &Float64Array {
&self.target
}
}
impl SparseTable {
pub fn new(x: Vec<Vec<f64>>, y: Vec<f64>) -> Result<Self, DenseTableError> {
Self::with_canaries(x, y, DEFAULT_CANARIES)
}
pub fn with_canaries(
x: Vec<Vec<f64>>,
y: Vec<f64>,
canaries: usize,
) -> Result<Self, DenseTableError> {
Self::with_options(x, y, canaries, NumericBins::Auto)
}
pub fn with_options(
x: Vec<Vec<f64>>,
y: Vec<f64>,
canaries: usize,
numeric_bins: NumericBins,
) -> Result<Self, DenseTableError> {
let (columns, target, n_rows, n_features) = preprocess_rows(&x, y)?;
validate_binary_columns(&columns)?;
Ok(Self::from_columns(
&columns,
target,
n_rows,
n_features,
canaries,
numeric_bins,
))
}
fn from_columns(
columns: &[Vec<f64>],
target: Float64Array,
n_rows: usize,
n_features: usize,
canaries: usize,
numeric_bins: NumericBins,
) -> Self {
let feature_columns: Vec<SparseBinaryColumn> = columns
.iter()
.map(|column| sparse_binary_column_from_values(column))
.collect();
let canary_columns: Vec<(BinnedColumnKind, SparseBinaryColumn)> = (0..canaries)
.flat_map(|copy_index| {
feature_columns
.iter()
.enumerate()
.map(move |(source_index, column)| {
(
BinnedColumnKind::Canary {
source_index,
copy_index,
},
shuffle_sparse_binary_column(column, n_rows, copy_index, source_index),
)
})
})
.collect();
let (binned_column_kinds, binned_feature_columns): (Vec<_>, Vec<_>) = (0..n_features)
.map(|source_index| BinnedColumnKind::Real { source_index })
.zip(feature_columns.iter().cloned())
.chain(canary_columns)
.unzip();
Self {
feature_columns,
binned_feature_columns,
binned_column_kinds,
target,
n_rows,
n_features,
canaries,
numeric_bins,
}
}
pub fn from_sparse_binary_columns(
n_rows: usize,
n_features: usize,
columns: Vec<Vec<usize>>,
y: Vec<f64>,
canaries: usize,
) -> Result<Self, DenseTableError> {
Self::from_sparse_binary_columns_with_options(
n_rows,
n_features,
columns,
y,
canaries,
NumericBins::Auto,
)
}
pub fn from_sparse_binary_columns_with_options(
n_rows: usize,
n_features: usize,
columns: Vec<Vec<usize>>,
y: Vec<f64>,
canaries: usize,
numeric_bins: NumericBins,
) -> Result<Self, DenseTableError> {
if n_rows != y.len() {
return Err(DenseTableError::MismatchedLengths {
x: n_rows,
y: y.len(),
});
}
if n_features != columns.len() {
return Err(DenseTableError::RaggedRows {
row: columns.len(),
expected: n_features,
actual: columns.len(),
});
}
let feature_columns = columns
.into_iter()
.enumerate()
.map(|(column_idx, mut row_indices)| {
row_indices.sort_unstable();
row_indices.dedup();
if row_indices.iter().any(|row_idx| *row_idx >= n_rows) {
return Err(DenseTableError::NonBinaryColumn { column: column_idx });
}
Ok(SparseBinaryColumn { row_indices })
})
.collect::<Result<Vec<_>, _>>()?;
let canary_columns: Vec<(BinnedColumnKind, SparseBinaryColumn)> = (0..canaries)
.flat_map(|copy_index| {
feature_columns
.iter()
.enumerate()
.map(move |(source_index, column)| {
(
BinnedColumnKind::Canary {
source_index,
copy_index,
},
shuffle_sparse_binary_column(column, n_rows, copy_index, source_index),
)
})
})
.collect();
let (binned_column_kinds, binned_feature_columns): (Vec<_>, Vec<_>) = (0..n_features)
.map(|source_index| BinnedColumnKind::Real { source_index })
.zip(feature_columns.iter().cloned())
.chain(canary_columns)
.unzip();
Ok(Self {
feature_columns,
binned_feature_columns,
binned_column_kinds,
target: Float64Array::from(y),
n_rows,
n_features,
canaries,
numeric_bins,
})
}
#[inline]
pub fn n_rows(&self) -> usize {
self.n_rows
}
#[inline]
pub fn n_features(&self) -> usize {
self.n_features
}
#[inline]
pub fn canaries(&self) -> usize {
self.canaries
}
#[inline]
pub fn numeric_bin_cap(&self) -> usize {
self.numeric_bins.cap()
}
#[inline]
pub fn binned_feature_count(&self) -> usize {
self.binned_feature_columns.len()
}
#[inline]
pub fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
f64::from(u8::from(
self.feature_columns[feature_index].value(row_index),
))
}
#[inline]
pub fn is_missing(&self, _feature_index: usize, _row_index: usize) -> bool {
false
}
#[inline]
pub fn is_binary_feature(&self, _index: usize) -> bool {
true
}
#[inline]
pub fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
u16::from(u8::from(
self.binned_feature_columns[feature_index].value(row_index),
))
}
#[inline]
pub fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
Some(self.binned_feature_columns[feature_index].value(row_index))
}
#[inline]
pub fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
self.binned_column_kinds[index]
}
#[inline]
pub fn is_canary_binned_feature(&self, index: usize) -> bool {
matches!(
self.binned_column_kinds[index],
BinnedColumnKind::Canary { .. }
)
}
#[inline]
pub fn is_binary_binned_feature(&self, _index: usize) -> bool {
true
}
#[inline]
pub fn target(&self) -> &Float64Array {
&self.target
}
}
impl Table {
pub fn new(x: Vec<Vec<f64>>, y: Vec<f64>) -> Result<Self, DenseTableError> {
Self::with_canaries(x, y, DEFAULT_CANARIES)
}
pub fn with_canaries(
x: Vec<Vec<f64>>,
y: Vec<f64>,
canaries: usize,
) -> Result<Self, DenseTableError> {
Self::with_options(x, y, canaries, NumericBins::Auto)
}
pub fn with_options(
x: Vec<Vec<f64>>,
y: Vec<f64>,
canaries: usize,
numeric_bins: NumericBins,
) -> Result<Self, DenseTableError> {
let (columns, target, n_rows, n_features) = preprocess_rows(&x, y)?;
if columns
.iter()
.all(|column| is_sparse_binary_eligible_column(column))
{
Ok(Self::Sparse(SparseTable::from_columns(
&columns,
target,
n_rows,
n_features,
canaries,
numeric_bins,
)))
} else {
Ok(Self::Dense(DenseTable::from_columns(
&columns,
target,
n_rows,
n_features,
canaries,
numeric_bins,
)))
}
}
pub fn kind(&self) -> TableKind {
match self {
Table::Dense(_) => TableKind::Dense,
Table::Sparse(_) => TableKind::Sparse,
}
}
pub fn as_dense(&self) -> Option<&DenseTable> {
match self {
Table::Dense(table) => Some(table),
Table::Sparse(_) => None,
}
}
pub fn as_sparse(&self) -> Option<&SparseTable> {
match self {
Table::Dense(_) => None,
Table::Sparse(table) => Some(table),
}
}
}
impl TableAccess for DenseTable {
fn n_rows(&self) -> usize {
self.n_rows()
}
fn n_features(&self) -> usize {
self.n_features()
}
fn canaries(&self) -> usize {
self.canaries()
}
fn numeric_bin_cap(&self) -> usize {
self.numeric_bin_cap()
}
fn binned_feature_count(&self) -> usize {
self.binned_feature_count()
}
fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
self.feature_value(feature_index, row_index)
}
fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
self.is_missing(feature_index, row_index)
}
fn is_binary_feature(&self, index: usize) -> bool {
self.is_binary_feature(index)
}
fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
self.binned_value(feature_index, row_index)
}
fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
self.binned_boolean_value(feature_index, row_index)
}
fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
self.binned_column_kind(index)
}
fn is_binary_binned_feature(&self, index: usize) -> bool {
self.is_binary_binned_feature(index)
}
fn target_value(&self, row_index: usize) -> f64 {
self.target().value(row_index)
}
}
impl TableAccess for SparseTable {
fn n_rows(&self) -> usize {
self.n_rows()
}
fn n_features(&self) -> usize {
self.n_features()
}
fn canaries(&self) -> usize {
self.canaries()
}
fn numeric_bin_cap(&self) -> usize {
self.numeric_bin_cap()
}
fn binned_feature_count(&self) -> usize {
self.binned_feature_count()
}
fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
self.feature_value(feature_index, row_index)
}
fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
self.is_missing(feature_index, row_index)
}
fn is_binary_feature(&self, index: usize) -> bool {
self.is_binary_feature(index)
}
fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
self.binned_value(feature_index, row_index)
}
fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
self.binned_boolean_value(feature_index, row_index)
}
fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
self.binned_column_kind(index)
}
fn is_binary_binned_feature(&self, index: usize) -> bool {
self.is_binary_binned_feature(index)
}
fn target_value(&self, row_index: usize) -> f64 {
self.target().value(row_index)
}
}
impl TableAccess for Table {
fn n_rows(&self) -> usize {
match self {
Table::Dense(table) => table.n_rows(),
Table::Sparse(table) => table.n_rows(),
}
}
fn n_features(&self) -> usize {
match self {
Table::Dense(table) => table.n_features(),
Table::Sparse(table) => table.n_features(),
}
}
fn canaries(&self) -> usize {
match self {
Table::Dense(table) => table.canaries(),
Table::Sparse(table) => table.canaries(),
}
}
fn numeric_bin_cap(&self) -> usize {
match self {
Table::Dense(table) => table.numeric_bin_cap(),
Table::Sparse(table) => table.numeric_bin_cap(),
}
}
fn binned_feature_count(&self) -> usize {
match self {
Table::Dense(table) => table.binned_feature_count(),
Table::Sparse(table) => table.binned_feature_count(),
}
}
fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
match self {
Table::Dense(table) => table.feature_value(feature_index, row_index),
Table::Sparse(table) => table.feature_value(feature_index, row_index),
}
}
fn is_missing(&self, feature_index: usize, row_index: usize) -> bool {
match self {
Table::Dense(table) => table.is_missing(feature_index, row_index),
Table::Sparse(table) => table.is_missing(feature_index, row_index),
}
}
fn is_binary_feature(&self, index: usize) -> bool {
match self {
Table::Dense(table) => table.is_binary_feature(index),
Table::Sparse(table) => table.is_binary_feature(index),
}
}
fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
match self {
Table::Dense(table) => table.binned_value(feature_index, row_index),
Table::Sparse(table) => table.binned_value(feature_index, row_index),
}
}
fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
match self {
Table::Dense(table) => table.binned_boolean_value(feature_index, row_index),
Table::Sparse(table) => table.binned_boolean_value(feature_index, row_index),
}
}
fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
match self {
Table::Dense(table) => table.binned_column_kind(index),
Table::Sparse(table) => table.binned_column_kind(index),
}
}
fn is_binary_binned_feature(&self, index: usize) -> bool {
match self {
Table::Dense(table) => table.is_binary_binned_feature(index),
Table::Sparse(table) => table.is_binary_binned_feature(index),
}
}
fn target_value(&self, row_index: usize) -> f64 {
match self {
Table::Dense(table) => table.target().value(row_index),
Table::Sparse(table) => table.target().value(row_index),
}
}
}
fn preprocess_rows(x: &[Vec<f64>], y: Vec<f64>) -> Result<PreprocessedRows, DenseTableError> {
validate_shape(x, &y)?;
let n_rows = x.len();
let n_features = x.first().map_or(0, Vec::len);
let columns = collect_columns(x, n_features);
Ok((columns, Float64Array::from(y), n_rows, n_features))
}
fn validate_shape(x: &[Vec<f64>], y: &[f64]) -> Result<(), DenseTableError> {
if x.len() != y.len() {
return Err(DenseTableError::MismatchedLengths {
x: x.len(),
y: y.len(),
});
}
let n_features = x.first().map_or(0, Vec::len);
for (row_idx, row) in x.iter().enumerate() {
if row.len() != n_features {
return Err(DenseTableError::RaggedRows {
row: row_idx,
expected: n_features,
actual: row.len(),
});
}
}
Ok(())
}
fn collect_columns(x: &[Vec<f64>], n_features: usize) -> Vec<Vec<f64>> {
(0..n_features)
.map(|col_idx| x.iter().map(|row| row[col_idx]).collect())
.collect()
}
fn validate_binary_columns(columns: &[Vec<f64>]) -> Result<(), DenseTableError> {
for (column_idx, column) in columns.iter().enumerate() {
if !is_sparse_binary_eligible_column(column) {
return Err(DenseTableError::NonBinaryColumn { column: column_idx });
}
}
Ok(())
}
fn build_feature_column(values: &[f64]) -> FeatureColumn {
if is_binary_column(values) {
FeatureColumn::Binary(BooleanArray::from(to_binary_values(values)))
} else {
FeatureColumn::Numeric(Float64Array::from(values.to_vec()))
}
}
fn build_binned_feature_column(values: &[f64], numeric_bins: NumericBins) -> BinnedFeatureColumn {
if is_binary_column(values) {
BinnedFeatureColumn::Binary(BooleanArray::from(to_binary_values(values)))
} else {
let bins = bin_numeric_column(values, numeric_bins);
if bins.iter().all(|value| *value <= u16::from(u8::MAX)) {
BinnedFeatureColumn::NumericU8(UInt8Array::from(
bins.into_iter()
.map(|value| value as u8)
.collect::<Vec<_>>(),
))
} else {
BinnedFeatureColumn::NumericU16(UInt16Array::from(bins))
}
}
}
fn is_binary_column(values: &[f64]) -> bool {
values.iter().all(|value| {
value.is_nan()
|| matches!(value.total_cmp(&0.0), Ordering::Equal)
|| matches!(value.total_cmp(&1.0), Ordering::Equal)
})
}
fn is_sparse_binary_eligible_column(values: &[f64]) -> bool {
values.iter().all(|value| {
matches!(value.total_cmp(&0.0), Ordering::Equal)
|| matches!(value.total_cmp(&1.0), Ordering::Equal)
})
}
fn to_binary_values(values: &[f64]) -> Vec<Option<bool>> {
values
.iter()
.map(|value| {
if value.is_nan() {
None
} else {
Some(value.total_cmp(&1.0) == Ordering::Equal)
}
})
.collect()
}
fn sparse_binary_column_from_values(values: &[f64]) -> SparseBinaryColumn {
SparseBinaryColumn {
row_indices: values
.iter()
.enumerate()
.filter_map(|(row_idx, value)| {
(value.total_cmp(&1.0) == Ordering::Equal).then_some(row_idx)
})
.collect(),
}
}
pub fn numeric_bin_boundaries(values: &[f64], numeric_bins: NumericBins) -> Vec<(u16, f64)> {
if values.is_empty() {
return Vec::new();
}
let mut ranked_values: Vec<(usize, f64)> = values
.iter()
.copied()
.enumerate()
.filter(|(_, value)| !value.is_nan())
.collect();
if ranked_values.is_empty() {
return Vec::new();
}
ranked_values.sort_by(|left, right| left.1.total_cmp(&right.1));
let unique_value_count = ranked_values
.iter()
.map(|(_row_idx, value)| *value)
.fold(Vec::<f64>::new(), |mut unique_values, value| {
let is_new_value = unique_values
.last()
.is_none_or(|last_value| last_value.total_cmp(&value) != Ordering::Equal);
if is_new_value {
unique_values.push(value);
}
unique_values
})
.len();
let bin_count =
resolved_numeric_bin_count(ranked_values.len(), unique_value_count, numeric_bins);
let mut unique_rank = 0usize;
let mut start = 0usize;
let mut boundaries = Vec::new();
while start < ranked_values.len() {
let current_value = ranked_values[start].1;
let end = ranked_values[start..]
.iter()
.position(|(_row_idx, value)| value.total_cmp(¤t_value) != Ordering::Equal)
.map_or(ranked_values.len(), |offset| start + offset);
let bin = match numeric_bins {
NumericBins::Auto => ((start * bin_count) / ranked_values.len()) as u16,
NumericBins::Fixed(_) => {
let max_bin = (bin_count - 1) as u16;
if unique_value_count == 1 {
0
} else {
((unique_rank * usize::from(max_bin)) / (unique_value_count - 1)) as u16
}
}
};
if let Some((last_bin, last_upper_bound)) = boundaries.last_mut() {
if *last_bin == bin {
*last_upper_bound = current_value;
} else {
boundaries.push((bin, current_value));
}
} else {
boundaries.push((bin, current_value));
}
unique_rank += 1;
start = end;
}
boundaries
}
fn bin_numeric_column(values: &[f64], numeric_bins: NumericBins) -> Vec<u16> {
if values.is_empty() {
return Vec::new();
}
let missing_bin = numeric_missing_bin(numeric_bins);
let mut bins = vec![missing_bin; values.len()];
let mut ranked_values: Vec<(usize, f64)> = values
.iter()
.copied()
.enumerate()
.filter(|(_, value)| !value.is_nan())
.collect();
if ranked_values.is_empty() {
return bins;
}
ranked_values.sort_by(|left, right| left.1.total_cmp(&right.1));
let unique_value_count = ranked_values
.iter()
.map(|(_row_idx, value)| *value)
.fold(Vec::<f64>::new(), |mut unique_values, value| {
let is_new_value = unique_values
.last()
.is_none_or(|last_value| last_value.total_cmp(&value) != Ordering::Equal);
if is_new_value {
unique_values.push(value);
}
unique_values
})
.len();
let bin_count =
resolved_numeric_bin_count(ranked_values.len(), unique_value_count, numeric_bins);
let mut unique_rank = 0usize;
let mut start = 0usize;
while start < ranked_values.len() {
let current_value = ranked_values[start].1;
let end = ranked_values[start..]
.iter()
.position(|(_row_idx, value)| value.total_cmp(¤t_value) != Ordering::Equal)
.map_or(ranked_values.len(), |offset| start + offset);
let bin = match numeric_bins {
NumericBins::Auto => ((start * bin_count) / ranked_values.len()) as u16,
NumericBins::Fixed(_) => {
let max_bin = (bin_count - 1) as u16;
if unique_value_count == 1 {
0
} else {
((unique_rank * usize::from(max_bin)) / (unique_value_count - 1)) as u16
}
}
};
for (row_idx, _value) in &ranked_values[start..end] {
bins[*row_idx] = bin;
}
unique_rank += 1;
start = end;
}
bins
}
fn resolved_numeric_bin_count(
value_count: usize,
unique_value_count: usize,
numeric_bins: NumericBins,
) -> usize {
match numeric_bins {
NumericBins::Auto => {
let populated_bin_cap = (value_count / 2).max(1);
let capped_unique_values = unique_value_count
.min(MAX_NUMERIC_BINS)
.min(populated_bin_cap)
.max(1);
highest_power_of_two_at_most(capped_unique_values)
}
NumericBins::Fixed(requested) => requested.min(unique_value_count).max(1),
}
}
pub fn numeric_missing_bin(numeric_bins: NumericBins) -> u16 {
numeric_bins.cap() as u16
}
fn highest_power_of_two_at_most(value: usize) -> usize {
if value <= 1 {
1
} else {
1usize << (usize::BITS as usize - 1 - value.leading_zeros() as usize)
}
}
fn shuffle_canary_column(
values: &BinnedFeatureColumn,
copy_index: usize,
source_index: usize,
) -> BinnedFeatureColumn {
match values {
BinnedFeatureColumn::NumericU8(values) => {
let mut shuffled = (0..values.len())
.map(|idx| values.value(idx))
.collect::<Vec<_>>();
shuffle_values(&mut shuffled, copy_index, source_index);
BinnedFeatureColumn::NumericU8(UInt8Array::from(shuffled))
}
BinnedFeatureColumn::NumericU16(values) => {
let mut shuffled = (0..values.len())
.map(|idx| values.value(idx))
.collect::<Vec<_>>();
shuffle_values(&mut shuffled, copy_index, source_index);
BinnedFeatureColumn::NumericU16(UInt16Array::from(shuffled))
}
BinnedFeatureColumn::Binary(values) => {
BinnedFeatureColumn::Binary(shuffle_boolean_array(values, copy_index, source_index))
}
}
}
fn shuffle_boolean_array(
values: &BooleanArray,
copy_index: usize,
source_index: usize,
) -> BooleanArray {
let mut shuffled = (0..values.len())
.map(|idx| values.value(idx))
.collect::<Vec<_>>();
shuffle_values(&mut shuffled, copy_index, source_index);
BooleanArray::from(shuffled)
}
fn shuffle_sparse_binary_column(
values: &SparseBinaryColumn,
n_rows: usize,
copy_index: usize,
source_index: usize,
) -> SparseBinaryColumn {
let mut dense = vec![false; n_rows];
for row_idx in &values.row_indices {
dense[*row_idx] = true;
}
shuffle_values(&mut dense, copy_index, source_index);
SparseBinaryColumn {
row_indices: dense
.into_iter()
.enumerate()
.filter_map(|(row_idx, value)| value.then_some(row_idx))
.collect(),
}
}
fn shuffle_values<T>(values: &mut [T], copy_index: usize, source_index: usize) {
let seed = 0xA11CE5EED_u64
^ ((copy_index as u64) << 32)
^ (source_index as u64)
^ ((values.len() as u64) << 16);
let mut rng = StdRng::seed_from_u64(seed);
values.shuffle(&mut rng);
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::{BTreeMap, BTreeSet};
#[test]
fn builds_arrow_backed_dense_table() {
let table =
DenseTable::new(vec![vec![0.0, 10.0], vec![1.0, 20.0]], vec![3.0, 5.0]).unwrap();
assert_eq!(table.n_rows(), 2);
assert_eq!(table.n_features(), 2);
assert_eq!(table.canaries(), 2);
assert_eq!(table.binned_feature_count(), 6);
assert_eq!(table.feature_value(0, 0), 0.0);
assert_eq!(table.feature_value(0, 1), 1.0);
assert_eq!(table.target().value(0), 3.0);
assert_eq!(table.target().value(1), 5.0);
assert!(!table.is_canary_binned_feature(0));
assert!(table.is_canary_binned_feature(2));
}
#[test]
fn builds_sparse_table_for_all_binary_features() {
let table = Table::with_canaries(
vec![vec![0.0, 1.0], vec![1.0, 0.0], vec![1.0, 1.0]],
vec![0.0, 1.0, 1.0],
1,
)
.unwrap();
assert_eq!(table.kind(), TableKind::Sparse);
assert!(table.is_binary_feature(0));
assert!(table.is_binary_feature(1));
assert!(table.is_binary_binned_feature(0));
assert_eq!(table.binned_feature_count(), 4);
}
#[test]
fn builds_dense_table_when_any_feature_is_non_binary() {
let table = Table::with_canaries(
vec![vec![0.0, 1.5], vec![1.0, 0.0], vec![1.0, 2.0]],
vec![0.0, 1.0, 1.0],
1,
)
.unwrap();
assert_eq!(table.kind(), TableKind::Dense);
assert!(table.is_binary_feature(0));
assert!(!table.is_binary_feature(1));
}
#[test]
fn sparse_table_rejects_non_binary_columns() {
let err =
SparseTable::with_canaries(vec![vec![0.0, 2.0], vec![1.0, 0.0]], vec![0.0, 1.0], 0)
.unwrap_err();
assert_eq!(err, DenseTableError::NonBinaryColumn { column: 1 });
}
#[test]
fn auto_bins_numeric_columns_into_power_of_two_bins_up_to_128() {
let x: Vec<Vec<f64>> = (0..1024).map(|value| vec![value as f64]).collect();
let y: Vec<f64> = vec![1.0; 1024];
let table = DenseTable::with_canaries(x, y, 0).unwrap();
assert_eq!(table.binned_value(0, 0), 0);
assert_eq!(table.binned_value(0, 1023), 127);
assert!((1..1024).all(|idx| table.binned_value(0, idx - 1) <= table.binned_value(0, idx)));
assert_eq!(
(0..1024)
.map(|idx| table.binned_value(0, idx))
.collect::<BTreeSet<_>>()
.len(),
128
);
}
#[test]
fn auto_bins_choose_highest_populated_power_of_two() {
let x: Vec<Vec<f64>> = (0..300).map(|value| vec![value as f64]).collect();
let y = vec![0.0; 300];
let table = DenseTable::with_canaries(x, y, 0).unwrap();
assert_eq!(
(0..300)
.map(|idx| table.binned_value(0, idx))
.collect::<BTreeSet<_>>()
.len(),
128
);
}
#[test]
fn auto_bins_require_at_least_two_rows_per_bin() {
let x: Vec<Vec<f64>> = (0..8).map(|value| vec![value as f64]).collect();
let y = vec![0.0; 8];
let table = DenseTable::with_canaries(x, y, 0).unwrap();
let counts = (0..table.n_rows()).fold(BTreeMap::new(), |mut counts, row_idx| {
*counts
.entry(table.binned_value(0, row_idx))
.or_insert(0usize) += 1;
counts
});
assert_eq!(counts.len(), 4);
assert!(counts.values().all(|count| *count >= 2));
}
#[test]
fn fixed_bins_cap_numeric_columns_to_requested_limit() {
let x: Vec<Vec<f64>> = (0..300).map(|value| vec![value as f64]).collect();
let y = vec![0.0; 300];
let table = DenseTable::with_options(x, y, 0, NumericBins::Fixed(64)).unwrap();
assert_eq!(
(0..300)
.map(|idx| table.binned_value(0, idx))
.collect::<BTreeSet<_>>()
.len(),
64
);
}
#[test]
fn rejects_invalid_fixed_bin_count() {
assert_eq!(
NumericBins::fixed(0).unwrap_err(),
DenseTableError::InvalidBinCount { requested: 0 }
);
assert_eq!(
NumericBins::fixed(513).unwrap_err(),
DenseTableError::InvalidBinCount { requested: 513 }
);
}
#[test]
fn keeps_equal_values_in_the_same_bin() {
let table = DenseTable::with_canaries(
vec![vec![0.0], vec![0.0], vec![1.0], vec![1.0], vec![2.0]],
vec![0.0; 5],
0,
)
.unwrap();
assert_eq!(table.binned_value(0, 0), table.binned_value(0, 1));
assert_eq!(table.binned_value(0, 2), table.binned_value(0, 3));
assert!(table.binned_value(0, 1) <= table.binned_value(0, 2));
assert!(table.binned_value(0, 3) < table.binned_value(0, 4));
}
#[test]
fn stores_binary_columns_as_booleans() {
let table = DenseTable::with_canaries(
vec![vec![0.0, 2.0], vec![1.0, 3.0], vec![0.0, 4.0]],
vec![0.0; 3],
1,
)
.unwrap();
assert!(table.is_binary_feature(0));
assert!(!table.is_binary_feature(1));
assert!(table.is_binary_binned_feature(0));
assert!(!table.is_binary_binned_feature(1));
assert!(table.is_binary_binned_feature(2));
assert_eq!(table.feature_value(0, 0), 0.0);
assert_eq!(table.feature_value(0, 1), 1.0);
assert_eq!(table.binned_boolean_value(0, 0), Some(false));
assert_eq!(table.binned_boolean_value(0, 1), Some(true));
}
#[test]
fn stores_small_auto_binned_numeric_columns_as_u8() {
let table = DenseTable::with_canaries(
(0..8).map(|value| vec![value as f64]).collect(),
vec![0.0; 8],
0,
)
.unwrap();
assert!(matches!(
table.binned_feature_column(0),
BinnedFeatureColumnRef::NumericU8(_)
));
}
#[test]
fn creates_canary_columns_as_shuffled_binned_copies() {
let table = DenseTable::with_canaries(
vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]],
vec![0.0; 5],
1,
)
.unwrap();
assert!(matches!(
table.binned_column_kind(1),
BinnedColumnKind::Canary {
source_index: 0,
copy_index: 0
}
));
assert_eq!(
(0..table.n_rows())
.map(|idx| table.binned_value(0, idx))
.collect::<BTreeSet<_>>(),
(0..table.n_rows())
.map(|idx| table.binned_value(1, idx))
.collect::<BTreeSet<_>>()
);
assert_ne!(
(0..table.n_rows())
.map(|idx| table.binned_value(0, idx))
.collect::<Vec<_>>(),
(0..table.n_rows())
.map(|idx| table.binned_value(1, idx))
.collect::<Vec<_>>()
);
}
#[test]
fn rejects_ragged_rows() {
let err = DenseTable::new(vec![vec![1.0, 2.0], vec![3.0]], vec![1.0, 2.0]).unwrap_err();
assert_eq!(
err,
DenseTableError::RaggedRows {
row: 1,
expected: 2,
actual: 1,
}
);
}
#[test]
fn rejects_mismatched_lengths() {
let err = DenseTable::new(vec![vec![1.0], vec![2.0]], vec![1.0]).unwrap_err();
assert_eq!(err, DenseTableError::MismatchedLengths { x: 2, y: 1 });
}
#[test]
fn canary_generation_is_deterministic_for_identical_inputs() {
let left = DenseTable::with_canaries(
vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]],
vec![0.0; 5],
2,
)
.unwrap();
let right = DenseTable::with_canaries(
vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]],
vec![0.0; 5],
2,
)
.unwrap();
let left_values = binned_snapshot(&left);
let right_values = binned_snapshot(&right);
assert_eq!(left_values, right_values);
}
#[test]
fn binary_canaries_remain_boolean_and_preserve_value_counts() {
let table = DenseTable::with_canaries(
vec![
vec![0.0],
vec![1.0],
vec![0.0],
vec![1.0],
vec![1.0],
vec![0.0],
],
vec![0.0; 6],
2,
)
.unwrap();
let real_true_count = (0..table.n_rows())
.filter(|row_idx| table.binned_boolean_value(0, *row_idx) == Some(true))
.count();
for feature_index in 1..table.binned_feature_count() {
assert!(table.is_binary_binned_feature(feature_index));
let canary_true_count = (0..table.n_rows())
.filter(|row_idx| table.binned_boolean_value(feature_index, *row_idx) == Some(true))
.count();
assert_eq!(canary_true_count, real_true_count);
}
}
#[test]
fn numeric_bin_boundaries_capture_training_bin_upper_bounds() {
let boundaries = numeric_bin_boundaries(&[1.0, 1.0, 2.0, 10.0], NumericBins::Auto);
assert_eq!(boundaries, vec![(0, 1.0), (1, 10.0)]);
}
fn binned_snapshot(table: &DenseTable) -> Vec<u16> {
let mut values = Vec::new();
for feature_idx in 0..table.binned_feature_count() {
for row_idx in 0..table.n_rows() {
values.push(table.binned_value(feature_idx, row_idx));
}
}
values
}
}