1use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fmt;
10
11pub type NodeId = usize;
13
14pub type TensorId = usize;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
19pub enum DType {
20 F32,
21 F16,
22 BF16,
23 F8E4M3,
25 F8E5M2,
27 Q8_0,
29 Q4_0,
31 Q4_1,
33 Q2,
35 NF4,
37 I32,
38 I64,
39}
40
41impl DType {
42 pub fn size_bytes(&self) -> usize {
44 match self {
45 DType::F32 | DType::I32 => 4,
46 DType::F16 | DType::BF16 => 2,
47 DType::F8E4M3 | DType::F8E5M2 | DType::Q8_0 => 1,
48 DType::I64 => 8,
49 DType::Q4_0 | DType::Q4_1 | DType::NF4 => 1,
51 DType::Q2 => 1,
52 }
53 }
54
55 pub fn is_quantized(&self) -> bool {
56 matches!(
57 self,
58 DType::Q8_0 | DType::Q4_0 | DType::Q4_1 | DType::Q2 | DType::NF4
59 )
60 }
61
62 pub fn is_float(&self) -> bool {
63 matches!(
64 self,
65 DType::F32 | DType::F16 | DType::BF16 | DType::F8E4M3 | DType::F8E5M2
66 )
67 }
68}
69
70impl fmt::Display for DType {
71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 match self {
73 DType::F32 => write!(f, "f32"),
74 DType::F16 => write!(f, "f16"),
75 DType::BF16 => write!(f, "bf16"),
76 DType::F8E4M3 => write!(f, "f8e4m3"),
77 DType::F8E5M2 => write!(f, "f8e5m2"),
78 DType::Q8_0 => write!(f, "q8_0"),
79 DType::Q4_0 => write!(f, "q4_0"),
80 DType::Q4_1 => write!(f, "q4_1"),
81 DType::Q2 => write!(f, "q2"),
82 DType::NF4 => write!(f, "nf4"),
83 DType::I32 => write!(f, "i32"),
84 DType::I64 => write!(f, "i64"),
85 }
86 }
87}
88
89#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
91pub struct Shape(pub Vec<usize>);
92
93impl Shape {
94 pub fn new(dims: Vec<usize>) -> Self {
95 Self(dims)
96 }
97
98 pub fn scalar() -> Self {
99 Self(vec![])
100 }
101
102 pub fn ndim(&self) -> usize {
103 self.0.len()
104 }
105
106 pub fn numel(&self) -> usize {
107 self.0.iter().product()
108 }
109
110 pub fn dim(&self, i: usize) -> usize {
112 self.0[i]
113 }
114}
115
116impl fmt::Display for Shape {
117 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
118 write!(f, "[")?;
119 for (i, d) in self.0.iter().enumerate() {
120 if i > 0 {
121 write!(f, ", ")?;
122 }
123 write!(f, "{d}")?;
124 }
125 write!(f, "]")
126 }
127}
128
129#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
131pub struct TensorInfo {
132 pub id: TensorId,
133 pub name: String,
134 pub shape: Shape,
135 pub dtype: DType,
136}
137
138#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
140pub enum Op {
141 LoadWeight {
144 name: String,
145 },
146
147 Input {
149 name: String,
150 },
151
152 MatMul,
155
156 BatchMatMul,
158
159 Add,
161 Mul,
162 SiLU,
164 GeLU,
166 ReLU,
168
169 RMSNorm {
172 eps: f32,
173 },
174
175 LayerNorm {
177 eps: f32,
178 },
179
180 RoPE {
183 max_seq_len: usize,
185 rope_theta: f32,
187 head_dim: usize,
189 },
190
191 Attention {
194 num_heads: usize,
195 num_kv_heads: usize,
196 head_dim: usize,
197 },
198
199 Softmax,
201
202 Reshape {
205 shape: Shape,
206 },
207
208 Transpose {
210 dim0: usize,
211 dim1: usize,
212 },
213
214 Contiguous,
216
217 Embedding {
220 vocab_size: usize,
221 embed_dim: usize,
222 },
223
224 LogitsProjection {
227 vocab_size: usize,
228 },
229
230 Residual,
233
234 Cast {
237 to: DType,
238 },
239}
240
241impl fmt::Display for Op {
242 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
243 match self {
244 Op::LoadWeight { name } => write!(f, "LoadWeight({name})"),
245 Op::Input { name } => write!(f, "Input({name})"),
246 Op::MatMul => write!(f, "MatMul"),
247 Op::BatchMatMul => write!(f, "BatchMatMul"),
248 Op::Add => write!(f, "Add"),
249 Op::Mul => write!(f, "Mul"),
250 Op::SiLU => write!(f, "SiLU"),
251 Op::GeLU => write!(f, "GeLU"),
252 Op::ReLU => write!(f, "ReLU"),
253 Op::RMSNorm { eps } => write!(f, "RMSNorm(eps={eps})"),
254 Op::LayerNorm { eps } => write!(f, "LayerNorm(eps={eps})"),
255 Op::RoPE { head_dim, .. } => write!(f, "RoPE(head_dim={head_dim})"),
256 Op::Attention {
257 num_heads,
258 num_kv_heads,
259 head_dim,
260 } => write!(f, "Attention(h={num_heads},kv={num_kv_heads},d={head_dim})"),
261 Op::Softmax => write!(f, "Softmax"),
262 Op::Reshape { shape } => write!(f, "Reshape({shape})"),
263 Op::Transpose { dim0, dim1 } => write!(f, "Transpose({dim0},{dim1})"),
264 Op::Contiguous => write!(f, "Contiguous"),
265 Op::Embedding {
266 vocab_size,
267 embed_dim,
268 } => write!(f, "Embedding(v={vocab_size},d={embed_dim})"),
269 Op::LogitsProjection { vocab_size } => {
270 write!(f, "LogitsProjection(v={vocab_size})")
271 }
272 Op::Residual => write!(f, "Residual"),
273 Op::Cast { to } => write!(f, "Cast({to})"),
274 }
275 }
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct Node {
281 pub id: NodeId,
282 pub op: Op,
284 pub inputs: Vec<NodeId>,
286 pub output: TensorInfo,
288}
289
290#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
292pub enum Architecture {
293 Llama,
294 Qwen2,
295 Mistral,
296 Phi3,
297 Gemma,
298 StableLM,
299}
300
301impl fmt::Display for Architecture {
302 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303 match self {
304 Architecture::Llama => write!(f, "Llama"),
305 Architecture::Qwen2 => write!(f, "Qwen2"),
306 Architecture::Mistral => write!(f, "Mistral"),
307 Architecture::Phi3 => write!(f, "Phi3"),
308 Architecture::Gemma => write!(f, "Gemma"),
309 Architecture::StableLM => write!(f, "StableLM"),
310 }
311 }
312}
313
314#[derive(Debug, Clone, Serialize, Deserialize)]
316pub struct ModelConfig {
317 pub architecture: Architecture,
318 pub hidden_size: usize,
319 pub intermediate_size: usize,
320 pub num_layers: usize,
321 pub num_attention_heads: usize,
322 pub num_kv_heads: usize,
323 pub head_dim: usize,
324 pub vocab_size: usize,
325 pub max_seq_len: usize,
326 pub rms_norm_eps: f32,
327 pub rope_theta: f32,
328 pub dtype: DType,
329}
330
331#[derive(Debug, Clone, Serialize, Deserialize)]
333pub struct Graph {
334 pub name: String,
335 pub config: Option<ModelConfig>,
336 pub nodes: Vec<Node>,
337 pub weights: HashMap<String, TensorInfo>,
339 next_node_id: NodeId,
340 next_tensor_id: TensorId,
341}
342
343impl Graph {
344 pub fn new(name: impl Into<String>) -> Self {
345 Self {
346 name: name.into(),
347 config: None,
348 nodes: Vec::new(),
349 weights: HashMap::new(),
350 next_node_id: 0,
351 next_tensor_id: 0,
352 }
353 }
354
355 pub fn with_config(mut self, config: ModelConfig) -> Self {
356 self.config = Some(config);
357 self
358 }
359
360 pub fn add_node(&mut self, op: Op, inputs: Vec<NodeId>, output: TensorInfo) -> NodeId {
362 let id = self.next_node_id;
363 self.next_node_id += 1;
364 self.nodes.push(Node {
365 id,
366 op,
367 inputs,
368 output,
369 });
370 id
371 }
372
373 pub fn alloc_tensor_id(&mut self) -> TensorId {
375 let id = self.next_tensor_id;
376 self.next_tensor_id += 1;
377 id
378 }
379
380 pub fn register_weight(&mut self, name: String, shape: Shape, dtype: DType) -> TensorId {
382 let id = self.alloc_tensor_id();
383 let info = TensorInfo {
384 id,
385 name: name.clone(),
386 shape,
387 dtype,
388 };
389 self.weights.insert(name, info);
390 id
391 }
392
393 pub fn load_weight(&mut self, name: impl Into<String>, shape: Shape, dtype: DType) -> NodeId {
395 let name = name.into();
396 let tensor_id = self.alloc_tensor_id();
397 let output = TensorInfo {
398 id: tensor_id,
399 name: name.clone(),
400 shape,
401 dtype,
402 };
403 self.register_weight(name.clone(), output.shape.clone(), output.dtype);
404 self.add_node(Op::LoadWeight { name }, vec![], output)
405 }
406
407 pub fn input(&mut self, name: impl Into<String>, shape: Shape, dtype: DType) -> NodeId {
409 let name = name.into();
410 let tensor_id = self.alloc_tensor_id();
411 let output = TensorInfo {
412 id: tensor_id,
413 name: name.clone(),
414 shape,
415 dtype,
416 };
417 self.add_node(Op::Input { name }, vec![], output)
418 }
419
420 pub fn node(&self, id: NodeId) -> &Node {
422 &self.nodes[id]
423 }
424
425 pub fn output_info(&self, id: NodeId) -> &TensorInfo {
427 &self.nodes[id].output
428 }
429
430 pub fn len(&self) -> usize {
432 self.nodes.len()
433 }
434
435 pub fn is_empty(&self) -> bool {
437 self.nodes.is_empty()
438 }
439
440 pub fn topological_order(&self) -> Vec<NodeId> {
443 (0..self.nodes.len()).collect()
444 }
445
446 pub fn validate(&self) -> Result<(), GraphError> {
449 for node in &self.nodes {
450 for &input_id in &node.inputs {
451 if input_id >= node.id {
452 return Err(GraphError::ForwardReference {
453 node: node.id,
454 input: input_id,
455 });
456 }
457 if input_id >= self.nodes.len() {
458 return Err(GraphError::InvalidNodeReference {
459 node: node.id,
460 input: input_id,
461 });
462 }
463 }
464 }
465 Ok(())
466 }
467}
468
469#[derive(Debug, Clone, thiserror::Error)]
471pub enum GraphError {
472 #[error("node {node} references future node {input} (forward reference)")]
473 ForwardReference { node: NodeId, input: NodeId },
474
475 #[error("node {node} references non-existent node {input}")]
476 InvalidNodeReference { node: NodeId, input: NodeId },
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482
483 #[test]
484 fn create_empty_graph() {
485 let graph = Graph::new("test");
486 assert_eq!(graph.name, "test");
487 assert!(graph.is_empty());
488 }
489
490 #[test]
491 fn add_nodes_and_validate() {
492 let mut graph = Graph::new("test_model");
493
494 let input = graph.input("tokens", Shape::new(vec![1, 128]), DType::I32);
496
497 let embed_w = graph.load_weight(
499 "model.embed_tokens.weight",
500 Shape::new(vec![32000, 2048]),
501 DType::F16,
502 );
503
504 let tid = graph.alloc_tensor_id();
506 let embed = graph.add_node(
507 Op::Embedding {
508 vocab_size: 32000,
509 embed_dim: 2048,
510 },
511 vec![input, embed_w],
512 TensorInfo {
513 id: tid,
514 name: "embed_out".into(),
515 shape: Shape::new(vec![1, 128, 2048]),
516 dtype: DType::F16,
517 },
518 );
519
520 assert_eq!(graph.len(), 3);
521 assert_eq!(graph.node(embed).inputs, vec![input, embed_w]);
522 assert!(graph.validate().is_ok());
523 }
524
525 #[test]
526 fn validate_detects_forward_reference() {
527 let mut graph = Graph::new("bad");
528 let tid = graph.alloc_tensor_id();
529 graph.nodes.push(Node {
531 id: 0,
532 op: Op::Add,
533 inputs: vec![1], output: TensorInfo {
535 id: tid,
536 name: "bad".into(),
537 shape: Shape::new(vec![1]),
538 dtype: DType::F32,
539 },
540 });
541 graph.next_node_id = 1;
542
543 assert!(graph.validate().is_err());
544 }
545
546 #[test]
547 fn shape_operations() {
548 let s = Shape::new(vec![2, 3, 4]);
549 assert_eq!(s.ndim(), 3);
550 assert_eq!(s.numel(), 24);
551 assert_eq!(s.dim(1), 3);
552 assert_eq!(s.to_string(), "[2, 3, 4]");
553 }
554
555 #[test]
556 fn dtype_properties() {
557 assert!(DType::Q4_0.is_quantized());
558 assert!(!DType::F32.is_quantized());
559 assert!(DType::F16.is_float());
560 assert!(!DType::I32.is_float());
561 assert_eq!(DType::F32.size_bytes(), 4);
562 }
563
564 #[test]
565 fn topological_order() {
566 let mut graph = Graph::new("topo");
567 let a = graph.input("a", Shape::new(vec![4]), DType::F32);
568 let b = graph.input("b", Shape::new(vec![4]), DType::F32);
569 let tid = graph.alloc_tensor_id();
570 let _c = graph.add_node(
571 Op::Add,
572 vec![a, b],
573 TensorInfo {
574 id: tid,
575 name: "c".into(),
576 shape: Shape::new(vec![4]),
577 dtype: DType::F32,
578 },
579 );
580 assert_eq!(graph.topological_order(), vec![0, 1, 2]);
581 }
582
583 #[test]
584 fn weight_registration() {
585 let mut graph = Graph::new("weights");
586 graph.register_weight(
587 "layer.0.attention.wq.weight".into(),
588 Shape::new(vec![2048, 2048]),
589 DType::F16,
590 );
591 assert!(graph.weights.contains_key("layer.0.attention.wq.weight"));
592 let info = &graph.weights["layer.0.attention.wq.weight"];
593 assert_eq!(info.shape, Shape::new(vec![2048, 2048]));
594 assert_eq!(info.dtype, DType::F16);
595 }
596
597 #[test]
598 fn model_config_roundtrip() {
599 let config = ModelConfig {
600 architecture: Architecture::Llama,
601 hidden_size: 2048,
602 intermediate_size: 5632,
603 num_layers: 16,
604 num_attention_heads: 32,
605 num_kv_heads: 8,
606 head_dim: 64,
607 vocab_size: 32000,
608 max_seq_len: 2048,
609 rms_norm_eps: 1e-5,
610 rope_theta: 10000.0,
611 dtype: DType::F16,
612 };
613
614 let json = serde_json::to_string(&config).unwrap();
615 let deserialized: ModelConfig = serde_json::from_str(&json).unwrap();
616 assert_eq!(deserialized.architecture, Architecture::Llama);
617 assert_eq!(deserialized.hidden_size, 2048);
618 assert_eq!(deserialized.num_kv_heads, 8);
619 }
620
621 #[test]
622 fn graph_with_config() {
623 let config = ModelConfig {
624 architecture: Architecture::Llama,
625 hidden_size: 2048,
626 intermediate_size: 5632,
627 num_layers: 16,
628 num_attention_heads: 32,
629 num_kv_heads: 8,
630 head_dim: 64,
631 vocab_size: 32000,
632 max_seq_len: 2048,
633 rms_norm_eps: 1e-5,
634 rope_theta: 10000.0,
635 dtype: DType::F16,
636 };
637
638 let graph = Graph::new("llama-1b").with_config(config);
639 assert!(graph.config.is_some());
640 let cfg = graph.config.unwrap();
641 assert_eq!(cfg.architecture, Architecture::Llama);
642 }
643
644 #[test]
645 fn build_transformer_layer_fragment() {
646 let mut graph = Graph::new("layer_test");
647 let hidden = 2048;
648
649 let input = graph.input(
651 "hidden_states",
652 Shape::new(vec![1, 128, hidden]),
653 DType::F16,
654 );
655
656 let norm_w = graph.load_weight(
657 "model.layers.0.input_layernorm.weight",
658 Shape::new(vec![hidden]),
659 DType::F16,
660 );
661
662 let tid1 = graph.alloc_tensor_id();
663 let normed = graph.add_node(
664 Op::RMSNorm { eps: 1e-5 },
665 vec![input, norm_w],
666 TensorInfo {
667 id: tid1,
668 name: "normed".into(),
669 shape: Shape::new(vec![1, 128, hidden]),
670 dtype: DType::F16,
671 },
672 );
673
674 let q_weight = graph.load_weight(
675 "model.layers.0.self_attn.q_proj.weight",
676 Shape::new(vec![hidden, hidden]),
677 DType::F16,
678 );
679
680 let tid2 = graph.alloc_tensor_id();
681 let q_proj = graph.add_node(
682 Op::MatMul,
683 vec![normed, q_weight],
684 TensorInfo {
685 id: tid2,
686 name: "q_proj".into(),
687 shape: Shape::new(vec![1, 128, hidden]),
688 dtype: DType::F16,
689 },
690 );
691
692 assert_eq!(graph.len(), 5); assert_eq!(graph.node(q_proj).inputs, vec![normed, q_weight]);
694 assert!(graph.validate().is_ok());
695 }
696}