use ndarray::{Array3, ArrayView2, ArrayView3, Axis};
use thiserror::Error;
use crate::attention::{AttentionConfig, AttentionParams};
use crate::layers::{layer_norm_last, linear3d};
use crate::rope::{RopeConfig, RopeTables};
use crate::state_dict::{StateDict, StateDictError};
use crate::tabicl::Activation;
#[derive(Debug, Error)]
pub enum EncoderError {
#[error("d_model ({d_model}) must be divisible by nhead ({nhead})")]
BadDims { d_model: usize, nhead: usize },
}
#[derive(Debug, Clone)]
pub struct MabConfig {
pub d_model: usize,
pub nhead: usize,
pub dim_feedforward: usize,
pub dropout: f32,
pub activation: Activation,
pub norm_first: bool,
pub bias_free_ln: bool,
}
impl MabConfig {
pub fn head_dim(&self) -> usize {
self.d_model / self.nhead
}
pub fn attention_cfg(&self) -> AttentionConfig {
AttentionConfig {
embed_dim: self.d_model,
num_heads: self.nhead,
dropout: self.dropout,
bias: true,
}
}
}
#[derive(Debug, Clone)]
pub struct MabSsmax {
pub spec: crate::ssmax::SsmaxSpec,
pub params: crate::ssmax::SsmaxParams,
}
#[derive(Debug, Clone)]
pub struct MabParams {
pub norm1_gamma: Vec<f32>,
pub norm1_beta: Option<Vec<f32>>,
pub norm2_gamma: Vec<f32>,
pub norm2_beta: Option<Vec<f32>>,
pub attn: AttentionParams,
pub ssmax: Option<MabSsmax>,
pub linear1: ndarray::Array2<f32>,
pub linear1_bias: Option<Vec<f32>>,
pub linear2: ndarray::Array2<f32>,
pub linear2_bias: Option<Vec<f32>>,
}
impl MabParams {
pub fn load_from_with_ssmax(
&mut self,
sd: &StateDict,
prefix: &str,
cfg: &MabConfig,
ssmax_kind: crate::ssmax::SsmaxKind,
) -> Result<(), StateDictError> {
self.load_from(sd, prefix, cfg)?;
if let Some(spec) = crate::ssmax::SsmaxSpec::create(ssmax_kind, cfg.nhead, cfg.d_model)
.ok()
.flatten()
{
let mut params = crate::ssmax::SsmaxParams::zeros(&spec);
params.load_from(sd, &format!("{prefix}.attn"), &spec)?;
self.ssmax = Some(MabSsmax { spec, params });
}
Ok(())
}
pub fn load_from(
&mut self,
sd: &StateDict,
prefix: &str,
cfg: &MabConfig,
) -> Result<(), StateDictError> {
let d = cfg.d_model;
let ff = cfg.dim_feedforward;
self.norm1_gamma = sd.take_vec(&format!("{prefix}.norm1.weight"), d)?;
let n1_b_key = format!("{prefix}.norm1.bias");
if sd.tensors.contains_key(&n1_b_key) {
self.norm1_beta = Some(sd.take_vec(&n1_b_key, d)?);
}
self.norm2_gamma = sd.take_vec(&format!("{prefix}.norm2.weight"), d)?;
let n2_b_key = format!("{prefix}.norm2.bias");
if sd.tensors.contains_key(&n2_b_key) {
self.norm2_beta = Some(sd.take_vec(&n2_b_key, d)?);
}
self.attn.load_from(sd, &format!("{prefix}.attn"), d)?;
self.linear1 = sd.take_array2(&format!("{prefix}.linear1.weight"), ff, d)?;
let l1b = format!("{prefix}.linear1.bias");
if sd.tensors.contains_key(&l1b) {
self.linear1_bias = Some(sd.take_vec(&l1b, ff)?);
}
self.linear2 = sd.take_array2(&format!("{prefix}.linear2.weight"), d, ff)?;
let l2b = format!("{prefix}.linear2.bias");
if sd.tensors.contains_key(&l2b) {
self.linear2_bias = Some(sd.take_vec(&l2b, d)?);
}
Ok(())
}
pub fn zeros(cfg: &MabConfig) -> Self {
let d = cfg.d_model;
let ff = cfg.dim_feedforward;
Self {
norm1_gamma: vec![1.0; d],
norm1_beta: if cfg.bias_free_ln {
None
} else {
Some(vec![0.0; d])
},
norm2_gamma: vec![1.0; d],
norm2_beta: if cfg.bias_free_ln {
None
} else {
Some(vec![0.0; d])
},
attn: AttentionParams {
in_proj_weight: ndarray::Array2::<f32>::zeros((3 * d, d)),
in_proj_bias: Some(vec![0.0; 3 * d]),
out_proj_weight: ndarray::Array2::<f32>::zeros((d, d)),
out_proj_bias: Some(vec![0.0; d]),
},
ssmax: None,
linear1: ndarray::Array2::<f32>::zeros((ff, d)),
linear1_bias: Some(vec![0.0; ff]),
linear2: ndarray::Array2::<f32>::zeros((d, ff)),
linear2_bias: Some(vec![0.0; d]),
}
}
}
fn apply_activation(x: &mut Array3<f32>, kind: Activation) {
match kind {
Activation::Relu => {
for v in x.iter_mut() {
if *v < 0.0 {
*v = 0.0;
}
}
}
Activation::Gelu => {
for v in x.iter_mut() {
let xv = *v;
*v = 0.5 * xv * (1.0 + erf(xv / std::f32::consts::SQRT_2));
}
}
Activation::Silu => {
for v in x.iter_mut() {
*v = *v / (1.0 + (-*v).exp());
}
}
}
}
pub fn ff_block(x: ArrayView3<f32>, cfg: &MabConfig, params: &MabParams) -> Array3<f32> {
let mut h = linear3d(x, params.linear1.view(), params.linear1_bias.as_deref());
apply_activation(&mut h, cfg.activation);
linear3d(
h.view(),
params.linear2.view(),
params.linear2_bias.as_deref(),
)
}
fn ln_3d(x: ArrayView3<f32>, gamma: &[f32], beta: Option<&[f32]>) -> Array3<f32> {
layer_norm_last(x, gamma, beta, 1e-5)
}
pub fn mab_forward(
x: ArrayView3<f32>,
cfg: &MabConfig,
params: &MabParams,
rope: Option<&RopeTables>,
) -> Array3<f32> {
mab_forward_qkv(x, x, x, cfg, params, rope)
}
pub fn mab_forward_train_size(
x: ArrayView3<f32>,
cfg: &MabConfig,
params: &MabParams,
rope: Option<&RopeTables>,
train_size: Option<usize>,
) -> Array3<f32> {
match train_size {
Some(k) => {
let k = k.min(x.shape()[1]);
let k_v = x.slice(ndarray::s![.., ..k, ..]);
mab_forward_qkv_masked(x, k_v, k_v, cfg, params, rope, None)
}
None => mab_forward_qkv_masked(x, x, x, cfg, params, rope, None),
}
}
pub fn mab_forward_qkv(
q: ArrayView3<f32>,
k: ArrayView3<f32>,
v: ArrayView3<f32>,
cfg: &MabConfig,
params: &MabParams,
rope: Option<&RopeTables>,
) -> Array3<f32> {
mab_forward_qkv_masked(q, k, v, cfg, params, rope, None)
}
pub fn mab_forward_qkv_masked(
q: ArrayView3<f32>,
k: ArrayView3<f32>,
v: ArrayView3<f32>,
cfg: &MabConfig,
params: &MabParams,
rope: Option<&RopeTables>,
attn_mask: Option<ArrayView2<f32>>,
) -> Array3<f32> {
if cfg.norm_first {
let q_normed = ln_3d(q, ¶ms.norm1_gamma, params.norm1_beta.as_deref());
let same_shape = k.shape() == q.shape() && std::ptr::eq(k.as_ptr(), q.as_ptr());
let k_normed: Array3<f32> = if same_shape {
q_normed.clone()
} else {
ln_3d(k, ¶ms.norm1_gamma, params.norm1_beta.as_deref())
};
let v_same_as_k = v.shape() == k.shape() && std::ptr::eq(v.as_ptr(), k.as_ptr());
let v_same_as_q = v.shape() == q.shape() && std::ptr::eq(v.as_ptr(), q.as_ptr());
let v_normed: Array3<f32> = if v_same_as_k {
k_normed.clone()
} else if v_same_as_q {
q_normed.clone()
} else {
ln_3d(v, ¶ms.norm1_gamma, params.norm1_beta.as_deref())
};
let attn_out = crate::attention::multi_head_attention_forward_with_ssmax(
q_normed.view(),
k_normed.view(),
v_normed.view(),
¶ms.attn,
&cfg.attention_cfg(),
rope,
attn_mask,
params.ssmax.as_ref(),
);
let mut after_attn = q.to_owned() + &attn_out;
let ff_in = ln_3d(
after_attn.view(),
¶ms.norm2_gamma,
params.norm2_beta.as_deref(),
);
let ff_out = ff_block(ff_in.view(), cfg, params);
after_attn += &ff_out;
after_attn
} else {
let attn_out = crate::attention::multi_head_attention_forward_with_ssmax(
q,
k,
v,
¶ms.attn,
&cfg.attention_cfg(),
rope,
attn_mask,
params.ssmax.as_ref(),
);
let after_attn_pre = q.to_owned() + &attn_out;
let z = ln_3d(
after_attn_pre.view(),
¶ms.norm1_gamma,
params.norm1_beta.as_deref(),
);
let ff_out = ff_block(z.view(), cfg, params);
let combined = &z + &ff_out;
ln_3d(
combined.view(),
¶ms.norm2_gamma,
params.norm2_beta.as_deref(),
)
}
}
#[derive(Debug, Clone)]
pub struct EncoderStack {
pub mab_cfg: MabConfig,
pub blocks: Vec<MabParams>,
pub rope: Option<RopeConfig>,
}
impl EncoderStack {
pub fn load_from(&mut self, sd: &StateDict, prefix: &str) -> Result<(), StateDictError> {
self.load_from_with_ssmax(sd, prefix, crate::ssmax::SsmaxKind::None)
}
pub fn load_from_with_ssmax(
&mut self,
sd: &StateDict,
prefix: &str,
ssmax_kind: crate::ssmax::SsmaxKind,
) -> Result<(), StateDictError> {
let cfg = self.mab_cfg.clone();
for (i, block) in self.blocks.iter_mut().enumerate() {
block.load_from_with_ssmax(sd, &format!("{prefix}.blocks.{i}"), &cfg, ssmax_kind)?;
}
Ok(())
}
pub fn new(
num_blocks: usize,
mab_cfg: MabConfig,
rope: Option<RopeConfig>,
) -> Result<Self, EncoderError> {
if !mab_cfg.d_model.is_multiple_of(mab_cfg.nhead) {
return Err(EncoderError::BadDims {
d_model: mab_cfg.d_model,
nhead: mab_cfg.nhead,
});
}
let blocks = (0..num_blocks)
.map(|_| MabParams::zeros(&mab_cfg))
.collect();
Ok(Self {
mab_cfg,
blocks,
rope,
})
}
pub fn forward(&self, x: ArrayView3<f32>) -> Array3<f32> {
self.forward_train_size(x, None)
}
pub fn forward_train_size(&self, x: ArrayView3<f32>, train_size: Option<usize>) -> Array3<f32> {
let seq_len = x.shape()[x.ndim() - 2];
let rope_tables = self.rope.map(|cfg| RopeTables::new(cfg, seq_len));
let mut cur = x.to_owned();
for block in &self.blocks {
cur = mab_forward_train_size(
cur.view(),
&self.mab_cfg,
block,
rope_tables.as_ref(),
train_size,
);
}
cur
}
}
fn erf(x: f32) -> f32 {
let sign = x.signum();
let ax = x.abs();
let a1 = 0.254_829_6_f32;
let a2 = -0.284_496_72_f32;
let a3 = 1.421_413_8_f32;
let a4 = -1.453_152_1_f32;
let a5 = 1.061_405_4_f32;
let p = 0.3275911_f32;
let t = 1.0 / (1.0 + p * ax);
let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-ax * ax).exp();
sign * y
}
#[derive(Debug, Clone)]
pub struct IsabParams {
pub mab1: MabParams,
pub mab2: MabParams,
pub ind_vectors: ndarray::Array2<f32>,
}
impl IsabParams {
pub fn zeros(cfg: &MabConfig, num_inds: usize) -> Self {
Self {
mab1: MabParams::zeros(cfg),
mab2: MabParams::zeros(cfg),
ind_vectors: ndarray::Array2::<f32>::zeros((num_inds, cfg.d_model)),
}
}
pub fn load_from(
&mut self,
sd: &StateDict,
prefix: &str,
cfg: &MabConfig,
) -> Result<(), StateDictError> {
self.load_from_with_ssmax(sd, prefix, cfg, crate::ssmax::SsmaxKind::None)
}
pub fn load_from_with_ssmax(
&mut self,
sd: &StateDict,
prefix: &str,
cfg: &MabConfig,
mab1_ssmax: crate::ssmax::SsmaxKind,
) -> Result<(), StateDictError> {
self.mab1.load_from_with_ssmax(
sd,
&format!("{prefix}.multihead_attn1"),
cfg,
mab1_ssmax,
)?;
self.mab2
.load_from(sd, &format!("{prefix}.multihead_attn2"), cfg)?;
let num_inds = self.ind_vectors.shape()[0];
self.ind_vectors =
sd.take_array2(&format!("{prefix}.ind_vectors"), num_inds, cfg.d_model)?;
Ok(())
}
}
pub fn isab_forward(src: ArrayView3<f32>, cfg: &MabConfig, params: &IsabParams) -> Array3<f32> {
isab_forward_train_size(src, cfg, params, None)
}
pub fn isab_forward_train_size(
src: ArrayView3<f32>,
cfg: &MabConfig,
params: &IsabParams,
train_size: Option<usize>,
) -> Array3<f32> {
let (b, _n, d) = (src.shape()[0], src.shape()[1], src.shape()[2]);
let m = params.ind_vectors.shape()[0];
assert_eq!(params.ind_vectors.shape()[1], d);
let mut ind = Array3::<f32>::zeros((b, m, d));
for bi in 0..b {
for mi in 0..m {
for di in 0..d {
ind[(bi, mi, di)] = params.ind_vectors[(mi, di)];
}
}
}
let hidden = match train_size {
Some(k) => {
let k = k.min(src.shape()[1]);
let src_train = src.slice(ndarray::s![.., ..k, ..]);
mab_forward_qkv(ind.view(), src_train, src_train, cfg, ¶ms.mab1, None)
}
None => mab_forward_qkv(ind.view(), src, src, cfg, ¶ms.mab1, None),
};
mab_forward_qkv(src, hidden.view(), hidden.view(), cfg, ¶ms.mab2, None)
}
#[derive(Debug, Clone)]
pub struct SetTransformerStack {
pub mab_cfg: MabConfig,
pub num_inds: usize,
pub blocks: Vec<IsabParams>,
}
impl SetTransformerStack {
pub fn load_from(&mut self, sd: &StateDict, prefix: &str) -> Result<(), StateDictError> {
self.load_from_with_ssmax(sd, prefix, crate::ssmax::SsmaxKind::None)
}
pub fn load_from_with_ssmax(
&mut self,
sd: &StateDict,
prefix: &str,
mab1_ssmax: crate::ssmax::SsmaxKind,
) -> Result<(), StateDictError> {
let cfg = self.mab_cfg.clone();
for (i, block) in self.blocks.iter_mut().enumerate() {
block.load_from_with_ssmax(sd, &format!("{prefix}.blocks.{i}"), &cfg, mab1_ssmax)?;
}
Ok(())
}
pub fn new(
num_blocks: usize,
mab_cfg: MabConfig,
num_inds: usize,
) -> Result<Self, EncoderError> {
if !mab_cfg.d_model.is_multiple_of(mab_cfg.nhead) {
return Err(EncoderError::BadDims {
d_model: mab_cfg.d_model,
nhead: mab_cfg.nhead,
});
}
let blocks = (0..num_blocks)
.map(|_| IsabParams::zeros(&mab_cfg, num_inds))
.collect();
Ok(Self {
mab_cfg,
num_inds,
blocks,
})
}
pub fn forward(&self, src: ArrayView3<f32>) -> Array3<f32> {
self.forward_train_size(src, None)
}
pub fn forward_train_size(
&self,
src: ArrayView3<f32>,
train_size: Option<usize>,
) -> Array3<f32> {
let mut cur = src.to_owned();
for block in &self.blocks {
cur = isab_forward_train_size(cur.view(), &self.mab_cfg, block, train_size);
}
cur
}
}
#[allow(dead_code)]
fn _silence(_a: Axis) {}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array;
#[allow(dead_code)]
fn ident_attn_params(d: usize) -> AttentionParams {
let w = ndarray::Array2::<f32>::zeros((3 * d, d));
AttentionParams {
in_proj_weight: w,
in_proj_bias: Some(vec![0.0; 3 * d]),
out_proj_weight: ndarray::Array2::<f32>::zeros((d, d)),
out_proj_bias: Some(vec![0.0; d]),
}
}
#[test]
fn pre_norm_with_zero_attn_and_ff_is_identity() {
let d = 4;
let cfg = MabConfig {
d_model: d,
nhead: 2,
dim_feedforward: 8,
dropout: 0.0,
activation: Activation::Gelu,
norm_first: true,
bias_free_ln: false,
};
let params = MabParams::zeros(&cfg);
let x = Array::from_shape_fn((2, 3, d), |(b, t, k)| {
(b as f32) * 0.1 + (t as f32) * 0.01 + (k as f32) * 0.001
});
let y = mab_forward(x.view(), &cfg, ¶ms, None);
assert_eq!(y.shape(), x.shape());
for (a, b) in x.iter().zip(y.iter()) {
assert!((a - b).abs() < 1e-5, "{} vs {}", a, b);
}
}
#[test]
fn encoder_stack_with_zero_blocks_is_identity() {
let d = 4;
let cfg = MabConfig {
d_model: d,
nhead: 2,
dim_feedforward: 8,
dropout: 0.0,
activation: Activation::Gelu,
norm_first: true,
bias_free_ln: false,
};
let stack = EncoderStack::new(3, cfg, None).unwrap();
let x = Array::from_shape_fn((1, 5, d), |(_, t, k)| (t as f32) * 0.1 + (k as f32) * 0.01);
let y = stack.forward(x.view());
for (a, b) in x.iter().zip(y.iter()) {
assert!((a - b).abs() < 1e-4, "{} vs {}", a, b);
}
}
#[test]
fn stack_rejects_bad_dims() {
let cfg = MabConfig {
d_model: 5, nhead: 2,
dim_feedforward: 8,
dropout: 0.0,
activation: Activation::Gelu,
norm_first: true,
bias_free_ln: false,
};
let err = EncoderStack::new(1, cfg, None).unwrap_err();
assert!(matches!(err, EncoderError::BadDims { .. }));
}
#[test]
fn stack_with_rope_runs() {
let d = 8;
let cfg = MabConfig {
d_model: d,
nhead: 2,
dim_feedforward: 16,
dropout: 0.0,
activation: Activation::Gelu,
norm_first: true,
bias_free_ln: false,
};
let rope = RopeConfig {
head_dim: 4,
base: 100_000.0,
interleaved: false,
};
let stack = EncoderStack::new(2, cfg, Some(rope)).unwrap();
let x = Array::from_shape_fn((1, 6, d), |(_, t, k)| (t * d + k) as f32 * 0.01);
let y = stack.forward(x.view());
assert_eq!(y.shape(), x.shape());
for (a, b) in x.iter().zip(y.iter()) {
assert!((a - b).abs() < 1e-4, "{} vs {}", a, b);
}
}
#[test]
fn gelu_matches_pytorch_at_known_points() {
let mut x = Array::from_shape_vec((1, 1, 3), vec![0.0_f32, 1.0, -1.0]).unwrap();
apply_activation(&mut x, Activation::Gelu);
assert!(x[(0, 0, 0)].abs() < 1e-5);
assert!((x[(0, 0, 1)] - 0.8413).abs() < 1e-3);
assert!((x[(0, 0, 2)] + 0.1587).abs() < 1e-3);
}
#[test]
fn relu_clamps_negatives() {
let mut x = Array::from_shape_vec((1, 1, 3), vec![1.0_f32, 0.0, -1.0]).unwrap();
apply_activation(&mut x, Activation::Relu);
assert_eq!(x[(0, 0, 0)], 1.0);
assert_eq!(x[(0, 0, 1)], 0.0);
assert_eq!(x[(0, 0, 2)], 0.0);
}
#[test]
fn isab_zero_init_preserves_input() {
let d = 4;
let cfg = MabConfig {
d_model: d,
nhead: 2,
dim_feedforward: 8,
dropout: 0.0,
activation: Activation::Gelu,
norm_first: true,
bias_free_ln: false,
};
let stack = SetTransformerStack::new(2, cfg, 3).unwrap();
let src = Array::from_shape_fn((2, 5, d), |(b, n, k)| (b * 100 + n * 10 + k) as f32 * 0.01);
let out = stack.forward(src.view());
for (a, b) in src.iter().zip(out.iter()) {
assert!((a - b).abs() < 1e-4, "ISAB zero-init drift: {} vs {}", a, b);
}
}
#[test]
fn isab_output_shape_matches_input() {
let d = 8;
let cfg = MabConfig {
d_model: d,
nhead: 2,
dim_feedforward: 16,
dropout: 0.0,
activation: Activation::Gelu,
norm_first: true,
bias_free_ln: false,
};
let stack = SetTransformerStack::new(1, cfg, 4).unwrap();
let src = Array::from_shape_fn((1, 10, d), |(_, n, k)| (n * d + k) as f32 * 0.001);
let out = stack.forward(src.view());
assert_eq!(out.shape(), src.shape());
}
#[test]
fn train_size_mask_blocks_test_to_test_attention() {
let d = 4;
let cfg = MabConfig {
d_model: d,
nhead: 1,
dim_feedforward: 8,
dropout: 0.0,
activation: Activation::Gelu,
norm_first: true,
bias_free_ln: false,
};
let mut params = MabParams::zeros(&cfg);
for i in 0..d {
params.attn.in_proj_weight[(i, i)] = 1.0;
params.attn.in_proj_weight[(d + i, i)] = 1.0;
params.attn.in_proj_weight[(2 * d + i, i)] = 1.0;
params.attn.out_proj_weight[(i, i)] = 1.0;
}
let raw = [
[1.0_f32, 2.0, 3.0, 4.0],
[4.0, 3.0, 2.0, 1.0],
[1.0, 5.0, 1.0, 5.0],
[-2.0, 1.0, 4.0, 0.5],
];
let x = Array::from_shape_fn((1, 4, d), |(_, t, e)| raw[t][e]);
let y_unmasked = mab_forward(x.view(), &cfg, ¶ms, None);
let y_masked = mab_forward_train_size(x.view(), &cfg, ¶ms, None, Some(2));
let row3_unmasked: Vec<f32> = (0..d).map(|e| y_unmasked[(0, 3, e)]).collect();
let row3_masked: Vec<f32> = (0..d).map(|e| y_masked[(0, 3, e)]).collect();
let differs = row3_unmasked
.iter()
.zip(row3_masked.iter())
.any(|(a, b)| (a - b).abs() > 1e-3);
assert!(differs, "train-size masking did not affect test row output");
}
#[test]
fn mab_cross_attention_uses_separate_kv() {
let d = 4;
let cfg = MabConfig {
d_model: d,
nhead: 2,
dim_feedforward: 8,
dropout: 0.0,
activation: Activation::Gelu,
norm_first: true,
bias_free_ln: false,
};
let mut params = MabParams::zeros(&cfg);
for i in 0..d {
params.attn.in_proj_weight[(2 * d + i, i)] = 1.0;
}
for i in 0..d {
params.attn.out_proj_weight[(i, i)] = 1.0;
}
let q = Array::from_shape_vec((1, 1, d), vec![5.0_f32, 6.0, 7.0, 8.0]).unwrap();
let k = Array::from_shape_vec((1, 1, d), vec![0.0_f32; 4]).unwrap();
let v = Array::from_shape_vec((1, 1, d), vec![1.0_f32, 2.0, 3.0, 4.0]).unwrap();
let out = mab_forward_qkv(q.view(), k.view(), v.view(), &cfg, ¶ms, None);
let q_view = q.view();
let differs = (0..d).any(|i| (out[(0, 0, i)] - q_view[(0, 0, i)]).abs() > 1e-3);
assert!(differs, "expected V to contribute to output via attn");
}
}