Skip to main content

morok_ir/uop/
helpers.rs

1//! Helper methods for UOp pattern matching and simplification.
2//!
3//! These methods support symbolic pattern matching, based on Tinygrad's ops.py.
4
5use std::collections::{HashMap, HashSet};
6use std::sync::Arc;
7
8use crate::op::Op;
9use crate::types::{AxisType, BinaryOp, ConstValue};
10use crate::uop::UOp;
11
12impl UOp {
13    /// Returns the largest known integer that divides this UOp.
14    ///
15    /// Based on Tinygrad's `const_factor()` (ops.py:693-700).
16    /// For MUL, only checks immediate CONST children (not recursive).
17    pub fn const_factor(&self) -> i64 {
18        match &self.op {
19            Op::Const(cv) => match &cv.0 {
20                ConstValue::Int(i) => *i,
21                ConstValue::UInt(u) => *u as i64,
22                _ => 1,
23            },
24            // VCONST: GCD of all elements (Tinygrad ops.py:697)
25            Op::VConst { values } => values
26                .iter()
27                .filter_map(|v| match v {
28                    ConstValue::Int(i) => Some(*i),
29                    ConstValue::UInt(u) => Some(*u as i64),
30                    _ => None,
31                })
32                .map(|v| v.abs())
33                .reduce(gcd)
34                .unwrap_or(1),
35            // MUL: only immediate CONST child, matching Tinygrad exactly
36            Op::Binary(BinaryOp::Mul, a, b) => {
37                if let Op::Const(cv) = &a.op
38                    && let ConstValue::Int(i) = cv.0
39                {
40                    return i;
41                }
42                if let Op::Const(cv) = &b.op
43                    && let ConstValue::Int(i) = cv.0
44                {
45                    return i;
46                }
47                1
48            }
49            Op::Binary(BinaryOp::Add, a, b) => gcd(a.const_factor().abs(), b.const_factor().abs()),
50            _ => 1,
51        }
52    }
53
54    /// Returns `self / v` if `v` divides `self` exactly, otherwise None.
55    ///
56    /// Based on Tinygrad's `divides()` (ops.py lines 703-711).
57    /// Delegates to [`divides_int`] for constant divisors.
58    pub fn divides(self: &Arc<Self>, v: &Arc<Self>) -> Option<Arc<Self>> {
59        if let Op::Const(cv) = v.op()
60            && let ConstValue::Int(divisor) = cv.0
61        {
62            return self.divides_int(divisor);
63        }
64        None
65    }
66
67    /// Returns `self / v` if integer `v` divides all terms exactly, otherwise None.
68    ///
69    /// Based on Tinygrad's `divides(v: int)` (ops.py:701-709).
70    /// Recursively handles Const, Add, and Mul operations.
71    pub fn divides_int(self: &Arc<Self>, v: i64) -> Option<Arc<Self>> {
72        if v == 1 {
73            return Some(Arc::clone(self));
74        }
75        if v == 0 {
76            return None;
77        }
78        match self.op() {
79            Op::Const(cv) => {
80                let ConstValue::Int(val) = cv.0 else { return None };
81                if val % v == 0 { Some(Self::const_(self.dtype(), ConstValue::Int(val / v))) } else { None }
82            }
83            // VCONST: divide each element if all are divisible (Tinygrad ops.py:704)
84            Op::VConst { values } => {
85                let divided: Option<Vec<ConstValue>> = values
86                    .iter()
87                    .map(|val| match val {
88                        ConstValue::Int(i) if i % v == 0 => Some(ConstValue::Int(i / v)),
89                        _ => None,
90                    })
91                    .collect();
92                divided.map(|v| UOp::vconst(v, self.dtype().scalar_dtype()))
93            }
94            Op::Binary(BinaryOp::Add, a, b) => {
95                let d0 = a.divides_int(v)?;
96                let d1 = b.divides_int(v)?;
97                d0.try_add(&d1).ok()
98            }
99            Op::Binary(BinaryOp::Mul, a, b) => {
100                if let Some(d0) = a.divides_int(v) {
101                    return d0.try_mul(b).ok();
102                }
103                if let Some(d1) = b.divides_int(v) {
104                    return a.try_mul(&d1).ok();
105                }
106                None
107            }
108            _ => None,
109        }
110    }
111
112    /// Returns `self / v` if exact division by UOp `v` is possible.
113    ///
114    /// Based on Tinygrad's `divide_exact(v: UOp)` (ops.py:717-726).
115    /// Handles identity, constant divisors, Add recursion, and Mul factoring.
116    pub fn divide_exact(self: &Arc<Self>, v: &Arc<Self>) -> Option<Arc<Self>> {
117        if Arc::ptr_eq(self, v) {
118            return Some(self.const_like(1i64));
119        }
120        if let Op::Const(cv) = v.op()
121            && let ConstValue::Int(d) = cv.0
122        {
123            return self.divides_int(d);
124        }
125        if let Op::Binary(BinaryOp::Add, a, b) = self.op() {
126            let d0 = a.divide_exact(v)?;
127            let d1 = b.divide_exact(v)?;
128            return d0.try_add(&d1).ok();
129        }
130        if let Op::Binary(BinaryOp::Mul, a, b) = self.op() {
131            if let Some(d) = a.divide_exact(v) {
132                return d.try_mul(b).ok();
133            }
134            if let Some(d) = b.divide_exact(v) {
135                return a.try_mul(&d).ok();
136            }
137        }
138        None
139    }
140
141    /// Computes the symbolic GCD of multiple UOps, returning a UOp.
142    ///
143    /// Based on Tinygrad's `UOp.gcd()` (ops.py:713-716).
144    /// Finds both numeric GCD of const_factors AND common symbolic MUL factors.
145    ///
146    /// For inputs `6*a*b` and `4*a*c`, returns `2*a` (numeric GCD=2, common factor=a).
147    pub fn symbolic_gcd(uops: &[Arc<Self>]) -> Arc<Self> {
148        assert!(!uops.is_empty(), "symbolic_gcd requires at least one uop");
149
150        // Step 1: decompose each uop into (term, factor) where term = uop / factor
151        let decomp: Vec<(Arc<Self>, i64)> = uops
152            .iter()
153            .map(|u| {
154                let f = u.const_factor();
155                let term = if f == 1 || f == 0 {
156                    Arc::clone(u)
157                } else {
158                    u.divides_int(f).unwrap_or_else(|| u.const_like(1i64))
159                };
160                (term, f)
161            })
162            .collect();
163
164        // Step 2: split each term into MUL factors, build Counter (ptr → count)
165        let counters: Vec<HashMap<*const Self, (Arc<Self>, usize)>> = decomp
166            .iter()
167            .map(|(term, _)| {
168                let mut counter: HashMap<*const Self, (Arc<Self>, usize)> = HashMap::new();
169                for factor in term.split_uop(BinaryOp::Mul) {
170                    let ptr = Arc::as_ptr(&factor);
171                    counter.entry(ptr).and_modify(|(_, c)| *c += 1).or_insert((factor, 1));
172                }
173                counter
174            })
175            .collect();
176
177        // Step 3: intersect counters (keep factors present in ALL terms with min count)
178        let mut common = counters[0].clone();
179        for other in &counters[1..] {
180            common.retain(|ptr, (_, count)| {
181                if let Some((_, other_count)) = other.get(ptr) {
182                    *count = (*count).min(*other_count);
183                    true
184                } else {
185                    false
186                }
187            });
188        }
189
190        // Step 4: numeric GCD of all const_factors
191        let numeric = decomp.iter().map(|(_, f)| f.abs()).reduce(gcd).unwrap_or(1);
192
193        // Step 5: multiply common symbolic factors with numeric GCD
194        let mut result = uops[0].const_like(numeric);
195        for (factor, count) in common.values() {
196            // Skip CONST(1) factors from divides_int normalization
197            if let Op::Const(cv) = factor.op()
198                && matches!(cv.0, ConstValue::Int(1))
199            {
200                continue;
201            }
202            for _ in 0..*count {
203                result = result.try_mul(factor).expect("symbolic_gcd: mul failed");
204            }
205        }
206
207        result
208    }
209
210    /// Separates a constant term from a binary expression.
211    ///
212    /// Returns (non_const_part, const_value).
213    /// Based on Tinygrad's `pop_const()` (ops.py lines 712-713).
214    ///
215    /// # Examples
216    ///
217    /// ```ignore
218    /// // (x + 5).pop_const(ADD) = (x, Some(Int(5)))
219    /// // (x + y).pop_const(ADD) = (x + y, None)
220    /// // x.pop_const(ADD) = (x, None)
221    /// ```
222    pub fn pop_const(self: &Arc<Self>, op: BinaryOp) -> (Arc<Self>, Option<ConstValue>) {
223        if let Op::Binary(self_op, a, b) = self.op()
224            && *self_op == op
225        {
226            // Check if right operand is constant
227            if let Op::Const(cv) = b.op() {
228                return (a.clone(), Some(cv.0));
229            }
230            // Check if left operand is constant (for commutative ops)
231            if op.is_commutative()
232                && let Op::Const(cv) = a.op()
233            {
234                return (b.clone(), Some(cv.0));
235            }
236        }
237
238        (self.clone(), None)
239    }
240
241    /// Splits an associative operation chain into its individual terms.
242    ///
243    /// Based on Tinygrad's `split_uop()` (ops.py lines 464-467).
244    ///
245    /// # Examples
246    ///
247    /// ```ignore
248    /// // (x + y + z).split_uop(ADD) = [x, y, z]
249    /// // (x + y).split_uop(ADD) = [x, y]
250    /// // x.split_uop(ADD) = [x]
251    /// ```
252    pub fn split_uop(self: &Arc<Self>, sep: BinaryOp) -> Vec<Arc<Self>> {
253        let mut result = Vec::new();
254        let mut stack = vec![self.clone()];
255
256        while let Some(node) = stack.pop() {
257            if let Op::Binary(op, a, b) = node.op()
258                && *op == sep
259            {
260                // Add operands to stack in reverse order to maintain left-to-right
261                stack.push(b.clone());
262                stack.push(a.clone());
263                continue;
264            }
265            result.push(node);
266        }
267
268        result
269    }
270
271    /// Cached backward slice: set of all node IDs reachable from this UOp.
272    ///
273    /// O(1) membership test via `contains()`. Computed once and cached per-node.
274    /// Prefer this over `backward_slice()` when you only need to check if a
275    /// node is in the dependency set.
276    pub fn backward_slice_ids(self: &Arc<Self>) -> &HashSet<u64> {
277        use crate::uop::cached_property::CachedProperty;
278        use crate::uop::properties::BackwardSliceProperty;
279        BackwardSliceProperty::get(self)
280    }
281
282    /// Returns all nodes that this UOp depends on (backward slice / dependency set).
283    ///
284    /// For membership tests, prefer [`backward_slice_ids()`] which returns a
285    /// cached `HashSet<u64>` with O(1) lookup.
286    pub fn backward_slice(self: &Arc<Self>) -> Vec<Arc<Self>> {
287        let mut visited = HashSet::new();
288        let mut result = Vec::new();
289        let mut stack = vec![self.clone()];
290
291        while let Some(node) = stack.pop() {
292            let ptr = Arc::as_ptr(&node);
293
294            if visited.contains(&ptr) {
295                continue;
296            }
297
298            visited.insert(ptr);
299            result.push(node.clone());
300
301            // Add all children to stack
302            node.op.map_child(|child| {
303                stack.push(child.clone());
304            });
305        }
306
307        result
308    }
309
310    /// Check if this UOp's size is divisible by the given amount.
311    ///
312    /// Returns `Some(quotient)` if divisible, `None` otherwise.
313    /// This is a convenience method for the optimizer to validate transformations.
314    ///
315    /// # Examples
316    ///
317    /// ```ignore
318    /// let range = UOp::range(SInt::Const(16), 0, AxisType::Loop);
319    /// assert_eq!(range.divisible_by(4), Some(4)); // 16 / 4 = 4
320    /// assert_eq!(range.divisible_by(5), None);    // 16 not divisible by 5
321    /// ```
322    pub fn divisible_by(self: &Arc<Self>, amount: usize) -> Option<usize> {
323        // For RANGE operations, check the end (size) field
324        if let Op::Range { end, .. } = self.op() {
325            // Check if end is a constant
326            if let Op::Const(cv) = end.op()
327                && let ConstValue::Int(sz) = cv.0
328                && sz > 0
329                && (sz as usize).is_multiple_of(amount)
330            {
331                return Some((sz as usize) / amount);
332            }
333
334            // Check using const_factor
335            let factor = end.const_factor();
336            if factor > 0 && (factor as usize).is_multiple_of(amount) {
337                return Some((factor as usize) / amount);
338            }
339        }
340
341        // For constants, check the value directly
342        if let Op::Const(cv) = self.op()
343            && let ConstValue::Int(val) = cv.0
344            && val > 0
345            && (val as usize).is_multiple_of(amount)
346        {
347            return Some((val as usize) / amount);
348        }
349
350        None
351    }
352
353    /// Create a new RANGE UOp with a different axis type.
354    ///
355    /// This is a convenience method for the optimizer to convert ranges between
356    /// axis types (e.g., LOOP → GLOBAL for parallelization).
357    ///
358    /// # Panics
359    ///
360    /// Panics if called on a non-RANGE operation.
361    ///
362    /// # Examples
363    ///
364    /// ```ignore
365    /// let loop_range = UOp::range_axis(UOp::index_const(16), 0, AxisType::Loop);
366    /// let global_range = loop_range.with_axis_type(AxisType::Global);
367    /// // global_range has same size and axis_id, but different axis type
368    /// ```
369    pub fn with_axis_type(self: &Arc<Self>, new_type: AxisType) -> Arc<Self> {
370        if let Op::Range { end, axis_id, .. } = self.op() {
371            Self::range_axis(end.clone(), *axis_id, new_type)
372        } else {
373            panic!("with_axis_type() called on non-RANGE operation: {:?}", self.op);
374        }
375    }
376
377    /// Extract the actual index from a range, stripping validity checks.
378    ///
379    /// If the range is a WHERE(valid, idx, invalid_marker), returns idx.
380    /// Otherwise, returns the range itself.
381    ///
382    /// This is used for range merging when comparing indexing patterns across
383    /// multiple consumers.
384    ///
385    /// Based on Tinygrad's `get_idx()` (ops.py:438-439).
386    ///
387    /// # Examples
388    ///
389    /// ```ignore
390    /// // Range with padding: WHERE(i < 5, i, SENTINEL)
391    /// let padded_range = UOp::where_op(valid, idx.clone(), invalid_marker)?;
392    /// assert!(Arc::ptr_eq(&padded_range.get_idx(), &idx));
393    ///
394    /// // Plain range: returns itself
395    /// let plain_range = UOp::range_axis(...);
396    /// assert!(Arc::ptr_eq(&plain_range.get_idx(), &plain_range));
397    /// ```
398    pub fn get_idx(self: &Arc<Self>) -> Arc<Self> {
399        use crate::types::TernaryOp;
400
401        match self.op() {
402            Op::Ternary(TernaryOp::Where, _, true_val, false_val) if Self::is_invalid_marker(false_val) => {
403                // WHERE(valid, idx, INVALID) → return idx
404                true_val.clone()
405            }
406            _ => self.clone(),
407        }
408    }
409
410    /// Extract the validity mask from a range.
411    ///
412    /// If the range is a WHERE(valid, idx, invalid_marker), returns valid.
413    /// Otherwise, returns constant true (always valid).
414    ///
415    /// This is used for range merging to combine validity conditions when
416    /// multiple consumers share compatible indexing patterns.
417    ///
418    /// Based on Tinygrad's `get_valid()` (ops.py:440-441).
419    ///
420    /// # Examples
421    ///
422    /// ```ignore
423    /// // Range with padding: WHERE(i < 5, i, SENTINEL)
424    /// let padded_range = UOp::where_op(valid.clone(), idx, invalid_marker)?;
425    /// assert!(Arc::ptr_eq(&padded_range.get_valid(), &valid));
426    ///
427    /// // Plain range: returns constant true
428    /// let plain_range = UOp::range_axis(...);
429    /// if let Op::Const(cv) = plain_range.get_valid().op() {
430    ///     assert_eq!(cv.0, ConstValue::Bool(true));
431    /// }
432    /// ```
433    pub fn get_valid(self: &Arc<Self>) -> Arc<Self> {
434        use crate::types::TernaryOp;
435        use morok_dtype::DType;
436
437        match self.op() {
438            Op::Ternary(TernaryOp::Where, cond, _, false_val) if Self::is_invalid_marker(false_val) => {
439                // WHERE(valid, idx, INVALID) → return valid
440                cond.clone()
441            }
442            Op::Invalid => {
443                // Bare Invalid is NOT valid (Tinygrad: self.arg is not Invalid → False)
444                Self::const_(DType::Bool, ConstValue::Bool(false))
445            }
446            _ => {
447                // Non-Invalid, non-WHERE: always valid
448                Self::const_(DType::Bool, ConstValue::Bool(true))
449            }
450        }
451    }
452
453    /// Check if a UOp represents an invalid index marker.
454    ///
455    /// Matches both scalar `Op::Invalid` and vectorized `VECTORIZE(Invalid, ..., Invalid)`
456    /// where ALL elements are Invalid. The vectorized form appears after expansion
457    /// broadcasts scalar Invalid across lanes.
458    ///
459    /// Uses `all()` semantics (entire vector must be Invalid). This differs from
460    /// `has_invalid()` in symbolic patterns which uses `any()` for guard semantics.
461    pub fn is_invalid_marker(uop: &Arc<Self>) -> bool {
462        match uop.op() {
463            Op::Invalid => true,
464            Op::Vectorize { elements } => {
465                !elements.is_empty() && elements.iter().all(|e| matches!(e.op(), Op::Invalid))
466            }
467            _ => false,
468        }
469    }
470
471    /// Create an invalid index marker.
472    ///
473    /// Invalid markers are used with WHERE operations to indicate out-of-bounds
474    /// or padded regions. The value is undefined and should never be used directly -
475    /// it exists only to be masked away by validity checks.
476    ///
477    /// # Returns
478    ///
479    /// A UOp representing an invalid index value.
480    ///
481    /// # Examples
482    ///
483    /// ```ignore
484    /// // Padding: WHERE(i < actual_size, i, invalid)
485    /// let invalid = UOp::invalid_marker();
486    /// let padded = UOp::where_op(valid, actual_idx, invalid)?;
487    /// ```
488    pub fn invalid_marker() -> Arc<Self> {
489        use morok_dtype::DType;
490
491        // Invalid marker for out-of-bounds indices (used in padding/masking)
492        Self::new(Op::Invalid, DType::Index)
493    }
494
495    /// Check if this UOp is a monotonically increasing function of its inputs.
496    ///
497    /// Returns true for:
498    /// - Irreducible ops (RANGE, CONST, DEFINE_VAR)
499    /// - ADD of increasing ops
500    /// - MUL/IDIV by non-negative constants
501    ///
502    /// Based on Tinygrad's `is_increasing()` (ops.py:689-694).
503    ///
504    /// # Examples
505    ///
506    /// ```ignore
507    /// // Constants are increasing
508    /// let c = UOp::const_(DType::Int32, ConstValue::Int(5));
509    /// assert!(c.is_increasing());
510    ///
511    /// // Range variables are increasing
512    /// let range = UOp::range_axis(UOp::index_const(16), 0, AxisType::Loop);
513    /// assert!(range.is_increasing());
514    ///
515    /// // x + y is increasing if both x and y are increasing
516    /// let sum = range.try_add(&c).unwrap();
517    /// assert!(sum.is_increasing());
518    ///
519    /// // x * 2 is increasing if x is increasing
520    /// let two = UOp::const_(DType::Index, ConstValue::Int(2));
521    /// let scaled = range.try_mul(&two).unwrap();
522    /// assert!(scaled.is_increasing());
523    /// ```
524    pub fn is_increasing(self: &Arc<Self>) -> bool {
525        match self.op() {
526            // Irreducible: RANGE, CONST, DEFINE_VAR
527            Op::Range { .. } | Op::Const(_) | Op::DefineVar { .. } => true,
528
529            // ADD: both operands must be increasing
530            Op::Binary(BinaryOp::Add, a, b) => a.is_increasing() && b.is_increasing(),
531
532            // MUL/IDIV by non-negative constant
533            Op::Binary(BinaryOp::Mul | BinaryOp::Idiv, a, b) => {
534                if let Op::Const(cv) = b.op() {
535                    matches!(cv.0, ConstValue::Int(n) if n >= 0) && a.is_increasing()
536                } else {
537                    false
538                }
539            }
540
541            _ => false,
542        }
543    }
544}
545
546/// Computes the greatest common divisor using Euclid's algorithm.
547/// Always returns a non-negative value.
548pub fn gcd(a: i64, b: i64) -> i64 {
549    let (mut a, mut b) = (a.abs(), b.abs());
550    while b != 0 {
551        let temp = b;
552        b = a % b;
553        a = temp;
554    }
555    a
556}
557
558/// Extension trait for BinaryOp to check if it's commutative.
559#[allow(dead_code)] // Used in pop_const for commutative check
560trait BinaryOpExt {
561    fn is_commutative(&self) -> bool;
562}
563
564impl BinaryOpExt for BinaryOp {
565    fn is_commutative(&self) -> bool {
566        matches!(
567            self,
568            BinaryOp::Add
569                | BinaryOp::Mul
570                | BinaryOp::And
571                | BinaryOp::Or
572                | BinaryOp::Xor
573                | BinaryOp::Max
574                | BinaryOp::Eq
575                | BinaryOp::Ne
576        )
577    }
578}
579
580#[cfg(test)]
581mod tests {
582    use super::*;
583    use morok_dtype::DType;
584
585    #[test]
586    fn test_const_factor_constant() {
587        let c = UOp::const_(DType::Int32, ConstValue::Int(6));
588        assert_eq!(c.const_factor(), 6);
589    }
590
591    #[test]
592    fn test_const_factor_multiplication() {
593        let x = UOp::var("x", DType::Int32, 0, 100);
594        let c = UOp::const_(DType::Int32, ConstValue::Int(6));
595        let mul = x.try_mul(&c).unwrap();
596        assert_eq!(mul.const_factor(), 6);
597    }
598
599    #[test]
600    fn test_const_factor_addition() {
601        let c1 = UOp::const_(DType::Int32, ConstValue::Int(6));
602        let c2 = UOp::const_(DType::Int32, ConstValue::Int(9));
603        let add = c1.try_add(&c2).unwrap();
604        assert_eq!(add.const_factor(), 3); // GCD(6, 9) = 3
605    }
606
607    #[test]
608    fn test_divides_constant_exact() {
609        let c = UOp::const_(DType::Int32, ConstValue::Int(12));
610        let divisor = UOp::const_(DType::Int32, ConstValue::Int(3));
611        let result = c.divides(&divisor);
612
613        assert!(result.is_some());
614        if let Some(r) = result {
615            if let Op::Const(cv) = r.op() {
616                assert_eq!(cv.0, ConstValue::Int(4));
617            } else {
618                panic!("Expected constant result");
619            }
620        }
621    }
622
623    #[test]
624    fn test_divides_constant_not_exact() {
625        let c = UOp::const_(DType::Int32, ConstValue::Int(10));
626        let divisor = UOp::const_(DType::Int32, ConstValue::Int(3));
627        let result = c.divides(&divisor);
628
629        assert!(result.is_none());
630    }
631
632    #[test]
633    fn test_pop_const_with_constant() {
634        let x = UOp::var("x", DType::Int32, 0, 100);
635        let c = UOp::const_(DType::Int32, ConstValue::Int(5));
636        let add = x.try_add(&c).unwrap();
637
638        let (rest, const_val) = add.pop_const(BinaryOp::Add);
639
640        assert!(Arc::ptr_eq(&rest, &x));
641        assert_eq!(const_val, Some(ConstValue::Int(5)));
642    }
643
644    #[test]
645    fn test_pop_const_without_constant() {
646        let x = UOp::var("x", DType::Int32, 0, 100);
647        let y = UOp::var("y", DType::Int32, 0, 100);
648        let add = x.try_add(&y).unwrap();
649
650        let (rest, const_val) = add.pop_const(BinaryOp::Add);
651
652        assert!(Arc::ptr_eq(&rest, &add));
653        assert_eq!(const_val, None);
654    }
655
656    #[test]
657    fn test_split_uop_chain() {
658        let x = UOp::var("x", DType::Int32, 0, 100);
659        let y = UOp::var("y", DType::Int32, 0, 100);
660        let z = UOp::var("z", DType::Int32, 0, 100);
661
662        // Build: x + y + z = (x + y) + z
663        let xy = x.try_add(&y).unwrap();
664        let xyz = xy.try_add(&z).unwrap();
665
666        let terms = xyz.split_uop(BinaryOp::Add);
667
668        assert_eq!(terms.len(), 3);
669        assert!(Arc::ptr_eq(&terms[0], &x));
670        assert!(Arc::ptr_eq(&terms[1], &y));
671        assert!(Arc::ptr_eq(&terms[2], &z));
672    }
673
674    #[test]
675    fn test_split_uop_single() {
676        let x = UOp::var("x", DType::Int32, 0, 100);
677        let terms = x.split_uop(BinaryOp::Add);
678
679        assert_eq!(terms.len(), 1);
680        assert!(Arc::ptr_eq(&terms[0], &x));
681    }
682
683    #[test]
684    fn test_gcd() {
685        assert_eq!(gcd(12, 8), 4);
686        assert_eq!(gcd(17, 19), 1);
687        assert_eq!(gcd(100, 50), 50);
688        assert_eq!(gcd(-12, 8), 4);
689        assert_eq!(gcd(12, -8), 4);
690        assert_eq!(gcd(-12, -8), 4);
691    }
692
693    #[test]
694    fn test_symbolic_gcd_numeric_only() {
695        // GCD of 6*x and 4*y → numeric GCD is 2
696        let x = UOp::var("x", DType::Index, 0, 10);
697        let y = UOp::var("y", DType::Index, 0, 10);
698        let six = UOp::const_(DType::Index, ConstValue::Int(6));
699        let four = UOp::const_(DType::Index, ConstValue::Int(4));
700        let a = x.try_mul(&six).unwrap(); // 6*x
701        let b = y.try_mul(&four).unwrap(); // 4*y
702        let g = UOp::symbolic_gcd(&[a, b]);
703        if let Op::Const(cv) = g.op() {
704            assert_eq!(cv.0, ConstValue::Int(2));
705        } else {
706            panic!("Expected constant GCD, got: {}", g.tree());
707        }
708    }
709
710    #[test]
711    fn test_symbolic_gcd_with_common_factor() {
712        // GCD of 6*x and 4*x → 2*x (common symbolic factor x, numeric GCD 2)
713        let x = UOp::var("x", DType::Index, 0, 10);
714        let six = UOp::const_(DType::Index, ConstValue::Int(6));
715        let four = UOp::const_(DType::Index, ConstValue::Int(4));
716        let a = x.try_mul(&six).unwrap(); // 6*x (= x*6 internally)
717        let b = x.try_mul(&four).unwrap(); // 4*x (= x*4 internally)
718        let g = UOp::symbolic_gcd(&[a, b]);
719        // Should be 2*x — a MUL node
720        assert!(matches!(g.op(), Op::Binary(BinaryOp::Mul, _, _)), "Expected MUL, got: {}", g.tree());
721    }
722
723    #[test]
724    fn test_const_factor_mul_only_immediate() {
725        // (x * 6) * (y * 4) — const_factor should be 1 (no immediate CONST child)
726        let x = UOp::var("x", DType::Index, 0, 10);
727        let y = UOp::var("y", DType::Index, 0, 10);
728        let six = UOp::const_(DType::Index, ConstValue::Int(6));
729        let four = UOp::const_(DType::Index, ConstValue::Int(4));
730        let a = x.try_mul(&six).unwrap(); // x*6
731        let b = y.try_mul(&four).unwrap(); // y*4
732        let ab = a.try_mul(&b).unwrap(); // (x*6) * (y*4)
733        // Tinygrad: neither immediate child is CONST → returns 1
734        assert_eq!(ab.const_factor(), 1);
735    }
736
737    #[test]
738    fn test_const_factor_vconst() {
739        let vc = UOp::vconst(
740            vec![ConstValue::Int(6), ConstValue::Int(12), ConstValue::Int(18), ConstValue::Int(24)],
741            DType::Int64,
742        );
743        assert_eq!(vc.const_factor(), 6); // GCD(6, 12, 18, 24) = 6
744    }
745
746    #[test]
747    fn test_const_factor_vconst_no_common() {
748        let vc = UOp::vconst(vec![ConstValue::Int(7), ConstValue::Int(11)], DType::Int64);
749        assert_eq!(vc.const_factor(), 1); // GCD(7, 11) = 1
750    }
751
752    #[test]
753    fn test_divides_int_vconst() {
754        let vc = UOp::vconst(vec![ConstValue::Int(6), ConstValue::Int(12)], DType::Int64);
755        let result = vc.divides_int(3);
756        assert!(result.is_some());
757        if let Some(r) = result {
758            if let Op::VConst { values } = r.op() {
759                assert_eq!(values, &[ConstValue::Int(2), ConstValue::Int(4)]);
760            } else {
761                panic!("Expected VConst result");
762            }
763        }
764    }
765
766    #[test]
767    fn test_divides_int_vconst_not_divisible() {
768        let vc = UOp::vconst(
769            vec![
770                ConstValue::Int(6),
771                ConstValue::Int(7), // 7 not divisible by 3
772            ],
773            DType::Int64,
774        );
775        assert!(vc.divides_int(3).is_none());
776    }
777
778    #[test]
779    fn test_is_increasing_const() {
780        let c = UOp::const_(DType::Int32, ConstValue::Int(5));
781        assert!(c.is_increasing());
782
783        let neg = UOp::const_(DType::Int32, ConstValue::Int(-5));
784        assert!(neg.is_increasing()); // Constants are always "increasing" (irreducible)
785    }
786
787    #[test]
788    fn test_is_increasing_add() {
789        let a = UOp::const_(DType::Int32, ConstValue::Int(5));
790        let b = UOp::const_(DType::Int32, ConstValue::Int(3));
791        let sum = a.try_add(&b).unwrap();
792        assert!(sum.is_increasing());
793    }
794
795    #[test]
796    fn test_is_increasing_mul_positive_const() {
797        let x = UOp::var("x", DType::Int32, 0, 100);
798        let two = UOp::const_(DType::Int32, ConstValue::Int(2));
799        let scaled = x.try_mul(&two).unwrap();
800        assert!(scaled.is_increasing());
801    }
802
803    #[test]
804    fn test_is_increasing_mul_negative_const() {
805        let x = UOp::var("x", DType::Int32, 0, 100);
806        let neg = UOp::const_(DType::Int32, ConstValue::Int(-2));
807        let scaled = x.try_mul(&neg).unwrap();
808        assert!(!scaled.is_increasing()); // Multiplying by negative is not increasing
809    }
810
811    #[test]
812    fn test_is_increasing_idiv_positive_const() {
813        let x = UOp::var("x", DType::Int32, 0, 100);
814        let two = UOp::const_(DType::Int32, ConstValue::Int(2));
815        let divided = x.idiv(&two);
816        assert!(divided.is_increasing());
817    }
818
819    #[test]
820    fn test_is_increasing_complex() {
821        // (x + 5) * 2 should be increasing
822        let x = UOp::var("x", DType::Int32, 0, 100);
823        let five = UOp::const_(DType::Int32, ConstValue::Int(5));
824        let two = UOp::const_(DType::Int32, ConstValue::Int(2));
825        let sum = x.try_add(&five).unwrap();
826        let scaled = sum.try_mul(&two).unwrap();
827        assert!(scaled.is_increasing());
828    }
829}