use crate::mamba2::config::Mamba2Config;
#[derive(Debug, thiserror::Error)]
pub enum Mamba2Error {
#[error("Empty input")]
EmptyInput,
#[error("Dimension mismatch: {0}")]
DimMismatch(String),
}
#[inline]
pub fn softplus(x: f64) -> f64 {
if x > 20.0 {
x
} else if x < -20.0 {
x.exp()
} else {
(1.0 + x.exp()).ln()
}
}
#[inline]
fn silu(x: f64) -> f64 {
x / (1.0 + (-x).exp())
}
fn mat_vec_mul(weight: &[Vec<f64>], x: &[f64]) -> Result<Vec<f64>, Mamba2Error> {
if weight.is_empty() {
return Ok(Vec::new());
}
let in_dim = weight[0].len();
if x.len() != in_dim {
return Err(Mamba2Error::DimMismatch(format!(
"mat_vec_mul: weight cols={} but x len={}",
in_dim,
x.len()
)));
}
let out: Vec<f64> = weight
.iter()
.map(|row| row.iter().zip(x.iter()).map(|(w, v)| w * v).sum())
.collect();
Ok(out)
}
pub struct Mamba2RmsNorm {
weight: Vec<f64>,
eps: f64,
}
impl Mamba2RmsNorm {
pub fn new(dim: usize, eps: f64) -> Self {
Self {
weight: vec![1.0; dim],
eps,
}
}
pub fn forward(&self, x: &[f64]) -> Result<Vec<f64>, Mamba2Error> {
if x.is_empty() {
return Err(Mamba2Error::EmptyInput);
}
if x.len() != self.weight.len() {
return Err(Mamba2Error::DimMismatch(format!(
"RmsNorm: weight dim={} but x len={}",
self.weight.len(),
x.len()
)));
}
let mean_sq: f64 = x.iter().map(|v| v * v).sum::<f64>() / x.len() as f64;
let rms = (mean_sq + self.eps).sqrt();
let out = x.iter().zip(self.weight.iter()).map(|(v, w)| v / rms * w).collect();
Ok(out)
}
pub fn dim(&self) -> usize {
self.weight.len()
}
}
pub struct Mamba2SSM {
in_proj: Vec<Vec<f64>>,
out_proj: Vec<Vec<f64>>,
a_log: Vec<f64>,
d_bias: Vec<f64>,
dt_bias: Vec<f64>,
conv_weight: Vec<Vec<f64>>,
config: Mamba2Config,
}
impl Mamba2SSM {
pub fn new(config: &Mamba2Config) -> Self {
let inner_dim = config.inner_dim();
let nheads = config.nheads;
let d_state = config.d_state;
let d_model = config.d_model;
let d_conv = config.d_conv;
let in_proj_out = 2 * inner_dim + 2 * nheads * d_state + nheads;
let in_proj: Vec<Vec<f64>> = (0..in_proj_out)
.map(|i| {
let mut row = vec![0.0f64; d_model];
row[i % d_model] = 0.02;
row
})
.collect();
let out_proj: Vec<Vec<f64>> = (0..d_model)
.map(|i| {
let mut row = vec![0.0f64; inner_dim];
row[i % inner_dim] = 0.02;
row
})
.collect();
let a_log = vec![0.0f64; nheads];
let d_bias = vec![1.0f64; nheads];
let dt_bias = vec![0.0f64; nheads];
let conv_weight: Vec<Vec<f64>> =
(0..inner_dim).map(|_| vec![1.0 / d_conv as f64; d_conv]).collect();
Self {
in_proj,
out_proj,
a_log,
d_bias,
dt_bias,
conv_weight,
config: config.clone(),
}
}
fn causal_conv(&self, x: &[Vec<f64>], _inner_dim: usize) -> Result<Vec<Vec<f64>>, Mamba2Error> {
let seq_len = x.len();
if seq_len == 0 {
return Err(Mamba2Error::EmptyInput);
}
let channels = x[0].len();
let d_conv = self.config.d_conv;
let mut out = vec![vec![0.0f64; channels]; seq_len];
for t in 0..seq_len {
for c in 0..channels {
let w = &self.conv_weight[c];
let mut val = 0.0f64;
for k in 0..d_conv {
if t >= k {
val += w[k] * x[t - k][c];
}
}
out[t][c] = val;
}
}
Ok(out)
}
pub fn forward(&self, x: &[Vec<f64>]) -> Result<Vec<Vec<f64>>, Mamba2Error> {
let seq_len = x.len();
if seq_len == 0 {
return Err(Mamba2Error::EmptyInput);
}
let d_model = self.config.d_model;
if x[0].len() != d_model {
return Err(Mamba2Error::DimMismatch(format!(
"SSM forward: expected d_model={} but got {}",
d_model,
x[0].len()
)));
}
let inner_dim = self.config.inner_dim();
let nheads = self.config.nheads;
let d_state = self.config.d_state;
let headdim = self.config.headdim;
let z_offset = 0usize;
let x_offset = inner_dim;
let b_offset = 2 * inner_dim;
let c_offset = b_offset + nheads * d_state;
let dt_offset = c_offset + nheads * d_state;
let mut proj_out: Vec<Vec<f64>> = Vec::with_capacity(seq_len);
for token in x.iter() {
proj_out.push(mat_vec_mul(&self.in_proj, token)?);
}
let z_seq: Vec<Vec<f64>> =
proj_out.iter().map(|p| p[z_offset..z_offset + inner_dim].to_vec()).collect();
let x_ssm_raw: Vec<Vec<f64>> =
proj_out.iter().map(|p| p[x_offset..x_offset + inner_dim].to_vec()).collect();
let b_seq: Vec<Vec<f64>> = proj_out
.iter()
.map(|p| p[b_offset..b_offset + nheads * d_state].to_vec())
.collect();
let c_seq: Vec<Vec<f64>> = proj_out
.iter()
.map(|p| p[c_offset..c_offset + nheads * d_state].to_vec())
.collect();
let dt_seq: Vec<Vec<f64>> =
proj_out.iter().map(|p| p[dt_offset..dt_offset + nheads].to_vec()).collect();
let x_ssm = self.causal_conv(&x_ssm_raw, inner_dim)?;
let mut h: Vec<Vec<Vec<f64>>> = vec![vec![vec![0.0f64; d_state]; headdim]; nheads];
let mut y_seq: Vec<Vec<f64>> = Vec::with_capacity(seq_len);
for t in 0..seq_len {
let dt_t = &dt_seq[t];
let b_t = &b_seq[t];
let c_t = &c_seq[t];
let x_t = &x_ssm[t];
let mut y_t = vec![0.0f64; inner_dim];
for head in 0..nheads {
let dt_val = softplus(dt_t[head] + self.dt_bias[head]);
let a_bar = (-dt_val * self.a_log[head].exp()).exp();
let b_head = &b_t[head * d_state..(head + 1) * d_state];
let c_head = &c_t[head * d_state..(head + 1) * d_state];
let x_head = &x_t[head * headdim..(head + 1) * headdim];
for hd in 0..headdim {
let x_val = x_head[hd];
let mut y_val = self.d_bias[head] * x_val; for s in 0..d_state {
h[head][hd][s] = a_bar * h[head][hd][s] + x_val * b_head[s];
y_val += c_head[s] * h[head][hd][s];
}
y_t[head * headdim + hd] = y_val;
}
}
let z_t = &z_seq[t];
let gated: Vec<f64> = y_t.iter().zip(z_t.iter()).map(|(y, z)| y * silu(*z)).collect();
y_seq.push(gated);
}
let mut result: Vec<Vec<f64>> = Vec::with_capacity(seq_len);
for gated in y_seq.iter() {
result.push(mat_vec_mul(&self.out_proj, gated)?);
}
Ok(result)
}
pub fn a_log(&self) -> &[f64] {
&self.a_log
}
pub fn d_bias(&self) -> &[f64] {
&self.d_bias
}
pub fn config(&self) -> &Mamba2Config {
&self.config
}
}
pub struct Mamba2Block {
ssm: Mamba2SSM,
norm: Mamba2RmsNorm,
}
impl Mamba2Block {
pub fn new(config: &Mamba2Config) -> Self {
Self {
ssm: Mamba2SSM::new(config),
norm: Mamba2RmsNorm::new(config.d_model, config.rms_norm_eps),
}
}
pub fn forward(&self, x: &[Vec<f64>]) -> Result<Vec<Vec<f64>>, Mamba2Error> {
let seq_len = x.len();
if seq_len == 0 {
return Err(Mamba2Error::EmptyInput);
}
let mut normed: Vec<Vec<f64>> = Vec::with_capacity(seq_len);
for token in x.iter() {
normed.push(self.norm.forward(token)?);
}
let ssm_out = self.ssm.forward(&normed)?;
let out: Vec<Vec<f64>> = x
.iter()
.zip(ssm_out.iter())
.map(|(res, s)| res.iter().zip(s.iter()).map(|(a, b)| a + b).collect())
.collect();
Ok(out)
}
}
pub struct Mamba2Model {
embed_tokens: Vec<Vec<f64>>,
layers: Vec<Mamba2Block>,
norm_f: Mamba2RmsNorm,
config: Mamba2Config,
}
impl Mamba2Model {
pub fn new(config: &Mamba2Config) -> Self {
let embed_tokens: Vec<Vec<f64>> = vec![vec![0.0f64; config.d_model]; config.vocab_size];
let layers: Vec<Mamba2Block> =
(0..config.n_layer).map(|_| Mamba2Block::new(config)).collect();
let norm_f = Mamba2RmsNorm::new(config.d_model, config.rms_norm_eps);
Self {
embed_tokens,
layers,
norm_f,
config: config.clone(),
}
}
pub fn forward(&self, input_ids: &[usize]) -> Result<Vec<Vec<f64>>, Mamba2Error> {
let seq_len = input_ids.len();
if seq_len == 0 {
return Err(Mamba2Error::EmptyInput);
}
let mut hidden: Vec<Vec<f64>> = input_ids
.iter()
.map(|&id| {
if id < self.embed_tokens.len() {
self.embed_tokens[id].clone()
} else {
vec![0.0f64; self.config.d_model]
}
})
.collect();
for layer in self.layers.iter() {
hidden = layer.forward(&hidden)?;
}
let mut normed: Vec<Vec<f64>> = Vec::with_capacity(seq_len);
for token in hidden.iter() {
normed.push(self.norm_f.forward(token)?);
}
Ok(normed)
}
pub fn num_layers(&self) -> usize {
self.layers.len()
}
}
pub struct Mamba2ForCausalLM {
backbone: Mamba2Model,
lm_head: Vec<Vec<f64>>,
}
impl Mamba2ForCausalLM {
pub fn new(config: &Mamba2Config) -> Self {
let lm_head: Vec<Vec<f64>> = vec![vec![0.0f64; config.d_model]; config.vocab_size];
Self {
backbone: Mamba2Model::new(config),
lm_head,
}
}
pub fn forward(&self, input_ids: &[usize]) -> Result<Vec<Vec<f64>>, Mamba2Error> {
let hidden = self.backbone.forward(input_ids)?;
let logits: Result<Vec<Vec<f64>>, Mamba2Error> =
hidden.iter().map(|h| mat_vec_mul(&self.lm_head, h)).collect();
logits
}
pub fn config(&self) -> &Mamba2Config {
&self.backbone.config
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mamba2::config::Mamba2Config;
#[test]
fn test_config_presets_valid() {
let cfg_2_7b = Mamba2Config::mamba2_2_7b();
assert!(cfg_2_7b.d_model > 0);
assert!(cfg_2_7b.n_layer > 0);
assert!(cfg_2_7b.d_state > 0);
let cfg_small = Mamba2Config::small_test();
assert!(cfg_small.d_model > 0);
assert!(cfg_small.n_layer > 0);
}
#[test]
fn test_headdim_consistency_2_7b() {
let cfg = Mamba2Config::mamba2_2_7b();
assert!(
cfg.validate(),
"headdim={} should equal inner_dim={} / nheads={}",
cfg.headdim,
cfg.inner_dim(),
cfg.nheads
);
assert_eq!(cfg.headdim * cfg.nheads, cfg.inner_dim());
}
#[test]
fn test_headdim_consistency_small() {
let cfg = Mamba2Config::small_test();
assert!(cfg.validate(), "small_test config should be valid");
assert_eq!(cfg.headdim * cfg.nheads, cfg.inner_dim());
}
#[test]
fn test_d_model_nheads_headdim_relation() {
let cfg = Mamba2Config::small_test();
let expected_headdim = cfg.d_model * cfg.expand / cfg.nheads;
assert_eq!(cfg.headdim, expected_headdim);
}
#[test]
fn test_rmsnorm_forward() {
let norm = Mamba2RmsNorm::new(4, 1e-5);
let x = vec![1.0, 2.0, 3.0, 4.0];
let out = norm.forward(&x).expect("rmsnorm should succeed");
assert_eq!(out.len(), 4);
let mean_sq: f64 = x.iter().map(|v| v * v).sum::<f64>() / 4.0;
let rms = (mean_sq + 1e-5).sqrt();
let expected: Vec<f64> = x.iter().map(|v| v / rms).collect();
for (got, exp) in out.iter().zip(expected.iter()) {
assert!((got - exp).abs() < 1e-9, "got={} exp={}", got, exp);
}
}
#[test]
fn test_rmsnorm_dimension_mismatch() {
let norm = Mamba2RmsNorm::new(4, 1e-5);
let x = vec![1.0, 2.0];
let result = norm.forward(&x);
assert!(result.is_err());
matches!(result.unwrap_err(), Mamba2Error::DimMismatch(_));
}
#[test]
fn test_local_conv_output_size() {
let cfg = Mamba2Config::small_test();
let ssm = Mamba2SSM::new(&cfg);
let seq_len = 8usize;
let inner_dim = cfg.inner_dim();
let x: Vec<Vec<f64>> = vec![vec![0.5f64; inner_dim]; seq_len];
let out = ssm.causal_conv(&x, inner_dim).expect("conv should work");
assert_eq!(out.len(), seq_len, "output seq_len should match input");
assert_eq!(
out[0].len(),
inner_dim,
"output channels should match inner_dim"
);
}
#[test]
fn test_ssm_forward_shape() {
let cfg = Mamba2Config::small_test();
let ssm = Mamba2SSM::new(&cfg);
let seq_len = 5usize;
let x: Vec<Vec<f64>> = vec![vec![0.1f64; cfg.d_model]; seq_len];
let out = ssm.forward(&x).expect("ssm forward should succeed");
assert_eq!(out.len(), seq_len);
assert_eq!(out[0].len(), cfg.d_model);
}
#[test]
fn test_recurrence_state_update() {
let cfg = Mamba2Config::small_test();
let ssm = Mamba2SSM::new(&cfg);
let seq_len = 4usize;
let x: Vec<Vec<f64>> =
(0..seq_len).map(|i| vec![(i + 1) as f64 * 0.1; cfg.d_model]).collect();
let out = ssm.forward(&x).expect("ssm forward");
assert_eq!(out.len(), seq_len);
assert_eq!(out[0].len(), cfg.d_model);
}
#[test]
fn test_d_skip_connection_nonzero() {
let cfg = Mamba2Config::small_test();
let ssm = Mamba2SSM::new(&cfg);
let all_nonzero = ssm.d_bias().iter().all(|&v| v != 0.0);
assert!(all_nonzero, "D skip connection should be non-zero");
}
#[test]
fn test_full_model_forward_small() {
let cfg = Mamba2Config::small_test();
let model = Mamba2ForCausalLM::new(&cfg);
let input_ids = vec![0usize, 1, 2, 3];
let logits = model.forward(&input_ids).expect("full model forward");
assert_eq!(logits.len(), 4, "one logit vector per token");
assert_eq!(logits[0].len(), cfg.vocab_size, "logit dim = vocab_size");
}
#[test]
fn test_lm_head_output_shape() {
let cfg = Mamba2Config::small_test();
let model = Mamba2ForCausalLM::new(&cfg);
let input_ids = vec![0usize, 5, 10];
let logits = model.forward(&input_ids).expect("lm_head forward");
assert_eq!(logits.len(), 3);
for row in logits.iter() {
assert_eq!(row.len(), cfg.vocab_size);
}
}
#[test]
fn test_softplus_function() {
let sp0 = softplus(0.0);
assert!((sp0 - std::f64::consts::LN_2).abs() < 1e-9);
assert!(softplus(-10.0) > 0.0);
assert!(softplus(10.0) > 0.0);
assert!((softplus(100.0) - 100.0).abs() < 0.01);
}
#[test]
fn test_discretization_a_bar_less_than_one() {
let dt = 0.0f64;
let dt_bias = 0.0f64;
let a_log = 0.0f64;
let a_bar = (-softplus(dt + dt_bias) * a_log.exp()).exp();
assert!(a_bar < 1.0, "A_bar={} should be < 1 for stability", a_bar);
assert!(a_bar > 0.0, "A_bar should be positive");
let a_bar_large = (-softplus(5.0) * a_log.exp()).exp();
assert!(a_bar_large < a_bar, "larger dt => smaller A_bar");
}
#[test]
fn test_empty_input_error() {
let cfg = Mamba2Config::small_test();
let model = Mamba2ForCausalLM::new(&cfg);
let result = model.forward(&[]);
assert!(result.is_err());
matches!(result.unwrap_err(), Mamba2Error::EmptyInput);
}
}