use ndarray::{Array2, Array3, ArrayView2, ArrayView3};
use serde::{Deserialize, Serialize};
use crate::encoders::{EncoderStack, MabConfig};
use crate::layers::{OneHotAndLinear, layer_norm_last, linear3d};
use crate::state_dict::{StateDict, StateDictError};
use crate::tabicl::Activation;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ICLearningConfig {
pub max_classes: usize,
pub out_dim: usize,
pub d_model: usize,
pub num_blocks: usize,
pub nhead: usize,
pub dim_feedforward: usize,
pub dropout: f32,
pub activation: Activation,
pub norm_first: bool,
pub bias_free_ln: bool,
pub ssmax: String,
pub recompute: bool,
}
impl ICLearningConfig {
pub fn is_regression(&self) -> bool {
self.max_classes == 0
}
}
#[derive(Debug, Clone)]
pub struct ICLearningParams {
pub y_one_hot: Option<OneHotAndLinear>,
pub y_linear: Option<(Array2<f32>, Option<Vec<f32>>)>,
pub ln_gamma: Option<Vec<f32>>,
pub ln_beta: Option<Vec<f32>>,
pub decoder_w1: Array2<f32>,
pub decoder_b1: Option<Vec<f32>>,
pub decoder_w2: Array2<f32>,
pub decoder_b2: Option<Vec<f32>>,
}
impl ICLearningParams {
pub fn zeros(cfg: &ICLearningConfig) -> Self {
let d = cfg.d_model;
let hid = d * 2;
Self {
y_one_hot: if cfg.max_classes > 0 {
Some(OneHotAndLinear::from_raw_weight(
Array2::<f32>::zeros((d, cfg.max_classes)),
Some(vec![0.0; d]),
))
} else {
None
},
y_linear: if cfg.is_regression() {
Some((Array2::<f32>::zeros((d, 1)), Some(vec![0.0; d])))
} else {
None
},
ln_gamma: if cfg.norm_first {
Some(vec![1.0; d])
} else {
None
},
ln_beta: if cfg.norm_first && !cfg.bias_free_ln {
Some(vec![0.0; d])
} else {
None
},
decoder_w1: Array2::<f32>::zeros((hid, d)),
decoder_b1: Some(vec![0.0; hid]),
decoder_w2: Array2::<f32>::zeros((cfg.out_dim, hid)),
decoder_b2: Some(vec![0.0; cfg.out_dim]),
}
}
}
#[derive(Debug, Clone)]
pub struct ICLearning {
pub config: ICLearningConfig,
pub params: ICLearningParams,
pub encoder: EncoderStack,
}
impl ICLearning {
pub fn load_from(&mut self, sd: &StateDict, prefix: &str) -> Result<(), StateDictError> {
let ssmax_kind = crate::ssmax::SsmaxKind::parse(&self.config.ssmax)
.unwrap_or(crate::ssmax::SsmaxKind::None);
self.encoder
.load_from_with_ssmax(sd, &format!("{prefix}.tf_icl"), ssmax_kind)?;
if self.config.norm_first {
self.params.ln_gamma =
Some(sd.take_vec(&format!("{prefix}.ln.weight"), self.config.d_model)?);
let beta_key = format!("{prefix}.ln.bias");
if sd.tensors.contains_key(&beta_key) {
self.params.ln_beta = Some(sd.take_vec(&beta_key, self.config.d_model)?);
}
}
let d = self.config.d_model;
let hid = d * 2;
if let Some(enc) = self.params.y_one_hot.as_mut() {
enc.load_from(sd, &format!("{prefix}.y_encoder"))?;
} else if let Some((w, bias)) = self.params.y_linear.as_mut() {
*w = sd.take_array2(&format!("{prefix}.y_encoder.weight"), d, 1)?;
let bias_key = format!("{prefix}.y_encoder.bias");
if sd.tensors.contains_key(&bias_key) {
*bias = Some(sd.take_vec(&bias_key, d)?);
}
}
self.params.decoder_w1 = sd.take_array2(&format!("{prefix}.decoder.0.weight"), hid, d)?;
let b1k = format!("{prefix}.decoder.0.bias");
if sd.tensors.contains_key(&b1k) {
self.params.decoder_b1 = Some(sd.take_vec(&b1k, hid)?);
}
self.params.decoder_w2 = sd.take_array2(
&format!("{prefix}.decoder.2.weight"),
self.config.out_dim,
hid,
)?;
let b2k = format!("{prefix}.decoder.2.bias");
if sd.tensors.contains_key(&b2k) {
self.params.decoder_b2 = Some(sd.take_vec(&b2k, self.config.out_dim)?);
}
Ok(())
}
pub fn new(config: ICLearningConfig) -> Self {
let params = ICLearningParams::zeros(&config);
let mab_cfg = MabConfig {
d_model: config.d_model,
nhead: config.nhead,
dim_feedforward: config.dim_feedforward,
dropout: config.dropout,
activation: config.activation,
norm_first: config.norm_first,
bias_free_ln: config.bias_free_ln,
};
let encoder = EncoderStack::new(config.num_blocks, mab_cfg, None)
.expect("ICLearning: d_model must be divisible by nhead");
Self {
config,
params,
encoder,
}
}
pub fn forward(
&self,
r: ArrayView3<f32>,
y_train_class: Option<ArrayView2<usize>>,
y_train_reg: Option<ArrayView2<f32>>,
) -> Array3<f32> {
let (b, t, d) = (r.shape()[0], r.shape()[1], r.shape()[2]);
assert_eq!(d, self.config.d_model);
let ry_train: Array3<f32> = match (
&self.params.y_one_hot,
&self.params.y_linear,
y_train_class,
y_train_reg,
) {
(Some(enc), _, Some(y_cls), None) => enc.forward(y_cls),
(_, Some((w, bias)), None, Some(y_reg)) => {
let (br, tr) = (y_reg.shape()[0], y_reg.shape()[1]);
let mut y3 = Array3::<f32>::zeros((br, tr, 1));
for bi in 0..br {
for ti in 0..tr {
y3[(bi, ti, 0)] = y_reg[(bi, ti)];
}
}
linear3d(y3.view(), w.view(), bias.as_deref())
}
_ => panic!("y_train shape doesn't match task type (classification vs regression)"),
};
let train_size = ry_train.shape()[1];
assert!(train_size <= t, "train_size {train_size} > total {t}");
let mut r_aug = r.to_owned();
for bi in 0..b {
for ti in 0..train_size {
for di in 0..d {
r_aug[(bi, ti, di)] += ry_train[(bi, ti, di)];
}
}
}
let src = self
.encoder
.forward_train_size(r_aug.view(), Some(train_size));
let src_normed = match &self.params.ln_gamma {
Some(g) => layer_norm_last(src.view(), g, self.params.ln_beta.as_deref(), 1e-5),
None => src,
};
let mut h = linear3d(
src_normed.view(),
self.params.decoder_w1.view(),
self.params.decoder_b1.as_deref(),
);
for v in h.iter_mut() {
let xv = *v;
*v = 0.5 * xv * (1.0 + erf_f32(xv / std::f32::consts::SQRT_2));
}
linear3d(
h.view(),
self.params.decoder_w2.view(),
self.params.decoder_b2.as_deref(),
)
}
}
fn erf_f32(x: f32) -> f32 {
let sign = x.signum();
let ax = x.abs();
let t = 1.0 / (1.0 + 0.3275911 * ax);
let y = 1.0
- (((((1.061_405_4_f32 * t - 1.453_152_1) * t + 1.421_413_8) * t - 0.284_496_72) * t
+ 0.254_829_6)
* t)
* (-ax * ax).exp();
sign * y
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array;
fn small_cfg(max_classes: usize, out_dim: usize) -> ICLearningConfig {
ICLearningConfig {
max_classes,
out_dim,
d_model: 4,
num_blocks: 1,
nhead: 2,
dim_feedforward: 8,
dropout: 0.0,
activation: Activation::Gelu,
norm_first: true,
bias_free_ln: false,
ssmax: "none".into(),
recompute: false,
}
}
#[test]
fn classification_forward_output_shape() {
let cfg = small_cfg(3, 3);
let icl = ICLearning::new(cfg);
let r = Array::from_shape_fn((2, 5, 4), |(b, t, d)| (b * 100 + t * 10 + d) as f32 * 0.001);
let y_train: Array2<usize> =
Array::from_shape_vec((2, 3), vec![0_usize, 1, 2, 2, 1, 0]).unwrap();
let out = icl.forward(r.view(), Some(y_train.view()), None);
assert_eq!(out.shape(), &[2, 5, 3]);
}
#[test]
fn regression_forward_output_shape() {
let cfg = small_cfg(0, 999); let icl = ICLearning::new(cfg);
let r = Array::from_shape_fn((1, 4, 4), |(b, t, d)| (b * 16 + t * 4 + d) as f32 * 0.01);
let y_train: Array2<f32> = Array::from_shape_vec((1, 2), vec![0.5_f32, 1.5]).unwrap();
let out = icl.forward(r.view(), None, Some(y_train.view()));
assert_eq!(out.shape(), &[1, 4, 999]);
}
#[test]
fn zero_init_decoder_gives_zero_logits() {
let cfg = small_cfg(4, 4);
let icl = ICLearning::new(cfg);
let r = Array::from_shape_fn((1, 3, 4), |(_, t, d)| (t * 4 + d) as f32 * 0.01);
let y_train: Array2<usize> = Array::from_shape_vec((1, 2), vec![0_usize, 1]).unwrap();
let out = icl.forward(r.view(), Some(y_train.view()), None);
for v in out.iter() {
assert!(v.abs() < 1e-5, "expected zero logit, got {}", v);
}
}
#[test]
fn rejects_mismatched_task_type() {
let cfg = small_cfg(3, 3); let icl = ICLearning::new(cfg);
let r = Array::from_shape_vec((1, 2, 4), vec![0.0_f32; 8]).unwrap();
let y_reg = Array::from_shape_vec((1, 1), vec![1.0_f32]).unwrap();
let result = std::panic::catch_unwind(|| {
icl.forward(r.view(), None, Some(y_reg.view()));
});
assert!(result.is_err());
}
}