Skip to main content

forgellm_runtime/
interpreter.rs

1//! Interpreter — executes IR graphs directly on CPU.
2//!
3//! Uses optimized kernels from the `kernels` module for compute-heavy
4//! operations (matmul, rms_norm). Validates correctness and serves as
5//! the primary inference path until the AOT codegen is ready.
6
7use forgellm_frontend::ir::*;
8use forgellm_frontend::weight_loader::ModelWeights;
9
10use crate::kernels;
11use crate::kv_cache::KVCache;
12
13/// Run a single forward pass for one token through the model.
14///
15/// Returns logits of shape `[vocab_size]`.
16pub fn forward(
17    token_id: u32,
18    pos: usize,
19    graph: &Graph,
20    weights: &ModelWeights,
21    cache: &mut KVCache,
22) -> Vec<f32> {
23    let config = graph.config.as_ref().expect("graph must have config");
24
25    let hidden = config.hidden_size;
26    let intermediate = config.intermediate_size;
27    let num_heads = config.num_attention_heads;
28    let num_kv_heads = config.num_kv_heads;
29    let head_dim = config.head_dim;
30    let vocab = config.vocab_size;
31
32    // Embedding lookup
33    let embed_w = weights.tensor("model.embed_tokens.weight");
34    let mut hidden_state = vec![0.0f32; hidden];
35    let offset = token_id as usize * hidden;
36    hidden_state.copy_from_slice(&embed_w[offset..offset + hidden]);
37
38    // Pre-allocate buffers
39    let mut normed = vec![0.0f32; hidden];
40    let mut q = vec![0.0f32; num_heads * head_dim];
41    let mut k = vec![0.0f32; num_kv_heads * head_dim];
42    let mut v = vec![0.0f32; num_kv_heads * head_dim];
43    let mut attn_out = vec![0.0f32; num_heads * head_dim];
44    let mut attn_proj = vec![0.0f32; hidden];
45    let mut residual = vec![0.0f32; hidden];
46    let mut gate = vec![0.0f32; intermediate];
47    let mut gate_act = vec![0.0f32; intermediate];
48    let mut up = vec![0.0f32; intermediate];
49    let mut ffn_hidden = vec![0.0f32; intermediate];
50    let mut ffn_out = vec![0.0f32; hidden];
51
52    for layer_idx in 0..config.num_layers {
53        let prefix = format!("model.layers.{layer_idx}");
54
55        // Attention norm
56        let norm_w = weights.tensor(&format!("{prefix}.input_layernorm.weight"));
57        rms_norm(&mut normed, &hidden_state, norm_w, config.rms_norm_eps);
58
59        // QKV projections
60        let q_w = weights.tensor(&format!("{prefix}.self_attn.q_proj.weight"));
61        let k_w = weights.tensor(&format!("{prefix}.self_attn.k_proj.weight"));
62        let v_w = weights.tensor(&format!("{prefix}.self_attn.v_proj.weight"));
63        matmul(&mut q, &normed, q_w, 1, hidden, num_heads * head_dim);
64        matmul(&mut k, &normed, k_w, 1, hidden, num_kv_heads * head_dim);
65        matmul(&mut v, &normed, v_w, 1, hidden, num_kv_heads * head_dim);
66
67        // Add QKV biases if present (Qwen2 uses biases on QKV)
68        if let Some(q_bias) = weights.get(&format!("{prefix}.self_attn.q_proj.bias")) {
69            elementwise_add_inplace(&mut q, q_bias);
70        }
71        if let Some(k_bias) = weights.get(&format!("{prefix}.self_attn.k_proj.bias")) {
72            elementwise_add_inplace(&mut k, k_bias);
73        }
74        if let Some(v_bias) = weights.get(&format!("{prefix}.self_attn.v_proj.bias")) {
75            elementwise_add_inplace(&mut v, v_bias);
76        }
77
78        // RoPE
79        rope(&mut q, pos, head_dim, num_heads, config.rope_theta);
80        rope(&mut k, pos, head_dim, num_kv_heads, config.rope_theta);
81
82        // Update KV cache
83        cache.append(layer_idx, &k, &v);
84
85        // Attention
86        attention(
87            &mut attn_out,
88            &q,
89            cache.k(layer_idx),
90            cache.v(layer_idx),
91            &AttentionParams {
92                seq_len: pos + 1,
93                num_heads,
94                num_kv_heads,
95                head_dim,
96            },
97        );
98
99        // Output projection
100        let o_w = weights.tensor(&format!("{prefix}.self_attn.o_proj.weight"));
101        matmul(
102            &mut attn_proj,
103            &attn_out,
104            o_w,
105            1,
106            num_heads * head_dim,
107            hidden,
108        );
109
110        // Residual
111        elementwise_add(&mut residual, &hidden_state, &attn_proj);
112
113        // FFN norm
114        let ffn_norm_w = weights.tensor(&format!("{prefix}.post_attention_layernorm.weight"));
115        rms_norm(&mut normed, &residual, ffn_norm_w, config.rms_norm_eps);
116
117        // FFN
118        let gate_w = weights.tensor(&format!("{prefix}.mlp.gate_proj.weight"));
119        let up_w = weights.tensor(&format!("{prefix}.mlp.up_proj.weight"));
120        let down_w = weights.tensor(&format!("{prefix}.mlp.down_proj.weight"));
121
122        matmul(&mut gate, &normed, gate_w, 1, hidden, intermediate);
123        silu(&mut gate_act, &gate);
124        matmul(&mut up, &normed, up_w, 1, hidden, intermediate);
125        elementwise_mul(&mut ffn_hidden, &gate_act, &up);
126        matmul(&mut ffn_out, &ffn_hidden, down_w, 1, intermediate, hidden);
127
128        // Residual
129        elementwise_add(&mut hidden_state, &residual, &ffn_out);
130    }
131
132    // Final norm
133    let final_norm_w = weights.tensor("model.norm.weight");
134    rms_norm(
135        &mut normed,
136        &hidden_state,
137        final_norm_w,
138        config.rms_norm_eps,
139    );
140
141    // Logits projection (may use tied embeddings)
142    let lm_head_w = weights
143        .get("lm_head.weight")
144        .unwrap_or_else(|| weights.tensor("model.embed_tokens.weight"));
145    let mut logits = vec![0.0f32; vocab];
146    matmul(&mut logits, &normed, lm_head_w, 1, hidden, vocab);
147
148    logits
149}
150
151// --- Kernel wrappers (delegate to optimized kernels module) ---
152
153fn rms_norm(output: &mut [f32], input: &[f32], weight: &[f32], eps: f32) {
154    kernels::rms_norm(output, input, weight, eps);
155}
156
157fn matmul(output: &mut [f32], input: &[f32], weight: &[f32], m: usize, k: usize, n: usize) {
158    kernels::matmul(output, input, weight, m, k, n);
159}
160
161fn silu(output: &mut [f32], input: &[f32]) {
162    kernels::silu(output, input);
163}
164
165fn elementwise_mul(output: &mut [f32], a: &[f32], b: &[f32]) {
166    kernels::elementwise_mul(output, a, b);
167}
168
169fn elementwise_add(output: &mut [f32], a: &[f32], b: &[f32]) {
170    kernels::elementwise_add(output, a, b);
171}
172
173fn elementwise_add_inplace(a: &mut [f32], b: &[f32]) {
174    for i in 0..a.len() {
175        a[i] += b[i];
176    }
177}
178
179fn rope(data: &mut [f32], pos: usize, head_dim: usize, num_heads: usize, theta: f32) {
180    for h in 0..num_heads {
181        let head_offset = h * head_dim;
182        for i in (0..head_dim).step_by(2) {
183            let freq = 1.0 / theta.powf(i as f32 / head_dim as f32);
184            let angle = pos as f32 * freq;
185            let cos_val = angle.cos();
186            let sin_val = angle.sin();
187            let x0 = data[head_offset + i];
188            let x1 = data[head_offset + i + 1];
189            data[head_offset + i] = x0 * cos_val - x1 * sin_val;
190            data[head_offset + i + 1] = x0 * sin_val + x1 * cos_val;
191        }
192    }
193}
194
195struct AttentionParams {
196    seq_len: usize,
197    num_heads: usize,
198    num_kv_heads: usize,
199    head_dim: usize,
200}
201
202fn attention(
203    output: &mut [f32],
204    q: &[f32],
205    k_cache: &[f32],
206    v_cache: &[f32],
207    params: &AttentionParams,
208) {
209    kernels::attention(
210        output,
211        q,
212        k_cache,
213        v_cache,
214        params.seq_len,
215        params.num_heads,
216        params.num_kv_heads,
217        params.head_dim,
218    );
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use std::collections::HashMap;
225
226    #[test]
227    fn rms_norm_basic() {
228        let input = vec![1.0, 2.0, 3.0, 4.0];
229        let weight = vec![1.0; 4];
230        let mut output = vec![0.0; 4];
231        rms_norm(&mut output, &input, &weight, 1e-5);
232
233        // RMS = sqrt((1+4+9+16)/4) = sqrt(7.5) ≈ 2.7386
234        let rms = (30.0f32 / 4.0 + 1e-5).sqrt();
235        let expected: Vec<f32> = input.iter().map(|x| x / rms).collect();
236        for (a, b) in output.iter().zip(expected.iter()) {
237            assert!((a - b).abs() < 1e-5, "got {a}, expected {b}");
238        }
239    }
240
241    #[test]
242    fn matmul_basic() {
243        // [1, 2] x [[1, 3], [2, 4]]^T = [1*1+2*2, 1*3+2*4] = [5, 11]
244        // weight stored as [n, k] = [[1, 2], [3, 4]]
245        let input = vec![1.0, 2.0];
246        let weight = vec![1.0, 2.0, 3.0, 4.0]; // row 0: [1,2], row 1: [3,4]
247        let mut output = vec![0.0; 2];
248        matmul(&mut output, &input, &weight, 1, 2, 2);
249        assert!((output[0] - 5.0).abs() < 1e-6);
250        assert!((output[1] - 11.0).abs() < 1e-6);
251    }
252
253    #[test]
254    fn silu_basic() {
255        let input = vec![0.0, 1.0, -1.0];
256        let mut output = vec![0.0; 3];
257        silu(&mut output, &input);
258        // silu(0) = 0, silu(1) = 1/(1+e^-1) ≈ 0.7311
259        assert!((output[0] - 0.0).abs() < 1e-6);
260        assert!((output[1] - 0.7311).abs() < 1e-3);
261        assert!((output[2] - (-0.2689)).abs() < 1e-3);
262    }
263
264    #[test]
265    fn softmax_basic() {
266        let mut values = vec![1.0, 2.0, 3.0];
267        kernels::softmax(&mut values);
268        let sum: f32 = values.iter().sum();
269        assert!((sum - 1.0).abs() < 1e-6);
270        assert!(values[2] > values[1]);
271        assert!(values[1] > values[0]);
272    }
273
274    #[test]
275    fn rope_preserves_magnitude() {
276        // RoPE is a rotation, so it should preserve vector magnitude
277        let mut data = vec![1.0, 0.0, 0.0, 1.0]; // 1 head, dim=4
278        let mag_before: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
279        rope(&mut data, 5, 4, 1, 10000.0);
280        let mag_after: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
281        assert!(
282            (mag_before - mag_after).abs() < 1e-5,
283            "RoPE changed magnitude: {mag_before} → {mag_after}"
284        );
285    }
286
287    #[test]
288    fn forward_with_tiny_model() {
289        // Build a minimal model to verify the interpreter runs without panicking
290        let config = ModelConfig {
291            architecture: Architecture::Llama,
292            hidden_size: 8,
293            intermediate_size: 16,
294            num_layers: 1,
295            num_attention_heads: 2,
296            num_kv_heads: 1,
297            head_dim: 4,
298            vocab_size: 16,
299            max_seq_len: 32,
300            rms_norm_eps: 1e-5,
301            rope_theta: 10000.0,
302            dtype: DType::F32,
303            sliding_window_size: None,
304            qkv_bias: false,
305        };
306
307        let graph = forgellm_frontend::graph_builder::build_graph(&config).unwrap();
308
309        // Create random-ish weights
310        let mut tensors = HashMap::new();
311        let h = 8;
312        let inter = 16;
313        let vocab = 16;
314        let num_heads = 2;
315        let num_kv_heads = 1;
316        let head_dim = 4;
317
318        tensors.insert("model.embed_tokens.weight".into(), vec![0.1f32; vocab * h]);
319        tensors.insert(
320            "model.layers.0.input_layernorm.weight".into(),
321            vec![1.0f32; h],
322        );
323        tensors.insert(
324            "model.layers.0.self_attn.q_proj.weight".into(),
325            vec![0.01f32; num_heads * head_dim * h],
326        );
327        tensors.insert(
328            "model.layers.0.self_attn.k_proj.weight".into(),
329            vec![0.01f32; num_kv_heads * head_dim * h],
330        );
331        tensors.insert(
332            "model.layers.0.self_attn.v_proj.weight".into(),
333            vec![0.01f32; num_kv_heads * head_dim * h],
334        );
335        tensors.insert(
336            "model.layers.0.self_attn.o_proj.weight".into(),
337            vec![0.01f32; h * num_heads * head_dim],
338        );
339        tensors.insert(
340            "model.layers.0.post_attention_layernorm.weight".into(),
341            vec![1.0f32; h],
342        );
343        tensors.insert(
344            "model.layers.0.mlp.gate_proj.weight".into(),
345            vec![0.01f32; inter * h],
346        );
347        tensors.insert(
348            "model.layers.0.mlp.up_proj.weight".into(),
349            vec![0.01f32; inter * h],
350        );
351        tensors.insert(
352            "model.layers.0.mlp.down_proj.weight".into(),
353            vec![0.01f32; h * inter],
354        );
355        tensors.insert("model.norm.weight".into(), vec![1.0f32; h]);
356        tensors.insert("lm_head.weight".into(), vec![0.01f32; vocab * h]);
357
358        let weights = ModelWeights { tensors };
359        let mut kv_cache = KVCache::new(1, num_kv_heads, head_dim);
360
361        // Run forward pass
362        let logits = forward(0, 0, &graph, &weights, &mut kv_cache);
363        assert_eq!(logits.len(), vocab);
364        assert_eq!(kv_cache.len(), 0); // advance not called by forward
365
366        // Logits should be finite
367        for &l in &logits {
368            assert!(l.is_finite(), "logit is not finite: {l}");
369        }
370    }
371
372    #[test]
373    fn forward_multi_token() {
374        let config = ModelConfig {
375            architecture: Architecture::Llama,
376            hidden_size: 8,
377            intermediate_size: 16,
378            num_layers: 1,
379            num_attention_heads: 2,
380            num_kv_heads: 1,
381            head_dim: 4,
382            vocab_size: 16,
383            max_seq_len: 32,
384            rms_norm_eps: 1e-5,
385            rope_theta: 10000.0,
386            dtype: DType::F32,
387            sliding_window_size: None,
388            qkv_bias: false,
389        };
390
391        let graph = forgellm_frontend::graph_builder::build_graph(&config).unwrap();
392
393        let mut tensors = HashMap::new();
394        let h = 8;
395        let inter = 16;
396        let vocab = 16;
397
398        tensors.insert("model.embed_tokens.weight".into(), vec![0.1f32; vocab * h]);
399        tensors.insert("model.layers.0.input_layernorm.weight".into(), vec![1.0; h]);
400        tensors.insert(
401            "model.layers.0.self_attn.q_proj.weight".into(),
402            vec![0.01; 8 * h],
403        );
404        tensors.insert(
405            "model.layers.0.self_attn.k_proj.weight".into(),
406            vec![0.01; 4 * h],
407        );
408        tensors.insert(
409            "model.layers.0.self_attn.v_proj.weight".into(),
410            vec![0.01; 4 * h],
411        );
412        tensors.insert(
413            "model.layers.0.self_attn.o_proj.weight".into(),
414            vec![0.01; h * 8],
415        );
416        tensors.insert(
417            "model.layers.0.post_attention_layernorm.weight".into(),
418            vec![1.0; h],
419        );
420        tensors.insert(
421            "model.layers.0.mlp.gate_proj.weight".into(),
422            vec![0.01; inter * h],
423        );
424        tensors.insert(
425            "model.layers.0.mlp.up_proj.weight".into(),
426            vec![0.01; inter * h],
427        );
428        tensors.insert(
429            "model.layers.0.mlp.down_proj.weight".into(),
430            vec![0.01; h * inter],
431        );
432        tensors.insert("model.norm.weight".into(), vec![1.0; h]);
433        tensors.insert("lm_head.weight".into(), vec![0.01; vocab * h]);
434
435        let weights = ModelWeights { tensors };
436        let mut cache = KVCache::new(1, 1, 4);
437
438        // Generate 3 tokens
439        for pos in 0..3 {
440            let logits = forward(1, pos, &graph, &weights, &mut cache);
441            assert_eq!(logits.len(), vocab);
442            cache.advance();
443        }
444
445        assert_eq!(cache.len(), 3);
446    }
447
448    // ── Real-world validation tests ──────────────────────────────────────
449
450    /// Build a tiny model with distinguishable per-token embeddings so that
451    /// different token IDs produce different logit distributions.
452    fn tiny_model_with_varied_weights() -> (ModelConfig, Graph, ModelWeights) {
453        let config = ModelConfig {
454            architecture: Architecture::Llama,
455            hidden_size: 8,
456            intermediate_size: 16,
457            num_layers: 1,
458            num_attention_heads: 2,
459            num_kv_heads: 1,
460            head_dim: 4,
461            vocab_size: 16,
462            max_seq_len: 32,
463            rms_norm_eps: 1e-5,
464            rope_theta: 10000.0,
465            dtype: DType::F32,
466            sliding_window_size: None,
467            qkv_bias: false,
468        };
469
470        let graph = forgellm_frontend::graph_builder::build_graph(&config).unwrap();
471
472        let h = 8;
473        let inter = 16;
474        let vocab = 16;
475        let num_heads = 2;
476        let num_kv_heads = 1;
477        let head_dim = 4;
478
479        let mut tensors = HashMap::new();
480
481        // Varied embeddings: each token gets a distinct embedding vector
482        let mut embed = vec![0.0f32; vocab * h];
483        for tok in 0..vocab {
484            for d in 0..h {
485                embed[tok * h + d] = ((tok * h + d) as f32 + 1.0) * 0.05;
486            }
487        }
488        tensors.insert("model.embed_tokens.weight".into(), embed);
489
490        tensors.insert(
491            "model.layers.0.input_layernorm.weight".into(),
492            vec![1.0f32; h],
493        );
494        // Use varied projection weights so the model isn't degenerate
495        let q_w: Vec<f32> = (0..num_heads * head_dim * h)
496            .map(|i| ((i % 7) as f32 + 1.0) * 0.01)
497            .collect();
498        let k_w: Vec<f32> = (0..num_kv_heads * head_dim * h)
499            .map(|i| ((i % 5) as f32 + 1.0) * 0.01)
500            .collect();
501        let v_w: Vec<f32> = (0..num_kv_heads * head_dim * h)
502            .map(|i| ((i % 3) as f32 + 1.0) * 0.01)
503            .collect();
504        let o_w: Vec<f32> = (0..h * num_heads * head_dim)
505            .map(|i| ((i % 11) as f32 + 1.0) * 0.01)
506            .collect();
507        tensors.insert("model.layers.0.self_attn.q_proj.weight".into(), q_w);
508        tensors.insert("model.layers.0.self_attn.k_proj.weight".into(), k_w);
509        tensors.insert("model.layers.0.self_attn.v_proj.weight".into(), v_w);
510        tensors.insert("model.layers.0.self_attn.o_proj.weight".into(), o_w);
511        tensors.insert(
512            "model.layers.0.post_attention_layernorm.weight".into(),
513            vec![1.0f32; h],
514        );
515
516        let gate_w: Vec<f32> = (0..inter * h)
517            .map(|i| ((i % 13) as f32 + 1.0) * 0.01)
518            .collect();
519        let up_w: Vec<f32> = (0..inter * h)
520            .map(|i| ((i % 9) as f32 + 1.0) * 0.01)
521            .collect();
522        let down_w: Vec<f32> = (0..h * inter)
523            .map(|i| ((i % 7) as f32 + 1.0) * 0.01)
524            .collect();
525        tensors.insert("model.layers.0.mlp.gate_proj.weight".into(), gate_w);
526        tensors.insert("model.layers.0.mlp.up_proj.weight".into(), up_w);
527        tensors.insert("model.layers.0.mlp.down_proj.weight".into(), down_w);
528
529        tensors.insert("model.norm.weight".into(), vec![1.0f32; h]);
530        // Varied lm_head so different hidden states map to different logits
531        let lm_head: Vec<f32> = (0..vocab * h)
532            .map(|i| ((i % 17) as f32 - 8.0) * 0.02)
533            .collect();
534        tensors.insert("lm_head.weight".into(), lm_head);
535
536        let weights = ModelWeights { tensors };
537        (config, graph, weights)
538    }
539
540    #[test]
541    fn different_prompts_produce_different_logits() {
542        // Two different token IDs at pos=0 should produce different logit vectors.
543        // This validates that the model distinguishes inputs (not degenerate).
544        let (_config, graph, weights) = tiny_model_with_varied_weights();
545
546        let mut cache1 = KVCache::new(1, 1, 4);
547        let logits1 = forward(0, 0, &graph, &weights, &mut cache1);
548
549        let mut cache2 = KVCache::new(1, 1, 4);
550        let logits2 = forward(5, 0, &graph, &weights, &mut cache2);
551
552        // Both should be finite
553        for &l in &logits1 {
554            assert!(l.is_finite(), "logits1 contains non-finite value: {l}");
555        }
556        for &l in &logits2 {
557            assert!(l.is_finite(), "logits2 contains non-finite value: {l}");
558        }
559
560        // They should differ (the model is not degenerate)
561        let differs = logits1
562            .iter()
563            .zip(logits2.iter())
564            .any(|(a, b)| (a - b).abs() > 1e-6);
565        assert!(
566            differs,
567            "different input tokens should produce different logit distributions"
568        );
569    }
570
571    #[test]
572    fn cache_reset_produces_same_logits() {
573        // After clearing the cache, running the same token at pos=0 should
574        // produce identical logits as a fresh run. This validates that clear()
575        // truly resets all state for independent multi-request serving.
576        let (_config, graph, weights) = tiny_model_with_varied_weights();
577
578        // First run: fresh cache
579        let mut cache = KVCache::new(1, 1, 4);
580        let logits_fresh = forward(3, 0, &graph, &weights, &mut cache);
581
582        // Advance the cache with more tokens to build up state
583        cache.advance();
584        let _ = forward(7, 1, &graph, &weights, &mut cache);
585        cache.advance();
586        assert_eq!(cache.len(), 2);
587
588        // Clear and re-run
589        cache.clear();
590        assert_eq!(cache.len(), 0);
591        let logits_after_reset = forward(3, 0, &graph, &weights, &mut cache);
592
593        // Should be identical
594        for (i, (a, b)) in logits_fresh
595            .iter()
596            .zip(logits_after_reset.iter())
597            .enumerate()
598        {
599            assert!(
600                (a - b).abs() < 1e-6,
601                "logit[{i}] differs after reset: fresh={a}, after_reset={b}"
602            );
603        }
604    }
605
606    #[test]
607    fn forward_at_pos_zero_no_nan() {
608        // pos=0 is the first token where seq_len=1 in attention.
609        // This is a common edge case: softmax over a single element,
610        // RoPE with angle=0, and KV cache with one entry.
611        let (_config, graph, weights) = tiny_model_with_varied_weights();
612        let mut cache = KVCache::new(1, 1, 4);
613
614        let logits = forward(0, 0, &graph, &weights, &mut cache);
615        assert_eq!(logits.len(), 16);
616
617        for (i, &l) in logits.iter().enumerate() {
618            assert!(
619                !l.is_nan(),
620                "logit[{i}] is NaN at pos=0 — likely a softmax or attention bug"
621            );
622            assert!(!l.is_infinite(), "logit[{i}] is infinite at pos=0");
623        }
624    }
625}