#[derive(Debug, Clone)]
pub struct S4Config {
pub d_model: usize,
pub d_state: usize,
pub seq_len: usize,
pub dt_min: f64,
pub dt_max: f64,
pub bidirectional: bool,
pub dropout: f64,
}
impl Default for S4Config {
fn default() -> Self {
Self {
d_model: 64,
d_state: 16,
seq_len: 128,
dt_min: 0.001,
dt_max: 0.1,
bidirectional: false,
dropout: 0.0,
}
}
}
#[derive(Debug, Clone)]
pub struct MambaConfig {
pub d_model: usize,
pub d_state: usize,
pub d_conv: usize,
pub expand: usize,
pub dt_rank: usize,
pub seq_len: usize,
pub n_layers: usize,
pub dropout: f64,
}
impl Default for MambaConfig {
fn default() -> Self {
let d_model = 64_usize;
let dt_rank = (d_model + 15) / 16;
Self {
d_model,
d_state: 16,
d_conv: 4,
expand: 2,
dt_rank,
seq_len: 512,
n_layers: 4,
dropout: 0.0,
}
}
}
impl MambaConfig {
#[inline]
pub fn d_inner(&self) -> usize {
self.expand * self.d_model
}
pub fn with_auto_dt_rank(d_model: usize, d_state: usize, n_layers: usize) -> Self {
let dt_rank = (d_model + 15) / 16;
Self {
d_model,
d_state,
d_conv: 4,
expand: 2,
dt_rank,
seq_len: 512,
n_layers,
dropout: 0.0,
}
}
}