irithyll/ssm/
mamba_config.rs1use std::fmt;
19
20use crate::error::ConfigError;
21
22#[derive(Debug, Clone)]
37pub struct MambaConfig {
38 pub d_in: usize,
40 pub n_state: usize,
42 pub forgetting_factor: f64,
44 pub delta_rls: f64,
46 pub seed: u64,
48 pub warmup: usize,
50}
51
52impl MambaConfig {
53 pub fn builder() -> MambaConfigBuilder {
57 MambaConfigBuilder::default()
58 }
59}
60
61impl fmt::Display for MambaConfig {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 write!(
64 f,
65 "MambaConfig(d_in={}, n_state={}, ff={}, delta={}, seed={}, warmup={})",
66 self.d_in, self.n_state, self.forgetting_factor, self.delta_rls, self.seed, self.warmup
67 )
68 }
69}
70
71#[derive(Debug)]
93pub struct MambaConfigBuilder {
94 d_in: Option<usize>,
95 n_state: usize,
96 forgetting_factor: f64,
97 delta_rls: f64,
98 seed: u64,
99 warmup: usize,
100}
101
102impl Default for MambaConfigBuilder {
103 fn default() -> Self {
104 Self {
105 d_in: None,
106 n_state: 32,
107 forgetting_factor: 0.998,
108 delta_rls: 100.0,
109 seed: 42,
110 warmup: 10,
111 }
112 }
113}
114
115impl MambaConfigBuilder {
116 pub fn new() -> Self {
118 Self::default()
119 }
120
121 pub fn d_in(mut self, d_in: usize) -> Self {
123 self.d_in = Some(d_in);
124 self
125 }
126
127 pub fn n_state(mut self, n_state: usize) -> Self {
129 self.n_state = n_state;
130 self
131 }
132
133 pub fn forgetting_factor(mut self, ff: f64) -> Self {
135 self.forgetting_factor = ff;
136 self
137 }
138
139 pub fn delta_rls(mut self, delta: f64) -> Self {
141 self.delta_rls = delta;
142 self
143 }
144
145 pub fn seed(mut self, seed: u64) -> Self {
147 self.seed = seed;
148 self
149 }
150
151 pub fn warmup(mut self, warmup: usize) -> Self {
153 self.warmup = warmup;
154 self
155 }
156
157 pub fn build(self) -> Result<MambaConfig, ConfigError> {
167 let d_in = self.d_in.ok_or_else(|| {
168 ConfigError::invalid("d_in", "d_in must be set (input feature dimension)")
169 })?;
170 if d_in < 1 {
171 return Err(ConfigError::out_of_range("d_in", "must be >= 1", d_in));
172 }
173 if self.n_state < 1 {
174 return Err(ConfigError::out_of_range(
175 "n_state",
176 "must be >= 1",
177 self.n_state,
178 ));
179 }
180 if self.forgetting_factor <= 0.0 || self.forgetting_factor > 1.0 {
181 return Err(ConfigError::out_of_range(
182 "forgetting_factor",
183 "must be in (0, 1]",
184 self.forgetting_factor,
185 ));
186 }
187 if self.delta_rls <= 0.0 {
188 return Err(ConfigError::out_of_range(
189 "delta_rls",
190 "must be > 0",
191 self.delta_rls,
192 ));
193 }
194
195 Ok(MambaConfig {
196 d_in,
197 n_state: self.n_state,
198 forgetting_factor: self.forgetting_factor,
199 delta_rls: self.delta_rls,
200 seed: self.seed,
201 warmup: self.warmup,
202 })
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[test]
211 fn builder_defaults() {
212 let config = MambaConfig::builder().d_in(4).build().unwrap();
213 assert_eq!(config.d_in, 4);
214 assert_eq!(config.n_state, 32);
215 assert!((config.forgetting_factor - 0.998).abs() < 1e-12);
216 assert!((config.delta_rls - 100.0).abs() < 1e-12);
217 assert_eq!(config.seed, 42);
218 assert_eq!(config.warmup, 10);
219 }
220
221 #[test]
222 fn builder_custom_values() {
223 let config = MambaConfig::builder()
224 .d_in(8)
225 .n_state(32)
226 .forgetting_factor(0.99)
227 .delta_rls(50.0)
228 .seed(123)
229 .warmup(5)
230 .build()
231 .unwrap();
232 assert_eq!(config.d_in, 8);
233 assert_eq!(config.n_state, 32);
234 assert!((config.forgetting_factor - 0.99).abs() < 1e-12);
235 assert!((config.delta_rls - 50.0).abs() < 1e-12);
236 assert_eq!(config.seed, 123);
237 assert_eq!(config.warmup, 5);
238 }
239
240 #[test]
241 fn builder_missing_d_in() {
242 let result = MambaConfig::builder().build();
243 assert!(result.is_err(), "should fail without d_in");
244 }
245
246 #[test]
247 fn builder_invalid_n_state() {
248 let result = MambaConfig::builder().d_in(4).n_state(0).build();
249 assert!(result.is_err(), "n_state=0 should be invalid");
250 }
251
252 #[test]
253 fn builder_invalid_forgetting_factor_zero() {
254 let result = MambaConfig::builder()
255 .d_in(4)
256 .forgetting_factor(0.0)
257 .build();
258 assert!(result.is_err(), "ff=0 should be invalid");
259 }
260
261 #[test]
262 fn builder_invalid_forgetting_factor_negative() {
263 let result = MambaConfig::builder()
264 .d_in(4)
265 .forgetting_factor(-0.5)
266 .build();
267 assert!(result.is_err(), "ff=-0.5 should be invalid");
268 }
269
270 #[test]
271 fn builder_invalid_forgetting_factor_over_one() {
272 let result = MambaConfig::builder()
273 .d_in(4)
274 .forgetting_factor(1.01)
275 .build();
276 assert!(result.is_err(), "ff=1.01 should be invalid");
277 }
278
279 #[test]
280 fn builder_forgetting_factor_one_valid() {
281 let config = MambaConfig::builder()
282 .d_in(4)
283 .forgetting_factor(1.0)
284 .build()
285 .unwrap();
286 assert!((config.forgetting_factor - 1.0).abs() < 1e-12);
287 }
288
289 #[test]
290 fn builder_invalid_delta_rls() {
291 let result = MambaConfig::builder().d_in(4).delta_rls(0.0).build();
292 assert!(result.is_err(), "delta_rls=0 should be invalid");
293 let result = MambaConfig::builder().d_in(4).delta_rls(-1.0).build();
294 assert!(result.is_err(), "delta_rls=-1 should be invalid");
295 }
296
297 #[test]
298 fn display_format() {
299 let config = MambaConfig::builder().d_in(4).build().unwrap();
300 let s = format!("{}", config);
301 assert!(s.contains("d_in=4"), "display should contain d_in");
302 assert!(s.contains("n_state=32"), "display should contain n_state");
303 }
304
305 #[test]
306 fn config_clone() {
307 let config = MambaConfig::builder().d_in(4).seed(99).build().unwrap();
308 let cloned = config.clone();
309 assert_eq!(cloned.d_in, config.d_in);
310 assert_eq!(cloned.seed, config.seed);
311 }
312}