1use forgellm_frontend::ir::*;
8use forgellm_frontend::weight_loader::ModelWeights;
9
10use crate::kernels;
11use crate::kv_cache::KVCache;
12
13pub fn forward(
17 token_id: u32,
18 pos: usize,
19 graph: &Graph,
20 weights: &ModelWeights,
21 cache: &mut KVCache,
22) -> Vec<f32> {
23 let config = graph.config.as_ref().expect("graph must have config");
24
25 let hidden = config.hidden_size;
26 let intermediate = config.intermediate_size;
27 let num_heads = config.num_attention_heads;
28 let num_kv_heads = config.num_kv_heads;
29 let head_dim = config.head_dim;
30 let vocab = config.vocab_size;
31
32 let embed_w = weights.tensor("model.embed_tokens.weight");
34 let mut hidden_state = vec![0.0f32; hidden];
35 let offset = token_id as usize * hidden;
36 hidden_state.copy_from_slice(&embed_w[offset..offset + hidden]);
37
38 let mut normed = vec![0.0f32; hidden];
40 let mut q = vec![0.0f32; num_heads * head_dim];
41 let mut k = vec![0.0f32; num_kv_heads * head_dim];
42 let mut v = vec![0.0f32; num_kv_heads * head_dim];
43 let mut attn_out = vec![0.0f32; num_heads * head_dim];
44 let mut attn_proj = vec![0.0f32; hidden];
45 let mut residual = vec![0.0f32; hidden];
46 let mut gate = vec![0.0f32; intermediate];
47 let mut gate_act = vec![0.0f32; intermediate];
48 let mut up = vec![0.0f32; intermediate];
49 let mut ffn_hidden = vec![0.0f32; intermediate];
50 let mut ffn_out = vec![0.0f32; hidden];
51
52 for layer_idx in 0..config.num_layers {
53 let prefix = format!("model.layers.{layer_idx}");
54
55 let norm_w = weights.tensor(&format!("{prefix}.input_layernorm.weight"));
57 rms_norm(&mut normed, &hidden_state, norm_w, config.rms_norm_eps);
58
59 let q_w = weights.tensor(&format!("{prefix}.self_attn.q_proj.weight"));
61 let k_w = weights.tensor(&format!("{prefix}.self_attn.k_proj.weight"));
62 let v_w = weights.tensor(&format!("{prefix}.self_attn.v_proj.weight"));
63 matmul(&mut q, &normed, q_w, 1, hidden, num_heads * head_dim);
64 matmul(&mut k, &normed, k_w, 1, hidden, num_kv_heads * head_dim);
65 matmul(&mut v, &normed, v_w, 1, hidden, num_kv_heads * head_dim);
66
67 if let Some(q_bias) = weights.get(&format!("{prefix}.self_attn.q_proj.bias")) {
69 elementwise_add_inplace(&mut q, q_bias);
70 }
71 if let Some(k_bias) = weights.get(&format!("{prefix}.self_attn.k_proj.bias")) {
72 elementwise_add_inplace(&mut k, k_bias);
73 }
74 if let Some(v_bias) = weights.get(&format!("{prefix}.self_attn.v_proj.bias")) {
75 elementwise_add_inplace(&mut v, v_bias);
76 }
77
78 rope(&mut q, pos, head_dim, num_heads, config.rope_theta);
80 rope(&mut k, pos, head_dim, num_kv_heads, config.rope_theta);
81
82 cache.append(layer_idx, &k, &v);
84
85 attention(
87 &mut attn_out,
88 &q,
89 cache.k(layer_idx),
90 cache.v(layer_idx),
91 &AttentionParams {
92 seq_len: pos + 1,
93 num_heads,
94 num_kv_heads,
95 head_dim,
96 },
97 );
98
99 let o_w = weights.tensor(&format!("{prefix}.self_attn.o_proj.weight"));
101 matmul(
102 &mut attn_proj,
103 &attn_out,
104 o_w,
105 1,
106 num_heads * head_dim,
107 hidden,
108 );
109
110 elementwise_add(&mut residual, &hidden_state, &attn_proj);
112
113 let ffn_norm_w = weights.tensor(&format!("{prefix}.post_attention_layernorm.weight"));
115 rms_norm(&mut normed, &residual, ffn_norm_w, config.rms_norm_eps);
116
117 let gate_w = weights.tensor(&format!("{prefix}.mlp.gate_proj.weight"));
119 let up_w = weights.tensor(&format!("{prefix}.mlp.up_proj.weight"));
120 let down_w = weights.tensor(&format!("{prefix}.mlp.down_proj.weight"));
121
122 matmul(&mut gate, &normed, gate_w, 1, hidden, intermediate);
123 silu(&mut gate_act, &gate);
124 matmul(&mut up, &normed, up_w, 1, hidden, intermediate);
125 elementwise_mul(&mut ffn_hidden, &gate_act, &up);
126 matmul(&mut ffn_out, &ffn_hidden, down_w, 1, intermediate, hidden);
127
128 elementwise_add(&mut hidden_state, &residual, &ffn_out);
130 }
131
132 let final_norm_w = weights.tensor("model.norm.weight");
134 rms_norm(
135 &mut normed,
136 &hidden_state,
137 final_norm_w,
138 config.rms_norm_eps,
139 );
140
141 let lm_head_w = weights
143 .get("lm_head.weight")
144 .unwrap_or_else(|| weights.tensor("model.embed_tokens.weight"));
145 let mut logits = vec![0.0f32; vocab];
146 matmul(&mut logits, &normed, lm_head_w, 1, hidden, vocab);
147
148 logits
149}
150
151fn rms_norm(output: &mut [f32], input: &[f32], weight: &[f32], eps: f32) {
154 kernels::rms_norm(output, input, weight, eps);
155}
156
157fn matmul(output: &mut [f32], input: &[f32], weight: &[f32], m: usize, k: usize, n: usize) {
158 kernels::matmul(output, input, weight, m, k, n);
159}
160
161fn silu(output: &mut [f32], input: &[f32]) {
162 kernels::silu(output, input);
163}
164
165fn elementwise_mul(output: &mut [f32], a: &[f32], b: &[f32]) {
166 kernels::elementwise_mul(output, a, b);
167}
168
169fn elementwise_add(output: &mut [f32], a: &[f32], b: &[f32]) {
170 kernels::elementwise_add(output, a, b);
171}
172
173fn elementwise_add_inplace(a: &mut [f32], b: &[f32]) {
174 for i in 0..a.len() {
175 a[i] += b[i];
176 }
177}
178
179fn rope(data: &mut [f32], pos: usize, head_dim: usize, num_heads: usize, theta: f32) {
180 for h in 0..num_heads {
181 let head_offset = h * head_dim;
182 for i in (0..head_dim).step_by(2) {
183 let freq = 1.0 / theta.powf(i as f32 / head_dim as f32);
184 let angle = pos as f32 * freq;
185 let cos_val = angle.cos();
186 let sin_val = angle.sin();
187 let x0 = data[head_offset + i];
188 let x1 = data[head_offset + i + 1];
189 data[head_offset + i] = x0 * cos_val - x1 * sin_val;
190 data[head_offset + i + 1] = x0 * sin_val + x1 * cos_val;
191 }
192 }
193}
194
195struct AttentionParams {
196 seq_len: usize,
197 num_heads: usize,
198 num_kv_heads: usize,
199 head_dim: usize,
200}
201
202fn attention(
203 output: &mut [f32],
204 q: &[f32],
205 k_cache: &[f32],
206 v_cache: &[f32],
207 params: &AttentionParams,
208) {
209 kernels::attention(
210 output,
211 q,
212 k_cache,
213 v_cache,
214 params.seq_len,
215 params.num_heads,
216 params.num_kv_heads,
217 params.head_dim,
218 );
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use std::collections::HashMap;
225
226 #[test]
227 fn rms_norm_basic() {
228 let input = vec![1.0, 2.0, 3.0, 4.0];
229 let weight = vec![1.0; 4];
230 let mut output = vec![0.0; 4];
231 rms_norm(&mut output, &input, &weight, 1e-5);
232
233 let rms = (30.0f32 / 4.0 + 1e-5).sqrt();
235 let expected: Vec<f32> = input.iter().map(|x| x / rms).collect();
236 for (a, b) in output.iter().zip(expected.iter()) {
237 assert!((a - b).abs() < 1e-5, "got {a}, expected {b}");
238 }
239 }
240
241 #[test]
242 fn matmul_basic() {
243 let input = vec![1.0, 2.0];
246 let weight = vec![1.0, 2.0, 3.0, 4.0]; let mut output = vec![0.0; 2];
248 matmul(&mut output, &input, &weight, 1, 2, 2);
249 assert!((output[0] - 5.0).abs() < 1e-6);
250 assert!((output[1] - 11.0).abs() < 1e-6);
251 }
252
253 #[test]
254 fn silu_basic() {
255 let input = vec![0.0, 1.0, -1.0];
256 let mut output = vec![0.0; 3];
257 silu(&mut output, &input);
258 assert!((output[0] - 0.0).abs() < 1e-6);
260 assert!((output[1] - 0.7311).abs() < 1e-3);
261 assert!((output[2] - (-0.2689)).abs() < 1e-3);
262 }
263
264 #[test]
265 fn softmax_basic() {
266 let mut values = vec![1.0, 2.0, 3.0];
267 kernels::softmax(&mut values);
268 let sum: f32 = values.iter().sum();
269 assert!((sum - 1.0).abs() < 1e-6);
270 assert!(values[2] > values[1]);
271 assert!(values[1] > values[0]);
272 }
273
274 #[test]
275 fn rope_preserves_magnitude() {
276 let mut data = vec![1.0, 0.0, 0.0, 1.0]; let mag_before: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
279 rope(&mut data, 5, 4, 1, 10000.0);
280 let mag_after: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
281 assert!(
282 (mag_before - mag_after).abs() < 1e-5,
283 "RoPE changed magnitude: {mag_before} → {mag_after}"
284 );
285 }
286
287 #[test]
288 fn forward_with_tiny_model() {
289 let config = ModelConfig {
291 architecture: Architecture::Llama,
292 hidden_size: 8,
293 intermediate_size: 16,
294 num_layers: 1,
295 num_attention_heads: 2,
296 num_kv_heads: 1,
297 head_dim: 4,
298 vocab_size: 16,
299 max_seq_len: 32,
300 rms_norm_eps: 1e-5,
301 rope_theta: 10000.0,
302 dtype: DType::F32,
303 sliding_window_size: None,
304 qkv_bias: false,
305 };
306
307 let graph = forgellm_frontend::graph_builder::build_graph(&config).unwrap();
308
309 let mut tensors = HashMap::new();
311 let h = 8;
312 let inter = 16;
313 let vocab = 16;
314 let num_heads = 2;
315 let num_kv_heads = 1;
316 let head_dim = 4;
317
318 tensors.insert("model.embed_tokens.weight".into(), vec![0.1f32; vocab * h]);
319 tensors.insert(
320 "model.layers.0.input_layernorm.weight".into(),
321 vec![1.0f32; h],
322 );
323 tensors.insert(
324 "model.layers.0.self_attn.q_proj.weight".into(),
325 vec![0.01f32; num_heads * head_dim * h],
326 );
327 tensors.insert(
328 "model.layers.0.self_attn.k_proj.weight".into(),
329 vec![0.01f32; num_kv_heads * head_dim * h],
330 );
331 tensors.insert(
332 "model.layers.0.self_attn.v_proj.weight".into(),
333 vec![0.01f32; num_kv_heads * head_dim * h],
334 );
335 tensors.insert(
336 "model.layers.0.self_attn.o_proj.weight".into(),
337 vec![0.01f32; h * num_heads * head_dim],
338 );
339 tensors.insert(
340 "model.layers.0.post_attention_layernorm.weight".into(),
341 vec![1.0f32; h],
342 );
343 tensors.insert(
344 "model.layers.0.mlp.gate_proj.weight".into(),
345 vec![0.01f32; inter * h],
346 );
347 tensors.insert(
348 "model.layers.0.mlp.up_proj.weight".into(),
349 vec![0.01f32; inter * h],
350 );
351 tensors.insert(
352 "model.layers.0.mlp.down_proj.weight".into(),
353 vec![0.01f32; h * inter],
354 );
355 tensors.insert("model.norm.weight".into(), vec![1.0f32; h]);
356 tensors.insert("lm_head.weight".into(), vec![0.01f32; vocab * h]);
357
358 let weights = ModelWeights { tensors };
359 let mut kv_cache = KVCache::new(1, num_kv_heads, head_dim);
360
361 let logits = forward(0, 0, &graph, &weights, &mut kv_cache);
363 assert_eq!(logits.len(), vocab);
364 assert_eq!(kv_cache.len(), 0); for &l in &logits {
368 assert!(l.is_finite(), "logit is not finite: {l}");
369 }
370 }
371
372 #[test]
373 fn forward_multi_token() {
374 let config = ModelConfig {
375 architecture: Architecture::Llama,
376 hidden_size: 8,
377 intermediate_size: 16,
378 num_layers: 1,
379 num_attention_heads: 2,
380 num_kv_heads: 1,
381 head_dim: 4,
382 vocab_size: 16,
383 max_seq_len: 32,
384 rms_norm_eps: 1e-5,
385 rope_theta: 10000.0,
386 dtype: DType::F32,
387 sliding_window_size: None,
388 qkv_bias: false,
389 };
390
391 let graph = forgellm_frontend::graph_builder::build_graph(&config).unwrap();
392
393 let mut tensors = HashMap::new();
394 let h = 8;
395 let inter = 16;
396 let vocab = 16;
397
398 tensors.insert("model.embed_tokens.weight".into(), vec![0.1f32; vocab * h]);
399 tensors.insert("model.layers.0.input_layernorm.weight".into(), vec![1.0; h]);
400 tensors.insert(
401 "model.layers.0.self_attn.q_proj.weight".into(),
402 vec![0.01; 8 * h],
403 );
404 tensors.insert(
405 "model.layers.0.self_attn.k_proj.weight".into(),
406 vec![0.01; 4 * h],
407 );
408 tensors.insert(
409 "model.layers.0.self_attn.v_proj.weight".into(),
410 vec![0.01; 4 * h],
411 );
412 tensors.insert(
413 "model.layers.0.self_attn.o_proj.weight".into(),
414 vec![0.01; h * 8],
415 );
416 tensors.insert(
417 "model.layers.0.post_attention_layernorm.weight".into(),
418 vec![1.0; h],
419 );
420 tensors.insert(
421 "model.layers.0.mlp.gate_proj.weight".into(),
422 vec![0.01; inter * h],
423 );
424 tensors.insert(
425 "model.layers.0.mlp.up_proj.weight".into(),
426 vec![0.01; inter * h],
427 );
428 tensors.insert(
429 "model.layers.0.mlp.down_proj.weight".into(),
430 vec![0.01; h * inter],
431 );
432 tensors.insert("model.norm.weight".into(), vec![1.0; h]);
433 tensors.insert("lm_head.weight".into(), vec![0.01; vocab * h]);
434
435 let weights = ModelWeights { tensors };
436 let mut cache = KVCache::new(1, 1, 4);
437
438 for pos in 0..3 {
440 let logits = forward(1, pos, &graph, &weights, &mut cache);
441 assert_eq!(logits.len(), vocab);
442 cache.advance();
443 }
444
445 assert_eq!(cache.len(), 3);
446 }
447
448 fn tiny_model_with_varied_weights() -> (ModelConfig, Graph, ModelWeights) {
453 let config = ModelConfig {
454 architecture: Architecture::Llama,
455 hidden_size: 8,
456 intermediate_size: 16,
457 num_layers: 1,
458 num_attention_heads: 2,
459 num_kv_heads: 1,
460 head_dim: 4,
461 vocab_size: 16,
462 max_seq_len: 32,
463 rms_norm_eps: 1e-5,
464 rope_theta: 10000.0,
465 dtype: DType::F32,
466 sliding_window_size: None,
467 qkv_bias: false,
468 };
469
470 let graph = forgellm_frontend::graph_builder::build_graph(&config).unwrap();
471
472 let h = 8;
473 let inter = 16;
474 let vocab = 16;
475 let num_heads = 2;
476 let num_kv_heads = 1;
477 let head_dim = 4;
478
479 let mut tensors = HashMap::new();
480
481 let mut embed = vec![0.0f32; vocab * h];
483 for tok in 0..vocab {
484 for d in 0..h {
485 embed[tok * h + d] = ((tok * h + d) as f32 + 1.0) * 0.05;
486 }
487 }
488 tensors.insert("model.embed_tokens.weight".into(), embed);
489
490 tensors.insert(
491 "model.layers.0.input_layernorm.weight".into(),
492 vec![1.0f32; h],
493 );
494 let q_w: Vec<f32> = (0..num_heads * head_dim * h)
496 .map(|i| ((i % 7) as f32 + 1.0) * 0.01)
497 .collect();
498 let k_w: Vec<f32> = (0..num_kv_heads * head_dim * h)
499 .map(|i| ((i % 5) as f32 + 1.0) * 0.01)
500 .collect();
501 let v_w: Vec<f32> = (0..num_kv_heads * head_dim * h)
502 .map(|i| ((i % 3) as f32 + 1.0) * 0.01)
503 .collect();
504 let o_w: Vec<f32> = (0..h * num_heads * head_dim)
505 .map(|i| ((i % 11) as f32 + 1.0) * 0.01)
506 .collect();
507 tensors.insert("model.layers.0.self_attn.q_proj.weight".into(), q_w);
508 tensors.insert("model.layers.0.self_attn.k_proj.weight".into(), k_w);
509 tensors.insert("model.layers.0.self_attn.v_proj.weight".into(), v_w);
510 tensors.insert("model.layers.0.self_attn.o_proj.weight".into(), o_w);
511 tensors.insert(
512 "model.layers.0.post_attention_layernorm.weight".into(),
513 vec![1.0f32; h],
514 );
515
516 let gate_w: Vec<f32> = (0..inter * h)
517 .map(|i| ((i % 13) as f32 + 1.0) * 0.01)
518 .collect();
519 let up_w: Vec<f32> = (0..inter * h)
520 .map(|i| ((i % 9) as f32 + 1.0) * 0.01)
521 .collect();
522 let down_w: Vec<f32> = (0..h * inter)
523 .map(|i| ((i % 7) as f32 + 1.0) * 0.01)
524 .collect();
525 tensors.insert("model.layers.0.mlp.gate_proj.weight".into(), gate_w);
526 tensors.insert("model.layers.0.mlp.up_proj.weight".into(), up_w);
527 tensors.insert("model.layers.0.mlp.down_proj.weight".into(), down_w);
528
529 tensors.insert("model.norm.weight".into(), vec![1.0f32; h]);
530 let lm_head: Vec<f32> = (0..vocab * h)
532 .map(|i| ((i % 17) as f32 - 8.0) * 0.02)
533 .collect();
534 tensors.insert("lm_head.weight".into(), lm_head);
535
536 let weights = ModelWeights { tensors };
537 (config, graph, weights)
538 }
539
540 #[test]
541 fn different_prompts_produce_different_logits() {
542 let (_config, graph, weights) = tiny_model_with_varied_weights();
545
546 let mut cache1 = KVCache::new(1, 1, 4);
547 let logits1 = forward(0, 0, &graph, &weights, &mut cache1);
548
549 let mut cache2 = KVCache::new(1, 1, 4);
550 let logits2 = forward(5, 0, &graph, &weights, &mut cache2);
551
552 for &l in &logits1 {
554 assert!(l.is_finite(), "logits1 contains non-finite value: {l}");
555 }
556 for &l in &logits2 {
557 assert!(l.is_finite(), "logits2 contains non-finite value: {l}");
558 }
559
560 let differs = logits1
562 .iter()
563 .zip(logits2.iter())
564 .any(|(a, b)| (a - b).abs() > 1e-6);
565 assert!(
566 differs,
567 "different input tokens should produce different logit distributions"
568 );
569 }
570
571 #[test]
572 fn cache_reset_produces_same_logits() {
573 let (_config, graph, weights) = tiny_model_with_varied_weights();
577
578 let mut cache = KVCache::new(1, 1, 4);
580 let logits_fresh = forward(3, 0, &graph, &weights, &mut cache);
581
582 cache.advance();
584 let _ = forward(7, 1, &graph, &weights, &mut cache);
585 cache.advance();
586 assert_eq!(cache.len(), 2);
587
588 cache.clear();
590 assert_eq!(cache.len(), 0);
591 let logits_after_reset = forward(3, 0, &graph, &weights, &mut cache);
592
593 for (i, (a, b)) in logits_fresh
595 .iter()
596 .zip(logits_after_reset.iter())
597 .enumerate()
598 {
599 assert!(
600 (a - b).abs() < 1e-6,
601 "logit[{i}] differs after reset: fresh={a}, after_reset={b}"
602 );
603 }
604 }
605
606 #[test]
607 fn forward_at_pos_zero_no_nan() {
608 let (_config, graph, weights) = tiny_model_with_varied_weights();
612 let mut cache = KVCache::new(1, 1, 4);
613
614 let logits = forward(0, 0, &graph, &weights, &mut cache);
615 assert_eq!(logits.len(), 16);
616
617 for (i, &l) in logits.iter().enumerate() {
618 assert!(
619 !l.is_nan(),
620 "logit[{i}] is NaN at pos=0 — likely a softmax or attention bug"
621 );
622 assert!(!l.is_infinite(), "logit[{i}] is infinite at pos=0");
623 }
624 }
625}