Skip to main content

forgellm_runtime/
interpreter.rs

1//! Interpreter — executes IR graphs directly on CPU.
2//!
3//! Uses optimized kernels from the `kernels` module for compute-heavy
4//! operations (matmul, rms_norm). Validates correctness and serves as
5//! the primary inference path until the AOT codegen is ready.
6
7use forgellm_frontend::ir::*;
8use forgellm_frontend::weight_loader::ModelWeights;
9
10use crate::kernels;
11use crate::kv_cache::KVCache;
12
13/// Run a single forward pass for one token through the model.
14///
15/// Returns logits of shape `[vocab_size]`.
16pub fn forward(
17    token_id: u32,
18    pos: usize,
19    graph: &Graph,
20    weights: &ModelWeights,
21    cache: &mut KVCache,
22) -> Vec<f32> {
23    let config = graph.config.as_ref().expect("graph must have config");
24
25    let hidden = config.hidden_size;
26    let intermediate = config.intermediate_size;
27    let num_heads = config.num_attention_heads;
28    let num_kv_heads = config.num_kv_heads;
29    let head_dim = config.head_dim;
30    let vocab = config.vocab_size;
31
32    // Embedding lookup
33    let embed_w = weights.tensor("model.embed_tokens.weight");
34    let mut hidden_state = vec![0.0f32; hidden];
35    let offset = token_id as usize * hidden;
36    hidden_state.copy_from_slice(&embed_w[offset..offset + hidden]);
37
38    // Pre-allocate buffers
39    let mut normed = vec![0.0f32; hidden];
40    let mut q = vec![0.0f32; num_heads * head_dim];
41    let mut k = vec![0.0f32; num_kv_heads * head_dim];
42    let mut v = vec![0.0f32; num_kv_heads * head_dim];
43    let mut attn_out = vec![0.0f32; num_heads * head_dim];
44    let mut attn_proj = vec![0.0f32; hidden];
45    let mut residual = vec![0.0f32; hidden];
46    let mut gate = vec![0.0f32; intermediate];
47    let mut gate_act = vec![0.0f32; intermediate];
48    let mut up = vec![0.0f32; intermediate];
49    let mut ffn_hidden = vec![0.0f32; intermediate];
50    let mut ffn_out = vec![0.0f32; hidden];
51
52    for layer_idx in 0..config.num_layers {
53        let prefix = format!("model.layers.{layer_idx}");
54
55        // Attention norm
56        let norm_w = weights.tensor(&format!("{prefix}.input_layernorm.weight"));
57        rms_norm(&mut normed, &hidden_state, norm_w, config.rms_norm_eps);
58
59        // QKV projections
60        let q_w = weights.tensor(&format!("{prefix}.self_attn.q_proj.weight"));
61        let k_w = weights.tensor(&format!("{prefix}.self_attn.k_proj.weight"));
62        let v_w = weights.tensor(&format!("{prefix}.self_attn.v_proj.weight"));
63        matmul(&mut q, &normed, q_w, 1, hidden, num_heads * head_dim);
64        matmul(&mut k, &normed, k_w, 1, hidden, num_kv_heads * head_dim);
65        matmul(&mut v, &normed, v_w, 1, hidden, num_kv_heads * head_dim);
66
67        // Add QKV biases if present (Qwen2 uses biases on QKV)
68        if let Some(q_bias) = weights.get(&format!("{prefix}.self_attn.q_proj.bias")) {
69            elementwise_add_inplace(&mut q, q_bias);
70        }
71        if let Some(k_bias) = weights.get(&format!("{prefix}.self_attn.k_proj.bias")) {
72            elementwise_add_inplace(&mut k, k_bias);
73        }
74        if let Some(v_bias) = weights.get(&format!("{prefix}.self_attn.v_proj.bias")) {
75            elementwise_add_inplace(&mut v, v_bias);
76        }
77
78        // RoPE
79        rope(&mut q, pos, head_dim, num_heads, config.rope_theta);
80        rope(&mut k, pos, head_dim, num_kv_heads, config.rope_theta);
81
82        // Update KV cache
83        cache.append(layer_idx, &k, &v);
84
85        // Attention
86        attention(
87            &mut attn_out,
88            &q,
89            cache.k(layer_idx),
90            cache.v(layer_idx),
91            &AttentionParams {
92                seq_len: pos + 1,
93                num_heads,
94                num_kv_heads,
95                head_dim,
96            },
97        );
98
99        // Output projection
100        let o_w = weights.tensor(&format!("{prefix}.self_attn.o_proj.weight"));
101        matmul(
102            &mut attn_proj,
103            &attn_out,
104            o_w,
105            1,
106            num_heads * head_dim,
107            hidden,
108        );
109
110        // Residual
111        elementwise_add(&mut residual, &hidden_state, &attn_proj);
112
113        // FFN norm
114        let ffn_norm_w = weights.tensor(&format!("{prefix}.post_attention_layernorm.weight"));
115        rms_norm(&mut normed, &residual, ffn_norm_w, config.rms_norm_eps);
116
117        // FFN
118        let gate_w = weights.tensor(&format!("{prefix}.mlp.gate_proj.weight"));
119        let up_w = weights.tensor(&format!("{prefix}.mlp.up_proj.weight"));
120        let down_w = weights.tensor(&format!("{prefix}.mlp.down_proj.weight"));
121
122        matmul(&mut gate, &normed, gate_w, 1, hidden, intermediate);
123        silu(&mut gate_act, &gate);
124        matmul(&mut up, &normed, up_w, 1, hidden, intermediate);
125        elementwise_mul(&mut ffn_hidden, &gate_act, &up);
126        matmul(&mut ffn_out, &ffn_hidden, down_w, 1, intermediate, hidden);
127
128        // Residual
129        elementwise_add(&mut hidden_state, &residual, &ffn_out);
130    }
131
132    // Final norm
133    let final_norm_w = weights.tensor("model.norm.weight");
134    rms_norm(
135        &mut normed,
136        &hidden_state,
137        final_norm_w,
138        config.rms_norm_eps,
139    );
140
141    // Logits projection (may use tied embeddings)
142    let lm_head_w = weights
143        .get("lm_head.weight")
144        .unwrap_or_else(|| weights.tensor("model.embed_tokens.weight"));
145    let mut logits = vec![0.0f32; vocab];
146    matmul(&mut logits, &normed, lm_head_w, 1, hidden, vocab);
147
148    logits
149}
150
151// --- Kernel wrappers (delegate to optimized kernels module) ---
152
153fn rms_norm(output: &mut [f32], input: &[f32], weight: &[f32], eps: f32) {
154    kernels::rms_norm(output, input, weight, eps);
155}
156
157fn matmul(output: &mut [f32], input: &[f32], weight: &[f32], m: usize, k: usize, n: usize) {
158    kernels::matmul(output, input, weight, m, k, n);
159}
160
161fn silu(output: &mut [f32], input: &[f32]) {
162    kernels::silu(output, input);
163}
164
165fn elementwise_mul(output: &mut [f32], a: &[f32], b: &[f32]) {
166    kernels::elementwise_mul(output, a, b);
167}
168
169fn elementwise_add(output: &mut [f32], a: &[f32], b: &[f32]) {
170    kernels::elementwise_add(output, a, b);
171}
172
173fn elementwise_add_inplace(a: &mut [f32], b: &[f32]) {
174    for i in 0..a.len() {
175        a[i] += b[i];
176    }
177}
178
179fn rope(data: &mut [f32], pos: usize, head_dim: usize, num_heads: usize, theta: f32) {
180    for h in 0..num_heads {
181        let head_offset = h * head_dim;
182        for i in (0..head_dim).step_by(2) {
183            let freq = 1.0 / theta.powf(i as f32 / head_dim as f32);
184            let angle = pos as f32 * freq;
185            let cos_val = angle.cos();
186            let sin_val = angle.sin();
187            let x0 = data[head_offset + i];
188            let x1 = data[head_offset + i + 1];
189            data[head_offset + i] = x0 * cos_val - x1 * sin_val;
190            data[head_offset + i + 1] = x0 * sin_val + x1 * cos_val;
191        }
192    }
193}
194
195struct AttentionParams {
196    seq_len: usize,
197    num_heads: usize,
198    num_kv_heads: usize,
199    head_dim: usize,
200}
201
202fn attention(
203    output: &mut [f32],
204    q: &[f32],
205    k_cache: &[f32],
206    v_cache: &[f32],
207    params: &AttentionParams,
208) {
209    kernels::attention(
210        output,
211        q,
212        k_cache,
213        v_cache,
214        params.seq_len,
215        params.num_heads,
216        params.num_kv_heads,
217        params.head_dim,
218    );
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use std::collections::HashMap;
225
226    #[test]
227    fn rms_norm_basic() {
228        let input = vec![1.0, 2.0, 3.0, 4.0];
229        let weight = vec![1.0; 4];
230        let mut output = vec![0.0; 4];
231        rms_norm(&mut output, &input, &weight, 1e-5);
232
233        // RMS = sqrt((1+4+9+16)/4) = sqrt(7.5) ≈ 2.7386
234        let rms = (30.0f32 / 4.0 + 1e-5).sqrt();
235        let expected: Vec<f32> = input.iter().map(|x| x / rms).collect();
236        for (a, b) in output.iter().zip(expected.iter()) {
237            assert!((a - b).abs() < 1e-5, "got {a}, expected {b}");
238        }
239    }
240
241    #[test]
242    fn matmul_basic() {
243        // [1, 2] x [[1, 3], [2, 4]]^T = [1*1+2*2, 1*3+2*4] = [5, 11]
244        // weight stored as [n, k] = [[1, 2], [3, 4]]
245        let input = vec![1.0, 2.0];
246        let weight = vec![1.0, 2.0, 3.0, 4.0]; // row 0: [1,2], row 1: [3,4]
247        let mut output = vec![0.0; 2];
248        matmul(&mut output, &input, &weight, 1, 2, 2);
249        assert!((output[0] - 5.0).abs() < 1e-6);
250        assert!((output[1] - 11.0).abs() < 1e-6);
251    }
252
253    #[test]
254    fn silu_basic() {
255        let input = vec![0.0, 1.0, -1.0];
256        let mut output = vec![0.0; 3];
257        silu(&mut output, &input);
258        // silu(0) = 0, silu(1) = 1/(1+e^-1) ≈ 0.7311
259        assert!((output[0] - 0.0).abs() < 1e-6);
260        assert!((output[1] - 0.7311).abs() < 1e-3);
261        assert!((output[2] - (-0.2689)).abs() < 1e-3);
262    }
263
264    #[test]
265    fn softmax_basic() {
266        let mut values = vec![1.0, 2.0, 3.0];
267        kernels::softmax(&mut values);
268        let sum: f32 = values.iter().sum();
269        assert!((sum - 1.0).abs() < 1e-6);
270        assert!(values[2] > values[1]);
271        assert!(values[1] > values[0]);
272    }
273
274    #[test]
275    fn rope_preserves_magnitude() {
276        // RoPE is a rotation, so it should preserve vector magnitude
277        let mut data = vec![1.0, 0.0, 0.0, 1.0]; // 1 head, dim=4
278        let mag_before: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
279        rope(&mut data, 5, 4, 1, 10000.0);
280        let mag_after: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
281        assert!(
282            (mag_before - mag_after).abs() < 1e-5,
283            "RoPE changed magnitude: {mag_before} → {mag_after}"
284        );
285    }
286
287    #[test]
288    fn forward_with_tiny_model() {
289        // Build a minimal model to verify the interpreter runs without panicking
290        let config = ModelConfig {
291            architecture: Architecture::Llama,
292            hidden_size: 8,
293            intermediate_size: 16,
294            num_layers: 1,
295            num_attention_heads: 2,
296            num_kv_heads: 1,
297            head_dim: 4,
298            vocab_size: 16,
299            max_seq_len: 32,
300            rms_norm_eps: 1e-5,
301            rope_theta: 10000.0,
302            dtype: DType::F32,
303        };
304
305        let graph = forgellm_frontend::graph_builder::build_graph(&config).unwrap();
306
307        // Create random-ish weights
308        let mut tensors = HashMap::new();
309        let h = 8;
310        let inter = 16;
311        let vocab = 16;
312        let num_heads = 2;
313        let num_kv_heads = 1;
314        let head_dim = 4;
315
316        tensors.insert("model.embed_tokens.weight".into(), vec![0.1f32; vocab * h]);
317        tensors.insert(
318            "model.layers.0.input_layernorm.weight".into(),
319            vec![1.0f32; h],
320        );
321        tensors.insert(
322            "model.layers.0.self_attn.q_proj.weight".into(),
323            vec![0.01f32; num_heads * head_dim * h],
324        );
325        tensors.insert(
326            "model.layers.0.self_attn.k_proj.weight".into(),
327            vec![0.01f32; num_kv_heads * head_dim * h],
328        );
329        tensors.insert(
330            "model.layers.0.self_attn.v_proj.weight".into(),
331            vec![0.01f32; num_kv_heads * head_dim * h],
332        );
333        tensors.insert(
334            "model.layers.0.self_attn.o_proj.weight".into(),
335            vec![0.01f32; h * num_heads * head_dim],
336        );
337        tensors.insert(
338            "model.layers.0.post_attention_layernorm.weight".into(),
339            vec![1.0f32; h],
340        );
341        tensors.insert(
342            "model.layers.0.mlp.gate_proj.weight".into(),
343            vec![0.01f32; inter * h],
344        );
345        tensors.insert(
346            "model.layers.0.mlp.up_proj.weight".into(),
347            vec![0.01f32; inter * h],
348        );
349        tensors.insert(
350            "model.layers.0.mlp.down_proj.weight".into(),
351            vec![0.01f32; h * inter],
352        );
353        tensors.insert("model.norm.weight".into(), vec![1.0f32; h]);
354        tensors.insert("lm_head.weight".into(), vec![0.01f32; vocab * h]);
355
356        let weights = ModelWeights { tensors };
357        let mut kv_cache = KVCache::new(1, num_kv_heads, head_dim);
358
359        // Run forward pass
360        let logits = forward(0, 0, &graph, &weights, &mut kv_cache);
361        assert_eq!(logits.len(), vocab);
362        assert_eq!(kv_cache.len(), 0); // advance not called by forward
363
364        // Logits should be finite
365        for &l in &logits {
366            assert!(l.is_finite(), "logit is not finite: {l}");
367        }
368    }
369
370    #[test]
371    fn forward_multi_token() {
372        let config = ModelConfig {
373            architecture: Architecture::Llama,
374            hidden_size: 8,
375            intermediate_size: 16,
376            num_layers: 1,
377            num_attention_heads: 2,
378            num_kv_heads: 1,
379            head_dim: 4,
380            vocab_size: 16,
381            max_seq_len: 32,
382            rms_norm_eps: 1e-5,
383            rope_theta: 10000.0,
384            dtype: DType::F32,
385        };
386
387        let graph = forgellm_frontend::graph_builder::build_graph(&config).unwrap();
388
389        let mut tensors = HashMap::new();
390        let h = 8;
391        let inter = 16;
392        let vocab = 16;
393
394        tensors.insert("model.embed_tokens.weight".into(), vec![0.1f32; vocab * h]);
395        tensors.insert("model.layers.0.input_layernorm.weight".into(), vec![1.0; h]);
396        tensors.insert(
397            "model.layers.0.self_attn.q_proj.weight".into(),
398            vec![0.01; 8 * h],
399        );
400        tensors.insert(
401            "model.layers.0.self_attn.k_proj.weight".into(),
402            vec![0.01; 4 * h],
403        );
404        tensors.insert(
405            "model.layers.0.self_attn.v_proj.weight".into(),
406            vec![0.01; 4 * h],
407        );
408        tensors.insert(
409            "model.layers.0.self_attn.o_proj.weight".into(),
410            vec![0.01; h * 8],
411        );
412        tensors.insert(
413            "model.layers.0.post_attention_layernorm.weight".into(),
414            vec![1.0; h],
415        );
416        tensors.insert(
417            "model.layers.0.mlp.gate_proj.weight".into(),
418            vec![0.01; inter * h],
419        );
420        tensors.insert(
421            "model.layers.0.mlp.up_proj.weight".into(),
422            vec![0.01; inter * h],
423        );
424        tensors.insert(
425            "model.layers.0.mlp.down_proj.weight".into(),
426            vec![0.01; h * inter],
427        );
428        tensors.insert("model.norm.weight".into(), vec![1.0; h]);
429        tensors.insert("lm_head.weight".into(), vec![0.01; vocab * h]);
430
431        let weights = ModelWeights { tensors };
432        let mut cache = KVCache::new(1, 1, 4);
433
434        // Generate 3 tokens
435        for pos in 0..3 {
436            let logits = forward(1, pos, &graph, &weights, &mut cache);
437            assert_eq!(logits.len(), vocab);
438            cache.advance();
439        }
440
441        assert_eq!(cache.len(), 3);
442    }
443}