1use crate::error::{ModelError, ModelResult};
31use crate::{AutoregressiveModel, ModelType};
32use kizzasi_core::{gelu, softmax, CoreResult, HiddenState, LayerNorm, NormType, SignalPredictor};
33use scirs2_core::ndarray::{Array1, Array2};
34use scirs2_core::random::{rng, Rng};
35use std::collections::VecDeque;
36#[allow(unused_imports)]
37use tracing::{debug, instrument, trace};
38
39#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
41pub struct TransformerConfig {
42 pub input_dim: usize,
44 pub hidden_dim: usize,
46 pub num_heads: usize,
48 pub head_dim: usize,
50 pub ff_dim: usize,
52 pub num_layers: usize,
54 pub max_seq_len: usize,
56 pub dropout: f32,
58 pub use_rms_norm: bool,
60 pub causal: bool,
62}
63
64impl Default for TransformerConfig {
65 fn default() -> Self {
66 let hidden_dim = 512;
67 let num_heads = 8;
68 Self {
69 input_dim: 1,
70 hidden_dim,
71 num_heads,
72 head_dim: hidden_dim / num_heads,
73 ff_dim: hidden_dim * 4,
74 num_layers: 6,
75 max_seq_len: 2048,
76 dropout: 0.1,
77 use_rms_norm: true,
78 causal: true,
79 }
80 }
81}
82
83impl TransformerConfig {
84 pub fn new() -> Self {
86 Self::default()
87 }
88
89 pub fn input_dim(mut self, dim: usize) -> Self {
91 self.input_dim = dim;
92 self
93 }
94
95 pub fn hidden_dim(mut self, dim: usize) -> Self {
97 self.hidden_dim = dim;
98 self.head_dim = dim / self.num_heads;
99 self
100 }
101
102 pub fn num_heads(mut self, n: usize) -> Self {
104 self.num_heads = n;
105 self.head_dim = self.hidden_dim / n;
106 self
107 }
108
109 pub fn num_layers(mut self, n: usize) -> Self {
111 self.num_layers = n;
112 self
113 }
114
115 pub fn max_seq_len(mut self, len: usize) -> Self {
117 self.max_seq_len = len;
118 self
119 }
120
121 pub fn validate(&self) -> ModelResult<()> {
123 if self.hidden_dim == 0 {
124 return Err(ModelError::invalid_config("hidden_dim must be > 0"));
125 }
126 if self.num_heads == 0 {
127 return Err(ModelError::invalid_config("num_heads must be > 0"));
128 }
129 if !self.hidden_dim.is_multiple_of(self.num_heads) {
130 return Err(ModelError::invalid_config(
131 "hidden_dim must be divisible by num_heads",
132 ));
133 }
134 if self.num_layers == 0 {
135 return Err(ModelError::invalid_config("num_layers must be > 0"));
136 }
137 if self.max_seq_len == 0 {
138 return Err(ModelError::invalid_config("max_seq_len must be > 0"));
139 }
140 Ok(())
141 }
142}
143
144struct MultiHeadAttention {
146 num_heads: usize,
147 head_dim: usize,
148 hidden_dim: usize,
149
150 q_proj: Array2<f32>,
152 k_proj: Array2<f32>,
153 v_proj: Array2<f32>,
154
155 o_proj: Array2<f32>,
157
158 key_cache: VecDeque<Array1<f32>>,
160 value_cache: VecDeque<Array1<f32>>,
161 max_cache_len: usize,
162}
163
164impl MultiHeadAttention {
165 fn new(config: &TransformerConfig) -> ModelResult<Self> {
166 let mut rng = rng();
167 let scale = (2.0 / config.hidden_dim as f32).sqrt();
168
169 let q_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
170 (rng.random::<f32>() - 0.5) * 2.0 * scale
171 });
172 let k_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
173 (rng.random::<f32>() - 0.5) * 2.0 * scale
174 });
175 let v_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
176 (rng.random::<f32>() - 0.5) * 2.0 * scale
177 });
178 let o_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
179 (rng.random::<f32>() - 0.5) * 2.0 * scale
180 });
181
182 Ok(Self {
183 num_heads: config.num_heads,
184 head_dim: config.head_dim,
185 hidden_dim: config.hidden_dim,
186 q_proj,
187 k_proj,
188 v_proj,
189 o_proj,
190 key_cache: VecDeque::new(),
191 value_cache: VecDeque::new(),
192 max_cache_len: config.max_seq_len,
193 })
194 }
195
196 fn forward(&mut self, x: &Array1<f32>, causal: bool) -> CoreResult<Array1<f32>> {
197 let batch_size = x.len().min(self.hidden_dim);
198
199 let q = self.project(x, &self.q_proj);
201 let k = self.project(x, &self.k_proj);
202 let v = self.project(x, &self.v_proj);
203
204 self.key_cache.push_back(k.clone());
206 self.value_cache.push_back(v.clone());
207
208 while self.key_cache.len() > self.max_cache_len {
210 self.key_cache.pop_front();
211 self.value_cache.pop_front();
212 }
213
214 let seq_len = self.key_cache.len();
216 let scale = (self.head_dim as f32).sqrt();
217
218 let mut attention_output = Array1::zeros(batch_size);
219
220 for h in 0..self.num_heads {
222 let head_start = h * self.head_dim;
223 let _head_end = (head_start + self.head_dim).min(batch_size);
224
225 let mut scores = Vec::with_capacity(seq_len);
227 for pos in 0..seq_len {
228 let k_cached = &self.key_cache[pos];
229 let mut score = 0.0;
230
231 for i in 0..self.head_dim {
233 let q_idx = head_start + i;
234 let k_idx = head_start + i;
235 if q_idx < q.len() && k_idx < k_cached.len() {
236 score += q[q_idx] * k_cached[k_idx];
237 }
238 }
239 score /= scale;
240
241 if !causal || pos < seq_len {
243 scores.push(score);
244 } else {
245 scores.push(f32::NEG_INFINITY);
246 }
247 }
248
249 let attention_weights = softmax(&Array1::from_vec(scores));
251
252 for i in 0..self.head_dim {
254 let out_idx = head_start + i;
255 if out_idx >= attention_output.len() {
256 break;
257 }
258
259 let mut weighted_value = 0.0;
260 for (pos, &weight) in attention_weights.iter().enumerate() {
261 let v_cached = &self.value_cache[pos];
262 let v_idx = head_start + i;
263 if v_idx < v_cached.len() {
264 weighted_value += weight * v_cached[v_idx];
265 }
266 }
267 attention_output[out_idx] = weighted_value;
268 }
269 }
270
271 let output = self.project(&attention_output, &self.o_proj);
273 Ok(output)
274 }
275
276 fn project(&self, x: &Array1<f32>, weight: &Array2<f32>) -> Array1<f32> {
277 let out_dim = weight.shape()[0];
278 let mut output = Array1::zeros(out_dim.min(x.len()));
279 for i in 0..output.len() {
280 let mut sum = 0.0;
281 for j in 0..x.len().min(weight.shape()[1]) {
282 sum += weight[[i, j]] * x[j];
283 }
284 output[i] = sum;
285 }
286 output
287 }
288
289 fn reset(&mut self) {
290 self.key_cache.clear();
291 self.value_cache.clear();
292 }
293}
294
295struct FeedForward {
297 fc1: Array2<f32>,
298 fc2: Array2<f32>,
299}
300
301impl FeedForward {
302 fn new(config: &TransformerConfig) -> ModelResult<Self> {
303 let mut rng = rng();
304 let scale1 = (2.0 / config.hidden_dim as f32).sqrt();
305 let scale2 = (2.0 / config.ff_dim as f32).sqrt();
306
307 let fc1 = Array2::from_shape_fn((config.hidden_dim, config.ff_dim), |_| {
308 (rng.random::<f32>() - 0.5) * 2.0 * scale1
309 });
310 let fc2 = Array2::from_shape_fn((config.ff_dim, config.hidden_dim), |_| {
311 (rng.random::<f32>() - 0.5) * 2.0 * scale2
312 });
313
314 Ok(Self { fc1, fc2 })
315 }
316
317 fn forward(&self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
318 let mut hidden = Array1::zeros(self.fc1.shape()[1]);
320 for i in 0..hidden.len() {
321 let mut sum = 0.0;
322 for j in 0..x.len().min(self.fc1.shape()[0]) {
323 sum += self.fc1[[j, i]] * x[j];
324 }
325 hidden[i] = sum;
326 }
327
328 hidden = gelu(&hidden);
330
331 let mut output = Array1::zeros(x.len().min(self.fc2.shape()[1]));
333 for i in 0..output.len() {
334 let mut sum = 0.0;
335 for j in 0..hidden.len().min(self.fc2.shape()[0]) {
336 sum += self.fc2[[j, i]] * hidden[j];
337 }
338 output[i] = sum;
339 }
340
341 Ok(output)
342 }
343}
344
345struct TransformerLayer {
347 ln1: LayerNorm,
348 ln2: LayerNorm,
349 attention: MultiHeadAttention,
350 feed_forward: FeedForward,
351 causal: bool,
352}
353
354impl TransformerLayer {
355 fn new(config: &TransformerConfig) -> ModelResult<Self> {
356 let norm_type = if config.use_rms_norm {
357 NormType::RMSNorm
358 } else {
359 NormType::LayerNorm
360 };
361
362 let ln1 = LayerNorm::new(config.hidden_dim, norm_type).with_eps(1e-5);
363 let ln2 = LayerNorm::new(config.hidden_dim, norm_type).with_eps(1e-5);
364 let attention = MultiHeadAttention::new(config)?;
365 let feed_forward = FeedForward::new(config)?;
366
367 Ok(Self {
368 ln1,
369 ln2,
370 attention,
371 feed_forward,
372 causal: config.causal,
373 })
374 }
375
376 fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
377 let x_norm = self.ln1.forward(x);
379 let attn_out = self.attention.forward(&x_norm, self.causal)?;
380 let mut x_attn = x.clone();
381 for i in 0..x_attn.len().min(attn_out.len()) {
382 x_attn[i] += attn_out[i];
383 }
384
385 let x_norm2 = self.ln2.forward(&x_attn);
387 let ff_out = self.feed_forward.forward(&x_norm2)?;
388 let mut output = x_attn;
389 for i in 0..output.len().min(ff_out.len()) {
390 output[i] += ff_out[i];
391 }
392
393 Ok(output)
394 }
395
396 fn reset(&mut self) {
397 self.attention.reset();
398 }
399}
400
401pub struct Transformer {
403 config: TransformerConfig,
404 layers: Vec<TransformerLayer>,
405 ln_out: LayerNorm,
406 input_proj: Array2<f32>,
407 output_proj: Array2<f32>,
408}
409
410impl Transformer {
411 pub fn new(config: TransformerConfig) -> ModelResult<Self> {
413 config.validate()?;
414
415 let mut layers = Vec::with_capacity(config.num_layers);
417 for _ in 0..config.num_layers {
418 layers.push(TransformerLayer::new(&config)?);
419 }
420
421 let norm_type = if config.use_rms_norm {
423 NormType::RMSNorm
424 } else {
425 NormType::LayerNorm
426 };
427 let ln_out = LayerNorm::new(config.hidden_dim, norm_type).with_eps(1e-5);
428
429 let mut rng = rng();
431 let scale = (2.0 / (config.input_dim + config.hidden_dim) as f32).sqrt();
432 let input_proj = Array2::from_shape_fn((config.input_dim, config.hidden_dim), |_| {
433 (rng.random::<f32>() - 0.5) * 2.0 * scale
434 });
435
436 let scale = (2.0 / (config.hidden_dim + config.input_dim) as f32).sqrt();
437 let output_proj = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
438 (rng.random::<f32>() - 0.5) * 2.0 * scale
439 });
440
441 Ok(Self {
442 config,
443 layers,
444 ln_out,
445 input_proj,
446 output_proj,
447 })
448 }
449
450 pub fn config(&self) -> &TransformerConfig {
452 &self.config
453 }
454
455 pub fn load_weights(&mut self, loader: &crate::loader::ModelLoader) -> ModelResult<()> {
481 if loader.has_tensor("input_proj") {
483 self.input_proj = loader.load_array2("input_proj")?;
484 }
485 if loader.has_tensor("output_proj") {
486 self.output_proj = loader.load_array2("output_proj")?;
487 }
488
489 if loader.has_tensor("ln_out.weight") {
491 let weight = loader.load_array1("ln_out.weight")?;
492 self.ln_out.set_gamma(weight);
493 }
494 if loader.has_tensor("ln_out.bias") {
495 let bias = loader.load_array1("ln_out.bias")?;
496 self.ln_out.set_beta(bias);
497 }
498
499 for (i, layer) in self.layers.iter_mut().enumerate() {
501 let prefix = format!("layers.{}", i);
502
503 if loader.has_tensor(&format!("{}.ln1.weight", prefix)) {
505 let weight = loader.load_array1(&format!("{}.ln1.weight", prefix))?;
506 layer.ln1.set_gamma(weight);
507 }
508 if loader.has_tensor(&format!("{}.ln1.bias", prefix)) {
509 let bias = loader.load_array1(&format!("{}.ln1.bias", prefix))?;
510 layer.ln1.set_beta(bias);
511 }
512
513 if loader.has_tensor(&format!("{}.ln2.weight", prefix)) {
515 let weight = loader.load_array1(&format!("{}.ln2.weight", prefix))?;
516 layer.ln2.set_gamma(weight);
517 }
518 if loader.has_tensor(&format!("{}.ln2.bias", prefix)) {
519 let bias = loader.load_array1(&format!("{}.ln2.bias", prefix))?;
520 layer.ln2.set_beta(bias);
521 }
522
523 let attn_prefix = format!("{}.attention", prefix);
525 if loader.has_tensor(&format!("{}.q_proj", attn_prefix)) {
526 layer.attention.q_proj = loader.load_array2(&format!("{}.q_proj", attn_prefix))?;
527 }
528 if loader.has_tensor(&format!("{}.k_proj", attn_prefix)) {
529 layer.attention.k_proj = loader.load_array2(&format!("{}.k_proj", attn_prefix))?;
530 }
531 if loader.has_tensor(&format!("{}.v_proj", attn_prefix)) {
532 layer.attention.v_proj = loader.load_array2(&format!("{}.v_proj", attn_prefix))?;
533 }
534 if loader.has_tensor(&format!("{}.o_proj", attn_prefix)) {
535 layer.attention.o_proj = loader.load_array2(&format!("{}.o_proj", attn_prefix))?;
536 }
537
538 let ff_prefix = format!("{}.feed_forward", prefix);
540 if loader.has_tensor(&format!("{}.fc1", ff_prefix)) {
541 layer.feed_forward.fc1 = loader.load_array2(&format!("{}.fc1", ff_prefix))?;
542 }
543 if loader.has_tensor(&format!("{}.fc2", ff_prefix)) {
544 layer.feed_forward.fc2 = loader.load_array2(&format!("{}.fc2", ff_prefix))?;
545 }
546 }
547
548 Ok(())
549 }
550
551 #[allow(unused_variables)]
553 pub fn save_weights(&self, path: &str) -> ModelResult<()> {
554 Err(ModelError::simple_load_error(
556 "Transformer save_weights not yet implemented".to_string(),
557 ))
558 }
559}
560
561impl SignalPredictor for Transformer {
562 #[instrument(skip(self, input))]
563 fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
564 let mut hidden = input.dot(&self.input_proj);
566
567 for layer in &mut self.layers {
569 hidden = layer.forward(&hidden)?;
570 }
571
572 hidden = self.ln_out.forward(&hidden);
574
575 let output = hidden.dot(&self.output_proj);
577 Ok(output)
578 }
579
580 fn reset(&mut self) {
581 for layer in &mut self.layers {
582 layer.reset();
583 }
584 }
585
586 fn context_window(&self) -> usize {
587 self.config.max_seq_len
588 }
589}
590
591impl AutoregressiveModel for Transformer {
592 fn hidden_dim(&self) -> usize {
593 self.config.hidden_dim
594 }
595
596 fn state_dim(&self) -> usize {
597 self.config.hidden_dim
599 }
600
601 fn num_layers(&self) -> usize {
602 self.config.num_layers
603 }
604
605 fn model_type(&self) -> ModelType {
606 ModelType::Transformer
607 }
608
609 fn get_states(&self) -> Vec<HiddenState> {
610 self.layers
612 .iter()
613 .map(|layer| {
614 let cache_len = layer.attention.key_cache.len();
615 let mut combined = Array2::zeros((cache_len.max(1), self.config.hidden_dim));
616
617 for (i, k) in layer.attention.key_cache.iter().enumerate() {
619 for j in 0..k.len().min(self.config.hidden_dim) {
620 combined[[i, j]] = k[j];
621 }
622 }
623
624 let mut hs = HiddenState::new(combined.shape()[0], combined.shape()[1]);
625 hs.update(combined);
626 hs
627 })
628 .collect()
629 }
630
631 fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
632 if states.len() != self.config.num_layers {
633 return Err(ModelError::state_count_mismatch(
634 "Transformer",
635 self.config.num_layers,
636 states.len(),
637 ));
638 }
639
640 for (layer_idx, layer) in self.layers.iter_mut().enumerate() {
641 let combined = states[layer_idx].state();
642
643 layer.attention.key_cache.clear();
645 for i in 0..combined.shape()[0] {
646 let mut k = Array1::zeros(self.config.hidden_dim);
647 for j in 0..self.config.hidden_dim.min(combined.shape()[1]) {
648 k[j] = combined[[i, j]];
649 }
650 layer.attention.key_cache.push_back(k);
651 }
652 }
653
654 Ok(())
655 }
656}
657
658#[cfg(test)]
659mod tests {
660 use super::*;
661
662 #[test]
663 fn test_transformer_config() {
664 let config = TransformerConfig::new()
665 .hidden_dim(256)
666 .num_heads(8)
667 .num_layers(4);
668
669 assert_eq!(config.hidden_dim, 256);
670 assert_eq!(config.num_heads, 8);
671 assert_eq!(config.head_dim, 32);
672 assert!(config.validate().is_ok());
673 }
674
675 #[test]
676 fn test_transformer_creation() {
677 let config = TransformerConfig::new().hidden_dim(128).num_heads(4);
678 let model = Transformer::new(config);
679 assert!(model.is_ok());
680 }
681
682 #[test]
683 fn test_transformer_forward() {
684 let config = TransformerConfig::new()
685 .hidden_dim(64)
686 .num_heads(4)
687 .num_layers(2)
688 .max_seq_len(128);
689 let mut model = Transformer::new(config).expect("Failed to create Transformer");
690
691 let input = Array1::from_vec(vec![0.5]);
692 let output = model.step(&input);
693 assert!(output.is_ok());
694 }
695
696 #[test]
697 fn test_invalid_heads() {
698 let config = TransformerConfig::new().hidden_dim(100).num_heads(3); assert!(config.validate().is_err());
700 }
701
702 #[test]
703 fn test_context_window() {
704 let config = TransformerConfig::new()
707 .hidden_dim(64)
708 .num_heads(4)
709 .num_layers(2)
710 .max_seq_len(512);
711 let model = Transformer::new(config).expect("Failed to create Transformer");
712 assert_eq!(model.context_window(), 512);
713 }
714}