use super::types::SsmInferenceInput;
use crate::error::{Error, Result};
use numr::dtype::DType;
use numr::ops::{ActivationOps, ScalarOps, TensorOps, UnaryOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn ssm_step_inference<R, C>(
_client: &C,
input: &SsmInferenceInput<'_, R>,
h: &Tensor<R>,
) -> Result<(Tensor<R>, Tensor<R>)>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + TensorOps<R> + ScalarOps<R> + UnaryOps<R> + ActivationOps<R>,
R::Client: TensorOps<R> + ScalarOps<R> + numr::ops::BinaryOps<R>,
{
let shape = input.x.shape();
let batch = shape[0];
let nheads = input.config.nheads;
let headdim = input.config.headdim;
let d_state = input.config.d_state;
let ngroups = input.config.ngroups;
let x_t = input
.x
.reshape(&[batch, nheads, headdim])
.map_err(Error::Numr)?;
let x_t = x_t.contiguous();
let dt_t = input
.dt
.reshape(&[batch, nheads, 1, 1])
.map_err(Error::Numr)?;
let dt_t = dt_t.contiguous();
let b_t = input
.b
.reshape(&[batch, ngroups, 1, d_state])
.map_err(Error::Numr)?;
let b_t = b_t.contiguous();
let c_t = input
.c
.reshape(&[batch, ngroups, d_state])
.map_err(Error::Numr)?;
let c_t = c_t.contiguous();
let a_broad = input.a.reshape(&[1, nheads, 1, 1]).map_err(Error::Numr)?;
let dt_a = dt_t.mul(&a_broad).map_err(Error::Numr)?;
let decay = dt_a.exp().map_err(Error::Numr)?;
let mut h_new = decay.mul(h).map_err(Error::Numr)?;
let b_t_expanded = if ngroups == nheads {
b_t.reshape(&[batch, nheads, 1, d_state])
.map_err(Error::Numr)?
} else {
b_t };
let x_t_col = x_t
.reshape(&[batch, nheads, headdim, 1])
.map_err(Error::Numr)?;
let dt_x = dt_t.mul(&x_t_col).map_err(Error::Numr)?;
let input_term = dt_x.mul(&b_t_expanded).map_err(Error::Numr)?;
h_new = h_new.add(&input_term).map_err(Error::Numr)?;
let c_t_col = c_t
.reshape(&[batch, ngroups, d_state, 1])
.map_err(Error::Numr)?;
let y_t = h_new.matmul(&c_t_col).map_err(Error::Numr)?;
let mut y_t = y_t
.reshape(&[batch, nheads, headdim])
.map_err(Error::Numr)?;
if let Some(d_param) = input.d_param {
let d_broad = d_param.reshape(&[1, nheads, 1]).map_err(Error::Numr)?;
let d_x = d_broad.mul(&x_t).map_err(Error::Numr)?;
y_t = y_t.add(&d_x).map_err(Error::Numr)?;
}
let y_t = y_t
.reshape(&[batch, 1, nheads, headdim])
.map_err(Error::Numr)?;
Ok((y_t, h_new))
}
pub fn ssm_forward_sequential_inference<R, C>(
client: &C,
input: &SsmInferenceInput<'_, R>,
h_init: &Tensor<R>,
) -> Result<(Tensor<R>, Tensor<R>)>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + TensorOps<R> + ScalarOps<R> + UnaryOps<R> + ActivationOps<R>,
R::Client: TensorOps<R> + ScalarOps<R> + numr::ops::BinaryOps<R> + numr::ops::ShapeOps<R>,
{
let shape = input.x.shape();
let seq_len = shape[1];
let mut h = h_init.clone();
let mut outputs: Vec<Tensor<R>> = Vec::with_capacity(seq_len);
for t in 0..seq_len {
let x_t = input.x.narrow(1, t, 1).map_err(Error::Numr)?.contiguous();
let b_t = input.b.narrow(1, t, 1).map_err(Error::Numr)?.contiguous();
let c_t = input.c.narrow(1, t, 1).map_err(Error::Numr)?.contiguous();
let dt_t = input.dt.narrow(1, t, 1).map_err(Error::Numr)?.contiguous();
let step_input = SsmInferenceInput {
x: &x_t,
a: input.a,
b: &b_t,
c: &c_t,
d_param: input.d_param,
dt: &dt_t,
config: input.config,
};
let (y_t, h_new) = ssm_step_inference(client, &step_input, &h)?;
h = h_new;
outputs.push(y_t);
}
let output_refs: Vec<&Tensor<R>> = outputs.iter().collect();
let output = Tensor::cat(&output_refs, 1).map_err(Error::Numr)?;
Ok((output, h))
}
#[cfg(test)]
mod tests {
use super::super::scan::ssm_forward_sequential;
use super::super::types::SsmInput;
use super::*;
use crate::model::mamba::mamba2::Mamba2Config;
use crate::test_utils::cpu_setup;
use numr::autograd::Var;
use numr::dtype::DType;
use numr::runtime::cpu::CpuRuntime;
#[test]
fn test_ssm_step_inference_matches_var() {
let (client, device) = cpu_setup();
let config = Mamba2Config::new(4)
.with_nheads(1)
.with_d_state(2)
.with_expand(1)
.with_use_d(true)
.with_dt_softplus(false)
.with_use_dt_bias(false);
let x_data = [1.0f32, 0.5, 0.0, -0.5];
let x_t = Tensor::<CpuRuntime>::from_slice(&x_data, &[1, 1, 1, 4], &device);
let a_t = Tensor::<CpuRuntime>::from_slice(&[-1.0f32], &[1], &device);
let b_t = Tensor::<CpuRuntime>::from_slice(&[0.5f32, 0.3], &[1, 1, 1, 2], &device);
let c_t = Tensor::<CpuRuntime>::from_slice(&[0.2f32, 0.8], &[1, 1, 1, 2], &device);
let d_p = Tensor::<CpuRuntime>::from_slice(&[0.1f32], &[1], &device);
let dt_t = Tensor::<CpuRuntime>::from_slice(&[0.5f32], &[1, 1, 1], &device);
let h = Tensor::<CpuRuntime>::zeros(&[1, 1, 4, 2], DType::F32, &device);
let input = SsmInferenceInput {
x: &x_t,
a: &a_t,
b: &b_t,
c: &c_t,
d_param: Some(&d_p),
dt: &dt_t,
config: &config,
};
let (y, h_new) = ssm_step_inference(&client, &input, &h).unwrap();
assert_eq!(y.shape(), &[1, 1, 1, 4]);
assert_eq!(h_new.shape(), &[1, 1, 4, 2]);
let y_data: Vec<f32> = y.to_vec();
assert!(y_data.iter().all(|v| v.is_finite()));
}
#[test]
fn test_sequential_inference_matches_var() {
let (client, device) = cpu_setup();
let config = Mamba2Config::new(4)
.with_nheads(1)
.with_d_state(2)
.with_expand(1)
.with_use_d(false)
.with_dt_softplus(false)
.with_use_dt_bias(false);
let x_data: Vec<f32> = (0..8).map(|i| i as f32 * 0.1).collect();
let a_data = [-1.0f32];
let b_data = [1.0f32, 0.0, 0.0, 1.0];
let c_data = [1.0f32, 1.0, 1.0, 1.0];
let dt_data = [0.5f32, 0.5];
let x_var = Var::new(
Tensor::<CpuRuntime>::from_slice(&x_data, &[1, 2, 1, 4], &device),
false,
);
let a_var = Var::new(
Tensor::<CpuRuntime>::from_slice(&a_data, &[1], &device),
false,
);
let b_var = Var::new(
Tensor::<CpuRuntime>::from_slice(&b_data, &[1, 2, 1, 2], &device),
false,
);
let c_var = Var::new(
Tensor::<CpuRuntime>::from_slice(&c_data, &[1, 2, 1, 2], &device),
false,
);
let dt_var = Var::new(
Tensor::<CpuRuntime>::from_slice(&dt_data, &[1, 2, 1], &device),
false,
);
let var_input = SsmInput {
x: &x_var,
a: &a_var,
b: &b_var,
c: &c_var,
d_param: None,
dt: &dt_var,
config: &config,
};
let out_var = ssm_forward_sequential(&client, &var_input).unwrap();
let x_t = Tensor::<CpuRuntime>::from_slice(&x_data, &[1, 2, 1, 4], &device);
let a_t = Tensor::<CpuRuntime>::from_slice(&a_data, &[1], &device);
let b_t = Tensor::<CpuRuntime>::from_slice(&b_data, &[1, 2, 1, 2], &device);
let c_t = Tensor::<CpuRuntime>::from_slice(&c_data, &[1, 2, 1, 2], &device);
let dt_t = Tensor::<CpuRuntime>::from_slice(&dt_data, &[1, 2, 1], &device);
let h_init = Tensor::<CpuRuntime>::zeros(&[1, 1, 4, 2], DType::F32, &device);
let inf_input = SsmInferenceInput {
x: &x_t,
a: &a_t,
b: &b_t,
c: &c_t,
d_param: None,
dt: &dt_t,
config: &config,
};
let (out_inf, _) = ssm_forward_sequential_inference(&client, &inf_input, &h_init).unwrap();
let var_data: Vec<f32> = out_var.tensor().to_vec();
let inf_data: Vec<f32> = out_inf.to_vec();
assert_eq!(var_data.len(), inf_data.len());
for (i, (v, t)) in var_data.iter().zip(inf_data.iter()).enumerate() {
assert!((v - t).abs() < 1e-5, "mismatch at {i}: var={v}, inf={t}");
}
}
}