use crate::error::{Error, Result};
#[derive(Debug, Clone)]
pub struct Mamba2Config {
pub d_model: usize,
pub d_state: usize,
pub nheads: usize,
pub headdim: usize,
pub expand: usize,
pub ngroups: usize,
pub d_conv: usize,
pub chunk_size: usize,
pub use_dt_bias: bool,
pub use_d: bool,
pub dt_softplus: bool,
}
impl Mamba2Config {
pub fn new(d_model: usize) -> Self {
let expand = 2;
let nheads = (d_model / 64).max(1);
let headdim = d_model * expand / nheads;
Self {
d_model,
d_state: 64,
nheads,
headdim,
expand,
ngroups: 1,
d_conv: 4,
chunk_size: 64,
use_dt_bias: true,
use_d: true,
dt_softplus: true,
}
}
pub fn with_nheads(mut self, nheads: usize) -> Self {
self.nheads = nheads;
self.headdim = self.d_model * self.expand / nheads;
self
}
pub fn with_d_state(mut self, d_state: usize) -> Self {
self.d_state = d_state;
self
}
pub fn with_expand(mut self, expand: usize) -> Self {
self.expand = expand;
self.headdim = self.d_model * expand / self.nheads;
self
}
pub fn with_d_conv(mut self, d_conv: usize) -> Self {
self.d_conv = d_conv;
self
}
pub fn with_chunk_size(mut self, chunk_size: usize) -> Self {
self.chunk_size = chunk_size;
self
}
pub fn with_dt_softplus(mut self, dt_softplus: bool) -> Self {
self.dt_softplus = dt_softplus;
self
}
pub fn with_use_d(mut self, use_d: bool) -> Self {
self.use_d = use_d;
self
}
pub fn with_use_dt_bias(mut self, use_dt_bias: bool) -> Self {
self.use_dt_bias = use_dt_bias;
self
}
pub fn from_universal(
config: &crate::model::config::UniversalConfig,
) -> crate::error::Result<Self> {
let ssm = config
.ssm
.as_ref()
.ok_or_else(|| crate::error::Error::ModelError {
reason: "Mamba2 requires ssm config section".into(),
})?;
let expand = ssm.expand;
let nheads = ssm.num_heads;
let headdim = ssm.head_dim;
Ok(Self {
d_model: config.hidden_size,
d_state: ssm.state_size,
nheads,
headdim,
expand,
ngroups: ssm.n_groups,
d_conv: ssm.conv_kernel,
chunk_size: ssm.chunk_size,
use_dt_bias: true,
use_d: true,
dt_softplus: true,
})
}
pub fn d_inner(&self) -> usize {
self.d_model * self.expand
}
pub fn proj_dim(&self) -> usize {
2 * self.d_inner() + 2 * self.ngroups * self.d_state + self.nheads
}
pub fn conv_channels(&self) -> usize {
self.d_inner() + 2 * self.ngroups * self.d_state
}
pub fn validate(&self) -> Result<()> {
if self.ngroups != 1 && self.ngroups != self.nheads {
return Err(Error::ModelError {
reason: format!(
"ngroups must be 1 or nheads ({}), got {}",
self.nheads, self.ngroups
),
});
}
if self.nheads == 0 || self.headdim == 0 {
return Err(Error::ModelError {
reason: "nheads and headdim must be > 0".to_string(),
});
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mamba2_config_defaults() {
let config = Mamba2Config::new(256);
assert_eq!(config.d_model, 256);
assert_eq!(config.d_state, 64);
assert_eq!(config.expand, 2);
assert_eq!(config.d_inner(), 512);
assert_eq!(config.nheads, 4);
assert_eq!(config.headdim, 128);
assert!(config.dt_softplus);
assert!(config.use_dt_bias);
assert!(config.use_d);
}
#[test]
fn test_mamba2_config_builders() {
let config = Mamba2Config::new(256)
.with_nheads(8)
.with_d_state(128)
.with_expand(3);
assert_eq!(config.nheads, 8);
assert_eq!(config.d_state, 128);
assert_eq!(config.expand, 3);
assert_eq!(config.d_inner(), 768);
assert_eq!(config.headdim, 96);
}
#[test]
fn test_mamba2_config_validation() {
let valid = Mamba2Config::new(256);
assert!(valid.validate().is_ok());
let mut cfg = Mamba2Config::new(256);
cfg.ngroups = 1;
assert!(cfg.validate().is_ok());
cfg.ngroups = cfg.nheads;
assert!(cfg.validate().is_ok());
cfg.ngroups = 2;
assert!(cfg.validate().is_err());
}
}