use super::types::{SsmInput, var_contiguous};
use crate::error::{Error, Result};
use numr::autograd::{
Var, var_add, var_cat, var_exp, var_matmul, var_mul, var_narrow, var_reshape,
};
use numr::dtype::DType;
use numr::ops::{ActivationOps, ReduceOps, ScalarOps, TensorOps, UnaryOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn ssm_forward_sequential<R, C>(client: &C, input: &SsmInput<'_, R>) -> Result<Var<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + TensorOps<R> + ScalarOps<R> + UnaryOps<R> + ActivationOps<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
{
let shape = input.x.shape();
let batch = shape[0];
let seq_len = shape[1];
let nheads = input.config.nheads;
let headdim = input.config.headdim;
let d_state = input.config.d_state;
let ngroups = input.config.ngroups;
let device = input.x.tensor().device();
let h_tensor = Tensor::<R>::zeros(
&[batch, nheads, headdim, d_state],
input.x.tensor().dtype(),
device,
);
let mut h = Var::new(h_tensor, false);
let mut outputs: Vec<Var<R>> = Vec::with_capacity(seq_len);
for t in 0..seq_len {
let (y_t, h_new) = ssm_step(
client,
input.x,
input.a,
input.b,
input.c,
input.d_param,
input.dt,
&h,
t,
batch,
nheads,
headdim,
d_state,
ngroups,
)?;
h = h_new;
outputs.push(y_t);
}
let output_refs: Vec<&Var<R>> = outputs.iter().collect();
var_cat(&output_refs, 1, client).map_err(Error::Numr)
}
pub fn ssm_forward_chunked<R, C>(client: &C, input: &SsmInput<'_, R>) -> Result<Var<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>
+ TensorOps<R>
+ ScalarOps<R>
+ UnaryOps<R>
+ ActivationOps<R>
+ ReduceOps<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
{
let shape = input.x.shape();
let batch = shape[0];
let seq_len = shape[1];
let chunk_size = input.config.chunk_size;
if seq_len <= chunk_size {
return ssm_forward_sequential(client, input);
}
let nchunks = seq_len.div_ceil(chunk_size);
let nheads = input.config.nheads;
let headdim = input.config.headdim;
let d_state = input.config.d_state;
let device = input.x.tensor().device();
let dtype = input.x.tensor().dtype();
let h_tensor = Tensor::<R>::zeros(&[batch, nheads, headdim, d_state], dtype, device);
let mut h = Var::new(h_tensor, false);
let mut chunk_outputs: Vec<Var<R>> = Vec::with_capacity(nchunks);
for chunk_idx in 0..nchunks {
let start = chunk_idx * chunk_size;
let len = (seq_len - start).min(chunk_size);
let x_chunk = var_narrow(input.x, 1, start, len).map_err(Error::Numr)?;
let b_chunk = var_narrow(input.b, 1, start, len).map_err(Error::Numr)?;
let c_chunk = var_narrow(input.c, 1, start, len).map_err(Error::Numr)?;
let dt_chunk = var_narrow(input.dt, 1, start, len).map_err(Error::Numr)?;
let chunk_input = SsmInput {
x: &x_chunk,
a: input.a,
b: &b_chunk,
c: &c_chunk,
d_param: input.d_param,
dt: &dt_chunk,
config: input.config,
};
let (chunk_out, h_new) = ssm_chunk_with_state(client, &chunk_input, &h)?;
chunk_outputs.push(chunk_out);
h = h_new;
}
let chunk_refs: Vec<&Var<R>> = chunk_outputs.iter().collect();
var_cat(&chunk_refs, 1, client).map_err(Error::Numr)
}
fn ssm_chunk_with_state<R, C>(
client: &C,
input: &SsmInput<'_, R>,
h_init: &Var<R>,
) -> Result<(Var<R>, Var<R>)>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + TensorOps<R> + ScalarOps<R> + UnaryOps<R> + ActivationOps<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
{
let shape = input.x.shape();
let batch = shape[0];
let seq_len = shape[1];
let nheads = input.config.nheads;
let headdim = input.config.headdim;
let d_state = input.config.d_state;
let ngroups = input.config.ngroups;
let mut h = h_init.clone();
let mut outputs: Vec<Var<R>> = Vec::with_capacity(seq_len);
for t in 0..seq_len {
let (y_t, h_new) = ssm_step(
client,
input.x,
input.a,
input.b,
input.c,
input.d_param,
input.dt,
&h,
t,
batch,
nheads,
headdim,
d_state,
ngroups,
)?;
h = h_new;
outputs.push(y_t);
}
let output_refs: Vec<&Var<R>> = outputs.iter().collect();
let out = var_cat(&output_refs, 1, client).map_err(Error::Numr)?;
Ok((out, h))
}
#[allow(clippy::too_many_arguments)]
fn ssm_step<R, C>(
client: &C,
x: &Var<R>,
a: &Var<R>,
b: &Var<R>,
c: &Var<R>,
d_param: Option<&Var<R>>,
dt: &Var<R>,
h: &Var<R>,
t: usize,
batch: usize,
nheads: usize,
headdim: usize,
d_state: usize,
ngroups: usize,
) -> Result<(Var<R>, Var<R>)>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + TensorOps<R> + ScalarOps<R> + UnaryOps<R> + ActivationOps<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
{
let x_t = var_contiguous(&var_narrow(x, 1, t, 1).map_err(Error::Numr)?);
let x_t = var_reshape(&x_t, &[batch, nheads, headdim]).map_err(Error::Numr)?;
let dt_t = var_contiguous(&var_narrow(dt, 1, t, 1).map_err(Error::Numr)?);
let dt_t = var_reshape(&dt_t, &[batch, nheads, 1, 1]).map_err(Error::Numr)?;
let b_t = var_contiguous(&var_narrow(b, 1, t, 1).map_err(Error::Numr)?);
let b_t = var_reshape(&b_t, &[batch, ngroups, 1, d_state]).map_err(Error::Numr)?;
let c_t = var_contiguous(&var_narrow(c, 1, t, 1).map_err(Error::Numr)?);
let c_t = var_reshape(&c_t, &[batch, ngroups, d_state]).map_err(Error::Numr)?;
let a_broad = var_reshape(a, &[1, nheads, 1, 1]).map_err(Error::Numr)?;
let dt_a = var_mul(&dt_t, &a_broad, client).map_err(Error::Numr)?;
let decay = var_exp(&dt_a, client).map_err(Error::Numr)?;
let mut h = var_mul(&decay, h, client).map_err(Error::Numr)?;
let b_t_expanded = if ngroups == nheads {
var_reshape(&b_t, &[batch, nheads, 1, d_state]).map_err(Error::Numr)?
} else {
b_t };
let x_t_col = var_reshape(&x_t, &[batch, nheads, headdim, 1]).map_err(Error::Numr)?;
let dt_x = var_mul(&dt_t, &x_t_col, client).map_err(Error::Numr)?;
let input_term = var_mul(&dt_x, &b_t_expanded, client).map_err(Error::Numr)?;
h = var_add(&h, &input_term, client).map_err(Error::Numr)?;
let c_t_col = var_reshape(&c_t, &[batch, ngroups, d_state, 1]).map_err(Error::Numr)?;
let y_t = var_matmul(&h, &c_t_col, client).map_err(Error::Numr)?;
let mut y_t = var_reshape(&y_t, &[batch, nheads, headdim]).map_err(Error::Numr)?;
if let Some(d_var) = d_param {
let d_broad = var_reshape(d_var, &[1, nheads, 1]).map_err(Error::Numr)?;
let d_x = var_mul(&d_broad, &x_t, client).map_err(Error::Numr)?;
y_t = var_add(&y_t, &d_x, client).map_err(Error::Numr)?;
}
let y_t = var_reshape(&y_t, &[batch, 1, nheads, headdim]).map_err(Error::Numr)?;
Ok((y_t, h))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::mamba::mamba2::Mamba2Config;
use crate::test_utils::cpu_setup;
use numr::autograd::Var;
use numr::runtime::cpu::CpuRuntime;
#[test]
fn test_sequential_ssm_known_values() {
let (client, device) = cpu_setup();
let config = Mamba2Config::new(4)
.with_nheads(1)
.with_d_state(2)
.with_expand(1)
.with_use_d(false)
.with_dt_softplus(false)
.with_use_dt_bias(false);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(
&[1.0f32, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
&[1, 2, 1, 4],
&device,
),
false,
);
let a = Var::new(
Tensor::<CpuRuntime>::from_slice(&[-1.0f32], &[1], &device),
false,
);
let b = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[1, 2, 1, 2], &device),
false,
);
let c = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[1, 2, 1, 2], &device),
false,
);
let dt = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.5f32, 0.5], &[1, 2, 1], &device),
false,
);
let input = SsmInput {
x: &x,
a: &a,
b: &b,
c: &c,
d_param: None,
dt: &dt,
config: &config,
};
let out = ssm_forward_sequential(&client, &input).unwrap();
assert_eq!(out.shape(), &[1, 2, 1, 4]);
let data: Vec<f32> = out.tensor().to_vec();
assert!(data.iter().all(|v| v.is_finite()));
}
#[test]
fn test_chunked_matches_sequential() {
let (client, device) = cpu_setup();
let config = Mamba2Config::new(4)
.with_nheads(1)
.with_d_state(2)
.with_expand(1)
.with_use_d(true)
.with_dt_softplus(false)
.with_use_dt_bias(false)
.with_chunk_size(3);
let seq_len = 6;
let x_data: Vec<f32> = (0..24).map(|i| (i as f32) * 0.1).collect();
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&x_data, &[1, seq_len, 1, 4], &device),
false,
);
let a = Var::new(
Tensor::<CpuRuntime>::from_slice(&[-0.5f32], &[1], &device),
false,
);
let b_data: Vec<f32> = (0..12).map(|i| (i as f32) * 0.05 + 0.1).collect();
let b = Var::new(
Tensor::<CpuRuntime>::from_slice(&b_data, &[1, seq_len, 1, 2], &device),
false,
);
let c_data: Vec<f32> = (0..12).map(|i| (i as f32) * 0.03 + 0.2).collect();
let c = Var::new(
Tensor::<CpuRuntime>::from_slice(&c_data, &[1, seq_len, 1, 2], &device),
false,
);
let d_param = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.5f32], &[1], &device),
false,
);
let dt_data: Vec<f32> = vec![0.1, 0.2, 0.3, 0.1, 0.2, 0.3];
let dt = Var::new(
Tensor::<CpuRuntime>::from_slice(&dt_data, &[1, seq_len, 1], &device),
false,
);
let input = SsmInput {
x: &x,
a: &a,
b: &b,
c: &c,
d_param: Some(&d_param),
dt: &dt,
config: &config,
};
let out_seq = ssm_forward_sequential(&client, &input).unwrap();
let out_chunk = ssm_forward_chunked(&client, &input).unwrap();
let seq_data: Vec<f32> = out_seq.tensor().to_vec();
let chunk_data: Vec<f32> = out_chunk.tensor().to_vec();
assert_eq!(seq_data.len(), chunk_data.len());
for (i, (s, c)) in seq_data.iter().zip(chunk_data.iter()).enumerate() {
assert!((s - c).abs() < 1e-4, "mismatch at {i}: seq={s}, chunk={c}");
}
}
}