use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
#[non_exhaustive]
#[derive(Default)]
pub enum Activation {
#[default]
Gelu,
Relu,
Silu,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
#[non_exhaustive]
#[derive(Default)]
pub enum ColFeatureGroup {
None,
#[default]
Same,
Valid,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TabICLConfig {
pub max_classes: usize,
pub num_quantiles: usize,
pub embed_dim: usize,
pub col_num_blocks: usize,
pub col_nhead: usize,
pub col_num_inds: usize,
pub col_affine: bool,
pub col_feature_group: ColFeatureGroup,
pub col_feature_group_size: usize,
pub col_target_aware: bool,
pub col_ssmax: String,
pub row_num_blocks: usize,
pub row_nhead: usize,
pub row_num_cls: usize,
pub row_rope_base: f32,
pub row_rope_interleaved: bool,
pub icl_num_blocks: usize,
pub icl_nhead: usize,
pub icl_ssmax: String,
pub ff_factor: usize,
pub dropout: f32,
pub activation: Activation,
pub norm_first: bool,
pub bias_free_ln: bool,
pub recompute: bool,
}
impl Default for TabICLConfig {
fn default() -> Self {
Self {
max_classes: 10,
num_quantiles: 999,
embed_dim: 128,
col_num_blocks: 3,
col_nhead: 8,
col_num_inds: 128,
col_affine: false,
col_feature_group: ColFeatureGroup::Same,
col_feature_group_size: 3,
col_target_aware: true,
col_ssmax: "qassmax-mlp-elementwise".into(),
row_num_blocks: 3,
row_nhead: 8,
row_num_cls: 4,
row_rope_base: 100_000.0,
row_rope_interleaved: false,
icl_num_blocks: 12,
icl_nhead: 8,
icl_ssmax: "qassmax-mlp-elementwise".into(),
ff_factor: 2,
dropout: 0.0,
activation: Activation::Gelu,
norm_first: true,
bias_free_ln: false,
recompute: false,
}
}
}
impl TabICLConfig {
pub fn out_dim(&self) -> usize {
if self.max_classes == 0 {
self.num_quantiles
} else {
self.max_classes
}
}
pub fn icl_dim(&self) -> usize {
self.embed_dim * self.row_num_cls
}
pub fn is_regression(&self) -> bool {
self.max_classes == 0
}
}
#[derive(Debug, Clone)]
pub struct TabICL {
pub config: TabICLConfig,
pub col: crate::embedding::ColEmbedding,
pub row: crate::interaction::RowInteraction,
pub icl: crate::learning::ICLearning,
}
impl TabICL {
pub fn load_from(&mut self, sd: &crate::StateDict) -> Result<(), crate::StateDictError> {
self.col.load_from(sd, "col_embedder")?;
self.row.load_from(sd, "row_interactor")?;
self.icl.load_from(sd, "icl_predictor")?;
Ok(())
}
pub fn load_from_file(
&mut self,
path: impl AsRef<std::path::Path>,
) -> Result<(), crate::StateDictError> {
let sd = crate::state_dict::load(path)?;
self.load_from(&sd)
}
pub fn new(config: TabICLConfig) -> Self {
let col = crate::embedding::ColEmbedding::new(crate::embedding::ColEmbeddingConfig {
embed_dim: config.embed_dim,
num_blocks: config.col_num_blocks,
nhead: config.col_nhead,
dim_feedforward: config.embed_dim * config.ff_factor,
num_inds: config.col_num_inds,
dropout: config.dropout,
activation: config.activation,
norm_first: config.norm_first,
bias_free_ln: config.bias_free_ln,
affine: config.col_affine,
feature_group: config.col_feature_group,
feature_group_size: config.col_feature_group_size,
target_aware: config.col_target_aware,
max_classes: config.max_classes,
reserve_cls_tokens: config.row_num_cls,
ssmax: config.col_ssmax.clone(),
mixed_radix_ensemble: true,
recompute: config.recompute,
});
let row =
crate::interaction::RowInteraction::new(crate::interaction::RowInteractionConfig {
embed_dim: config.embed_dim,
num_blocks: config.row_num_blocks,
nhead: config.row_nhead,
dim_feedforward: config.embed_dim * config.ff_factor,
num_cls: config.row_num_cls,
rope_base: config.row_rope_base,
rope_interleaved: config.row_rope_interleaved,
dropout: config.dropout,
activation: config.activation,
norm_first: config.norm_first,
bias_free_ln: config.bias_free_ln,
recompute: config.recompute,
});
let icl = crate::learning::ICLearning::new(crate::learning::ICLearningConfig {
max_classes: config.max_classes,
out_dim: config.out_dim(),
d_model: config.icl_dim(),
num_blocks: config.icl_num_blocks,
nhead: config.icl_nhead,
dim_feedforward: config.icl_dim() * config.ff_factor,
dropout: config.dropout,
activation: config.activation,
norm_first: config.norm_first,
bias_free_ln: config.bias_free_ln,
ssmax: config.icl_ssmax.clone(),
recompute: config.recompute,
});
Self {
config,
col,
row,
icl,
}
}
pub fn forward(
&self,
x: ndarray::ArrayView3<f32>,
y_train_class: Option<ndarray::ArrayView2<usize>>,
y_train_reg: Option<ndarray::ArrayView2<f32>>,
) -> Result<ndarray::Array3<f32>, crate::embedding::EmbeddingError> {
let train_size = y_train_class
.map(|y| y.shape()[1])
.or_else(|| y_train_reg.map(|y| y.shape()[1]))
.unwrap_or(0);
let emb = self
.col
.forward_with_targets(x, y_train_class, y_train_reg, train_size)?;
Ok(self.row_then_icl(emb.view(), y_train_class, y_train_reg))
}
pub fn build_repr_cache(
&self,
x_train: ndarray::ArrayView2<f32>,
y_train_class: Option<ndarray::ArrayView1<usize>>,
y_train_reg: Option<ndarray::ArrayView1<f32>>,
) -> Result<crate::kv_cache::TabICLCache, crate::embedding::EmbeddingError> {
let n_train = x_train.shape()[0];
let h = x_train.shape()[1];
let mut x_b = ndarray::Array3::<f32>::zeros((1, n_train, h));
for i in 0..n_train {
for j in 0..h {
x_b[(0, i, j)] = x_train[(i, j)];
}
}
let y_cls_b = y_train_class.map(|y| {
let mut a = ndarray::Array2::<usize>::zeros((1, y.len()));
for i in 0..y.len() {
a[(0, i)] = y[i];
}
a
});
let y_reg_b = y_train_reg.map(|y| {
let mut a = ndarray::Array2::<f32>::zeros((1, y.len()));
for i in 0..y.len() {
a[(0, i)] = y[i];
}
a
});
let emb = self.col.forward_with_targets(
x_b.view(),
y_cls_b.as_ref().map(|a| a.view()),
y_reg_b.as_ref().map(|a| a.view()),
n_train,
)?;
let r = self.row.forward(emb.view());
Ok(crate::kv_cache::TabICLCache::from_row_repr(
r,
(1, n_train, h),
if self.config.max_classes > 0 {
Some(self.config.max_classes)
} else {
None
},
))
}
pub fn forward_with_cache(
&self,
cache: &crate::kv_cache::TabICLCache,
x_test: ndarray::ArrayView2<f32>,
y_train_class: Option<ndarray::ArrayView1<usize>>,
y_train_reg: Option<ndarray::ArrayView1<f32>>,
) -> Result<ndarray::Array3<f32>, crate::embedding::EmbeddingError> {
let r_train = cache
.row_repr
.as_ref()
.expect("cache must be a repr-cache (call build_repr_cache first)");
let n_train = r_train.shape()[1];
let n_test = x_test.shape()[0];
let h = x_test.shape()[1];
let mut x_b = ndarray::Array3::<f32>::zeros((1, n_train + n_test, h));
for i in 0..n_test {
for j in 0..h {
x_b[(0, n_train + i, j)] = x_test[(i, j)];
}
}
let y_cls_b = y_train_class.map(|y| {
let mut a = ndarray::Array2::<usize>::zeros((1, y.len()));
for i in 0..y.len() {
a[(0, i)] = y[i];
}
a
});
let y_reg_b = y_train_reg.map(|y| {
let mut a = ndarray::Array2::<f32>::zeros((1, y.len()));
for i in 0..y.len() {
a[(0, i)] = y[i];
}
a
});
let emb = self.col.forward_with_targets(
x_b.view(),
y_cls_b.as_ref().map(|a| a.view()),
y_reg_b.as_ref().map(|a| a.view()),
n_train,
)?;
let mut r = self.row.forward(emb.view());
for i in 0..n_train {
for j in 0..r_train.shape()[2] {
r[(0, i, j)] = r_train[(0, i, j)];
}
}
Ok(self.icl.forward(
r.view(),
y_cls_b.as_ref().map(|a| a.view()),
y_reg_b.as_ref().map(|a| a.view()),
))
}
pub fn row_then_icl(
&self,
embeddings: ndarray::ArrayView4<f32>,
y_train_class: Option<ndarray::ArrayView2<usize>>,
y_train_reg: Option<ndarray::ArrayView2<f32>>,
) -> ndarray::Array3<f32> {
let r = self.row.forward(embeddings);
self.icl.forward(r.view(), y_train_class, y_train_reg)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn defaults_match_python() {
let c = TabICLConfig::default();
assert_eq!(c.max_classes, 10);
assert_eq!(c.num_quantiles, 999);
assert_eq!(c.embed_dim, 128);
assert_eq!(c.row_num_cls, 4);
assert_eq!(c.icl_num_blocks, 12);
assert_eq!(c.icl_dim(), 128 * 4);
assert_eq!(c.out_dim(), 10); assert!(c.col_target_aware);
assert!(!c.col_affine);
}
#[test]
fn regression_out_dim() {
let mut c = TabICLConfig::default();
c.max_classes = 0;
assert!(c.is_regression());
assert_eq!(c.out_dim(), 999);
}
#[test]
fn row_then_icl_runs_end_to_end_classification() {
let cfg = TabICLConfig {
max_classes: 3,
num_quantiles: 999,
embed_dim: 8,
col_num_blocks: 0, col_nhead: 2,
col_num_inds: 4,
col_affine: false,
col_feature_group: ColFeatureGroup::Same,
col_feature_group_size: 3,
col_target_aware: true,
col_ssmax: "none".into(),
row_num_blocks: 1,
row_nhead: 2,
row_num_cls: 2,
row_rope_base: 100_000.0,
row_rope_interleaved: false,
icl_num_blocks: 1,
icl_nhead: 2,
icl_ssmax: "none".into(),
ff_factor: 2,
dropout: 0.0,
activation: Activation::Gelu,
norm_first: true,
bias_free_ln: false,
recompute: false,
};
let model = TabICL::new(cfg);
let emb = ndarray::Array::from_shape_fn((1, 3, 4, 8), |(b, t, h, e)| {
(b * 100 + t * 10 + h) as f32 * 0.001 + (e as f32) * 0.0001
});
let y_train: ndarray::Array2<usize> =
ndarray::Array::from_shape_vec((1, 2), vec![0_usize, 1]).unwrap();
let out = model.row_then_icl(emb.view(), Some(y_train.view()), None);
assert_eq!(out.shape(), &[1, 3, 3]);
}
#[test]
fn end_to_end_col_row_icl_classification() {
let cfg = TabICLConfig {
max_classes: 3,
num_quantiles: 999,
embed_dim: 8,
col_num_blocks: 1,
col_nhead: 2,
col_num_inds: 4,
col_affine: false,
col_feature_group: ColFeatureGroup::None,
col_feature_group_size: 3,
col_target_aware: false, col_ssmax: "none".into(),
row_num_blocks: 1,
row_nhead: 2,
row_num_cls: 2,
row_rope_base: 100_000.0,
row_rope_interleaved: false,
icl_num_blocks: 1,
icl_nhead: 2,
icl_ssmax: "none".into(),
ff_factor: 2,
dropout: 0.0,
activation: Activation::Gelu,
norm_first: true,
bias_free_ln: false,
recompute: false,
};
let model = TabICL::new(cfg);
let x =
ndarray::Array::from_shape_fn((1, 5, 4), |(_, t, h)| (t * 4 + h) as f32 * 0.1 - 0.5);
let y_train: ndarray::Array2<usize> =
ndarray::Array::from_shape_vec((1, 3), vec![0_usize, 1, 2]).unwrap();
let out = model.forward(x.view(), Some(y_train.view()), None).unwrap();
assert_eq!(out.shape(), &[1, 5, 3]);
}
#[test]
fn repr_cache_roundtrip_runs_end_to_end() {
let mut cfg = TabICLConfig::default();
cfg.embed_dim = 8;
cfg.col_num_blocks = 1;
cfg.col_nhead = 2;
cfg.col_num_inds = 4;
cfg.row_num_blocks = 1;
cfg.row_nhead = 2;
cfg.row_num_cls = 2;
cfg.icl_num_blocks = 1;
cfg.icl_nhead = 2;
cfg.col_feature_group = ColFeatureGroup::None;
cfg.max_classes = 3;
let model = TabICL::new(cfg);
let x_train = ndarray::Array::from_shape_fn((4, 3), |(i, j)| (i * 3 + j) as f32 * 0.1);
let y_train = ndarray::Array::from_vec(vec![0_usize, 1, 2, 0]);
let cache = model
.build_repr_cache(x_train.view(), Some(y_train.view()), None)
.unwrap();
assert!(!cache.is_empty());
assert!(cache.row_repr.is_some());
let x_test = ndarray::Array::from_shape_fn((2, 3), |(_, j)| j as f32);
let out = model
.forward_with_cache(&cache, x_test.view(), Some(y_train.view()), None)
.unwrap();
assert_eq!(out.shape(), &[1, 6, 3]); }
#[test]
fn end_to_end_target_aware_default_config_runs() {
let mut cfg = TabICLConfig::default();
cfg.embed_dim = 8;
cfg.col_num_blocks = 1;
cfg.col_nhead = 2;
cfg.col_num_inds = 4;
cfg.row_num_blocks = 1;
cfg.row_nhead = 2;
cfg.row_num_cls = 2;
cfg.icl_num_blocks = 1;
cfg.icl_nhead = 2;
cfg.max_classes = 3;
let model = TabICL::new(cfg);
let x = ndarray::Array::from_shape_fn((1, 5, 4), |(_, t, h)| (t * 4 + h) as f32 * 0.1);
let y_train: ndarray::Array2<usize> =
ndarray::Array::from_shape_vec((1, 3), vec![0_usize, 1, 2]).unwrap();
let out = model.forward(x.view(), Some(y_train.view()), None).unwrap();
assert_eq!(out.shape(), &[1, 5, 3]);
}
#[test]
fn row_then_icl_regression_path() {
let cfg = TabICLConfig {
max_classes: 0,
num_quantiles: 5, embed_dim: 8,
col_num_blocks: 0,
col_nhead: 2,
col_num_inds: 4,
col_affine: false,
col_feature_group: ColFeatureGroup::Same,
col_feature_group_size: 3,
col_target_aware: true,
col_ssmax: "none".into(),
row_num_blocks: 1,
row_nhead: 2,
row_num_cls: 2,
row_rope_base: 100_000.0,
row_rope_interleaved: false,
icl_num_blocks: 1,
icl_nhead: 2,
icl_ssmax: "none".into(),
ff_factor: 2,
dropout: 0.0,
activation: Activation::Gelu,
norm_first: true,
bias_free_ln: false,
recompute: false,
};
let model = TabICL::new(cfg);
assert!(model.config.is_regression());
let emb =
ndarray::Array::from_shape_fn((1, 3, 4, 8), |(_, t, _, e)| (t * 8 + e) as f32 * 0.01);
let y_train: ndarray::Array2<f32> =
ndarray::Array::from_shape_vec((1, 2), vec![0.5_f32, 1.5]).unwrap();
let out = model.row_then_icl(emb.view(), None, Some(y_train.view()));
assert_eq!(out.shape(), &[1, 3, 5]);
}
}