use crate::error::{Error, Result};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SsmConfig {
pub variant: String,
pub num_heads: usize,
pub head_dim: usize,
pub state_size: usize,
pub chunk_size: usize,
#[serde(default = "default_n_groups")]
pub n_groups: usize,
#[serde(default = "default_conv_kernel")]
pub conv_kernel: usize,
#[serde(default = "default_expand")]
pub expand: usize,
#[serde(default)]
pub complex_rope: Option<bool>,
#[serde(default)]
pub mimo_rank: Option<usize>,
#[serde(default)]
pub use_conv: Option<bool>,
}
fn default_n_groups() -> usize {
1
}
fn default_conv_kernel() -> usize {
4
}
fn default_expand() -> usize {
2
}
impl SsmConfig {
pub fn validate(&self, hidden_size: usize) -> Result<()> {
if self.num_heads == 0 {
return Err(Error::ModelError {
reason: "ssm.num_heads must be > 0".into(),
});
}
if self.head_dim == 0 {
return Err(Error::ModelError {
reason: "ssm.head_dim must be > 0".into(),
});
}
if self.state_size == 0 {
return Err(Error::ModelError {
reason: "ssm.state_size must be > 0".into(),
});
}
if self.chunk_size == 0 {
return Err(Error::ModelError {
reason: "ssm.chunk_size must be > 0".into(),
});
}
let expected = self.num_heads * self.head_dim;
let actual = hidden_size * self.expand;
if actual != expected {
return Err(Error::ModelError {
reason: format!(
"Mamba2 constraint violated: hidden_size * expand ({actual}) != num_heads * head_dim ({expected})"
),
});
}
Ok(())
}
pub fn is_mamba3(&self) -> bool {
self.variant == "mamba3"
}
pub fn intermediate_dim(&self, hidden_size: usize) -> usize {
hidden_size * self.expand
}
}