1use crate::error::{CoreError, CoreResult};
7use crate::numerics::{safe_exp, softmax_stable};
8use crate::simd;
9use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
10use scirs2_core::random::thread_rng;
11
12#[derive(Debug, Clone)]
14pub struct MultiHeadSSMConfig {
15 pub hidden_dim: usize,
17 pub num_heads: usize,
19 pub head_dim: usize,
21 pub state_dim: usize,
23 pub dropout: f32,
25 pub causal: bool,
27}
28
29impl MultiHeadSSMConfig {
30 pub fn new(hidden_dim: usize, num_heads: usize, state_dim: usize) -> CoreResult<Self> {
32 if !hidden_dim.is_multiple_of(num_heads) {
33 return Err(CoreError::InvalidConfig(format!(
34 "hidden_dim ({}) must be divisible by num_heads ({})",
35 hidden_dim, num_heads
36 )));
37 }
38
39 Ok(Self {
40 hidden_dim,
41 num_heads,
42 head_dim: hidden_dim / num_heads,
43 state_dim,
44 dropout: 0.0,
45 causal: true,
46 })
47 }
48
49 pub fn dropout(mut self, rate: f32) -> Self {
51 self.dropout = rate;
52 self
53 }
54
55 pub fn causal(mut self, causal: bool) -> Self {
57 self.causal = causal;
58 self
59 }
60}
61
62#[derive(Debug)]
70pub struct MultiHeadSSMAttention {
71 config: MultiHeadSSMConfig,
72 w_q: Array2<f32>,
74 w_k: Array2<f32>,
75 w_v: Array2<f32>,
76 w_o: Array2<f32>,
78 b_q: Option<Array1<f32>>,
80 b_k: Option<Array1<f32>>,
81 b_v: Option<Array1<f32>>,
82 b_o: Option<Array1<f32>>,
83}
84
85impl MultiHeadSSMAttention {
86 pub fn new(config: MultiHeadSSMConfig, use_bias: bool) -> CoreResult<Self> {
88 let hidden_dim = config.hidden_dim;
89 let mut rng = thread_rng();
90 let scale = (1.0 / hidden_dim as f32).sqrt();
91
92 let w_q = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
94 (rng.random::<f32>() - 0.5) * 2.0 * scale
95 });
96 let w_k = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
97 (rng.random::<f32>() - 0.5) * 2.0 * scale
98 });
99 let w_v = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
100 (rng.random::<f32>() - 0.5) * 2.0 * scale
101 });
102 let w_o = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
103 (rng.random::<f32>() - 0.5) * 2.0 * scale
104 });
105
106 let (b_q, b_k, b_v, b_o) = if use_bias {
108 (
109 Some(Array1::zeros(hidden_dim)),
110 Some(Array1::zeros(hidden_dim)),
111 Some(Array1::zeros(hidden_dim)),
112 Some(Array1::zeros(hidden_dim)),
113 )
114 } else {
115 (None, None, None, None)
116 };
117
118 Ok(Self {
119 config,
120 w_q,
121 w_k,
122 w_v,
123 w_o,
124 b_q,
125 b_k,
126 b_v,
127 b_o,
128 })
129 }
130
131 pub fn forward_step(
135 &self,
136 query: &Array1<f32>,
137 key_cache: &Array2<f32>,
138 value_cache: &Array2<f32>,
139 ) -> CoreResult<Array1<f32>> {
140 let num_heads = self.config.num_heads;
141 let head_dim = self.config.head_dim;
142 let seq_len = key_cache.nrows();
143
144 let q = self.project_qkv(&self.w_q, &self.b_q, query);
146
147 let q_heads = self.reshape_to_heads(&q)?;
149
150 let mut attn_output = Array1::zeros(self.config.hidden_dim);
152 let scale = 1.0 / (head_dim as f32).sqrt();
153
154 for h in 0..num_heads {
155 let q_h = q_heads.slice(s![h, ..]);
156
157 let mut scores = Array1::zeros(seq_len);
159 for i in 0..seq_len {
160 let k_i = key_cache.slice(s![i, h * head_dim..(h + 1) * head_dim]);
161 scores[i] = simd::dot_view(q_h, k_i) * scale;
162 }
163
164 if self.config.causal {
166 }
169
170 let attn_weights = softmax_stable(&scores);
172
173 let mut context = Array1::zeros(head_dim);
175 for i in 0..seq_len {
176 let v_i = value_cache.slice(s![i, h * head_dim..(h + 1) * head_dim]);
177 let weight = attn_weights[i];
178 for j in 0..head_dim {
179 context[j] += weight * v_i[j];
180 }
181 }
182
183 let start = h * head_dim;
185 let end = start + head_dim;
186 attn_output.slice_mut(s![start..end]).assign(&context);
187 }
188
189 let output = if let Some(ref bias) = self.b_o {
191 attn_output.dot(&self.w_o) + bias
192 } else {
193 attn_output.dot(&self.w_o)
194 };
195
196 Ok(output)
197 }
198
199 pub fn forward_batch(
204 &self,
205 input: &Array3<f32>,
206 mask: Option<&Array2<bool>>,
207 ) -> CoreResult<Array3<f32>> {
208 let (batch_size, seq_len, _hidden_dim) = input.dim();
209 let num_heads = self.config.num_heads;
210 let head_dim = self.config.head_dim;
211
212 let mut output = Array3::zeros((batch_size, seq_len, self.config.hidden_dim));
213
214 for b in 0..batch_size {
216 let input_batch = input.index_axis(Axis(0), b);
217
218 let mut q_all = Array2::zeros((seq_len, self.config.hidden_dim));
220 let mut k_all = Array2::zeros((seq_len, self.config.hidden_dim));
221 let mut v_all = Array2::zeros((seq_len, self.config.hidden_dim));
222
223 for t in 0..seq_len {
224 let x_t = input_batch.index_axis(Axis(0), t).to_owned();
225 q_all
226 .index_axis_mut(Axis(0), t)
227 .assign(&self.project_qkv(&self.w_q, &self.b_q, &x_t));
228 k_all
229 .index_axis_mut(Axis(0), t)
230 .assign(&self.project_qkv(&self.w_k, &self.b_k, &x_t));
231 v_all
232 .index_axis_mut(Axis(0), t)
233 .assign(&self.project_qkv(&self.w_v, &self.b_v, &x_t));
234 }
235
236 for t in 0..seq_len {
238 let q_t = q_all.index_axis(Axis(0), t).to_owned();
239 let q_heads = self.reshape_to_heads(&q_t)?;
240
241 let mut attn_output = Array1::zeros(self.config.hidden_dim);
242 let scale = 1.0 / (head_dim as f32).sqrt();
243
244 for h in 0..num_heads {
245 let q_h = q_heads.slice(s![h, ..]);
246
247 let attend_len = if self.config.causal { t + 1 } else { seq_len };
249 let mut scores = Array1::zeros(attend_len);
250
251 for i in 0..attend_len {
252 let k_i = k_all.slice(s![i, h * head_dim..(h + 1) * head_dim]);
253 scores[i] = simd::dot_view(q_h, k_i) * scale;
254 }
255
256 if let Some(mask_data) = mask {
258 for i in 0..attend_len {
259 if !mask_data[[b, i]] {
260 scores[i] = f32::NEG_INFINITY;
261 }
262 }
263 }
264
265 let attn_weights = softmax_stable(&scores);
267
268 let mut context = Array1::zeros(head_dim);
270 for i in 0..attend_len {
271 let v_i = v_all.slice(s![i, h * head_dim..(h + 1) * head_dim]);
272 let weight = attn_weights[i];
273 for j in 0..head_dim {
274 context[j] += weight * v_i[j];
275 }
276 }
277
278 let start = h * head_dim;
280 let end = start + head_dim;
281 attn_output.slice_mut(s![start..end]).assign(&context);
282 }
283
284 let out_t = if let Some(ref bias) = self.b_o {
286 attn_output.dot(&self.w_o) + bias
287 } else {
288 attn_output.dot(&self.w_o)
289 };
290
291 output
292 .index_axis_mut(Axis(0), b)
293 .index_axis_mut(Axis(0), t)
294 .assign(&out_t);
295 }
296 }
297
298 Ok(output)
299 }
300
301 fn project_qkv(
303 &self,
304 weight: &Array2<f32>,
305 bias: &Option<Array1<f32>>,
306 input: &Array1<f32>,
307 ) -> Array1<f32> {
308 if let Some(ref b) = bias {
309 input.dot(weight) + b
310 } else {
311 input.dot(weight)
312 }
313 }
314
315 fn reshape_to_heads(&self, x: &Array1<f32>) -> CoreResult<Array2<f32>> {
318 if x.len() != self.config.hidden_dim {
319 return Err(CoreError::DimensionMismatch {
320 expected: self.config.hidden_dim,
321 got: x.len(),
322 });
323 }
324
325 let mut result = Array2::zeros((self.config.num_heads, self.config.head_dim));
326 for h in 0..self.config.num_heads {
327 let start = h * self.config.head_dim;
328 let end = start + self.config.head_dim;
329 result.row_mut(h).assign(&x.slice(s![start..end]));
330 }
331
332 Ok(result)
333 }
334
335 pub fn config(&self) -> &MultiHeadSSMConfig {
337 &self.config
338 }
339
340 pub fn num_parameters(&self) -> usize {
342 let weight_params = self.w_q.len() + self.w_k.len() + self.w_v.len() + self.w_o.len();
343 let bias_params = if self.b_q.is_some() {
344 4 * self.config.hidden_dim
345 } else {
346 0
347 };
348 weight_params + bias_params
349 }
350}
351
352#[derive(Debug)]
359pub struct GatedLinearAttention {
360 hidden_dim: usize,
361 w_gate: Array2<f32>,
363 w_q: Array2<f32>,
365 w_k: Array2<f32>,
366 w_o: Array2<f32>,
368}
369
370impl GatedLinearAttention {
371 pub fn new(hidden_dim: usize) -> CoreResult<Self> {
373 let mut rng = thread_rng();
374 let scale = (1.0 / hidden_dim as f32).sqrt();
375
376 let w_gate = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
377 (rng.random::<f32>() - 0.5) * 2.0 * scale
378 });
379 let w_q = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
380 (rng.random::<f32>() - 0.5) * 2.0 * scale
381 });
382 let w_k = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
383 (rng.random::<f32>() - 0.5) * 2.0 * scale
384 });
385 let w_o = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
386 (rng.random::<f32>() - 0.5) * 2.0 * scale
387 });
388
389 Ok(Self {
390 hidden_dim,
391 w_gate,
392 w_q,
393 w_k,
394 w_o,
395 })
396 }
397
398 pub fn forward_step(
402 &self,
403 input: &Array1<f32>,
404 kv_state: &mut Array2<f32>,
405 ) -> CoreResult<Array1<f32>> {
406 let q = input.dot(&self.w_q);
408 let k = input.dot(&self.w_k);
409 let g = input.dot(&self.w_gate);
410
411 let gate = g.mapv(|x| 1.0 / (1.0 + safe_exp(-x)));
413
414 let gated_value = &gate * input;
416 for i in 0..self.hidden_dim {
417 for j in 0..self.hidden_dim {
418 kv_state[[i, j]] += k[i] * gated_value[j];
419 }
420 }
421
422 let mut attn_out = Array1::zeros(self.hidden_dim);
424 for j in 0..self.hidden_dim {
425 let mut sum = 0.0;
426 for i in 0..self.hidden_dim {
427 sum += q[i] * kv_state[[i, j]];
428 }
429 attn_out[j] = sum;
430 }
431
432 let output = attn_out.dot(&self.w_o);
434 Ok(output)
435 }
436
437 pub fn reset_state(&self) -> Array2<f32> {
439 Array2::zeros((self.hidden_dim, self.hidden_dim))
440 }
441}
442
443use scirs2_core::ndarray::s;
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449
450 #[test]
451 fn test_multihead_ssm_config() {
452 let config = MultiHeadSSMConfig::new(512, 8, 64).unwrap();
453 assert_eq!(config.hidden_dim, 512);
454 assert_eq!(config.num_heads, 8);
455 assert_eq!(config.head_dim, 64);
456 }
457
458 #[test]
459 fn test_multihead_ssm_attention() {
460 let config = MultiHeadSSMConfig::new(64, 4, 16).unwrap();
461 let attn = MultiHeadSSMAttention::new(config, false).unwrap();
462
463 let query = Array1::from_vec(vec![0.1; 64]);
464 let key_cache = Array2::from_shape_vec((10, 64), vec![0.1; 640]).unwrap();
465 let value_cache = Array2::from_shape_vec((10, 64), vec![0.2; 640]).unwrap();
466
467 let output = attn.forward_step(&query, &key_cache, &value_cache).unwrap();
468 assert_eq!(output.len(), 64);
469 }
470
471 #[test]
472 fn test_gated_linear_attention() {
473 let gla = GatedLinearAttention::new(64).unwrap();
474 let input = Array1::from_vec(vec![0.1; 64]);
475 let mut kv_state = gla.reset_state();
476
477 let output = gla.forward_step(&input, &mut kv_state).unwrap();
478 assert_eq!(output.len(), 64);
479 }
480
481 #[test]
482 fn test_multihead_batch_forward() {
483 let config = MultiHeadSSMConfig::new(64, 4, 16).unwrap();
484 let attn = MultiHeadSSMAttention::new(config, false).unwrap();
485
486 let batch_size = 2;
487 let seq_len = 5;
488 let input = Array3::from_shape_vec(
489 (batch_size, seq_len, 64),
490 vec![0.1; batch_size * seq_len * 64],
491 )
492 .unwrap();
493
494 let output = attn.forward_batch(&input, None).unwrap();
495 assert_eq!(output.dim(), (batch_size, seq_len, 64));
496 }
497}