kizzasi_model/
hybrid.rs

1//! Hybrid Mamba+Attention Model
2//!
3//! This module implements a hybrid architecture combining Mamba's selective SSM
4//! with multi-head attention, achieving both efficiency and expressiveness.
5//!
6//! # Architecture Strategy
7//!
8//! The hybrid model alternates between Mamba layers (for local/efficient processing)
9//! and Attention layers (for global context), getting the best of both:
10//!
11//! - **Mamba layers**: O(1) per-step inference, selective state dynamics
12//! - **Attention layers**: Global context, explicit long-range dependencies
13//!
14//! # Layer Configuration
15//!
16//! ```text
17//! Input → [Mamba] → [Attention] → [Mamba] → [Attention] → ... → Output
18//! ```
19//!
20//! Or interleaved:
21//! ```text
22//! Input → [Mamba] → [Mamba] → [Attention] → [Mamba] → [Mamba] → [Attention] → ...
23//! ```
24//!
25//! # Use Cases
26//!
27//! - **Long sequences**: Attention provides global context while Mamba handles local patterns
28//! - **Few-shot learning**: Attention for in-context learning, Mamba for parameter efficiency
29//! - **Multimodal**: Different modalities can use different layer types
30//!
31//! # References
32//!
33//! - Combines ideas from Mamba and Transformer architectures
34//! - Inspired by hybrid models like Jamba (AI21 Labs)
35
36use crate::error::{ModelError, ModelResult};
37use crate::{AutoregressiveModel, ModelType};
38use kizzasi_core::{silu, softmax, CoreResult, HiddenState, SignalPredictor};
39use scirs2_core::ndarray::{Array1, Array2};
40use scirs2_core::random::{rng, Rng};
41use std::collections::VecDeque;
42
43#[allow(unused_imports)]
44use tracing::{debug, instrument, trace};
45
46/// Layer type in hybrid model
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum LayerType {
49    /// Mamba selective SSM layer
50    Mamba,
51    /// Multi-head attention layer
52    Attention,
53}
54
55/// Configuration for hybrid Mamba+Attention model
56#[derive(Debug, Clone)]
57pub struct HybridConfig {
58    /// Input dimension
59    pub input_dim: usize,
60    /// Hidden dimension
61    pub hidden_dim: usize,
62    /// State dimension for Mamba layers
63    pub state_dim: usize,
64    /// Total number of layers
65    pub num_layers: usize,
66    /// Number of attention heads
67    pub num_heads: usize,
68    /// Maximum sequence length for attention
69    pub max_seq_len: usize,
70    /// Layer pattern (e.g., [Mamba, Mamba, Attention, ...])
71    pub layer_pattern: Vec<LayerType>,
72}
73
74impl HybridConfig {
75    /// Create a new hybrid config with alternating layers
76    pub fn alternating(
77        input_dim: usize,
78        hidden_dim: usize,
79        num_layers: usize,
80        num_heads: usize,
81    ) -> Self {
82        let layer_pattern = (0..num_layers)
83            .map(|i| {
84                if i % 2 == 0 {
85                    LayerType::Mamba
86                } else {
87                    LayerType::Attention
88                }
89            })
90            .collect();
91
92        Self {
93            input_dim,
94            hidden_dim,
95            state_dim: 64,
96            num_layers,
97            num_heads,
98            max_seq_len: 2048,
99            layer_pattern,
100        }
101    }
102
103    /// Create a config with mostly Mamba, occasional attention
104    pub fn mamba_heavy(
105        input_dim: usize,
106        hidden_dim: usize,
107        num_layers: usize,
108        num_heads: usize,
109    ) -> Self {
110        let layer_pattern = (0..num_layers)
111            .map(|i| {
112                // Attention every 4 layers
113                if i % 4 == 3 {
114                    LayerType::Attention
115                } else {
116                    LayerType::Mamba
117                }
118            })
119            .collect();
120
121        Self {
122            input_dim,
123            hidden_dim,
124            state_dim: 64,
125            num_layers,
126            num_heads,
127            max_seq_len: 2048,
128            layer_pattern,
129        }
130    }
131
132    /// Validate configuration
133    pub fn validate(&self) -> ModelResult<()> {
134        if self.hidden_dim == 0 {
135            return Err(ModelError::invalid_config("hidden_dim must be > 0"));
136        }
137        if self.state_dim == 0 {
138            return Err(ModelError::invalid_config("state_dim must be > 0"));
139        }
140        if self.num_layers == 0 {
141            return Err(ModelError::invalid_config("num_layers must be > 0"));
142        }
143        if self.num_heads == 0 {
144            return Err(ModelError::invalid_config("num_heads must be > 0"));
145        }
146        if !self.hidden_dim.is_multiple_of(self.num_heads) {
147            return Err(ModelError::invalid_config(
148                "hidden_dim must be divisible by num_heads",
149            ));
150        }
151        if self.layer_pattern.len() != self.num_layers {
152            return Err(ModelError::invalid_config(
153                "layer_pattern length must equal num_layers",
154            ));
155        }
156        Ok(())
157    }
158}
159
160/// Simplified Mamba layer for hybrid model
161#[allow(dead_code)]
162struct MambaBlock {
163    hidden_dim: usize,
164    state_dim: usize,
165    /// Projection matrices
166    proj_in: Array2<f32>,
167    proj_out: Array2<f32>,
168    /// SSM parameters (simplified)
169    a_log: Array1<f32>,
170    b_matrix: Array2<f32>,
171    c_matrix: Array2<f32>,
172    /// Current state
173    state: Array1<f32>,
174}
175
176impl MambaBlock {
177    fn new(hidden_dim: usize, state_dim: usize) -> Self {
178        let mut rng = rng();
179
180        let scale = (2.0 / (hidden_dim + hidden_dim) as f32).sqrt();
181        let proj_in = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
182            (rng.random::<f32>() - 0.5) * 2.0 * scale
183        });
184
185        let proj_out = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
186            (rng.random::<f32>() - 0.5) * 2.0 * scale
187        });
188
189        // Initialize SSM parameters
190        let a_log = Array1::from_shape_fn(state_dim, |i| -((i + 1) as f32).ln());
191
192        let scale = (1.0 / state_dim as f32).sqrt();
193        let b_matrix = Array2::from_shape_fn((state_dim, hidden_dim), |_| {
194            (rng.random::<f32>() - 0.5) * 2.0 * scale
195        });
196
197        let c_matrix = Array2::from_shape_fn((hidden_dim, state_dim), |_| {
198            (rng.random::<f32>() - 0.5) * 2.0 * scale
199        });
200
201        let state = Array1::zeros(state_dim);
202
203        Self {
204            hidden_dim,
205            state_dim,
206            proj_in,
207            proj_out,
208            a_log,
209            b_matrix,
210            c_matrix,
211            state,
212        }
213    }
214
215    fn forward(&mut self, x: &Array1<f32>) -> Array1<f32> {
216        // Input projection
217        let projected = x.dot(&self.proj_in);
218
219        // SSM dynamics with selective mechanism
220        let a_bar = self.a_log.mapv(|a| (0.001 * a.exp()).exp());
221        self.state = &self.state * &a_bar + self.b_matrix.dot(&projected) * 0.001;
222
223        // Output
224        let ssm_out = self.c_matrix.dot(&self.state);
225
226        // Gate with SiLU
227        let gated = silu(&projected) * &ssm_out;
228
229        // Output projection
230        gated.dot(&self.proj_out)
231    }
232
233    fn reset(&mut self) {
234        self.state.fill(0.0);
235    }
236}
237
238/// Simplified attention layer for hybrid model
239#[allow(dead_code)]
240struct AttentionBlock {
241    hidden_dim: usize,
242    num_heads: usize,
243    head_dim: usize,
244    /// Query, Key, Value projections
245    q_proj: Array2<f32>,
246    k_proj: Array2<f32>,
247    v_proj: Array2<f32>,
248    /// Output projection
249    o_proj: Array2<f32>,
250    /// KV cache
251    k_cache: VecDeque<Array1<f32>>,
252    v_cache: VecDeque<Array1<f32>>,
253    max_cache_len: usize,
254}
255
256impl AttentionBlock {
257    fn new(hidden_dim: usize, num_heads: usize, max_seq_len: usize) -> Self {
258        let mut rng = rng();
259        let head_dim = hidden_dim / num_heads;
260
261        let scale = (2.0 / (hidden_dim + hidden_dim) as f32).sqrt();
262        let q_proj = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
263            (rng.random::<f32>() - 0.5) * 2.0 * scale
264        });
265        let k_proj = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
266            (rng.random::<f32>() - 0.5) * 2.0 * scale
267        });
268        let v_proj = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
269            (rng.random::<f32>() - 0.5) * 2.0 * scale
270        });
271        let o_proj = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
272            (rng.random::<f32>() - 0.5) * 2.0 * scale
273        });
274
275        Self {
276            hidden_dim,
277            num_heads,
278            head_dim,
279            q_proj,
280            k_proj,
281            v_proj,
282            o_proj,
283            k_cache: VecDeque::new(),
284            v_cache: VecDeque::new(),
285            max_cache_len: max_seq_len,
286        }
287    }
288
289    fn forward(&mut self, x: &Array1<f32>) -> Array1<f32> {
290        // Compute Q, K, V
291        let q = x.dot(&self.q_proj);
292        let k = x.dot(&self.k_proj);
293        let v = x.dot(&self.v_proj);
294
295        // Add to cache
296        self.k_cache.push_back(k.clone());
297        self.v_cache.push_back(v.clone());
298
299        // Trim cache
300        while self.k_cache.len() > self.max_cache_len {
301            self.k_cache.pop_front();
302            self.v_cache.pop_front();
303        }
304
305        // Compute attention (simplified single-head version)
306        let cache_len = self.k_cache.len();
307        let mut attention_out = Array1::zeros(self.hidden_dim);
308
309        if cache_len > 0 {
310            // Compute attention scores
311            let mut scores = Vec::with_capacity(cache_len);
312            for k_cached in &self.k_cache {
313                let score = q.dot(k_cached) / (self.head_dim as f32).sqrt();
314                scores.push(score);
315            }
316
317            // Softmax
318            let scores_array = Array1::from_vec(scores);
319            let attn_weights = softmax(&scores_array);
320
321            // Weighted sum of values
322            for (weight, v_cached) in attn_weights.iter().zip(self.v_cache.iter()) {
323                attention_out = attention_out + v_cached * *weight;
324            }
325        }
326
327        // Output projection
328        attention_out.dot(&self.o_proj)
329    }
330
331    fn reset(&mut self) {
332        self.k_cache.clear();
333        self.v_cache.clear();
334    }
335}
336
337/// Enum for hybrid layer
338enum HybridLayer {
339    Mamba(MambaBlock),
340    Attention(AttentionBlock),
341}
342
343impl HybridLayer {
344    fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
345        match self {
346            HybridLayer::Mamba(mamba) => Ok(mamba.forward(x)),
347            HybridLayer::Attention(attn) => Ok(attn.forward(x)),
348        }
349    }
350
351    fn reset(&mut self) {
352        match self {
353            HybridLayer::Mamba(mamba) => mamba.reset(),
354            HybridLayer::Attention(attn) => attn.reset(),
355        }
356    }
357}
358
359/// Hybrid Mamba+Attention model
360pub struct HybridModel {
361    config: HybridConfig,
362    layers: Vec<HybridLayer>,
363    /// Input/output projections
364    input_proj: Array2<f32>,
365    output_proj: Array2<f32>,
366}
367
368impl HybridModel {
369    /// Create a new hybrid model
370    #[instrument(skip(config), fields(num_layers = config.num_layers))]
371    pub fn new(config: HybridConfig) -> ModelResult<Self> {
372        debug!("Creating new Hybrid Mamba+Attention model");
373        config.validate()?;
374
375        let mut rng = rng();
376
377        // Input projection
378        let scale = (2.0 / (config.input_dim + config.hidden_dim) as f32).sqrt();
379        let input_proj = Array2::from_shape_fn((config.input_dim, config.hidden_dim), |_| {
380            (rng.random::<f32>() - 0.5) * 2.0 * scale
381        });
382
383        // Output projection
384        let scale = (2.0 / (config.hidden_dim + config.input_dim) as f32).sqrt();
385        let output_proj = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
386            (rng.random::<f32>() - 0.5) * 2.0 * scale
387        });
388
389        // Create layers based on pattern
390        let mut layers = Vec::with_capacity(config.num_layers);
391        for (i, &layer_type) in config.layer_pattern.iter().enumerate() {
392            trace!("Initializing hybrid layer {} as {:?}", i, layer_type);
393            let layer = match layer_type {
394                LayerType::Mamba => {
395                    HybridLayer::Mamba(MambaBlock::new(config.hidden_dim, config.state_dim))
396                }
397                LayerType::Attention => HybridLayer::Attention(AttentionBlock::new(
398                    config.hidden_dim,
399                    config.num_heads,
400                    config.max_seq_len,
401                )),
402            };
403            layers.push(layer);
404        }
405
406        debug!(
407            "Hybrid model created successfully with {} layers",
408            layers.len()
409        );
410        Ok(Self {
411            config,
412            layers,
413            input_proj,
414            output_proj,
415        })
416    }
417
418    /// Get configuration
419    pub fn config(&self) -> &HybridConfig {
420        &self.config
421    }
422
423    /// Count layers of each type
424    pub fn layer_counts(&self) -> (usize, usize) {
425        let mamba_count = self
426            .config
427            .layer_pattern
428            .iter()
429            .filter(|&&t| t == LayerType::Mamba)
430            .count();
431        let attention_count = self.config.num_layers - mamba_count;
432        (mamba_count, attention_count)
433    }
434}
435
436impl SignalPredictor for HybridModel {
437    #[instrument(skip(self, input))]
438    fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
439        // Project input
440        let mut hidden = input.dot(&self.input_proj);
441
442        // Pass through hybrid layers
443        for layer in &mut self.layers {
444            hidden = layer.forward(&hidden)?;
445        }
446
447        // Project output
448        let output = hidden.dot(&self.output_proj);
449        Ok(output)
450    }
451
452    #[instrument(skip(self))]
453    fn reset(&mut self) {
454        debug!("Resetting Hybrid model state");
455        for layer in &mut self.layers {
456            layer.reset();
457        }
458    }
459
460    fn context_window(&self) -> usize {
461        // Context window is determined by attention layers
462        self.config.max_seq_len
463    }
464}
465
466impl AutoregressiveModel for HybridModel {
467    fn hidden_dim(&self) -> usize {
468        self.config.hidden_dim
469    }
470
471    fn state_dim(&self) -> usize {
472        self.config.state_dim
473    }
474
475    fn num_layers(&self) -> usize {
476        self.config.num_layers
477    }
478
479    fn model_type(&self) -> ModelType {
480        ModelType::Mamba // Hybrid, but Mamba-based
481    }
482
483    fn get_states(&self) -> Vec<HiddenState> {
484        // Simplified state extraction
485        (0..self.config.num_layers)
486            .map(|_| HiddenState::new(self.config.hidden_dim, self.config.state_dim))
487            .collect()
488    }
489
490    fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
491        if states.len() != self.config.num_layers {
492            return Err(ModelError::state_count_mismatch(
493                "Hybrid",
494                self.config.num_layers,
495                states.len(),
496            ));
497        }
498        // State setting would require more complex handling of different layer types
499        Ok(())
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506
507    #[test]
508    fn test_hybrid_creation_alternating() {
509        let config = HybridConfig::alternating(32, 64, 4, 4);
510        let model = HybridModel::new(config);
511        assert!(model.is_ok());
512    }
513
514    #[test]
515    fn test_hybrid_creation_mamba_heavy() {
516        let config = HybridConfig::mamba_heavy(32, 64, 8, 4);
517        let model = HybridModel::new(config);
518        assert!(model.is_ok());
519    }
520
521    #[test]
522    fn test_hybrid_forward() {
523        let config = HybridConfig::alternating(32, 64, 4, 4);
524        let mut model = HybridModel::new(config).expect("Failed to create HybridModel");
525
526        let input = Array1::from_vec(vec![1.0; 32]);
527        let output = model.step(&input);
528        assert!(output.is_ok());
529        assert_eq!(output.expect("Failed to get output").len(), 32);
530    }
531
532    #[test]
533    fn test_hybrid_layer_counts() {
534        let config = HybridConfig::alternating(32, 64, 6, 4);
535        let model = HybridModel::new(config).expect("Failed to create HybridModel");
536        let (mamba, attn) = model.layer_counts();
537        assert_eq!(mamba, 3);
538        assert_eq!(attn, 3);
539    }
540
541    #[test]
542    fn test_hybrid_mamba_heavy_counts() {
543        let config = HybridConfig::mamba_heavy(32, 64, 8, 4);
544        let model = HybridModel::new(config).expect("Failed to create HybridModel");
545        let (mamba, attn) = model.layer_counts();
546        assert_eq!(mamba, 6);
547        assert_eq!(attn, 2);
548    }
549
550    #[test]
551    fn test_hybrid_reset() {
552        let config = HybridConfig::alternating(32, 64, 4, 4);
553        let mut model = HybridModel::new(config).expect("Failed to create HybridModel");
554
555        let input = Array1::from_vec(vec![0.5; 32]);
556        let _ = model.step(&input).expect("Failed to step model");
557
558        model.reset();
559
560        let output = model.step(&input).expect("Failed to get output");
561        assert_eq!(output.len(), 32);
562    }
563
564    #[test]
565    fn test_invalid_config() {
566        let mut config = HybridConfig::alternating(32, 64, 4, 4);
567        config.layer_pattern.push(LayerType::Mamba); // Mismatch with num_layers
568        assert!(config.validate().is_err());
569    }
570}