Skip to main content

morok_ir/uop/constructors/
compute.rs

1//! Mathematical operations: arithmetic, transcendental, bitwise, comparison.
2//!
3//! This module contains all computational operations:
4//! - Arithmetic: add, sub, mul, div, mod, pow, max, neg, abs, square, sign
5//! - Transcendental: sqrt, rsqrt, exp, exp2, log, log2, sin, cos, tan, erf, reciprocal
6//! - Rounding: trunc, floor, ceil, round
7//! - Bitwise: and, or, xor, shl, shr, not
8//! - Comparison: lt, le, eq, ne, gt, ge
9//! - Ternary: where, mulacc
10//! - Random: threefry
11//! - Scalar convenience: add_scalar, sub_scalar, mul_scalar, mod_scalar
12
13use std::sync::Arc;
14
15use morok_dtype::DType;
16use snafu::ensure;
17
18use crate::error::{InvalidDTypeForUnaryOpSnafu, WhereConditionNotBoolSnafu};
19use crate::op::Op;
20use crate::types::{BinaryOp, TernaryOp, UnaryOp};
21use crate::uop::UOp;
22use crate::{IntoUOp, Result};
23
24// =========================================================================
25// Macro Definitions
26// =========================================================================
27
28/// Macro for simple binary arithmetic operations with type promotion.
29macro_rules! binary_arith_ops {
30    ($($method:ident => $op:ident),+ $(,)?) => {
31        $(
32            #[track_caller]
33            pub fn $method(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
34                let (lhs, rhs, dtype) = Self::promote_and_cast(self.clone(), rhs.clone())?;
35                Self::validate_binary_shapes(&lhs, &rhs, BinaryOp::$op)?;
36                Ok(Self::new(Op::Binary(BinaryOp::$op, lhs, rhs), dtype))
37            }
38        )+
39    };
40}
41
42/// Macro for division-like operations that check for division by zero.
43macro_rules! division_ops {
44    ($($method:ident => $op:ident),+ $(,)?) => {
45        $(
46            #[track_caller]
47            pub fn $method(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
48                Self::check_division_by_zero(rhs)?;
49                let (lhs, rhs, dtype) = Self::promote_and_cast(self.clone(), rhs.clone())?;
50                Self::validate_binary_shapes(&lhs, &rhs, BinaryOp::$op)?;
51                Ok(Self::new(Op::Binary(BinaryOp::$op, lhs, rhs), dtype))
52            }
53        )+
54    };
55}
56
57/// Macro for bitwise binary operations with type promotion and dtype validation.
58macro_rules! bitwise_binary_ops {
59    ($($method:ident => $op:ident),+ $(,)?) => {
60        $(
61            pub fn $method(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
62                let (lhs, rhs, dtype) = Self::promote_and_cast(self.clone(), rhs.clone())?;
63                Self::check_bitwise_dtype(dtype.clone(), BinaryOp::$op)?;
64                Self::validate_binary_shapes(&lhs, &rhs, BinaryOp::$op)?;
65                Ok(Self::new(Op::Binary(BinaryOp::$op, lhs, rhs), dtype))
66            }
67        )+
68    };
69}
70
71/// Macro for shift operations that only check LHS dtype.
72macro_rules! shift_ops {
73    ($($method:ident => $op:ident),+ $(,)?) => {
74        $(
75            pub fn $method(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
76                let dtype = self.dtype();
77                Self::check_bitwise_dtype(dtype.clone(), BinaryOp::$op)?;
78                Self::validate_binary_shapes(self, rhs, BinaryOp::$op)?;
79                Ok(Self::new(Op::Binary(BinaryOp::$op, self.clone(), rhs.clone()), dtype))
80            }
81        )+
82    };
83}
84
85/// Macro for comparison operations.
86/// Preserves vectorization: <N x T> cmp <N x T> → <N x bool>
87macro_rules! cmp_ops {
88    ($($method:ident => $op:ident),+ $(,)?) => {
89        $(
90            #[track_caller]
91            pub fn $method(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
92                // Use type promotion to validate types and find common type
93                let (lhs, rhs, dtype) = Self::promote_and_cast(self.clone(), rhs.clone())?;
94                Self::validate_binary_shapes(&lhs, &rhs, BinaryOp::$op)?;
95                // Preserve vectorization: <N x T> cmp <N x T> → <N x bool>
96                let vcount = dtype.vcount();
97                let result_dtype = if vcount > 1 { DType::Bool.vec(vcount) } else { DType::Bool };
98                Ok(Self::new(Op::Binary(BinaryOp::$op, lhs, rhs), result_dtype))
99            }
100        )+
101    };
102}
103
104/// Macro for transcendental functions that require float dtype.
105macro_rules! transcendental_ops {
106    ($($method:ident => $op:ident),+ $(,)?) => {
107        $(
108            #[track_caller]
109            pub fn $method(self: &Arc<Self>) -> Result<Arc<Self>> {
110                let dtype = self.dtype();
111                ensure!(dtype.is_float(), InvalidDTypeForUnaryOpSnafu { operation: UnaryOp::$op, dtype });
112                Ok(Self::new(Op::Unary(UnaryOp::$op, self.clone()), dtype))
113            }
114        )+
115    };
116}
117
118/// Macro for scalar convenience wrappers.
119macro_rules! scalar_ops {
120    ($($method:ident => $op_method:ident),+ $(,)?) => {
121        $(
122            pub fn $method<T: IntoUOp>(lhs: Arc<Self>, rhs: T) -> Result<Arc<Self>> {
123                let rhs_uop = rhs.into_uop(lhs.dtype());
124                lhs.$op_method(&rhs_uop)
125            }
126        )+
127    };
128}
129
130// =========================================================================
131// Panicking Wrapper Macro
132// =========================================================================
133
134/// Macro to generate panicking wrappers for try_* binary methods.
135///
136/// These are for use in pattern rewrites where types are already validated.
137/// Each method panics on type mismatch with a clear error message and location.
138macro_rules! panicking_binary_wrapper {
139    ($($method:ident => $try_method:ident),+ $(,)?) => {
140        $(
141            #[doc = concat!("Panicking version of `", stringify!($try_method), "`.")]
142            #[doc = ""]
143            #[doc = "For use in pattern rewrites where types are validated."]
144            #[doc = "Panics on type mismatch."]
145            #[track_caller]
146            pub fn $method(self: &Arc<Self>, rhs: &Arc<Self>) -> Arc<Self> {
147                self.$try_method(rhs).expect(concat!(stringify!($method), ": type mismatch"))
148            }
149        )+
150    };
151}
152
153impl UOp {
154    // =========================================================================
155    // Arithmetic Operations
156    // =========================================================================
157
158    binary_arith_ops! {
159        try_add => Add,
160        try_sub => Sub,
161        try_mul => Mul,
162    }
163
164    division_ops! {
165        try_mod => Mod,
166    }
167
168    /// Division with automatic type-based operator selection.
169    ///
170    /// Uses Idiv for integer types and Fdiv for float types.
171    #[track_caller]
172    pub fn try_div(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
173        Self::check_division_by_zero(rhs)?;
174        let (lhs, rhs, dtype) = Self::promote_and_cast(self.clone(), rhs.clone())?;
175
176        // Choose division operator based on dtype
177        let op = if dtype.is_float() { BinaryOp::Fdiv } else { BinaryOp::Idiv };
178
179        Self::validate_binary_shapes(&lhs, &rhs, op)?;
180        Ok(Self::new(Op::Binary(op, lhs, rhs), dtype))
181    }
182
183    /// Maximum of two values: max(a, b).
184    pub fn try_max(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
185        let (lhs, rhs, dtype) = Self::promote_and_cast(self.clone(), rhs.clone())?;
186        Self::validate_binary_shapes(&lhs, &rhs, BinaryOp::Max)?;
187        Ok(Self::new(Op::Binary(BinaryOp::Max, lhs, rhs), dtype))
188    }
189
190    /// Power: a^b.
191    pub fn try_pow(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
192        let (lhs, rhs, dtype) = Self::promote_and_cast(self.clone(), rhs.clone())?;
193        Self::validate_binary_shapes(&lhs, &rhs, BinaryOp::Pow)?;
194        Ok(Self::new(Op::Binary(BinaryOp::Pow, lhs, rhs), dtype))
195    }
196
197    /// Negation: -x.
198    ///
199    /// Produces `MUL(x, -1)` instead of `Unary(Neg, x)`, matching Tinygrad's approach.
200    /// `Unary(Neg)` is reintroduced late in codegen decompositions (`pm_neg_from_mul`)
201    /// AFTER `pm_lower_index_dtype` has resolved all Invalid nodes. This ensures
202    /// `propagate_invalid` (which only handles Binary ops) can push WHERE+Invalid
203    /// through negation.
204    ///
205    /// If `self` has a shape, broadcasts -1 to match (RESHAPE+EXPAND), matching
206    /// Tinygrad's Tensor-level `_broadcasted()`. If shapeless (schedule/symbolic
207    /// context), uses a scalar const directly.
208    #[track_caller]
209    pub fn neg(self: &Arc<Self>) -> Arc<Self> {
210        // Tinygrad: logical_not for bool, MUL(-1) for everything else
211        if self.dtype.is_bool() {
212            return self.not();
213        }
214        use crate::types::ConstValue;
215        let dtype = self.dtype.clone();
216        // Use Int(-1) or Float(-1.0) and let const_() handle dtype cast (wraps for unsigned).
217        // Matches Tinygrad where Python's -1 is cast via dtypes.as_const(-1, dtype).
218        let neg_one = if dtype.is_float() { ConstValue::Float(-1.0) } else { ConstValue::Int(-1) };
219        let mut neg_one_uop = Self::const_(dtype.clone(), neg_one);
220
221        // Broadcast scalar -1 to match self's shape if present.
222        // Matches Tinygrad's _broadcasted: reshape to (1,)*ndim then expand to shape.
223        if let Ok(Some(shape)) = self.shape()
224            && !shape.is_empty()
225        {
226            use crate::sint::SInt;
227            use smallvec::SmallVec;
228            let ones: SmallVec<[SInt; 4]> = shape.iter().map(|_| SInt::from(1)).collect();
229            neg_one_uop = neg_one_uop.try_reshape(&ones).expect("neg: reshape failed");
230            neg_one_uop = neg_one_uop.try_expand(shape).expect("neg: expand failed");
231        }
232
233        self.mul(&neg_one_uop)
234    }
235
236    /// Absolute value: |x|.
237    #[track_caller]
238    pub fn abs(self: &Arc<Self>) -> Arc<Self> {
239        let dtype = self.dtype.clone();
240        Self::new(Op::Unary(UnaryOp::Abs, self.clone()), dtype)
241    }
242
243    /// Square: x².
244    #[track_caller]
245    pub fn square(self: &Arc<Self>) -> Arc<Self> {
246        let dtype = self.dtype();
247        Self::new(Op::Unary(UnaryOp::Square, self.clone()), dtype)
248    }
249
250    /// Sign: -1 for negative, 0 for zero, 1 for positive.
251    pub fn sign(self: &Arc<Self>) -> Arc<Self> {
252        let dtype = self.dtype();
253        Self::new(Op::Unary(UnaryOp::Sign, self.clone()), dtype)
254    }
255
256    // =========================================================================
257    // Scalar Convenience Methods
258    // =========================================================================
259
260    scalar_ops! {
261        try_add_scalar => try_add,
262        try_sub_scalar => try_sub,
263        try_mul_scalar => try_mul,
264        try_mod_scalar => try_mod,
265    }
266
267    // =========================================================================
268    // Transcendental Operations
269    // =========================================================================
270
271    transcendental_ops! {
272        try_sqrt => Sqrt,
273        try_rsqrt => Rsqrt,
274        try_exp => Exp,
275        try_exp2 => Exp2,
276        try_log => Log,
277        try_log2 => Log2,
278        try_sin => Sin,
279        try_cos => Cos,
280        try_tan => Tan,
281    }
282
283    /// Error function: erf(x) - requires float dtype.
284    #[track_caller]
285    pub fn erf(self: &Arc<Self>) -> Result<Arc<Self>> {
286        let dtype = self.dtype();
287        ensure!(dtype.is_float(), InvalidDTypeForUnaryOpSnafu { operation: UnaryOp::Erf, dtype });
288        Ok(Self::new(Op::Unary(UnaryOp::Erf, self.clone()), dtype))
289    }
290
291    /// Reciprocal: 1/x - requires float dtype.
292    #[track_caller]
293    pub fn try_reciprocal(operand: &Arc<Self>) -> Result<Arc<Self>> {
294        let dtype = operand.dtype();
295        ensure!(dtype.is_float(), InvalidDTypeForUnaryOpSnafu { operation: UnaryOp::Reciprocal, dtype });
296        Ok(Self::new(Op::Unary(UnaryOp::Reciprocal, operand.clone()), dtype))
297    }
298
299    // =========================================================================
300    // Rounding Operations
301    // =========================================================================
302
303    /// Truncate towards zero.
304    #[track_caller]
305    pub fn trunc(operand: Arc<Self>) -> Arc<Self> {
306        let dtype = operand.dtype();
307        Self::new(Op::Unary(UnaryOp::Trunc, operand), dtype)
308    }
309
310    /// Floor: round towards -∞.
311    #[track_caller]
312    pub fn floor(operand: Arc<Self>) -> Arc<Self> {
313        let dtype = operand.dtype();
314        Self::new(Op::Unary(UnaryOp::Floor, operand), dtype)
315    }
316
317    /// Ceiling: round towards +∞.
318    #[track_caller]
319    pub fn ceil(operand: Arc<Self>) -> Arc<Self> {
320        let dtype = operand.dtype();
321        Self::new(Op::Unary(UnaryOp::Ceil, operand), dtype)
322    }
323
324    /// Round: round to nearest integer (half to even).
325    pub fn round(operand: Arc<Self>) -> Arc<Self> {
326        let dtype = operand.dtype();
327        Self::new(Op::Unary(UnaryOp::Round, operand), dtype)
328    }
329
330    // =========================================================================
331    // Bitwise Operations
332    // =========================================================================
333
334    bitwise_binary_ops! {
335        try_and_op => And,
336        try_or_op => Or,
337        try_xor_op => Xor,
338    }
339
340    shift_ops! {
341        try_shl_op => Shl,
342        try_shr_op => Shr,
343    }
344
345    /// Logical not: !x.
346    #[track_caller]
347    pub fn not(self: &Arc<Self>) -> Arc<Self> {
348        let dtype = self.dtype.clone();
349        Self::new(Op::Unary(UnaryOp::Not, self.clone()), dtype)
350    }
351
352    // =========================================================================
353    // Comparison Operations
354    // =========================================================================
355
356    cmp_ops! {
357        try_cmplt => Lt,
358        try_cmple => Le,
359        try_cmpeq => Eq,
360        try_cmpne => Ne,
361        try_cmpgt => Gt,
362        try_cmpge => Ge,
363    }
364
365    // =========================================================================
366    // Ternary Operations
367    // =========================================================================
368
369    /// Conditional selection: condition ? true_val : false_val.
370    ///
371    /// # Errors
372    /// - `WhereConditionNotBool` if condition dtype is not bool
373    #[track_caller]
374    pub fn try_where(condition: Arc<Self>, true_val: Arc<Self>, false_val: Arc<Self>) -> Result<Arc<Self>> {
375        let cond_dtype = condition.dtype();
376        ensure!(cond_dtype.is_bool(), WhereConditionNotBoolSnafu { actual: cond_dtype });
377
378        // Determine result dtype from the non-INVALID branch.
379        // INVALID is always created with Index type but may appear in WHERE with
380        // different branch dtype after propagate_invalid pushes CAST/ALU through WHERE.
381        let dtype = if matches!(true_val.op, Op::Invalid) { false_val.dtype() } else { true_val.dtype() };
382        let true_val = if matches!(true_val.op, Op::Invalid) && true_val.dtype() != dtype {
383            Self::new(Op::Invalid, dtype.clone())
384        } else {
385            true_val
386        };
387        let false_val = if matches!(false_val.op, Op::Invalid) && false_val.dtype() != dtype {
388            Self::new(Op::Invalid, dtype.clone())
389        } else {
390            false_val
391        };
392        Self::validate_ternary_shapes(&true_val, &false_val)?;
393        Ok(Self::new(Op::Ternary(TernaryOp::Where, condition, true_val, false_val), dtype))
394    }
395
396    /// Multiply-accumulate: a * b + c (fused operation).
397    ///
398    /// All operands must have matching dtypes (including vcount) for valid codegen.
399    /// Returns None if vcounts don't match - caller should fall back to Add(Mul(a,b), c).
400    pub fn try_mulacc(a: Arc<Self>, b: Arc<Self>, c: Arc<Self>) -> Result<Arc<Self>> {
401        // Validate all operands have matching dtypes (including vcount) for valid fmuladd
402        if a.dtype() != b.dtype() || a.dtype() != c.dtype() {
403            return crate::error::MulAccDtypeMismatchSnafu {
404                a_dtype: a.dtype(),
405                b_dtype: b.dtype(),
406                c_dtype: c.dtype(),
407            }
408            .fail();
409        }
410        let dtype = a.dtype();
411        // Validate all three operands have matching shapes
412        Self::validate_ternary_shapes(&a, &b)?;
413        Self::validate_ternary_shapes(&a, &c)?;
414        Ok(Self::new(Op::Ternary(TernaryOp::MulAcc, a, b, c), dtype))
415    }
416
417    // =========================================================================
418    // Panicking Wrappers (for pattern rewrites)
419    // =========================================================================
420    //
421    // These methods are for use in pattern rewrites where types are already
422    // validated by the pattern matcher. They panic on type mismatch, which
423    // indicates a bug in the pattern rather than a user error.
424    //
425    // Note: Using trailing underscore for `and_`, `or_`, `mod_` to avoid
426    // Rust keyword conflicts.
427
428    panicking_binary_wrapper! {
429        // Arithmetic
430        add => try_add,
431        sub => try_sub,
432        mul => try_mul,
433        idiv => try_div,
434        mod_ => try_mod,
435        max => try_max,
436
437        // Bitwise
438        and_ => try_and_op,
439        or_ => try_or_op,
440        xor => try_xor_op,
441        shl => try_shl_op,
442        shr => try_shr_op,
443
444        // Comparison
445        lt => try_cmplt,
446        le => try_cmple,
447        gt => try_cmpgt,
448        ge => try_cmpge,
449        eq => try_cmpeq,
450        ne => try_cmpne,
451    }
452
453    /// Low-level binary op constructor that auto-selects result dtype.
454    ///
455    /// Comparisons produce Bool; everything else inherits `lhs` dtype.
456    /// Matches Tinygrad's `UOp.alu()`. No type promotion or validation —
457    /// use only in rewrites where types are already correct.
458    pub fn alu(op: BinaryOp, lhs: Arc<Self>, rhs: Arc<Self>) -> Arc<Self> {
459        let dtype = if op.is_comparison() { DType::Bool } else { lhs.dtype() };
460        Self::new(Op::Binary(op, lhs, rhs), dtype)
461    }
462
463    // =========================================================================
464    // Random Operations
465    // =========================================================================
466
467    /// Threefry PRNG: threefry(x, key).
468    pub fn threefry(lhs: Arc<Self>, rhs: Arc<Self>) -> Result<Arc<Self>> {
469        let dtype = DType::UInt64; // Threefry always returns uint64
470        Self::validate_binary_shapes(&lhs, &rhs, BinaryOp::Threefry)?;
471        Ok(Self::new(Op::Binary(BinaryOp::Threefry, lhs, rhs), dtype))
472    }
473}