1use crate::error::{ModelError, ModelResult};
37use crate::{AutoregressiveModel, ModelType};
38use kizzasi_core::{silu, softmax, CoreResult, HiddenState, SignalPredictor};
39use scirs2_core::ndarray::{Array1, Array2};
40use scirs2_core::random::{rng, Rng};
41use std::collections::VecDeque;
42
43#[allow(unused_imports)]
44use tracing::{debug, instrument, trace};
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum LayerType {
49 Mamba,
51 Attention,
53}
54
55#[derive(Debug, Clone)]
57pub struct HybridConfig {
58 pub input_dim: usize,
60 pub hidden_dim: usize,
62 pub state_dim: usize,
64 pub num_layers: usize,
66 pub num_heads: usize,
68 pub max_seq_len: usize,
70 pub layer_pattern: Vec<LayerType>,
72}
73
74impl HybridConfig {
75 pub fn alternating(
77 input_dim: usize,
78 hidden_dim: usize,
79 num_layers: usize,
80 num_heads: usize,
81 ) -> Self {
82 let layer_pattern = (0..num_layers)
83 .map(|i| {
84 if i % 2 == 0 {
85 LayerType::Mamba
86 } else {
87 LayerType::Attention
88 }
89 })
90 .collect();
91
92 Self {
93 input_dim,
94 hidden_dim,
95 state_dim: 64,
96 num_layers,
97 num_heads,
98 max_seq_len: 2048,
99 layer_pattern,
100 }
101 }
102
103 pub fn mamba_heavy(
105 input_dim: usize,
106 hidden_dim: usize,
107 num_layers: usize,
108 num_heads: usize,
109 ) -> Self {
110 let layer_pattern = (0..num_layers)
111 .map(|i| {
112 if i % 4 == 3 {
114 LayerType::Attention
115 } else {
116 LayerType::Mamba
117 }
118 })
119 .collect();
120
121 Self {
122 input_dim,
123 hidden_dim,
124 state_dim: 64,
125 num_layers,
126 num_heads,
127 max_seq_len: 2048,
128 layer_pattern,
129 }
130 }
131
132 pub fn validate(&self) -> ModelResult<()> {
134 if self.hidden_dim == 0 {
135 return Err(ModelError::invalid_config("hidden_dim must be > 0"));
136 }
137 if self.state_dim == 0 {
138 return Err(ModelError::invalid_config("state_dim must be > 0"));
139 }
140 if self.num_layers == 0 {
141 return Err(ModelError::invalid_config("num_layers must be > 0"));
142 }
143 if self.num_heads == 0 {
144 return Err(ModelError::invalid_config("num_heads must be > 0"));
145 }
146 if !self.hidden_dim.is_multiple_of(self.num_heads) {
147 return Err(ModelError::invalid_config(
148 "hidden_dim must be divisible by num_heads",
149 ));
150 }
151 if self.layer_pattern.len() != self.num_layers {
152 return Err(ModelError::invalid_config(
153 "layer_pattern length must equal num_layers",
154 ));
155 }
156 Ok(())
157 }
158}
159
160#[allow(dead_code)]
162struct MambaBlock {
163 hidden_dim: usize,
164 state_dim: usize,
165 proj_in: Array2<f32>,
167 proj_out: Array2<f32>,
168 a_log: Array1<f32>,
170 b_matrix: Array2<f32>,
171 c_matrix: Array2<f32>,
172 state: Array1<f32>,
174}
175
176impl MambaBlock {
177 fn new(hidden_dim: usize, state_dim: usize) -> Self {
178 let mut rng = rng();
179
180 let scale = (2.0 / (hidden_dim + hidden_dim) as f32).sqrt();
181 let proj_in = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
182 (rng.random::<f32>() - 0.5) * 2.0 * scale
183 });
184
185 let proj_out = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
186 (rng.random::<f32>() - 0.5) * 2.0 * scale
187 });
188
189 let a_log = Array1::from_shape_fn(state_dim, |i| -((i + 1) as f32).ln());
191
192 let scale = (1.0 / state_dim as f32).sqrt();
193 let b_matrix = Array2::from_shape_fn((state_dim, hidden_dim), |_| {
194 (rng.random::<f32>() - 0.5) * 2.0 * scale
195 });
196
197 let c_matrix = Array2::from_shape_fn((hidden_dim, state_dim), |_| {
198 (rng.random::<f32>() - 0.5) * 2.0 * scale
199 });
200
201 let state = Array1::zeros(state_dim);
202
203 Self {
204 hidden_dim,
205 state_dim,
206 proj_in,
207 proj_out,
208 a_log,
209 b_matrix,
210 c_matrix,
211 state,
212 }
213 }
214
215 fn forward(&mut self, x: &Array1<f32>) -> Array1<f32> {
216 let projected = x.dot(&self.proj_in);
218
219 let a_bar = self.a_log.mapv(|a| (0.001 * a.exp()).exp());
221 self.state = &self.state * &a_bar + self.b_matrix.dot(&projected) * 0.001;
222
223 let ssm_out = self.c_matrix.dot(&self.state);
225
226 let gated = silu(&projected) * &ssm_out;
228
229 gated.dot(&self.proj_out)
231 }
232
233 fn reset(&mut self) {
234 self.state.fill(0.0);
235 }
236}
237
238#[allow(dead_code)]
240struct AttentionBlock {
241 hidden_dim: usize,
242 num_heads: usize,
243 head_dim: usize,
244 q_proj: Array2<f32>,
246 k_proj: Array2<f32>,
247 v_proj: Array2<f32>,
248 o_proj: Array2<f32>,
250 k_cache: VecDeque<Array1<f32>>,
252 v_cache: VecDeque<Array1<f32>>,
253 max_cache_len: usize,
254}
255
256impl AttentionBlock {
257 fn new(hidden_dim: usize, num_heads: usize, max_seq_len: usize) -> Self {
258 let mut rng = rng();
259 let head_dim = hidden_dim / num_heads;
260
261 let scale = (2.0 / (hidden_dim + hidden_dim) as f32).sqrt();
262 let q_proj = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
263 (rng.random::<f32>() - 0.5) * 2.0 * scale
264 });
265 let k_proj = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
266 (rng.random::<f32>() - 0.5) * 2.0 * scale
267 });
268 let v_proj = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
269 (rng.random::<f32>() - 0.5) * 2.0 * scale
270 });
271 let o_proj = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
272 (rng.random::<f32>() - 0.5) * 2.0 * scale
273 });
274
275 Self {
276 hidden_dim,
277 num_heads,
278 head_dim,
279 q_proj,
280 k_proj,
281 v_proj,
282 o_proj,
283 k_cache: VecDeque::new(),
284 v_cache: VecDeque::new(),
285 max_cache_len: max_seq_len,
286 }
287 }
288
289 fn forward(&mut self, x: &Array1<f32>) -> Array1<f32> {
290 let q = x.dot(&self.q_proj);
292 let k = x.dot(&self.k_proj);
293 let v = x.dot(&self.v_proj);
294
295 self.k_cache.push_back(k.clone());
297 self.v_cache.push_back(v.clone());
298
299 while self.k_cache.len() > self.max_cache_len {
301 self.k_cache.pop_front();
302 self.v_cache.pop_front();
303 }
304
305 let cache_len = self.k_cache.len();
307 let mut attention_out = Array1::zeros(self.hidden_dim);
308
309 if cache_len > 0 {
310 let mut scores = Vec::with_capacity(cache_len);
312 for k_cached in &self.k_cache {
313 let score = q.dot(k_cached) / (self.head_dim as f32).sqrt();
314 scores.push(score);
315 }
316
317 let scores_array = Array1::from_vec(scores);
319 let attn_weights = softmax(&scores_array);
320
321 for (weight, v_cached) in attn_weights.iter().zip(self.v_cache.iter()) {
323 attention_out = attention_out + v_cached * *weight;
324 }
325 }
326
327 attention_out.dot(&self.o_proj)
329 }
330
331 fn reset(&mut self) {
332 self.k_cache.clear();
333 self.v_cache.clear();
334 }
335}
336
337enum HybridLayer {
339 Mamba(MambaBlock),
340 Attention(AttentionBlock),
341}
342
343impl HybridLayer {
344 fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
345 match self {
346 HybridLayer::Mamba(mamba) => Ok(mamba.forward(x)),
347 HybridLayer::Attention(attn) => Ok(attn.forward(x)),
348 }
349 }
350
351 fn reset(&mut self) {
352 match self {
353 HybridLayer::Mamba(mamba) => mamba.reset(),
354 HybridLayer::Attention(attn) => attn.reset(),
355 }
356 }
357}
358
359pub struct HybridModel {
361 config: HybridConfig,
362 layers: Vec<HybridLayer>,
363 input_proj: Array2<f32>,
365 output_proj: Array2<f32>,
366}
367
368impl HybridModel {
369 #[instrument(skip(config), fields(num_layers = config.num_layers))]
371 pub fn new(config: HybridConfig) -> ModelResult<Self> {
372 debug!("Creating new Hybrid Mamba+Attention model");
373 config.validate()?;
374
375 let mut rng = rng();
376
377 let scale = (2.0 / (config.input_dim + config.hidden_dim) as f32).sqrt();
379 let input_proj = Array2::from_shape_fn((config.input_dim, config.hidden_dim), |_| {
380 (rng.random::<f32>() - 0.5) * 2.0 * scale
381 });
382
383 let scale = (2.0 / (config.hidden_dim + config.input_dim) as f32).sqrt();
385 let output_proj = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
386 (rng.random::<f32>() - 0.5) * 2.0 * scale
387 });
388
389 let mut layers = Vec::with_capacity(config.num_layers);
391 for (i, &layer_type) in config.layer_pattern.iter().enumerate() {
392 trace!("Initializing hybrid layer {} as {:?}", i, layer_type);
393 let layer = match layer_type {
394 LayerType::Mamba => {
395 HybridLayer::Mamba(MambaBlock::new(config.hidden_dim, config.state_dim))
396 }
397 LayerType::Attention => HybridLayer::Attention(AttentionBlock::new(
398 config.hidden_dim,
399 config.num_heads,
400 config.max_seq_len,
401 )),
402 };
403 layers.push(layer);
404 }
405
406 debug!(
407 "Hybrid model created successfully with {} layers",
408 layers.len()
409 );
410 Ok(Self {
411 config,
412 layers,
413 input_proj,
414 output_proj,
415 })
416 }
417
418 pub fn config(&self) -> &HybridConfig {
420 &self.config
421 }
422
423 pub fn layer_counts(&self) -> (usize, usize) {
425 let mamba_count = self
426 .config
427 .layer_pattern
428 .iter()
429 .filter(|&&t| t == LayerType::Mamba)
430 .count();
431 let attention_count = self.config.num_layers - mamba_count;
432 (mamba_count, attention_count)
433 }
434}
435
436impl SignalPredictor for HybridModel {
437 #[instrument(skip(self, input))]
438 fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
439 let mut hidden = input.dot(&self.input_proj);
441
442 for layer in &mut self.layers {
444 hidden = layer.forward(&hidden)?;
445 }
446
447 let output = hidden.dot(&self.output_proj);
449 Ok(output)
450 }
451
452 #[instrument(skip(self))]
453 fn reset(&mut self) {
454 debug!("Resetting Hybrid model state");
455 for layer in &mut self.layers {
456 layer.reset();
457 }
458 }
459
460 fn context_window(&self) -> usize {
461 self.config.max_seq_len
463 }
464}
465
466impl AutoregressiveModel for HybridModel {
467 fn hidden_dim(&self) -> usize {
468 self.config.hidden_dim
469 }
470
471 fn state_dim(&self) -> usize {
472 self.config.state_dim
473 }
474
475 fn num_layers(&self) -> usize {
476 self.config.num_layers
477 }
478
479 fn model_type(&self) -> ModelType {
480 ModelType::Mamba }
482
483 fn get_states(&self) -> Vec<HiddenState> {
484 (0..self.config.num_layers)
486 .map(|_| HiddenState::new(self.config.hidden_dim, self.config.state_dim))
487 .collect()
488 }
489
490 fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
491 if states.len() != self.config.num_layers {
492 return Err(ModelError::state_count_mismatch(
493 "Hybrid",
494 self.config.num_layers,
495 states.len(),
496 ));
497 }
498 Ok(())
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506
507 #[test]
508 fn test_hybrid_creation_alternating() {
509 let config = HybridConfig::alternating(32, 64, 4, 4);
510 let model = HybridModel::new(config);
511 assert!(model.is_ok());
512 }
513
514 #[test]
515 fn test_hybrid_creation_mamba_heavy() {
516 let config = HybridConfig::mamba_heavy(32, 64, 8, 4);
517 let model = HybridModel::new(config);
518 assert!(model.is_ok());
519 }
520
521 #[test]
522 fn test_hybrid_forward() {
523 let config = HybridConfig::alternating(32, 64, 4, 4);
524 let mut model = HybridModel::new(config).expect("Failed to create HybridModel");
525
526 let input = Array1::from_vec(vec![1.0; 32]);
527 let output = model.step(&input);
528 assert!(output.is_ok());
529 assert_eq!(output.expect("Failed to get output").len(), 32);
530 }
531
532 #[test]
533 fn test_hybrid_layer_counts() {
534 let config = HybridConfig::alternating(32, 64, 6, 4);
535 let model = HybridModel::new(config).expect("Failed to create HybridModel");
536 let (mamba, attn) = model.layer_counts();
537 assert_eq!(mamba, 3);
538 assert_eq!(attn, 3);
539 }
540
541 #[test]
542 fn test_hybrid_mamba_heavy_counts() {
543 let config = HybridConfig::mamba_heavy(32, 64, 8, 4);
544 let model = HybridModel::new(config).expect("Failed to create HybridModel");
545 let (mamba, attn) = model.layer_counts();
546 assert_eq!(mamba, 6);
547 assert_eq!(attn, 2);
548 }
549
550 #[test]
551 fn test_hybrid_reset() {
552 let config = HybridConfig::alternating(32, 64, 4, 4);
553 let mut model = HybridModel::new(config).expect("Failed to create HybridModel");
554
555 let input = Array1::from_vec(vec![0.5; 32]);
556 let _ = model.step(&input).expect("Failed to step model");
557
558 model.reset();
559
560 let output = model.step(&input).expect("Failed to get output");
561 assert_eq!(output.len(), 32);
562 }
563
564 #[test]
565 fn test_invalid_config() {
566 let mut config = HybridConfig::alternating(32, 64, 4, 4);
567 config.layer_pattern.push(LayerType::Mamba); assert!(config.validate().is_err());
569 }
570}