use burn::module::{Module, Param};
use burn::tensor::backend::Backend;
use burn::tensor::{Distribution as TensorDistribution, Tensor, TensorData, activation};
use burn_dragon_kernel::kernels::sequence::mamba3::forward::{
Mamba3TensorizedState, tensorized_mamba3_forward,
};
use serde::{Deserialize, Serialize};
use super::config::SequenceMemorySystem;
#[allow(dead_code)]
pub const MAMBA1_UPSTREAM_REPO: &str = "https://github.com/state-spaces/mamba";
#[allow(dead_code)]
pub const MAMBA1_UPSTREAM_COMMIT: &str = "c5afbdf";
#[allow(dead_code)]
pub const MAMBA2_UPSTREAM_REPO: &str = "https://github.com/state-spaces/mamba";
fn default_mamba_d_state() -> usize {
16
}
fn default_mamba_d_conv() -> usize {
4
}
fn default_mamba_expand() -> usize {
2
}
fn default_mamba_dt_min() -> f32 {
1.0e-3
}
fn default_mamba_dt_max() -> f32 {
1.0e-1
}
fn default_mamba_dt_scale() -> f32 {
1.0
}
fn default_mamba_headdim() -> usize {
128
}
fn default_mamba_ngroups() -> usize {
1
}
fn default_mamba_a_init_min() -> f32 {
1.0
}
fn default_mamba_a_init_max() -> f32 {
16.0
}
fn default_mamba_norm_eps() -> f32 {
1.0e-5
}
fn default_mamba_rope_fraction() -> f32 {
0.5
}
fn default_mamba_dt_init_floor() -> f32 {
1.0e-4
}
fn default_mamba_a_floor() -> f32 {
1.0e-4
}
fn default_mamba_chunk_size() -> usize {
64
}
fn default_true() -> bool {
true
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct MambaSequenceConfig {
#[serde(default = "default_mamba_d_state")]
pub d_state: usize,
#[serde(default = "default_mamba_d_conv")]
pub d_conv: usize,
#[serde(default = "default_mamba_expand")]
pub expand: usize,
#[serde(default)]
pub dt_rank: Option<usize>,
#[serde(default = "default_mamba_dt_min")]
pub dt_min: f32,
#[serde(default = "default_mamba_dt_max")]
pub dt_max: f32,
#[serde(default = "default_mamba_dt_scale")]
pub dt_scale: f32,
#[serde(default = "default_true")]
pub conv_bias: bool,
#[serde(default = "default_true")]
pub use_fast_path: bool,
#[serde(default = "default_mamba_headdim")]
pub headdim: usize,
#[serde(default = "default_mamba_ngroups")]
pub ngroups: usize,
#[serde(default = "default_mamba_a_init_min")]
pub a_init_min: f32,
#[serde(default = "default_mamba_a_init_max")]
pub a_init_max: f32,
#[serde(default = "default_mamba_norm_eps")]
pub norm_eps: f32,
#[serde(default = "default_mamba_rope_fraction")]
pub rope_fraction: f32,
#[serde(default = "default_mamba_dt_init_floor")]
pub dt_init_floor: f32,
#[serde(default = "default_mamba_a_floor")]
pub a_floor: f32,
#[serde(default = "default_mamba_chunk_size")]
pub chunk_size: usize,
#[serde(default)]
pub is_outproj_norm: bool,
#[serde(default)]
pub is_mimo: bool,
#[serde(default = "default_mamba_ngroups")]
pub mimo_rank: usize,
}
impl Default for MambaSequenceConfig {
fn default() -> Self {
Self {
d_state: default_mamba_d_state(),
d_conv: default_mamba_d_conv(),
expand: default_mamba_expand(),
dt_rank: None,
dt_min: default_mamba_dt_min(),
dt_max: default_mamba_dt_max(),
dt_scale: default_mamba_dt_scale(),
conv_bias: default_true(),
use_fast_path: default_true(),
headdim: default_mamba_headdim(),
ngroups: default_mamba_ngroups(),
a_init_min: default_mamba_a_init_min(),
a_init_max: default_mamba_a_init_max(),
norm_eps: default_mamba_norm_eps(),
rope_fraction: default_mamba_rope_fraction(),
dt_init_floor: default_mamba_dt_init_floor(),
a_floor: default_mamba_a_floor(),
chunk_size: default_mamba_chunk_size(),
is_outproj_norm: false,
is_mimo: false,
mimo_rank: default_mamba_ngroups(),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct ResolvedMambaSequenceConfig {
pub d_model: usize,
pub d_inner: usize,
pub d_state: usize,
pub d_conv: usize,
pub dt_rank: usize,
pub dt_min: f32,
pub dt_max: f32,
pub dt_scale: f32,
pub conv_bias: bool,
pub use_fast_path: bool,
pub headdim: usize,
pub ngroups: usize,
pub nheads: usize,
pub a_init_min: f32,
pub a_init_max: f32,
pub norm_eps: f32,
pub rope_fraction: f32,
pub dt_init_floor: f32,
pub a_floor: f32,
pub chunk_size: usize,
pub is_outproj_norm: bool,
pub is_mimo: bool,
pub mimo_rank: usize,
pub num_rope_angles: usize,
}
impl ResolvedMambaSequenceConfig {
pub fn mamba2_conv_dim(self) -> usize {
self.d_inner + 2 * self.ngroups * self.d_state
}
pub fn mamba2_in_proj_dim(self) -> usize {
2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
}
pub fn mamba3_in_proj_dim(self) -> usize {
2 * self.d_inner
+ 2 * self.ngroups * self.mimo_rank * self.d_state
+ 3 * self.nheads
+ self.num_rope_angles
}
}
impl MambaSequenceConfig {
pub fn validate(
&self,
memory_system: SequenceMemorySystem,
d_model: usize,
) -> Result<(), String> {
if self.d_state == 0 {
return Err("d_state must be positive".to_string());
}
if self.d_conv == 0 {
return Err("d_conv must be positive".to_string());
}
if self.expand == 0 {
return Err("expand must be positive".to_string());
}
if self.dt_min <= 0.0 || !self.dt_min.is_finite() {
return Err("dt_min must be finite and positive".to_string());
}
if self.dt_max < self.dt_min || !self.dt_max.is_finite() {
return Err("dt_max must be finite and >= dt_min".to_string());
}
if self.dt_scale <= 0.0 || !self.dt_scale.is_finite() {
return Err("dt_scale must be finite and positive".to_string());
}
let d_inner = d_model.max(1) * self.expand.max(1);
if matches!(
memory_system,
SequenceMemorySystem::Mamba2StateSpaceDuality
| SequenceMemorySystem::Mamba3StateSpaceDuality
) {
if self.headdim == 0 {
return Err(format!("headdim must be positive for {memory_system:?}"));
}
if d_inner % self.headdim != 0 {
return Err(format!(
"{memory_system:?} requires d_inner divisible by headdim (got d_inner={d_inner} headdim={})",
self.headdim
));
}
let nheads = d_inner / self.headdim;
if self.ngroups == 0 {
return Err(format!("ngroups must be positive for {memory_system:?}"));
}
if nheads % self.ngroups != 0 {
return Err(format!(
"{memory_system:?} requires nheads divisible by ngroups (got nheads={nheads} ngroups={})",
self.ngroups
));
}
if self.norm_eps <= 0.0 || !self.norm_eps.is_finite() {
return Err("norm_eps must be finite and positive".to_string());
}
}
if matches!(memory_system, SequenceMemorySystem::Mamba2StateSpaceDuality) {
if self.a_init_min <= 0.0
|| self.a_init_max < self.a_init_min
|| !self.a_init_min.is_finite()
|| !self.a_init_max.is_finite()
{
return Err("a_init range must be finite, positive, and ordered".to_string());
}
}
if matches!(memory_system, SequenceMemorySystem::Mamba3StateSpaceDuality) {
if (self.rope_fraction - 0.5).abs() > 1.0e-6
&& (self.rope_fraction - 1.0).abs() > 1.0e-6
{
return Err(
"mamba3_state_space_duality currently supports rope_fraction = 0.5 or 1.0"
.to_string(),
);
}
if self.dt_init_floor <= 0.0 || !self.dt_init_floor.is_finite() {
return Err("dt_init_floor must be finite and positive".to_string());
}
if self.a_floor <= 0.0 || !self.a_floor.is_finite() {
return Err("a_floor must be finite and positive".to_string());
}
if self.chunk_size == 0 {
return Err(
"chunk_size must be positive for mamba3_state_space_duality".to_string()
);
}
if self.is_mimo {
return Err("mamba3_state_space_duality MIMO is not implemented in burn_dragon yet; set model.mamba.is_mimo = false".to_string());
}
let split_tensor_size = ((self.d_state as f32) * self.rope_fraction).floor() as usize;
let split_tensor_size = split_tensor_size - (split_tensor_size % 2);
if split_tensor_size < 2 {
return Err("mamba3_state_space_duality requires at least one rotary pair in d_state * rope_fraction".to_string());
}
}
Ok(())
}
pub fn resolve(
&self,
d_model: usize,
memory_system: SequenceMemorySystem,
) -> ResolvedMambaSequenceConfig {
self.validate(memory_system, d_model)
.unwrap_or_else(|message| panic!("{message}"));
let d_model = d_model.max(1);
let d_state = self.d_state.max(1);
let d_conv = self.d_conv.max(1);
let expand = self.expand.max(1);
let d_inner = d_model * expand;
let dt_rank = self.dt_rank.unwrap_or_else(|| d_model.div_ceil(16)).max(1);
let headdim = self.headdim.max(1);
let nheads = if matches!(
memory_system,
SequenceMemorySystem::Mamba2StateSpaceDuality
| SequenceMemorySystem::Mamba3StateSpaceDuality
) {
d_inner / headdim
} else {
0
};
let split_tensor_size = ((d_state as f32) * self.rope_fraction).floor() as usize;
let split_tensor_size = split_tensor_size - (split_tensor_size % 2);
let num_rope_angles =
if matches!(memory_system, SequenceMemorySystem::Mamba3StateSpaceDuality) {
(split_tensor_size / 2).max(1)
} else {
0
};
ResolvedMambaSequenceConfig {
d_model,
d_inner,
d_state,
d_conv,
dt_rank,
dt_min: self.dt_min.max(1.0e-6),
dt_max: self.dt_max.max(self.dt_min.max(1.0e-6)),
dt_scale: self.dt_scale.max(1.0e-6),
conv_bias: self.conv_bias,
use_fast_path: self.use_fast_path,
headdim,
ngroups: self.ngroups.max(1),
nheads,
a_init_min: self.a_init_min.max(1.0e-6),
a_init_max: self.a_init_max.max(self.a_init_min.max(1.0e-6)),
norm_eps: self.norm_eps.max(1.0e-8),
rope_fraction: self.rope_fraction,
dt_init_floor: self.dt_init_floor.max(1.0e-6),
a_floor: self.a_floor.max(1.0e-6),
chunk_size: self.chunk_size.max(1),
is_outproj_norm: self.is_outproj_norm,
is_mimo: self.is_mimo,
mimo_rank: self.mimo_rank.max(1),
num_rope_angles,
}
}
}
#[derive(Module, Debug)]
pub struct Mamba1SequenceParameters<B: Backend> {
d_model: usize,
d_inner: usize,
d_state: usize,
d_conv: usize,
dt_rank: usize,
in_proj: Param<Tensor<B, 2>>,
conv_weight: Param<Tensor<B, 2>>,
conv_bias: Option<Param<Tensor<B, 1>>>,
x_proj: Param<Tensor<B, 2>>,
dt_proj_weight: Param<Tensor<B, 2>>,
dt_proj_bias: Param<Tensor<B, 1>>,
a_log: Param<Tensor<B, 2>>,
d_skip: Param<Tensor<B, 1>>,
out_proj: Param<Tensor<B, 2>>,
}
impl<B: Backend> Mamba1SequenceParameters<B> {
pub fn new(config: ResolvedMambaSequenceConfig, device: &B::Device) -> Self {
let in_std = (1.0 / config.d_model.max(1) as f32).sqrt();
let out_std = (1.0 / config.d_inner.max(1) as f32).sqrt();
let conv_std = (1.0 / config.d_conv.max(1) as f32).sqrt();
let dt_weight_std = (1.0 / config.dt_rank.max(1) as f32).sqrt() * config.dt_scale;
let dt_target = (config.dt_min * config.dt_max).sqrt().max(1.0e-6);
let dt_bias = dt_target + (-(-dt_target).exp_m1()).ln();
let in_proj = Param::from_tensor(Tensor::<B, 2>::random(
[config.d_model, config.d_inner * 2],
TensorDistribution::Normal(0.0, in_std as f64),
device,
));
let conv_weight = Param::from_tensor(Tensor::<B, 2>::random(
[config.d_inner, config.d_conv],
TensorDistribution::Normal(0.0, conv_std as f64),
device,
));
let conv_bias = config
.conv_bias
.then(|| Param::from_tensor(Tensor::<B, 1>::zeros([config.d_inner], device)));
let x_proj = Param::from_tensor(Tensor::<B, 2>::random(
[config.d_inner, config.dt_rank + config.d_state * 2],
TensorDistribution::Normal(0.0, out_std as f64),
device,
));
let dt_proj_weight = Param::from_tensor(Tensor::<B, 2>::random(
[config.dt_rank, config.d_inner],
TensorDistribution::Normal(0.0, dt_weight_std as f64),
device,
));
let dt_proj_bias = Param::from_tensor(Tensor::<B, 1>::from_data(
TensorData::new(vec![dt_bias; config.d_inner], [config.d_inner]),
device,
));
let a_values = (0..config.d_inner)
.flat_map(|_| (1..=config.d_state).map(|value| (value as f32).ln()))
.collect::<Vec<_>>();
let a_log = Param::from_tensor(Tensor::<B, 2>::from_data(
TensorData::new(a_values, [config.d_inner, config.d_state]),
device,
));
let d_skip = Param::from_tensor(Tensor::<B, 1>::ones([config.d_inner], device));
let out_proj = Param::from_tensor(Tensor::<B, 2>::random(
[config.d_inner, config.d_model],
TensorDistribution::Normal(0.0, out_std as f64),
device,
));
Self {
d_model: config.d_model,
d_inner: config.d_inner,
d_state: config.d_state,
d_conv: config.d_conv,
dt_rank: config.dt_rank,
in_proj,
conv_weight,
conv_bias,
x_proj,
dt_proj_weight,
dt_proj_bias,
a_log,
d_skip,
out_proj,
}
}
pub fn config(&self) -> ResolvedMambaSequenceConfig {
ResolvedMambaSequenceConfig {
d_model: self.d_model,
d_inner: self.d_inner,
d_state: self.d_state,
d_conv: self.d_conv,
dt_rank: self.dt_rank,
dt_min: default_mamba_dt_min(),
dt_max: default_mamba_dt_max(),
dt_scale: default_mamba_dt_scale(),
conv_bias: self.conv_bias.is_some(),
use_fast_path: false,
headdim: default_mamba_headdim(),
ngroups: default_mamba_ngroups(),
nheads: 0,
a_init_min: default_mamba_a_init_min(),
a_init_max: default_mamba_a_init_max(),
norm_eps: default_mamba_norm_eps(),
rope_fraction: default_mamba_rope_fraction(),
dt_init_floor: default_mamba_dt_init_floor(),
a_floor: default_mamba_a_floor(),
chunk_size: default_mamba_chunk_size(),
is_outproj_norm: false,
is_mimo: false,
mimo_rank: 1,
num_rope_angles: 0,
}
}
pub fn in_proj_tensor(&self) -> Tensor<B, 2> {
self.in_proj.val()
}
pub fn conv_weight_tensor(&self) -> Tensor<B, 2> {
self.conv_weight.val()
}
pub fn conv_bias_tensor(&self) -> Option<Tensor<B, 1>> {
self.conv_bias.as_ref().map(|bias| bias.val())
}
pub fn x_proj_tensor(&self) -> Tensor<B, 2> {
self.x_proj.val()
}
pub fn dt_proj_weight_tensor(&self) -> Tensor<B, 2> {
self.dt_proj_weight.val()
}
pub fn dt_proj_bias_tensor(&self) -> Tensor<B, 1> {
self.dt_proj_bias.val()
}
pub fn a_log_tensor(&self) -> Tensor<B, 2> {
self.a_log.val()
}
pub fn d_skip_tensor(&self) -> Tensor<B, 1> {
self.d_skip.val()
}
pub fn out_proj_tensor(&self) -> Tensor<B, 2> {
self.out_proj.val()
}
pub fn blended_with(&self, fresh: &Self, alpha: f32) -> Self {
Self {
d_model: self.d_model,
d_inner: self.d_inner,
d_state: self.d_state,
d_conv: self.d_conv,
dt_rank: self.dt_rank,
in_proj: Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
self.in_proj.val(),
fresh.in_proj.val(),
alpha,
)),
conv_weight: Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
self.conv_weight.val(),
fresh.conv_weight.val(),
alpha,
)),
conv_bias: self.conv_bias.as_ref().zip(fresh.conv_bias.as_ref()).map(
|(source, fresh)| {
Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
source.val(),
fresh.val(),
alpha,
))
},
),
x_proj: Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
self.x_proj.val(),
fresh.x_proj.val(),
alpha,
)),
dt_proj_weight: Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
self.dt_proj_weight.val(),
fresh.dt_proj_weight.val(),
alpha,
)),
dt_proj_bias: Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
self.dt_proj_bias.val(),
fresh.dt_proj_bias.val(),
alpha,
)),
a_log: Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
self.a_log.val(),
fresh.a_log.val(),
alpha,
)),
d_skip: Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
self.d_skip.val(),
fresh.d_skip.val(),
alpha,
)),
out_proj: Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
self.out_proj.val(),
fresh.out_proj.val(),
alpha,
)),
}
}
pub fn matched_fresh_rms(&self, fresh: &Self) -> Self {
Self {
d_model: self.d_model,
d_inner: self.d_inner,
d_state: self.d_state,
d_conv: self.d_conv,
dt_rank: self.dt_rank,
in_proj: Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
self.in_proj.val(),
fresh.in_proj.val(),
)),
conv_weight: Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
self.conv_weight.val(),
fresh.conv_weight.val(),
)),
conv_bias: self.conv_bias.as_ref().zip(fresh.conv_bias.as_ref()).map(
|(source, fresh)| {
Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
source.val(),
fresh.val(),
))
},
),
x_proj: Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
self.x_proj.val(),
fresh.x_proj.val(),
)),
dt_proj_weight: Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
self.dt_proj_weight.val(),
fresh.dt_proj_weight.val(),
)),
dt_proj_bias: Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
self.dt_proj_bias.val(),
fresh.dt_proj_bias.val(),
)),
a_log: Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
self.a_log.val(),
fresh.a_log.val(),
)),
d_skip: Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
self.d_skip.val(),
fresh.d_skip.val(),
)),
out_proj: Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
self.out_proj.val(),
fresh.out_proj.val(),
)),
}
}
}
#[derive(Module, Debug)]
pub struct Mamba2SequenceParameters<B: Backend> {
d_model: usize,
d_inner: usize,
d_state: usize,
d_conv: usize,
headdim: usize,
ngroups: usize,
nheads: usize,
norm_eps: f32,
in_proj: Param<Tensor<B, 2>>,
conv_weight: Param<Tensor<B, 2>>,
conv_bias: Option<Param<Tensor<B, 1>>>,
dt_bias: Param<Tensor<B, 1>>,
a_log: Param<Tensor<B, 1>>,
d_skip: Param<Tensor<B, 1>>,
norm_weight: Param<Tensor<B, 1>>,
out_proj: Param<Tensor<B, 2>>,
}
impl<B: Backend> Mamba2SequenceParameters<B> {
pub fn new(config: ResolvedMambaSequenceConfig, device: &B::Device) -> Self {
let in_std = (1.0 / config.d_model.max(1) as f32).sqrt();
let out_std = (1.0 / config.d_inner.max(1) as f32).sqrt();
let conv_std = (1.0 / config.d_conv.max(1) as f32).sqrt();
let log_dt_min = config.dt_min.ln();
let log_dt_max = config.dt_max.ln();
let dt_sample = Tensor::<B, 1>::random(
[config.nheads],
TensorDistribution::Uniform(log_dt_min as f64, log_dt_max as f64),
device,
)
.exp()
.clamp_min(1.0e-4);
let dt_bias = dt_sample
.clone()
.exp()
.sub_scalar(1.0)
.clamp_min(1.0e-6)
.log();
let a_log = Tensor::<B, 1>::random(
[config.nheads],
TensorDistribution::Uniform(config.a_init_min as f64, config.a_init_max as f64),
device,
)
.clamp_min(1.0e-6)
.log();
let in_proj = Param::from_tensor(Tensor::<B, 2>::random(
[config.d_model, config.mamba2_in_proj_dim()],
TensorDistribution::Normal(0.0, in_std as f64),
device,
));
let conv_weight = Param::from_tensor(Tensor::<B, 2>::random(
[config.mamba2_conv_dim(), config.d_conv],
TensorDistribution::Normal(0.0, conv_std as f64),
device,
));
let conv_bias = config
.conv_bias
.then(|| Param::from_tensor(Tensor::<B, 1>::zeros([config.mamba2_conv_dim()], device)));
let dt_bias = Param::from_tensor(Tensor::<B, 1>::from_data(dt_bias.to_data(), device));
let a_log = Param::from_tensor(Tensor::<B, 1>::from_data(a_log.to_data(), device));
let d_skip = Param::from_tensor(Tensor::<B, 1>::ones([config.nheads], device));
let norm_weight = Param::from_tensor(Tensor::<B, 1>::ones([config.d_inner], device));
let out_proj = Param::from_tensor(Tensor::<B, 2>::random(
[config.d_inner, config.d_model],
TensorDistribution::Normal(0.0, out_std as f64),
device,
));
Self {
d_model: config.d_model,
d_inner: config.d_inner,
d_state: config.d_state,
d_conv: config.d_conv,
headdim: config.headdim,
ngroups: config.ngroups,
nheads: config.nheads,
norm_eps: config.norm_eps,
in_proj,
conv_weight,
conv_bias,
dt_bias,
a_log,
d_skip,
norm_weight,
out_proj,
}
}
pub fn config(&self) -> ResolvedMambaSequenceConfig {
ResolvedMambaSequenceConfig {
d_model: self.d_model,
d_inner: self.d_inner,
d_state: self.d_state,
d_conv: self.d_conv,
dt_rank: self.d_model.div_ceil(16),
dt_min: default_mamba_dt_min(),
dt_max: default_mamba_dt_max(),
dt_scale: default_mamba_dt_scale(),
conv_bias: self.conv_bias.is_some(),
use_fast_path: false,
headdim: self.headdim,
ngroups: self.ngroups,
nheads: self.nheads,
a_init_min: default_mamba_a_init_min(),
a_init_max: default_mamba_a_init_max(),
norm_eps: self.norm_eps,
rope_fraction: default_mamba_rope_fraction(),
dt_init_floor: default_mamba_dt_init_floor(),
a_floor: default_mamba_a_floor(),
chunk_size: default_mamba_chunk_size(),
is_outproj_norm: false,
is_mimo: false,
mimo_rank: 1,
num_rope_angles: 0,
}
}
pub fn in_proj_tensor(&self) -> Tensor<B, 2> {
self.in_proj.val()
}
pub fn conv_weight_tensor(&self) -> Tensor<B, 2> {
self.conv_weight.val()
}
pub fn conv_bias_tensor(&self) -> Option<Tensor<B, 1>> {
self.conv_bias.as_ref().map(|bias| bias.val())
}
pub fn dt_bias_tensor(&self) -> Tensor<B, 1> {
self.dt_bias.val()
}
pub fn a_log_tensor(&self) -> Tensor<B, 1> {
self.a_log.val()
}
pub fn d_skip_tensor(&self) -> Tensor<B, 1> {
self.d_skip.val()
}
pub fn norm_weight_tensor(&self) -> Tensor<B, 1> {
self.norm_weight.val()
}
pub fn out_proj_tensor(&self) -> Tensor<B, 2> {
self.out_proj.val()
}
pub fn blended_with(&self, fresh: &Self, alpha: f32) -> Self {
Self {
d_model: self.d_model,
d_inner: self.d_inner,
d_state: self.d_state,
d_conv: self.d_conv,
headdim: self.headdim,
ngroups: self.ngroups,
nheads: self.nheads,
norm_eps: self.norm_eps,
in_proj: Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
self.in_proj.val(),
fresh.in_proj.val(),
alpha,
)),
conv_weight: Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
self.conv_weight.val(),
fresh.conv_weight.val(),
alpha,
)),
conv_bias: self.conv_bias.as_ref().zip(fresh.conv_bias.as_ref()).map(
|(source, fresh)| {
Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
source.val(),
fresh.val(),
alpha,
))
},
),
dt_bias: Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
self.dt_bias.val(),
fresh.dt_bias.val(),
alpha,
)),
a_log: Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
self.a_log.val(),
fresh.a_log.val(),
alpha,
)),
d_skip: Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
self.d_skip.val(),
fresh.d_skip.val(),
alpha,
)),
norm_weight: Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
self.norm_weight.val(),
fresh.norm_weight.val(),
alpha,
)),
out_proj: Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
self.out_proj.val(),
fresh.out_proj.val(),
alpha,
)),
}
}
pub fn matched_fresh_rms(&self, fresh: &Self) -> Self {
Self {
d_model: self.d_model,
d_inner: self.d_inner,
d_state: self.d_state,
d_conv: self.d_conv,
headdim: self.headdim,
ngroups: self.ngroups,
nheads: self.nheads,
norm_eps: self.norm_eps,
in_proj: Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
self.in_proj.val(),
fresh.in_proj.val(),
)),
conv_weight: Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
self.conv_weight.val(),
fresh.conv_weight.val(),
)),
conv_bias: self.conv_bias.as_ref().zip(fresh.conv_bias.as_ref()).map(
|(source, fresh)| {
Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
source.val(),
fresh.val(),
))
},
),
dt_bias: Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
self.dt_bias.val(),
fresh.dt_bias.val(),
)),
a_log: Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
self.a_log.val(),
fresh.a_log.val(),
)),
d_skip: Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
self.d_skip.val(),
fresh.d_skip.val(),
)),
norm_weight: Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
self.norm_weight.val(),
fresh.norm_weight.val(),
)),
out_proj: Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
self.out_proj.val(),
fresh.out_proj.val(),
)),
}
}
}
#[derive(Module, Debug)]
pub struct Mamba3SequenceParameters<B: Backend> {
d_model: usize,
d_inner: usize,
d_state: usize,
headdim: usize,
ngroups: usize,
nheads: usize,
norm_eps: f32,
num_rope_angles: usize,
a_floor: f32,
chunk_size: usize,
in_proj: Param<Tensor<B, 2>>,
dt_bias: Param<Tensor<B, 1>>,
b_bias: Param<Tensor<B, 2>>,
c_bias: Param<Tensor<B, 2>>,
b_norm_weight: Param<Tensor<B, 1>>,
c_norm_weight: Param<Tensor<B, 1>>,
d_skip: Param<Tensor<B, 1>>,
out_proj: Param<Tensor<B, 2>>,
}
impl<B: Backend> Mamba3SequenceParameters<B> {
pub fn new(config: ResolvedMambaSequenceConfig, device: &B::Device) -> Self {
let in_std = (1.0 / config.d_model.max(1) as f32).sqrt();
let out_std = (1.0 / config.d_inner.max(1) as f32).sqrt();
let log_dt_min = config.dt_min.ln();
let log_dt_max = config.dt_max.ln();
let dt_sample = Tensor::<B, 1>::random(
[config.nheads],
TensorDistribution::Uniform(log_dt_min as f64, log_dt_max as f64),
device,
)
.exp()
.clamp_min(config.dt_init_floor);
let dt_bias_values = dt_sample
.to_data()
.convert::<f32>()
.into_vec::<f32>()
.expect("mamba3 dt bias init")
.into_iter()
.map(|dt| dt + (-(-dt).exp_m1()).ln())
.collect::<Vec<_>>();
let in_proj = Param::from_tensor(Tensor::<B, 2>::random(
[config.d_model, config.mamba3_in_proj_dim()],
TensorDistribution::Normal(0.0, in_std as f64),
device,
));
let dt_bias = Param::from_tensor(Tensor::<B, 1>::from_data(
TensorData::new(dt_bias_values, [config.nheads]),
device,
));
let b_bias = Param::from_tensor(Tensor::<B, 2>::ones(
[config.nheads, config.d_state],
device,
));
let c_bias = Param::from_tensor(Tensor::<B, 2>::ones(
[config.nheads, config.d_state],
device,
));
let b_norm_weight = Param::from_tensor(Tensor::<B, 1>::ones([config.d_state], device));
let c_norm_weight = Param::from_tensor(Tensor::<B, 1>::ones([config.d_state], device));
let d_skip = Param::from_tensor(Tensor::<B, 1>::ones([config.nheads], device));
let out_proj = Param::from_tensor(Tensor::<B, 2>::random(
[config.d_inner, config.d_model],
TensorDistribution::Normal(0.0, out_std as f64),
device,
));
Self {
d_model: config.d_model,
d_inner: config.d_inner,
d_state: config.d_state,
headdim: config.headdim,
ngroups: config.ngroups,
nheads: config.nheads,
norm_eps: config.norm_eps,
num_rope_angles: config.num_rope_angles,
a_floor: config.a_floor,
chunk_size: config.chunk_size,
in_proj,
dt_bias,
b_bias,
c_bias,
b_norm_weight,
c_norm_weight,
d_skip,
out_proj,
}
}
pub fn config(&self) -> ResolvedMambaSequenceConfig {
ResolvedMambaSequenceConfig {
d_model: self.d_model,
d_inner: self.d_inner,
d_state: self.d_state,
d_conv: default_mamba_d_conv(),
dt_rank: self.d_model.div_ceil(16),
dt_min: default_mamba_dt_min(),
dt_max: default_mamba_dt_max(),
dt_scale: default_mamba_dt_scale(),
conv_bias: false,
use_fast_path: false,
headdim: self.headdim,
ngroups: self.ngroups,
nheads: self.nheads,
a_init_min: default_mamba_a_init_min(),
a_init_max: default_mamba_a_init_max(),
norm_eps: self.norm_eps,
rope_fraction: default_mamba_rope_fraction(),
dt_init_floor: default_mamba_dt_init_floor(),
a_floor: self.a_floor,
chunk_size: self.chunk_size,
is_outproj_norm: false,
is_mimo: false,
mimo_rank: 1,
num_rope_angles: self.num_rope_angles,
}
}
pub fn in_proj_tensor(&self) -> Tensor<B, 2> {
self.in_proj.val()
}
pub fn dt_bias_tensor(&self) -> Tensor<B, 1> {
self.dt_bias.val()
}
pub fn b_bias_tensor(&self) -> Tensor<B, 2> {
self.b_bias.val()
}
pub fn c_bias_tensor(&self) -> Tensor<B, 2> {
self.c_bias.val()
}
pub fn b_norm_weight_tensor(&self) -> Tensor<B, 1> {
self.b_norm_weight.val()
}
pub fn c_norm_weight_tensor(&self) -> Tensor<B, 1> {
self.c_norm_weight.val()
}
pub fn d_skip_tensor(&self) -> Tensor<B, 1> {
self.d_skip.val()
}
pub fn out_proj_tensor(&self) -> Tensor<B, 2> {
self.out_proj.val()
}
pub fn blended_with(&self, fresh: &Self, alpha: f32) -> Self {
Self {
d_model: self.d_model,
d_inner: self.d_inner,
d_state: self.d_state,
headdim: self.headdim,
ngroups: self.ngroups,
nheads: self.nheads,
norm_eps: self.norm_eps,
num_rope_angles: self.num_rope_angles,
a_floor: self.a_floor,
chunk_size: self.chunk_size,
in_proj: Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
self.in_proj.val(),
fresh.in_proj.val(),
alpha,
)),
dt_bias: Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
self.dt_bias.val(),
fresh.dt_bias.val(),
alpha,
)),
b_bias: Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
self.b_bias.val(),
fresh.b_bias.val(),
alpha,
)),
c_bias: Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
self.c_bias.val(),
fresh.c_bias.val(),
alpha,
)),
b_norm_weight: Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
self.b_norm_weight.val(),
fresh.b_norm_weight.val(),
alpha,
)),
c_norm_weight: Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
self.c_norm_weight.val(),
fresh.c_norm_weight.val(),
alpha,
)),
d_skip: Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
self.d_skip.val(),
fresh.d_skip.val(),
alpha,
)),
out_proj: Param::from_tensor(MambaSequenceParameters::<B>::blend_param(
self.out_proj.val(),
fresh.out_proj.val(),
alpha,
)),
}
}
pub fn matched_fresh_rms(&self, fresh: &Self) -> Self {
Self {
d_model: self.d_model,
d_inner: self.d_inner,
d_state: self.d_state,
headdim: self.headdim,
ngroups: self.ngroups,
nheads: self.nheads,
norm_eps: self.norm_eps,
num_rope_angles: self.num_rope_angles,
a_floor: self.a_floor,
chunk_size: self.chunk_size,
in_proj: Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
self.in_proj.val(),
fresh.in_proj.val(),
)),
dt_bias: Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
self.dt_bias.val(),
fresh.dt_bias.val(),
)),
b_bias: Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
self.b_bias.val(),
fresh.b_bias.val(),
)),
c_bias: Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
self.c_bias.val(),
fresh.c_bias.val(),
)),
b_norm_weight: Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
self.b_norm_weight.val(),
fresh.b_norm_weight.val(),
)),
c_norm_weight: Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
self.c_norm_weight.val(),
fresh.c_norm_weight.val(),
)),
d_skip: Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
self.d_skip.val(),
fresh.d_skip.val(),
)),
out_proj: Param::from_tensor(MambaSequenceParameters::<B>::match_fresh_rms(
self.out_proj.val(),
fresh.out_proj.val(),
)),
}
}
}
#[derive(Module, Debug)]
pub struct MambaSequenceParameters<B: Backend> {
mamba1: Option<Mamba1SequenceParameters<B>>,
mamba2: Option<Mamba2SequenceParameters<B>>,
mamba3: Option<Mamba3SequenceParameters<B>>,
}
impl<B: Backend> MambaSequenceParameters<B> {
fn param_rms<const D: usize>(tensor: Tensor<B, D>) -> f32 {
let values = tensor
.powf_scalar(2.0)
.mean()
.to_data()
.convert::<f32>()
.into_vec::<f32>()
.expect("mamba rms scalar");
values.first().copied().unwrap_or(0.0).sqrt()
}
fn blend_param<const D: usize>(
source: Tensor<B, D>,
fresh: Tensor<B, D>,
alpha: f32,
) -> Tensor<B, D> {
let alpha = alpha.clamp(0.0, 1.0);
(fresh.mul_scalar(1.0 - alpha) + source.mul_scalar(alpha)).detach()
}
fn match_fresh_rms<const D: usize>(source: Tensor<B, D>, fresh: Tensor<B, D>) -> Tensor<B, D> {
let source_rms = Self::param_rms(source.clone());
let fresh_rms = Self::param_rms(fresh);
if source_rms <= 1.0e-8 || !source_rms.is_finite() || !fresh_rms.is_finite() {
return source;
}
source.mul_scalar(fresh_rms / source_rms).detach()
}
pub fn new(
config: ResolvedMambaSequenceConfig,
memory_system: SequenceMemorySystem,
device: &B::Device,
) -> Self {
match memory_system {
SequenceMemorySystem::Mamba1SelectiveScan => Self {
mamba1: Some(Mamba1SequenceParameters::new(config, device)),
mamba2: None,
mamba3: None,
},
SequenceMemorySystem::Mamba2StateSpaceDuality => Self {
mamba1: None,
mamba2: Some(Mamba2SequenceParameters::new(config, device)),
mamba3: None,
},
SequenceMemorySystem::Mamba3StateSpaceDuality => Self {
mamba1: None,
mamba2: None,
mamba3: Some(Mamba3SequenceParameters::new(config, device)),
},
other => panic!("unsupported memory system {other:?} for mamba params"),
}
}
pub fn mamba1(&self) -> Option<&Mamba1SequenceParameters<B>> {
self.mamba1.as_ref()
}
pub fn mamba2(&self) -> Option<&Mamba2SequenceParameters<B>> {
self.mamba2.as_ref()
}
pub fn mamba3(&self) -> Option<&Mamba3SequenceParameters<B>> {
self.mamba3.as_ref()
}
pub fn blended_with(&self, fresh: &Self, alpha: f32) -> Self {
Self {
mamba1: self
.mamba1
.as_ref()
.zip(fresh.mamba1.as_ref())
.map(|(source, fresh)| source.blended_with(fresh, alpha)),
mamba2: self
.mamba2
.as_ref()
.zip(fresh.mamba2.as_ref())
.map(|(source, fresh)| source.blended_with(fresh, alpha)),
mamba3: self
.mamba3
.as_ref()
.zip(fresh.mamba3.as_ref())
.map(|(source, fresh)| source.blended_with(fresh, alpha)),
}
}
pub fn matched_fresh_rms(&self, fresh: &Self) -> Self {
Self {
mamba1: self
.mamba1
.as_ref()
.zip(fresh.mamba1.as_ref())
.map(|(source, fresh)| source.matched_fresh_rms(fresh)),
mamba2: self
.mamba2
.as_ref()
.zip(fresh.mamba2.as_ref())
.map(|(source, fresh)| source.matched_fresh_rms(fresh)),
mamba3: self
.mamba3
.as_ref()
.zip(fresh.mamba3.as_ref())
.map(|(source, fresh)| source.matched_fresh_rms(fresh)),
}
}
}
#[derive(Debug, Clone)]
pub struct MambaReferenceState<B: Backend> {
pub conv: Tensor<B, 4>,
pub ssm: Tensor<B, 4>,
pub angle: Option<Tensor<B, 3>>,
pub k: Option<Tensor<B, 3>>,
pub v: Option<Tensor<B, 3>>,
}
fn silu<B: Backend, const D: usize>(values: Tensor<B, D>) -> Tensor<B, D> {
values.clone() * activation::sigmoid(values)
}
pub(crate) fn mamba_depthwise_conv_step_reference<B: Backend>(
x_t: Tensor<B, 3>,
conv_state: Tensor<B, 4>,
conv_weight: Tensor<B, 2>,
conv_bias: Option<Tensor<B, 1>>,
) -> (Tensor<B, 3>, Tensor<B, 4>) {
let [batch, views, channels] = x_t.shape().dims::<3>();
let d_conv = conv_state.shape().dims::<4>()[3];
let device = x_t.device();
let conv_tail = if d_conv > 1 {
conv_state.clone().slice_dim(3, 1..d_conv)
} else {
Tensor::<B, 4>::zeros([batch, views, channels, 0], &device)
};
let next_conv_state = Tensor::cat(vec![conv_tail, x_t.clone().unsqueeze_dim::<4>(3)], 3);
let mut u_t = (next_conv_state.clone() * conv_weight.reshape([1, 1, channels, d_conv]))
.sum_dim(3)
.reshape([batch, views, channels]);
if let Some(bias) = conv_bias {
u_t = u_t + bias.reshape([1, 1, channels]);
}
(silu(u_t), next_conv_state)
}
pub(crate) fn mamba1_selective_scan_step_reference<B: Backend>(
u_t: Tensor<B, 3>,
z_t: Tensor<B, 3>,
ssm_state: Tensor<B, 4>,
params: &Mamba1SequenceParameters<B>,
) -> (Tensor<B, 3>, Tensor<B, 4>) {
let [batch, views, d_inner] = u_t.shape().dims::<3>();
let config = params.config();
let a = params
.a_log
.val()
.exp()
.neg()
.reshape([1, 1, config.d_inner, config.d_state]);
let d_skip = params.d_skip.val().reshape([1, 1, config.d_inner]);
let x_db = u_t
.clone()
.reshape([batch, d_inner])
.matmul(params.x_proj.val())
.reshape([batch, config.dt_rank + config.d_state * 2]);
let dt = activation::softplus(
x_db.clone()
.slice_dim(1, 0..config.dt_rank)
.matmul(params.dt_proj_weight.val())
.reshape([batch, views, config.d_inner])
+ params.dt_proj_bias.val().reshape([1, 1, config.d_inner]),
1.0,
);
let b_t = x_db
.clone()
.slice_dim(1, config.dt_rank..(config.dt_rank + config.d_state))
.reshape([batch, views, config.d_state]);
let c_t = x_db
.slice_dim(
1,
(config.dt_rank + config.d_state)..(config.dt_rank + config.d_state * 2),
)
.reshape([batch, views, config.d_state]);
let d_a = (dt.clone().unsqueeze_dim::<4>(3) * a).exp();
let d_b = dt.clone().unsqueeze_dim::<4>(3) * b_t.clone().unsqueeze_dim::<4>(2);
let next_ssm_state = ssm_state * d_a + u_t.clone().unsqueeze_dim::<4>(3) * d_b;
let y_t = (next_ssm_state.clone() * c_t.unsqueeze_dim::<4>(2))
.sum_dim(3)
.reshape([batch, views, config.d_inner])
+ d_skip * u_t;
(y_t * silu(z_t), next_ssm_state)
}
fn repeat_groups_to_heads<B: Backend>(grouped: Tensor<B, 3>, nheads: usize) -> Tensor<B, 3> {
let [batch, ngroups, d_state] = grouped.shape().dims::<3>();
assert_eq!(
nheads % ngroups,
0,
"Mamba-2 requires nheads divisible by ngroups"
);
grouped
.reshape([batch, ngroups, 1, d_state])
.repeat_dim(2, nheads / ngroups)
.reshape([batch, nheads, d_state])
}
fn mamba2_rmsnorm_gated_reference<B: Backend>(
y: Tensor<B, 3>,
z: Tensor<B, 3>,
weight: Tensor<B, 1>,
eps: f32,
) -> Tensor<B, 3> {
let width = weight.shape().dims::<1>()[0];
let rms = y
.clone()
.powf_scalar(2.0)
.mean_dim(2)
.add_scalar(eps)
.sqrt();
(y / rms) * weight.reshape([1, 1, width]) * silu(z)
}
pub(crate) fn mamba2_state_space_duality_step_reference<B: Backend>(
x_t: Tensor<B, 3>,
dt_t: Tensor<B, 2>,
b_t: Tensor<B, 3>,
c_t: Tensor<B, 3>,
ssm_state: Tensor<B, 4>,
params: &Mamba2SequenceParameters<B>,
) -> (Tensor<B, 3>, Tensor<B, 4>) {
let [batch, nheads, headdim] = x_t.shape().dims::<3>();
let config = params.config();
assert_eq!(nheads, config.nheads);
assert_eq!(headdim, config.headdim);
let a = params
.a_log
.val()
.exp()
.neg()
.reshape([1, config.nheads, 1, 1]);
let d_skip = params.d_skip.val().reshape([1, config.nheads, 1]);
let dt = activation::softplus(dt_t + params.dt_bias.val().reshape([1, config.nheads]), 1.0);
let b_heads = repeat_groups_to_heads(b_t, config.nheads);
let c_heads = repeat_groups_to_heads(c_t, config.nheads);
let decay = (dt.clone().reshape([batch, config.nheads, 1, 1]) * a).exp();
let input_term = dt.reshape([batch, config.nheads, 1, 1])
* b_heads.unsqueeze_dim::<4>(2)
* x_t.clone().unsqueeze_dim::<4>(3);
let next_ssm_state = ssm_state * decay + input_term;
let y_t = (next_ssm_state.clone() * c_heads.unsqueeze_dim::<4>(2))
.sum_dim(3)
.reshape([batch, config.nheads, config.headdim])
+ x_t * d_skip;
(y_t, next_ssm_state)
}
fn mamba1_reference<B: Backend>(
hidden_states: Tensor<B, 4>,
params: &Mamba1SequenceParameters<B>,
state: Option<MambaReferenceState<B>>,
) -> (Tensor<B, 4>, MambaReferenceState<B>) {
let [batch, views, time, dim] = hidden_states.shape().dims::<4>();
assert_eq!(
views, 1,
"Mamba reference path currently expects a single dense stream view"
);
assert_eq!(
dim, params.d_model,
"hidden dim {} must match mamba d_model {}",
dim, params.d_model
);
let device = hidden_states.device();
let config = params.config();
let mut conv_state = match state.as_ref() {
Some(existing)
if existing.conv.shape().dims::<4>() == [batch, 1, config.d_inner, config.d_conv] =>
{
existing.conv.clone()
}
_ => Tensor::<B, 4>::zeros([batch, 1, config.d_inner, config.d_conv], &device),
};
let mut ssm_state = match state.as_ref() {
Some(existing)
if existing.ssm.shape().dims::<4>() == [batch, 1, config.d_inner, config.d_state] =>
{
existing.ssm.clone()
}
_ => Tensor::<B, 4>::zeros([batch, 1, config.d_inner, config.d_state], &device),
};
let xz = hidden_states
.clone()
.reshape([batch * time, config.d_model])
.matmul(params.in_proj.val())
.reshape([batch, time, config.d_inner * 2]);
let x = xz
.clone()
.slice_dim(2, 0..config.d_inner)
.swap_dims(1, 2)
.reshape([batch, 1, config.d_inner, time]);
let z = xz
.slice_dim(2, config.d_inner..(config.d_inner * 2))
.swap_dims(1, 2)
.reshape([batch, 1, config.d_inner, time]);
let mut outputs = Vec::with_capacity(time);
for step in 0..time {
let x_t = x
.clone()
.slice_dim(3, step..step + 1)
.reshape([batch, 1, config.d_inner]);
let z_t = z
.clone()
.slice_dim(3, step..step + 1)
.reshape([batch, 1, config.d_inner]);
let (u_t, next_conv_state) = mamba_depthwise_conv_step_reference(
x_t,
conv_state,
params.conv_weight.val(),
params.conv_bias.as_ref().map(|bias| bias.val()),
);
conv_state = next_conv_state;
let (y_t, next_ssm_state) =
mamba1_selective_scan_step_reference(u_t, z_t, ssm_state, params);
ssm_state = next_ssm_state;
let out_t = y_t
.reshape([batch, config.d_inner])
.matmul(params.out_proj.val())
.reshape([batch, 1, 1, config.d_model]);
outputs.push(out_t);
}
(
Tensor::cat(outputs, 2),
MambaReferenceState {
conv: conv_state,
ssm: ssm_state,
angle: None,
k: None,
v: None,
},
)
}
fn mamba2_reference<B: Backend>(
hidden_states: Tensor<B, 4>,
params: &Mamba2SequenceParameters<B>,
state: Option<MambaReferenceState<B>>,
) -> (Tensor<B, 4>, MambaReferenceState<B>) {
let [batch, views, time, dim] = hidden_states.shape().dims::<4>();
assert_eq!(
views, 1,
"Mamba-2 reference path currently expects a single dense stream view"
);
assert_eq!(
dim, params.d_model,
"hidden dim {} must match mamba2 d_model {}",
dim, params.d_model
);
let config = params.config();
let device = hidden_states.device();
let conv_dim = config.mamba2_conv_dim();
let mut conv_state = match state.as_ref() {
Some(existing)
if existing.conv.shape().dims::<4>() == [batch, 1, conv_dim, config.d_conv] =>
{
existing.conv.clone()
}
_ => Tensor::<B, 4>::zeros([batch, 1, conv_dim, config.d_conv], &device),
};
let mut ssm_state = match state.as_ref() {
Some(existing)
if existing.ssm.shape().dims::<4>()
== [batch, config.nheads, config.headdim, config.d_state] =>
{
existing.ssm.clone()
}
_ => Tensor::<B, 4>::zeros(
[batch, config.nheads, config.headdim, config.d_state],
&device,
),
};
let zxbcdt = hidden_states
.clone()
.reshape([batch * time, config.d_model])
.matmul(params.in_proj.val())
.reshape([batch, time, config.mamba2_in_proj_dim()]);
let z = zxbcdt
.clone()
.slice_dim(2, 0..config.d_inner)
.reshape([batch, time, config.d_inner]);
let xbc = zxbcdt
.clone()
.slice_dim(
2,
config.d_inner..(config.d_inner + config.mamba2_conv_dim()),
)
.reshape([batch, time, config.mamba2_conv_dim()]);
let dt = zxbcdt
.slice_dim(
2,
(config.d_inner + config.mamba2_conv_dim())..config.mamba2_in_proj_dim(),
)
.reshape([batch, time, config.nheads]);
let mut outputs = Vec::with_capacity(time);
for step in 0..time {
let xbc_t = xbc
.clone()
.slice_dim(1, step..step + 1)
.reshape([batch, 1, conv_dim]);
let z_t = z
.clone()
.slice_dim(1, step..step + 1)
.reshape([batch, 1, config.d_inner]);
let dt_t = dt
.clone()
.slice_dim(1, step..step + 1)
.reshape([batch, config.nheads]);
let (xbc_conv_t, next_conv_state) = mamba_depthwise_conv_step_reference(
xbc_t,
conv_state,
params.conv_weight.val(),
params.conv_bias.as_ref().map(|bias| bias.val()),
);
conv_state = next_conv_state;
let x_t = xbc_conv_t.clone().slice_dim(2, 0..config.d_inner).reshape([
batch,
config.nheads,
config.headdim,
]);
let b_t = xbc_conv_t
.clone()
.slice_dim(
2,
config.d_inner..(config.d_inner + config.ngroups * config.d_state),
)
.reshape([batch, config.ngroups, config.d_state]);
let c_t = xbc_conv_t
.slice_dim(
2,
(config.d_inner + config.ngroups * config.d_state)..config.mamba2_conv_dim(),
)
.reshape([batch, config.ngroups, config.d_state]);
let (y_t, next_ssm_state) =
mamba2_state_space_duality_step_reference(x_t, dt_t, b_t, c_t, ssm_state, params);
ssm_state = next_ssm_state;
let y_flat = y_t.reshape([batch, 1, config.d_inner]);
let normed =
mamba2_rmsnorm_gated_reference(y_flat, z_t, params.norm_weight.val(), params.norm_eps);
let out_t = normed
.reshape([batch, config.d_inner])
.matmul(params.out_proj.val())
.reshape([batch, 1, 1, config.d_model]);
outputs.push(out_t);
}
(
Tensor::cat(outputs, 2),
MambaReferenceState {
conv: conv_state,
ssm: ssm_state,
angle: None,
k: None,
v: None,
},
)
}
fn mamba3_reference<B: Backend>(
hidden_states: Tensor<B, 4>,
params: &Mamba3SequenceParameters<B>,
state: Option<MambaReferenceState<B>>,
) -> (Tensor<B, 4>, MambaReferenceState<B>) {
let config = params.config();
let tensorized = tensorized_mamba3_forward(
hidden_states,
config.d_inner,
config.d_state,
config.headdim,
config.ngroups,
config.num_rope_angles,
config.norm_eps,
config.a_floor,
config.chunk_size,
params.in_proj_tensor(),
params.dt_bias_tensor(),
params.b_bias_tensor(),
params.c_bias_tensor(),
params.b_norm_weight_tensor(),
params.c_norm_weight_tensor(),
params.d_skip_tensor(),
params.out_proj_tensor(),
state.map(|state| Mamba3TensorizedState {
ssm: state.ssm,
angle: state.angle.expect("mamba3 reference requires angle state"),
k: state.k.expect("mamba3 reference requires k state"),
v: state.v.expect("mamba3 reference requires v state"),
}),
);
(
tensorized.context,
MambaReferenceState {
conv: Tensor::<B, 4>::zeros(
[tensorized.state.ssm.shape().dims::<4>()[0], 1, 0, 0],
&tensorized.state.ssm.device(),
),
ssm: tensorized.state.ssm,
angle: Some(tensorized.state.angle),
k: Some(tensorized.state.k),
v: Some(tensorized.state.v),
},
)
}
pub fn mamba_reference<B: Backend>(
hidden_states: Tensor<B, 4>,
params: &MambaSequenceParameters<B>,
state: Option<MambaReferenceState<B>>,
) -> (Tensor<B, 4>, MambaReferenceState<B>) {
match (params.mamba1(), params.mamba2(), params.mamba3()) {
(Some(mamba1), None, None) => mamba1_reference(hidden_states, mamba1, state),
(None, Some(mamba2), None) => mamba2_reference(hidden_states, mamba2, state),
(None, None, Some(mamba3)) => mamba3_reference(hidden_states, mamba3, state),
_ => panic!("invalid mamba parameter bundle"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::tensor::TensorData;
use burn::tensor::backend::Backend as BackendTrait;
use burn_ndarray::NdArray;
use serde::Deserialize;
type Backend = NdArray<f32>;
#[derive(Debug, Deserialize)]
struct MambaFixture {
upstream_repo: String,
upstream_commit: String,
d_model: usize,
d_state: usize,
d_conv: usize,
expand: usize,
dt_rank: usize,
batch: usize,
time: usize,
hidden: Vec<Vec<Vec<f32>>>,
in_proj: Vec<Vec<f32>>,
conv_weight: Vec<Vec<f32>>,
conv_bias: Vec<f32>,
x_proj: Vec<Vec<f32>>,
dt_proj_weight: Vec<Vec<f32>>,
dt_proj_bias: Vec<f32>,
a_log: Vec<Vec<f32>>,
d_skip: Vec<f32>,
out_proj: Vec<Vec<f32>>,
expected_output: Vec<Vec<f32>>,
expected_final_conv_state: Vec<Vec<f32>>,
expected_final_ssm_state: Vec<Vec<f32>>,
}
fn fixture() -> MambaFixture {
serde_json::from_str(include_str!(
"../../../tests/data/mamba_c5afbdf_fixture.json"
))
.expect("parse mamba fixture")
}
fn flatten2(values: &[Vec<f32>]) -> Vec<f32> {
values.iter().flat_map(|row| row.iter().copied()).collect()
}
fn flatten3(values: &[Vec<Vec<f32>>]) -> Vec<f32> {
values
.iter()
.flat_map(|plane| plane.iter().flat_map(|row| row.iter().copied()))
.collect()
}
fn assert_step_mode_matches_full_sequence(
config: MambaSequenceConfig,
memory_system: SequenceMemorySystem,
d_model: usize,
batch: usize,
time: usize,
) {
let device = <Backend as BackendTrait>::Device::default();
<Backend as BackendTrait>::seed(&device, 7);
let resolved = config.resolve(d_model, memory_system);
let params = MambaSequenceParameters::<Backend>::new(resolved, memory_system, &device);
let hidden = Tensor::<Backend, 4>::from_data(
TensorData::new(
(0..(batch * time * d_model))
.map(|idx| ((idx % 17) as f32) / 17.0 - 0.25)
.collect::<Vec<_>>(),
[batch, 1, time, d_model],
),
&device,
);
let (full_out, full_state) = mamba_reference(hidden.clone(), ¶ms, None);
let mut outputs = Vec::with_capacity(time);
let mut state = None;
for step in 0..time {
let step_hidden = hidden.clone().slice_dim(2, step..step + 1);
let (step_out, next_state) = mamba_reference(step_hidden, ¶ms, state);
outputs.push(step_out);
state = Some(next_state);
}
let step_out = Tensor::cat(outputs, 2);
let step_state = state.expect("step state");
let out_diff = step_out.clone().sub(full_out).abs().max().into_scalar();
let conv_diff = step_state
.conv
.clone()
.sub(full_state.conv)
.abs()
.max()
.into_scalar();
let ssm_diff = step_state
.ssm
.clone()
.sub(full_state.ssm)
.abs()
.max()
.into_scalar();
assert!(out_diff <= 1.0e-6, "output diff {out_diff}");
assert!(conv_diff <= 1.0e-6, "conv diff {conv_diff}");
assert!(ssm_diff <= 1.0e-6, "ssm diff {ssm_diff}");
}
fn assert_chunked_state_matches_full_sequence(
config: MambaSequenceConfig,
memory_system: SequenceMemorySystem,
d_model: usize,
batch: usize,
time: usize,
prefix: usize,
) {
let device = <Backend as BackendTrait>::Device::default();
<Backend as BackendTrait>::seed(&device, 11);
let resolved = config.resolve(d_model, memory_system);
let params = MambaSequenceParameters::<Backend>::new(resolved, memory_system, &device);
let hidden = Tensor::<Backend, 4>::from_data(
TensorData::new(
(0..(batch * time * d_model))
.map(|idx| ((idx % 23) as f32) / 23.0 - 0.35)
.collect::<Vec<_>>(),
[batch, 1, time, d_model],
),
&device,
);
let (full_out, full_state) = mamba_reference(hidden.clone(), ¶ms, None);
let (prefix_out, prefix_state) =
mamba_reference(hidden.clone().slice_dim(2, 0..prefix), ¶ms, None);
let (suffix_out, suffix_state) = mamba_reference(
hidden.clone().slice_dim(2, prefix..time),
¶ms,
Some(prefix_state),
);
let chunked_out = Tensor::cat(vec![prefix_out, suffix_out], 2);
let out_diff = chunked_out.clone().sub(full_out).abs().max().into_scalar();
let conv_diff = suffix_state
.conv
.clone()
.sub(full_state.conv)
.abs()
.max()
.into_scalar();
let ssm_diff = suffix_state
.ssm
.clone()
.sub(full_state.ssm)
.abs()
.max()
.into_scalar();
assert!(out_diff <= 1.0e-6, "output diff {out_diff}");
assert!(conv_diff <= 1.0e-6, "conv diff {conv_diff}");
assert!(ssm_diff <= 1.0e-6, "ssm diff {ssm_diff}");
}
#[test]
fn mamba1_config_resolves_like_upstream_defaults() {
let resolved =
MambaSequenceConfig::default().resolve(256, SequenceMemorySystem::Mamba1SelectiveScan);
assert_eq!(resolved.d_inner, 512);
assert_eq!(resolved.d_state, 16);
assert_eq!(resolved.d_conv, 4);
assert_eq!(resolved.dt_rank, 16);
}
#[test]
fn mamba2_config_resolves_heads_from_inner_width() {
let resolved = MambaSequenceConfig {
headdim: 64,
..Default::default()
}
.resolve(256, SequenceMemorySystem::Mamba2StateSpaceDuality);
assert_eq!(resolved.d_inner, 512);
assert_eq!(resolved.nheads, 8);
assert_eq!(resolved.ngroups, 1);
}
#[test]
fn mamba2_config_rejects_non_divisible_head_width() {
let err = MambaSequenceConfig {
headdim: 96,
..Default::default()
}
.validate(SequenceMemorySystem::Mamba2StateSpaceDuality, 256)
.expect_err("expected invalid mamba2 config");
assert!(err.contains("divisible"));
}
#[test]
fn mamba1_reference_returns_expected_shapes() {
let device = <Backend as BackendTrait>::Device::default();
let config =
MambaSequenceConfig::default().resolve(8, SequenceMemorySystem::Mamba1SelectiveScan);
let params = MambaSequenceParameters::<Backend>::new(
config,
SequenceMemorySystem::Mamba1SelectiveScan,
&device,
);
let hidden = Tensor::<Backend, 4>::zeros([2, 1, 5, 8], &device);
let (output, state) = mamba_reference(hidden, ¶ms, None);
assert_eq!(output.shape().dims::<4>(), [2, 1, 5, 8]);
assert_eq!(
state.conv.shape().dims::<4>(),
[2, 1, config.d_inner, config.d_conv]
);
assert_eq!(
state.ssm.shape().dims::<4>(),
[2, 1, config.d_inner, config.d_state]
);
}
#[test]
fn mamba2_reference_returns_expected_shapes() {
let device = <Backend as BackendTrait>::Device::default();
let config = MambaSequenceConfig {
headdim: 8,
..Default::default()
}
.resolve(8, SequenceMemorySystem::Mamba2StateSpaceDuality);
let params = MambaSequenceParameters::<Backend>::new(
config,
SequenceMemorySystem::Mamba2StateSpaceDuality,
&device,
);
let hidden = Tensor::<Backend, 4>::zeros([2, 1, 5, 8], &device);
let (output, state) = mamba_reference(hidden, ¶ms, None);
assert_eq!(output.shape().dims::<4>(), [2, 1, 5, 8]);
assert_eq!(
state.conv.shape().dims::<4>(),
[2, 1, config.mamba2_conv_dim(), config.d_conv]
);
assert_eq!(
state.ssm.shape().dims::<4>(),
[2, config.nheads, config.headdim, config.d_state]
);
}
#[test]
fn mamba1_reference_matches_pinned_c5afbdf_fixture() {
let fixture = fixture();
assert_eq!(fixture.upstream_repo, MAMBA1_UPSTREAM_REPO);
assert_eq!(fixture.upstream_commit, MAMBA1_UPSTREAM_COMMIT);
let device = <Backend as BackendTrait>::Device::default();
let resolved = MambaSequenceConfig {
d_state: fixture.d_state,
d_conv: fixture.d_conv,
expand: fixture.expand,
dt_rank: Some(fixture.dt_rank),
..Default::default()
}
.resolve(fixture.d_model, SequenceMemorySystem::Mamba1SelectiveScan);
let mut params = Mamba1SequenceParameters::<Backend>::new(resolved, &device);
params.in_proj = Param::from_tensor(Tensor::<Backend, 2>::from_data(
TensorData::new(
flatten2(&fixture.in_proj),
[fixture.d_model, resolved.d_inner * 2],
),
&device,
));
params.conv_weight = Param::from_tensor(Tensor::<Backend, 2>::from_data(
TensorData::new(
flatten2(&fixture.conv_weight),
[resolved.d_inner, resolved.d_conv],
),
&device,
));
params.conv_bias = Some(Param::from_tensor(Tensor::<Backend, 1>::from_data(
TensorData::new(fixture.conv_bias.clone(), [resolved.d_inner]),
&device,
)));
params.x_proj = Param::from_tensor(Tensor::<Backend, 2>::from_data(
TensorData::new(
flatten2(&fixture.x_proj),
[resolved.d_inner, resolved.dt_rank + resolved.d_state * 2],
),
&device,
));
params.dt_proj_weight = Param::from_tensor(Tensor::<Backend, 2>::from_data(
TensorData::new(
flatten2(&fixture.dt_proj_weight),
[resolved.dt_rank, resolved.d_inner],
),
&device,
));
params.dt_proj_bias = Param::from_tensor(Tensor::<Backend, 1>::from_data(
TensorData::new(fixture.dt_proj_bias.clone(), [resolved.d_inner]),
&device,
));
params.a_log = Param::from_tensor(Tensor::<Backend, 2>::from_data(
TensorData::new(
flatten2(&fixture.a_log),
[resolved.d_inner, resolved.d_state],
),
&device,
));
params.d_skip = Param::from_tensor(Tensor::<Backend, 1>::from_data(
TensorData::new(fixture.d_skip.clone(), [resolved.d_inner]),
&device,
));
params.out_proj = Param::from_tensor(Tensor::<Backend, 2>::from_data(
TensorData::new(
flatten2(&fixture.out_proj),
[resolved.d_inner, fixture.d_model],
),
&device,
));
let hidden = Tensor::<Backend, 4>::from_data(
TensorData::new(
flatten3(&fixture.hidden),
[fixture.batch, 1, fixture.time, fixture.d_model],
),
&device,
);
let wrapped = MambaSequenceParameters {
mamba1: Some(params),
mamba2: None,
mamba3: None,
};
let (output, state) = mamba_reference(hidden, &wrapped, None);
let expected_output = Tensor::<Backend, 4>::from_data(
TensorData::new(
flatten3(&vec![fixture.expected_output.clone()]),
[fixture.batch, 1, fixture.time, fixture.d_model],
),
&device,
);
let expected_conv = Tensor::<Backend, 4>::from_data(
TensorData::new(
flatten2(&fixture.expected_final_conv_state),
[fixture.batch, 1, resolved.d_inner, resolved.d_conv],
),
&device,
);
let expected_ssm = Tensor::<Backend, 4>::from_data(
TensorData::new(
flatten2(&fixture.expected_final_ssm_state),
[fixture.batch, 1, resolved.d_inner, resolved.d_state],
),
&device,
);
let max_output_diff = output
.clone()
.sub(expected_output)
.abs()
.max()
.into_scalar();
let max_conv_diff = state
.conv
.clone()
.sub(expected_conv)
.abs()
.max()
.into_scalar();
let max_ssm_diff = state
.ssm
.clone()
.sub(expected_ssm)
.abs()
.max()
.into_scalar();
assert!(max_output_diff <= 1.0e-6, "output diff {max_output_diff}");
assert!(max_conv_diff <= 1.0e-7, "conv diff {max_conv_diff}");
assert!(max_ssm_diff <= 1.0e-7, "ssm diff {max_ssm_diff}");
}
#[test]
fn mamba1_reference_step_mode_matches_full_sequence() {
assert_step_mode_matches_full_sequence(
MambaSequenceConfig::default(),
SequenceMemorySystem::Mamba1SelectiveScan,
8,
2,
5,
);
}
#[test]
fn mamba2_reference_step_mode_matches_full_sequence() {
assert_step_mode_matches_full_sequence(
MambaSequenceConfig {
headdim: 8,
..Default::default()
},
SequenceMemorySystem::Mamba2StateSpaceDuality,
8,
2,
5,
);
}
#[test]
fn mamba1_reference_chunked_state_matches_full_sequence() {
assert_chunked_state_matches_full_sequence(
MambaSequenceConfig::default(),
SequenceMemorySystem::Mamba1SelectiveScan,
8,
2,
6,
2,
);
}
#[test]
fn mamba2_reference_chunked_state_matches_full_sequence() {
assert_chunked_state_matches_full_sequence(
MambaSequenceConfig {
headdim: 8,
..Default::default()
},
SequenceMemorySystem::Mamba2StateSpaceDuality,
8,
2,
6,
2,
);
}
}