kizzasi_core/
attention.rs

1//! Multi-head SSM Attention mechanisms
2//!
3//! Implements multi-head attention patterns optimized for State Space Models.
4//! Includes both standard multi-head attention and specialized SSM variants.
5
6use crate::error::{CoreError, CoreResult};
7use crate::numerics::{safe_exp, softmax_stable};
8use crate::simd;
9use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
10use scirs2_core::random::thread_rng;
11
12/// Multi-head SSM Attention configuration
13#[derive(Debug, Clone)]
14pub struct MultiHeadSSMConfig {
15    /// Model dimension (d_model)
16    pub hidden_dim: usize,
17    /// Number of attention heads
18    pub num_heads: usize,
19    /// Head dimension (d_model / num_heads)
20    pub head_dim: usize,
21    /// State dimension per head
22    pub state_dim: usize,
23    /// Dropout rate
24    pub dropout: f32,
25    /// Use causal masking
26    pub causal: bool,
27}
28
29impl MultiHeadSSMConfig {
30    /// Create a new configuration
31    pub fn new(hidden_dim: usize, num_heads: usize, state_dim: usize) -> CoreResult<Self> {
32        if !hidden_dim.is_multiple_of(num_heads) {
33            return Err(CoreError::InvalidConfig(format!(
34                "hidden_dim ({}) must be divisible by num_heads ({})",
35                hidden_dim, num_heads
36            )));
37        }
38
39        Ok(Self {
40            hidden_dim,
41            num_heads,
42            head_dim: hidden_dim / num_heads,
43            state_dim,
44            dropout: 0.0,
45            causal: true,
46        })
47    }
48
49    /// Set dropout rate
50    pub fn dropout(mut self, rate: f32) -> Self {
51        self.dropout = rate;
52        self
53    }
54
55    /// Set causal masking
56    pub fn causal(mut self, causal: bool) -> Self {
57        self.causal = causal;
58        self
59    }
60}
61
62/// Multi-head SSM Attention layer
63///
64/// Implements multi-head attention specifically optimized for SSMs:
65/// - Supports both standard attention and SSM-specific variants
66/// - SIMD-optimized matrix operations
67/// - Memory-efficient causal masking
68/// - Compatible with linear-time SSM inference
69#[derive(Debug)]
70pub struct MultiHeadSSMAttention {
71    config: MultiHeadSSMConfig,
72    // Query, Key, Value projections
73    w_q: Array2<f32>,
74    w_k: Array2<f32>,
75    w_v: Array2<f32>,
76    // Output projection
77    w_o: Array2<f32>,
78    // Optional bias terms
79    b_q: Option<Array1<f32>>,
80    b_k: Option<Array1<f32>>,
81    b_v: Option<Array1<f32>>,
82    b_o: Option<Array1<f32>>,
83}
84
85impl MultiHeadSSMAttention {
86    /// Create a new multi-head SSM attention layer
87    pub fn new(config: MultiHeadSSMConfig, use_bias: bool) -> CoreResult<Self> {
88        let hidden_dim = config.hidden_dim;
89        let mut rng = thread_rng();
90        let scale = (1.0 / hidden_dim as f32).sqrt();
91
92        // Initialize projection matrices with Xavier/Glorot initialization
93        let w_q = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
94            (rng.random::<f32>() - 0.5) * 2.0 * scale
95        });
96        let w_k = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
97            (rng.random::<f32>() - 0.5) * 2.0 * scale
98        });
99        let w_v = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
100            (rng.random::<f32>() - 0.5) * 2.0 * scale
101        });
102        let w_o = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
103            (rng.random::<f32>() - 0.5) * 2.0 * scale
104        });
105
106        // Optional bias terms
107        let (b_q, b_k, b_v, b_o) = if use_bias {
108            (
109                Some(Array1::zeros(hidden_dim)),
110                Some(Array1::zeros(hidden_dim)),
111                Some(Array1::zeros(hidden_dim)),
112                Some(Array1::zeros(hidden_dim)),
113            )
114        } else {
115            (None, None, None, None)
116        };
117
118        Ok(Self {
119            config,
120            w_q,
121            w_k,
122            w_v,
123            w_o,
124            b_q,
125            b_k,
126            b_v,
127            b_o,
128        })
129    }
130
131    /// Forward pass for a single query vector (inference mode)
132    ///
133    /// This is optimized for O(1) per-step inference in SSMs.
134    pub fn forward_step(
135        &self,
136        query: &Array1<f32>,
137        key_cache: &Array2<f32>,
138        value_cache: &Array2<f32>,
139    ) -> CoreResult<Array1<f32>> {
140        let num_heads = self.config.num_heads;
141        let head_dim = self.config.head_dim;
142        let seq_len = key_cache.nrows();
143
144        // Project query
145        let q = self.project_qkv(&self.w_q, &self.b_q, query);
146
147        // Reshape to multi-head: (hidden_dim,) -> (num_heads, head_dim)
148        let q_heads = self.reshape_to_heads(&q)?;
149
150        // Compute attention scores for each head
151        let mut attn_output = Array1::zeros(self.config.hidden_dim);
152        let scale = 1.0 / (head_dim as f32).sqrt();
153
154        for h in 0..num_heads {
155            let q_h = q_heads.slice(s![h, ..]);
156
157            // Compute attention scores: scores[i] = q · k[i]
158            let mut scores = Array1::zeros(seq_len);
159            for i in 0..seq_len {
160                let k_i = key_cache.slice(s![i, h * head_dim..(h + 1) * head_dim]);
161                scores[i] = simd::dot_view(q_h, k_i) * scale;
162            }
163
164            // Apply causal mask if needed
165            if self.config.causal {
166                // For inference, all cached keys are valid (from past)
167                // No masking needed as we only attend to past
168            }
169
170            // Softmax over scores
171            let attn_weights = softmax_stable(&scores);
172
173            // Weighted sum of values
174            let mut context = Array1::zeros(head_dim);
175            for i in 0..seq_len {
176                let v_i = value_cache.slice(s![i, h * head_dim..(h + 1) * head_dim]);
177                let weight = attn_weights[i];
178                for j in 0..head_dim {
179                    context[j] += weight * v_i[j];
180                }
181            }
182
183            // Copy to output
184            let start = h * head_dim;
185            let end = start + head_dim;
186            attn_output.slice_mut(s![start..end]).assign(&context);
187        }
188
189        // Output projection
190        let output = if let Some(ref bias) = self.b_o {
191            attn_output.dot(&self.w_o) + bias
192        } else {
193            attn_output.dot(&self.w_o)
194        };
195
196        Ok(output)
197    }
198
199    /// Forward pass for a batch of sequences (training mode)
200    ///
201    /// Input shape: (batch_size, seq_len, hidden_dim)
202    /// Output shape: (batch_size, seq_len, hidden_dim)
203    pub fn forward_batch(
204        &self,
205        input: &Array3<f32>,
206        mask: Option<&Array2<bool>>,
207    ) -> CoreResult<Array3<f32>> {
208        let (batch_size, seq_len, _hidden_dim) = input.dim();
209        let num_heads = self.config.num_heads;
210        let head_dim = self.config.head_dim;
211
212        let mut output = Array3::zeros((batch_size, seq_len, self.config.hidden_dim));
213
214        // Process each batch item
215        for b in 0..batch_size {
216            let input_batch = input.index_axis(Axis(0), b);
217
218            // Project Q, K, V for all positions
219            let mut q_all = Array2::zeros((seq_len, self.config.hidden_dim));
220            let mut k_all = Array2::zeros((seq_len, self.config.hidden_dim));
221            let mut v_all = Array2::zeros((seq_len, self.config.hidden_dim));
222
223            for t in 0..seq_len {
224                let x_t = input_batch.index_axis(Axis(0), t).to_owned();
225                q_all
226                    .index_axis_mut(Axis(0), t)
227                    .assign(&self.project_qkv(&self.w_q, &self.b_q, &x_t));
228                k_all
229                    .index_axis_mut(Axis(0), t)
230                    .assign(&self.project_qkv(&self.w_k, &self.b_k, &x_t));
231                v_all
232                    .index_axis_mut(Axis(0), t)
233                    .assign(&self.project_qkv(&self.w_v, &self.b_v, &x_t));
234            }
235
236            // Compute attention for each position
237            for t in 0..seq_len {
238                let q_t = q_all.index_axis(Axis(0), t).to_owned();
239                let q_heads = self.reshape_to_heads(&q_t)?;
240
241                let mut attn_output = Array1::zeros(self.config.hidden_dim);
242                let scale = 1.0 / (head_dim as f32).sqrt();
243
244                for h in 0..num_heads {
245                    let q_h = q_heads.slice(s![h, ..]);
246
247                    // Compute attention scores
248                    let attend_len = if self.config.causal { t + 1 } else { seq_len };
249                    let mut scores = Array1::zeros(attend_len);
250
251                    for i in 0..attend_len {
252                        let k_i = k_all.slice(s![i, h * head_dim..(h + 1) * head_dim]);
253                        scores[i] = simd::dot_view(q_h, k_i) * scale;
254                    }
255
256                    // Apply mask if provided
257                    if let Some(mask_data) = mask {
258                        for i in 0..attend_len {
259                            if !mask_data[[b, i]] {
260                                scores[i] = f32::NEG_INFINITY;
261                            }
262                        }
263                    }
264
265                    // Softmax
266                    let attn_weights = softmax_stable(&scores);
267
268                    // Weighted sum of values
269                    let mut context = Array1::zeros(head_dim);
270                    for i in 0..attend_len {
271                        let v_i = v_all.slice(s![i, h * head_dim..(h + 1) * head_dim]);
272                        let weight = attn_weights[i];
273                        for j in 0..head_dim {
274                            context[j] += weight * v_i[j];
275                        }
276                    }
277
278                    // Copy to output
279                    let start = h * head_dim;
280                    let end = start + head_dim;
281                    attn_output.slice_mut(s![start..end]).assign(&context);
282                }
283
284                // Output projection
285                let out_t = if let Some(ref bias) = self.b_o {
286                    attn_output.dot(&self.w_o) + bias
287                } else {
288                    attn_output.dot(&self.w_o)
289                };
290
291                output
292                    .index_axis_mut(Axis(0), b)
293                    .index_axis_mut(Axis(0), t)
294                    .assign(&out_t);
295            }
296        }
297
298        Ok(output)
299    }
300
301    /// Project input through QKV matrix
302    fn project_qkv(
303        &self,
304        weight: &Array2<f32>,
305        bias: &Option<Array1<f32>>,
306        input: &Array1<f32>,
307    ) -> Array1<f32> {
308        if let Some(ref b) = bias {
309            input.dot(weight) + b
310        } else {
311            input.dot(weight)
312        }
313    }
314
315    /// Reshape flat vector to multi-head format
316    /// Input: (hidden_dim,) -> Output: (num_heads, head_dim)
317    fn reshape_to_heads(&self, x: &Array1<f32>) -> CoreResult<Array2<f32>> {
318        if x.len() != self.config.hidden_dim {
319            return Err(CoreError::DimensionMismatch {
320                expected: self.config.hidden_dim,
321                got: x.len(),
322            });
323        }
324
325        let mut result = Array2::zeros((self.config.num_heads, self.config.head_dim));
326        for h in 0..self.config.num_heads {
327            let start = h * self.config.head_dim;
328            let end = start + self.config.head_dim;
329            result.row_mut(h).assign(&x.slice(s![start..end]));
330        }
331
332        Ok(result)
333    }
334
335    /// Get configuration
336    pub fn config(&self) -> &MultiHeadSSMConfig {
337        &self.config
338    }
339
340    /// Get number of parameters
341    pub fn num_parameters(&self) -> usize {
342        let weight_params = self.w_q.len() + self.w_k.len() + self.w_v.len() + self.w_o.len();
343        let bias_params = if self.b_q.is_some() {
344            4 * self.config.hidden_dim
345        } else {
346            0
347        };
348        weight_params + bias_params
349    }
350}
351
352/// Gated Linear Attention (Griffin-style)
353///
354/// Implements efficient gated linear attention for SSMs:
355/// - Linear complexity in sequence length
356/// - Gating mechanism for selective attention
357/// - Compatible with SSM recurrence
358#[derive(Debug)]
359pub struct GatedLinearAttention {
360    hidden_dim: usize,
361    // Gate projection
362    w_gate: Array2<f32>,
363    // Query/Key projections
364    w_q: Array2<f32>,
365    w_k: Array2<f32>,
366    // Output projection
367    w_o: Array2<f32>,
368}
369
370impl GatedLinearAttention {
371    /// Create a new gated linear attention layer
372    pub fn new(hidden_dim: usize) -> CoreResult<Self> {
373        let mut rng = thread_rng();
374        let scale = (1.0 / hidden_dim as f32).sqrt();
375
376        let w_gate = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
377            (rng.random::<f32>() - 0.5) * 2.0 * scale
378        });
379        let w_q = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
380            (rng.random::<f32>() - 0.5) * 2.0 * scale
381        });
382        let w_k = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
383            (rng.random::<f32>() - 0.5) * 2.0 * scale
384        });
385        let w_o = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
386            (rng.random::<f32>() - 0.5) * 2.0 * scale
387        });
388
389        Ok(Self {
390            hidden_dim,
391            w_gate,
392            w_q,
393            w_k,
394            w_o,
395        })
396    }
397
398    /// Forward step with linear attention
399    ///
400    /// Uses linear attention: O(d^2) per step instead of O(n*d)
401    pub fn forward_step(
402        &self,
403        input: &Array1<f32>,
404        kv_state: &mut Array2<f32>,
405    ) -> CoreResult<Array1<f32>> {
406        // Project to query, key, gate
407        let q = input.dot(&self.w_q);
408        let k = input.dot(&self.w_k);
409        let g = input.dot(&self.w_gate);
410
411        // Apply gating (sigmoid)
412        let gate = g.mapv(|x| 1.0 / (1.0 + safe_exp(-x)));
413
414        // Update KV state: kv_state += k ⊗ (g * v)
415        let gated_value = &gate * input;
416        for i in 0..self.hidden_dim {
417            for j in 0..self.hidden_dim {
418                kv_state[[i, j]] += k[i] * gated_value[j];
419            }
420        }
421
422        // Attention output: q^T * kv_state
423        let mut attn_out = Array1::zeros(self.hidden_dim);
424        for j in 0..self.hidden_dim {
425            let mut sum = 0.0;
426            for i in 0..self.hidden_dim {
427                sum += q[i] * kv_state[[i, j]];
428            }
429            attn_out[j] = sum;
430        }
431
432        // Output projection
433        let output = attn_out.dot(&self.w_o);
434        Ok(output)
435    }
436
437    /// Reset KV state
438    pub fn reset_state(&self) -> Array2<f32> {
439        Array2::zeros((self.hidden_dim, self.hidden_dim))
440    }
441}
442
443// Re-export slice macro if needed
444use scirs2_core::ndarray::s;
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449
450    #[test]
451    fn test_multihead_ssm_config() {
452        let config = MultiHeadSSMConfig::new(512, 8, 64).unwrap();
453        assert_eq!(config.hidden_dim, 512);
454        assert_eq!(config.num_heads, 8);
455        assert_eq!(config.head_dim, 64);
456    }
457
458    #[test]
459    fn test_multihead_ssm_attention() {
460        let config = MultiHeadSSMConfig::new(64, 4, 16).unwrap();
461        let attn = MultiHeadSSMAttention::new(config, false).unwrap();
462
463        let query = Array1::from_vec(vec![0.1; 64]);
464        let key_cache = Array2::from_shape_vec((10, 64), vec![0.1; 640]).unwrap();
465        let value_cache = Array2::from_shape_vec((10, 64), vec![0.2; 640]).unwrap();
466
467        let output = attn.forward_step(&query, &key_cache, &value_cache).unwrap();
468        assert_eq!(output.len(), 64);
469    }
470
471    #[test]
472    fn test_gated_linear_attention() {
473        let gla = GatedLinearAttention::new(64).unwrap();
474        let input = Array1::from_vec(vec![0.1; 64]);
475        let mut kv_state = gla.reset_state();
476
477        let output = gla.forward_step(&input, &mut kv_state).unwrap();
478        assert_eq!(output.len(), 64);
479    }
480
481    #[test]
482    fn test_multihead_batch_forward() {
483        let config = MultiHeadSSMConfig::new(64, 4, 16).unwrap();
484        let attn = MultiHeadSSMAttention::new(config, false).unwrap();
485
486        let batch_size = 2;
487        let seq_len = 5;
488        let input = Array3::from_shape_vec(
489            (batch_size, seq_len, 64),
490            vec![0.1; batch_size * seq_len * 64],
491        )
492        .unwrap();
493
494        let output = attn.forward_batch(&input, None).unwrap();
495        assert_eq!(output.dim(), (batch_size, seq_len, 64));
496    }
497}