#[derive(Debug, Clone)]
pub struct SSMConfig {
pub d_model: usize,
pub d_state: usize,
pub d_conv: usize,
pub expand_factor: usize,
pub dt_rank: usize,
}
impl SSMConfig {
pub fn new(d_model: usize) -> Self {
let expand = 2;
Self {
d_model,
d_state: 16,
d_conv: 4,
expand_factor: expand,
dt_rank: (d_model + 15) / 16, }
}
pub fn d_inner(&self) -> usize {
self.d_model * self.expand_factor
}
pub fn validate(&self) -> Result<(), &'static str> {
if self.d_model == 0 {
return Err("d_model must be > 0");
}
if self.d_state == 0 {
return Err("d_state must be > 0");
}
if self.d_conv == 0 {
return Err("d_conv must be > 0");
}
if self.expand_factor == 0 {
return Err("expand_factor must be > 0");
}
if self.dt_rank == 0 {
return Err("dt_rank must be > 0");
}
Ok(())
}
}
#[inline]
pub fn softplus(x: f32) -> f32 {
if x > 20.0 {
x } else if x < -20.0 {
0.0
} else {
(1.0 + x.exp()).ln()
}
}
#[inline]
pub fn silu(x: f32) -> f32 {
x / (1.0 + (-x).exp())
}
pub fn rms_norm(x: &[f32], weight: &[f32], eps: f32) -> Vec<f32> {
let n = x.len();
assert_eq!(n, weight.len(), "rms_norm: x and weight must match in size");
let mean_sq = x.iter().map(|v| v * v).sum::<f32>() / n as f32;
let inv_rms = 1.0 / (mean_sq + eps).sqrt();
x.iter()
.zip(weight.iter())
.map(|(&xi, &wi)| xi * inv_rms * wi)
.collect()
}
fn matvec(matrix: &[f32], x: &[f32], rows: usize, cols: usize) -> Vec<f32> {
assert_eq!(matrix.len(), rows * cols);
assert_eq!(x.len(), cols);
(0..rows)
.map(|r| {
let row = &matrix[r * cols..(r + 1) * cols];
row.iter().zip(x.iter()).map(|(m, v)| m * v).sum()
})
.collect()
}
pub struct SelectiveSSM {
config: SSMConfig,
a_log: Vec<f32>, conv_weight: Vec<f32>,
conv_bias: Vec<f32>, in_proj: Vec<f32>,
w_dt: Vec<f32>,
dt_bias: Vec<f32>, w_b: Vec<f32>,
w_c: Vec<f32>,
out_proj: Vec<f32>,
}
impl SelectiveSSM {
pub fn new(config: SSMConfig) -> Self {
config.validate().expect("invalid SSMConfig");
let d_inner = config.d_inner();
let d_state = config.d_state;
let d_model = config.d_model;
let d_conv = config.d_conv;
let dt_rank = config.dt_rank;
let a_log = vec![0.0_f32; d_inner * d_state];
let conv_weight = vec![1.0 / d_conv as f32; d_inner * d_conv];
let conv_bias = vec![0.0; d_inner];
let scale = 1.0 / (d_model as f32).sqrt();
let in_proj = vec![scale; 2 * d_inner * d_model];
let w_dt = vec![scale; d_inner * dt_rank];
let dt_bias = vec![0.0; d_inner];
let w_b = vec![scale; d_state * d_inner];
let w_c = vec![scale; d_state * d_inner];
let out_proj = vec![scale; d_model * d_inner];
Self {
config,
a_log,
conv_weight,
conv_bias,
in_proj,
w_dt,
dt_bias,
w_b,
w_c,
out_proj,
}
}
pub fn config(&self) -> &SSMConfig {
&self.config
}
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
let d_model = self.config.d_model;
let seq_len = input.len() / d_model;
assert_eq!(input.len(), seq_len * d_model, "input not divisible by d_model");
let d_inner = self.config.d_inner();
let mut z_seq = Vec::with_capacity(seq_len * d_inner);
let mut xc_seq = Vec::with_capacity(seq_len * d_inner);
for t in 0..seq_len {
let x_t = &input[t * d_model..(t + 1) * d_model];
let projected = matvec(&self.in_proj, x_t, 2 * d_inner, d_model);
z_seq.extend_from_slice(&projected[..d_inner]);
xc_seq.extend_from_slice(&projected[d_inner..]);
}
let xc_conv = self.causal_conv(&xc_seq, seq_len, d_inner);
let y_seq = self.selective_scan(&xc_conv, seq_len, d_inner);
let mut output = Vec::with_capacity(seq_len * d_model);
for t in 0..seq_len {
let gated: Vec<f32> = (0..d_inner)
.map(|i| y_seq[t * d_inner + i] * silu(z_seq[t * d_inner + i]))
.collect();
let out_t = matvec(&self.out_proj, &gated, d_model, d_inner);
output.extend_from_slice(&out_t);
}
output
}
fn causal_conv(&self, xc: &[f32], seq_len: usize, d_inner: usize) -> Vec<f32> {
let d_conv = self.config.d_conv;
let mut out = vec![0.0; seq_len * d_inner];
for t in 0..seq_len {
for i in 0..d_inner {
let mut acc = self.conv_bias[i];
for k in 0..d_conv {
if t >= k {
let w = self.conv_weight[i * d_conv + k];
acc += w * xc[(t - k) * d_inner + i];
}
}
out[t * d_inner + i] = silu(acc);
}
}
out
}
fn selective_scan(&self, x: &[f32], seq_len: usize, d_inner: usize) -> Vec<f32> {
let d_state = self.config.d_state;
let mut h = vec![0.0_f32; d_inner * d_state];
let mut y_seq = Vec::with_capacity(seq_len * d_inner);
for t in 0..seq_len {
let x_t = &x[t * d_inner..(t + 1) * d_inner];
let dt_pre = matvec(&self.w_dt, x_t, self.config.dt_rank, d_inner);
let delta: Vec<f32> = (0..d_inner)
.map(|i| softplus(dt_pre[i % self.config.dt_rank] + self.dt_bias[i]))
.collect();
let b_t = matvec(&self.w_b, x_t, d_state, d_inner);
let c_t = matvec(&self.w_c, x_t, d_state, d_inner);
let mut y_t = vec![0.0_f32; d_inner];
for i in 0..d_inner {
for j in 0..d_state {
let a = -(-self.a_log[i * d_state + j]).exp(); let a_bar = (delta[i] * a).exp();
let b_bar = delta[i] * b_t[j];
let idx = i * d_state + j;
h[idx] = a_bar * h[idx] + b_bar * x_t[i];
y_t[i] += c_t[j] * h[idx];
}
}
y_seq.extend_from_slice(&y_t);
}
y_seq
}
pub fn init_state(&self) -> SSMState {
SSMState {
h: vec![0.0; self.config.d_inner() * self.config.d_state],
d_inner: self.config.d_inner(),
d_state: self.config.d_state,
}
}
pub fn step(&self, token: &[f32], state: &mut SSMState) -> Vec<f32> {
let d_model = self.config.d_model;
let d_inner = self.config.d_inner();
let d_state = self.config.d_state;
assert_eq!(token.len(), d_model);
let projected = matvec(&self.in_proj, token, 2 * d_inner, d_model);
let z = &projected[..d_inner];
let xc: Vec<f32> = (0..d_inner).map(|i| silu(projected[d_inner + i])).collect();
let dt_pre = matvec(&self.w_dt, &xc, self.config.dt_rank, d_inner);
let delta: Vec<f32> = (0..d_inner)
.map(|i| softplus(dt_pre[i % self.config.dt_rank] + self.dt_bias[i]))
.collect();
let b_t = matvec(&self.w_b, &xc, d_state, d_inner);
let c_t = matvec(&self.w_c, &xc, d_state, d_inner);
let mut y = vec![0.0_f32; d_inner];
for i in 0..d_inner {
for j in 0..d_state {
let a = -(-self.a_log[i * d_state + j]).exp();
let a_bar = (delta[i] * a).exp();
let b_bar = delta[i] * b_t[j];
let idx = i * d_state + j;
state.h[idx] = a_bar * state.h[idx] + b_bar * xc[i];
y[i] += c_t[j] * state.h[idx];
}
}
let gated: Vec<f32> = (0..d_inner).map(|i| y[i] * silu(z[i])).collect();
matvec(&self.out_proj, &gated, d_model, d_inner)
}
}
#[derive(Debug, Clone)]
pub struct SSMState {
pub h: Vec<f32>,
d_inner: usize,
d_state: usize,
}
impl SSMState {
pub fn reset(&mut self) {
self.h.fill(0.0);
}
pub fn shape(&self) -> (usize, usize) {
(self.d_inner, self.d_state)
}
}
pub struct MambaBlock {
ssm: SelectiveSSM,
norm_weight: Vec<f32>,
norm_eps: f32,
}
impl MambaBlock {
pub fn new(config: SSMConfig) -> Self {
let d = config.d_model;
Self {
ssm: SelectiveSSM::new(config),
norm_weight: vec![1.0; d],
norm_eps: 1e-5,
}
}
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
let d = self.ssm.config().d_model;
let seq_len = input.len() / d;
let mut normed = Vec::with_capacity(input.len());
for t in 0..seq_len {
let tok = &input[t * d..(t + 1) * d];
normed.extend(rms_norm(tok, &self.norm_weight, self.norm_eps));
}
let ssm_out = self.ssm.forward(&normed);
input.iter().zip(ssm_out.iter()).map(|(a, b)| a + b).collect()
}
pub fn step(&self, token: &[f32], state: &mut SSMState) -> Vec<f32> {
let normed = rms_norm(token, &self.norm_weight, self.norm_eps);
let out = self.ssm.step(&normed, state);
token.iter().zip(out.iter()).map(|(a, b)| a + b).collect()
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum LayerKind {
SSM,
Attention,
}
#[derive(Debug, Clone)]
pub struct HybridConfig {
pub ssm: SSMConfig,
pub num_layers: usize,
pub hybrid_ratio: f32,
}
impl HybridConfig {
pub fn layer_schedule(&self) -> Vec<LayerKind> {
(0..self.num_layers)
.map(|i| {
let attn_every = if self.hybrid_ratio <= 0.0 {
usize::MAX
} else {
(1.0 / self.hybrid_ratio).round().max(1.0) as usize
};
if attn_every < usize::MAX && i % attn_every == attn_every - 1 {
LayerKind::Attention
} else {
LayerKind::SSM
}
})
.collect()
}
}
pub struct HybridBlock {
schedule: Vec<LayerKind>,
ssm_layers: Vec<MambaBlock>,
num_attention_layers: usize,
}
impl HybridBlock {
pub fn new(config: HybridConfig) -> Self {
let schedule = config.layer_schedule();
let ssm_count = schedule.iter().filter(|k| **k == LayerKind::SSM).count();
let attn_count = schedule.len() - ssm_count;
let ssm_layers = (0..ssm_count)
.map(|_| MambaBlock::new(config.ssm.clone()))
.collect();
Self {
schedule,
ssm_layers,
num_attention_layers: attn_count,
}
}
pub fn schedule(&self) -> &[LayerKind] {
&self.schedule
}
pub fn attention_layer_count(&self) -> usize {
self.num_attention_layers
}
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
let mut x = input.to_vec();
let mut ssm_idx = 0;
for kind in &self.schedule {
match kind {
LayerKind::SSM => {
x = self.ssm_layers[ssm_idx].forward(&x);
ssm_idx += 1;
}
LayerKind::Attention => {
}
}
}
x
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_defaults() {
let c = SSMConfig::new(64);
assert_eq!(c.d_model, 64);
assert_eq!(c.d_state, 16);
assert_eq!(c.d_conv, 4);
assert_eq!(c.expand_factor, 2);
assert_eq!(c.d_inner(), 128);
assert!(c.validate().is_ok());
}
#[test]
fn test_config_validation_errors() {
let mut c = SSMConfig::new(64);
c.d_model = 0;
assert!(c.validate().is_err());
c.d_model = 64;
c.d_state = 0;
assert!(c.validate().is_err());
c.d_state = 16;
c.d_conv = 0;
assert!(c.validate().is_err());
}
#[test]
fn test_softplus_values() {
assert!((softplus(0.0) - 0.6931).abs() < 1e-3); assert!((softplus(1.0) - 1.3133).abs() < 1e-3); assert!((softplus(25.0) - 25.0).abs() < 1e-3);
assert!(softplus(-25.0) < 1e-3);
}
#[test]
fn test_silu_values() {
assert!((silu(0.0)).abs() < 1e-6); assert!((silu(1.0) - 0.7311).abs() < 1e-3);
assert!(silu(-5.0) < 0.0);
}
#[test]
fn test_rms_norm() {
let x = vec![3.0, 4.0];
let w = vec![1.0, 1.0];
let normed = rms_norm(&x, &w, 1e-8);
let rms = (12.5_f32).sqrt();
assert!((normed[0] - 3.0 / rms).abs() < 1e-4);
assert!((normed[1] - 4.0 / rms).abs() < 1e-4);
}
#[test]
fn test_selective_scan_single_step() {
let config = SSMConfig::new(4);
let ssm = SelectiveSSM::new(config);
let input = vec![1.0; 4]; let output = ssm.forward(&input);
assert_eq!(output.len(), 4);
assert!(output.iter().all(|v| v.is_finite()));
}
#[test]
fn test_selective_scan_sequence() {
let config = SSMConfig::new(4);
let ssm = SelectiveSSM::new(config);
let seq_len = 5;
let input = vec![0.5; seq_len * 4];
let output = ssm.forward(&input);
assert_eq!(output.len(), seq_len * 4);
assert!(output.iter().all(|v| v.is_finite()));
}
#[test]
fn test_state_recurrence_consistency() {
let config = SSMConfig::new(4);
let ssm = SelectiveSSM::new(config);
let token = vec![1.0; 4];
let batch_out = ssm.forward(&token);
let mut state = ssm.init_state();
let step_out = ssm.step(&token, &mut state);
assert_eq!(batch_out.len(), step_out.len());
assert!(step_out.iter().all(|v| v.is_finite()));
}
#[test]
fn test_mamba_block_forward() {
let config = SSMConfig::new(8);
let block = MambaBlock::new(config);
let input = vec![1.0; 3 * 8]; let output = block.forward(&input);
assert_eq!(output.len(), 3 * 8);
assert!(output.iter().all(|v| v.is_finite()));
assert!(output.iter().any(|v| *v != 0.0));
}
#[test]
fn test_hybrid_routing() {
let hc = HybridConfig {
ssm: SSMConfig::new(4),
num_layers: 8,
hybrid_ratio: 0.25,
};
let schedule = hc.layer_schedule();
assert_eq!(schedule.len(), 8);
let attn_count = schedule.iter().filter(|k| **k == LayerKind::Attention).count();
assert_eq!(attn_count, 2); assert_eq!(schedule[3], LayerKind::Attention);
assert_eq!(schedule[7], LayerKind::Attention);
}
#[test]
fn test_hybrid_block_forward() {
let hc = HybridConfig {
ssm: SSMConfig::new(4),
num_layers: 4,
hybrid_ratio: 0.25,
};
let block = HybridBlock::new(hc);
assert_eq!(block.attention_layer_count(), 1);
let input = vec![1.0; 2 * 4]; let output = block.forward(&input);
assert_eq!(output.len(), 2 * 4);
assert!(output.iter().all(|v| v.is_finite()));
}
#[test]
fn test_inference_step_updates_state() {
let config = SSMConfig::new(4);
let ssm = SelectiveSSM::new(config);
let mut state = ssm.init_state();
assert!(state.h.iter().all(|v| *v == 0.0));
let token = vec![1.0; 4];
let _ = ssm.step(&token, &mut state);
assert!(state.h.iter().any(|v| *v != 0.0));
let h_after_1 = state.h.clone();
let _ = ssm.step(&token, &mut state);
assert_ne!(state.h, h_after_1);
}
#[test]
fn test_ssm_state_reset() {
let config = SSMConfig::new(4);
let ssm = SelectiveSSM::new(config);
let mut state = ssm.init_state();
let _ = ssm.step(&vec![1.0; 4], &mut state);
assert!(state.h.iter().any(|v| *v != 0.0));
state.reset();
assert!(state.h.iter().all(|v| *v == 0.0));
assert_eq!(state.shape(), (8, 16)); }
}