Skip to main content

irithyll/ssm/
mamba_config.rs

1//! Configuration and builder for [`StreamingMamba`](super::StreamingMamba).
2//!
3//! [`MambaConfig`] holds all hyperparameters for the streaming Mamba model.
4//! Use [`MambaConfigBuilder`] (via [`MambaConfig::builder()`]) for validated
5//! construction with sensible defaults.
6//!
7//! # Defaults
8//!
9//! | Parameter | Default | Description |
10//! |-----------|---------|-------------|
11//! | `d_in` | (required) | Input feature dimension |
12//! | `n_state` | 32 | Hidden state dimension per channel |
13//! | `forgetting_factor` | 0.998 | RLS exponential forgetting |
14//! | `delta_rls` | 100.0 | Initial P matrix diagonal for RLS |
15//! | `seed` | 42 | PRNG seed for SSM weight initialization |
16//! | `warmup` | 10 | Samples before RLS predictions are trusted |
17
18use std::fmt;
19
20use crate::error::ConfigError;
21
22/// Configuration for a [`StreamingMamba`](super::StreamingMamba) model.
23///
24/// Create via the builder pattern:
25///
26/// ```
27/// use irithyll::ssm::MambaConfig;
28///
29/// let config = MambaConfig::builder()
30///     .d_in(8)
31///     .n_state(16)
32///     .forgetting_factor(0.998)
33///     .build()
34///     .unwrap();
35/// ```
36#[derive(Debug, Clone)]
37pub struct MambaConfig {
38    /// Input/output feature dimension (required, >= 1).
39    pub d_in: usize,
40    /// Hidden state dimension per channel (default: 32, >= 1).
41    pub n_state: usize,
42    /// RLS forgetting factor (default: 0.998, in (0, 1]).
43    pub forgetting_factor: f64,
44    /// Initial P matrix diagonal for RLS (default: 100.0, > 0).
45    pub delta_rls: f64,
46    /// Random seed for SSM weight initialization (default: 42).
47    pub seed: u64,
48    /// Number of warmup samples before predictions are trusted (default: 10, >= 0).
49    pub warmup: usize,
50}
51
52impl MambaConfig {
53    /// Create a new builder with default values.
54    ///
55    /// Only `d_in` is required; all other parameters have sensible defaults.
56    pub fn builder() -> MambaConfigBuilder {
57        MambaConfigBuilder::default()
58    }
59}
60
61impl fmt::Display for MambaConfig {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        write!(
64            f,
65            "MambaConfig(d_in={}, n_state={}, ff={}, delta={}, seed={}, warmup={})",
66            self.d_in, self.n_state, self.forgetting_factor, self.delta_rls, self.seed, self.warmup
67        )
68    }
69}
70
71/// Builder for [`MambaConfig`] with validation.
72///
73/// # Required Parameters
74///
75/// - `d_in` -- must be set before calling `build()`
76///
77/// # Example
78///
79/// ```
80/// use irithyll::ssm::MambaConfig;
81///
82/// let config = MambaConfig::builder()
83///     .d_in(4)
84///     .n_state(32)
85///     .seed(123)
86///     .build()
87///     .unwrap();
88///
89/// assert_eq!(config.d_in, 4);
90/// assert_eq!(config.n_state, 32);
91/// ```
92#[derive(Debug)]
93pub struct MambaConfigBuilder {
94    d_in: Option<usize>,
95    n_state: usize,
96    forgetting_factor: f64,
97    delta_rls: f64,
98    seed: u64,
99    warmup: usize,
100}
101
102impl Default for MambaConfigBuilder {
103    fn default() -> Self {
104        Self {
105            d_in: None,
106            n_state: 32,
107            forgetting_factor: 0.998,
108            delta_rls: 100.0,
109            seed: 42,
110            warmup: 10,
111        }
112    }
113}
114
115impl MambaConfigBuilder {
116    /// Create a new builder with default values.
117    pub fn new() -> Self {
118        Self::default()
119    }
120
121    /// Set the input feature dimension (required, >= 1).
122    pub fn d_in(mut self, d_in: usize) -> Self {
123        self.d_in = Some(d_in);
124        self
125    }
126
127    /// Set the hidden state dimension per channel (default: 32, >= 1).
128    pub fn n_state(mut self, n_state: usize) -> Self {
129        self.n_state = n_state;
130        self
131    }
132
133    /// Set the RLS forgetting factor (default: 0.998, must be in (0, 1]).
134    pub fn forgetting_factor(mut self, ff: f64) -> Self {
135        self.forgetting_factor = ff;
136        self
137    }
138
139    /// Set the initial P matrix diagonal for RLS (default: 100.0, must be > 0).
140    pub fn delta_rls(mut self, delta: f64) -> Self {
141        self.delta_rls = delta;
142        self
143    }
144
145    /// Set the random seed for SSM weight initialization (default: 42).
146    pub fn seed(mut self, seed: u64) -> Self {
147        self.seed = seed;
148        self
149    }
150
151    /// Set the warmup period in samples (default: 10).
152    pub fn warmup(mut self, warmup: usize) -> Self {
153        self.warmup = warmup;
154        self
155    }
156
157    /// Build the config, validating all parameters.
158    ///
159    /// # Errors
160    ///
161    /// Returns [`ConfigError`] if:
162    /// - `d_in` was not set or is 0
163    /// - `n_state` is 0
164    /// - `forgetting_factor` is not in (0, 1]
165    /// - `delta_rls` is not positive
166    pub fn build(self) -> Result<MambaConfig, ConfigError> {
167        let d_in = self.d_in.ok_or_else(|| {
168            ConfigError::invalid("d_in", "d_in must be set (input feature dimension)")
169        })?;
170        if d_in < 1 {
171            return Err(ConfigError::out_of_range("d_in", "must be >= 1", d_in));
172        }
173        if self.n_state < 1 {
174            return Err(ConfigError::out_of_range(
175                "n_state",
176                "must be >= 1",
177                self.n_state,
178            ));
179        }
180        if self.forgetting_factor <= 0.0 || self.forgetting_factor > 1.0 {
181            return Err(ConfigError::out_of_range(
182                "forgetting_factor",
183                "must be in (0, 1]",
184                self.forgetting_factor,
185            ));
186        }
187        if self.delta_rls <= 0.0 {
188            return Err(ConfigError::out_of_range(
189                "delta_rls",
190                "must be > 0",
191                self.delta_rls,
192            ));
193        }
194
195        Ok(MambaConfig {
196            d_in,
197            n_state: self.n_state,
198            forgetting_factor: self.forgetting_factor,
199            delta_rls: self.delta_rls,
200            seed: self.seed,
201            warmup: self.warmup,
202        })
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[test]
211    fn builder_defaults() {
212        let config = MambaConfig::builder().d_in(4).build().unwrap();
213        assert_eq!(config.d_in, 4);
214        assert_eq!(config.n_state, 32);
215        assert!((config.forgetting_factor - 0.998).abs() < 1e-12);
216        assert!((config.delta_rls - 100.0).abs() < 1e-12);
217        assert_eq!(config.seed, 42);
218        assert_eq!(config.warmup, 10);
219    }
220
221    #[test]
222    fn builder_custom_values() {
223        let config = MambaConfig::builder()
224            .d_in(8)
225            .n_state(32)
226            .forgetting_factor(0.99)
227            .delta_rls(50.0)
228            .seed(123)
229            .warmup(5)
230            .build()
231            .unwrap();
232        assert_eq!(config.d_in, 8);
233        assert_eq!(config.n_state, 32);
234        assert!((config.forgetting_factor - 0.99).abs() < 1e-12);
235        assert!((config.delta_rls - 50.0).abs() < 1e-12);
236        assert_eq!(config.seed, 123);
237        assert_eq!(config.warmup, 5);
238    }
239
240    #[test]
241    fn builder_missing_d_in() {
242        let result = MambaConfig::builder().build();
243        assert!(result.is_err(), "should fail without d_in");
244    }
245
246    #[test]
247    fn builder_invalid_n_state() {
248        let result = MambaConfig::builder().d_in(4).n_state(0).build();
249        assert!(result.is_err(), "n_state=0 should be invalid");
250    }
251
252    #[test]
253    fn builder_invalid_forgetting_factor_zero() {
254        let result = MambaConfig::builder()
255            .d_in(4)
256            .forgetting_factor(0.0)
257            .build();
258        assert!(result.is_err(), "ff=0 should be invalid");
259    }
260
261    #[test]
262    fn builder_invalid_forgetting_factor_negative() {
263        let result = MambaConfig::builder()
264            .d_in(4)
265            .forgetting_factor(-0.5)
266            .build();
267        assert!(result.is_err(), "ff=-0.5 should be invalid");
268    }
269
270    #[test]
271    fn builder_invalid_forgetting_factor_over_one() {
272        let result = MambaConfig::builder()
273            .d_in(4)
274            .forgetting_factor(1.01)
275            .build();
276        assert!(result.is_err(), "ff=1.01 should be invalid");
277    }
278
279    #[test]
280    fn builder_forgetting_factor_one_valid() {
281        let config = MambaConfig::builder()
282            .d_in(4)
283            .forgetting_factor(1.0)
284            .build()
285            .unwrap();
286        assert!((config.forgetting_factor - 1.0).abs() < 1e-12);
287    }
288
289    #[test]
290    fn builder_invalid_delta_rls() {
291        let result = MambaConfig::builder().d_in(4).delta_rls(0.0).build();
292        assert!(result.is_err(), "delta_rls=0 should be invalid");
293        let result = MambaConfig::builder().d_in(4).delta_rls(-1.0).build();
294        assert!(result.is_err(), "delta_rls=-1 should be invalid");
295    }
296
297    #[test]
298    fn display_format() {
299        let config = MambaConfig::builder().d_in(4).build().unwrap();
300        let s = format!("{}", config);
301        assert!(s.contains("d_in=4"), "display should contain d_in");
302        assert!(s.contains("n_state=32"), "display should contain n_state");
303    }
304
305    #[test]
306    fn config_clone() {
307        let config = MambaConfig::builder().d_in(4).seed(99).build().unwrap();
308        let cloned = config.clone();
309        assert_eq!(cloned.d_in, config.d_in);
310        assert_eq!(cloned.seed, config.seed);
311    }
312}