use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::error::{Result, TrustformersError};
use crate::pipeline::{Pipeline, PipelineInput, PipelineOutput};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Mamba2Config {
pub d_model: usize,
pub d_state: usize,
pub expand_factor: usize,
pub d_conv: usize,
pub dt_rank: Option<usize>,
pub n_heads: usize,
pub activation: Mamba2Activation,
pub bias: bool,
pub simplified_a_init: bool,
pub hardware_strategy: HardwareStrategy,
pub chunking_strategy: ChunkingStrategy,
pub memory_optimization: MemoryOptimization,
}
impl Default for Mamba2Config {
fn default() -> Self {
Self {
d_model: 768,
d_state: 16,
expand_factor: 2,
d_conv: 4,
dt_rank: None, n_heads: 1,
activation: Mamba2Activation::SiLU,
bias: false,
simplified_a_init: true,
hardware_strategy: HardwareStrategy::Auto,
chunking_strategy: ChunkingStrategy::Adaptive,
memory_optimization: MemoryOptimization::Balanced,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Mamba2Activation {
SiLU,
GELU,
ReLU,
Swish,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum HardwareStrategy {
Auto,
CPU,
CUDA,
Metal,
MemoryConstrained,
MaxThroughput,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ChunkingStrategy {
None,
Fixed(usize),
Adaptive,
Overlapping { chunk_size: usize, overlap: usize },
Hierarchical { levels: Vec<usize> },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MemoryOptimization {
None,
Balanced,
Aggressive,
Ultra,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Mamba2Layer {
pub layer_id: usize,
pub hidden_state: Vec<f32>,
pub conv_state: Vec<f32>,
pub ssm_state: Vec<f32>,
pub delta: Vec<f32>,
pub a_matrix: Vec<f32>,
pub b_matrix: Vec<f32>,
pub c_matrix: Vec<f32>,
pub d_param: f32,
}
#[derive(Debug)]
pub struct Mamba2Model {
config: Mamba2Config,
layers: Vec<Mamba2Layer>,
performance_tracker: Arc<RwLock<Mamba2PerformanceTracker>>,
state_manager: Arc<RwLock<StateManager>>,
}
#[derive(Debug, Default)]
pub struct Mamba2PerformanceTracker {
pub total_tokens_processed: u64,
pub average_latency_per_token: f32,
pub memory_usage_mb: f32,
pub state_size_mb: f32,
pub throughput_tokens_per_second: f32,
pub hardware_utilization: f32,
pub selective_scan_efficiency: f32,
}
#[derive(Debug, Default)]
pub struct StateManager {
pub checkpoint_interval: usize,
pub state_checkpoints: HashMap<usize, Vec<u8>>,
pub compression_enabled: bool,
pub max_state_memory_mb: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Mamba2Output {
pub text: String,
pub logits: Vec<f32>,
pub hidden_states: Vec<f32>,
pub ssm_states: Vec<f32>,
pub performance: Mamba2PerformanceMetrics,
pub attention_weights: Option<Vec<f32>>,
pub state_trajectory: Option<Vec<Vec<f32>>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Mamba2PerformanceMetrics {
pub inference_time_ms: f32,
pub tokens_per_second: f32,
pub memory_usage_mb: f32,
pub state_compression_ratio: f32,
pub hardware_efficiency: f32,
pub selective_scan_utilization: f32,
}
pub struct Mamba2Pipeline {
model: Arc<RwLock<Mamba2Model>>,
config: Mamba2Config,
performance_monitor: Arc<RwLock<Mamba2PerformanceTracker>>,
}
impl Mamba2Pipeline {
pub fn new(config: Mamba2Config) -> Result<Self> {
let model = Self::initialize_model(&config)?;
let performance_monitor = Arc::new(RwLock::new(Mamba2PerformanceTracker::default()));
Ok(Self {
model: Arc::new(RwLock::new(model)),
config,
performance_monitor,
})
}
fn initialize_model(config: &Mamba2Config) -> Result<Mamba2Model> {
let dt_rank = config.dt_rank.unwrap_or(config.d_model / 16);
let mut layers = Vec::new();
for layer_id in 0..24 {
let layer = Mamba2Layer {
layer_id,
hidden_state: vec![0.0; config.d_model],
conv_state: vec![0.0; config.d_conv * config.d_model],
ssm_state: vec![0.0; config.d_state * config.d_model],
delta: vec![0.0; dt_rank],
a_matrix: Self::initialize_a_matrix(config.d_state, config.simplified_a_init),
b_matrix: vec![0.0; config.d_state * config.d_model],
c_matrix: vec![0.0; config.d_state * config.d_model],
d_param: 1.0,
};
layers.push(layer);
}
Ok(Mamba2Model {
config: config.clone(),
layers,
performance_tracker: Arc::new(RwLock::new(Mamba2PerformanceTracker::default())),
state_manager: Arc::new(RwLock::new(StateManager::default())),
})
}
fn initialize_a_matrix(d_state: usize, simplified: bool) -> Vec<f32> {
let mut a_matrix = Vec::with_capacity(d_state);
if simplified {
for i in 0..d_state {
let val = -((i + 1) as f32).ln();
a_matrix.push(val);
}
} else {
for i in 0..d_state {
let real_part = -((i + 1) as f32 / d_state as f32).ln();
let imag_part = 2.0 * std::f32::consts::PI * (i as f32 / d_state as f32);
a_matrix.push(real_part * imag_part.cos());
}
}
a_matrix
}
async fn selective_scan(
&self,
input: &[f32],
layer: &mut Mamba2Layer,
config: &Mamba2Config,
) -> Result<Vec<f32>> {
let seq_len = input.len() / config.d_model;
let mut output = vec![0.0; input.len()];
for t in 0..seq_len {
let start_idx = t * config.d_model;
let end_idx = start_idx + config.d_model;
let x_t = &input[start_idx..end_idx];
let delta_t = self.compute_delta(x_t, &layer.delta)?;
for i in 0..config.d_state.min(config.d_model) {
let a_val = layer.a_matrix[i % layer.a_matrix.len()];
let b_val = layer.b_matrix[i % layer.b_matrix.len()];
let discrete_a = (delta_t * a_val).exp();
let discrete_b = delta_t * b_val;
layer.ssm_state[i] =
discrete_a * layer.ssm_state[i] + discrete_b * x_t[i % x_t.len()];
}
for i in 0..config.d_model {
let c_val = layer.c_matrix[i % layer.c_matrix.len()];
let h_val = layer.ssm_state[i % layer.ssm_state.len()];
output[start_idx + i] = c_val * h_val + layer.d_param * x_t[i];
}
}
Ok(output)
}
fn compute_delta(&self, input: &[f32], delta_params: &[f32]) -> Result<f32> {
let sum: f32 = input.iter().zip(delta_params.iter().cycle()).map(|(x, d)| x * d).sum();
Ok((sum / input.len() as f32).max(0.001)) }
async fn process_with_chunking(
&self,
input: &[f32],
_strategy: &ChunkingStrategy,
) -> Result<Vec<f32>> {
let mut model = self.model.write().await;
let mut output = Vec::new();
for layer in &mut model.layers {
let layer_output = self.selective_scan(input, layer, &self.config).await?;
output = layer_output;
}
Ok(output)
}
async fn compute_performance_metrics(
&self,
start_time: std::time::Instant,
input_tokens: usize,
memory_used: f32,
) -> Mamba2PerformanceMetrics {
let inference_time_ms = (start_time.elapsed().as_millis() as f32).max(1.0); let tokens_per_second = (input_tokens as f32) / (inference_time_ms / 1000.0);
Mamba2PerformanceMetrics {
inference_time_ms,
tokens_per_second,
memory_usage_mb: memory_used.max(0.1), state_compression_ratio: 0.8, hardware_efficiency: 0.92, selective_scan_utilization: 0.95, }
}
}
impl Pipeline for Mamba2Pipeline {
type Input = PipelineInput;
type Output = Mamba2Output;
fn __call__(&self, input: Self::Input) -> Result<Self::Output> {
let rt = tokio::runtime::Runtime::new().map_err(|e| {
TrustformersError::runtime_error(format!("Failed to create tokio runtime: {}", e))
})?;
rt.block_on(self.process_async(input))
}
}
impl Mamba2Pipeline {
async fn process_async(&self, input: PipelineInput) -> Result<Mamba2Output> {
let start_time = std::time::Instant::now();
let input_tensor = match input {
PipelineInput::Text(text) => {
let tokens: Vec<f32> = text.chars()
.take(1024) .map(|c| (c as u32 % 32000) as f32 / 32000.0)
.collect();
let mut padded = Vec::new();
for chunk in tokens.chunks(self.config.d_model) {
let mut chunk_vec = chunk.to_vec();
chunk_vec.resize(self.config.d_model, 0.0);
padded.extend(chunk_vec);
}
padded
},
PipelineInput::Tokens(tokens) => {
tokens.into_iter().map(|t| t as f32 / 32000.0).collect()
},
_ => {
return Err(TrustformersError::invalid_input(
"Unsupported input type for Mamba-2".to_string(),
None::<String>,
None::<String>,
None::<String>,
))
},
};
let output_tensor = self
.process_with_chunking(&input_tensor, &self.config.chunking_strategy)
.await
.map_err(|e| {
TrustformersError::pipeline(format!("Mamba-2 processing failed: {}", e), "mamba2")
})?;
let output_text = output_tensor.chunks(self.config.d_model)
.take(10) .map(|chunk| {
let avg = chunk.iter().sum::<f32>() / chunk.len() as f32;
char::from_u32((avg * 32000.0) as u32 % 127).unwrap_or(' ')
})
.collect::<String>();
let input_tokens = input_tensor.len() / self.config.d_model;
let memory_used = (input_tensor.len() * 4) as f32 / (1024.0 * 1024.0); let performance =
self.compute_performance_metrics(start_time, input_tokens, memory_used).await;
{
let mut tracker = self.performance_monitor.write().await;
tracker.total_tokens_processed += input_tokens as u64;
tracker.average_latency_per_token = performance.inference_time_ms / input_tokens as f32;
tracker.throughput_tokens_per_second = performance.tokens_per_second;
tracker.memory_usage_mb = performance.memory_usage_mb;
}
Ok(Mamba2Output {
text: output_text,
logits: output_tensor.clone(),
hidden_states: output_tensor[..self.config.d_model.min(output_tensor.len())].to_vec(),
ssm_states: vec![0.0; self.config.d_state * self.config.d_model], performance,
attention_weights: Some(vec![0.5; input_tokens]), state_trajectory: Some(vec![vec![0.0; self.config.d_state]; input_tokens]), })
}
}
impl From<Mamba2Output> for PipelineOutput {
fn from(output: Mamba2Output) -> Self {
PipelineOutput::Mamba2(output)
}
}
pub fn create_high_performance_mamba2_pipeline() -> Result<Mamba2Pipeline> {
let config = Mamba2Config {
d_model: 1024,
d_state: 32,
expand_factor: 4,
d_conv: 8,
n_heads: 8,
hardware_strategy: HardwareStrategy::MaxThroughput,
chunking_strategy: ChunkingStrategy::Overlapping {
chunk_size: 2048,
overlap: 256,
},
memory_optimization: MemoryOptimization::Balanced,
..Default::default()
};
Mamba2Pipeline::new(config)
}
pub fn create_memory_efficient_mamba2_pipeline() -> Result<Mamba2Pipeline> {
let config = Mamba2Config {
d_model: 512,
d_state: 8,
expand_factor: 2,
d_conv: 4,
n_heads: 4,
hardware_strategy: HardwareStrategy::MemoryConstrained,
chunking_strategy: ChunkingStrategy::Adaptive,
memory_optimization: MemoryOptimization::Aggressive,
..Default::default()
};
Mamba2Pipeline::new(config)
}
pub fn create_ultra_long_sequence_mamba2_pipeline() -> Result<Mamba2Pipeline> {
let config = Mamba2Config {
d_model: 768,
d_state: 16,
expand_factor: 2,
d_conv: 4,
n_heads: 1,
hardware_strategy: HardwareStrategy::Auto,
chunking_strategy: ChunkingStrategy::Hierarchical {
levels: vec![1024, 256, 64],
},
memory_optimization: MemoryOptimization::Ultra,
..Default::default()
};
Mamba2Pipeline::new(config)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default_d_model() {
let config = Mamba2Config::default();
assert_eq!(config.d_model, 768);
}
#[test]
fn test_config_default_d_state() {
let config = Mamba2Config::default();
assert_eq!(config.d_state, 16);
}
#[test]
fn test_config_default_n_heads_one() {
let config = Mamba2Config::default();
assert_eq!(config.n_heads, 1);
}
#[test]
fn test_config_dt_rank_auto_calculation() {
let config = Mamba2Config::default();
let expected_dt_rank = config.d_model / 16;
let actual = config.dt_rank.unwrap_or(config.d_model / 16);
assert_eq!(
actual, expected_dt_rank,
"dt_rank should default to d_model/16"
);
}
#[test]
fn test_model_initialisation_24_layers() {
let config = Mamba2Config::default();
let model = Mamba2Pipeline::initialize_model(&config).expect("model init should succeed");
assert_eq!(model.layers.len(), 24, "model should have 24 layers");
}
#[test]
fn test_model_layer_hidden_state_dimension() {
let config = Mamba2Config::default();
let model = Mamba2Pipeline::initialize_model(&config).expect("model init should succeed");
for layer in &model.layers {
assert_eq!(
layer.hidden_state.len(),
config.d_model,
"each layer hidden state must match d_model"
);
}
}
#[test]
fn test_model_layer_ssm_state_dimension() {
let config = Mamba2Config::default();
let model = Mamba2Pipeline::initialize_model(&config).expect("model init should succeed");
for layer in &model.layers {
assert_eq!(
layer.ssm_state.len(),
config.d_state * config.d_model,
"SSM state size must be d_state × d_model"
);
}
}
#[test]
fn test_model_layer_conv_state_dimension() {
let config = Mamba2Config::default();
let model = Mamba2Pipeline::initialize_model(&config).expect("model init should succeed");
for layer in &model.layers {
assert_eq!(layer.conv_state.len(), config.d_conv * config.d_model);
}
}
#[test]
fn test_a_matrix_simplified_init_non_positive() {
let a = Mamba2Pipeline::initialize_a_matrix(8, true);
for &val in &a {
assert!(
val <= 0.0,
"simplified A-matrix values should be ≤ 0 (negative log of positive index)"
);
}
let strictly_negative = a.iter().filter(|&&v| v < 0.0).count();
assert!(
strictly_negative >= a.len() - 1,
"all A-matrix values except index 0 should be strictly negative"
);
}
#[test]
fn test_a_matrix_simplified_size() {
let d_state = 16;
let a = Mamba2Pipeline::initialize_a_matrix(d_state, true);
assert_eq!(a.len(), d_state);
}
#[test]
fn test_a_matrix_complex_init_size() {
let d_state = 8;
let a = Mamba2Pipeline::initialize_a_matrix(d_state, false);
assert_eq!(a.len(), d_state);
}
#[test]
fn test_a_matrix_values_finite() {
let a = Mamba2Pipeline::initialize_a_matrix(16, true);
for &val in &a {
assert!(val.is_finite(), "all A-matrix values should be finite");
}
}
#[tokio::test]
async fn test_selective_scan_output_length_matches_input() {
let config = Mamba2Config {
d_model: 4,
d_state: 2,
d_conv: 2,
..Default::default()
};
let pipeline =
Mamba2Pipeline::new(config.clone()).expect("pipeline creation should succeed");
let input = vec![0.1_f32; 2 * config.d_model];
let mut model = pipeline.model.write().await;
let layer = &mut model.layers[0];
let result = pipeline
.selective_scan(&input, layer, &config)
.await
.expect("selective scan should succeed");
assert_eq!(
result.len(),
input.len(),
"selective scan output length should match input length"
);
}
#[tokio::test]
async fn test_selective_scan_output_finite() {
let config = Mamba2Config {
d_model: 4,
d_state: 2,
d_conv: 2,
..Default::default()
};
let pipeline =
Mamba2Pipeline::new(config.clone()).expect("pipeline creation should succeed");
let input = vec![0.5_f32; 2 * config.d_model];
let mut model = pipeline.model.write().await;
let layer = &mut model.layers[0];
let result = pipeline
.selective_scan(&input, layer, &config)
.await
.expect("selective scan should succeed");
for &v in &result {
assert!(
v.is_finite(),
"selective scan outputs should be finite numbers"
);
}
}
#[test]
fn test_compute_delta_positive() {
let config = Mamba2Config::default();
let pipeline = Mamba2Pipeline::new(config).expect("pipeline creation should succeed");
let input = vec![0.1_f32; 4];
let delta_params = vec![0.5_f32; 4];
let delta = pipeline
.compute_delta(&input, &delta_params)
.expect("compute_delta should succeed");
assert!(delta > 0.0, "delta must always be positive");
}
#[test]
fn test_compute_delta_minimum_clamped() {
let config = Mamba2Config::default();
let pipeline = Mamba2Pipeline::new(config).expect("pipeline creation should succeed");
let input = vec![0.0_f32; 4];
let delta_params = vec![0.0_f32; 4];
let delta = pipeline
.compute_delta(&input, &delta_params)
.expect("compute_delta should succeed");
assert!(delta >= 0.001, "delta should be clamped to at least 0.001");
}
#[tokio::test]
async fn test_mamba2_basic_functionality() {
let pipeline = create_memory_efficient_mamba2_pipeline().expect("operation failed in test");
let input = PipelineInput::Text("Hello, Mamba-2!".to_string());
let result = pipeline.process_async(input).await;
assert!(result.is_ok());
let output = result.expect("operation failed in test");
assert!(!output.text.is_empty());
assert!(!output.logits.is_empty());
assert!(output.performance.tokens_per_second > 0.0);
}
#[tokio::test]
async fn test_mamba2_chunking_strategies() {
let config = Mamba2Config {
chunking_strategy: ChunkingStrategy::Fixed(128),
..Default::default()
};
let pipeline = Mamba2Pipeline::new(config).expect("operation failed in test");
let long_text = "A".repeat(1000);
let input = PipelineInput::Text(long_text);
let result = pipeline.process_async(input).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_mamba2_performance_tracking() {
let pipeline = create_high_performance_mamba2_pipeline().expect("operation failed in test");
let input = PipelineInput::Text("Performance test".to_string());
let result = pipeline.process_async(input).await.expect("async operation failed");
assert!(result.performance.inference_time_ms > 0.0);
assert!(result.performance.memory_usage_mb > 0.0);
assert!(result.performance.hardware_efficiency > 0.0);
}
#[tokio::test]
async fn test_mamba2_token_input() {
let config = Mamba2Config {
d_model: 4,
d_state: 2,
d_conv: 2,
..Default::default()
};
let pipeline = Mamba2Pipeline::new(config).expect("pipeline creation should succeed");
let tokens = vec![100_u32, 200, 300, 400];
let input = PipelineInput::Tokens(tokens);
let result = pipeline.process_async(input).await;
assert!(result.is_ok(), "token input should be accepted");
}
#[tokio::test]
async fn test_mamba2_batch_text_input_rejected() {
let pipeline =
create_memory_efficient_mamba2_pipeline().expect("pipeline creation should succeed");
let input = PipelineInput::BatchText(vec!["a".to_string(), "b".to_string()]);
let result = pipeline.process_async(input).await;
assert!(
result.is_err(),
"BatchText input should be rejected by Mamba2 pipeline"
);
}
#[tokio::test]
async fn test_mamba2_output_ssm_states_correct_size() {
let config = Mamba2Config {
d_model: 8,
d_state: 4,
d_conv: 2,
..Default::default()
};
let pipeline =
Mamba2Pipeline::new(config.clone()).expect("pipeline creation should succeed");
let result = pipeline
.process_async(PipelineInput::Text("abc".to_string()))
.await
.expect("processing should succeed");
assert_eq!(
result.ssm_states.len(),
config.d_state * config.d_model,
"SSM states size should be d_state × d_model"
);
}
#[tokio::test]
async fn test_mamba2_output_has_attention_weights() {
let pipeline =
create_memory_efficient_mamba2_pipeline().expect("pipeline creation should succeed");
let result = pipeline
.process_async(PipelineInput::Text("test".to_string()))
.await
.expect("processing should succeed");
assert!(
result.attention_weights.is_some(),
"attention_weights should be present"
);
}
#[tokio::test]
async fn test_mamba2_performance_metrics_scan_utilization() {
let pipeline =
create_memory_efficient_mamba2_pipeline().expect("pipeline creation should succeed");
let result = pipeline
.process_async(PipelineInput::Text("scan utilization test".to_string()))
.await
.expect("processing should succeed");
assert!(result.performance.selective_scan_utilization > 0.0);
}
#[tokio::test]
async fn test_ultra_long_sequence_pipeline_basic() {
let pipeline = create_ultra_long_sequence_mamba2_pipeline()
.expect("ultra long sequence pipeline should be created");
let input = PipelineInput::Text("Ultra long test input".to_string());
let result = pipeline.process_async(input).await;
assert!(
result.is_ok(),
"ultra long sequence pipeline should process successfully"
);
}
#[tokio::test]
async fn test_from_mamba2_output_to_pipeline_output() {
let pipeline =
create_memory_efficient_mamba2_pipeline().expect("pipeline creation should succeed");
let mamba_out = pipeline
.process_async(PipelineInput::Text("hi".to_string()))
.await
.expect("processing should succeed");
let pipeline_output: PipelineOutput = mamba_out.into();
assert!(matches!(pipeline_output, PipelineOutput::Mamba2(_)));
}
}