use thiserror::Error;
#[derive(Debug, Error)]
pub enum SsmaxError {
#[error("unknown ssmax kind: {0:?}")]
UnknownKind(String),
#[error("`head_dim` is required for `elementwise=true` SSMax variants")]
MissingHeadDim,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum SsmaxKind {
None,
Ssmax,
SsmaxMlp,
SsmaxMlpElementwise,
QassmaxMlp,
QassmaxMlpElementwise,
}
impl SsmaxKind {
pub fn parse(s: &str) -> Result<Self, SsmaxError> {
Ok(match s {
"none" => Self::None,
"ssmax" => Self::Ssmax,
"ssmax-mlp" => Self::SsmaxMlp,
"ssmax-mlp-elementwise" => Self::SsmaxMlpElementwise,
"qassmax-mlp" => Self::QassmaxMlp,
"qassmax-mlp-elementwise" => Self::QassmaxMlpElementwise,
other => return Err(SsmaxError::UnknownKind(other.to_string())),
})
}
pub fn from_bool(b: bool) -> Self {
if b {
Self::QassmaxMlpElementwise
} else {
Self::None
}
}
pub fn is_active(self) -> bool {
!matches!(self, Self::None)
}
}
#[derive(Debug, Clone)]
pub struct SsmaxSpec {
pub kind: SsmaxKind,
pub num_heads: usize,
pub head_dim: usize,
pub n_hidden: usize,
}
impl SsmaxSpec {
pub fn create(
kind: SsmaxKind,
num_heads: usize,
embed_dim: usize,
) -> Result<Option<Self>, SsmaxError> {
if !kind.is_active() {
return Ok(None);
}
if num_heads == 0 || embed_dim == 0 || !embed_dim.is_multiple_of(num_heads) {
return Err(SsmaxError::MissingHeadDim);
}
Ok(Some(Self {
kind,
num_heads,
head_dim: embed_dim / num_heads,
n_hidden: 64,
}))
}
pub fn base_out_dim(&self) -> usize {
match self.kind {
SsmaxKind::Ssmax | SsmaxKind::None => self.num_heads,
SsmaxKind::SsmaxMlp | SsmaxKind::QassmaxMlp => self.num_heads,
SsmaxKind::SsmaxMlpElementwise | SsmaxKind::QassmaxMlpElementwise => {
self.num_heads * self.head_dim
}
}
}
pub fn query_out_dim(&self) -> usize {
match self.kind {
SsmaxKind::QassmaxMlp => 1,
SsmaxKind::QassmaxMlpElementwise => self.head_dim,
_ => 0,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct SsmaxParams {
pub scales: Vec<f32>,
pub base_w1: ndarray::Array2<f32>,
pub base_b1: Vec<f32>,
pub base_w2: ndarray::Array2<f32>,
pub base_b2: Vec<f32>,
pub query_w1: ndarray::Array2<f32>,
pub query_b1: Vec<f32>,
pub query_w2: ndarray::Array2<f32>,
pub query_b2: Vec<f32>,
}
impl SsmaxParams {
pub fn zeros(spec: &SsmaxSpec) -> Self {
let base_out = spec.base_out_dim();
let q_out = spec.query_out_dim();
Self {
scales: vec![1.0; spec.num_heads],
base_w1: ndarray::Array2::<f32>::zeros((spec.n_hidden, 1)),
base_b1: vec![0.0; spec.n_hidden],
base_w2: ndarray::Array2::<f32>::zeros((base_out, spec.n_hidden)),
base_b2: vec![0.0; base_out],
query_w1: if q_out > 0 {
ndarray::Array2::<f32>::zeros((spec.n_hidden, spec.head_dim))
} else {
ndarray::Array2::<f32>::zeros((0, 0))
},
query_b1: if q_out > 0 {
vec![0.0; spec.n_hidden]
} else {
Vec::new()
},
query_w2: if q_out > 0 {
ndarray::Array2::<f32>::zeros((q_out, spec.n_hidden))
} else {
ndarray::Array2::<f32>::zeros((0, 0))
},
query_b2: if q_out > 0 {
vec![0.0; q_out]
} else {
Vec::new()
},
}
}
pub fn load_from(
&mut self,
sd: &crate::state_dict::StateDict,
prefix: &str,
spec: &SsmaxSpec,
) -> Result<(), crate::state_dict::StateDictError> {
use crate::state_dict::StateDictError;
let p = format!("{prefix}.ssmax_layer");
let scales_key = format!("{p}.scales");
if sd.tensors.contains_key(&scales_key) {
self.scales = sd.take_vec(&scales_key, spec.num_heads)?;
return Ok(());
}
let base_out = spec.base_out_dim();
self.base_w1 = sd
.take_array2(&format!("{p}.base_mlp.0.weight"), spec.n_hidden, 1)
.map_err(|e| StateDictError::MissingKey(format!("{p}.base_mlp.0.weight: {e}")))?;
self.base_b1 = sd.take_vec(&format!("{p}.base_mlp.0.bias"), spec.n_hidden)?;
self.base_w2 =
sd.take_array2(&format!("{p}.base_mlp.2.weight"), base_out, spec.n_hidden)?;
self.base_b2 = sd.take_vec(&format!("{p}.base_mlp.2.bias"), base_out)?;
let q_out = spec.query_out_dim();
if q_out > 0 {
self.query_w1 = sd.take_array2(
&format!("{p}.query_mlp.0.weight"),
spec.n_hidden,
spec.head_dim,
)?;
self.query_b1 = sd.take_vec(&format!("{p}.query_mlp.0.bias"), spec.n_hidden)?;
self.query_w2 =
sd.take_array2(&format!("{p}.query_mlp.2.weight"), q_out, spec.n_hidden)?;
self.query_b2 = sd.take_vec(&format!("{p}.query_mlp.2.bias"), q_out)?;
}
Ok(())
}
}
pub fn compute_query_scale(
spec: &SsmaxSpec,
params: &SsmaxParams,
q: ndarray::ArrayView4<f32>,
n_src: usize,
) -> ndarray::Array4<f32> {
let (b, h, t, d) = (q.shape()[0], q.shape()[1], q.shape()[2], q.shape()[3]);
let log_n = (n_src.max(1) as f32).ln();
let base_scales: Vec<f32> = match spec.kind {
SsmaxKind::None => return ndarray::Array4::<f32>::ones((1, h, 1, 1)),
SsmaxKind::Ssmax => {
params.scales.iter().map(|s| s * log_n).collect()
}
_ => {
let n_hidden = params.base_w1.shape()[0];
let mut h1 = vec![0.0_f32; n_hidden];
for (k, h1k) in h1.iter_mut().enumerate().take(n_hidden) {
let pre = params.base_w1[(k, 0)] * log_n + params.base_b1[k];
*h1k = 0.5 * pre * (1.0 + erf_ss(pre / std::f32::consts::SQRT_2));
}
let out_dim = params.base_w2.shape()[0];
let mut h2 = vec![0.0_f32; out_dim];
for (o, h2o) in h2.iter_mut().enumerate().take(out_dim) {
let mut s = params.base_b2[o];
for (k, h1k) in h1.iter().enumerate().take(n_hidden) {
s += params.base_w2[(o, k)] * h1k;
}
*h2o = s;
}
h2
}
};
let scale_d = match spec.kind {
SsmaxKind::Ssmax | SsmaxKind::SsmaxMlp | SsmaxKind::QassmaxMlp => 1,
SsmaxKind::SsmaxMlpElementwise | SsmaxKind::QassmaxMlpElementwise => d,
SsmaxKind::None => 1,
};
let mut base = ndarray::Array4::<f32>::zeros((1, h, 1, scale_d));
for hi in 0..h {
for di in 0..scale_d {
let idx = if scale_d == 1 { hi } else { hi * scale_d + di };
base[(0, hi, 0, di)] = base_scales[idx];
}
}
if matches!(
spec.kind,
SsmaxKind::QassmaxMlp | SsmaxKind::QassmaxMlpElementwise
) {
let n_hidden = params.query_w1.shape()[0];
let q_out = params.query_w2.shape()[0];
let mut out = ndarray::Array4::<f32>::zeros((b, h, t, q_out));
for bi in 0..b {
for hi in 0..h {
for ti in 0..t {
let mut h1 = vec![0.0_f32; n_hidden];
for (k, h1k) in h1.iter_mut().enumerate().take(n_hidden) {
let mut s = params.query_b1[k];
for di in 0..d {
s += params.query_w1[(k, di)] * q[(bi, hi, ti, di)];
}
*h1k = 0.5 * s * (1.0 + erf_ss(s / std::f32::consts::SQRT_2));
}
for o in 0..q_out {
let mut s2 = params.query_b2[o];
for (k, h1k) in h1.iter().enumerate().take(n_hidden) {
s2 += params.query_w2[(o, k)] * h1k;
}
out[(bi, hi, ti, o)] = 1.0 + s2.tanh();
}
}
}
}
let mut result = ndarray::Array4::<f32>::zeros((b, h, t, q_out));
for bi in 0..b {
for hi in 0..h {
for ti in 0..t {
for di in 0..q_out {
let base_v = base[(0, hi, 0, di.min(scale_d - 1))];
result[(bi, hi, ti, di)] = base_v * out[(bi, hi, ti, di)];
}
}
}
}
return result;
}
base
}
fn erf_ss(x: f32) -> f32 {
let sign = x.signum();
let ax = x.abs();
let t = 1.0 / (1.0 + 0.3275911 * ax);
let y = 1.0
- (((((1.061_405_4_f32 * t - 1.453_152_1) * t + 1.421_413_8) * t - 0.284_496_72) * t
+ 0.254_829_6)
* t)
* (-ax * ax).exp();
sign * y
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_round_trip() {
let cases = [
("none", SsmaxKind::None),
("ssmax", SsmaxKind::Ssmax),
("ssmax-mlp", SsmaxKind::SsmaxMlp),
("ssmax-mlp-elementwise", SsmaxKind::SsmaxMlpElementwise),
("qassmax-mlp", SsmaxKind::QassmaxMlp),
("qassmax-mlp-elementwise", SsmaxKind::QassmaxMlpElementwise),
];
for (s, want) in cases {
assert_eq!(SsmaxKind::parse(s).unwrap(), want);
}
assert!(SsmaxKind::parse("nope").is_err());
}
#[test]
fn bool_shorthand_matches_python_default() {
assert_eq!(SsmaxKind::from_bool(true), SsmaxKind::QassmaxMlpElementwise);
assert_eq!(SsmaxKind::from_bool(false), SsmaxKind::None);
}
#[test]
fn dims_match_python_layout() {
let s = SsmaxSpec::create(SsmaxKind::QassmaxMlpElementwise, 8, 128)
.unwrap()
.unwrap();
assert_eq!(s.head_dim, 16);
assert_eq!(s.base_out_dim(), 8 * 16);
assert_eq!(s.query_out_dim(), 16);
let s = SsmaxSpec::create(SsmaxKind::Ssmax, 8, 128)
.unwrap()
.unwrap();
assert_eq!(s.base_out_dim(), 8);
assert_eq!(s.query_out_dim(), 0);
}
}