Skip to main content

morok_ir/
op.rs

1//! Operation enum and implementation.
2//!
3//! The [`Op`] enum defines all possible operations in the IR, from basic arithmetic
4//! to complex control flow and memory operations.
5
6use std::sync::Arc;
7
8use smallvec::SmallVec;
9
10use crate::types::*;
11use crate::uop::UOp;
12use morok_dtype::DType;
13use morok_dtype::DeviceSpec;
14
15/// Operation type with typed operands.
16///
17/// Each operation encodes its operand structure directly in the enum variant.
18/// This provides compile-time verification of operand count and types.
19///
20/// Design choices:
21/// - Fixed-arity ops grouped by arity: Unary, Binary, Ternary
22/// - Special ops with extra data remain separate: Cast (dtype), MSelect (device_index)
23/// - Variable-arity ops use SmallVec: Index { indices: SmallVec<[Arc<UOp>; 4]> }
24/// - SmallVec avoids heap allocation for common cases (≤4 children)
25/// - Gate is on INDEX (not LOAD/STORE) following Tinygrad's model
26///
27/// Hash is derived and uses UOp's Hash impl for Arc<UOp> children.
28/// UOp hashes by content (dtype + op), enabling content-based hashing for caching.
29#[derive(Debug, Clone, Hash)]
30#[derive(strum::AsRefStr)]
31#[derive(morok_macros::PatternEnum)]
32#[pattern(grouped = [Unary, Binary, Ternary])]
33pub enum Op {
34    // Nullary operations (7 variants)
35    Const(ConstValueHash),
36    Unique(usize),
37    Device(DeviceSpec),
38    Noop,
39    #[pattern(skip)]
40    Invalid,
41    DefineLocal(usize),
42
43    // Graph organization operations (2 variants)
44    Sink {
45        sources: SmallVec<[Arc<UOp>; 4]>,
46    },
47    Group {
48        sources: SmallVec<[Arc<UOp>; 4]>,
49    },
50
51    // Grouped operations (3 variants)
52    Unary(UnaryOp, Arc<UOp>),
53    Binary(BinaryOp, Arc<UOp>, Arc<UOp>),
54    Ternary(TernaryOp, Arc<UOp>, Arc<UOp>, Arc<UOp>),
55
56    // Type operations (2 variants)
57    Cast {
58        src: Arc<UOp>,
59        dtype: DType,
60    },
61    BitCast {
62        src: Arc<UOp>,
63        dtype: DType,
64    },
65
66    // Special operations (2 variants)
67    MSelect {
68        buffer: Arc<UOp>,
69        device_index: usize,
70    },
71    Special {
72        end: Arc<UOp>,
73        name: String,
74    },
75
76    // Buffer operations (high-level, 8 variants)
77    /// Normalized buffer parameter — positional reference to an input/output buffer.
78    /// Created by pre-schedule normalization (BUFFER→PARAM) to erase buffer identity,
79    /// enabling structural deduplication of identical computations on different buffers.
80    /// Matches Tinygrad's Ops.PARAM (engine/schedule.py:125).
81    Param {
82        slot: usize,
83        size: usize,
84        device: Option<Arc<UOp>>,
85    },
86    Buffer {
87        unique: Arc<UOp>,
88        device: Arc<UOp>,
89        size: usize,
90    },
91    BufferView {
92        buffer: Arc<UOp>,
93        size: usize,
94        offset: usize,
95    },
96    Bufferize {
97        compute: Arc<UOp>,
98        ranges: SmallVec<[Arc<UOp>; 4]>,
99        opts: BufferizeOpts,
100    },
101    Index {
102        buffer: Arc<UOp>,
103        indices: SmallVec<[Arc<UOp>; 4]>,
104        gate: Option<Arc<UOp>>,
105    },
106    PointerIndex {
107        ptr: Arc<UOp>,
108        offset: Arc<UOp>,
109    },
110    Copy {
111        src: Arc<UOp>,
112        device: Arc<UOp>,
113    },
114    MStack {
115        buffers: SmallVec<[Arc<UOp>; 4]>,
116    },
117
118    // Movement/Reshape operations (7 variants)
119    Reshape {
120        src: Arc<UOp>,
121        new_shape: Arc<UOp>,
122    },
123    Permute {
124        src: Arc<UOp>,
125        axes: Vec<usize>,
126    },
127    Expand {
128        src: Arc<UOp>,
129        new_shape: Arc<UOp>,
130    },
131    Pad {
132        src: Arc<UOp>,
133        begin_pads: Arc<UOp>,
134        end_pads: Arc<UOp>,
135    },
136    Shrink {
137        src: Arc<UOp>,
138        begins: Arc<UOp>,
139        ends: Arc<UOp>,
140    },
141    Flip {
142        src: Arc<UOp>,
143        axes: Vec<bool>,
144    },
145    Multi {
146        src: Arc<UOp>,
147        axis: usize,
148    },
149
150    // Reduction operations (3 variants)
151    ReduceAxis {
152        src: Arc<UOp>,
153        reduce_op: ReduceOp,
154        axes: Vec<usize>,
155    },
156    Reduce {
157        src: Arc<UOp>,
158        ranges: SmallVec<[Arc<UOp>; 4]>,
159        reduce_op: ReduceOp,
160    },
161    AllReduce {
162        src: Arc<UOp>,
163        device: Arc<UOp>,
164        reduce_op: ReduceOp,
165    },
166
167    // Control flow operations (5 variants)
168    If {
169        condition: Arc<UOp>,
170        body: SmallVec<[Arc<UOp>; 4]>,
171    },
172    EndIf {
173        if_op: Arc<UOp>,
174    },
175    Range {
176        end: Arc<UOp>,
177        axis_id: AxisId,
178        axis_type: AxisType,
179        deps: SmallVec<[Arc<UOp>; 2]>,
180    },
181    End {
182        computation: Arc<UOp>,
183        ranges: SmallVec<[Arc<UOp>; 4]>,
184    },
185    Barrier {
186        src: Arc<UOp>,
187        deps: SmallVec<[Arc<UOp>; 4]>,
188    },
189
190    // Vector operations (5 variants)
191    Vectorize {
192        elements: SmallVec<[Arc<UOp>; 4]>,
193    },
194    Gep {
195        vector: Arc<UOp>,
196        indices: Vec<usize>,
197    },
198    VConst {
199        values: Vec<ConstValue>,
200    },
201    /// Concatenate vectors into larger vector (expander op).
202    /// Like VECTORIZE but sources can be vectors themselves.
203    /// Output vcount = sum of all input vcounts.
204    Cat {
205        sources: SmallVec<[Arc<UOp>; 4]>,
206    },
207    /// Concatenate pointers into vectorized pointer (expander op).
208    /// Used for grouping memory accesses in devectorizer.
209    PtrCat {
210        sources: SmallVec<[Arc<UOp>; 4]>,
211    },
212
213    // Symbolic/Define operations (3 variants)
214    DefineVar {
215        name: String,
216        min_val: i64,
217        max_val: i64,
218    },
219    Bind {
220        var: Arc<UOp>,
221        value: Arc<UOp>,
222    },
223    DefineReg {
224        size: usize,
225        /// Unique accumulator ID for disambiguation.
226        /// Without this, two same-dtype reduces would share one DefineReg via hash consing.
227        id: usize,
228    },
229
230    // Advanced operations (12 variants)
231    Wmma {
232        a: Arc<UOp>,
233        b: Arc<UOp>,
234        c: Arc<UOp>,
235        metadata: WmmaMetadata,
236    },
237    Contract {
238        src: Arc<UOp>,
239        upcast_ranges: Vec<(usize, usize)>,
240    },
241    Unroll {
242        src: Arc<UOp>,
243        unroll_axes: Vec<(usize, usize)>,
244    },
245    Kernel {
246        sources: SmallVec<[Arc<UOp>; 4]>,
247        ast: Arc<UOp>,
248    },
249    Assign {
250        target: Arc<UOp>,
251        value: Arc<UOp>,
252        /// Movement ops chain for shape tracking (third source in Tinygrad).
253        /// This is a UOp chain where each node is a movement op, and walking
254        /// via src[0] reaches the base INDEX operation. Used during
255        /// bufferize_to_store to apply the same transformations to the result buffer.
256        movement_ops: Option<Arc<UOp>>,
257    },
258    Detach {
259        src: Arc<UOp>,
260    },
261    Contiguous {
262        src: Arc<UOp>,
263        /// Optimization hints (Tinygrad: CONTIGUOUS.arg)
264        opts: SmallVec<[crate::types::ContiguousHint; 4]>,
265    },
266    ContiguousBackward {
267        src: Arc<UOp>,
268    },
269    After {
270        passthrough: Arc<UOp>,
271        deps: SmallVec<[Arc<UOp>; 4]>,
272    },
273    Precast {
274        src: Arc<UOp>,
275    },
276    Custom {
277        deps: SmallVec<[Arc<UOp>; 4]>,
278        code: String,
279    },
280    CustomI {
281        deps: SmallVec<[Arc<UOp>; 4]>,
282        code: String,
283    },
284
285    // Memory operations (low-level, after kernel splitting, 2 variants)
286    // Gate is on INDEX, not LOAD/STORE (following Tinygrad's model)
287    /// Load from buffer at index.
288    ///
289    /// - `buffer`: The buffer to load from
290    /// - `index`: The INDEX operation specifying where to load (may be gated)
291    /// - `alt`: Optional alternative value for gated loads (used when gate is false)
292    ///
293    /// When `alt` is Some, the load behaves as: `if gate { load(index) } else { alt }`.
294    /// This is used for masked loads in image processing and padding scenarios.
295    Load {
296        buffer: Arc<UOp>,
297        index: Arc<UOp>,
298        alt: Option<Arc<UOp>>,
299    },
300    Store {
301        index: Arc<UOp>,
302        value: Arc<UOp>,
303        ranges: SmallVec<[Arc<UOp>; 4]>,
304    },
305}
306
307impl Op {
308    /// Get all child UOps as a Vec of references.
309    ///
310    /// This is the convenient API for traversing the graph.
311    /// Allocates a Vec but is simple to use.
312    pub fn children(&self) -> SmallVec<[&Arc<UOp>; 4]> {
313        match self {
314            // Nullary operations
315            Self::Const(_)
316            | Self::Unique(_)
317            | Self::Device(_)
318            | Self::Noop
319            | Self::Invalid
320            | Self::DefineLocal(_)
321            | Self::VConst { .. }
322            | Self::DefineVar { .. }
323            | Self::DefineReg { .. } => SmallVec::new(),
324
325            // Param has optional device child — pre-kernel PARAMs have device, codegen PARAMs don't
326            Self::Param { device: Some(d), .. } => SmallVec::from_slice(&[d]),
327            Self::Param { device: None, .. } => SmallVec::new(),
328
329            // Graph organization operations
330            Self::Sink { sources } | Self::Group { sources } => sources.iter().collect(),
331
332            // Grouped operations
333            Self::Unary(_, x) => SmallVec::from_slice(&[x]),
334            Self::Binary(_, a, b) => SmallVec::from_slice(&[a, b]),
335            Self::Ternary(_, a, b, c) => SmallVec::from_slice(&[a, b, c]),
336
337            // Type operations
338            Self::Cast { src, .. } | Self::BitCast { src, .. } => SmallVec::from_slice(&[src]),
339
340            // Special operations
341            Self::MSelect { buffer, .. } => SmallVec::from_slice(&[buffer]),
342            Self::Special { end, .. } => SmallVec::from_slice(&[end]),
343
344            // Buffer operations
345            Self::Buffer { unique, device, .. } => SmallVec::from_slice(&[unique, device]),
346            Self::BufferView { buffer, .. } => SmallVec::from_slice(&[buffer]),
347            Self::Bufferize { compute, ranges, .. } => {
348                let mut children = SmallVec::from_slice(&[compute]);
349                children.extend(ranges.iter());
350                children
351            }
352            Self::Index { buffer, indices, gate } => {
353                let mut children = SmallVec::from_slice(&[buffer]);
354                children.extend(indices.iter());
355                children.extend(gate);
356                children
357            }
358            Self::PointerIndex { ptr, offset } => SmallVec::from_slice(&[ptr, offset]),
359            Self::Copy { src, device } => SmallVec::from_slice(&[src, device]),
360            Self::MStack { buffers } => buffers.iter().collect(),
361
362            // Movement operations
363            Self::Reshape { src, new_shape } => SmallVec::from_slice(&[src, new_shape]),
364            Self::Permute { src, .. } | Self::Flip { src, .. } | Self::Multi { src, .. } => {
365                SmallVec::from_slice(&[src])
366            }
367            Self::Expand { src, new_shape } => SmallVec::from_slice(&[src, new_shape]),
368            Self::Pad { src, begin_pads, end_pads } => SmallVec::from_slice(&[src, begin_pads, end_pads]),
369            Self::Shrink { src, begins, ends } => SmallVec::from_slice(&[src, begins, ends]),
370
371            // Reduction operations
372            Self::ReduceAxis { src, .. } => SmallVec::from_slice(&[src]),
373            Self::Reduce { src, ranges, .. } => {
374                let mut children = SmallVec::from_slice(&[src]);
375                children.extend(ranges.iter());
376                children
377            }
378            Self::AllReduce { src, device, .. } => SmallVec::from_slice(&[src, device]),
379
380            // Control flow operations
381            Self::If { condition, body } => {
382                let mut children = SmallVec::from_slice(&[condition]);
383                children.extend(body.iter());
384                children
385            }
386            Self::EndIf { if_op } => SmallVec::from_slice(&[if_op]),
387            Self::Range { end, deps, .. } => {
388                let mut children = SmallVec::from_slice(&[end]);
389                children.extend(deps.iter());
390                children
391            }
392            Self::End { computation, ranges } => {
393                let mut children = SmallVec::from_slice(&[computation]);
394                children.extend(ranges.iter());
395                children
396            }
397            Self::Barrier { src, deps } => {
398                let mut children = SmallVec::from_slice(&[src]);
399                children.extend(deps.iter());
400                children
401            }
402
403            // Vector operations
404            Self::Vectorize { elements } => elements.iter().collect(),
405            Self::Gep { vector, .. } => SmallVec::from_slice(&[vector]),
406            Self::Cat { sources } | Self::PtrCat { sources } => sources.iter().collect(),
407
408            // Symbolic/Define operations
409            Self::Bind { var, value } => SmallVec::from_slice(&[var, value]),
410
411            // Advanced operations
412            Self::Wmma { a, b, c, .. } => SmallVec::from_slice(&[a, b, c]),
413            Self::Contract { src, .. }
414            | Self::Unroll { src, .. }
415            | Self::Detach { src }
416            | Self::Contiguous { src, .. }
417            | Self::ContiguousBackward { src }
418            | Self::Precast { src } => SmallVec::from_slice(&[src]),
419            Self::Kernel { sources, ast } => {
420                let mut children: SmallVec<[&Arc<UOp>; 4]> = sources.iter().collect();
421                children.push(ast);
422                children
423            }
424            Self::Assign { target, value, movement_ops } => {
425                let mut children = SmallVec::from_slice(&[target, value]);
426                if let Some(mops) = movement_ops {
427                    children.push(mops);
428                }
429                children
430            }
431            Self::After { passthrough, deps } => {
432                let mut children = SmallVec::from_slice(&[passthrough]);
433                children.extend(deps.iter());
434                children
435            }
436            Self::Custom { deps, .. } | Self::CustomI { deps, .. } => deps.iter().collect(),
437
438            // Memory operations
439            Self::Load { buffer, index, alt } => {
440                let mut children = SmallVec::from_slice(&[buffer, index]);
441                children.extend(alt);
442                children
443            }
444            Self::Store { index, value, ranges } => {
445                let mut children = SmallVec::from_slice(&[index, value]);
446                children.extend(ranges.iter());
447                children
448            }
449        }
450    }
451
452    /// Get all child UOps as a Vec of owned Rcs (cloned).
453    ///
454    /// Similar to `children()` but returns owned Rcs instead of references.
455    /// Useful when you need to reconstruct nodes or store sources.
456    pub fn sources(&self) -> SmallVec<[Arc<UOp>; 4]> {
457        self.children().iter().map(|rc| (*rc).clone()).collect()
458    }
459
460    /// Apply a function to each child UOp.
461    pub fn map_child<F>(&self, mut f: F)
462    where
463        F: FnMut(&Arc<UOp>),
464    {
465        for child in self.children() {
466            f(child);
467        }
468    }
469
470    /// Check if this operation is a movement operation.
471    ///
472    /// Movement operations transform tensor shapes without changing data values:
473    /// - RESHAPE: Change shape with same number of elements
474    /// - PERMUTE: Transpose/reorder axes
475    /// - EXPAND: Broadcast to larger shape
476    /// - PAD: Add padding around tensor
477    /// - SHRINK: Extract sub-region
478    /// - FLIP: Reverse along axes
479    ///
480    /// Note: MULTI is not considered a pure movement op as it has different semantics.
481    pub fn is_movement(&self) -> bool {
482        matches!(
483            self,
484            Self::Reshape { .. }
485                | Self::Permute { .. }
486                | Self::Expand { .. }
487                | Self::Pad { .. }
488                | Self::Shrink { .. }
489                | Self::Flip { .. }
490        )
491    }
492
493    /// Get the source index where ranges start being "ended" by this operation.
494    ///
495    /// Based on Tinygrad's `range_start` dict (ops.py:28).
496    /// Returns `Some(index)` if this operation ends ranges, `None` otherwise.
497    ///
498    /// Operations that end ranges:
499    /// - BUFFERIZE: ranges start at index 1 (compute is 0, ranges are 1+)
500    /// - REDUCE: ranges start at index 1 (src is 0, ranges are 1+)
501    /// - STORE: ranges start at index 2 (index=0, value=1, ranges=2+)
502    /// - WMMA: ranges start at index 3 (a=0, b=1, c=2)
503    /// - END: ranges start at index 1 (computation=0, ranges=1+)
504    ///
505    /// These operations mark range boundaries in the computation graph.
506    /// Any RANGE operations in sources at or after the returned index
507    /// are considered "ended" and removed from scope.
508    ///
509    /// # Examples
510    ///
511    /// ```ignore
512    /// use morok_ir::Op;
513    ///
514    /// // BUFFERIZE ends ranges starting at source index 1
515    /// let bufferize_op = Op::Bufferize { /* ... */ };
516    /// assert_eq!(bufferize_op.range_ending_src_index(), Some(1));
517    ///
518    /// // Regular arithmetic operations don't end ranges
519    /// let binary_op = Op::Binary(/* ... */);
520    /// assert_eq!(binary_op.range_ending_src_index(), None);
521    /// ```
522    pub fn range_ending_src_index(&self) -> Option<usize> {
523        // Source layout for range-ending ops:
524        // - BUFFERIZE: compute=0, ranges=1+
525        // - REDUCE: src=0, ranges=1+
526        // - STORE: index=0, value=1, ranges=2+
527        // - WMMA: a=0, b=1, c=2, (ranges start at 3)
528        // - END: computation=0, ranges=1+
529        match self {
530            Self::Bufferize { .. } => Some(1),
531            Self::Reduce { .. } => Some(1),
532            Self::Store { .. } => Some(2),
533            Self::Wmma { .. } => Some(3),
534            Self::End { .. } => Some(1),
535            _ => None,
536        }
537    }
538
539    /// Check if this operation should be expanded when it has UNROLL inputs.
540    ///
541    /// Based on Tinygrad's expander.py:97-98 pattern which expands:
542    /// - ALU ops (Unary, Binary, Ternary)
543    /// - Type ops (Cast, BitCast)
544    /// - Vector ops (Gep, Vectorize)
545    /// - Tensor core ops (Wmma)
546    /// - Memory ops (Load, Store, Index)
547    /// - Buffer ops (Bufferize)
548    /// - Control flow (Reduce, End, After)
549    ///
550    /// These operations propagate vectorization through the computation graph
551    /// when any of their sources is an UNROLL operation.
552    pub fn is_expandable(&self) -> bool {
553        matches!(
554            self,
555            // ALU operations
556            Self::Unary(..) | Self::Binary(..) | Self::Ternary(..) |
557            // Type operations
558            Self::Cast { .. } | Self::BitCast { .. } |
559            // Vector operations
560            Self::Gep { .. } | Self::Vectorize { .. } |
561            // Tensor core
562            Self::Wmma { .. } |
563            // Memory operations
564            Self::Load { .. } | Self::Store { .. } |
565            Self::Index { .. } | Self::PointerIndex { .. } |
566            // Buffer operations
567            Self::Bufferize { .. } |
568            // Control flow (range-ending ops)
569            Self::Reduce { .. } | Self::End { .. } | Self::After { .. }
570        )
571    }
572
573    /// Get the "ended ranges" for this operation.
574    ///
575    /// These are the RANGE operations (and operations containing ranges)
576    /// that should be removed from scope after this operation.
577    ///
578    /// Based on Tinygrad's `ended_ranges` property (ops.py:296-299).
579    ///
580    /// # Returns
581    ///
582    /// A SmallVec of references to child UOps that represent ended ranges.
583    /// For operations that don't end ranges, returns an empty SmallVec.
584    ///
585    /// # Examples
586    ///
587    /// ```ignore
588    /// use morok_ir::{Op, UOp};
589    ///
590    /// // END operation ends its range arguments
591    /// let range = UOp::range(/* ... */);
592    /// let computation = UOp::const_(/* ... */);
593    /// let end_op = computation.end(vec![range.clone()]);
594    ///
595    /// // ended_ranges() returns the ranges that are closed
596    /// let ended = end_op.op().ended_ranges();
597    /// assert_eq!(ended.len(), 1);
598    /// ```
599    pub fn ended_ranges(&self) -> SmallVec<[&Arc<UOp>; 4]> {
600        if let Some(start_idx) = self.range_ending_src_index() {
601            let children = self.children();
602            children.into_iter().skip(start_idx).collect()
603        } else if let Self::After { deps, .. } = self {
604            // Tinygrad (ops.py:312): flatten([x.ended_ranges for x in self.src[1:]])
605            // AFTER propagates ended ranges from its dependency chain.
606            let mut result = SmallVec::new();
607            for dep in deps {
608                result.extend(dep.op().ended_ranges());
609            }
610            result
611        } else if matches!(self, Self::Copy { .. } | Self::BufferView { .. }) {
612            // Tinygrad (ops.py:314): return self.src[0].ranges
613            // COPY/BUFFER_VIEW ends all ranges from the source.
614            // We return the source itself (not individual ranges) — the
615            // InScopeRangesProperty handles the non-RANGE branch by looking
616            // up the ended node's in_scope_ranges and removing them all.
617            let children = self.children();
618            if children.is_empty() { SmallVec::new() } else { SmallVec::from_elem(children[0], 1) }
619        } else {
620            SmallVec::new()
621        }
622    }
623}