Skip to main content

forgellm_frontend/
ir.rs

1//! Intermediate representation for transformer computation graphs.
2//!
3//! The IR is the central abstraction in ForgeLLM: all frontends produce it,
4//! all backends consume it. It represents a static computation graph with
5//! typed tensor operations specialized for transformer architectures.
6
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fmt;
10
11/// Unique identifier for a node in the computation graph.
12pub type NodeId = usize;
13
14/// Unique identifier for a tensor.
15pub type TensorId = usize;
16
17/// Scalar data type for tensor elements.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
19pub enum DType {
20    F32,
21    F16,
22    BF16,
23    /// 8-bit float (E4M3)
24    F8E4M3,
25    /// 8-bit float (E5M2)
26    F8E5M2,
27    /// 8-bit quantized (block-wise, with scale factors)
28    Q8_0,
29    /// 4-bit quantized (block-wise, with scale factors)
30    Q4_0,
31    /// 4-bit quantized (with scale + min)
32    Q4_1,
33    /// 2-bit quantized
34    Q2,
35    /// 4-bit NormalFloat (for QLoRA)
36    NF4,
37    I32,
38    I64,
39}
40
41impl DType {
42    /// Size in bytes for non-quantized types, or effective bits per element for quantized.
43    pub fn size_bytes(&self) -> usize {
44        match self {
45            DType::F32 | DType::I32 => 4,
46            DType::F16 | DType::BF16 => 2,
47            DType::F8E4M3 | DType::F8E5M2 | DType::Q8_0 => 1,
48            DType::I64 => 8,
49            // Quantized types: return 1 as a placeholder; actual size depends on block size
50            DType::Q4_0 | DType::Q4_1 | DType::NF4 => 1,
51            DType::Q2 => 1,
52        }
53    }
54
55    pub fn is_quantized(&self) -> bool {
56        matches!(
57            self,
58            DType::Q8_0 | DType::Q4_0 | DType::Q4_1 | DType::Q2 | DType::NF4
59        )
60    }
61
62    pub fn is_float(&self) -> bool {
63        matches!(
64            self,
65            DType::F32 | DType::F16 | DType::BF16 | DType::F8E4M3 | DType::F8E5M2
66        )
67    }
68}
69
70impl fmt::Display for DType {
71    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72        match self {
73            DType::F32 => write!(f, "f32"),
74            DType::F16 => write!(f, "f16"),
75            DType::BF16 => write!(f, "bf16"),
76            DType::F8E4M3 => write!(f, "f8e4m3"),
77            DType::F8E5M2 => write!(f, "f8e5m2"),
78            DType::Q8_0 => write!(f, "q8_0"),
79            DType::Q4_0 => write!(f, "q4_0"),
80            DType::Q4_1 => write!(f, "q4_1"),
81            DType::Q2 => write!(f, "q2"),
82            DType::NF4 => write!(f, "nf4"),
83            DType::I32 => write!(f, "i32"),
84            DType::I64 => write!(f, "i64"),
85        }
86    }
87}
88
89/// Shape of a tensor — a list of dimension sizes.
90#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
91pub struct Shape(pub Vec<usize>);
92
93impl Shape {
94    pub fn new(dims: Vec<usize>) -> Self {
95        Self(dims)
96    }
97
98    pub fn scalar() -> Self {
99        Self(vec![])
100    }
101
102    pub fn ndim(&self) -> usize {
103        self.0.len()
104    }
105
106    pub fn numel(&self) -> usize {
107        self.0.iter().product()
108    }
109
110    /// Returns the size of a specific dimension.
111    pub fn dim(&self, i: usize) -> usize {
112        self.0[i]
113    }
114}
115
116impl fmt::Display for Shape {
117    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
118        write!(f, "[")?;
119        for (i, d) in self.0.iter().enumerate() {
120            if i > 0 {
121                write!(f, ", ")?;
122            }
123            write!(f, "{d}")?;
124        }
125        write!(f, "]")
126    }
127}
128
129/// Metadata about a tensor (shape + dtype), without the actual data.
130#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
131pub struct TensorInfo {
132    pub id: TensorId,
133    pub name: String,
134    pub shape: Shape,
135    pub dtype: DType,
136}
137
138/// An operation in the computation graph.
139#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
140pub enum Op {
141    // --- Tensor creation ---
142    /// Load a constant weight tensor (by name/id).
143    LoadWeight {
144        name: String,
145    },
146
147    /// External input (e.g., token ids).
148    Input {
149        name: String,
150    },
151
152    // --- Core linear algebra ---
153    /// Matrix multiplication: (M, K) x (K, N) -> (M, N)
154    MatMul,
155
156    /// Batched matrix multiplication.
157    BatchMatMul,
158
159    // --- Elementwise operations ---
160    Add,
161    Mul,
162    /// Sigmoid Linear Unit: x * sigmoid(x)
163    SiLU,
164    /// Gaussian Error Linear Unit
165    GeLU,
166    /// Rectified Linear Unit
167    ReLU,
168
169    // --- Normalization ---
170    /// Root Mean Square Layer Normalization with epsilon.
171    RMSNorm {
172        eps: f32,
173    },
174
175    /// Layer Normalization with epsilon.
176    LayerNorm {
177        eps: f32,
178    },
179
180    // --- Attention-specific ---
181    /// Rotary Position Embedding.
182    RoPE {
183        /// Maximum sequence length.
184        max_seq_len: usize,
185        /// Base frequency (typically 10000.0 or 500000.0).
186        rope_theta: f32,
187        /// Head dimension.
188        head_dim: usize,
189    },
190
191    /// Scaled dot-product attention.
192    /// Inputs: Q, K, V, optional mask.
193    Attention {
194        num_heads: usize,
195        num_kv_heads: usize,
196        head_dim: usize,
197    },
198
199    /// Softmax along the last dimension.
200    Softmax,
201
202    // --- Shape operations ---
203    /// Reshape tensor to new shape.
204    Reshape {
205        shape: Shape,
206    },
207
208    /// Transpose dimensions.
209    Transpose {
210        dim0: usize,
211        dim1: usize,
212    },
213
214    /// Contiguous memory layout.
215    Contiguous,
216
217    // --- Embedding ---
218    /// Token embedding lookup.
219    Embedding {
220        vocab_size: usize,
221        embed_dim: usize,
222    },
223
224    // --- Output ---
225    /// Final logits projection (often tied to embedding weights).
226    LogitsProjection {
227        vocab_size: usize,
228    },
229
230    // --- Residual ---
231    /// Residual connection (just an Add, but semantically distinct for fusion).
232    Residual,
233
234    // --- Cast ---
235    /// Cast tensor to a different dtype.
236    Cast {
237        to: DType,
238    },
239}
240
241impl fmt::Display for Op {
242    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
243        match self {
244            Op::LoadWeight { name } => write!(f, "LoadWeight({name})"),
245            Op::Input { name } => write!(f, "Input({name})"),
246            Op::MatMul => write!(f, "MatMul"),
247            Op::BatchMatMul => write!(f, "BatchMatMul"),
248            Op::Add => write!(f, "Add"),
249            Op::Mul => write!(f, "Mul"),
250            Op::SiLU => write!(f, "SiLU"),
251            Op::GeLU => write!(f, "GeLU"),
252            Op::ReLU => write!(f, "ReLU"),
253            Op::RMSNorm { eps } => write!(f, "RMSNorm(eps={eps})"),
254            Op::LayerNorm { eps } => write!(f, "LayerNorm(eps={eps})"),
255            Op::RoPE { head_dim, .. } => write!(f, "RoPE(head_dim={head_dim})"),
256            Op::Attention {
257                num_heads,
258                num_kv_heads,
259                head_dim,
260            } => write!(f, "Attention(h={num_heads},kv={num_kv_heads},d={head_dim})"),
261            Op::Softmax => write!(f, "Softmax"),
262            Op::Reshape { shape } => write!(f, "Reshape({shape})"),
263            Op::Transpose { dim0, dim1 } => write!(f, "Transpose({dim0},{dim1})"),
264            Op::Contiguous => write!(f, "Contiguous"),
265            Op::Embedding {
266                vocab_size,
267                embed_dim,
268            } => write!(f, "Embedding(v={vocab_size},d={embed_dim})"),
269            Op::LogitsProjection { vocab_size } => {
270                write!(f, "LogitsProjection(v={vocab_size})")
271            }
272            Op::Residual => write!(f, "Residual"),
273            Op::Cast { to } => write!(f, "Cast({to})"),
274        }
275    }
276}
277
278/// A node in the computation graph.
279#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct Node {
281    pub id: NodeId,
282    /// The operation this node performs.
283    pub op: Op,
284    /// Input node IDs (ordered).
285    pub inputs: Vec<NodeId>,
286    /// Output tensor info.
287    pub output: TensorInfo,
288}
289
290/// Model architecture type, used to select the right graph-building strategy.
291#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
292pub enum Architecture {
293    Llama,
294    Qwen2,
295    Mistral,
296    Phi3,
297    Gemma,
298    StableLM,
299}
300
301impl fmt::Display for Architecture {
302    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303        match self {
304            Architecture::Llama => write!(f, "Llama"),
305            Architecture::Qwen2 => write!(f, "Qwen2"),
306            Architecture::Mistral => write!(f, "Mistral"),
307            Architecture::Phi3 => write!(f, "Phi3"),
308            Architecture::Gemma => write!(f, "Gemma"),
309            Architecture::StableLM => write!(f, "StableLM"),
310        }
311    }
312}
313
314/// Configuration describing a transformer model's hyperparameters.
315#[derive(Debug, Clone, Serialize, Deserialize)]
316pub struct ModelConfig {
317    pub architecture: Architecture,
318    pub hidden_size: usize,
319    pub intermediate_size: usize,
320    pub num_layers: usize,
321    pub num_attention_heads: usize,
322    pub num_kv_heads: usize,
323    pub head_dim: usize,
324    pub vocab_size: usize,
325    pub max_seq_len: usize,
326    pub rms_norm_eps: f32,
327    pub rope_theta: f32,
328    pub dtype: DType,
329}
330
331/// The computation graph — the central IR artifact.
332#[derive(Debug, Clone, Serialize, Deserialize)]
333pub struct Graph {
334    pub name: String,
335    pub config: Option<ModelConfig>,
336    pub nodes: Vec<Node>,
337    /// Maps weight names to their tensor info.
338    pub weights: HashMap<String, TensorInfo>,
339    next_node_id: NodeId,
340    next_tensor_id: TensorId,
341}
342
343impl Graph {
344    pub fn new(name: impl Into<String>) -> Self {
345        Self {
346            name: name.into(),
347            config: None,
348            nodes: Vec::new(),
349            weights: HashMap::new(),
350            next_node_id: 0,
351            next_tensor_id: 0,
352        }
353    }
354
355    pub fn with_config(mut self, config: ModelConfig) -> Self {
356        self.config = Some(config);
357        self
358    }
359
360    /// Add a node to the graph and return its ID.
361    pub fn add_node(&mut self, op: Op, inputs: Vec<NodeId>, output: TensorInfo) -> NodeId {
362        let id = self.next_node_id;
363        self.next_node_id += 1;
364        self.nodes.push(Node {
365            id,
366            op,
367            inputs,
368            output,
369        });
370        id
371    }
372
373    /// Allocate a new tensor ID.
374    pub fn alloc_tensor_id(&mut self) -> TensorId {
375        let id = self.next_tensor_id;
376        self.next_tensor_id += 1;
377        id
378    }
379
380    /// Register a weight tensor in the graph.
381    pub fn register_weight(&mut self, name: String, shape: Shape, dtype: DType) -> TensorId {
382        let id = self.alloc_tensor_id();
383        let info = TensorInfo {
384            id,
385            name: name.clone(),
386            shape,
387            dtype,
388        };
389        self.weights.insert(name, info);
390        id
391    }
392
393    /// Add a weight-loading node.
394    pub fn load_weight(&mut self, name: impl Into<String>, shape: Shape, dtype: DType) -> NodeId {
395        let name = name.into();
396        let tensor_id = self.alloc_tensor_id();
397        let output = TensorInfo {
398            id: tensor_id,
399            name: name.clone(),
400            shape,
401            dtype,
402        };
403        self.register_weight(name.clone(), output.shape.clone(), output.dtype);
404        self.add_node(Op::LoadWeight { name }, vec![], output)
405    }
406
407    /// Add an input node.
408    pub fn input(&mut self, name: impl Into<String>, shape: Shape, dtype: DType) -> NodeId {
409        let name = name.into();
410        let tensor_id = self.alloc_tensor_id();
411        let output = TensorInfo {
412            id: tensor_id,
413            name: name.clone(),
414            shape,
415            dtype,
416        };
417        self.add_node(Op::Input { name }, vec![], output)
418    }
419
420    /// Get a node by ID.
421    pub fn node(&self, id: NodeId) -> &Node {
422        &self.nodes[id]
423    }
424
425    /// Get the output tensor info for a node.
426    pub fn output_info(&self, id: NodeId) -> &TensorInfo {
427        &self.nodes[id].output
428    }
429
430    /// Number of nodes in the graph.
431    pub fn len(&self) -> usize {
432        self.nodes.len()
433    }
434
435    /// Whether the graph has no nodes.
436    pub fn is_empty(&self) -> bool {
437        self.nodes.is_empty()
438    }
439
440    /// Return node IDs in topological order.
441    /// Since nodes are added in dependency order, this is just 0..n.
442    pub fn topological_order(&self) -> Vec<NodeId> {
443        (0..self.nodes.len()).collect()
444    }
445
446    /// Validate the graph: check that all input references are valid
447    /// and precede the node that uses them.
448    pub fn validate(&self) -> Result<(), GraphError> {
449        for node in &self.nodes {
450            for &input_id in &node.inputs {
451                if input_id >= node.id {
452                    return Err(GraphError::ForwardReference {
453                        node: node.id,
454                        input: input_id,
455                    });
456                }
457                if input_id >= self.nodes.len() {
458                    return Err(GraphError::InvalidNodeReference {
459                        node: node.id,
460                        input: input_id,
461                    });
462                }
463            }
464        }
465        Ok(())
466    }
467}
468
469/// Errors in graph construction or validation.
470#[derive(Debug, Clone, thiserror::Error)]
471pub enum GraphError {
472    #[error("node {node} references future node {input} (forward reference)")]
473    ForwardReference { node: NodeId, input: NodeId },
474
475    #[error("node {node} references non-existent node {input}")]
476    InvalidNodeReference { node: NodeId, input: NodeId },
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482
483    #[test]
484    fn create_empty_graph() {
485        let graph = Graph::new("test");
486        assert_eq!(graph.name, "test");
487        assert!(graph.is_empty());
488    }
489
490    #[test]
491    fn add_nodes_and_validate() {
492        let mut graph = Graph::new("test_model");
493
494        // Input tokens
495        let input = graph.input("tokens", Shape::new(vec![1, 128]), DType::I32);
496
497        // Embedding weight
498        let embed_w = graph.load_weight(
499            "model.embed_tokens.weight",
500            Shape::new(vec![32000, 2048]),
501            DType::F16,
502        );
503
504        // Embedding lookup
505        let tid = graph.alloc_tensor_id();
506        let embed = graph.add_node(
507            Op::Embedding {
508                vocab_size: 32000,
509                embed_dim: 2048,
510            },
511            vec![input, embed_w],
512            TensorInfo {
513                id: tid,
514                name: "embed_out".into(),
515                shape: Shape::new(vec![1, 128, 2048]),
516                dtype: DType::F16,
517            },
518        );
519
520        assert_eq!(graph.len(), 3);
521        assert_eq!(graph.node(embed).inputs, vec![input, embed_w]);
522        assert!(graph.validate().is_ok());
523    }
524
525    #[test]
526    fn validate_detects_forward_reference() {
527        let mut graph = Graph::new("bad");
528        let tid = graph.alloc_tensor_id();
529        // Manually push a node that references a future node
530        graph.nodes.push(Node {
531            id: 0,
532            op: Op::Add,
533            inputs: vec![1], // references node 1 which doesn't exist yet
534            output: TensorInfo {
535                id: tid,
536                name: "bad".into(),
537                shape: Shape::new(vec![1]),
538                dtype: DType::F32,
539            },
540        });
541        graph.next_node_id = 1;
542
543        assert!(graph.validate().is_err());
544    }
545
546    #[test]
547    fn shape_operations() {
548        let s = Shape::new(vec![2, 3, 4]);
549        assert_eq!(s.ndim(), 3);
550        assert_eq!(s.numel(), 24);
551        assert_eq!(s.dim(1), 3);
552        assert_eq!(s.to_string(), "[2, 3, 4]");
553    }
554
555    #[test]
556    fn dtype_properties() {
557        assert!(DType::Q4_0.is_quantized());
558        assert!(!DType::F32.is_quantized());
559        assert!(DType::F16.is_float());
560        assert!(!DType::I32.is_float());
561        assert_eq!(DType::F32.size_bytes(), 4);
562    }
563
564    #[test]
565    fn topological_order() {
566        let mut graph = Graph::new("topo");
567        let a = graph.input("a", Shape::new(vec![4]), DType::F32);
568        let b = graph.input("b", Shape::new(vec![4]), DType::F32);
569        let tid = graph.alloc_tensor_id();
570        let _c = graph.add_node(
571            Op::Add,
572            vec![a, b],
573            TensorInfo {
574                id: tid,
575                name: "c".into(),
576                shape: Shape::new(vec![4]),
577                dtype: DType::F32,
578            },
579        );
580        assert_eq!(graph.topological_order(), vec![0, 1, 2]);
581    }
582
583    #[test]
584    fn weight_registration() {
585        let mut graph = Graph::new("weights");
586        graph.register_weight(
587            "layer.0.attention.wq.weight".into(),
588            Shape::new(vec![2048, 2048]),
589            DType::F16,
590        );
591        assert!(graph.weights.contains_key("layer.0.attention.wq.weight"));
592        let info = &graph.weights["layer.0.attention.wq.weight"];
593        assert_eq!(info.shape, Shape::new(vec![2048, 2048]));
594        assert_eq!(info.dtype, DType::F16);
595    }
596
597    #[test]
598    fn model_config_roundtrip() {
599        let config = ModelConfig {
600            architecture: Architecture::Llama,
601            hidden_size: 2048,
602            intermediate_size: 5632,
603            num_layers: 16,
604            num_attention_heads: 32,
605            num_kv_heads: 8,
606            head_dim: 64,
607            vocab_size: 32000,
608            max_seq_len: 2048,
609            rms_norm_eps: 1e-5,
610            rope_theta: 10000.0,
611            dtype: DType::F16,
612        };
613
614        let json = serde_json::to_string(&config).unwrap();
615        let deserialized: ModelConfig = serde_json::from_str(&json).unwrap();
616        assert_eq!(deserialized.architecture, Architecture::Llama);
617        assert_eq!(deserialized.hidden_size, 2048);
618        assert_eq!(deserialized.num_kv_heads, 8);
619    }
620
621    #[test]
622    fn graph_with_config() {
623        let config = ModelConfig {
624            architecture: Architecture::Llama,
625            hidden_size: 2048,
626            intermediate_size: 5632,
627            num_layers: 16,
628            num_attention_heads: 32,
629            num_kv_heads: 8,
630            head_dim: 64,
631            vocab_size: 32000,
632            max_seq_len: 2048,
633            rms_norm_eps: 1e-5,
634            rope_theta: 10000.0,
635            dtype: DType::F16,
636        };
637
638        let graph = Graph::new("llama-1b").with_config(config);
639        assert!(graph.config.is_some());
640        let cfg = graph.config.unwrap();
641        assert_eq!(cfg.architecture, Architecture::Llama);
642    }
643
644    #[test]
645    fn build_transformer_layer_fragment() {
646        let mut graph = Graph::new("layer_test");
647        let hidden = 2048;
648
649        // Simulate: input -> RMSNorm -> Q projection (matmul)
650        let input = graph.input(
651            "hidden_states",
652            Shape::new(vec![1, 128, hidden]),
653            DType::F16,
654        );
655
656        let norm_w = graph.load_weight(
657            "model.layers.0.input_layernorm.weight",
658            Shape::new(vec![hidden]),
659            DType::F16,
660        );
661
662        let tid1 = graph.alloc_tensor_id();
663        let normed = graph.add_node(
664            Op::RMSNorm { eps: 1e-5 },
665            vec![input, norm_w],
666            TensorInfo {
667                id: tid1,
668                name: "normed".into(),
669                shape: Shape::new(vec![1, 128, hidden]),
670                dtype: DType::F16,
671            },
672        );
673
674        let q_weight = graph.load_weight(
675            "model.layers.0.self_attn.q_proj.weight",
676            Shape::new(vec![hidden, hidden]),
677            DType::F16,
678        );
679
680        let tid2 = graph.alloc_tensor_id();
681        let q_proj = graph.add_node(
682            Op::MatMul,
683            vec![normed, q_weight],
684            TensorInfo {
685                id: tid2,
686                name: "q_proj".into(),
687                shape: Shape::new(vec![1, 128, hidden]),
688                dtype: DType::F16,
689            },
690        );
691
692        assert_eq!(graph.len(), 5); // input, norm_w, normed, q_weight, q_proj
693        assert_eq!(graph.node(q_proj).inputs, vec![normed, q_weight]);
694        assert!(graph.validate().is_ok());
695    }
696}