1use crate::error::{NeuralDecoderError, Result};
14use ndarray::{Array1, Array2, ArrayView1};
15use rand::Rng;
16use rand_distr::{Distribution, Normal};
17use serde::{Deserialize, Serialize};
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct MambaConfig {
22 pub input_dim: usize,
24 pub state_dim: usize,
26 pub output_dim: usize,
28}
29
30impl Default for MambaConfig {
31 fn default() -> Self {
32 Self {
33 input_dim: 128,
34 state_dim: 64,
35 output_dim: 25, }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct MambaState {
43 pub hidden: Vec<f32>,
45 pub dim: usize,
47 pub steps: usize,
49}
50
51impl MambaState {
52 pub fn new(dim: usize) -> Self {
54 Self {
55 hidden: vec![0.0; dim],
56 dim,
57 steps: 0,
58 }
59 }
60
61 pub fn reset(&mut self) {
63 self.hidden.fill(0.0);
64 self.steps = 0;
65 }
66
67 pub fn get(&self) -> &[f32] {
69 &self.hidden
70 }
71
72 pub fn update(&mut self, new_state: Vec<f32>) {
74 assert_eq!(new_state.len(), self.dim);
75 self.hidden = new_state;
76 self.steps += 1;
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82struct Linear {
83 weights: Array2<f32>,
84 bias: Array1<f32>,
85}
86
87impl Linear {
88 fn new(input_dim: usize, output_dim: usize) -> Self {
89 let mut rng = rand::thread_rng();
90 let scale = (2.0 / (input_dim + output_dim) as f32).sqrt();
91 let normal = Normal::new(0.0, scale as f64).unwrap();
92
93 let weights = Array2::from_shape_fn(
94 (output_dim, input_dim),
95 |_| normal.sample(&mut rng) as f32
96 );
97 let bias = Array1::zeros(output_dim);
98
99 Self { weights, bias }
100 }
101
102 fn forward(&self, input: &[f32]) -> Vec<f32> {
103 let x = ArrayView1::from(input);
104 let output = self.weights.dot(&x) + &self.bias;
105 output.to_vec()
106 }
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111struct SelectiveScan {
112 delta_proj: Linear,
114 b_proj: Linear,
116 c_proj: Linear,
118 delta_scale: f32,
120 state_dim: usize,
122}
123
124impl SelectiveScan {
125 fn new(input_dim: usize, state_dim: usize) -> Self {
126 Self {
127 delta_proj: Linear::new(input_dim, state_dim),
128 b_proj: Linear::new(input_dim, state_dim),
129 c_proj: Linear::new(input_dim, state_dim),
130 delta_scale: 0.1,
131 state_dim,
132 }
133 }
134
135 fn step(&self, input: &[f32], state: &[f32]) -> (Vec<f32>, Vec<f32>) {
137 let delta_raw = self.delta_proj.forward(input);
139 let b = self.b_proj.forward(input);
140 let c = self.c_proj.forward(input);
141
142 let delta: Vec<f32> = delta_raw.iter()
144 .map(|&x| (1.0 + (x * self.delta_scale).exp()).ln())
145 .collect();
146
147 let input_norm: f32 = input.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-6);
150
151 let mut new_state = vec![0.0; self.state_dim];
152 for i in 0..self.state_dim {
153 let decay = (-delta[i]).exp();
154 let input_contrib = delta[i] * b[i] * (input_norm / (self.state_dim as f32).sqrt());
155 new_state[i] = decay * state[i] + input_contrib;
156 }
157
158 let output: f32 = c.iter().zip(new_state.iter())
160 .map(|(ci, xi)| ci * xi)
161 .sum();
162
163 let output_vec = vec![output / (self.state_dim as f32).sqrt(); input.len()];
165
166 (new_state, output_vec)
167 }
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
172struct MambaBlock {
173 in_proj: Linear,
175 ssm: SelectiveScan,
177 gate_proj: Linear,
179 out_proj: Linear,
181 norm: Array1<f32>,
183 state_dim: usize,
185}
186
187impl MambaBlock {
188 fn new(input_dim: usize, state_dim: usize) -> Self {
189 Self {
190 in_proj: Linear::new(input_dim, state_dim),
191 ssm: SelectiveScan::new(state_dim, state_dim),
192 gate_proj: Linear::new(input_dim, state_dim),
193 out_proj: Linear::new(state_dim, input_dim),
194 norm: Array1::ones(state_dim),
195 state_dim,
196 }
197 }
198
199 fn forward(&self, input: &[f32], state: &[f32]) -> (Vec<f32>, Vec<f32>) {
200 let x = self.in_proj.forward(input);
202
203 let (new_state, ssm_out) = self.ssm.step(&x, state);
205
206 let gate_raw = self.gate_proj.forward(input);
208 let gate: Vec<f32> = gate_raw.iter()
209 .map(|&g| 1.0 / (1.0 + (-g).exp()))
210 .collect();
211
212 let gated: Vec<f32> = ssm_out.iter().zip(gate.iter().cycle())
214 .map(|(s, g)| s * g)
215 .collect();
216
217 let output_raw = self.out_proj.forward(&gated[..self.state_dim.min(gated.len())]);
219
220 let output: Vec<f32> = input.iter().zip(output_raw.iter().cycle())
222 .map(|(i, o)| i + o)
223 .collect();
224
225 (new_state, output)
226 }
227}
228
229#[derive(Debug, Clone)]
231pub struct MambaDecoder {
232 config: MambaConfig,
233 block: MambaBlock,
234 output_proj: Linear,
235 state: MambaState,
236}
237
238impl MambaDecoder {
239 pub fn new(config: MambaConfig) -> Self {
241 let block = MambaBlock::new(config.input_dim, config.state_dim);
242 let output_proj = Linear::new(config.input_dim, config.output_dim);
243 let state = MambaState::new(config.state_dim);
244
245 Self {
246 config,
247 block,
248 output_proj,
249 state,
250 }
251 }
252
253 pub fn decode(&mut self, embeddings: &Array2<f32>) -> Result<Array1<f32>> {
255 if embeddings.shape()[0] == 0 {
256 return Err(NeuralDecoderError::EmptyGraph);
257 }
258
259 let expected_dim = self.config.input_dim;
260 let actual_dim = embeddings.shape()[1];
261
262 if actual_dim != expected_dim {
263 return Err(NeuralDecoderError::embed_dim(expected_dim, actual_dim));
264 }
265
266 let mut aggregated = vec![0.0; self.config.input_dim];
268
269 for row in embeddings.rows() {
270 let input: Vec<f32> = row.to_vec();
271
272 let (new_state, output) = self.block.forward(&input, self.state.get());
274 self.state.update(new_state);
275
276 for (agg, out) in aggregated.iter_mut().zip(output.iter()) {
278 *agg += out;
279 }
280 }
281
282 let num_nodes = embeddings.shape()[0] as f32;
284 for val in &mut aggregated {
285 *val /= num_nodes;
286 }
287
288 let logits = self.output_proj.forward(&aggregated);
290
291 let probs: Vec<f32> = logits.iter()
293 .map(|&x| 1.0 / (1.0 + (-x).exp()))
294 .collect();
295
296 Ok(Array1::from_vec(probs))
297 }
298
299 pub fn decode_step(&mut self, embedding: &[f32]) -> Result<Vec<f32>> {
301 if embedding.len() != self.config.input_dim {
302 return Err(NeuralDecoderError::embed_dim(
303 self.config.input_dim,
304 embedding.len()
305 ));
306 }
307
308 let (new_state, output) = self.block.forward(embedding, self.state.get());
309 self.state.update(new_state);
310
311 Ok(output)
312 }
313
314 pub fn state(&self) -> &MambaState {
316 &self.state
317 }
318
319 pub fn reset(&mut self) {
321 self.state.reset();
322 }
323
324 pub fn config(&self) -> &MambaConfig {
326 &self.config
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn test_mamba_config_default() {
336 let config = MambaConfig::default();
337 assert_eq!(config.input_dim, 128);
338 assert_eq!(config.state_dim, 64);
339 assert_eq!(config.output_dim, 25);
340 }
341
342 #[test]
343 fn test_mamba_state_creation() {
344 let state = MambaState::new(64);
345 assert_eq!(state.dim, 64);
346 assert_eq!(state.steps, 0);
347 assert_eq!(state.get().len(), 64);
348
349 for &val in state.get() {
351 assert_eq!(val, 0.0);
352 }
353 }
354
355 #[test]
356 fn test_mamba_state_update() {
357 let mut state = MambaState::new(4);
358 let new_values = vec![1.0, 2.0, 3.0, 4.0];
359 state.update(new_values.clone());
360
361 assert_eq!(state.steps, 1);
362 assert_eq!(state.get(), &new_values[..]);
363 }
364
365 #[test]
366 fn test_mamba_state_reset() {
367 let mut state = MambaState::new(4);
368 state.update(vec![1.0, 2.0, 3.0, 4.0]);
369 state.update(vec![5.0, 6.0, 7.0, 8.0]);
370
371 assert_eq!(state.steps, 2);
372
373 state.reset();
374
375 assert_eq!(state.steps, 0);
376 for &val in state.get() {
377 assert_eq!(val, 0.0);
378 }
379 }
380
381 #[test]
382 fn test_mamba_decoder_creation() {
383 let config = MambaConfig::default();
384 let decoder = MambaDecoder::new(config);
385
386 assert_eq!(decoder.config().input_dim, 128);
387 assert_eq!(decoder.state().dim, 64);
388 }
389
390 #[test]
391 fn test_mamba_decode() {
392 let config = MambaConfig {
393 input_dim: 32,
394 state_dim: 16,
395 output_dim: 9,
396 };
397 let mut decoder = MambaDecoder::new(config);
398
399 let embeddings = Array2::from_shape_fn((9, 32), |(_i, _j)| 0.5);
401
402 let output = decoder.decode(&embeddings).unwrap();
403 assert_eq!(output.len(), 9);
404
405 for &prob in output.iter() {
407 assert!(prob >= 0.0 && prob <= 1.0);
408 }
409 }
410
411 #[test]
412 fn test_mamba_decode_updates_state() {
413 let config = MambaConfig {
414 input_dim: 32,
415 state_dim: 16,
416 output_dim: 9,
417 };
418 let mut decoder = MambaDecoder::new(config);
419
420 let embeddings = Array2::from_shape_fn((9, 32), |(_i, _j)| 0.5);
421
422 assert_eq!(decoder.state().steps, 0);
423
424 decoder.decode(&embeddings).unwrap();
425
426 assert_eq!(decoder.state().steps, 9);
428 }
429
430 #[test]
431 fn test_mamba_decode_step() {
432 let config = MambaConfig {
433 input_dim: 32,
434 state_dim: 16,
435 output_dim: 9,
436 };
437 let mut decoder = MambaDecoder::new(config);
438
439 let embedding = vec![0.5; 32];
440 let output = decoder.decode_step(&embedding).unwrap();
441
442 assert_eq!(output.len(), 32); assert_eq!(decoder.state().steps, 1);
444 }
445
446 #[test]
447 fn test_mamba_decode_wrong_dimension() {
448 let config = MambaConfig {
449 input_dim: 32,
450 state_dim: 16,
451 output_dim: 9,
452 };
453 let mut decoder = MambaDecoder::new(config);
454
455 let embeddings = Array2::from_shape_fn((9, 64), |(_i, _j)| 0.5);
457 let result = decoder.decode(&embeddings);
458
459 assert!(result.is_err());
460 }
461
462 #[test]
463 fn test_mamba_decode_empty() {
464 let config = MambaConfig {
465 input_dim: 32,
466 state_dim: 16,
467 output_dim: 9,
468 };
469 let mut decoder = MambaDecoder::new(config);
470
471 let embeddings: Array2<f32> = Array2::zeros((0, 32));
472 let result = decoder.decode(&embeddings);
473
474 assert!(result.is_err());
475 }
476
477 #[test]
478 fn test_mamba_reset() {
479 let config = MambaConfig {
480 input_dim: 32,
481 state_dim: 16,
482 output_dim: 9,
483 };
484 let mut decoder = MambaDecoder::new(config);
485
486 let embeddings = Array2::from_shape_fn((9, 32), |(_i, _j)| 0.5);
487 decoder.decode(&embeddings).unwrap();
488
489 assert_eq!(decoder.state().steps, 9);
490
491 decoder.reset();
492
493 assert_eq!(decoder.state().steps, 0);
494 }
495
496 #[test]
497 fn test_mamba_sequential_decode() {
498 let config = MambaConfig {
499 input_dim: 16,
500 state_dim: 8,
501 output_dim: 4,
502 };
503 let mut decoder = MambaDecoder::new(config);
504
505 let embeddings: Vec<Vec<f32>> = (0..5)
507 .map(|i| vec![i as f32 * 0.1; 16])
508 .collect();
509
510 let mut outputs = Vec::new();
511 for emb in &embeddings {
512 let out = decoder.decode_step(emb).unwrap();
513 outputs.push(out);
514 }
515
516 assert_eq!(outputs.len(), 5);
517 assert_eq!(decoder.state().steps, 5);
518 }
519
520 #[test]
521 fn test_mamba_state_evolution() {
522 let config = MambaConfig {
523 input_dim: 8,
524 state_dim: 4,
525 output_dim: 2,
526 };
527 let mut decoder = MambaDecoder::new(config);
528
529 let emb1 = vec![1.0; 8];
530 let emb2 = vec![0.0; 8];
531
532 decoder.decode_step(&emb1).unwrap();
533 let state1 = decoder.state().get().to_vec();
534
535 decoder.decode_step(&emb2).unwrap();
536 let state2 = decoder.state().get().to_vec();
537
538 let diff: f32 = state1.iter().zip(state2.iter())
540 .map(|(a, b)| (a - b).abs())
541 .sum();
542 assert!(diff > 0.0);
543 }
544
545 #[test]
546 fn test_selective_scan_step() {
547 let ssm = SelectiveScan::new(8, 4);
548 let input = vec![0.5; 8];
549 let state = vec![0.0; 4];
550
551 let (new_state, output) = ssm.step(&input, &state);
552
553 assert_eq!(new_state.len(), 4);
554 assert_eq!(output.len(), 8);
555 }
556
557 #[test]
558 fn test_mamba_block_forward() {
559 let block = MambaBlock::new(8, 4);
560 let input = vec![0.5; 8];
561 let state = vec![0.0; 4];
562
563 let (new_state, output) = block.forward(&input, &state);
564
565 assert_eq!(new_state.len(), 4);
566 assert_eq!(output.len(), 8);
567 }
568}