use std::fmt;
use crate::common::PlasticityConfig;
use crate::error::ConfigError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum MambaVersion {
V1,
V3,
V3Exp {
use_bcnorm: bool,
},
V3Mimo {
rank: usize,
use_bcnorm: bool,
},
BlockDiagonal {
block_size: usize,
},
}
#[derive(Debug, Clone)]
pub struct MambaConfig {
pub d_in: usize,
pub n_state: usize,
pub forgetting_factor: f64,
pub delta_rls: f64,
pub seed: u64,
pub warmup: usize,
pub version: MambaVersion,
pub n_groups: usize,
pub block_size: usize,
pub plasticity: Option<PlasticityConfig>,
}
impl MambaConfig {
pub fn builder() -> MambaConfigBuilder {
MambaConfigBuilder::default()
}
}
impl fmt::Display for MambaConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.version {
MambaVersion::V1 => write!(
f,
"MambaConfig(v1, d_in={}, n_state={}, ff={}, delta={}, seed={}, warmup={})",
self.d_in, self.n_state, self.forgetting_factor, self.delta_rls, self.seed,
self.warmup
),
MambaVersion::V3 => write!(
f,
"MambaConfig(v3, d_in={}, n_state={}, n_groups={}, ff={}, delta={}, seed={}, warmup={})",
self.d_in, self.n_state, self.n_groups, self.forgetting_factor, self.delta_rls,
self.seed, self.warmup
),
MambaVersion::V3Exp { use_bcnorm } => write!(
f,
"MambaConfig(v3exp, d_in={}, n_state={}, n_groups={}, bcnorm={}, ff={}, delta={}, seed={}, warmup={})",
self.d_in, self.n_state, self.n_groups, use_bcnorm,
self.forgetting_factor, self.delta_rls, self.seed, self.warmup
),
MambaVersion::V3Mimo { rank, use_bcnorm } => write!(
f,
"MambaConfig(v3mimo, d_in={}, n_state={}, n_groups={}, rank={}, bcnorm={}, ff={}, delta={}, seed={}, warmup={})",
self.d_in, self.n_state, self.n_groups, rank, use_bcnorm,
self.forgetting_factor, self.delta_rls, self.seed, self.warmup
),
MambaVersion::BlockDiagonal { block_size } => write!(
f,
"MambaConfig(bd, d_in={}, n_state={}, block_size={}, ff={}, delta={}, seed={}, warmup={})",
self.d_in, self.n_state, block_size, self.forgetting_factor, self.delta_rls,
self.seed, self.warmup
),
}
}
}
#[derive(Debug)]
pub struct MambaConfigBuilder {
d_in: Option<usize>,
n_state: usize,
forgetting_factor: f64,
delta_rls: f64,
seed: u64,
warmup: usize,
version: MambaVersion,
n_groups: usize,
block_size: usize,
rank: usize,
plasticity: Option<PlasticityConfig>,
}
impl Default for MambaConfigBuilder {
fn default() -> Self {
Self {
d_in: None,
n_state: 32,
forgetting_factor: 0.998,
delta_rls: 100.0,
seed: 42,
warmup: 10,
version: MambaVersion::V1,
n_groups: 1,
block_size: 4,
rank: 1,
plasticity: None,
}
}
}
impl MambaConfigBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn d_in(mut self, d_in: usize) -> Self {
self.d_in = Some(d_in);
self
}
pub fn n_state(mut self, n_state: usize) -> Self {
self.n_state = n_state;
self
}
pub fn forgetting_factor(mut self, ff: f64) -> Self {
self.forgetting_factor = ff;
self
}
pub fn delta_rls(mut self, delta: f64) -> Self {
self.delta_rls = delta;
self
}
pub fn seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
pub fn warmup(mut self, warmup: usize) -> Self {
self.warmup = warmup;
self
}
pub fn version(mut self, version: MambaVersion) -> Self {
self.version = version;
self
}
pub fn n_groups(mut self, n_groups: usize) -> Self {
self.n_groups = n_groups;
self
}
pub fn block_size(mut self, block_size: usize) -> Self {
self.block_size = block_size;
self
}
pub fn rank(mut self, rank: usize) -> Self {
self.rank = rank;
self
}
pub fn plasticity(mut self, p: Option<PlasticityConfig>) -> Self {
self.plasticity = p;
self
}
pub fn build(self) -> Result<MambaConfig, ConfigError> {
let d_in = self.d_in.ok_or_else(|| {
ConfigError::invalid("d_in", "d_in must be set (input feature dimension)")
})?;
if d_in < 1 {
return Err(ConfigError::out_of_range("d_in", "must be >= 1", d_in));
}
if self.n_state < 1 {
return Err(ConfigError::out_of_range(
"n_state",
"must be >= 1",
self.n_state,
));
}
if self.forgetting_factor <= 0.0 || self.forgetting_factor > 1.0 {
return Err(ConfigError::out_of_range(
"forgetting_factor",
"must be in (0, 1]",
self.forgetting_factor,
));
}
if self.delta_rls <= 0.0 {
return Err(ConfigError::out_of_range(
"delta_rls",
"must be > 0",
self.delta_rls,
));
}
let derive_n_groups =
|requested: usize, version_name: &'static str| -> Result<usize, ConfigError> {
let g = if requested == 0 {
let target = (d_in / 4).max(1);
(1..=target).rev().find(|&g| d_in % g == 0).unwrap_or(1)
} else {
requested
};
if g < 1 {
return Err(ConfigError::out_of_range("n_groups", version_name, g));
}
if d_in % g != 0 {
return Err(ConfigError::invalid(
"n_groups",
format!(
"n_groups ({}) must divide d_in ({}) evenly for {}",
g, d_in, version_name
),
));
}
Ok(g)
};
let (n_groups, block_size, version) = match self.version {
MambaVersion::V1 => {
(1, 1, MambaVersion::V1)
}
MambaVersion::V3 => {
let g = derive_n_groups(self.n_groups, "V3")?;
(g, 1, MambaVersion::V3)
}
MambaVersion::V3Exp { use_bcnorm } => {
let g = derive_n_groups(self.n_groups, "V3Exp")?;
(g, 1, MambaVersion::V3Exp { use_bcnorm })
}
MambaVersion::V3Mimo {
rank: _,
use_bcnorm,
} => {
let g = derive_n_groups(self.n_groups, "V3Mimo")?;
let r = self.rank;
if r < 1 {
return Err(ConfigError::out_of_range(
"rank",
"must be >= 1 for V3Mimo (rank=1 is standard outer product)",
r,
));
}
if r > 16 {
return Err(ConfigError::out_of_range(
"rank",
"must be <= 16 for V3Mimo (parameter count scales with rank)",
r,
));
}
(
g,
1,
MambaVersion::V3Mimo {
rank: r,
use_bcnorm,
},
)
}
MambaVersion::BlockDiagonal { block_size: _ } => {
let bs = self.block_size;
if bs < 2 {
return Err(ConfigError::out_of_range(
"block_size",
"must be >= 2 for BlockDiagonal (use V1 for block_size=1)",
bs,
));
}
if bs > 16 {
return Err(ConfigError::out_of_range(
"block_size",
"must be <= 16 for BlockDiagonal (dense matmul cost is O(m^2))",
bs,
));
}
if d_in % bs != 0 {
return Err(ConfigError::invalid(
"block_size",
format!(
"block_size ({}) must divide d_in ({}) evenly for BlockDiagonal",
bs, d_in
),
));
}
(1, bs, MambaVersion::BlockDiagonal { block_size: bs })
}
};
Ok(MambaConfig {
d_in,
n_state: self.n_state,
forgetting_factor: self.forgetting_factor,
delta_rls: self.delta_rls,
seed: self.seed,
warmup: self.warmup,
version,
n_groups,
block_size,
plasticity: self.plasticity.clone(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_defaults() {
let config = MambaConfig::builder().d_in(4).build().unwrap();
assert_eq!(config.d_in, 4);
assert_eq!(config.n_state, 32);
assert!((config.forgetting_factor - 0.998).abs() < 1e-12);
assert!((config.delta_rls - 100.0).abs() < 1e-12);
assert_eq!(config.seed, 42);
assert_eq!(config.warmup, 10);
assert_eq!(
config.version,
MambaVersion::V1,
"default version should be V1"
);
assert_eq!(config.n_groups, 1, "V1 n_groups should always be 1");
}
#[test]
fn builder_custom_values() {
let config = MambaConfig::builder()
.d_in(8)
.n_state(32)
.forgetting_factor(0.99)
.delta_rls(50.0)
.seed(123)
.warmup(5)
.build()
.unwrap();
assert_eq!(config.d_in, 8);
assert_eq!(config.n_state, 32);
assert!((config.forgetting_factor - 0.99).abs() < 1e-12);
assert!((config.delta_rls - 50.0).abs() < 1e-12);
assert_eq!(config.seed, 123);
assert_eq!(config.warmup, 5);
}
#[test]
fn builder_missing_d_in() {
let result = MambaConfig::builder().build();
assert!(result.is_err(), "should fail without d_in");
}
#[test]
fn builder_invalid_n_state() {
let result = MambaConfig::builder().d_in(4).n_state(0).build();
assert!(result.is_err(), "n_state=0 should be invalid");
}
#[test]
fn builder_invalid_forgetting_factor_zero() {
let result = MambaConfig::builder()
.d_in(4)
.forgetting_factor(0.0)
.build();
assert!(result.is_err(), "ff=0 should be invalid");
}
#[test]
fn builder_invalid_forgetting_factor_negative() {
let result = MambaConfig::builder()
.d_in(4)
.forgetting_factor(-0.5)
.build();
assert!(result.is_err(), "ff=-0.5 should be invalid");
}
#[test]
fn builder_invalid_forgetting_factor_over_one() {
let result = MambaConfig::builder()
.d_in(4)
.forgetting_factor(1.01)
.build();
assert!(result.is_err(), "ff=1.01 should be invalid");
}
#[test]
fn builder_forgetting_factor_one_valid() {
let config = MambaConfig::builder()
.d_in(4)
.forgetting_factor(1.0)
.build()
.unwrap();
assert!((config.forgetting_factor - 1.0).abs() < 1e-12);
}
#[test]
fn builder_invalid_delta_rls() {
let result = MambaConfig::builder().d_in(4).delta_rls(0.0).build();
assert!(result.is_err(), "delta_rls=0 should be invalid");
let result = MambaConfig::builder().d_in(4).delta_rls(-1.0).build();
assert!(result.is_err(), "delta_rls=-1 should be invalid");
}
#[test]
fn display_format() {
let config = MambaConfig::builder().d_in(4).build().unwrap();
let s = format!("{}", config);
assert!(s.contains("d_in=4"), "display should contain d_in");
assert!(s.contains("n_state=32"), "display should contain n_state");
}
#[test]
fn config_clone() {
let config = MambaConfig::builder().d_in(4).seed(99).build().unwrap();
let cloned = config.clone();
assert_eq!(cloned.d_in, config.d_in);
assert_eq!(cloned.seed, config.seed);
}
#[test]
fn mamba_version_default_is_v1() {
let config = MambaConfig::builder().d_in(8).build().unwrap();
assert_eq!(
config.version,
MambaVersion::V1,
"default version should be V1"
);
assert_eq!(config.n_groups, 1, "V1 should have n_groups=1");
}
#[test]
fn v3_explicit_n_groups() {
let config = MambaConfig::builder()
.d_in(8)
.version(MambaVersion::V3)
.n_groups(2)
.build()
.unwrap();
assert_eq!(config.version, MambaVersion::V3);
assert_eq!(config.n_groups, 2, "should use explicit n_groups=2");
}
#[test]
fn v3_auto_derive_n_groups() {
let config = MambaConfig::builder()
.d_in(16)
.version(MambaVersion::V3)
.n_groups(0)
.build()
.unwrap();
assert_eq!(
config.n_groups, 4,
"auto-derived n_groups should be d_in/4 = 4"
);
}
#[test]
fn v3_auto_derive_n_groups_small_d_in() {
let config = MambaConfig::builder()
.d_in(2)
.version(MambaVersion::V3)
.n_groups(0)
.build()
.unwrap();
assert_eq!(
config.n_groups, 1,
"auto-derived n_groups should clamp to 1 for small d_in"
);
}
#[test]
fn v3_n_groups_must_divide_d_in() {
let result = MambaConfig::builder()
.d_in(7)
.version(MambaVersion::V3)
.n_groups(3)
.build();
assert!(result.is_err(), "n_groups=3 should not divide d_in=7");
}
#[test]
fn v1_ignores_n_groups() {
let config = MambaConfig::builder()
.d_in(8)
.version(MambaVersion::V1)
.n_groups(4)
.build()
.unwrap();
assert_eq!(config.n_groups, 1, "V1 should ignore n_groups and store 1");
}
#[test]
fn display_format_v3() {
let config = MambaConfig::builder()
.d_in(8)
.version(MambaVersion::V3)
.n_groups(2)
.build()
.unwrap();
let s = format!("{}", config);
assert!(s.contains("v3"), "V3 display should contain 'v3'");
assert!(
s.contains("n_groups=2"),
"V3 display should contain n_groups"
);
}
#[test]
fn bd_basic_config() {
let config = MambaConfig::builder()
.d_in(8)
.n_state(16)
.version(MambaVersion::BlockDiagonal { block_size: 4 })
.block_size(4)
.build()
.unwrap();
assert_eq!(
config.version,
MambaVersion::BlockDiagonal { block_size: 4 }
);
assert_eq!(config.block_size, 4);
assert_eq!(config.d_in, 8);
}
#[test]
fn bd_block_size_must_divide_d_in() {
let result = MambaConfig::builder()
.d_in(7)
.version(MambaVersion::BlockDiagonal { block_size: 4 })
.block_size(4)
.build();
assert!(result.is_err(), "block_size=4 should not divide d_in=7");
}
#[test]
fn bd_block_size_too_small() {
let result = MambaConfig::builder()
.d_in(4)
.version(MambaVersion::BlockDiagonal { block_size: 1 })
.block_size(1)
.build();
assert!(result.is_err(), "block_size=1 should be invalid (use V1)");
}
#[test]
fn bd_block_size_too_large() {
let result = MambaConfig::builder()
.d_in(32)
.version(MambaVersion::BlockDiagonal { block_size: 32 })
.block_size(32)
.build();
assert!(result.is_err(), "block_size=32 should exceed maximum 16");
}
#[test]
fn bd_display_format() {
let config = MambaConfig::builder()
.d_in(8)
.version(MambaVersion::BlockDiagonal { block_size: 4 })
.block_size(4)
.build()
.unwrap();
let s = format!("{}", config);
assert!(s.contains("bd"), "BD display should contain 'bd'");
assert!(
s.contains("block_size=4"),
"BD display should contain block_size"
);
}
#[test]
fn bd_various_block_sizes() {
for bs in [2, 4, 8] {
let config = MambaConfig::builder()
.d_in(8)
.version(MambaVersion::BlockDiagonal { block_size: bs })
.block_size(bs)
.build()
.unwrap();
assert_eq!(config.block_size, bs, "block_size should be {}", bs);
}
}
#[test]
fn v3exp_basic_config() {
let config = MambaConfig::builder()
.d_in(8)
.n_state(16)
.version(MambaVersion::V3Exp { use_bcnorm: false })
.n_groups(2)
.build()
.unwrap();
assert_eq!(config.version, MambaVersion::V3Exp { use_bcnorm: false });
assert_eq!(config.n_groups, 2);
assert_eq!(config.d_in, 8);
}
#[test]
fn v3exp_with_bcnorm() {
let config = MambaConfig::builder()
.d_in(8)
.version(MambaVersion::V3Exp { use_bcnorm: true })
.n_groups(2)
.build()
.unwrap();
assert_eq!(config.version, MambaVersion::V3Exp { use_bcnorm: true });
}
#[test]
fn v3exp_auto_derive_n_groups() {
let config = MambaConfig::builder()
.d_in(16)
.version(MambaVersion::V3Exp { use_bcnorm: false })
.n_groups(0)
.build()
.unwrap();
assert_eq!(config.n_groups, 4, "auto-derived n_groups=d_in/4=4");
}
#[test]
fn v3exp_n_groups_must_divide_d_in() {
let result = MambaConfig::builder()
.d_in(7)
.version(MambaVersion::V3Exp { use_bcnorm: false })
.n_groups(3)
.build();
assert!(
result.is_err(),
"n_groups=3 must not divide d_in=7 for V3Exp"
);
}
#[test]
fn v3exp_display_format() {
let config = MambaConfig::builder()
.d_in(8)
.version(MambaVersion::V3Exp { use_bcnorm: false })
.n_groups(2)
.build()
.unwrap();
let s = format!("{}", config);
assert!(s.contains("v3exp"), "V3Exp display should contain 'v3exp'");
assert!(
s.contains("n_groups=2"),
"V3Exp display should contain n_groups"
);
assert!(
s.contains("bcnorm="),
"V3Exp display should contain bcnorm flag"
);
}
#[test]
fn v3mimo_basic_config() {
let config = MambaConfig::builder()
.d_in(8)
.n_state(16)
.version(MambaVersion::V3Mimo {
rank: 1,
use_bcnorm: false,
})
.n_groups(2)
.rank(1)
.build()
.unwrap();
assert_eq!(
config.version,
MambaVersion::V3Mimo {
rank: 1,
use_bcnorm: false
}
);
assert_eq!(config.n_groups, 2);
}
#[test]
fn v3mimo_rank_4() {
let config = MambaConfig::builder()
.d_in(8)
.version(MambaVersion::V3Mimo {
rank: 4,
use_bcnorm: true,
})
.n_groups(2)
.rank(4)
.build()
.unwrap();
assert_eq!(
config.version,
MambaVersion::V3Mimo {
rank: 4,
use_bcnorm: true
}
);
}
#[test]
fn v3mimo_rank_too_large() {
let result = MambaConfig::builder()
.d_in(8)
.version(MambaVersion::V3Mimo {
rank: 1,
use_bcnorm: false,
})
.rank(32) .n_groups(2)
.build();
assert!(result.is_err(), "rank=32 should be invalid (> 16)");
}
#[test]
fn v3mimo_rank_zero_invalid() {
let result = MambaConfig::builder()
.d_in(8)
.version(MambaVersion::V3Mimo {
rank: 1,
use_bcnorm: false,
})
.rank(0)
.n_groups(2)
.build();
assert!(result.is_err(), "rank=0 should be invalid");
}
#[test]
fn v3mimo_display_format() {
let config = MambaConfig::builder()
.d_in(8)
.version(MambaVersion::V3Mimo {
rank: 2,
use_bcnorm: false,
})
.n_groups(2)
.rank(2)
.build()
.unwrap();
let s = format!("{}", config);
assert!(
s.contains("v3mimo"),
"V3Mimo display should contain 'v3mimo'"
);
assert!(s.contains("rank=2"), "V3Mimo display should contain rank");
assert!(
s.contains("n_groups=2"),
"V3Mimo display should contain n_groups"
);
}
#[test]
fn v3mimo_readout_dim_smaller_than_v3exp() {
use crate::StreamingMamba;
let config_exp = MambaConfig::builder()
.d_in(8)
.version(MambaVersion::V3Exp { use_bcnorm: false })
.n_groups(2)
.build()
.unwrap();
let config_mimo = MambaConfig::builder()
.d_in(8)
.version(MambaVersion::V3Mimo {
rank: 1,
use_bcnorm: false,
})
.n_groups(2)
.rank(1)
.build()
.unwrap();
let m_exp = StreamingMamba::new(config_exp);
let m_mimo = StreamingMamba::new(config_mimo);
assert!(
m_exp.last_features().len() > m_mimo.last_features().len(),
"V3Exp readout dim ({}) should exceed V3Mimo readout dim ({}) — \
V3Exp surfaces complex-state features and a random tanh lift \
that V3Mimo does not.",
m_exp.last_features().len(),
m_mimo.last_features().len()
);
assert_eq!(
m_mimo.last_features().len(),
10,
"V3Mimo readout dim should be d_in+n_groups = 10"
);
}
}