use ndarray::{Array2, Array3, Array4, ArrayView4};
use serde::{Deserialize, Serialize};
use crate::encoders::{EncoderStack, MabConfig};
use crate::layers::layer_norm_last;
use crate::rope::RopeConfig;
use crate::state_dict::{StateDict, StateDictError};
use crate::tabicl::Activation;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RowInteractionConfig {
pub embed_dim: usize,
pub num_blocks: usize,
pub nhead: usize,
pub dim_feedforward: usize,
pub num_cls: usize,
pub rope_base: f32,
pub rope_interleaved: bool,
pub dropout: f32,
pub activation: Activation,
pub norm_first: bool,
pub bias_free_ln: bool,
pub recompute: bool,
}
impl Default for RowInteractionConfig {
fn default() -> Self {
Self {
embed_dim: 128,
num_blocks: 3,
nhead: 8,
dim_feedforward: 256,
num_cls: 4,
rope_base: 100_000.0,
rope_interleaved: true,
dropout: 0.0,
activation: Activation::Gelu,
norm_first: true,
bias_free_ln: false,
recompute: false,
}
}
}
impl RowInteractionConfig {
pub fn head_dim(&self) -> usize {
self.embed_dim / self.nhead
}
pub fn repr_dim(&self) -> usize {
self.num_cls * self.embed_dim
}
}
#[derive(Debug, Clone)]
pub struct RowInteractionParams {
pub cls_tokens: Array2<f32>,
pub out_ln_gamma: Option<Vec<f32>>,
pub out_ln_beta: Option<Vec<f32>>,
}
impl RowInteractionParams {
pub fn zeros(cfg: &RowInteractionConfig) -> Self {
Self {
cls_tokens: Array2::<f32>::zeros((cfg.num_cls, cfg.embed_dim)),
out_ln_gamma: if cfg.norm_first {
Some(vec![1.0; cfg.embed_dim])
} else {
None
},
out_ln_beta: if cfg.norm_first && !cfg.bias_free_ln {
Some(vec![0.0; cfg.embed_dim])
} else {
None
},
}
}
}
#[derive(Debug, Clone)]
pub struct RowInteraction {
pub config: RowInteractionConfig,
pub params: RowInteractionParams,
pub encoder: EncoderStack,
}
impl RowInteraction {
pub fn load_from(&mut self, sd: &StateDict, prefix: &str) -> Result<(), StateDictError> {
let cls_key = format!("{prefix}.cls_tokens");
self.params.cls_tokens =
sd.take_array2(&cls_key, self.config.num_cls, self.config.embed_dim)?;
if self.config.norm_first {
self.params.out_ln_gamma =
Some(sd.take_vec(&format!("{prefix}.out_ln.weight"), self.config.embed_dim)?);
let beta_key = format!("{prefix}.out_ln.bias");
if sd.tensors.contains_key(&beta_key) {
self.params.out_ln_beta = Some(sd.take_vec(&beta_key, self.config.embed_dim)?);
}
}
self.encoder.load_from(sd, &format!("{prefix}.tf_row"))?;
Ok(())
}
pub fn new(config: RowInteractionConfig) -> Self {
let params = RowInteractionParams::zeros(&config);
let mab_cfg = MabConfig {
d_model: config.embed_dim,
nhead: config.nhead,
dim_feedforward: config.dim_feedforward,
dropout: config.dropout,
activation: config.activation,
norm_first: config.norm_first,
bias_free_ln: config.bias_free_ln,
};
let rope = Some(RopeConfig {
head_dim: config.head_dim(),
base: config.rope_base,
interleaved: config.rope_interleaved,
});
let encoder = EncoderStack::new(config.num_blocks, mab_cfg, rope)
.expect("RowInteraction: d_model must be divisible by nhead");
Self {
config,
params,
encoder,
}
}
pub fn forward(&self, embeddings: ArrayView4<f32>) -> Array3<f32> {
let (b, t, hc, e) = (
embeddings.shape()[0],
embeddings.shape()[1],
embeddings.shape()[2],
embeddings.shape()[3],
);
assert_eq!(e, self.config.embed_dim, "embed_dim mismatch");
assert!(hc >= self.config.num_cls, "fewer tokens than CLS slots");
let mut buf = embeddings.to_owned();
for bi in 0..b {
for ti in 0..t {
for ci in 0..self.config.num_cls {
for ei in 0..e {
buf[(bi, ti, ci, ei)] = self.params.cls_tokens[(ci, ei)];
}
}
}
}
let bt = b * t;
let mut flat = Array3::<f32>::zeros((bt, hc, e));
for bi in 0..b {
for ti in 0..t {
for hi in 0..hc {
for ei in 0..e {
flat[(bi * t + ti, hi, ei)] = buf[(bi, ti, hi, ei)];
}
}
}
}
let out_flat = self.encoder.forward(flat.view());
let mut cls_out = Array3::<f32>::zeros((bt, self.config.num_cls, e));
for bti in 0..bt {
for ci in 0..self.config.num_cls {
for ei in 0..e {
cls_out[(bti, ci, ei)] = out_flat[(bti, ci, ei)];
}
}
}
let cls_norm = match (&self.params.out_ln_gamma, &self.params.out_ln_beta) {
(Some(g), beta) => layer_norm_last(cls_out.view(), g, beta.as_deref(), 1e-5),
_ => cls_out, };
let repr_dim = self.config.repr_dim();
let mut out = Array3::<f32>::zeros((b, t, repr_dim));
for bi in 0..b {
for ti in 0..t {
for ci in 0..self.config.num_cls {
for ei in 0..e {
out[(bi, ti, ci * e + ei)] = cls_norm[(bi * t + ti, ci, ei)];
}
}
}
}
out
}
}
#[allow(dead_code)]
fn _silence(_a: Array4<f32>) {}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array;
#[test]
fn forward_output_shape() {
let cfg = RowInteractionConfig {
embed_dim: 8,
num_blocks: 1,
nhead: 2,
dim_feedforward: 16,
num_cls: 4,
rope_base: 100_000.0,
rope_interleaved: false,
dropout: 0.0,
activation: Activation::Gelu,
norm_first: true,
bias_free_ln: false,
recompute: false,
};
let ri = RowInteraction::new(cfg);
let emb = Array::from_shape_fn((2, 3, 6, 8), |(b, t, h, e)| {
((b * 100 + t * 10 + h) as f32) * 0.001 + (e as f32) * 0.0001
});
let out = ri.forward(emb.view());
assert_eq!(out.shape(), &[2, 3, 4 * 8]);
}
#[test]
fn forward_cls_tokens_are_overwritten_then_propagated() {
let cfg = RowInteractionConfig {
embed_dim: 4,
num_blocks: 1,
nhead: 2,
dim_feedforward: 8,
num_cls: 2,
rope_base: 100_000.0,
rope_interleaved: false,
dropout: 0.0,
activation: Activation::Gelu,
norm_first: true,
bias_free_ln: false,
recompute: false,
};
let mut ri = RowInteraction::new(cfg);
for ei in 0..4 {
ri.params.cls_tokens[(0, ei)] = 1.0;
ri.params.cls_tokens[(1, ei)] = -1.0;
}
let emb = Array::from_shape_fn((1, 2, 4, 4), |(_, _, h, e)| (h * 10 + e) as f32);
let out = ri.forward(emb.view());
for b in 0..1 {
for t in 0..2 {
for k in 0..8 {
assert!(
out[(b, t, k)].abs() < 1e-4,
"constant CLS row should LN to zero: out[{b},{t},{k}] = {}",
out[(b, t, k)]
);
}
}
}
}
#[test]
fn defaults_match_python_signature() {
let c = RowInteractionConfig::default();
assert_eq!(c.num_cls, 4);
assert_eq!(c.embed_dim, 128);
assert_eq!(c.head_dim(), 16);
assert_eq!(c.repr_dim(), 4 * 128);
assert!(c.rope_interleaved);
assert!(c.norm_first);
}
#[test]
fn norm_first_controls_out_ln_params() {
let mut c = RowInteractionConfig::default();
c.norm_first = true;
c.bias_free_ln = false;
let p = RowInteractionParams::zeros(&c);
assert!(p.out_ln_gamma.is_some());
assert!(p.out_ln_beta.is_some());
c.bias_free_ln = true;
let p = RowInteractionParams::zeros(&c);
assert!(p.out_ln_gamma.is_some());
assert!(p.out_ln_beta.is_none());
c.norm_first = false;
let p = RowInteractionParams::zeros(&c);
assert!(p.out_ln_gamma.is_none());
assert!(p.out_ln_beta.is_none());
}
}