Skip to main content

morok_ir/uop/constructors/
control.rs

1//! Control flow: range, if/end, barrier, symbolic variables.
2//!
3//! This module contains control flow operations:
4//! - Loop constructs: range, range_const, range_axis
5//! - Conditionals: if_, endif, end
6//! - Synchronization: barrier
7//! - Symbolic variables: var, define_var, bind
8//! - Special: special (GPU dimension index)
9
10use std::sync::Arc;
11
12use morok_dtype::DType;
13use smallvec::SmallVec;
14
15use crate::op::Op;
16use crate::types::{AxisId, AxisType, ConstValue};
17use crate::uop::UOp;
18
19impl UOp {
20    // =========================================================================
21    // Range Operations
22    // =========================================================================
23
24    /// Create a Range operation with specified axis type.
25    pub fn range_axis(end: Arc<Self>, axis_id: AxisId, axis_type: AxisType) -> Arc<Self> {
26        Self::new(Op::Range { end, axis_id, axis_type, deps: SmallVec::new() }, DType::Index)
27    }
28
29    /// Create a RANGE operation with Loop axis type (convenience for tests).
30    ///
31    /// Uses `AxisId::Renumbered` since tests typically work with renumbered kernels.
32    pub fn range(end: Arc<Self>, axis_id: usize) -> Arc<Self> {
33        Self::range_axis(end, AxisId::Renumbered(axis_id), AxisType::Loop)
34    }
35
36    /// Create a RANGE operation with constant end value (convenience for tests).
37    ///
38    /// Uses `AxisId::Renumbered` since tests typically work with renumbered kernels.
39    /// Creates a `Loop` range (inside kernels).
40    pub fn range_const(end_value: i64, axis_id: usize) -> Arc<Self> {
41        let end = Self::const_(DType::Index, ConstValue::Int(end_value));
42        Self::range_axis(end, AxisId::Renumbered(axis_id), AxisType::Loop)
43    }
44
45    /// Create an OUTER RANGE operation with constant end value (convenience for tests).
46    ///
47    /// Uses `AxisId::Renumbered` since tests typically work with renumbered kernels.
48    /// Creates an `Outer` range (wraps entire kernels).
49    pub fn range_outer_const(end_value: i64, axis_id: usize) -> Arc<Self> {
50        let end = Self::const_(DType::Index, ConstValue::Int(end_value));
51        Self::range_axis(end, AxisId::Renumbered(axis_id), AxisType::Outer)
52    }
53
54    // =========================================================================
55    // Conditional Operations
56    // =========================================================================
57
58    /// Create a conditional block that executes body when condition is true.
59    ///
60    /// Body contains operations to execute; use `endif` to close the block.
61    pub fn if_(condition: Arc<Self>, body: SmallVec<[Arc<Self>; 4]>) -> Arc<Self> {
62        Self::new(Op::If { condition, body }, DType::Void)
63    }
64
65    /// End if block.
66    pub fn endif(if_op: Arc<Self>) -> Arc<Self> {
67        Self::new(Op::EndIf { if_op }, DType::Void)
68    }
69
70    /// End of range or reduce scope.
71    ///
72    /// Wraps self (the computation) and closes the specified ranges.
73    /// This marks the end of RANGE or REDUCE loops.
74    ///
75    /// # Arguments
76    ///
77    /// * `ranges` - The RANGE or REDUCE operations being closed
78    pub fn end(self: &Arc<Self>, ranges: SmallVec<[Arc<Self>; 4]>) -> Arc<Self> {
79        if ranges.is_empty() {
80            return self.clone();
81        }
82        Self::new(Op::End { computation: self.clone(), ranges }, DType::Void)
83    }
84
85    // =========================================================================
86    // Synchronization
87    // =========================================================================
88
89    /// Insert a synchronization barrier.
90    ///
91    /// Self passes through; `deps` are operations that must complete before
92    /// any consumer of this barrier executes.
93    pub fn barrier(self: &Arc<Self>, deps: SmallVec<[Arc<Self>; 4]>) -> Arc<Self> {
94        let dtype = self.dtype();
95        Self::new(Op::Barrier { src: self.clone(), deps }, dtype)
96    }
97
98    // =========================================================================
99    // Symbolic Variables
100    // =========================================================================
101
102    /// Create a DefineVar operation for range-bounded variables.
103    ///
104    /// Used in testing and symbolic analysis to define variables with known ranges.
105    /// Range is [min_val, max_val] inclusive.
106    pub fn var(name: impl Into<String>, dtype: DType, min_val: i64, max_val: i64) -> Arc<Self> {
107        Self::new(Op::DefineVar { name: name.into(), min_val, max_val }, dtype)
108    }
109
110    /// Define a symbolic variable with known bounds for range analysis.
111    ///
112    /// Range is [min_val, max_val] inclusive.
113    pub fn define_var(name: String, min_val: i64, max_val: i64) -> Arc<Self> {
114        Self::new(Op::DefineVar { name, min_val, max_val }, DType::Index)
115    }
116
117    /// Bind concrete value to symbolic variable.
118    pub fn bind(self: &Arc<Self>, value: Arc<Self>) -> Arc<Self> {
119        let dtype = self.dtype();
120        Self::new(Op::Bind { var: self.clone(), value }, dtype)
121    }
122
123    // =========================================================================
124    // Special Operations
125    // =========================================================================
126
127    /// Create a GPU-specific dimension variable (e.g., blockIdx.x, threadIdx.y).
128    ///
129    /// Unlike RANGE which is a loop, SPECIAL represents hardware-provided indices.
130    /// The `name` identifies the dimension (rendered as-is in codegen).
131    pub fn special(end: Arc<Self>, name: String) -> Arc<Self> {
132        Self::new(Op::Special { end, name }, DType::Index)
133    }
134}