use crate::mamba::config::MambaConfig;
use std::io::Read;
use trustformers_core::{
device::Device,
errors::{tensor_op_error, Result},
layers::{Embedding, Linear},
ops::activations::silu,
tensor::Tensor,
traits::{Layer, Model},
};
use scirs2_core::ndarray::s;
pub struct RMSNorm {
weight: Tensor,
eps: f32,
device: Device,
}
impl RMSNorm {
pub fn new(normalized_shape: usize, eps: f32) -> Result<Self> {
Self::new_with_device(normalized_shape, eps, Device::CPU)
}
pub fn new_with_device(normalized_shape: usize, eps: f32, device: Device) -> Result<Self> {
let weight = Tensor::ones(&[normalized_shape])?;
Ok(Self {
weight,
eps,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
}
impl Layer for RMSNorm {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
match &input {
Tensor::F32(arr) => {
let mean_sq = arr.iter().map(|x| x * x).sum::<f32>() / arr.len() as f32;
let rms = (mean_sq + self.eps).sqrt();
let normalized = arr.mapv(|x| x / rms);
match &self.weight {
Tensor::F32(weight_arr) => {
let result = &normalized * weight_arr;
Ok(Tensor::F32(result))
},
_ => Err(tensor_op_error(
"tensor_operation",
"Unsupported weight tensor type for RMSNorm",
)),
}
},
_ => Err(tensor_op_error(
"tensor_operation",
"Unsupported input tensor type for RMSNorm",
)),
}
}
}
impl RMSNorm {
pub fn parameter_count(&self) -> usize {
self.weight.data().unwrap_or_default().len()
}
}
pub struct CausalConv1d {
weight: Tensor,
bias: Option<Tensor>,
#[allow(dead_code)]
kernel_size: usize,
#[allow(dead_code)]
padding: usize,
device: Device,
}
impl CausalConv1d {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
use_bias: bool,
) -> Result<Self> {
Self::new_with_device(
in_channels,
out_channels,
kernel_size,
use_bias,
Device::CPU,
)
}
pub fn new_with_device(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
use_bias: bool,
device: Device,
) -> Result<Self> {
let weight = Tensor::randn(&[out_channels, in_channels, kernel_size])?;
let bias = if use_bias { Some(Tensor::zeros(&[out_channels])?) } else { None };
let padding = kernel_size - 1;
Ok(Self {
weight,
bias,
kernel_size,
padding,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
}
impl Layer for CausalConv1d {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
match &input {
Tensor::F32(_input_arr) => {
Ok(input.clone())
},
_ => Err(tensor_op_error(
"tensor_operation",
"Unsupported input tensor type for CausalConv1d",
)),
}
}
}
impl CausalConv1d {
pub fn parameter_count(&self) -> usize {
let mut total = self.weight.data().unwrap_or_default().len();
if let Some(bias) = &self.bias {
total += bias.data().unwrap_or_default().len();
}
total
}
}
pub struct MambaBlock {
config: MambaConfig,
in_proj: Linear,
conv1d: CausalConv1d,
x_proj: Linear,
dt_proj: Linear,
a_log: Tensor,
d: Tensor,
out_proj: Linear,
norm: RMSNorm,
device: Device,
}
impl MambaBlock {
pub fn new(config: &MambaConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &MambaConfig, device: Device) -> Result<Self> {
let d_inner = config.get_d_inner();
let dt_rank = config.get_dt_rank();
let in_proj = Linear::new_with_device(config.d_model, 2 * d_inner, config.use_bias, device);
let conv1d = CausalConv1d::new_with_device(
d_inner,
d_inner,
config.d_conv,
config.use_conv_bias,
device,
)?;
let x_proj = Linear::new_with_device(d_inner, dt_rank + config.d_state * 2, false, device);
let dt_proj = Linear::new_with_device(dt_rank, d_inner, true, device);
let a_log = Tensor::randn(&[d_inner, config.d_state])?;
let d = Tensor::ones(&[d_inner])?;
let out_proj = Linear::new_with_device(d_inner, config.d_model, config.use_bias, device);
let norm = RMSNorm::new_with_device(config.d_model, config.rms_norm_eps, device)?;
Ok(Self {
config: config.clone(),
in_proj,
conv1d,
x_proj,
dt_proj,
a_log,
d,
out_proj,
norm,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
fn selective_ssm(
&self,
x: &Tensor,
_delta: &Tensor,
_a: &Tensor,
_b: &Tensor,
_c: &Tensor,
) -> Result<Tensor> {
let activated = x.silu()?;
Ok(activated)
}
fn parameter_count(&self) -> usize {
let mut total = 0;
total += self.in_proj.parameter_count();
total += self.conv1d.parameter_count();
total += self.x_proj.parameter_count();
total += self.dt_proj.parameter_count();
total += self.a_log.data().unwrap_or_default().len();
total += self.d.data().unwrap_or_default().len();
total += self.out_proj.parameter_count();
total += self.norm.parameter_count();
total
}
}
impl Layer for MambaBlock {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let residual = input.clone();
let normed = self.norm.forward(input)?;
let projected = self.in_proj.forward(normed)?;
let d_inner = self.config.get_d_inner();
let (x, z) = match &projected {
Tensor::F32(arr) => {
let shape = arr.shape();
if shape.len() != 2 || shape[1] != 2 * d_inner {
return Err(tensor_op_error(
"tensor_operation",
"Invalid projected tensor shape for splitting",
));
}
let x_slice = arr.slice(s![.., ..d_inner]).to_owned().into_dyn();
let z_slice = arr.slice(s![.., d_inner..]).to_owned().into_dyn();
(Tensor::F32(x_slice), Tensor::F32(z_slice))
},
_ => {
return Err(tensor_op_error(
"tensor_operation",
"Unsupported tensor type for splitting",
))
},
};
let conv_out = self.conv1d.forward(x)?;
let activated = silu(&conv_out)?;
let ssm_out = self.x_proj.forward(activated.clone())?;
let ssm_result =
self.selective_ssm(&activated, &ssm_out, &self.a_log, &ssm_out, &ssm_out)?;
let z_activated = silu(&z)?;
let gated = match (&ssm_result, &z_activated) {
(Tensor::F32(ssm_arr), Tensor::F32(z_arr)) => {
let result = ssm_arr * z_arr;
Tensor::F32(result)
},
_ => {
return Err(tensor_op_error(
"tensor_operation",
"Tensor type mismatch in gating",
))
},
};
let output = self.out_proj.forward(gated)?;
match (&residual, &output) {
(Tensor::F32(res_arr), Tensor::F32(out_arr)) => {
let result = res_arr + out_arr;
Ok(Tensor::F32(result))
},
_ => Err(tensor_op_error(
"tensor_operation",
"Tensor type mismatch in residual connection",
)),
}
}
}
pub struct MambaModel {
config: MambaConfig,
embeddings: Embedding,
layers: Vec<MambaBlock>,
norm_f: RMSNorm,
lm_head: Option<Linear>,
device: Device,
}
impl MambaModel {
pub fn new(config: MambaConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: MambaConfig, device: Device) -> Result<Self> {
let embeddings =
Embedding::new_with_device(config.vocab_size, config.d_model, None, device)?;
let mut layers = Vec::with_capacity(config.n_layer);
for _ in 0..config.n_layer {
layers.push(MambaBlock::new_with_device(&config, device)?);
}
let norm_f = RMSNorm::new_with_device(config.d_model, config.rms_norm_eps, device)?;
let lm_head = if config.tie_word_embeddings {
None
} else {
Some(Linear::new_with_device(
config.d_model,
config.vocab_size,
false,
device,
))
};
Ok(Self {
config,
embeddings,
layers,
norm_f,
lm_head,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward_lm(&self, input_ids: &Tensor) -> Result<Tensor> {
let hidden_states = self.forward(input_ids.clone())?;
if let Some(lm_head) = &self.lm_head {
lm_head.forward(hidden_states)
} else {
Ok(hidden_states)
}
}
}
impl Model for MambaModel {
type Config = MambaConfig;
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let input_ids = match &input {
Tensor::I64(arr) => arr.iter().map(|&x| x as u32).collect::<Vec<u32>>(),
Tensor::F32(arr) => arr.iter().map(|&x| x as u32).collect::<Vec<u32>>(),
_ => {
return Err(tensor_op_error(
"tensor_operation",
"Unsupported input tensor type for Mamba model",
))
},
};
let mut hidden_states = self.embeddings.forward(input_ids)?;
for layer in &self.layers {
hidden_states = layer.forward(hidden_states)?;
}
let output = self.norm_f.forward(hidden_states)?;
Ok(output)
}
fn load_pretrained(&mut self, _reader: &mut dyn Read) -> Result<()> {
Ok(())
}
fn get_config(&self) -> &Self::Config {
&self.config
}
fn num_parameters(&self) -> usize {
let mut total = 0;
total += self.embeddings.parameter_count();
for layer in &self.layers {
total += layer.parameter_count();
}
total += self.norm_f.parameter_count();
if let Some(lm_head) = &self.lm_head {
total += lm_head.parameter_count();
}
total
}
}
impl MambaModel {
pub fn mamba_130m() -> Result<Self> {
Self::new(MambaConfig::mamba_130m())
}
pub fn mamba_130m_with_device(device: Device) -> Result<Self> {
Self::new_with_device(MambaConfig::mamba_130m(), device)
}
pub fn mamba_370m() -> Result<Self> {
Self::new(MambaConfig::mamba_370m())
}
pub fn mamba_370m_with_device(device: Device) -> Result<Self> {
Self::new_with_device(MambaConfig::mamba_370m(), device)
}
pub fn mamba_790m() -> Result<Self> {
Self::new(MambaConfig::mamba_790m())
}
pub fn mamba_790m_with_device(device: Device) -> Result<Self> {
Self::new_with_device(MambaConfig::mamba_790m(), device)
}
pub fn mamba_1_4b() -> Result<Self> {
Self::new(MambaConfig::mamba_1_4b())
}
pub fn mamba_1_4b_with_device(device: Device) -> Result<Self> {
Self::new_with_device(MambaConfig::mamba_1_4b(), device)
}
pub fn mamba_2_8b() -> Result<Self> {
Self::new(MambaConfig::mamba_2_8b())
}
pub fn mamba_2_8b_with_device(device: Device) -> Result<Self> {
Self::new_with_device(MambaConfig::mamba_2_8b(), device)
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_mamba_model_creation() {
let config = MambaConfig::default();
let model = MambaModel::new(config);
assert!(model.is_ok());
}
#[test]
fn test_mamba_block_creation() {
let config = MambaConfig::default();
let block = MambaBlock::new(&config);
assert!(block.is_ok());
}
#[test]
fn test_rms_norm_creation() {
let norm = RMSNorm::new(768, 1e-5);
assert!(norm.is_ok());
}
#[test]
fn test_causal_conv1d_creation() {
let conv = CausalConv1d::new(768, 768, 4, true);
assert!(conv.is_ok());
}
#[test]
#[ignore] fn test_predefined_models() {
assert!(MambaModel::mamba_130m().is_ok());
assert!(MambaModel::mamba_370m().is_ok());
assert!(MambaModel::mamba_790m().is_ok());
assert!(MambaModel::mamba_1_4b().is_ok());
assert!(MambaModel::mamba_2_8b().is_ok());
}
#[test]
fn test_forward_pass_shape() {
let config = MambaConfig::default();
let model = MambaModel::new(config).expect("operation failed");
let input_data = vec![1i64, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let input_ids = Tensor::I64(Array1::from(input_data).into_dyn());
let output = model.forward(input_ids);
assert!(output.is_ok());
}
#[test]
fn test_device_support() {
let config = MambaConfig::default();
let model_cpu =
MambaModel::new_with_device(config.clone(), Device::CPU).expect("operation failed");
assert_eq!(model_cpu.device(), Device::CPU);
let model_130m = MambaModel::mamba_130m_with_device(Device::CPU).expect("operation failed");
assert_eq!(model_130m.device(), Device::CPU);
let block = MambaBlock::new_with_device(&config, Device::CPU).expect("operation failed");
assert_eq!(block.device(), Device::CPU);
let norm = RMSNorm::new_with_device(768, 1e-5, Device::CPU).expect("operation failed");
assert_eq!(norm.device(), Device::CPU);
let conv = CausalConv1d::new_with_device(768, 768, 4, true, Device::CPU)
.expect("operation failed");
assert_eq!(conv.device(), Device::CPU);
}
#[test]
fn test_metal_device_creation() {
let device = Device::Metal(0);
let config = MambaConfig::default();
let model = MambaModel::new_with_device(config, device).expect("operation failed");
assert!(model.device() == Device::Metal(0) || model.device() == Device::CPU);
}
#[test]
#[ignore] fn test_all_predefined_models_with_device() {
let device = Device::CPU;
assert!(MambaModel::mamba_130m_with_device(device).is_ok());
assert!(MambaModel::mamba_370m_with_device(device).is_ok());
assert!(MambaModel::mamba_790m_with_device(device).is_ok());
assert!(MambaModel::mamba_1_4b_with_device(device).is_ok());
assert!(MambaModel::mamba_2_8b_with_device(device).is_ok());
}
}