Skip to main content

mlx_core/
graph.rs

1//! Lazy computation graph IR.
2//!
3//! Tensors are handles to nodes in this graph. Computation is deferred until
4//! `eval()` is called, at which point the scheduler topologically sorts the
5//! graph and dispatches to the active backend.
6
7use crate::types::{DType, Shape};
8use smallvec::SmallVec;
9use std::collections::HashMap;
10use std::hash::{Hash, Hasher};
11
12/// Unique identifier for a node in the computation graph.
13#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
14pub struct NodeId(pub(crate) u64);
15
16/// Metadata about a tensor (known before materialization).
17#[derive(Clone, Debug)]
18pub struct TensorMeta {
19    pub shape: Shape,
20    pub dtype: DType,
21}
22
23/// A node in the lazy computation graph.
24#[derive(Clone, Debug)]
25pub struct Node {
26    pub id: NodeId,
27    pub op: OpKind,
28    pub inputs: SmallVec<[NodeId; 2]>,
29    pub meta: TensorMeta,
30}
31
32/// The set of operations supported by the graph IR.
33#[derive(Clone, Debug)]
34pub enum OpKind {
35    // ── Sources ─────────────────────────────────────────────────────────
36    /// Constant tensor (data already materialized).
37    Constant,
38    /// Parameter (learnable weight, data provided externally).
39    Parameter,
40
41    // ── Elementwise ─────────────────────────────────────────────────────
42    Add,
43    Sub,
44    Mul,
45    Div,
46    Neg,
47    Exp,
48    Log,
49
50    // ── Reductions ──────────────────────────────────────────────────────
51    Sum {
52        axis: Option<i32>,
53    },
54    Mean {
55        axis: Option<i32>,
56    },
57    Max {
58        axis: Option<i32>,
59    },
60
61    // ── Linear algebra ──────────────────────────────────────────────────
62    MatMul,
63
64    // ── Shape manipulation ──────────────────────────────────────────────
65    Reshape {
66        new_shape: Shape,
67    },
68    Transpose {
69        axes: Option<Vec<usize>>,
70    },
71
72    // ── Activations ─────────────────────────────────────────────────────
73    Softmax {
74        axis: i32,
75    },
76    Silu,
77    Gelu,
78
79    // ── Normalization ───────────────────────────────────────────────────
80    LayerNorm {
81        eps: f32,
82    },
83    RmsNorm {
84        eps: f32,
85    },
86
87    // ── Positional encoding ────────────────────────────────────────────
88    /// Rotary positional embeddings (RoPE).
89    /// Applied in-place to interleaved pairs of the last dimension.
90    Rope {
91        rotary_dim: usize,
92        pos_offset: usize,
93        theta: f32,
94    },
95
96    // ── Broadcasting ──────────────────────────────────────────────────
97    /// Broadcast a tensor to a target shape (numpy-style rules).
98    Broadcast {
99        target_shape: Shape,
100    },
101
102    // ── Attention ──────────────────────────────────────────────────
103    /// Fused scale + causal-mask + softmax along last axis.
104    /// Input: scores [Tq, Tk], output: probs [Tq, Tk]
105    ScaledMaskedSoftmax {
106        scale: f32,
107        causal: bool,
108    },
109
110    /// Full single-head attention composition.
111    /// Inputs: [Q, K, V] where Q=[Tq,Dh], K=[Tk,Dh], V=[Tk,Dh]
112    /// Output: Y=[Tq,Dh]
113    Attention {
114        scale: f32,
115        causal: bool,
116    },
117
118    // ── Backward (VJP) ops ──────────────────────────────────────────
119    /// LayerNorm backward: inputs = [grad_output, input], produces grad_input.
120    LayerNormVjp {
121        eps: f32,
122    },
123    /// RmsNorm backward: inputs = [grad_output, input], produces grad_input.
124    RmsNormVjp {
125        eps: f32,
126    },
127    /// Softmax backward: inputs = [grad_output, softmax_output], produces grad_input.
128    SoftmaxVjp {
129        axis: i32,
130    },
131    /// SiLU backward: inputs = [grad_output, original_input], produces grad_input.
132    SiluVjp,
133    /// GELU backward: inputs = [grad_output, original_input], produces grad_input.
134    GeluVjp,
135
136    // ── Elementwise (misc) ──────────────────────────────────────────
137    /// Element-wise square root.
138    Sqrt,
139
140    // ── Rotary Positional Embeddings ───────────────────────────────────
141    #[cfg_attr(target_os = "macos", doc = "Apply rotary positional embeddings.")]
142    RoPE {
143        base: f32,
144        offset: usize,
145        traditional: bool,
146    },
147
148    // ── Indexing / gathering ─────────────────────────────────────────────
149    /// Embedding lookup: gather rows from a weight matrix by index.
150    /// Inputs: [weight [vocab, dim], indices [seq_len]]
151    /// Output: [seq_len, dim]
152    Embedding,
153
154    /// Extract a contiguous slice along an axis.
155    /// Inputs: [input]
156    /// Output: narrowed tensor
157    Narrow {
158        axis: i32,
159        start: i64,
160        length: i64,
161    },
162
163    /// Concatenate tensors along an axis.
164    /// Inputs: [tensor_0, tensor_1, ...]
165    /// Output: concatenated tensor
166    Concatenate {
167        axis: i32,
168    },
169}
170
171/// The computation graph arena.
172#[derive(Debug, Default)]
173pub struct Graph {
174    nodes: Vec<Node>,
175    next_id: u64,
176    cse: HashMap<CseKey, NodeId>,
177    const_payloads: HashMap<NodeId, Vec<f32>>,
178}
179
180impl Graph {
181    pub fn new() -> Self {
182        Self::default()
183    }
184
185    /// Add a node and return its ID.
186    pub fn add_node(
187        &mut self,
188        op: OpKind,
189        inputs: SmallVec<[NodeId; 2]>,
190        meta: TensorMeta,
191    ) -> NodeId {
192        self.add_node_raw(op, inputs, meta)
193    }
194
195    /// Add a node without CSE.
196    pub fn add_node_raw(
197        &mut self,
198        op: OpKind,
199        inputs: SmallVec<[NodeId; 2]>,
200        meta: TensorMeta,
201    ) -> NodeId {
202        let id = NodeId(self.next_id);
203        self.next_id += 1;
204        self.nodes.push(Node {
205            id,
206            op,
207            inputs,
208            meta,
209        });
210        id
211    }
212
213    /// Add a node with CSE. For constants, include a payload hash if available.
214    pub fn intern_node(
215        &mut self,
216        op: OpKind,
217        inputs: SmallVec<[NodeId; 2]>,
218        meta: TensorMeta,
219        const_payload: Option<&[f32]>,
220    ) -> NodeId {
221        if !is_cse_eligible(&op) {
222            return self.add_node_raw(op, inputs, meta);
223        }
224
225        let mut inputs = inputs;
226        normalize_inputs_for_cse(&op, &mut inputs);
227
228        let const_hash = const_payload.map(hash_f32_payload);
229        let key = CseKey {
230            op_key: OpKey::from_op(&op),
231            inputs: inputs.clone(),
232            meta_sig: MetaSig::new(&meta),
233            const_hash,
234        };
235
236        if let Some(&existing) = self.cse.get(&key) {
237            if matches!(op, OpKind::Constant) {
238                if let (Some(payload), Some(existing_payload)) =
239                    (const_payload, self.const_payload(existing))
240                    && existing_payload == payload
241                {
242                    return existing;
243                }
244            } else {
245                return existing;
246            }
247        }
248
249        let id = self.add_node_raw(op, inputs, meta);
250        if matches!(key.op_key, OpKey::Constant)
251            && let Some(payload) = const_payload
252        {
253            self.const_payloads.insert(id, payload.to_vec());
254        }
255        self.cse.insert(key, id);
256        id
257    }
258
259    pub fn const_payload(&self, id: NodeId) -> Option<&[f32]> {
260        self.const_payloads.get(&id).map(|v| v.as_slice())
261    }
262
263    /// Get a node by ID.
264    pub fn get(&self, id: NodeId) -> Option<&Node> {
265        self.nodes.iter().find(|n| n.id == id)
266    }
267
268    /// Topological sort of the graph rooted at `outputs`.
269    pub fn topo_sort(&self, outputs: &[NodeId]) -> Vec<NodeId> {
270        let mut visited = std::collections::HashSet::new();
271        let mut order = Vec::new();
272
273        for &out in outputs {
274            self.topo_visit(out, &mut visited, &mut order);
275        }
276
277        order
278    }
279
280    fn topo_visit(
281        &self,
282        id: NodeId,
283        visited: &mut std::collections::HashSet<NodeId>,
284        order: &mut Vec<NodeId>,
285    ) {
286        if !visited.insert(id) {
287            return;
288        }
289        if let Some(node) = self.get(id) {
290            for &input in &node.inputs {
291                self.topo_visit(input, visited, order);
292            }
293        }
294        order.push(id);
295    }
296
297    /// Number of nodes.
298    pub fn len(&self) -> usize {
299        self.nodes.len()
300    }
301
302    /// Whether the graph is empty.
303    pub fn is_empty(&self) -> bool {
304        self.nodes.is_empty()
305    }
306}
307
308#[derive(Clone, Debug, PartialEq, Eq, Hash)]
309struct MetaSig {
310    dtype: DType,
311    shape: Vec<i64>,
312}
313
314impl MetaSig {
315    fn new(meta: &TensorMeta) -> Self {
316        Self {
317            dtype: meta.dtype,
318            shape: meta.shape.0.clone(),
319        }
320    }
321}
322
323#[derive(Clone, Debug, PartialEq, Eq, Hash)]
324enum OpKey {
325    Constant,
326    Parameter,
327    Add,
328    Sub,
329    Mul,
330    Div,
331    Neg,
332    Exp,
333    Log,
334    Sum {
335        axis: Option<i32>,
336    },
337    Mean {
338        axis: Option<i32>,
339    },
340    Max {
341        axis: Option<i32>,
342    },
343    MatMul,
344    Reshape {
345        new_shape: Vec<i64>,
346    },
347    Transpose {
348        axes: Option<Vec<usize>>,
349    },
350    Softmax {
351        axis: i32,
352    },
353    Silu,
354    Gelu,
355    LayerNorm {
356        eps_bits: u32,
357    },
358    RmsNorm {
359        eps_bits: u32,
360    },
361    Broadcast {
362        target_shape: Vec<i64>,
363    },
364    LayerNormVjp {
365        eps_bits: u32,
366    },
367    RmsNormVjp {
368        eps_bits: u32,
369    },
370    ScaledMaskedSoftmax {
371        scale_bits: u32,
372        causal: bool,
373    },
374    Attention {
375        scale_bits: u32,
376        causal: bool,
377    },
378    Rope {
379        rotary_dim: usize,
380        pos_offset: usize,
381        theta_bits: u32,
382    },
383    RoPE {
384        base_bits: u32,
385        offset: usize,
386        traditional: bool,
387    },
388    SoftmaxVjp {
389        axis: i32,
390    },
391    SiluVjp,
392    GeluVjp,
393    Sqrt,
394    Embedding,
395    Narrow {
396        axis: i32,
397        start: i64,
398        length: i64,
399    },
400    Concatenate {
401        axis: i32,
402    },
403}
404
405impl OpKey {
406    fn from_op(op: &OpKind) -> Self {
407        match op {
408            OpKind::Constant => OpKey::Constant,
409            OpKind::Parameter => OpKey::Parameter,
410            OpKind::Add => OpKey::Add,
411            OpKind::Sub => OpKey::Sub,
412            OpKind::Mul => OpKey::Mul,
413            OpKind::Div => OpKey::Div,
414            OpKind::Neg => OpKey::Neg,
415            OpKind::Exp => OpKey::Exp,
416            OpKind::Log => OpKey::Log,
417            OpKind::Sum { axis } => OpKey::Sum { axis: *axis },
418            OpKind::Mean { axis } => OpKey::Mean { axis: *axis },
419            OpKind::Max { axis } => OpKey::Max { axis: *axis },
420            OpKind::MatMul => OpKey::MatMul,
421            OpKind::Reshape { new_shape } => OpKey::Reshape {
422                new_shape: new_shape.0.clone(),
423            },
424            OpKind::Transpose { axes } => OpKey::Transpose { axes: axes.clone() },
425            OpKind::Softmax { axis } => OpKey::Softmax { axis: *axis },
426            OpKind::Silu => OpKey::Silu,
427            OpKind::Gelu => OpKey::Gelu,
428            OpKind::LayerNorm { eps } => OpKey::LayerNorm {
429                eps_bits: eps.to_bits(),
430            },
431            OpKind::RmsNorm { eps } => OpKey::RmsNorm {
432                eps_bits: eps.to_bits(),
433            },
434            OpKind::Broadcast { target_shape } => OpKey::Broadcast {
435                target_shape: target_shape.0.clone(),
436            },
437            OpKind::LayerNormVjp { eps } => OpKey::LayerNormVjp {
438                eps_bits: eps.to_bits(),
439            },
440            OpKind::RmsNormVjp { eps } => OpKey::RmsNormVjp {
441                eps_bits: eps.to_bits(),
442            },
443            OpKind::ScaledMaskedSoftmax { scale, causal } => OpKey::ScaledMaskedSoftmax {
444                scale_bits: scale.to_bits(),
445                causal: *causal,
446            },
447            OpKind::Attention { scale, causal } => OpKey::Attention {
448                scale_bits: scale.to_bits(),
449                causal: *causal,
450            },
451            OpKind::Rope {
452                rotary_dim,
453                pos_offset,
454                theta,
455            } => OpKey::Rope {
456                rotary_dim: *rotary_dim,
457                pos_offset: *pos_offset,
458                theta_bits: theta.to_bits(),
459            },
460            OpKind::RoPE {
461                base,
462                offset,
463                traditional,
464            } => OpKey::RoPE {
465                base_bits: base.to_bits(),
466                offset: *offset,
467                traditional: *traditional,
468            },
469            OpKind::SoftmaxVjp { axis } => OpKey::SoftmaxVjp { axis: *axis },
470            OpKind::SiluVjp => OpKey::SiluVjp,
471            OpKind::GeluVjp => OpKey::GeluVjp,
472            OpKind::Sqrt => OpKey::Sqrt,
473            OpKind::Embedding => OpKey::Embedding,
474            OpKind::Narrow {
475                axis,
476                start,
477                length,
478            } => OpKey::Narrow {
479                axis: *axis,
480                start: *start,
481                length: *length,
482            },
483            OpKind::Concatenate { axis } => OpKey::Concatenate { axis: *axis },
484        }
485    }
486}
487
488#[derive(Clone, Debug, PartialEq, Eq, Hash)]
489struct CseKey {
490    op_key: OpKey,
491    inputs: SmallVec<[NodeId; 2]>,
492    meta_sig: MetaSig,
493    const_hash: Option<u64>,
494}
495
496fn is_cse_eligible(op: &OpKind) -> bool {
497    // Constants and Parameters must never be deduplicated: two tensors with
498    // identical data may flow through different parts of the graph and receive
499    // independent gradients during backpropagation.
500    !matches!(op, OpKind::Constant | OpKind::Parameter)
501}
502
503pub fn hash_f32_payload(data: &[f32]) -> u64 {
504    let mut h = std::collections::hash_map::DefaultHasher::new();
505    data.len().hash(&mut h);
506    for &x in data {
507        x.to_bits().hash(&mut h);
508    }
509    h.finish()
510}
511
512fn normalize_inputs_for_cse(op: &OpKind, inputs: &mut SmallVec<[NodeId; 2]>) {
513    if matches!(op, OpKind::Add | OpKind::Mul) && inputs.len() == 2 && inputs[0].0 > inputs[1].0 {
514        inputs.swap(0, 1);
515    }
516}
517
518#[cfg(test)]
519mod tests {
520    use super::*;
521
522    #[test]
523    fn test_graph_topo_sort() {
524        let mut g = Graph::new();
525        let a = g.add_node(
526            OpKind::Constant,
527            SmallVec::new(),
528            TensorMeta {
529                shape: Shape::new(vec![2, 3]),
530                dtype: DType::F32,
531            },
532        );
533        let b = g.add_node(
534            OpKind::Constant,
535            SmallVec::new(),
536            TensorMeta {
537                shape: Shape::new(vec![2, 3]),
538                dtype: DType::F32,
539            },
540        );
541        let c = g.add_node(
542            OpKind::Add,
543            SmallVec::from_slice(&[a, b]),
544            TensorMeta {
545                shape: Shape::new(vec![2, 3]),
546                dtype: DType::F32,
547            },
548        );
549
550        let order = g.topo_sort(&[c]);
551        assert_eq!(order.len(), 3);
552        // a and b before c
553        let pos_a = order.iter().position(|&id| id == a).unwrap();
554        let pos_b = order.iter().position(|&id| id == b).unwrap();
555        let pos_c = order.iter().position(|&id| id == c).unwrap();
556        assert!(pos_a < pos_c);
557        assert!(pos_b < pos_c);
558    }
559
560    #[test]
561    fn test_cse_does_not_dedup_constants() {
562        let mut g = Graph::new();
563        let meta = TensorMeta {
564            shape: Shape::new(vec![2]),
565            dtype: DType::F32,
566        };
567        // Constants must NOT be deduplicated — they may receive independent
568        // gradients during backpropagation.
569        let a = g.intern_node(
570            OpKind::Constant,
571            SmallVec::new(),
572            meta.clone(),
573            Some(&[1.0, 2.0]),
574        );
575        let b = g.intern_node(
576            OpKind::Constant,
577            SmallVec::new(),
578            meta.clone(),
579            Some(&[1.0, 2.0]),
580        );
581        // Constants must NOT be deduplicated — they may receive independent gradients
582        assert_ne!(a, b);
583        assert_eq!(g.len(), 2);
584    }
585
586    #[test]
587    fn test_cse_dedups_ops() {
588        let mut g = Graph::new();
589        let meta = TensorMeta {
590            shape: Shape::new(vec![2]),
591            dtype: DType::F32,
592        };
593        let a = g.intern_node(
594            OpKind::Constant,
595            SmallVec::new(),
596            meta.clone(),
597            Some(&[1.0, 2.0]),
598        );
599        let b = g.intern_node(
600            OpKind::Constant,
601            SmallVec::new(),
602            meta.clone(),
603            Some(&[3.0, 4.0]),
604        );
605
606        let add1 = g.intern_node(
607            OpKind::Add,
608            SmallVec::from_slice(&[a, b]),
609            meta.clone(),
610            None,
611        );
612        let add2 = g.intern_node(
613            OpKind::Add,
614            SmallVec::from_slice(&[a, b]),
615            meta.clone(),
616            None,
617        );
618        assert_eq!(add1, add2);
619        assert_eq!(g.len(), 3); // 2 constants + 1 deduplicated add
620    }
621}