use crate::model::mamba::mamba2::Mamba2Config;
use numr::dtype::DType;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
pub struct SsmState<R: Runtime> {
h: Tensor<R>,
conv_state: Tensor<R>,
initialized: bool,
}
impl<R: Runtime<DType = DType>> SsmState<R> {
pub fn new(batch_size: usize, config: &Mamba2Config, dtype: DType, device: &R::Device) -> Self {
let h = Tensor::<R>::zeros(
&[batch_size, config.nheads, config.headdim, config.d_state],
dtype,
device,
);
let conv_state = Tensor::<R>::zeros(
&[batch_size, config.conv_channels(), config.d_conv - 1],
dtype,
device,
);
Self {
h,
conv_state,
initialized: false,
}
}
pub fn h(&self) -> &Tensor<R> {
&self.h
}
pub fn conv_state(&self) -> &Tensor<R> {
&self.conv_state
}
pub fn is_initialized(&self) -> bool {
self.initialized
}
pub fn update_h(&mut self, new_h: Tensor<R>) {
self.h = new_h;
self.initialized = true;
}
pub fn update_conv_state(&mut self, new_conv: Tensor<R>) {
self.conv_state = new_conv;
}
pub fn reset(&mut self) {
let shape = self.h.shape().to_vec();
let dtype = self.h.dtype();
let conv_shape = self.conv_state.shape().to_vec();
let device = self.h.device().clone();
self.h = Tensor::<R>::zeros(&shape, dtype, &device);
self.conv_state = Tensor::<R>::zeros(&conv_shape, dtype, &device);
self.initialized = false;
}
}
pub struct LayeredSsmState<R: Runtime> {
layers: Vec<SsmState<R>>,
}
impl<R: Runtime<DType = DType>> LayeredSsmState<R> {
pub fn new(
num_layers: usize,
batch_size: usize,
config: &Mamba2Config,
dtype: DType,
device: &R::Device,
) -> Self {
let layers = (0..num_layers)
.map(|_| SsmState::new(batch_size, config, dtype, device))
.collect();
Self { layers }
}
pub fn layer_mut(&mut self, idx: usize) -> Option<&mut SsmState<R>> {
self.layers.get_mut(idx)
}
pub fn layer(&self, idx: usize) -> Option<&SsmState<R>> {
self.layers.get(idx)
}
pub fn num_layers(&self) -> usize {
self.layers.len()
}
pub fn reset(&mut self) {
for layer in &mut self.layers {
layer.reset();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use numr::runtime::cpu::CpuRuntime;
#[test]
fn test_ssm_state_create() {
let device = numr::runtime::cpu::CpuDevice::new();
let config = Mamba2Config::new(64)
.with_nheads(2)
.with_d_state(16)
.with_expand(2);
let state = SsmState::<CpuRuntime>::new(1, &config, DType::F32, &device);
assert_eq!(state.h().shape(), &[1, 2, 64, 16]);
assert_eq!(state.conv_state().shape(), &[1, config.conv_channels(), 3]);
assert!(!state.is_initialized());
}
#[test]
fn test_layered_ssm_state() {
let device = numr::runtime::cpu::CpuDevice::new();
let config = Mamba2Config::new(64)
.with_nheads(2)
.with_d_state(16)
.with_expand(2);
let state = LayeredSsmState::<CpuRuntime>::new(4, 1, &config, DType::F32, &device);
assert_eq!(state.num_layers(), 4);
assert!(state.layer(0).is_some());
assert!(state.layer(4).is_none());
}
#[test]
fn test_ssm_state_reset() {
let device = numr::runtime::cpu::CpuDevice::new();
let config = Mamba2Config::new(64)
.with_nheads(2)
.with_d_state(16)
.with_expand(2);
let mut state = SsmState::<CpuRuntime>::new(1, &config, DType::F32, &device);
let dummy_h = Tensor::<CpuRuntime>::ones(&[1, 2, 64, 16], DType::F32, &device);
state.update_h(dummy_h);
assert!(state.is_initialized());
state.reset();
assert!(!state.is_initialized());
}
}