morok_ir/op.rs
1//! Operation enum and implementation.
2//!
3//! The [`Op`] enum defines all possible operations in the IR, from basic arithmetic
4//! to complex control flow and memory operations.
5
6use std::sync::Arc;
7
8use smallvec::SmallVec;
9
10use crate::types::*;
11use crate::uop::UOp;
12use morok_dtype::DType;
13use morok_dtype::DeviceSpec;
14
15/// Operation type with typed operands.
16///
17/// Each operation encodes its operand structure directly in the enum variant.
18/// This provides compile-time verification of operand count and types.
19///
20/// Design choices:
21/// - Fixed-arity ops grouped by arity: Unary, Binary, Ternary
22/// - Special ops with extra data remain separate: Cast (dtype), MSelect (device_index)
23/// - Variable-arity ops use SmallVec: Index { indices: SmallVec<[Arc<UOp>; 4]> }
24/// - SmallVec avoids heap allocation for common cases (≤4 children)
25/// - Gate is on INDEX (not LOAD/STORE) following Tinygrad's model
26///
27/// Hash is derived and uses UOp's Hash impl for Arc<UOp> children.
28/// UOp hashes by content (dtype + op), enabling content-based hashing for caching.
29#[derive(Debug, Clone, Hash)]
30#[derive(strum::AsRefStr)]
31#[derive(morok_macros::PatternEnum)]
32#[pattern(grouped = [Unary, Binary, Ternary])]
33pub enum Op {
34 // Nullary operations (7 variants)
35 Const(ConstValueHash),
36 Unique(usize),
37 Device(DeviceSpec),
38 Noop,
39 #[pattern(skip)]
40 Invalid,
41 DefineLocal(usize),
42
43 // Graph organization operations (2 variants)
44 Sink {
45 sources: SmallVec<[Arc<UOp>; 4]>,
46 },
47 Group {
48 sources: SmallVec<[Arc<UOp>; 4]>,
49 },
50
51 // Grouped operations (3 variants)
52 Unary(UnaryOp, Arc<UOp>),
53 Binary(BinaryOp, Arc<UOp>, Arc<UOp>),
54 Ternary(TernaryOp, Arc<UOp>, Arc<UOp>, Arc<UOp>),
55
56 // Type operations (2 variants)
57 Cast {
58 src: Arc<UOp>,
59 dtype: DType,
60 },
61 BitCast {
62 src: Arc<UOp>,
63 dtype: DType,
64 },
65
66 // Special operations (2 variants)
67 MSelect {
68 buffer: Arc<UOp>,
69 device_index: usize,
70 },
71 Special {
72 end: Arc<UOp>,
73 name: String,
74 },
75
76 // Buffer operations (high-level, 8 variants)
77 /// Normalized buffer parameter — positional reference to an input/output buffer.
78 /// Created by pre-schedule normalization (BUFFER→PARAM) to erase buffer identity,
79 /// enabling structural deduplication of identical computations on different buffers.
80 /// Matches Tinygrad's Ops.PARAM (engine/schedule.py:125).
81 Param {
82 slot: usize,
83 size: usize,
84 device: Option<Arc<UOp>>,
85 },
86 Buffer {
87 unique: Arc<UOp>,
88 device: Arc<UOp>,
89 size: usize,
90 },
91 BufferView {
92 buffer: Arc<UOp>,
93 size: usize,
94 offset: usize,
95 },
96 Bufferize {
97 compute: Arc<UOp>,
98 ranges: SmallVec<[Arc<UOp>; 4]>,
99 opts: BufferizeOpts,
100 },
101 Index {
102 buffer: Arc<UOp>,
103 indices: SmallVec<[Arc<UOp>; 4]>,
104 gate: Option<Arc<UOp>>,
105 },
106 PointerIndex {
107 ptr: Arc<UOp>,
108 offset: Arc<UOp>,
109 },
110 Copy {
111 src: Arc<UOp>,
112 device: Arc<UOp>,
113 },
114 MStack {
115 buffers: SmallVec<[Arc<UOp>; 4]>,
116 },
117
118 // Movement/Reshape operations (7 variants)
119 Reshape {
120 src: Arc<UOp>,
121 new_shape: Arc<UOp>,
122 },
123 Permute {
124 src: Arc<UOp>,
125 axes: Vec<usize>,
126 },
127 Expand {
128 src: Arc<UOp>,
129 new_shape: Arc<UOp>,
130 },
131 Pad {
132 src: Arc<UOp>,
133 begin_pads: Arc<UOp>,
134 end_pads: Arc<UOp>,
135 },
136 Shrink {
137 src: Arc<UOp>,
138 begins: Arc<UOp>,
139 ends: Arc<UOp>,
140 },
141 Flip {
142 src: Arc<UOp>,
143 axes: Vec<bool>,
144 },
145 Multi {
146 src: Arc<UOp>,
147 axis: usize,
148 },
149
150 // Reduction operations (3 variants)
151 ReduceAxis {
152 src: Arc<UOp>,
153 reduce_op: ReduceOp,
154 axes: Vec<usize>,
155 },
156 Reduce {
157 src: Arc<UOp>,
158 ranges: SmallVec<[Arc<UOp>; 4]>,
159 reduce_op: ReduceOp,
160 },
161 AllReduce {
162 src: Arc<UOp>,
163 device: Arc<UOp>,
164 reduce_op: ReduceOp,
165 },
166
167 // Control flow operations (5 variants)
168 If {
169 condition: Arc<UOp>,
170 body: SmallVec<[Arc<UOp>; 4]>,
171 },
172 EndIf {
173 if_op: Arc<UOp>,
174 },
175 Range {
176 end: Arc<UOp>,
177 axis_id: AxisId,
178 axis_type: AxisType,
179 deps: SmallVec<[Arc<UOp>; 2]>,
180 },
181 End {
182 computation: Arc<UOp>,
183 ranges: SmallVec<[Arc<UOp>; 4]>,
184 },
185 Barrier {
186 src: Arc<UOp>,
187 deps: SmallVec<[Arc<UOp>; 4]>,
188 },
189
190 // Vector operations (5 variants)
191 Vectorize {
192 elements: SmallVec<[Arc<UOp>; 4]>,
193 },
194 Gep {
195 vector: Arc<UOp>,
196 indices: Vec<usize>,
197 },
198 VConst {
199 values: Vec<ConstValue>,
200 },
201 /// Concatenate vectors into larger vector (expander op).
202 /// Like VECTORIZE but sources can be vectors themselves.
203 /// Output vcount = sum of all input vcounts.
204 Cat {
205 sources: SmallVec<[Arc<UOp>; 4]>,
206 },
207 /// Concatenate pointers into vectorized pointer (expander op).
208 /// Used for grouping memory accesses in devectorizer.
209 PtrCat {
210 sources: SmallVec<[Arc<UOp>; 4]>,
211 },
212
213 // Symbolic/Define operations (3 variants)
214 DefineVar {
215 name: String,
216 min_val: i64,
217 max_val: i64,
218 },
219 Bind {
220 var: Arc<UOp>,
221 value: Arc<UOp>,
222 },
223 DefineReg {
224 size: usize,
225 /// Unique accumulator ID for disambiguation.
226 /// Without this, two same-dtype reduces would share one DefineReg via hash consing.
227 id: usize,
228 },
229
230 // Advanced operations (12 variants)
231 Wmma {
232 a: Arc<UOp>,
233 b: Arc<UOp>,
234 c: Arc<UOp>,
235 metadata: WmmaMetadata,
236 },
237 Contract {
238 src: Arc<UOp>,
239 upcast_ranges: Vec<(usize, usize)>,
240 },
241 Unroll {
242 src: Arc<UOp>,
243 unroll_axes: Vec<(usize, usize)>,
244 },
245 Kernel {
246 sources: SmallVec<[Arc<UOp>; 4]>,
247 ast: Arc<UOp>,
248 },
249 Assign {
250 target: Arc<UOp>,
251 value: Arc<UOp>,
252 /// Movement ops chain for shape tracking (third source in Tinygrad).
253 /// This is a UOp chain where each node is a movement op, and walking
254 /// via src[0] reaches the base INDEX operation. Used during
255 /// bufferize_to_store to apply the same transformations to the result buffer.
256 movement_ops: Option<Arc<UOp>>,
257 },
258 Detach {
259 src: Arc<UOp>,
260 },
261 Contiguous {
262 src: Arc<UOp>,
263 /// Optimization hints (Tinygrad: CONTIGUOUS.arg)
264 opts: SmallVec<[crate::types::ContiguousHint; 4]>,
265 },
266 ContiguousBackward {
267 src: Arc<UOp>,
268 },
269 After {
270 passthrough: Arc<UOp>,
271 deps: SmallVec<[Arc<UOp>; 4]>,
272 },
273 Precast {
274 src: Arc<UOp>,
275 },
276 Custom {
277 deps: SmallVec<[Arc<UOp>; 4]>,
278 code: String,
279 },
280 CustomI {
281 deps: SmallVec<[Arc<UOp>; 4]>,
282 code: String,
283 },
284
285 // Memory operations (low-level, after kernel splitting, 2 variants)
286 // Gate is on INDEX, not LOAD/STORE (following Tinygrad's model)
287 /// Load from buffer at index.
288 ///
289 /// - `buffer`: The buffer to load from
290 /// - `index`: The INDEX operation specifying where to load (may be gated)
291 /// - `alt`: Optional alternative value for gated loads (used when gate is false)
292 ///
293 /// When `alt` is Some, the load behaves as: `if gate { load(index) } else { alt }`.
294 /// This is used for masked loads in image processing and padding scenarios.
295 Load {
296 buffer: Arc<UOp>,
297 index: Arc<UOp>,
298 alt: Option<Arc<UOp>>,
299 },
300 Store {
301 index: Arc<UOp>,
302 value: Arc<UOp>,
303 ranges: SmallVec<[Arc<UOp>; 4]>,
304 },
305}
306
307impl Op {
308 /// Get all child UOps as a Vec of references.
309 ///
310 /// This is the convenient API for traversing the graph.
311 /// Allocates a Vec but is simple to use.
312 pub fn children(&self) -> SmallVec<[&Arc<UOp>; 4]> {
313 match self {
314 // Nullary operations
315 Self::Const(_)
316 | Self::Unique(_)
317 | Self::Device(_)
318 | Self::Noop
319 | Self::Invalid
320 | Self::DefineLocal(_)
321 | Self::VConst { .. }
322 | Self::DefineVar { .. }
323 | Self::DefineReg { .. } => SmallVec::new(),
324
325 // Param has optional device child — pre-kernel PARAMs have device, codegen PARAMs don't
326 Self::Param { device: Some(d), .. } => SmallVec::from_slice(&[d]),
327 Self::Param { device: None, .. } => SmallVec::new(),
328
329 // Graph organization operations
330 Self::Sink { sources } | Self::Group { sources } => sources.iter().collect(),
331
332 // Grouped operations
333 Self::Unary(_, x) => SmallVec::from_slice(&[x]),
334 Self::Binary(_, a, b) => SmallVec::from_slice(&[a, b]),
335 Self::Ternary(_, a, b, c) => SmallVec::from_slice(&[a, b, c]),
336
337 // Type operations
338 Self::Cast { src, .. } | Self::BitCast { src, .. } => SmallVec::from_slice(&[src]),
339
340 // Special operations
341 Self::MSelect { buffer, .. } => SmallVec::from_slice(&[buffer]),
342 Self::Special { end, .. } => SmallVec::from_slice(&[end]),
343
344 // Buffer operations
345 Self::Buffer { unique, device, .. } => SmallVec::from_slice(&[unique, device]),
346 Self::BufferView { buffer, .. } => SmallVec::from_slice(&[buffer]),
347 Self::Bufferize { compute, ranges, .. } => {
348 let mut children = SmallVec::from_slice(&[compute]);
349 children.extend(ranges.iter());
350 children
351 }
352 Self::Index { buffer, indices, gate } => {
353 let mut children = SmallVec::from_slice(&[buffer]);
354 children.extend(indices.iter());
355 children.extend(gate);
356 children
357 }
358 Self::PointerIndex { ptr, offset } => SmallVec::from_slice(&[ptr, offset]),
359 Self::Copy { src, device } => SmallVec::from_slice(&[src, device]),
360 Self::MStack { buffers } => buffers.iter().collect(),
361
362 // Movement operations
363 Self::Reshape { src, new_shape } => SmallVec::from_slice(&[src, new_shape]),
364 Self::Permute { src, .. } | Self::Flip { src, .. } | Self::Multi { src, .. } => {
365 SmallVec::from_slice(&[src])
366 }
367 Self::Expand { src, new_shape } => SmallVec::from_slice(&[src, new_shape]),
368 Self::Pad { src, begin_pads, end_pads } => SmallVec::from_slice(&[src, begin_pads, end_pads]),
369 Self::Shrink { src, begins, ends } => SmallVec::from_slice(&[src, begins, ends]),
370
371 // Reduction operations
372 Self::ReduceAxis { src, .. } => SmallVec::from_slice(&[src]),
373 Self::Reduce { src, ranges, .. } => {
374 let mut children = SmallVec::from_slice(&[src]);
375 children.extend(ranges.iter());
376 children
377 }
378 Self::AllReduce { src, device, .. } => SmallVec::from_slice(&[src, device]),
379
380 // Control flow operations
381 Self::If { condition, body } => {
382 let mut children = SmallVec::from_slice(&[condition]);
383 children.extend(body.iter());
384 children
385 }
386 Self::EndIf { if_op } => SmallVec::from_slice(&[if_op]),
387 Self::Range { end, deps, .. } => {
388 let mut children = SmallVec::from_slice(&[end]);
389 children.extend(deps.iter());
390 children
391 }
392 Self::End { computation, ranges } => {
393 let mut children = SmallVec::from_slice(&[computation]);
394 children.extend(ranges.iter());
395 children
396 }
397 Self::Barrier { src, deps } => {
398 let mut children = SmallVec::from_slice(&[src]);
399 children.extend(deps.iter());
400 children
401 }
402
403 // Vector operations
404 Self::Vectorize { elements } => elements.iter().collect(),
405 Self::Gep { vector, .. } => SmallVec::from_slice(&[vector]),
406 Self::Cat { sources } | Self::PtrCat { sources } => sources.iter().collect(),
407
408 // Symbolic/Define operations
409 Self::Bind { var, value } => SmallVec::from_slice(&[var, value]),
410
411 // Advanced operations
412 Self::Wmma { a, b, c, .. } => SmallVec::from_slice(&[a, b, c]),
413 Self::Contract { src, .. }
414 | Self::Unroll { src, .. }
415 | Self::Detach { src }
416 | Self::Contiguous { src, .. }
417 | Self::ContiguousBackward { src }
418 | Self::Precast { src } => SmallVec::from_slice(&[src]),
419 Self::Kernel { sources, ast } => {
420 let mut children: SmallVec<[&Arc<UOp>; 4]> = sources.iter().collect();
421 children.push(ast);
422 children
423 }
424 Self::Assign { target, value, movement_ops } => {
425 let mut children = SmallVec::from_slice(&[target, value]);
426 if let Some(mops) = movement_ops {
427 children.push(mops);
428 }
429 children
430 }
431 Self::After { passthrough, deps } => {
432 let mut children = SmallVec::from_slice(&[passthrough]);
433 children.extend(deps.iter());
434 children
435 }
436 Self::Custom { deps, .. } | Self::CustomI { deps, .. } => deps.iter().collect(),
437
438 // Memory operations
439 Self::Load { buffer, index, alt } => {
440 let mut children = SmallVec::from_slice(&[buffer, index]);
441 children.extend(alt);
442 children
443 }
444 Self::Store { index, value, ranges } => {
445 let mut children = SmallVec::from_slice(&[index, value]);
446 children.extend(ranges.iter());
447 children
448 }
449 }
450 }
451
452 /// Get all child UOps as a Vec of owned Rcs (cloned).
453 ///
454 /// Similar to `children()` but returns owned Rcs instead of references.
455 /// Useful when you need to reconstruct nodes or store sources.
456 pub fn sources(&self) -> SmallVec<[Arc<UOp>; 4]> {
457 self.children().iter().map(|rc| (*rc).clone()).collect()
458 }
459
460 /// Apply a function to each child UOp.
461 pub fn map_child<F>(&self, mut f: F)
462 where
463 F: FnMut(&Arc<UOp>),
464 {
465 for child in self.children() {
466 f(child);
467 }
468 }
469
470 /// Check if this operation is a movement operation.
471 ///
472 /// Movement operations transform tensor shapes without changing data values:
473 /// - RESHAPE: Change shape with same number of elements
474 /// - PERMUTE: Transpose/reorder axes
475 /// - EXPAND: Broadcast to larger shape
476 /// - PAD: Add padding around tensor
477 /// - SHRINK: Extract sub-region
478 /// - FLIP: Reverse along axes
479 ///
480 /// Note: MULTI is not considered a pure movement op as it has different semantics.
481 pub fn is_movement(&self) -> bool {
482 matches!(
483 self,
484 Self::Reshape { .. }
485 | Self::Permute { .. }
486 | Self::Expand { .. }
487 | Self::Pad { .. }
488 | Self::Shrink { .. }
489 | Self::Flip { .. }
490 )
491 }
492
493 /// Get the source index where ranges start being "ended" by this operation.
494 ///
495 /// Based on Tinygrad's `range_start` dict (ops.py:28).
496 /// Returns `Some(index)` if this operation ends ranges, `None` otherwise.
497 ///
498 /// Operations that end ranges:
499 /// - BUFFERIZE: ranges start at index 1 (compute is 0, ranges are 1+)
500 /// - REDUCE: ranges start at index 1 (src is 0, ranges are 1+)
501 /// - STORE: ranges start at index 2 (index=0, value=1, ranges=2+)
502 /// - WMMA: ranges start at index 3 (a=0, b=1, c=2)
503 /// - END: ranges start at index 1 (computation=0, ranges=1+)
504 ///
505 /// These operations mark range boundaries in the computation graph.
506 /// Any RANGE operations in sources at or after the returned index
507 /// are considered "ended" and removed from scope.
508 ///
509 /// # Examples
510 ///
511 /// ```ignore
512 /// use morok_ir::Op;
513 ///
514 /// // BUFFERIZE ends ranges starting at source index 1
515 /// let bufferize_op = Op::Bufferize { /* ... */ };
516 /// assert_eq!(bufferize_op.range_ending_src_index(), Some(1));
517 ///
518 /// // Regular arithmetic operations don't end ranges
519 /// let binary_op = Op::Binary(/* ... */);
520 /// assert_eq!(binary_op.range_ending_src_index(), None);
521 /// ```
522 pub fn range_ending_src_index(&self) -> Option<usize> {
523 // Source layout for range-ending ops:
524 // - BUFFERIZE: compute=0, ranges=1+
525 // - REDUCE: src=0, ranges=1+
526 // - STORE: index=0, value=1, ranges=2+
527 // - WMMA: a=0, b=1, c=2, (ranges start at 3)
528 // - END: computation=0, ranges=1+
529 match self {
530 Self::Bufferize { .. } => Some(1),
531 Self::Reduce { .. } => Some(1),
532 Self::Store { .. } => Some(2),
533 Self::Wmma { .. } => Some(3),
534 Self::End { .. } => Some(1),
535 _ => None,
536 }
537 }
538
539 /// Check if this operation should be expanded when it has UNROLL inputs.
540 ///
541 /// Based on Tinygrad's expander.py:97-98 pattern which expands:
542 /// - ALU ops (Unary, Binary, Ternary)
543 /// - Type ops (Cast, BitCast)
544 /// - Vector ops (Gep, Vectorize)
545 /// - Tensor core ops (Wmma)
546 /// - Memory ops (Load, Store, Index)
547 /// - Buffer ops (Bufferize)
548 /// - Control flow (Reduce, End, After)
549 ///
550 /// These operations propagate vectorization through the computation graph
551 /// when any of their sources is an UNROLL operation.
552 pub fn is_expandable(&self) -> bool {
553 matches!(
554 self,
555 // ALU operations
556 Self::Unary(..) | Self::Binary(..) | Self::Ternary(..) |
557 // Type operations
558 Self::Cast { .. } | Self::BitCast { .. } |
559 // Vector operations
560 Self::Gep { .. } | Self::Vectorize { .. } |
561 // Tensor core
562 Self::Wmma { .. } |
563 // Memory operations
564 Self::Load { .. } | Self::Store { .. } |
565 Self::Index { .. } | Self::PointerIndex { .. } |
566 // Buffer operations
567 Self::Bufferize { .. } |
568 // Control flow (range-ending ops)
569 Self::Reduce { .. } | Self::End { .. } | Self::After { .. }
570 )
571 }
572
573 /// Get the "ended ranges" for this operation.
574 ///
575 /// These are the RANGE operations (and operations containing ranges)
576 /// that should be removed from scope after this operation.
577 ///
578 /// Based on Tinygrad's `ended_ranges` property (ops.py:296-299).
579 ///
580 /// # Returns
581 ///
582 /// A SmallVec of references to child UOps that represent ended ranges.
583 /// For operations that don't end ranges, returns an empty SmallVec.
584 ///
585 /// # Examples
586 ///
587 /// ```ignore
588 /// use morok_ir::{Op, UOp};
589 ///
590 /// // END operation ends its range arguments
591 /// let range = UOp::range(/* ... */);
592 /// let computation = UOp::const_(/* ... */);
593 /// let end_op = computation.end(vec![range.clone()]);
594 ///
595 /// // ended_ranges() returns the ranges that are closed
596 /// let ended = end_op.op().ended_ranges();
597 /// assert_eq!(ended.len(), 1);
598 /// ```
599 pub fn ended_ranges(&self) -> SmallVec<[&Arc<UOp>; 4]> {
600 if let Some(start_idx) = self.range_ending_src_index() {
601 let children = self.children();
602 children.into_iter().skip(start_idx).collect()
603 } else if let Self::After { deps, .. } = self {
604 // Tinygrad (ops.py:312): flatten([x.ended_ranges for x in self.src[1:]])
605 // AFTER propagates ended ranges from its dependency chain.
606 let mut result = SmallVec::new();
607 for dep in deps {
608 result.extend(dep.op().ended_ranges());
609 }
610 result
611 } else if matches!(self, Self::Copy { .. } | Self::BufferView { .. }) {
612 // Tinygrad (ops.py:314): return self.src[0].ranges
613 // COPY/BUFFER_VIEW ends all ranges from the source.
614 // We return the source itself (not individual ranges) — the
615 // InScopeRangesProperty handles the non-RANGE branch by looking
616 // up the ended node's in_scope_ranges and removing them all.
617 let children = self.children();
618 if children.is_empty() { SmallVec::new() } else { SmallVec::from_elem(children[0], 1) }
619 } else {
620 SmallVec::new()
621 }
622 }
623}