1use crate::ir::*;
8
9pub fn build_graph(config: &ModelConfig) -> Result<Graph, GraphBuildError> {
14 match config.architecture {
15 Architecture::Llama | Architecture::Mistral => build_llama_graph(config),
16 Architecture::Qwen2 => build_llama_graph(config),
17 Architecture::Gemma => build_llama_graph(config), Architecture::StableLM => build_llama_graph(config),
19 Architecture::Phi3 => build_llama_graph(config), }
22}
23
24fn build_llama_graph(config: &ModelConfig) -> Result<Graph, GraphBuildError> {
29 let mut graph =
30 Graph::new(format!("{}-graph", config.architecture)).with_config(config.clone());
31
32 let hidden = config.hidden_size;
33 let vocab = config.vocab_size;
34 let dtype = config.dtype;
35
36 let input_ids = graph.input("input_ids", Shape::new(vec![1, 0]), DType::I32);
38
39 let embed_weight = graph.load_weight(
41 "model.embed_tokens.weight",
42 Shape::new(vec![vocab, hidden]),
43 dtype,
44 );
45
46 let tid = graph.alloc_tensor_id();
47 let mut current = graph.add_node(
48 Op::Embedding {
49 vocab_size: vocab,
50 embed_dim: hidden,
51 },
52 vec![input_ids, embed_weight],
53 TensorInfo {
54 id: tid,
55 name: "embed_output".into(),
56 shape: Shape::new(vec![1, 0, hidden]),
57 dtype,
58 },
59 );
60
61 for layer_idx in 0..config.num_layers {
63 let prefix = format!("model.layers.{layer_idx}");
64 current = build_llama_layer(&mut graph, config, &prefix, current)?;
65 }
66
67 let final_norm_w = graph.load_weight("model.norm.weight", Shape::new(vec![hidden]), dtype);
69 let tid = graph.alloc_tensor_id();
70 let normed = graph.add_node(
71 Op::RMSNorm {
72 eps: config.rms_norm_eps,
73 },
74 vec![current, final_norm_w],
75 TensorInfo {
76 id: tid,
77 name: "final_norm".into(),
78 shape: Shape::new(vec![1, 0, hidden]),
79 dtype,
80 },
81 );
82
83 let lm_head_weight =
85 graph.load_weight("lm_head.weight", Shape::new(vec![vocab, hidden]), dtype);
86 let tid = graph.alloc_tensor_id();
87 let _logits = graph.add_node(
88 Op::LogitsProjection { vocab_size: vocab },
89 vec![normed, lm_head_weight],
90 TensorInfo {
91 id: tid,
92 name: "logits".into(),
93 shape: Shape::new(vec![1, 0, vocab]),
94 dtype: DType::F32, },
96 );
97
98 graph.validate().map_err(GraphBuildError::Validation)?;
99 Ok(graph)
100}
101
102fn build_llama_layer(
106 graph: &mut Graph,
107 config: &ModelConfig,
108 prefix: &str,
109 input: NodeId,
110) -> Result<NodeId, GraphBuildError> {
111 let hidden = config.hidden_size;
112 let intermediate = config.intermediate_size;
113 let num_heads = config.num_attention_heads;
114 let num_kv_heads = config.num_kv_heads;
115 let head_dim = config.head_dim;
116 let dtype = config.dtype;
117
118 let attn_norm_w = graph.load_weight(
122 format!("{prefix}.input_layernorm.weight"),
123 Shape::new(vec![hidden]),
124 dtype,
125 );
126 let tid = graph.alloc_tensor_id();
127 let normed = graph.add_node(
128 Op::RMSNorm {
129 eps: config.rms_norm_eps,
130 },
131 vec![input, attn_norm_w],
132 TensorInfo {
133 id: tid,
134 name: format!("{prefix}.attn_norm"),
135 shape: Shape::new(vec![1, 0, hidden]),
136 dtype,
137 },
138 );
139
140 let q_weight = graph.load_weight(
142 format!("{prefix}.self_attn.q_proj.weight"),
143 Shape::new(vec![num_heads * head_dim, hidden]),
144 dtype,
145 );
146 let tid = graph.alloc_tensor_id();
147 let q = graph.add_node(
148 Op::MatMul,
149 vec![normed, q_weight],
150 TensorInfo {
151 id: tid,
152 name: format!("{prefix}.q_proj"),
153 shape: Shape::new(vec![1, 0, num_heads * head_dim]),
154 dtype,
155 },
156 );
157
158 let k_weight = graph.load_weight(
159 format!("{prefix}.self_attn.k_proj.weight"),
160 Shape::new(vec![num_kv_heads * head_dim, hidden]),
161 dtype,
162 );
163 let tid = graph.alloc_tensor_id();
164 let k = graph.add_node(
165 Op::MatMul,
166 vec![normed, k_weight],
167 TensorInfo {
168 id: tid,
169 name: format!("{prefix}.k_proj"),
170 shape: Shape::new(vec![1, 0, num_kv_heads * head_dim]),
171 dtype,
172 },
173 );
174
175 let v_weight = graph.load_weight(
176 format!("{prefix}.self_attn.v_proj.weight"),
177 Shape::new(vec![num_kv_heads * head_dim, hidden]),
178 dtype,
179 );
180 let tid = graph.alloc_tensor_id();
181 let v = graph.add_node(
182 Op::MatMul,
183 vec![normed, v_weight],
184 TensorInfo {
185 id: tid,
186 name: format!("{prefix}.v_proj"),
187 shape: Shape::new(vec![1, 0, num_kv_heads * head_dim]),
188 dtype,
189 },
190 );
191
192 let tid = graph.alloc_tensor_id();
194 let q_rope = graph.add_node(
195 Op::RoPE {
196 max_seq_len: config.max_seq_len,
197 rope_theta: config.rope_theta,
198 head_dim,
199 },
200 vec![q],
201 TensorInfo {
202 id: tid,
203 name: format!("{prefix}.q_rope"),
204 shape: Shape::new(vec![1, 0, num_heads * head_dim]),
205 dtype,
206 },
207 );
208
209 let tid = graph.alloc_tensor_id();
210 let k_rope = graph.add_node(
211 Op::RoPE {
212 max_seq_len: config.max_seq_len,
213 rope_theta: config.rope_theta,
214 head_dim,
215 },
216 vec![k],
217 TensorInfo {
218 id: tid,
219 name: format!("{prefix}.k_rope"),
220 shape: Shape::new(vec![1, 0, num_kv_heads * head_dim]),
221 dtype,
222 },
223 );
224
225 let tid = graph.alloc_tensor_id();
227 let attn_out = graph.add_node(
228 Op::Attention {
229 num_heads,
230 num_kv_heads,
231 head_dim,
232 },
233 vec![q_rope, k_rope, v],
234 TensorInfo {
235 id: tid,
236 name: format!("{prefix}.attn"),
237 shape: Shape::new(vec![1, 0, num_heads * head_dim]),
238 dtype,
239 },
240 );
241
242 let o_weight = graph.load_weight(
244 format!("{prefix}.self_attn.o_proj.weight"),
245 Shape::new(vec![hidden, num_heads * head_dim]),
246 dtype,
247 );
248 let tid = graph.alloc_tensor_id();
249 let attn_proj = graph.add_node(
250 Op::MatMul,
251 vec![attn_out, o_weight],
252 TensorInfo {
253 id: tid,
254 name: format!("{prefix}.o_proj"),
255 shape: Shape::new(vec![1, 0, hidden]),
256 dtype,
257 },
258 );
259
260 let tid = graph.alloc_tensor_id();
262 let after_attn = graph.add_node(
263 Op::Residual,
264 vec![input, attn_proj],
265 TensorInfo {
266 id: tid,
267 name: format!("{prefix}.attn_residual"),
268 shape: Shape::new(vec![1, 0, hidden]),
269 dtype,
270 },
271 );
272
273 let ffn_norm_w = graph.load_weight(
277 format!("{prefix}.post_attention_layernorm.weight"),
278 Shape::new(vec![hidden]),
279 dtype,
280 );
281 let tid = graph.alloc_tensor_id();
282 let ffn_normed = graph.add_node(
283 Op::RMSNorm {
284 eps: config.rms_norm_eps,
285 },
286 vec![after_attn, ffn_norm_w],
287 TensorInfo {
288 id: tid,
289 name: format!("{prefix}.ffn_norm"),
290 shape: Shape::new(vec![1, 0, hidden]),
291 dtype,
292 },
293 );
294
295 let gate_weight = graph.load_weight(
297 format!("{prefix}.mlp.gate_proj.weight"),
298 Shape::new(vec![intermediate, hidden]),
299 dtype,
300 );
301 let tid = graph.alloc_tensor_id();
302 let gate = graph.add_node(
303 Op::MatMul,
304 vec![ffn_normed, gate_weight],
305 TensorInfo {
306 id: tid,
307 name: format!("{prefix}.gate_proj"),
308 shape: Shape::new(vec![1, 0, intermediate]),
309 dtype,
310 },
311 );
312
313 let tid = graph.alloc_tensor_id();
315 let gate_act = graph.add_node(
316 Op::SiLU,
317 vec![gate],
318 TensorInfo {
319 id: tid,
320 name: format!("{prefix}.gate_silu"),
321 shape: Shape::new(vec![1, 0, intermediate]),
322 dtype,
323 },
324 );
325
326 let up_weight = graph.load_weight(
328 format!("{prefix}.mlp.up_proj.weight"),
329 Shape::new(vec![intermediate, hidden]),
330 dtype,
331 );
332 let tid = graph.alloc_tensor_id();
333 let up = graph.add_node(
334 Op::MatMul,
335 vec![ffn_normed, up_weight],
336 TensorInfo {
337 id: tid,
338 name: format!("{prefix}.up_proj"),
339 shape: Shape::new(vec![1, 0, intermediate]),
340 dtype,
341 },
342 );
343
344 let tid = graph.alloc_tensor_id();
346 let ffn_hidden = graph.add_node(
347 Op::Mul,
348 vec![gate_act, up],
349 TensorInfo {
350 id: tid,
351 name: format!("{prefix}.gate_up_mul"),
352 shape: Shape::new(vec![1, 0, intermediate]),
353 dtype,
354 },
355 );
356
357 let down_weight = graph.load_weight(
359 format!("{prefix}.mlp.down_proj.weight"),
360 Shape::new(vec![hidden, intermediate]),
361 dtype,
362 );
363 let tid = graph.alloc_tensor_id();
364 let ffn_out = graph.add_node(
365 Op::MatMul,
366 vec![ffn_hidden, down_weight],
367 TensorInfo {
368 id: tid,
369 name: format!("{prefix}.down_proj"),
370 shape: Shape::new(vec![1, 0, hidden]),
371 dtype,
372 },
373 );
374
375 let tid = graph.alloc_tensor_id();
377 let output = graph.add_node(
378 Op::Residual,
379 vec![after_attn, ffn_out],
380 TensorInfo {
381 id: tid,
382 name: format!("{prefix}.ffn_residual"),
383 shape: Shape::new(vec![1, 0, hidden]),
384 dtype,
385 },
386 );
387
388 Ok(output)
389}
390
391#[derive(Debug, thiserror::Error)]
393pub enum GraphBuildError {
394 #[error("unsupported architecture: {0}")]
395 UnsupportedArchitecture(String),
396
397 #[error("graph validation failed: {0}")]
398 Validation(#[from] GraphError),
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404
405 fn llama_1b_config() -> ModelConfig {
406 ModelConfig {
407 architecture: Architecture::Llama,
408 hidden_size: 2048,
409 intermediate_size: 5632,
410 num_layers: 16,
411 num_attention_heads: 32,
412 num_kv_heads: 8,
413 head_dim: 64,
414 vocab_size: 32000,
415 max_seq_len: 2048,
416 rms_norm_eps: 1e-5,
417 rope_theta: 10000.0,
418 dtype: DType::F16,
419 }
420 }
421
422 fn smollm_135m_config() -> ModelConfig {
423 ModelConfig {
424 architecture: Architecture::Llama,
425 hidden_size: 576,
426 intermediate_size: 1536,
427 num_layers: 30,
428 num_attention_heads: 9,
429 num_kv_heads: 3,
430 head_dim: 64,
431 vocab_size: 49152,
432 max_seq_len: 2048,
433 rms_norm_eps: 1e-5,
434 rope_theta: 10000.0,
435 dtype: DType::BF16,
436 }
437 }
438
439 #[test]
440 fn build_llama_1b_graph() {
441 let config = llama_1b_config();
442 let graph = build_graph(&config).unwrap();
443
444 assert!(!graph.is_empty());
445 assert!(graph.config.is_some());
446 assert!(graph.validate().is_ok());
447
448 assert!(graph.weights.contains_key("model.embed_tokens.weight"));
450 assert!(graph.weights.contains_key("model.norm.weight"));
451 assert!(graph.weights.contains_key("lm_head.weight"));
452 assert!(graph
453 .weights
454 .contains_key("model.layers.0.input_layernorm.weight"));
455 assert!(graph
456 .weights
457 .contains_key("model.layers.0.self_attn.q_proj.weight"));
458 assert!(graph
459 .weights
460 .contains_key("model.layers.0.mlp.gate_proj.weight"));
461 assert!(graph
462 .weights
463 .contains_key("model.layers.15.mlp.down_proj.weight"));
464 }
465
466 #[test]
467 fn build_smollm_135m_graph() {
468 let config = smollm_135m_config();
469 let graph = build_graph(&config).unwrap();
470
471 assert!(graph.validate().is_ok());
472 assert!(graph
473 .weights
474 .contains_key("model.layers.29.mlp.down_proj.weight"));
475
476 let embed = &graph.weights["model.embed_tokens.weight"];
478 assert_eq!(embed.shape, Shape::new(vec![49152, 576]));
479
480 let q_proj = &graph.weights["model.layers.0.self_attn.q_proj.weight"];
481 assert_eq!(q_proj.shape, Shape::new(vec![576, 576])); }
483
484 #[test]
485 fn graph_node_count() {
486 let config = llama_1b_config();
487 let graph = build_graph(&config).unwrap();
488
489 assert!(graph.len() > 100);
496 }
497
498 #[test]
499 fn graph_has_correct_output() {
500 let config = llama_1b_config();
501 let graph = build_graph(&config).unwrap();
502
503 let last = graph.node(graph.len() - 1);
505 assert!(matches!(last.op, Op::LogitsProjection { .. }));
506 assert_eq!(last.output.dtype, DType::F32);
507 }
508
509 #[test]
510 fn qwen2_uses_llama_builder() {
511 let config = ModelConfig {
512 architecture: Architecture::Qwen2,
513 hidden_size: 1536,
514 intermediate_size: 8960,
515 num_layers: 28,
516 num_attention_heads: 12,
517 num_kv_heads: 2,
518 head_dim: 128,
519 vocab_size: 151936,
520 max_seq_len: 32768,
521 rms_norm_eps: 1e-6,
522 rope_theta: 1000000.0,
523 dtype: DType::BF16,
524 };
525
526 let graph = build_graph(&config).unwrap();
527 assert!(graph.validate().is_ok());
528 assert!(graph
529 .weights
530 .contains_key("model.layers.27.mlp.down_proj.weight"));
531 }
532
533 #[test]
534 fn all_architectures_supported() {
535 for arch in [
537 Architecture::Llama,
538 Architecture::Qwen2,
539 Architecture::Mistral,
540 Architecture::Phi3,
541 Architecture::Gemma,
542 Architecture::StableLM,
543 ] {
544 let config = ModelConfig {
545 architecture: arch.clone(),
546 hidden_size: 64,
547 intermediate_size: 128,
548 num_layers: 1,
549 num_attention_heads: 4,
550 num_kv_heads: 2,
551 head_dim: 16,
552 vocab_size: 256,
553 max_seq_len: 64,
554 rms_norm_eps: 1e-5,
555 rope_theta: 10000.0,
556 dtype: DType::F16,
557 };
558 let result = build_graph(&config);
559 assert!(result.is_ok(), "failed to build graph for {arch}");
560 }
561 }
562
563 #[test]
564 fn topological_order_is_valid() {
565 let config = smollm_135m_config();
566 let graph = build_graph(&config).unwrap();
567
568 for node in &graph.nodes {
570 for &input_id in &node.inputs {
571 assert!(
572 input_id < node.id,
573 "node {} references future node {}",
574 node.id,
575 input_id
576 );
577 }
578 }
579 }
580}