1use forgellm_frontend::ir::*;
8use forgellm_frontend::weight_loader::ModelWeights;
9
10use crate::kernels;
11use crate::kv_cache::KVCache;
12
13pub fn forward(
17 token_id: u32,
18 pos: usize,
19 graph: &Graph,
20 weights: &ModelWeights,
21 cache: &mut KVCache,
22) -> Vec<f32> {
23 let config = graph.config.as_ref().expect("graph must have config");
24
25 let hidden = config.hidden_size;
26 let intermediate = config.intermediate_size;
27 let num_heads = config.num_attention_heads;
28 let num_kv_heads = config.num_kv_heads;
29 let head_dim = config.head_dim;
30 let vocab = config.vocab_size;
31
32 let embed_w = weights.tensor("model.embed_tokens.weight");
34 let mut hidden_state = vec![0.0f32; hidden];
35 let offset = token_id as usize * hidden;
36 hidden_state.copy_from_slice(&embed_w[offset..offset + hidden]);
37
38 let mut normed = vec![0.0f32; hidden];
40 let mut q = vec![0.0f32; num_heads * head_dim];
41 let mut k = vec![0.0f32; num_kv_heads * head_dim];
42 let mut v = vec![0.0f32; num_kv_heads * head_dim];
43 let mut attn_out = vec![0.0f32; num_heads * head_dim];
44 let mut attn_proj = vec![0.0f32; hidden];
45 let mut residual = vec![0.0f32; hidden];
46 let mut gate = vec![0.0f32; intermediate];
47 let mut gate_act = vec![0.0f32; intermediate];
48 let mut up = vec![0.0f32; intermediate];
49 let mut ffn_hidden = vec![0.0f32; intermediate];
50 let mut ffn_out = vec![0.0f32; hidden];
51
52 for layer_idx in 0..config.num_layers {
53 let prefix = format!("model.layers.{layer_idx}");
54
55 let norm_w = weights.tensor(&format!("{prefix}.input_layernorm.weight"));
57 rms_norm(&mut normed, &hidden_state, norm_w, config.rms_norm_eps);
58
59 let q_w = weights.tensor(&format!("{prefix}.self_attn.q_proj.weight"));
61 let k_w = weights.tensor(&format!("{prefix}.self_attn.k_proj.weight"));
62 let v_w = weights.tensor(&format!("{prefix}.self_attn.v_proj.weight"));
63 matmul(&mut q, &normed, q_w, 1, hidden, num_heads * head_dim);
64 matmul(&mut k, &normed, k_w, 1, hidden, num_kv_heads * head_dim);
65 matmul(&mut v, &normed, v_w, 1, hidden, num_kv_heads * head_dim);
66
67 if let Some(q_bias) = weights.get(&format!("{prefix}.self_attn.q_proj.bias")) {
69 elementwise_add_inplace(&mut q, q_bias);
70 }
71 if let Some(k_bias) = weights.get(&format!("{prefix}.self_attn.k_proj.bias")) {
72 elementwise_add_inplace(&mut k, k_bias);
73 }
74 if let Some(v_bias) = weights.get(&format!("{prefix}.self_attn.v_proj.bias")) {
75 elementwise_add_inplace(&mut v, v_bias);
76 }
77
78 rope(&mut q, pos, head_dim, num_heads, config.rope_theta);
80 rope(&mut k, pos, head_dim, num_kv_heads, config.rope_theta);
81
82 cache.append(layer_idx, &k, &v);
84
85 attention(
87 &mut attn_out,
88 &q,
89 cache.k(layer_idx),
90 cache.v(layer_idx),
91 &AttentionParams {
92 seq_len: pos + 1,
93 num_heads,
94 num_kv_heads,
95 head_dim,
96 },
97 );
98
99 let o_w = weights.tensor(&format!("{prefix}.self_attn.o_proj.weight"));
101 matmul(
102 &mut attn_proj,
103 &attn_out,
104 o_w,
105 1,
106 num_heads * head_dim,
107 hidden,
108 );
109
110 elementwise_add(&mut residual, &hidden_state, &attn_proj);
112
113 let ffn_norm_w = weights.tensor(&format!("{prefix}.post_attention_layernorm.weight"));
115 rms_norm(&mut normed, &residual, ffn_norm_w, config.rms_norm_eps);
116
117 let gate_w = weights.tensor(&format!("{prefix}.mlp.gate_proj.weight"));
119 let up_w = weights.tensor(&format!("{prefix}.mlp.up_proj.weight"));
120 let down_w = weights.tensor(&format!("{prefix}.mlp.down_proj.weight"));
121
122 matmul(&mut gate, &normed, gate_w, 1, hidden, intermediate);
123 silu(&mut gate_act, &gate);
124 matmul(&mut up, &normed, up_w, 1, hidden, intermediate);
125 elementwise_mul(&mut ffn_hidden, &gate_act, &up);
126 matmul(&mut ffn_out, &ffn_hidden, down_w, 1, intermediate, hidden);
127
128 elementwise_add(&mut hidden_state, &residual, &ffn_out);
130 }
131
132 let final_norm_w = weights.tensor("model.norm.weight");
134 rms_norm(
135 &mut normed,
136 &hidden_state,
137 final_norm_w,
138 config.rms_norm_eps,
139 );
140
141 let lm_head_w = weights
143 .get("lm_head.weight")
144 .unwrap_or_else(|| weights.tensor("model.embed_tokens.weight"));
145 let mut logits = vec![0.0f32; vocab];
146 matmul(&mut logits, &normed, lm_head_w, 1, hidden, vocab);
147
148 logits
149}
150
151fn rms_norm(output: &mut [f32], input: &[f32], weight: &[f32], eps: f32) {
154 kernels::rms_norm(output, input, weight, eps);
155}
156
157fn matmul(output: &mut [f32], input: &[f32], weight: &[f32], m: usize, k: usize, n: usize) {
158 kernels::matmul(output, input, weight, m, k, n);
159}
160
161fn silu(output: &mut [f32], input: &[f32]) {
162 kernels::silu(output, input);
163}
164
165fn elementwise_mul(output: &mut [f32], a: &[f32], b: &[f32]) {
166 kernels::elementwise_mul(output, a, b);
167}
168
169fn elementwise_add(output: &mut [f32], a: &[f32], b: &[f32]) {
170 kernels::elementwise_add(output, a, b);
171}
172
173fn elementwise_add_inplace(a: &mut [f32], b: &[f32]) {
174 for i in 0..a.len() {
175 a[i] += b[i];
176 }
177}
178
179fn rope(data: &mut [f32], pos: usize, head_dim: usize, num_heads: usize, theta: f32) {
180 for h in 0..num_heads {
181 let head_offset = h * head_dim;
182 for i in (0..head_dim).step_by(2) {
183 let freq = 1.0 / theta.powf(i as f32 / head_dim as f32);
184 let angle = pos as f32 * freq;
185 let cos_val = angle.cos();
186 let sin_val = angle.sin();
187 let x0 = data[head_offset + i];
188 let x1 = data[head_offset + i + 1];
189 data[head_offset + i] = x0 * cos_val - x1 * sin_val;
190 data[head_offset + i + 1] = x0 * sin_val + x1 * cos_val;
191 }
192 }
193}
194
195struct AttentionParams {
196 seq_len: usize,
197 num_heads: usize,
198 num_kv_heads: usize,
199 head_dim: usize,
200}
201
202fn attention(
203 output: &mut [f32],
204 q: &[f32],
205 k_cache: &[f32],
206 v_cache: &[f32],
207 params: &AttentionParams,
208) {
209 kernels::attention(
210 output,
211 q,
212 k_cache,
213 v_cache,
214 params.seq_len,
215 params.num_heads,
216 params.num_kv_heads,
217 params.head_dim,
218 );
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use std::collections::HashMap;
225
226 #[test]
227 fn rms_norm_basic() {
228 let input = vec![1.0, 2.0, 3.0, 4.0];
229 let weight = vec![1.0; 4];
230 let mut output = vec![0.0; 4];
231 rms_norm(&mut output, &input, &weight, 1e-5);
232
233 let rms = (30.0f32 / 4.0 + 1e-5).sqrt();
235 let expected: Vec<f32> = input.iter().map(|x| x / rms).collect();
236 for (a, b) in output.iter().zip(expected.iter()) {
237 assert!((a - b).abs() < 1e-5, "got {a}, expected {b}");
238 }
239 }
240
241 #[test]
242 fn matmul_basic() {
243 let input = vec![1.0, 2.0];
246 let weight = vec![1.0, 2.0, 3.0, 4.0]; let mut output = vec![0.0; 2];
248 matmul(&mut output, &input, &weight, 1, 2, 2);
249 assert!((output[0] - 5.0).abs() < 1e-6);
250 assert!((output[1] - 11.0).abs() < 1e-6);
251 }
252
253 #[test]
254 fn silu_basic() {
255 let input = vec![0.0, 1.0, -1.0];
256 let mut output = vec![0.0; 3];
257 silu(&mut output, &input);
258 assert!((output[0] - 0.0).abs() < 1e-6);
260 assert!((output[1] - 0.7311).abs() < 1e-3);
261 assert!((output[2] - (-0.2689)).abs() < 1e-3);
262 }
263
264 #[test]
265 fn softmax_basic() {
266 let mut values = vec![1.0, 2.0, 3.0];
267 kernels::softmax(&mut values);
268 let sum: f32 = values.iter().sum();
269 assert!((sum - 1.0).abs() < 1e-6);
270 assert!(values[2] > values[1]);
271 assert!(values[1] > values[0]);
272 }
273
274 #[test]
275 fn rope_preserves_magnitude() {
276 let mut data = vec![1.0, 0.0, 0.0, 1.0]; let mag_before: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
279 rope(&mut data, 5, 4, 1, 10000.0);
280 let mag_after: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
281 assert!(
282 (mag_before - mag_after).abs() < 1e-5,
283 "RoPE changed magnitude: {mag_before} → {mag_after}"
284 );
285 }
286
287 #[test]
288 fn forward_with_tiny_model() {
289 let config = ModelConfig {
291 architecture: Architecture::Llama,
292 hidden_size: 8,
293 intermediate_size: 16,
294 num_layers: 1,
295 num_attention_heads: 2,
296 num_kv_heads: 1,
297 head_dim: 4,
298 vocab_size: 16,
299 max_seq_len: 32,
300 rms_norm_eps: 1e-5,
301 rope_theta: 10000.0,
302 dtype: DType::F32,
303 };
304
305 let graph = forgellm_frontend::graph_builder::build_graph(&config).unwrap();
306
307 let mut tensors = HashMap::new();
309 let h = 8;
310 let inter = 16;
311 let vocab = 16;
312 let num_heads = 2;
313 let num_kv_heads = 1;
314 let head_dim = 4;
315
316 tensors.insert("model.embed_tokens.weight".into(), vec![0.1f32; vocab * h]);
317 tensors.insert(
318 "model.layers.0.input_layernorm.weight".into(),
319 vec![1.0f32; h],
320 );
321 tensors.insert(
322 "model.layers.0.self_attn.q_proj.weight".into(),
323 vec![0.01f32; num_heads * head_dim * h],
324 );
325 tensors.insert(
326 "model.layers.0.self_attn.k_proj.weight".into(),
327 vec![0.01f32; num_kv_heads * head_dim * h],
328 );
329 tensors.insert(
330 "model.layers.0.self_attn.v_proj.weight".into(),
331 vec![0.01f32; num_kv_heads * head_dim * h],
332 );
333 tensors.insert(
334 "model.layers.0.self_attn.o_proj.weight".into(),
335 vec![0.01f32; h * num_heads * head_dim],
336 );
337 tensors.insert(
338 "model.layers.0.post_attention_layernorm.weight".into(),
339 vec![1.0f32; h],
340 );
341 tensors.insert(
342 "model.layers.0.mlp.gate_proj.weight".into(),
343 vec![0.01f32; inter * h],
344 );
345 tensors.insert(
346 "model.layers.0.mlp.up_proj.weight".into(),
347 vec![0.01f32; inter * h],
348 );
349 tensors.insert(
350 "model.layers.0.mlp.down_proj.weight".into(),
351 vec![0.01f32; h * inter],
352 );
353 tensors.insert("model.norm.weight".into(), vec![1.0f32; h]);
354 tensors.insert("lm_head.weight".into(), vec![0.01f32; vocab * h]);
355
356 let weights = ModelWeights { tensors };
357 let mut kv_cache = KVCache::new(1, num_kv_heads, head_dim);
358
359 let logits = forward(0, 0, &graph, &weights, &mut kv_cache);
361 assert_eq!(logits.len(), vocab);
362 assert_eq!(kv_cache.len(), 0); for &l in &logits {
366 assert!(l.is_finite(), "logit is not finite: {l}");
367 }
368 }
369
370 #[test]
371 fn forward_multi_token() {
372 let config = ModelConfig {
373 architecture: Architecture::Llama,
374 hidden_size: 8,
375 intermediate_size: 16,
376 num_layers: 1,
377 num_attention_heads: 2,
378 num_kv_heads: 1,
379 head_dim: 4,
380 vocab_size: 16,
381 max_seq_len: 32,
382 rms_norm_eps: 1e-5,
383 rope_theta: 10000.0,
384 dtype: DType::F32,
385 };
386
387 let graph = forgellm_frontend::graph_builder::build_graph(&config).unwrap();
388
389 let mut tensors = HashMap::new();
390 let h = 8;
391 let inter = 16;
392 let vocab = 16;
393
394 tensors.insert("model.embed_tokens.weight".into(), vec![0.1f32; vocab * h]);
395 tensors.insert("model.layers.0.input_layernorm.weight".into(), vec![1.0; h]);
396 tensors.insert(
397 "model.layers.0.self_attn.q_proj.weight".into(),
398 vec![0.01; 8 * h],
399 );
400 tensors.insert(
401 "model.layers.0.self_attn.k_proj.weight".into(),
402 vec![0.01; 4 * h],
403 );
404 tensors.insert(
405 "model.layers.0.self_attn.v_proj.weight".into(),
406 vec![0.01; 4 * h],
407 );
408 tensors.insert(
409 "model.layers.0.self_attn.o_proj.weight".into(),
410 vec![0.01; h * 8],
411 );
412 tensors.insert(
413 "model.layers.0.post_attention_layernorm.weight".into(),
414 vec![1.0; h],
415 );
416 tensors.insert(
417 "model.layers.0.mlp.gate_proj.weight".into(),
418 vec![0.01; inter * h],
419 );
420 tensors.insert(
421 "model.layers.0.mlp.up_proj.weight".into(),
422 vec![0.01; inter * h],
423 );
424 tensors.insert(
425 "model.layers.0.mlp.down_proj.weight".into(),
426 vec![0.01; h * inter],
427 );
428 tensors.insert("model.norm.weight".into(), vec![1.0; h]);
429 tensors.insert("lm_head.weight".into(), vec![0.01; vocab * h]);
430
431 let weights = ModelWeights { tensors };
432 let mut cache = KVCache::new(1, 1, 4);
433
434 for pos in 0..3 {
436 let logits = forward(1, pos, &graph, &weights, &mut cache);
437 assert_eq!(logits.len(), vocab);
438 cache.advance();
439 }
440
441 assert_eq!(cache.len(), 3);
442 }
443}