use super::layer::Mamba2;
use crate::error::{Error, Result};
use crate::model::mamba::ssm::{SsmInput, var_contiguous};
use numr::autograd::{
Var, var_add, var_exp, var_mul, var_narrow, var_neg, var_reshape, var_silu, var_softplus,
var_transpose,
};
use numr::dtype::DType;
use numr::ops::{
ActivationOps, BinaryOps, ConvOps, NormalizationOps, ReduceOps, ScalarOps, ShapeOps, TensorOps,
UnaryOps,
};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
impl<R: Runtime> Mamba2<R> {
pub fn forward<C>(&self, client: &C, x: &Var<R>) -> Result<Var<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>
+ TensorOps<R>
+ ScalarOps<R>
+ UnaryOps<R>
+ ActivationOps<R>
+ ConvOps<R>
+ NormalizationOps<R>
+ ReduceOps<R>
+ ShapeOps<R>,
R::Client: TensorOps<R>
+ ScalarOps<R>
+ ActivationOps<R>
+ ConvOps<R>
+ ReduceOps<R>
+ BinaryOps<R>,
{
let shape = x.shape();
if shape.len() != 3 {
return Err(Error::ModelError {
reason: format!("expected [batch, seq_len, d_model], got shape {:?}", shape),
});
}
let batch = shape[0];
let seq_len = shape[1];
if shape[2] != self.config.d_model {
return Err(Error::ModelError {
reason: format!(
"d_model mismatch: expected {}, got {}",
self.config.d_model, shape[2]
),
});
}
let d_inner = self.config.d_inner();
let n_groups_d_state = self.config.ngroups * self.config.d_state;
let projected = self.in_proj.forward(client, x)?;
let xbc_len = d_inner + 2 * n_groups_d_state;
let z = var_contiguous(&var_narrow(&projected, 2, 0, d_inner).map_err(Error::Numr)?);
let xbc =
var_contiguous(&var_narrow(&projected, 2, d_inner, xbc_len).map_err(Error::Numr)?);
let dt = var_contiguous(
&var_narrow(&projected, 2, d_inner + xbc_len, self.config.nheads)
.map_err(Error::Numr)?,
);
let xbc_ncl = var_contiguous(&var_transpose(&xbc).map_err(Error::Numr)?);
let xbc_conv = self.conv1d.forward(client, &xbc_ncl)?;
let xbc = var_contiguous(&var_transpose(&xbc_conv).map_err(Error::Numr)?);
let xbc = var_silu(&xbc, client).map_err(Error::Numr)?;
let x_ssm = var_contiguous(&var_narrow(&xbc, 2, 0, d_inner).map_err(Error::Numr)?);
let b_proj =
var_contiguous(&var_narrow(&xbc, 2, d_inner, n_groups_d_state).map_err(Error::Numr)?);
let c_proj = var_contiguous(
&var_narrow(&xbc, 2, d_inner + n_groups_d_state, n_groups_d_state)
.map_err(Error::Numr)?,
);
let x_ssm = var_reshape(
&x_ssm,
&[batch, seq_len, self.config.nheads, self.config.headdim],
)
.map_err(Error::Numr)?;
let b_proj = var_reshape(
&b_proj,
&[batch, seq_len, self.config.ngroups, self.config.d_state],
)
.map_err(Error::Numr)?;
let c_proj = var_reshape(
&c_proj,
&[batch, seq_len, self.config.ngroups, self.config.d_state],
)
.map_err(Error::Numr)?;
let a = var_neg(&var_exp(&self.a_log, client).map_err(Error::Numr)?, client)
.map_err(Error::Numr)?;
let mut dt = dt;
if self.config.dt_softplus {
dt = var_softplus(&dt, client).map_err(Error::Numr)?;
}
if let Some(ref bias) = self.dt_bias {
dt = var_add(&dt, bias, client).map_err(Error::Numr)?;
}
let ssm_input = SsmInput {
x: &x_ssm,
a: &a,
b: &b_proj,
c: &c_proj,
d_param: self.d_param.as_ref(),
dt: &dt,
config: &self.config,
};
let out = crate::model::mamba::ssm::ssm_forward_sequential(client, &ssm_input)?;
let out = var_reshape(&out, &[batch, seq_len, d_inner]).map_err(Error::Numr)?;
let z_gate = var_silu(&z, client).map_err(Error::Numr)?;
let out = var_mul(&out, &z_gate, client).map_err(Error::Numr)?;
let out = if let Some(ref norm) = self.norm {
norm.forward(client, &out)?
} else {
out
};
self.out_proj.forward(client, &out)
}
pub fn forward_inference<C>(
&self,
client: &C,
x: &Tensor<R>,
state: &mut crate::inference::SsmState<R>,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>
+ TensorOps<R>
+ ScalarOps<R>
+ UnaryOps<R>
+ ActivationOps<R>
+ ConvOps<R>
+ NormalizationOps<R>
+ ReduceOps<R>
+ ShapeOps<R>,
R::Client: TensorOps<R>
+ ScalarOps<R>
+ ActivationOps<R>
+ ConvOps<R>
+ ReduceOps<R>
+ BinaryOps<R>,
{
let shape = x.shape();
if shape.len() != 3 || shape[2] != self.config.d_model {
return Err(Error::ModelError {
reason: format!(
"expected [batch, seq_len, {}], got shape {:?}",
self.config.d_model, shape
),
});
}
let batch = shape[0];
let seq_len = shape[1];
let d_inner = self.config.d_inner();
let n_groups_d_state = self.config.ngroups * self.config.d_state;
let x_var = Var::new(x.clone(), false);
let projected = self.in_proj.forward(client, &x_var)?;
let projected = projected.tensor().clone().contiguous();
let xbc_len = d_inner + 2 * n_groups_d_state;
let z = projected
.narrow(2, 0, d_inner)
.map_err(Error::Numr)?
.contiguous();
let xbc = projected
.narrow(2, d_inner, xbc_len)
.map_err(Error::Numr)?
.contiguous();
let dt = projected
.narrow(2, d_inner + xbc_len, self.config.nheads)
.map_err(Error::Numr)?
.contiguous();
let xbc_ncl = xbc.transpose(-1, -2).map_err(Error::Numr)?.contiguous();
let xbc_conv = if seq_len > 1 {
self.prefill_conv(client, &xbc_ncl, seq_len, batch, x, state)?
} else {
self.decode_conv(&xbc_ncl, batch, x, state)?
};
let xbc = xbc_conv
.transpose(-1, -2)
.map_err(Error::Numr)?
.contiguous();
let xbc = xbc.silu().map_err(Error::Numr)?;
let x_ssm = xbc.narrow(2, 0, d_inner).map_err(Error::Numr)?.contiguous();
let b_proj = xbc
.narrow(2, d_inner, n_groups_d_state)
.map_err(Error::Numr)?
.contiguous();
let c_proj = xbc
.narrow(2, d_inner + n_groups_d_state, n_groups_d_state)
.map_err(Error::Numr)?
.contiguous();
let x_ssm = x_ssm
.reshape(&[batch, seq_len, self.config.nheads, self.config.headdim])
.map_err(Error::Numr)?;
let b_proj = b_proj
.reshape(&[batch, seq_len, self.config.ngroups, self.config.d_state])
.map_err(Error::Numr)?;
let c_proj = c_proj
.reshape(&[batch, seq_len, self.config.ngroups, self.config.d_state])
.map_err(Error::Numr)?;
let a = self.a_log.tensor().exp().map_err(Error::Numr)?;
let neg_one = Tensor::<R>::from_slice(&[-1.0f32], &[1], x.device());
let a = a.mul(&neg_one).map_err(Error::Numr)?;
let mut dt = dt;
if self.config.dt_softplus {
dt = client.softplus(&dt).map_err(Error::Numr)?;
}
if let Some(ref bias) = self.dt_bias {
dt = dt.add(bias.tensor()).map_err(Error::Numr)?;
}
let d_tensor = self.d_param.as_ref().map(|d| d.tensor().clone());
let ssm_input = crate::model::mamba::ssm::SsmInferenceInput {
x: &x_ssm,
a: &a,
b: &b_proj,
c: &c_proj,
d_param: d_tensor.as_ref(),
dt: &dt,
config: &self.config,
};
let (out, final_h) = crate::model::mamba::ssm::ssm_forward_sequential_inference(
client,
&ssm_input,
state.h(),
)?;
state.update_h(final_h);
let out = out
.reshape(&[batch, seq_len, d_inner])
.map_err(Error::Numr)?;
let z_gate = z.silu().map_err(Error::Numr)?;
let out = out.mul(&z_gate).map_err(Error::Numr)?;
let out_var = Var::new(out, false);
let out_var = if let Some(ref norm) = self.norm {
norm.forward(client, &out_var)?
} else {
out_var
};
let out_var = self.out_proj.forward(client, &out_var)?;
Ok(out_var.tensor().clone())
}
}