morok_ir/uop/constructors/control.rs
1//! Control flow: range, if/end, barrier, symbolic variables.
2//!
3//! This module contains control flow operations:
4//! - Loop constructs: range, range_const, range_axis
5//! - Conditionals: if_, endif, end
6//! - Synchronization: barrier
7//! - Symbolic variables: var, define_var, bind
8//! - Special: special (GPU dimension index)
9
10use std::sync::Arc;
11
12use morok_dtype::DType;
13use smallvec::SmallVec;
14
15use crate::op::Op;
16use crate::types::{AxisId, AxisType, ConstValue};
17use crate::uop::UOp;
18
19impl UOp {
20 // =========================================================================
21 // Range Operations
22 // =========================================================================
23
24 /// Create a Range operation with specified axis type.
25 pub fn range_axis(end: Arc<Self>, axis_id: AxisId, axis_type: AxisType) -> Arc<Self> {
26 Self::new(Op::Range { end, axis_id, axis_type, deps: SmallVec::new() }, DType::Index)
27 }
28
29 /// Create a RANGE operation with Loop axis type (convenience for tests).
30 ///
31 /// Uses `AxisId::Renumbered` since tests typically work with renumbered kernels.
32 pub fn range(end: Arc<Self>, axis_id: usize) -> Arc<Self> {
33 Self::range_axis(end, AxisId::Renumbered(axis_id), AxisType::Loop)
34 }
35
36 /// Create a RANGE operation with constant end value (convenience for tests).
37 ///
38 /// Uses `AxisId::Renumbered` since tests typically work with renumbered kernels.
39 /// Creates a `Loop` range (inside kernels).
40 pub fn range_const(end_value: i64, axis_id: usize) -> Arc<Self> {
41 let end = Self::const_(DType::Index, ConstValue::Int(end_value));
42 Self::range_axis(end, AxisId::Renumbered(axis_id), AxisType::Loop)
43 }
44
45 /// Create an OUTER RANGE operation with constant end value (convenience for tests).
46 ///
47 /// Uses `AxisId::Renumbered` since tests typically work with renumbered kernels.
48 /// Creates an `Outer` range (wraps entire kernels).
49 pub fn range_outer_const(end_value: i64, axis_id: usize) -> Arc<Self> {
50 let end = Self::const_(DType::Index, ConstValue::Int(end_value));
51 Self::range_axis(end, AxisId::Renumbered(axis_id), AxisType::Outer)
52 }
53
54 // =========================================================================
55 // Conditional Operations
56 // =========================================================================
57
58 /// Create a conditional block that executes body when condition is true.
59 ///
60 /// Body contains operations to execute; use `endif` to close the block.
61 pub fn if_(condition: Arc<Self>, body: SmallVec<[Arc<Self>; 4]>) -> Arc<Self> {
62 Self::new(Op::If { condition, body }, DType::Void)
63 }
64
65 /// End if block.
66 pub fn endif(if_op: Arc<Self>) -> Arc<Self> {
67 Self::new(Op::EndIf { if_op }, DType::Void)
68 }
69
70 /// End of range or reduce scope.
71 ///
72 /// Wraps self (the computation) and closes the specified ranges.
73 /// This marks the end of RANGE or REDUCE loops.
74 ///
75 /// # Arguments
76 ///
77 /// * `ranges` - The RANGE or REDUCE operations being closed
78 pub fn end(self: &Arc<Self>, ranges: SmallVec<[Arc<Self>; 4]>) -> Arc<Self> {
79 if ranges.is_empty() {
80 return self.clone();
81 }
82 Self::new(Op::End { computation: self.clone(), ranges }, DType::Void)
83 }
84
85 // =========================================================================
86 // Synchronization
87 // =========================================================================
88
89 /// Insert a synchronization barrier.
90 ///
91 /// Self passes through; `deps` are operations that must complete before
92 /// any consumer of this barrier executes.
93 pub fn barrier(self: &Arc<Self>, deps: SmallVec<[Arc<Self>; 4]>) -> Arc<Self> {
94 let dtype = self.dtype();
95 Self::new(Op::Barrier { src: self.clone(), deps }, dtype)
96 }
97
98 // =========================================================================
99 // Symbolic Variables
100 // =========================================================================
101
102 /// Create a DefineVar operation for range-bounded variables.
103 ///
104 /// Used in testing and symbolic analysis to define variables with known ranges.
105 /// Range is [min_val, max_val] inclusive.
106 pub fn var(name: impl Into<String>, dtype: DType, min_val: i64, max_val: i64) -> Arc<Self> {
107 Self::new(Op::DefineVar { name: name.into(), min_val, max_val }, dtype)
108 }
109
110 /// Define a symbolic variable with known bounds for range analysis.
111 ///
112 /// Range is [min_val, max_val] inclusive.
113 pub fn define_var(name: String, min_val: i64, max_val: i64) -> Arc<Self> {
114 Self::new(Op::DefineVar { name, min_val, max_val }, DType::Index)
115 }
116
117 /// Bind concrete value to symbolic variable.
118 pub fn bind(self: &Arc<Self>, value: Arc<Self>) -> Arc<Self> {
119 let dtype = self.dtype();
120 Self::new(Op::Bind { var: self.clone(), value }, dtype)
121 }
122
123 // =========================================================================
124 // Special Operations
125 // =========================================================================
126
127 /// Create a GPU-specific dimension variable (e.g., blockIdx.x, threadIdx.y).
128 ///
129 /// Unlike RANGE which is a loop, SPECIAL represents hardware-provided indices.
130 /// The `name` identifies the dimension (rendered as-is in codegen).
131 pub fn special(end: Arc<Self>, name: String) -> Arc<Self> {
132 Self::new(Op::Special { end, name }, DType::Index)
133 }
134}