1use forgellm_frontend::ir::*;
8use forgellm_frontend::weight_loader::ModelWeights;
9
10use crate::kernels;
11use crate::kv_cache::KVCache;
12
13pub fn forward(
17 token_id: u32,
18 pos: usize,
19 graph: &Graph,
20 weights: &ModelWeights,
21 cache: &mut KVCache,
22) -> Vec<f32> {
23 let config = graph.config.as_ref().expect("graph must have config");
24
25 let hidden = config.hidden_size;
26 let intermediate = config.intermediate_size;
27 let num_heads = config.num_attention_heads;
28 let num_kv_heads = config.num_kv_heads;
29 let head_dim = config.head_dim;
30 let vocab = config.vocab_size;
31
32 let embed_w = weights.tensor("model.embed_tokens.weight");
34 let mut hidden_state = vec![0.0f32; hidden];
35 let offset = token_id as usize * hidden;
36 hidden_state.copy_from_slice(&embed_w[offset..offset + hidden]);
37
38 if matches!(
42 config.architecture,
43 forgellm_frontend::ir::Architecture::Gemma
44 ) {
45 let scale = (hidden as f32).sqrt();
46 for v in hidden_state.iter_mut() {
47 *v *= scale;
48 }
49 }
50
51 let mut normed = vec![0.0f32; hidden];
53 let mut q = vec![0.0f32; num_heads * head_dim];
54 let mut k = vec![0.0f32; num_kv_heads * head_dim];
55 let mut v = vec![0.0f32; num_kv_heads * head_dim];
56 let mut attn_out = vec![0.0f32; num_heads * head_dim];
57 let mut attn_proj = vec![0.0f32; hidden];
58 let mut residual = vec![0.0f32; hidden];
59 let mut gate = vec![0.0f32; intermediate];
60 let mut gate_act = vec![0.0f32; intermediate];
61 let mut up = vec![0.0f32; intermediate];
62 let mut ffn_hidden = vec![0.0f32; intermediate];
63 let mut ffn_out = vec![0.0f32; hidden];
64
65 for layer_idx in 0..config.num_layers {
66 let prefix = format!("model.layers.{layer_idx}");
67
68 let norm_w = weights.tensor(&format!("{prefix}.input_layernorm.weight"));
70 rms_norm(&mut normed, &hidden_state, norm_w, config.rms_norm_eps);
71
72 let q_w = weights.tensor(&format!("{prefix}.self_attn.q_proj.weight"));
74 let k_w = weights.tensor(&format!("{prefix}.self_attn.k_proj.weight"));
75 let v_w = weights.tensor(&format!("{prefix}.self_attn.v_proj.weight"));
76 matmul(&mut q, &normed, q_w, 1, hidden, num_heads * head_dim);
77 matmul(&mut k, &normed, k_w, 1, hidden, num_kv_heads * head_dim);
78 matmul(&mut v, &normed, v_w, 1, hidden, num_kv_heads * head_dim);
79
80 if let Some(q_bias) = weights.get(&format!("{prefix}.self_attn.q_proj.bias")) {
82 elementwise_add_inplace(&mut q, q_bias);
83 }
84 if let Some(k_bias) = weights.get(&format!("{prefix}.self_attn.k_proj.bias")) {
85 elementwise_add_inplace(&mut k, k_bias);
86 }
87 if let Some(v_bias) = weights.get(&format!("{prefix}.self_attn.v_proj.bias")) {
88 elementwise_add_inplace(&mut v, v_bias);
89 }
90
91 rope(&mut q, pos, head_dim, num_heads, config.rope_theta);
93 rope(&mut k, pos, head_dim, num_kv_heads, config.rope_theta);
94
95 cache.append(layer_idx, &k, &v);
97
98 attention(
100 &mut attn_out,
101 &q,
102 cache.k(layer_idx),
103 cache.v(layer_idx),
104 &AttentionParams {
105 seq_len: pos + 1,
106 num_heads,
107 num_kv_heads,
108 head_dim,
109 },
110 );
111
112 let o_w = weights.tensor(&format!("{prefix}.self_attn.o_proj.weight"));
114 matmul(
115 &mut attn_proj,
116 &attn_out,
117 o_w,
118 1,
119 num_heads * head_dim,
120 hidden,
121 );
122
123 elementwise_add(&mut residual, &hidden_state, &attn_proj);
125
126 let ffn_norm_w = weights.tensor(&format!("{prefix}.post_attention_layernorm.weight"));
128 rms_norm(&mut normed, &residual, ffn_norm_w, config.rms_norm_eps);
129
130 let gate_w = weights.tensor(&format!("{prefix}.mlp.gate_proj.weight"));
132 let up_w = weights.tensor(&format!("{prefix}.mlp.up_proj.weight"));
133 let down_w = weights.tensor(&format!("{prefix}.mlp.down_proj.weight"));
134
135 matmul(&mut gate, &normed, gate_w, 1, hidden, intermediate);
136 match config.hidden_activation {
137 forgellm_frontend::ir::HiddenActivation::SiLU => silu(&mut gate_act, &gate),
138 forgellm_frontend::ir::HiddenActivation::GeluApprox => {
139 kernels::gelu(&mut gate_act, &gate)
140 }
141 }
142 matmul(&mut up, &normed, up_w, 1, hidden, intermediate);
143 elementwise_mul(&mut ffn_hidden, &gate_act, &up);
144 matmul(&mut ffn_out, &ffn_hidden, down_w, 1, intermediate, hidden);
145
146 elementwise_add(&mut hidden_state, &residual, &ffn_out);
148 }
149
150 let final_norm_w = weights.tensor("model.norm.weight");
152 rms_norm(
153 &mut normed,
154 &hidden_state,
155 final_norm_w,
156 config.rms_norm_eps,
157 );
158
159 let lm_head_w = weights
161 .get("lm_head.weight")
162 .unwrap_or_else(|| weights.tensor("model.embed_tokens.weight"));
163 let mut logits = vec![0.0f32; vocab];
164 matmul(&mut logits, &normed, lm_head_w, 1, hidden, vocab);
165
166 logits
167}
168
169fn rms_norm(output: &mut [f32], input: &[f32], weight: &[f32], eps: f32) {
172 kernels::rms_norm(output, input, weight, eps);
173}
174
175fn matmul(output: &mut [f32], input: &[f32], weight: &[f32], m: usize, k: usize, n: usize) {
176 kernels::matmul(output, input, weight, m, k, n);
177}
178
179fn silu(output: &mut [f32], input: &[f32]) {
180 kernels::silu(output, input);
181}
182
183fn elementwise_mul(output: &mut [f32], a: &[f32], b: &[f32]) {
184 kernels::elementwise_mul(output, a, b);
185}
186
187fn elementwise_add(output: &mut [f32], a: &[f32], b: &[f32]) {
188 kernels::elementwise_add(output, a, b);
189}
190
191fn elementwise_add_inplace(a: &mut [f32], b: &[f32]) {
192 for i in 0..a.len() {
193 a[i] += b[i];
194 }
195}
196
197fn rope(data: &mut [f32], pos: usize, head_dim: usize, num_heads: usize, theta: f32) {
198 for h in 0..num_heads {
199 let head_offset = h * head_dim;
200 for i in (0..head_dim).step_by(2) {
201 let freq = 1.0 / theta.powf(i as f32 / head_dim as f32);
202 let angle = pos as f32 * freq;
203 let cos_val = angle.cos();
204 let sin_val = angle.sin();
205 let x0 = data[head_offset + i];
206 let x1 = data[head_offset + i + 1];
207 data[head_offset + i] = x0 * cos_val - x1 * sin_val;
208 data[head_offset + i + 1] = x0 * sin_val + x1 * cos_val;
209 }
210 }
211}
212
213struct AttentionParams {
214 seq_len: usize,
215 num_heads: usize,
216 num_kv_heads: usize,
217 head_dim: usize,
218}
219
220fn attention(
221 output: &mut [f32],
222 q: &[f32],
223 k_cache: &[f32],
224 v_cache: &[f32],
225 params: &AttentionParams,
226) {
227 kernels::attention(
228 output,
229 q,
230 k_cache,
231 v_cache,
232 params.seq_len,
233 params.num_heads,
234 params.num_kv_heads,
235 params.head_dim,
236 );
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use std::collections::HashMap;
243
244 #[test]
245 fn rms_norm_basic() {
246 let input = vec![1.0, 2.0, 3.0, 4.0];
247 let weight = vec![1.0; 4];
248 let mut output = vec![0.0; 4];
249 rms_norm(&mut output, &input, &weight, 1e-5);
250
251 let rms = (30.0f32 / 4.0 + 1e-5).sqrt();
253 let expected: Vec<f32> = input.iter().map(|x| x / rms).collect();
254 for (a, b) in output.iter().zip(expected.iter()) {
255 assert!((a - b).abs() < 1e-5, "got {a}, expected {b}");
256 }
257 }
258
259 #[test]
260 fn matmul_basic() {
261 let input = vec![1.0, 2.0];
264 let weight = vec![1.0, 2.0, 3.0, 4.0]; let mut output = vec![0.0; 2];
266 matmul(&mut output, &input, &weight, 1, 2, 2);
267 assert!((output[0] - 5.0).abs() < 1e-6);
268 assert!((output[1] - 11.0).abs() < 1e-6);
269 }
270
271 #[test]
272 fn silu_basic() {
273 let input = vec![0.0, 1.0, -1.0];
274 let mut output = vec![0.0; 3];
275 silu(&mut output, &input);
276 assert!((output[0] - 0.0).abs() < 1e-6);
278 assert!((output[1] - 0.7311).abs() < 1e-3);
279 assert!((output[2] - (-0.2689)).abs() < 1e-3);
280 }
281
282 #[test]
283 fn softmax_basic() {
284 let mut values = vec![1.0, 2.0, 3.0];
285 kernels::softmax(&mut values);
286 let sum: f32 = values.iter().sum();
287 assert!((sum - 1.0).abs() < 1e-6);
288 assert!(values[2] > values[1]);
289 assert!(values[1] > values[0]);
290 }
291
292 #[test]
293 fn rope_preserves_magnitude() {
294 let mut data = vec![1.0, 0.0, 0.0, 1.0]; let mag_before: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
297 rope(&mut data, 5, 4, 1, 10000.0);
298 let mag_after: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
299 assert!(
300 (mag_before - mag_after).abs() < 1e-5,
301 "RoPE changed magnitude: {mag_before} → {mag_after}"
302 );
303 }
304
305 #[test]
306 fn forward_with_tiny_model() {
307 let config = ModelConfig {
309 architecture: Architecture::Llama,
310 hidden_size: 8,
311 intermediate_size: 16,
312 num_layers: 1,
313 num_attention_heads: 2,
314 num_kv_heads: 1,
315 head_dim: 4,
316 vocab_size: 16,
317 max_seq_len: 32,
318 rms_norm_eps: 1e-5,
319 rope_theta: 10000.0,
320 dtype: DType::F32,
321 sliding_window_size: None,
322 qkv_bias: false,
323 hidden_activation: HiddenActivation::SiLU,
324 };
325
326 let graph = forgellm_frontend::graph_builder::build_graph(&config).unwrap();
327
328 let mut tensors = HashMap::new();
330 let h = 8;
331 let inter = 16;
332 let vocab = 16;
333 let num_heads = 2;
334 let num_kv_heads = 1;
335 let head_dim = 4;
336
337 tensors.insert("model.embed_tokens.weight".into(), vec![0.1f32; vocab * h]);
338 tensors.insert(
339 "model.layers.0.input_layernorm.weight".into(),
340 vec![1.0f32; h],
341 );
342 tensors.insert(
343 "model.layers.0.self_attn.q_proj.weight".into(),
344 vec![0.01f32; num_heads * head_dim * h],
345 );
346 tensors.insert(
347 "model.layers.0.self_attn.k_proj.weight".into(),
348 vec![0.01f32; num_kv_heads * head_dim * h],
349 );
350 tensors.insert(
351 "model.layers.0.self_attn.v_proj.weight".into(),
352 vec![0.01f32; num_kv_heads * head_dim * h],
353 );
354 tensors.insert(
355 "model.layers.0.self_attn.o_proj.weight".into(),
356 vec![0.01f32; h * num_heads * head_dim],
357 );
358 tensors.insert(
359 "model.layers.0.post_attention_layernorm.weight".into(),
360 vec![1.0f32; h],
361 );
362 tensors.insert(
363 "model.layers.0.mlp.gate_proj.weight".into(),
364 vec![0.01f32; inter * h],
365 );
366 tensors.insert(
367 "model.layers.0.mlp.up_proj.weight".into(),
368 vec![0.01f32; inter * h],
369 );
370 tensors.insert(
371 "model.layers.0.mlp.down_proj.weight".into(),
372 vec![0.01f32; h * inter],
373 );
374 tensors.insert("model.norm.weight".into(), vec![1.0f32; h]);
375 tensors.insert("lm_head.weight".into(), vec![0.01f32; vocab * h]);
376
377 let weights = ModelWeights { tensors };
378 let mut kv_cache = KVCache::new(1, num_kv_heads, head_dim);
379
380 let logits = forward(0, 0, &graph, &weights, &mut kv_cache);
382 assert_eq!(logits.len(), vocab);
383 assert_eq!(kv_cache.len(), 0); for &l in &logits {
387 assert!(l.is_finite(), "logit is not finite: {l}");
388 }
389 }
390
391 #[test]
392 fn forward_multi_token() {
393 let config = ModelConfig {
394 architecture: Architecture::Llama,
395 hidden_size: 8,
396 intermediate_size: 16,
397 num_layers: 1,
398 num_attention_heads: 2,
399 num_kv_heads: 1,
400 head_dim: 4,
401 vocab_size: 16,
402 max_seq_len: 32,
403 rms_norm_eps: 1e-5,
404 rope_theta: 10000.0,
405 dtype: DType::F32,
406 sliding_window_size: None,
407 qkv_bias: false,
408 hidden_activation: HiddenActivation::SiLU,
409 };
410
411 let graph = forgellm_frontend::graph_builder::build_graph(&config).unwrap();
412
413 let mut tensors = HashMap::new();
414 let h = 8;
415 let inter = 16;
416 let vocab = 16;
417
418 tensors.insert("model.embed_tokens.weight".into(), vec![0.1f32; vocab * h]);
419 tensors.insert("model.layers.0.input_layernorm.weight".into(), vec![1.0; h]);
420 tensors.insert(
421 "model.layers.0.self_attn.q_proj.weight".into(),
422 vec![0.01; 8 * h],
423 );
424 tensors.insert(
425 "model.layers.0.self_attn.k_proj.weight".into(),
426 vec![0.01; 4 * h],
427 );
428 tensors.insert(
429 "model.layers.0.self_attn.v_proj.weight".into(),
430 vec![0.01; 4 * h],
431 );
432 tensors.insert(
433 "model.layers.0.self_attn.o_proj.weight".into(),
434 vec![0.01; h * 8],
435 );
436 tensors.insert(
437 "model.layers.0.post_attention_layernorm.weight".into(),
438 vec![1.0; h],
439 );
440 tensors.insert(
441 "model.layers.0.mlp.gate_proj.weight".into(),
442 vec![0.01; inter * h],
443 );
444 tensors.insert(
445 "model.layers.0.mlp.up_proj.weight".into(),
446 vec![0.01; inter * h],
447 );
448 tensors.insert(
449 "model.layers.0.mlp.down_proj.weight".into(),
450 vec![0.01; h * inter],
451 );
452 tensors.insert("model.norm.weight".into(), vec![1.0; h]);
453 tensors.insert("lm_head.weight".into(), vec![0.01; vocab * h]);
454
455 let weights = ModelWeights { tensors };
456 let mut cache = KVCache::new(1, 1, 4);
457
458 for pos in 0..3 {
460 let logits = forward(1, pos, &graph, &weights, &mut cache);
461 assert_eq!(logits.len(), vocab);
462 cache.advance();
463 }
464
465 assert_eq!(cache.len(), 3);
466 }
467
468 fn tiny_model_with_varied_weights() -> (ModelConfig, Graph, ModelWeights) {
473 let config = ModelConfig {
474 architecture: Architecture::Llama,
475 hidden_size: 8,
476 intermediate_size: 16,
477 num_layers: 1,
478 num_attention_heads: 2,
479 num_kv_heads: 1,
480 head_dim: 4,
481 vocab_size: 16,
482 max_seq_len: 32,
483 rms_norm_eps: 1e-5,
484 rope_theta: 10000.0,
485 dtype: DType::F32,
486 sliding_window_size: None,
487 qkv_bias: false,
488 hidden_activation: HiddenActivation::SiLU,
489 };
490
491 let graph = forgellm_frontend::graph_builder::build_graph(&config).unwrap();
492
493 let h = 8;
494 let inter = 16;
495 let vocab = 16;
496 let num_heads = 2;
497 let num_kv_heads = 1;
498 let head_dim = 4;
499
500 let mut tensors = HashMap::new();
501
502 let mut embed = vec![0.0f32; vocab * h];
504 for tok in 0..vocab {
505 for d in 0..h {
506 embed[tok * h + d] = ((tok * h + d) as f32 + 1.0) * 0.05;
507 }
508 }
509 tensors.insert("model.embed_tokens.weight".into(), embed);
510
511 tensors.insert(
512 "model.layers.0.input_layernorm.weight".into(),
513 vec![1.0f32; h],
514 );
515 let q_w: Vec<f32> = (0..num_heads * head_dim * h)
517 .map(|i| ((i % 7) as f32 + 1.0) * 0.01)
518 .collect();
519 let k_w: Vec<f32> = (0..num_kv_heads * head_dim * h)
520 .map(|i| ((i % 5) as f32 + 1.0) * 0.01)
521 .collect();
522 let v_w: Vec<f32> = (0..num_kv_heads * head_dim * h)
523 .map(|i| ((i % 3) as f32 + 1.0) * 0.01)
524 .collect();
525 let o_w: Vec<f32> = (0..h * num_heads * head_dim)
526 .map(|i| ((i % 11) as f32 + 1.0) * 0.01)
527 .collect();
528 tensors.insert("model.layers.0.self_attn.q_proj.weight".into(), q_w);
529 tensors.insert("model.layers.0.self_attn.k_proj.weight".into(), k_w);
530 tensors.insert("model.layers.0.self_attn.v_proj.weight".into(), v_w);
531 tensors.insert("model.layers.0.self_attn.o_proj.weight".into(), o_w);
532 tensors.insert(
533 "model.layers.0.post_attention_layernorm.weight".into(),
534 vec![1.0f32; h],
535 );
536
537 let gate_w: Vec<f32> = (0..inter * h)
538 .map(|i| ((i % 13) as f32 + 1.0) * 0.01)
539 .collect();
540 let up_w: Vec<f32> = (0..inter * h)
541 .map(|i| ((i % 9) as f32 + 1.0) * 0.01)
542 .collect();
543 let down_w: Vec<f32> = (0..h * inter)
544 .map(|i| ((i % 7) as f32 + 1.0) * 0.01)
545 .collect();
546 tensors.insert("model.layers.0.mlp.gate_proj.weight".into(), gate_w);
547 tensors.insert("model.layers.0.mlp.up_proj.weight".into(), up_w);
548 tensors.insert("model.layers.0.mlp.down_proj.weight".into(), down_w);
549
550 tensors.insert("model.norm.weight".into(), vec![1.0f32; h]);
551 let lm_head: Vec<f32> = (0..vocab * h)
553 .map(|i| ((i % 17) as f32 - 8.0) * 0.02)
554 .collect();
555 tensors.insert("lm_head.weight".into(), lm_head);
556
557 let weights = ModelWeights { tensors };
558 (config, graph, weights)
559 }
560
561 #[test]
562 fn different_prompts_produce_different_logits() {
563 let (_config, graph, weights) = tiny_model_with_varied_weights();
566
567 let mut cache1 = KVCache::new(1, 1, 4);
568 let logits1 = forward(0, 0, &graph, &weights, &mut cache1);
569
570 let mut cache2 = KVCache::new(1, 1, 4);
571 let logits2 = forward(5, 0, &graph, &weights, &mut cache2);
572
573 for &l in &logits1 {
575 assert!(l.is_finite(), "logits1 contains non-finite value: {l}");
576 }
577 for &l in &logits2 {
578 assert!(l.is_finite(), "logits2 contains non-finite value: {l}");
579 }
580
581 let differs = logits1
583 .iter()
584 .zip(logits2.iter())
585 .any(|(a, b)| (a - b).abs() > 1e-6);
586 assert!(
587 differs,
588 "different input tokens should produce different logit distributions"
589 );
590 }
591
592 #[test]
593 fn cache_reset_produces_same_logits() {
594 let (_config, graph, weights) = tiny_model_with_varied_weights();
598
599 let mut cache = KVCache::new(1, 1, 4);
601 let logits_fresh = forward(3, 0, &graph, &weights, &mut cache);
602
603 cache.advance();
605 let _ = forward(7, 1, &graph, &weights, &mut cache);
606 cache.advance();
607 assert_eq!(cache.len(), 2);
608
609 cache.clear();
611 assert_eq!(cache.len(), 0);
612 let logits_after_reset = forward(3, 0, &graph, &weights, &mut cache);
613
614 for (i, (a, b)) in logits_fresh
616 .iter()
617 .zip(logits_after_reset.iter())
618 .enumerate()
619 {
620 assert!(
621 (a - b).abs() < 1e-6,
622 "logit[{i}] differs after reset: fresh={a}, after_reset={b}"
623 );
624 }
625 }
626
627 #[test]
628 fn forward_at_pos_zero_no_nan() {
629 let (_config, graph, weights) = tiny_model_with_varied_weights();
633 let mut cache = KVCache::new(1, 1, 4);
634
635 let logits = forward(0, 0, &graph, &weights, &mut cache);
636 assert_eq!(logits.len(), 16);
637
638 for (i, &l) in logits.iter().enumerate() {
639 assert!(
640 !l.is_nan(),
641 "logit[{i}] is NaN at pos=0 — likely a softmax or attention bug"
642 );
643 assert!(!l.is_infinite(), "logit[{i}] is infinite at pos=0");
644 }
645 }
646}