Skip to main content

forgellm_frontend/
graph_builder.rs

1//! Graph builder — constructs IR computation graphs from model configs.
2//!
3//! Takes a `ModelConfig` (extracted from GGUF metadata or HF config.json)
4//! and builds the full transformer computation graph with all layers,
5//! weight references, and operations.
6
7use crate::ir::*;
8
9/// Build a complete computation graph for a transformer model.
10///
11/// The graph represents the full forward pass from token IDs to logits.
12/// Weight names follow the HuggingFace convention (model.layers.N.*).
13pub fn build_graph(config: &ModelConfig) -> Result<Graph, GraphBuildError> {
14    match config.architecture {
15        Architecture::Llama | Architecture::Mistral => build_llama_graph(config),
16        Architecture::Qwen2 => build_llama_graph(config),
17        Architecture::Gemma => build_llama_graph(config), // Gemma uses same structure with GeLU instead of SiLU
18        Architecture::StableLM => build_llama_graph(config),
19        Architecture::Phi3 => build_llama_graph(config), // Phi3 is structurally similar
20                                                         // All remaining Llama-family architectures use the same builder
21    }
22}
23
24/// Build a Llama-family computation graph.
25///
26/// Architecture: embedding → N × (attention_norm → attention → residual →
27/// ffn_norm → ffn → residual) → final_norm → lm_head
28fn build_llama_graph(config: &ModelConfig) -> Result<Graph, GraphBuildError> {
29    let mut graph =
30        Graph::new(format!("{}-graph", config.architecture)).with_config(config.clone());
31
32    let hidden = config.hidden_size;
33    let vocab = config.vocab_size;
34    let dtype = config.dtype;
35
36    // Input: token IDs [batch, seq_len]
37    let input_ids = graph.input("input_ids", Shape::new(vec![1, 0]), DType::I32);
38
39    // Token embedding: [vocab_size, hidden_size]
40    let embed_weight = graph.load_weight(
41        "model.embed_tokens.weight",
42        Shape::new(vec![vocab, hidden]),
43        dtype,
44    );
45
46    let tid = graph.alloc_tensor_id();
47    let mut current = graph.add_node(
48        Op::Embedding {
49            vocab_size: vocab,
50            embed_dim: hidden,
51        },
52        vec![input_ids, embed_weight],
53        TensorInfo {
54            id: tid,
55            name: "embed_output".into(),
56            shape: Shape::new(vec![1, 0, hidden]),
57            dtype,
58        },
59    );
60
61    // Transformer layers
62    for layer_idx in 0..config.num_layers {
63        let prefix = format!("model.layers.{layer_idx}");
64        current = build_llama_layer(&mut graph, config, &prefix, current)?;
65    }
66
67    // Final RMSNorm
68    let final_norm_w = graph.load_weight("model.norm.weight", Shape::new(vec![hidden]), dtype);
69    let tid = graph.alloc_tensor_id();
70    let normed = graph.add_node(
71        Op::RMSNorm {
72            eps: config.rms_norm_eps,
73        },
74        vec![current, final_norm_w],
75        TensorInfo {
76            id: tid,
77            name: "final_norm".into(),
78            shape: Shape::new(vec![1, 0, hidden]),
79            dtype,
80        },
81    );
82
83    // LM head (logits projection)
84    let lm_head_weight =
85        graph.load_weight("lm_head.weight", Shape::new(vec![vocab, hidden]), dtype);
86    let tid = graph.alloc_tensor_id();
87    let _logits = graph.add_node(
88        Op::LogitsProjection { vocab_size: vocab },
89        vec![normed, lm_head_weight],
90        TensorInfo {
91            id: tid,
92            name: "logits".into(),
93            shape: Shape::new(vec![1, 0, vocab]),
94            dtype: DType::F32, // Logits are always f32
95        },
96    );
97
98    graph.validate().map_err(GraphBuildError::Validation)?;
99    Ok(graph)
100}
101
102/// Build a single Llama transformer layer.
103///
104/// Structure: input_norm → attention → residual → ffn_norm → FFN → residual
105fn build_llama_layer(
106    graph: &mut Graph,
107    config: &ModelConfig,
108    prefix: &str,
109    input: NodeId,
110) -> Result<NodeId, GraphBuildError> {
111    let hidden = config.hidden_size;
112    let intermediate = config.intermediate_size;
113    let num_heads = config.num_attention_heads;
114    let num_kv_heads = config.num_kv_heads;
115    let head_dim = config.head_dim;
116    let dtype = config.dtype;
117
118    // === Self-Attention ===
119
120    // Input LayerNorm
121    let attn_norm_w = graph.load_weight(
122        format!("{prefix}.input_layernorm.weight"),
123        Shape::new(vec![hidden]),
124        dtype,
125    );
126    let tid = graph.alloc_tensor_id();
127    let normed = graph.add_node(
128        Op::RMSNorm {
129            eps: config.rms_norm_eps,
130        },
131        vec![input, attn_norm_w],
132        TensorInfo {
133            id: tid,
134            name: format!("{prefix}.attn_norm"),
135            shape: Shape::new(vec![1, 0, hidden]),
136            dtype,
137        },
138    );
139
140    // Q, K, V projections
141    let q_weight = graph.load_weight(
142        format!("{prefix}.self_attn.q_proj.weight"),
143        Shape::new(vec![num_heads * head_dim, hidden]),
144        dtype,
145    );
146    let tid = graph.alloc_tensor_id();
147    let q = graph.add_node(
148        Op::MatMul,
149        vec![normed, q_weight],
150        TensorInfo {
151            id: tid,
152            name: format!("{prefix}.q_proj"),
153            shape: Shape::new(vec![1, 0, num_heads * head_dim]),
154            dtype,
155        },
156    );
157
158    let k_weight = graph.load_weight(
159        format!("{prefix}.self_attn.k_proj.weight"),
160        Shape::new(vec![num_kv_heads * head_dim, hidden]),
161        dtype,
162    );
163    let tid = graph.alloc_tensor_id();
164    let k = graph.add_node(
165        Op::MatMul,
166        vec![normed, k_weight],
167        TensorInfo {
168            id: tid,
169            name: format!("{prefix}.k_proj"),
170            shape: Shape::new(vec![1, 0, num_kv_heads * head_dim]),
171            dtype,
172        },
173    );
174
175    let v_weight = graph.load_weight(
176        format!("{prefix}.self_attn.v_proj.weight"),
177        Shape::new(vec![num_kv_heads * head_dim, hidden]),
178        dtype,
179    );
180    let tid = graph.alloc_tensor_id();
181    let v = graph.add_node(
182        Op::MatMul,
183        vec![normed, v_weight],
184        TensorInfo {
185            id: tid,
186            name: format!("{prefix}.v_proj"),
187            shape: Shape::new(vec![1, 0, num_kv_heads * head_dim]),
188            dtype,
189        },
190    );
191
192    // RoPE on Q and K
193    let tid = graph.alloc_tensor_id();
194    let q_rope = graph.add_node(
195        Op::RoPE {
196            max_seq_len: config.max_seq_len,
197            rope_theta: config.rope_theta,
198            head_dim,
199        },
200        vec![q],
201        TensorInfo {
202            id: tid,
203            name: format!("{prefix}.q_rope"),
204            shape: Shape::new(vec![1, 0, num_heads * head_dim]),
205            dtype,
206        },
207    );
208
209    let tid = graph.alloc_tensor_id();
210    let k_rope = graph.add_node(
211        Op::RoPE {
212            max_seq_len: config.max_seq_len,
213            rope_theta: config.rope_theta,
214            head_dim,
215        },
216        vec![k],
217        TensorInfo {
218            id: tid,
219            name: format!("{prefix}.k_rope"),
220            shape: Shape::new(vec![1, 0, num_kv_heads * head_dim]),
221            dtype,
222        },
223    );
224
225    // Attention
226    let tid = graph.alloc_tensor_id();
227    let attn_out = graph.add_node(
228        Op::Attention {
229            num_heads,
230            num_kv_heads,
231            head_dim,
232        },
233        vec![q_rope, k_rope, v],
234        TensorInfo {
235            id: tid,
236            name: format!("{prefix}.attn"),
237            shape: Shape::new(vec![1, 0, num_heads * head_dim]),
238            dtype,
239        },
240    );
241
242    // Output projection
243    let o_weight = graph.load_weight(
244        format!("{prefix}.self_attn.o_proj.weight"),
245        Shape::new(vec![hidden, num_heads * head_dim]),
246        dtype,
247    );
248    let tid = graph.alloc_tensor_id();
249    let attn_proj = graph.add_node(
250        Op::MatMul,
251        vec![attn_out, o_weight],
252        TensorInfo {
253            id: tid,
254            name: format!("{prefix}.o_proj"),
255            shape: Shape::new(vec![1, 0, hidden]),
256            dtype,
257        },
258    );
259
260    // Residual connection
261    let tid = graph.alloc_tensor_id();
262    let after_attn = graph.add_node(
263        Op::Residual,
264        vec![input, attn_proj],
265        TensorInfo {
266            id: tid,
267            name: format!("{prefix}.attn_residual"),
268            shape: Shape::new(vec![1, 0, hidden]),
269            dtype,
270        },
271    );
272
273    // === Feed-Forward Network ===
274
275    // Post-attention LayerNorm
276    let ffn_norm_w = graph.load_weight(
277        format!("{prefix}.post_attention_layernorm.weight"),
278        Shape::new(vec![hidden]),
279        dtype,
280    );
281    let tid = graph.alloc_tensor_id();
282    let ffn_normed = graph.add_node(
283        Op::RMSNorm {
284            eps: config.rms_norm_eps,
285        },
286        vec![after_attn, ffn_norm_w],
287        TensorInfo {
288            id: tid,
289            name: format!("{prefix}.ffn_norm"),
290            shape: Shape::new(vec![1, 0, hidden]),
291            dtype,
292        },
293    );
294
295    // Gate projection (SiLU-gated FFN)
296    let gate_weight = graph.load_weight(
297        format!("{prefix}.mlp.gate_proj.weight"),
298        Shape::new(vec![intermediate, hidden]),
299        dtype,
300    );
301    let tid = graph.alloc_tensor_id();
302    let gate = graph.add_node(
303        Op::MatMul,
304        vec![ffn_normed, gate_weight],
305        TensorInfo {
306            id: tid,
307            name: format!("{prefix}.gate_proj"),
308            shape: Shape::new(vec![1, 0, intermediate]),
309            dtype,
310        },
311    );
312
313    // SiLU activation on gate
314    let tid = graph.alloc_tensor_id();
315    let gate_act = graph.add_node(
316        Op::SiLU,
317        vec![gate],
318        TensorInfo {
319            id: tid,
320            name: format!("{prefix}.gate_silu"),
321            shape: Shape::new(vec![1, 0, intermediate]),
322            dtype,
323        },
324    );
325
326    // Up projection
327    let up_weight = graph.load_weight(
328        format!("{prefix}.mlp.up_proj.weight"),
329        Shape::new(vec![intermediate, hidden]),
330        dtype,
331    );
332    let tid = graph.alloc_tensor_id();
333    let up = graph.add_node(
334        Op::MatMul,
335        vec![ffn_normed, up_weight],
336        TensorInfo {
337            id: tid,
338            name: format!("{prefix}.up_proj"),
339            shape: Shape::new(vec![1, 0, intermediate]),
340            dtype,
341        },
342    );
343
344    // Gate * Up (elementwise multiply)
345    let tid = graph.alloc_tensor_id();
346    let ffn_hidden = graph.add_node(
347        Op::Mul,
348        vec![gate_act, up],
349        TensorInfo {
350            id: tid,
351            name: format!("{prefix}.gate_up_mul"),
352            shape: Shape::new(vec![1, 0, intermediate]),
353            dtype,
354        },
355    );
356
357    // Down projection
358    let down_weight = graph.load_weight(
359        format!("{prefix}.mlp.down_proj.weight"),
360        Shape::new(vec![hidden, intermediate]),
361        dtype,
362    );
363    let tid = graph.alloc_tensor_id();
364    let ffn_out = graph.add_node(
365        Op::MatMul,
366        vec![ffn_hidden, down_weight],
367        TensorInfo {
368            id: tid,
369            name: format!("{prefix}.down_proj"),
370            shape: Shape::new(vec![1, 0, hidden]),
371            dtype,
372        },
373    );
374
375    // Residual connection
376    let tid = graph.alloc_tensor_id();
377    let output = graph.add_node(
378        Op::Residual,
379        vec![after_attn, ffn_out],
380        TensorInfo {
381            id: tid,
382            name: format!("{prefix}.ffn_residual"),
383            shape: Shape::new(vec![1, 0, hidden]),
384            dtype,
385        },
386    );
387
388    Ok(output)
389}
390
391/// Errors in graph construction.
392#[derive(Debug, thiserror::Error)]
393pub enum GraphBuildError {
394    #[error("unsupported architecture: {0}")]
395    UnsupportedArchitecture(String),
396
397    #[error("graph validation failed: {0}")]
398    Validation(#[from] GraphError),
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    fn llama_1b_config() -> ModelConfig {
406        ModelConfig {
407            architecture: Architecture::Llama,
408            hidden_size: 2048,
409            intermediate_size: 5632,
410            num_layers: 16,
411            num_attention_heads: 32,
412            num_kv_heads: 8,
413            head_dim: 64,
414            vocab_size: 32000,
415            max_seq_len: 2048,
416            rms_norm_eps: 1e-5,
417            rope_theta: 10000.0,
418            dtype: DType::F16,
419        }
420    }
421
422    fn smollm_135m_config() -> ModelConfig {
423        ModelConfig {
424            architecture: Architecture::Llama,
425            hidden_size: 576,
426            intermediate_size: 1536,
427            num_layers: 30,
428            num_attention_heads: 9,
429            num_kv_heads: 3,
430            head_dim: 64,
431            vocab_size: 49152,
432            max_seq_len: 2048,
433            rms_norm_eps: 1e-5,
434            rope_theta: 10000.0,
435            dtype: DType::BF16,
436        }
437    }
438
439    #[test]
440    fn build_llama_1b_graph() {
441        let config = llama_1b_config();
442        let graph = build_graph(&config).unwrap();
443
444        assert!(!graph.is_empty());
445        assert!(graph.config.is_some());
446        assert!(graph.validate().is_ok());
447
448        // Check we have weights registered
449        assert!(graph.weights.contains_key("model.embed_tokens.weight"));
450        assert!(graph.weights.contains_key("model.norm.weight"));
451        assert!(graph.weights.contains_key("lm_head.weight"));
452        assert!(graph
453            .weights
454            .contains_key("model.layers.0.input_layernorm.weight"));
455        assert!(graph
456            .weights
457            .contains_key("model.layers.0.self_attn.q_proj.weight"));
458        assert!(graph
459            .weights
460            .contains_key("model.layers.0.mlp.gate_proj.weight"));
461        assert!(graph
462            .weights
463            .contains_key("model.layers.15.mlp.down_proj.weight"));
464    }
465
466    #[test]
467    fn build_smollm_135m_graph() {
468        let config = smollm_135m_config();
469        let graph = build_graph(&config).unwrap();
470
471        assert!(graph.validate().is_ok());
472        assert!(graph
473            .weights
474            .contains_key("model.layers.29.mlp.down_proj.weight"));
475
476        // Verify weight shapes for sub-1B model
477        let embed = &graph.weights["model.embed_tokens.weight"];
478        assert_eq!(embed.shape, Shape::new(vec![49152, 576]));
479
480        let q_proj = &graph.weights["model.layers.0.self_attn.q_proj.weight"];
481        assert_eq!(q_proj.shape, Shape::new(vec![576, 576])); // 9 heads * 64 head_dim = 576
482    }
483
484    #[test]
485    fn graph_node_count() {
486        let config = llama_1b_config();
487        let graph = build_graph(&config).unwrap();
488
489        // Per layer: 2 norms + 4 QKV projections + 2 RoPE + 1 attention +
490        //            1 output proj + 1 residual + 3 FFN projections +
491        //            1 SiLU + 1 mul + 1 residual + 6 weight loads + 2 norm weights = ~23
492        // Plus: 1 input + 1 embed_weight + 1 embedding + 1 final_norm_weight +
493        //       1 final_norm + 1 lm_head_weight + 1 lm_head = 7
494        // Total should be reasonable for 16 layers
495        assert!(graph.len() > 100);
496    }
497
498    #[test]
499    fn graph_has_correct_output() {
500        let config = llama_1b_config();
501        let graph = build_graph(&config).unwrap();
502
503        // Last node should be logits projection
504        let last = graph.node(graph.len() - 1);
505        assert!(matches!(last.op, Op::LogitsProjection { .. }));
506        assert_eq!(last.output.dtype, DType::F32);
507    }
508
509    #[test]
510    fn qwen2_uses_llama_builder() {
511        let config = ModelConfig {
512            architecture: Architecture::Qwen2,
513            hidden_size: 1536,
514            intermediate_size: 8960,
515            num_layers: 28,
516            num_attention_heads: 12,
517            num_kv_heads: 2,
518            head_dim: 128,
519            vocab_size: 151936,
520            max_seq_len: 32768,
521            rms_norm_eps: 1e-6,
522            rope_theta: 1000000.0,
523            dtype: DType::BF16,
524        };
525
526        let graph = build_graph(&config).unwrap();
527        assert!(graph.validate().is_ok());
528        assert!(graph
529            .weights
530            .contains_key("model.layers.27.mlp.down_proj.weight"));
531    }
532
533    #[test]
534    fn all_architectures_supported() {
535        // All current architectures should build successfully
536        for arch in [
537            Architecture::Llama,
538            Architecture::Qwen2,
539            Architecture::Mistral,
540            Architecture::Phi3,
541            Architecture::Gemma,
542            Architecture::StableLM,
543        ] {
544            let config = ModelConfig {
545                architecture: arch.clone(),
546                hidden_size: 64,
547                intermediate_size: 128,
548                num_layers: 1,
549                num_attention_heads: 4,
550                num_kv_heads: 2,
551                head_dim: 16,
552                vocab_size: 256,
553                max_seq_len: 64,
554                rms_norm_eps: 1e-5,
555                rope_theta: 10000.0,
556                dtype: DType::F16,
557            };
558            let result = build_graph(&config);
559            assert!(result.is_ok(), "failed to build graph for {arch}");
560        }
561    }
562
563    #[test]
564    fn topological_order_is_valid() {
565        let config = smollm_135m_config();
566        let graph = build_graph(&config).unwrap();
567
568        // Every node's inputs should have lower IDs (already enforced by validate)
569        for node in &graph.nodes {
570            for &input_id in &node.inputs {
571                assert!(
572                    input_id < node.id,
573                    "node {} references future node {}",
574                    node.id,
575                    input_id
576                );
577            }
578        }
579    }
580}