Skip to main content

mamba_rs/
lib.rs

1//! # mamba-rs
2//!
3//! Mamba SSM (Selective State Space Model) implementation in Rust.
4//!
5//! Provides both CPU and GPU (CUDA) paths for inference and training,
6//! including full backward pass with BPTT through recurrent state.
7//!
8//! Based on: Gu & Dao, "Mamba: Linear-Time Sequence Modeling with
9//! Selective State Spaces" (NeurIPS 2024).
10//!
11//! ## Module Structure
12//!
13//! - [`mamba_ssm`] — Mamba-1 implementation (CPU + GPU)
14//!   - `cpu/` — inference, forward, backward
15//!   - `gpu/` — CUDA inference, forward, backward
16//! - [`ops`] — shared operations (dims, BLAS, math)
17//! - [`module`] — high-level MambaBackbone API
18//! - [`config`], [`state`], [`weights`], [`serialize`] — data types
19
20pub mod config;
21pub mod mamba_ssm;
22pub mod module;
23pub mod ops;
24pub mod serialize;
25pub mod state;
26pub mod weights;
27
28// Re-export old paths for backward compatibility during transition.
29// These will be removed once all external users migrate.
30pub mod inference {
31    pub use crate::mamba_ssm::cpu::inference::*;
32}
33pub mod train {
34    pub use crate::mamba_ssm::cpu::backward;
35    pub use crate::mamba_ssm::cpu::backward_ops;
36    pub use crate::mamba_ssm::cpu::flat;
37    pub use crate::mamba_ssm::cpu::forward;
38    pub use crate::mamba_ssm::cpu::parallel;
39    pub use crate::mamba_ssm::cpu::scratch;
40    pub use crate::mamba_ssm::cpu::target;
41    pub use crate::mamba_ssm::cpu::weights;
42
43    // Re-export shared ops that were previously in train/
44    pub use crate::ops::blas;
45    pub use crate::ops::fast_math;
46}
47
48#[cfg(feature = "cuda")]
49pub mod gpu {
50    pub use crate::mamba_ssm::gpu::*;
51}
52
53pub use config::MambaConfig;
54pub use mamba_ssm::cpu::inference::{
55    MambaLayerScratch, MambaStepScratch, mamba_block_step, mamba_layer_step, mamba_step,
56};
57pub use module::MambaBackbone;
58pub use state::{MambaLayerState, MambaState};
59pub use weights::{MambaLayerWeights, MambaWeights};