Skip to main content

morok_ir/uop/
core.rs

1//! Core UOp struct and fundamental operations.
2//!
3//! This module contains the [`UOp`] struct definition and its core methods
4//! for accessing operation data, dtype, shape, and graph traversal.
5
6use std::collections::{HashMap, HashSet};
7use std::hash::{Hash, Hasher};
8use std::sync::Arc;
9
10use bon::bon;
11use smallvec::SmallVec;
12
13use crate::op::Op;
14use crate::pattern::{Matcher, RewriteResult};
15use crate::shape;
16use crate::types::{AxisType, ConstValue};
17use morok_dtype::DType;
18
19/// Matcher for `UOp::substitute` — looks up each node in a substitution map.
20///
21/// Equivalent to Tinygrad's `_substitute = PatternMatcher([(UPat(tuple(Ops)), lambda ctx,x: ctx.get(x))])`.
22struct SubstituteMatcher<'a>(&'a HashMap<UOpKey, Arc<UOp>>);
23
24impl Matcher<()> for SubstituteMatcher<'_> {
25    fn rewrite(&self, uop: &Arc<UOp>, _ctx: &mut ()) -> RewriteResult {
26        match self.0.get(&UOpKey(uop.clone())) {
27            Some(replacement) if !Arc::ptr_eq(uop, replacement) => RewriteResult::Rewritten(replacement.clone()),
28            _ => RewriteResult::NoMatch,
29        }
30    }
31}
32
33/// Matcher for `UOp::substitute_gated` — substitution with range-scope gating.
34///
35/// Equivalent to Tinygrad's `_substitute` + `pm_gate_substitute`:
36/// - If a node is in the substitution map, replace it.
37/// - If a node's ranges don't overlap with substitution keys, gate (skip subtree).
38struct SubstituteGatedMatcher<'a> {
39    map: &'a HashMap<UOpKey, Arc<UOp>>,
40    range_keys: &'a HashSet<UOpKey>,
41}
42
43impl Matcher<()> for SubstituteGatedMatcher<'_> {
44    fn rewrite(&self, uop: &Arc<UOp>, _ctx: &mut ()) -> RewriteResult {
45        // Direct substitution lookup
46        if let Some(replacement) = self.map.get(&UOpKey(uop.clone()))
47            && !Arc::ptr_eq(uop, replacement)
48        {
49            return RewriteResult::Rewritten(replacement.clone());
50        }
51        // Gate: skip subtrees whose ranges don't overlap with substitution keys
52        // Tinygrad (rangeify.py:187): `if not any(r in b.ranges for r in ctx.keys()): raise BottomUpGate()`
53        if !uop.in_scope_ranges().iter().any(|r| self.range_keys.contains(r)) {
54            return RewriteResult::Gate(uop.clone());
55        }
56        RewriteResult::NoMatch
57    }
58}
59
60/// Wrapper for Arc<UOp> that implements Hash and Eq based on stable ID.
61///
62/// This allows using Arc<UOp> as HashMap keys without implementing
63/// Hash/Eq on UOp itself (which would be problematic due to OnceCell fields).
64///
65/// Note: While UOp contains OnceCell fields, Hash/Eq are based solely on the
66/// immutable `id` field, making this safe to use as a HashMap key.
67#[allow(clippy::mutable_key_type)]
68#[derive(Clone)]
69pub struct UOpKey(pub Arc<UOp>);
70
71// Custom Debug impl to show only the UOp ID, avoiding recursive printing
72impl std::fmt::Debug for UOpKey {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        write!(f, "UOpKey(id={})", self.0.id)
75    }
76}
77
78impl PartialEq for UOpKey {
79    fn eq(&self, other: &Self) -> bool {
80        self.0.id == other.0.id
81    }
82}
83
84impl Eq for UOpKey {}
85
86impl Hash for UOpKey {
87    fn hash<H: Hasher>(&self, state: &mut H) {
88        self.0.id.hash(state);
89    }
90}
91
92/// Micro-operation node in the computation graph.
93///
94/// UOps form a DAG where operations reference their inputs through the Op enum.
95/// Hash consing ensures that structurally identical UOps share the same allocation.
96///
97/// Shape inference is lazy and cached - computed on first access via `shape()` method.
98///
99/// Note: Debug uses derive_more with `#[debug(skip)]` on cache fields to prevent
100/// stack overflow from recursive Arc<UOp> references in caches.
101#[derive(derive_more::Debug)]
102pub struct UOp {
103    /// Unique stable ID for this UOp instance.
104    /// Used for identity-based caching instead of fragile raw pointers.
105    pub id: u64,
106    pub(crate) op: Op,
107    pub(crate) dtype: DType,
108    /// Cached shape - computed lazily on first access.
109    /// OnceLock provides thread-safe lazy initialization.
110    #[debug(skip)]
111    pub(crate) shape_cache: std::sync::OnceLock<crate::Result<Option<shape::Shape>>>,
112    /// Cached list of RANGE operations in this UOp's graph.
113    /// Computed lazily via toposort to collect all RANGE ops.
114    #[debug(skip)]
115    pub(crate) ranges_cache: std::sync::OnceLock<Vec<Arc<UOp>>>,
116    /// Cached set of RANGE operations that are in scope at this UOp.
117    /// Unlike ranges_cache which contains ALL ranges in the graph,
118    /// this contains only the ranges that are currently "active" (not yet ended).
119    /// Computed lazily based on Tinygrad's ranges property.
120    /// Uses UOpKey wrapper to enable Hash/Eq based on UOp ID.
121    #[debug(skip)]
122    pub(crate) in_scope_ranges_cache: std::sync::OnceLock<HashSet<UOpKey>>,
123    /// Cached vmin/vmax range analysis values.
124    /// Computed lazily via range propagation through the computation graph.
125    /// Returns (vmin, vmax) as ConstValue types.
126    #[debug(skip)]
127    pub(crate) vmin_vmax_cache: std::sync::OnceLock<(ConstValue, ConstValue)>,
128    /// Sound vmin/vmax: `None` for ops where range analysis is unsound (LOAD, Pow, etc.).
129    /// Used by patterns that must not act on unsound bounds (e.g., vmin_vmax_collapse).
130    #[debug(skip)]
131    pub(crate) sound_vmin_vmax_cache: std::sync::OnceLock<Option<(ConstValue, ConstValue)>>,
132    /// Whether this node or any of its sources is an INDEX op.
133    /// Cached O(1) lookup used by `simplify_valid` to skip And chains inside INDEX trees.
134    #[debug(skip)]
135    pub(crate) has_index_in_sources_cache: std::sync::OnceLock<bool>,
136    /// Cached backward slice: IDs of all nodes reachable from this UOp (including self).
137    /// O(1) membership test via `backward_slice_ids().contains(&target.id)`.
138    #[debug(skip)]
139    pub(crate) backward_slice_cache: std::sync::OnceLock<HashSet<u64>>,
140    /// Structural content hash — deterministic regardless of allocation order.
141    /// Computed at creation time: hash(op_discriminant, dtype, op_data, children_content_hashes).
142    /// O(1) per node since children are already created with their content_hash set.
143    /// Used for schedule-level caching where UOp IDs are not stable across runs.
144    pub content_hash: u64,
145    /// Tag for tracking tensor identity through the rangeify pipeline.
146    ///
147    /// Matches Tinygrad's `UOp.tag` (ops.py:128). Tags are tuples of integer indices
148    /// that track which original tensor UOps map to which final kernel outputs.
149    /// Tags participate in hash consing — different tag = different UOp.
150    ///
151    /// Values:
152    /// - `None` — untagged (default)
153    /// - `Some([])` — empty tag (e.g., RANGE ops)
154    /// - `Some([i])` — single index (assigned by add_tags)
155    /// - `Some([i, j, ...])` — merged indices (from buffer folding)
156    pub tag: Option<SmallVec<[usize; 2]>>,
157    /// Optional metadata attached to this UOp.
158    ///
159    /// Metadata is NOT part of hash consing - attaching metadata creates a new UOp
160    /// instance with a different ID. This is used for kernel info (name, opts) after
161    /// optimization is complete.
162    ///
163    /// Uses Arc<dyn Any> to allow attaching any metadata type without circular
164    /// dependencies (e.g., schedule::KernelInfo).
165    #[debug(skip)]
166    pub(crate) metadata: Option<std::sync::Arc<dyn std::any::Any + Send + Sync>>,
167}
168
169/// Hash implementation for UOp based on content (dtype + op).
170///
171/// This enables content-based hashing for cross-run caching. The hash traverses
172/// the DAG structure since Op contains Arc<UOp> children that also get hashed.
173/// Cache fields are intentionally skipped - they don't affect semantic identity.
174impl Hash for UOp {
175    fn hash<H: Hasher>(&self, state: &mut H) {
176        self.dtype.hash(state);
177        self.op.hash(state);
178        // Intentionally skip: id, caches, metadata
179    }
180}
181
182impl UOp {
183    /// Get the operation.
184    pub fn op(&self) -> &Op {
185        &self.op
186    }
187
188    /// Get the data type.
189    pub fn dtype(&self) -> DType {
190        self.dtype.clone()
191    }
192
193    /// Get the tag.
194    pub fn tag(&self) -> &Option<SmallVec<[usize; 2]>> {
195        &self.tag
196    }
197
198    /// Create a new UOp with the given tag (Tinygrad: `rtag()`).
199    /// Returns self unchanged if tag is already equal.
200    pub fn rtag(self: &Arc<Self>, tag: Option<SmallVec<[usize; 2]>>) -> Arc<Self> {
201        if self.tag == tag {
202            return self.clone();
203        }
204        Self::new_tagged(self.op.clone(), self.dtype.clone(), tag)
205    }
206
207    /// Create a new UOp with the given tag set.
208    pub fn with_tag(self: &Arc<Self>, tag: SmallVec<[usize; 2]>) -> Arc<Self> {
209        self.rtag(Some(tag))
210    }
211
212    /// Check if this UOp has a concrete buffer identity in the graph.
213    ///
214    /// Returns true for BUFFER or RESHAPE/MULTI chains leading to BUFFER.
215    /// These are already contiguous by definition, so wrapping in CONTIGUOUS is a no-op.
216    ///
217    /// Based on Tinygrad's `UOp.has_buffer_identity()` (ops.py:616-619).
218    pub fn has_buffer_identity(&self) -> bool {
219        match &self.op {
220            Op::Reshape { src, .. } => src.has_buffer_identity(),
221            Op::Buffer { .. } => true,
222            _ => false,
223        }
224    }
225
226    /// Get pointer dtype components if this UOp has a Ptr dtype.
227    ///
228    /// Returns `(base, addrspace, size)` for Ptr types, None otherwise.
229    /// This simplifies pattern matching on pointer types.
230    ///
231    /// # Examples
232    ///
233    /// ```rust
234    /// # use morok_ir::UOp;
235    /// # use morok_dtype::{DType, AddrSpace, DeviceSpec};
236    /// let buffer = UOp::new_buffer(DeviceSpec::Cpu, 10, DType::Float32);
237    /// if let Some((base, addrspace, size)) = buffer.ptrdtype() {
238    ///     assert_eq!(*base, DType::Float32);
239    ///     assert_eq!(addrspace, AddrSpace::Global);
240    /// }
241    /// ```
242    pub fn ptrdtype(&self) -> Option<(&DType, morok_dtype::AddrSpace, Option<usize>)> {
243        match &self.dtype {
244            DType::Ptr { base, addrspace, size, .. } => Some((base.as_ref(), *addrspace, *size)),
245            _ => None,
246        }
247    }
248
249    /// Create a copy of this UOp with a different dtype.
250    ///
251    /// If the dtype is unchanged, returns self (clone of Arc).
252    /// This is the Rust equivalent of Tinygrad's `buf.replace(dtype=x)`.
253    ///
254    /// # Examples
255    ///
256    /// ```rust
257    /// # use std::sync::Arc;
258    /// # use morok_ir::UOp;
259    /// # use morok_dtype::DType;
260    /// let int_const = UOp::const_(DType::Int32, morok_ir::ConstValue::Int(5));
261    /// let float_const = int_const.with_dtype(DType::Float32);
262    /// assert_eq!(float_const.dtype(), DType::Float32);
263    /// ```
264    pub fn with_dtype(self: &Arc<Self>, dtype: DType) -> Arc<Self> {
265        if self.dtype == dtype {
266            return self.clone();
267        }
268        Self::new(self.op.clone(), dtype)
269    }
270
271    /// Walk through AFTER nodes to get the passthrough value.
272    ///
273    /// This is the Rust equivalent of Tinygrad's `.or_after()` pattern.
274    /// Recursively unwraps AFTER nodes to find the underlying value.
275    ///
276    /// # Examples
277    ///
278    /// ```ignore
279    /// // Given: AFTER(AFTER(value, [dep1]), [dep2])
280    /// // Returns: value
281    /// let inner = wrapped.unwrap_after();
282    /// ```
283    pub fn unwrap_after(self: &Arc<Self>) -> Arc<Self> {
284        match self.op() {
285            Op::After { passthrough, .. } => passthrough.unwrap_after(),
286            _ => self.clone(),
287        }
288    }
289
290    /// Walk through CAST nodes to get the inner value.
291    ///
292    /// This is the Rust equivalent of Tinygrad's `.or_casted()` pattern.
293    /// Recursively unwraps CAST nodes to find the underlying value.
294    ///
295    /// # Examples
296    ///
297    /// ```ignore
298    /// // Given: CAST(CAST(value, dtype1), dtype2)
299    /// // Returns: value
300    /// let inner = casted.unwrap_cast();
301    /// ```
302    pub fn unwrap_cast(self: &Arc<Self>) -> Arc<Self> {
303        match self.op() {
304            Op::Cast { src, .. } => src.unwrap_cast(),
305            _ => self.clone(),
306        }
307    }
308
309    /// Get the buffer from a STORE operation (via its INDEX child).
310    ///
311    /// STORE operations reference the buffer indirectly through an INDEX node.
312    /// This helper extracts the buffer from `STORE.index.buffer`.
313    ///
314    /// Returns `None` if:
315    /// - This is not a STORE operation
316    /// - The STORE's index is not an INDEX operation
317    pub fn store_buffer(&self) -> Option<&Arc<UOp>> {
318        match self.op() {
319            Op::Store { index, .. } => match index.op() {
320                Op::Index { buffer, .. } => Some(buffer),
321                _ => None,
322            },
323            _ => None,
324        }
325    }
326
327    /// Get the buffer from a LOAD operation.
328    ///
329    /// Returns `None` if this is not a LOAD operation.
330    pub fn load_buffer(&self) -> Option<Arc<UOp>> {
331        match self.op() {
332            Op::Load { buffer, .. } => Some(buffer.clone()),
333            _ => None,
334        }
335    }
336
337    /// Store a value at this INDEX node.
338    ///
339    /// Convenience method for `self.store(value)`.
340    /// Matches Tinygrad's `idx.store(val)` pattern.
341    ///
342    /// # Panics
343    ///
344    /// Debug-asserts that self is an INDEX operation.
345    pub fn store_value(self: &Arc<Self>, value: Arc<Self>) -> Arc<Self> {
346        debug_assert!(matches!(self.op(), Op::Index { .. }), "store_value requires INDEX");
347        self.store(value)
348    }
349
350    /// Alias for `with_sources()`.
351    ///
352    /// Creates a new UOp with the same operation type and dtype, but with
353    /// the provided sources replacing the original ones.
354    pub fn with_src(self: &Arc<Self>, new_srcs: Vec<Arc<Self>>) -> Arc<Self> {
355        self.with_sources(new_srcs)
356    }
357
358    /// Get the shape of this UOp.
359    ///
360    /// Shape is computed lazily on first access and cached.
361    /// Returns Ok(None) if shape cannot be determined (e.g., for control flow ops).
362    /// Returns Err if there is a shape mismatch error.
363    ///
364    /// # Examples
365    ///
366    /// ```rust
367    /// # use morok_ir::{UOp, ConstValue};
368    /// # use morok_dtype::DType;
369    /// let scalar = UOp::const_(DType::Float32, ConstValue::Float(1.0));
370    /// assert_eq!(scalar.shape().unwrap().as_ref().map(|s| s.len()), Some(0)); // Scalar has empty shape
371    /// ```
372    pub fn shape(self: &Arc<Self>) -> crate::Result<Option<&shape::Shape>> {
373        use crate::uop::cached_property::CachedProperty;
374        use crate::uop::properties::ShapeProperty;
375        match ShapeProperty::get(self) {
376            Ok(opt) => Ok(opt.as_ref()),
377            Err(e) => Err(e.clone()),
378        }
379    }
380
381    /// Get the minimum possible value of this UOp.
382    ///
383    /// Returns the minimum value based on range analysis.
384    /// Computed lazily on first access and cached.
385    ///
386    /// # Examples
387    ///
388    /// ```rust
389    /// # use morok_ir::{UOp, ConstValue};
390    /// # use morok_dtype::DType;
391    /// let five = UOp::const_(DType::Int32, ConstValue::Int(5));
392    /// assert_eq!(five.vmin(), &ConstValue::Int(5));
393    /// ```
394    pub fn vmin(self: &Arc<Self>) -> &ConstValue {
395        use crate::uop::cached_property::CachedProperty;
396        use crate::uop::properties::VminVmaxProperty;
397        &VminVmaxProperty::get(self).0
398    }
399
400    /// Get the maximum possible value of this UOp.
401    ///
402    /// Returns the maximum value based on range analysis.
403    /// Computed lazily on first access and cached.
404    ///
405    /// # Examples
406    ///
407    /// ```rust
408    /// # use morok_ir::{UOp, ConstValue};
409    /// # use morok_dtype::DType;
410    /// let five = UOp::const_(DType::Int32, ConstValue::Int(5));
411    /// assert_eq!(five.vmax(), &ConstValue::Int(5));
412    /// ```
413    pub fn vmax(self: &Arc<Self>) -> &ConstValue {
414        use crate::uop::cached_property::CachedProperty;
415        use crate::uop::properties::VminVmaxProperty;
416        &VminVmaxProperty::get(self).1
417    }
418
419    /// Extract device specification from this UOp graph.
420    ///
421    /// Traverses the graph to find Op::Device nodes following Tinygrad's
422    /// `_device` recursive property (ops.py:585-599):
423    /// - DEVICE: returns the DeviceSpec directly
424    /// - BUFFER: returns device from the device child
425    /// - COPY: returns device from the device child (target device)
426    /// - Otherwise: searches children recursively
427    ///
428    /// # Examples
429    ///
430    /// ```rust
431    /// # use morok_ir::UOp;
432    /// # use morok_dtype::{DType, DeviceSpec};
433    /// let buffer = UOp::new_buffer(DeviceSpec::Cpu, 10, DType::Float32);
434    /// assert_eq!(buffer.device_spec(), Some(DeviceSpec::Cpu));
435    /// ```
436    pub fn device_spec(&self) -> Option<morok_dtype::DeviceSpec> {
437        match self.op() {
438            Op::Device(spec) => Some(spec.clone()),
439            Op::Buffer { device, .. } => {
440                if let Op::Device(spec) = device.op() {
441                    Some(spec.clone())
442                } else {
443                    None
444                }
445            }
446            Op::Param { device: Some(device), .. } => {
447                if let Op::Device(spec) = device.op() {
448                    Some(spec.clone())
449                } else {
450                    None
451                }
452            }
453            Op::Param { device: None, .. } => None,
454            Op::Copy { device, .. } => {
455                if let Op::Device(spec) = device.op() {
456                    Some(spec.clone())
457                } else {
458                    None
459                }
460            }
461            _ => {
462                // Search children for device
463                for child in self.op().children() {
464                    if let Some(spec) = child.device_spec() {
465                        return Some(spec);
466                    }
467                }
468                None
469            }
470        }
471    }
472
473    /// Get the base UOp by walking through movement operations.
474    ///
475    /// Movement operations (RESHAPE, PERMUTE, EXPAND, etc.) are views that don't
476    /// change the underlying data. This method recursively walks through these
477    /// operations to find the actual buffer or computation that owns the data.
478    ///
479    /// Based on Tinygrad's `base` property (ops.py:524-527).
480    ///
481    /// # Examples
482    ///
483    /// ```rust
484    /// # use morok_ir::{UOp, SInt, shape::Shape};
485    /// # use morok_dtype::DType;
486    /// # use morok_dtype::DeviceSpec;
487    /// let buffer = UOp::new_buffer(DeviceSpec::Cpu, 10, DType::Float32);
488    /// let shape = Shape::from_iter([SInt::Const(2), SInt::Const(5)]);
489    /// let reshaped = buffer.try_reshape(&shape).unwrap();
490    ///
491    /// // base() walks through RESHAPE to get the original BUFFER
492    /// assert!(std::sync::Arc::ptr_eq(&reshaped.base(), &buffer));
493    /// ```
494    pub fn base(self: &Arc<Self>) -> Arc<Self> {
495        match &self.op {
496            // Movement operations - recursively get base of source
497            Op::Reshape { src, .. }
498            | Op::Permute { src, .. }
499            | Op::Expand { src, .. }
500            | Op::Pad { src, .. }
501            | Op::Shrink { src, .. }
502            | Op::Flip { src, .. }
503            | Op::Multi { src, .. } => src.base(),
504            // All other operations are their own base
505            _ => self.clone(),
506        }
507    }
508
509    /// Get the underlying buffer UOp, walking through AFTER/MSELECT/MSTACK chains.
510    ///
511    /// Based on Tinygrad's `buf_uop` property (ops.py:601-606).
512    /// This recursively unwraps AFTER chains to find the actual buffer.
513    ///
514    /// # Examples
515    ///
516    /// ```ignore
517    /// use morok_ir::UOp;
518    ///
519    /// // AFTER wrapping a buffer
520    /// let buffer = UOp::new_buffer(...);
521    /// let after = buffer.after(deps);
522    ///
523    /// // buf_uop() walks through AFTER to get the underlying buffer
524    /// assert!(Arc::ptr_eq(&after.buf_uop(), &buffer));
525    /// ```
526    pub fn buf_uop(self: &Arc<Self>) -> Arc<Self> {
527        match self.op() {
528            Op::Buffer { .. } => self.clone(),
529            Op::MSelect { buffer, .. } => buffer.buf_uop(),
530            Op::MStack { buffers } if !buffers.is_empty() => buffers[0].buf_uop(),
531            Op::After { passthrough, .. } => passthrough.buf_uop(),
532            _ => {
533                // For other ops, check if base is AFTER
534                let base = self.base();
535                if matches!(base.op(), Op::After { .. }) { base.buf_uop() } else { self.clone() }
536            }
537        }
538    }
539
540    /// Topological sort of the computation graph.
541    ///
542    /// Returns nodes in an order where all dependencies come before their dependents.
543    pub fn toposort(self: &Arc<Self>) -> Vec<Arc<Self>> {
544        let mut visited = HashSet::new();
545        let mut result = Vec::new();
546        let mut stack = vec![(self.clone(), false)];
547
548        while let Some((node, processed)) = stack.pop() {
549            let ptr = Arc::as_ptr(&node);
550
551            if visited.contains(&ptr) {
552                continue;
553            }
554
555            if processed {
556                visited.insert(ptr);
557                result.push(node);
558            } else {
559                stack.push((node.clone(), true));
560
561                // Use for_each_child for zero-allocation traversal
562                let mut children = Vec::new();
563                node.op.map_child(|child| {
564                    if !visited.contains(&Arc::as_ptr(child)) {
565                        children.push(child.clone());
566                    }
567                });
568
569                // Push in reverse order for proper traversal
570                for child in children.into_iter().rev() {
571                    stack.push((child, false));
572                }
573            }
574        }
575
576        result
577    }
578
579    /// Topological sort with gate function (filtered toposort).
580    ///
581    /// Only traverses nodes for which `gate(node)` returns true.
582    /// Nodes for which gate returns false are excluded from the
583    /// traversal entirely (along with their ancestors).
584    ///
585    /// This is a key optimization for cached property computation,
586    /// allowing us to skip nodes that already have a property cached.
587    ///
588    /// # Performance
589    ///
590    /// For a graph with 10,000 nodes where 9,900 already have a cached property:
591    /// - **Full toposort**: 10,000 nodes visited
592    /// - **Filtered toposort**: 100 nodes visited
593    /// - **Speedup**: 100x
594    ///
595    /// # Example
596    ///
597    /// ```ignore
598    /// // Only process nodes that don't have shape cached
599    /// let uncached = uop.toposort_filtered(|node| {
600    ///     node.shape_cache.get().is_none()
601    /// });
602    /// ```
603    pub fn toposort_filtered<F>(self: &Arc<Self>, gate: F) -> Vec<Arc<Self>>
604    where
605        F: Fn(&Arc<UOp>) -> bool,
606    {
607        let mut visited = HashSet::new();
608        let mut result = Vec::new();
609        let mut stack = vec![(self.clone(), false)];
610
611        while let Some((node, processed)) = stack.pop() {
612            let ptr = Arc::as_ptr(&node);
613
614            if visited.contains(&ptr) {
615                continue;
616            }
617
618            if processed {
619                visited.insert(ptr);
620                result.push(node);
621            } else {
622                // Key optimization: only traverse nodes that pass the gate
623                if gate(&node) {
624                    stack.push((node.clone(), true));
625
626                    let mut children = Vec::new();
627                    node.op.map_child(|child| {
628                        if !visited.contains(&Arc::as_ptr(child)) {
629                            children.push(child.clone());
630                        }
631                    });
632
633                    // Push in reverse order for proper traversal
634                    for child in children.into_iter().rev() {
635                        stack.push((child, false));
636                    }
637                }
638            }
639        }
640
641        result
642    }
643
644    /// Check if any node in the backward slice satisfies a predicate.
645    ///
646    /// Early-exit DFS — returns `true` as soon as a matching node is found,
647    /// without building the full toposort Vec. Use this instead of
648    /// `toposort().iter().any(pred)` when you only need an existential check.
649    pub fn any_in_subtree<F>(self: &Arc<Self>, pred: F) -> bool
650    where
651        F: Fn(&Arc<UOp>) -> bool,
652    {
653        let mut visited = HashSet::new();
654        let mut stack = vec![self.clone()];
655        while let Some(node) = stack.pop() {
656            if !visited.insert(Arc::as_ptr(&node)) {
657                continue;
658            }
659            if pred(&node) {
660                return true;
661            }
662            node.op.map_child(|child| {
663                if !visited.contains(&Arc::as_ptr(child)) {
664                    stack.push(child.clone());
665                }
666            });
667        }
668        false
669    }
670
671    /// Collect all nodes in the backward slice that match a predicate.
672    ///
673    /// DFS collecting matches — cheaper than `toposort().iter().filter(pred).collect()`
674    /// when you don't need topological ordering.
675    pub fn collect_in_subtree<F>(self: &Arc<Self>, pred: F) -> Vec<Arc<UOp>>
676    where
677        F: Fn(&Arc<UOp>) -> bool,
678    {
679        let mut visited = HashSet::new();
680        let mut stack = vec![self.clone()];
681        let mut result = Vec::new();
682        while let Some(node) = stack.pop() {
683            if !visited.insert(Arc::as_ptr(&node)) {
684                continue;
685            }
686            if pred(&node) {
687                result.push(node.clone());
688            }
689            node.op.map_child(|child| {
690                if !visited.contains(&Arc::as_ptr(child)) {
691                    stack.push(child.clone());
692                }
693            });
694        }
695        result
696    }
697
698    /// Count unique nodes in the DAG rooted at this UOp.
699    ///
700    /// Much cheaper than `toposort().len()` — no result Vec, no ordering.
701    /// Uses pointer-based visited set for O(1) identity checks.
702    pub fn node_count(self: &Arc<Self>) -> usize {
703        let mut visited = HashSet::new();
704        let mut stack = vec![self.clone()];
705        while let Some(node) = stack.pop() {
706            if !visited.insert(Arc::as_ptr(&node)) {
707                continue;
708            }
709            node.op.map_child(|child| {
710                if !visited.contains(&Arc::as_ptr(child)) {
711                    stack.push(child.clone());
712                }
713            });
714        }
715        visited.len()
716    }
717
718    /// O(1) cached check: does this node or any of its sources contain an INDEX op?
719    ///
720    /// Computed lazily and cached. Each node checks itself and its direct sources'
721    /// cached values, so the total cost across the graph is O(N).
722    pub fn has_index_in_sources(self: &Arc<Self>) -> bool {
723        *self.has_index_in_sources_cache.get_or_init(|| {
724            if matches!(self.op, Op::Index { .. }) {
725                return true;
726            }
727            let mut result = false;
728            self.op.map_child(|child| {
729                if child.has_index_in_sources() {
730                    result = true;
731                }
732            });
733            result
734        })
735    }
736
737    /// Render this UOp and its sources as a compact ASCII tree.
738    ///
739    /// Shared nodes (appearing multiple times due to hash-consing) are shown
740    /// as back-references: `[id] → (see above)`
741    ///
742    /// # Example Output
743    ///
744    /// ```text
745    /// [42] STORE : Void
746    /// ├── [10] PARAM(0) : Ptr<Float32> shape=[4]
747    /// ├── [35] INDEX : Ptr<Float32> shape=[4]
748    /// │   ├── [10] → (see above)
749    /// │   └── [30] RANGE(0, Reduce) : Index
750    /// │       └── [5] CONST(Int(4)) : Index
751    /// └── [40] REDUCE(Add) : Float32 shape=[]
752    ///     └── [35] → (see above)
753    /// ```
754    pub fn tree(self: &Arc<Self>) -> String {
755        crate::uop::tree::render_tree_compact(self)
756    }
757
758    /// Render this UOp and its sources as a full ASCII tree.
759    ///
760    /// Shared nodes are expanded every time they appear (verbose but complete).
761    /// Use this when you need to see the full subtree at every occurrence.
762    pub fn tree_full(self: &Arc<Self>) -> String {
763        crate::uop::tree::render_tree_full(self)
764    }
765
766    /// Get all RANGE operations in this UOp's computation graph.
767    ///
768    /// Lazily computed and cached. Useful for rangeify pass to track loop variables.
769    pub fn ranges(self: &Arc<Self>) -> &Vec<Arc<Self>> {
770        use crate::uop::cached_property::CachedProperty;
771        use crate::uop::properties::RangesProperty;
772        RangesProperty::get(self)
773    }
774
775    /// Get the RANGE operations that are in scope at this UOp.
776    ///
777    /// Returns only the ranges that are currently "active" (not yet ended).
778    /// This is computed by:
779    /// 1. Merging ranges from all source operations
780    /// 2. Removing ranges that are ended by this operation
781    /// 3. Adding self if this is a RANGE operation
782    ///
783    /// Based on Tinygrad's `ranges` property (ops.py:318-320) and
784    /// `_ranges` recursive property (ops.py:302-315).
785    ///
786    /// # Returns
787    ///
788    /// A HashSet of RANGE UOps that are in scope at this point in the graph.
789    /// The result is cached for performance.
790    ///
791    /// # Examples
792    ///
793    /// ```ignore
794    /// use morok_ir::{UOp, AxisType};
795    ///
796    /// // A simple computation inside a range
797    /// let range = UOp::range(end, 0, AxisType::Loop);
798    /// let value = UOp::const_(...);
799    /// let end_op = value.end(vec![range.clone()]);
800    ///
801    /// // Value has range in scope
802    /// assert!(value.in_scope_ranges().contains(&range));
803    ///
804    /// // After END, range is no longer in scope
805    /// assert!(!end_op.in_scope_ranges().contains(&range));
806    /// ```
807    #[allow(clippy::mutable_key_type)]
808    pub fn in_scope_ranges(self: &Arc<Self>) -> &HashSet<UOpKey> {
809        use crate::uop::cached_property::CachedProperty;
810        use crate::uop::properties::InScopeRangesProperty;
811        InScopeRangesProperty::get(self)
812    }
813
814    /// Check if all in-scope ranges at this UOp have the given AxisType.
815    ///
816    /// Returns true if the in-scope ranges set is empty or all ranges
817    /// match the specified axis type.
818    ///
819    /// # Use Cases
820    ///
821    /// - `all_in_scope_ranges_are(AxisType::Outer)` - Used in split_store
822    ///   to determine if we're at a kernel boundary
823    ///
824    /// # Examples
825    ///
826    /// ```ignore
827    /// use morok_ir::{UOp, AxisType};
828    ///
829    /// // At kernel boundary: only OUTER ranges in scope
830    /// assert!(uop.all_in_scope_ranges_are(AxisType::Outer));
831    ///
832    /// // Inside kernel: has non-OUTER ranges
833    /// assert!(!uop.all_in_scope_ranges_are(AxisType::Outer));
834    /// ```
835    #[allow(clippy::mutable_key_type)]
836    pub fn all_in_scope_ranges_are(self: &Arc<Self>, axis_type: AxisType) -> bool {
837        use crate::Op;
838
839        let ranges = self.in_scope_ranges();
840
841        // Empty scope means we're at the top level (treat as all OUTER)
842        if ranges.is_empty() {
843            return true;
844        }
845
846        ranges.iter().all(|r| match r.0.op() {
847            Op::Range { axis_type: at, .. } => *at == axis_type,
848            _ => false, // Should never happen
849        })
850    }
851
852    /// Check if any in-scope range is NOT of the given AxisType.
853    ///
854    /// Inverse of `all_in_scope_ranges_are`. Useful for Tinygrad-style
855    /// filtering: "skip if any range is not OUTER".
856    ///
857    /// # Examples
858    ///
859    /// ```ignore
860    /// use morok_ir::{UOp, AxisType};
861    ///
862    /// // Has non-OUTER ranges: should skip in split_store
863    /// if uop.has_non_outer_ranges() {
864    ///     return None;  // Don't split here
865    /// }
866    /// ```
867    pub fn has_non_outer_ranges(self: &Arc<Self>) -> bool {
868        !self.all_in_scope_ranges_are(AxisType::Outer)
869    }
870
871    /// Build a consumer map for this UOp's computation graph.
872    ///
873    /// Returns a HashMap where each UOp maps to the list of UOps that consume it.
874    /// Useful for reverse traversal and dependency analysis.
875    #[allow(clippy::mutable_key_type)]
876    pub fn get_consumer_map(self: &Arc<Self>) -> HashMap<UOpKey, Vec<Arc<Self>>> {
877        let mut consumer_map: HashMap<UOpKey, Vec<Arc<Self>>> = HashMap::new();
878
879        for node in self.toposort() {
880            node.op.map_child(|child| {
881                consumer_map.entry(UOpKey(child.clone())).or_default().push(node.clone());
882            });
883        }
884
885        consumer_map
886    }
887
888    /// Reverse topological sort of the computation graph.
889    ///
890    /// Returns nodes in bottom-up order (leaves first, root last).
891    /// Requires a consumer map to traverse from leaves to roots.
892    #[allow(clippy::mutable_key_type)]
893    pub fn reverse_toposort(self: &Arc<Self>, consumer_map: &HashMap<UOpKey, Vec<Arc<Self>>>) -> Vec<Arc<Self>> {
894        let mut visited = HashMap::new(); // Use HashMap to track visited by ID
895        let mut result = Vec::new();
896        let mut stack = vec![(self.clone(), false)];
897
898        while let Some((node, processed)) = stack.pop() {
899            if visited.contains_key(&node.id) {
900                continue;
901            }
902
903            if processed {
904                visited.insert(node.id, ());
905                result.push(node);
906            } else {
907                stack.push((node.clone(), true));
908
909                // Visit consumers (nodes that depend on this node)
910                if let Some(consumers) = consumer_map.get(&UOpKey(node.clone())) {
911                    for consumer in consumers {
912                        if !visited.contains_key(&consumer.id) {
913                            stack.push((consumer.clone(), false));
914                        }
915                    }
916                }
917            }
918        }
919
920        result
921    }
922
923    /// Replace UOps in the computation graph according to a substitution map.
924    ///
925    /// Delegates to `graph_rewrite_bottom_up` with a wildcard pattern that looks up
926    /// each node in the map — exactly like Tinygrad's `substitute`. The rewrite engine
927    /// provides O(n) memoization via its result cache.
928    #[allow(clippy::mutable_key_type)]
929    pub fn substitute(self: &Arc<Self>, map: &HashMap<UOpKey, Arc<Self>>) -> Arc<Self> {
930        if map.is_empty() {
931            return self.clone();
932        }
933        let matcher = SubstituteMatcher(map);
934        crate::rewrite::graph_rewrite_bottom_up(&matcher, self.clone(), &mut ())
935    }
936
937    /// Replace UOps with range-gated substitution (Tinygrad: `extra_pm=pm_gate_substitute`).
938    ///
939    /// Like `substitute`, but skips subtrees whose `in_scope_ranges()` don't contain
940    /// any of the substitution keys. This prevents substituting ranges in subexpressions
941    /// that don't reference them, matching Tinygrad's `gate_substitute` behavior.
942    #[allow(clippy::mutable_key_type)]
943    pub fn substitute_gated(self: &Arc<Self>, map: &HashMap<UOpKey, Arc<Self>>) -> Arc<Self> {
944        if map.is_empty() {
945            return self.clone();
946        }
947        let range_keys: HashSet<UOpKey> = map.keys().cloned().collect();
948        let matcher = SubstituteGatedMatcher { map, range_keys: &range_keys };
949        crate::rewrite::graph_rewrite_bottom_up(&matcher, self.clone(), &mut ())
950    }
951
952    /// Reconstruct this UOp with new sources.
953    ///
954    /// Creates a new UOp with the same operation type and dtype, but with the provided
955    /// sources replacing the original ones. Hash consing ensures that if an identical
956    /// UOp already exists, it will be reused.
957    ///
958    /// This is used by the graph rewrite engine when sources have been rewritten.
959    ///
960    /// # Panics
961    ///
962    /// Panics if the number of sources doesn't match the operation's arity.
963    ///
964    /// # Examples
965    ///
966    /// ```ignore
967    /// // Original: a + b
968    /// let add = UOp::add(a.clone(), b.clone());
969    ///
970    /// // Rewrite sources: a' + b'
971    /// let new_add = add.with_sources(vec![a_prime, b_prime]);
972    /// ```
973    pub fn with_sources(self: &Arc<Self>, new_srcs: Vec<Arc<Self>>) -> Arc<Self> {
974        use smallvec::SmallVec;
975
976        // Helper to get nth source
977        let src = |n: usize| new_srcs[n].clone();
978
979        let new_op = match &self.op {
980            // Nullary operations - no sources
981            Op::Const(_)
982            | Op::Unique(_)
983            | Op::Device(_)
984            | Op::Noop
985            | Op::Invalid
986            | Op::DefineLocal(_)
987            | Op::VConst { .. }
988            | Op::DefineVar { .. }
989            | Op::DefineReg { .. } => {
990                assert_eq!(new_srcs.len(), 0, "Nullary op should have no sources");
991                return self.clone(); // No sources to replace
992            }
993
994            // Unary operations
995            Op::Unary(op_type, _) => {
996                assert_eq!(new_srcs.len(), 1);
997                Op::Unary(*op_type, src(0))
998            }
999
1000            // Binary operations
1001            Op::Binary(op_type, _, _) => {
1002                assert_eq!(new_srcs.len(), 2);
1003                Op::Binary(*op_type, src(0), src(1))
1004            }
1005
1006            // Ternary operations
1007            Op::Ternary(op_type, _, _, _) => {
1008                assert_eq!(new_srcs.len(), 3);
1009                Op::Ternary(*op_type, src(0), src(1), src(2))
1010            }
1011
1012            // Type operations
1013            Op::Cast { dtype, .. } => {
1014                assert_eq!(new_srcs.len(), 1);
1015                Op::Cast { src: src(0), dtype: dtype.clone() }
1016            }
1017            Op::BitCast { dtype, .. } => {
1018                assert_eq!(new_srcs.len(), 1);
1019                Op::BitCast { src: src(0), dtype: dtype.clone() }
1020            }
1021
1022            // Special operations
1023            Op::MSelect { device_index, .. } => {
1024                assert_eq!(new_srcs.len(), 1);
1025                Op::MSelect { buffer: src(0), device_index: *device_index }
1026            }
1027            Op::Special { name, .. } => {
1028                assert_eq!(new_srcs.len(), 1);
1029                Op::Special { end: src(0), name: name.clone() }
1030            }
1031
1032            // Buffer operations
1033            Op::Buffer { size, .. } => {
1034                assert_eq!(new_srcs.len(), 2);
1035                Op::Buffer { unique: src(0), device: src(1), size: *size }
1036            }
1037            Op::Param { slot, size, device } => {
1038                if device.is_some() {
1039                    assert_eq!(new_srcs.len(), 1);
1040                    Op::Param { slot: *slot, size: *size, device: Some(src(0)) }
1041                } else {
1042                    assert_eq!(new_srcs.len(), 0);
1043                    return self.clone();
1044                }
1045            }
1046            Op::BufferView { size, offset, .. } => {
1047                assert_eq!(new_srcs.len(), 1);
1048                Op::BufferView { buffer: src(0), size: *size, offset: *offset }
1049            }
1050            Op::Bufferize { opts, .. } => {
1051                assert!(!new_srcs.is_empty());
1052                Op::Bufferize { compute: src(0), ranges: new_srcs[1..].iter().cloned().collect(), opts: opts.clone() }
1053            }
1054            Op::Index { gate, .. } => {
1055                assert!(!new_srcs.is_empty());
1056                // First source is buffer, rest are indices, last might be gate
1057                let buffer = src(0);
1058                let (indices, gate_new) = if gate.is_some() && new_srcs.len() >= 2 {
1059                    let gate_src = new_srcs.last().unwrap().clone();
1060                    let indices: SmallVec<[Arc<Self>; 4]> = new_srcs[1..new_srcs.len() - 1].iter().cloned().collect();
1061                    (indices, Some(gate_src))
1062                } else {
1063                    let indices: SmallVec<[Arc<Self>; 4]> = new_srcs[1..].iter().cloned().collect();
1064                    (indices, None)
1065                };
1066                Op::Index { buffer, indices, gate: gate_new }
1067            }
1068            Op::PointerIndex { .. } => {
1069                assert_eq!(new_srcs.len(), 2);
1070                Op::PointerIndex { ptr: src(0), offset: src(1) }
1071            }
1072            Op::Copy { .. } => {
1073                assert_eq!(new_srcs.len(), 2);
1074                Op::Copy { src: src(0), device: src(1) }
1075            }
1076            Op::MStack { .. } => Op::MStack { buffers: new_srcs.iter().cloned().collect() },
1077
1078            // Movement operations
1079            Op::Reshape { .. } => {
1080                assert_eq!(new_srcs.len(), 2);
1081                Op::Reshape { src: src(0), new_shape: src(1) }
1082            }
1083            Op::Permute { axes, .. } => {
1084                assert_eq!(new_srcs.len(), 1);
1085                Op::Permute { src: src(0), axes: axes.clone() }
1086            }
1087            Op::Expand { .. } => {
1088                assert_eq!(new_srcs.len(), 2);
1089                Op::Expand { src: src(0), new_shape: src(1) }
1090            }
1091            Op::Pad { .. } => {
1092                assert_eq!(new_srcs.len(), 3);
1093                Op::Pad { src: src(0), begin_pads: src(1), end_pads: src(2) }
1094            }
1095            Op::Shrink { .. } => {
1096                assert_eq!(new_srcs.len(), 3);
1097                Op::Shrink { src: src(0), begins: src(1), ends: src(2) }
1098            }
1099            Op::Flip { axes, .. } => {
1100                assert_eq!(new_srcs.len(), 1);
1101                Op::Flip { src: src(0), axes: axes.clone() }
1102            }
1103            Op::Multi { axis, .. } => {
1104                assert_eq!(new_srcs.len(), 1);
1105                Op::Multi { src: src(0), axis: *axis }
1106            }
1107
1108            // Reduction operations
1109            Op::ReduceAxis { reduce_op, axes, .. } => {
1110                assert_eq!(new_srcs.len(), 1);
1111                Op::ReduceAxis { src: src(0), reduce_op: *reduce_op, axes: axes.clone() }
1112            }
1113            Op::Reduce { reduce_op, .. } => {
1114                assert!(!new_srcs.is_empty());
1115                Op::Reduce { src: src(0), ranges: new_srcs[1..].iter().cloned().collect(), reduce_op: *reduce_op }
1116            }
1117            Op::AllReduce { reduce_op, .. } => {
1118                assert_eq!(new_srcs.len(), 2);
1119                Op::AllReduce { src: src(0), device: src(1), reduce_op: *reduce_op }
1120            }
1121
1122            // Control flow operations
1123            Op::If { .. } => {
1124                assert!(!new_srcs.is_empty());
1125                Op::If { condition: src(0), body: new_srcs[1..].iter().cloned().collect() }
1126            }
1127            Op::EndIf { .. } => {
1128                assert_eq!(new_srcs.len(), 1);
1129                Op::EndIf { if_op: src(0) }
1130            }
1131            Op::Range { axis_id, axis_type, .. } => {
1132                assert!(!new_srcs.is_empty());
1133                Op::Range {
1134                    end: src(0),
1135                    axis_id: *axis_id,
1136                    axis_type: *axis_type,
1137                    deps: new_srcs[1..].iter().cloned().collect(),
1138                }
1139            }
1140            Op::End { .. } => {
1141                assert!(!new_srcs.is_empty());
1142                Op::End { computation: src(0), ranges: new_srcs[1..].iter().cloned().collect() }
1143            }
1144            Op::Barrier { .. } => {
1145                assert!(!new_srcs.is_empty());
1146                Op::Barrier { src: src(0), deps: new_srcs[1..].iter().cloned().collect() }
1147            }
1148
1149            // Vector operations — recompute dtype from new elements when element
1150            // dtype category changed (e.g. Scalar → Ptr during rewrite reconstruction).
1151            // Preserving old dtype is wrong when DEFINE_LOCAL → AFTER(Ptr) changes
1152            // element types from Scalar to Ptr, causing pm_add_loads infinite loops.
1153            Op::Vectorize { .. } => {
1154                let elements: SmallVec<[Arc<Self>; 4]> = new_srcs.iter().cloned().collect();
1155                let elem_dtype = elements[0].dtype();
1156                let new_dtype = match elem_dtype {
1157                    DType::Scalar(_) | DType::Ptr { .. } => elem_dtype.vec(elements.len()),
1158                    _ => self.dtype.clone(),
1159                };
1160                return Self::new(Op::Vectorize { elements }, new_dtype);
1161            }
1162            Op::Gep { indices, .. } => {
1163                assert_eq!(new_srcs.len(), 1);
1164                Op::Gep { vector: src(0), indices: indices.clone() }
1165            }
1166            Op::Cat { .. } => Op::Cat { sources: new_srcs.iter().cloned().collect() },
1167            Op::PtrCat { .. } => Op::PtrCat { sources: new_srcs.iter().cloned().collect() },
1168
1169            // Symbolic/Define operations
1170            Op::Bind { .. } => {
1171                assert_eq!(new_srcs.len(), 2);
1172                Op::Bind { var: src(0), value: src(1) }
1173            }
1174
1175            // Advanced operations
1176            Op::Wmma { metadata, .. } => {
1177                assert_eq!(new_srcs.len(), 3);
1178                Op::Wmma { a: src(0), b: src(1), c: src(2), metadata: metadata.clone() }
1179            }
1180            Op::Contract { upcast_ranges, .. } => {
1181                assert_eq!(new_srcs.len(), 1);
1182                Op::Contract { src: src(0), upcast_ranges: upcast_ranges.clone() }
1183            }
1184            Op::Unroll { unroll_axes, .. } => {
1185                assert_eq!(new_srcs.len(), 1);
1186                Op::Unroll { src: src(0), unroll_axes: unroll_axes.clone() }
1187            }
1188            Op::Kernel { .. } => {
1189                assert!(!new_srcs.is_empty());
1190                Op::Kernel {
1191                    sources: new_srcs[..new_srcs.len() - 1].iter().cloned().collect(),
1192                    ast: src(new_srcs.len() - 1),
1193                }
1194            }
1195            Op::Assign { .. } => {
1196                assert!(new_srcs.len() >= 2 && new_srcs.len() <= 3, "Assign requires 2-3 sources");
1197                Op::Assign {
1198                    target: src(0),
1199                    value: src(1),
1200                    movement_ops: if new_srcs.len() > 2 { Some(src(2)) } else { None },
1201                }
1202            }
1203            Op::Detach { .. } => {
1204                assert_eq!(new_srcs.len(), 1);
1205                Op::Detach { src: src(0) }
1206            }
1207            Op::Contiguous { opts, .. } => {
1208                assert_eq!(new_srcs.len(), 1);
1209                Op::Contiguous { src: src(0), opts: opts.clone() }
1210            }
1211            Op::ContiguousBackward { .. } => {
1212                assert_eq!(new_srcs.len(), 1);
1213                Op::ContiguousBackward { src: src(0) }
1214            }
1215            Op::After { .. } => {
1216                assert!(!new_srcs.is_empty());
1217                let passthrough = src(0);
1218                // Validate: AFTER passthrough must not be control flow (Tinygrad semantics)
1219                debug_assert!(
1220                    !matches!(passthrough.op(), Op::Range { .. } | Op::End { .. }),
1221                    "reconstruct_sources: AFTER passthrough is {:?} (id={}), violates Tinygrad semantics",
1222                    passthrough.op(),
1223                    passthrough.id
1224                );
1225                Op::After { passthrough, deps: new_srcs[1..].iter().cloned().collect() }
1226            }
1227            Op::Precast { .. } => {
1228                assert_eq!(new_srcs.len(), 1);
1229                Op::Precast { src: src(0) }
1230            }
1231            Op::Custom { code, .. } => Op::Custom { deps: new_srcs.iter().cloned().collect(), code: code.clone() },
1232            Op::CustomI { code, .. } => Op::CustomI { deps: new_srcs.iter().cloned().collect(), code: code.clone() },
1233
1234            // Memory operations
1235            Op::Load { alt, .. } => {
1236                // Load has 2-3 sources: buffer, index, and optionally alt
1237                assert!(new_srcs.len() >= 2 && new_srcs.len() <= 3, "Load requires 2-3 sources");
1238                let new_alt = if new_srcs.len() == 3 { Some(src(2)) } else { alt.clone() };
1239                Op::Load { buffer: src(0), index: src(1), alt: new_alt }
1240            }
1241            Op::Store { .. } => {
1242                assert!(new_srcs.len() >= 2, "Store requires at least 2 sources (index, value)");
1243                Op::Store { index: src(0), value: src(1), ranges: new_srcs[2..].iter().cloned().collect() }
1244            }
1245
1246            // Graph organization
1247            Op::Sink { .. } => Op::Sink { sources: new_srcs.iter().cloned().collect() },
1248            Op::Group { .. } => Op::Group { sources: new_srcs.iter().cloned().collect() },
1249        };
1250
1251        // Preserve original dtype and tag (Tinygrad ops.py:1256: preserves tag through source reconstruction)
1252        Self::new_tagged(new_op, self.dtype.clone(), self.tag.clone())
1253    }
1254}
1255
1256#[bon]
1257impl UOp {
1258    /// Create a modified copy with optional field overrides.
1259    ///
1260    /// Enables concise pattern implementations by allowing selective field modification.
1261    /// Returns `self.clone()` if nothing changed (optimization for hash consing).
1262    ///
1263    /// # Examples
1264    ///
1265    /// ```ignore
1266    /// let new_load = load.replace().dtype(new_dtype).src(new_sources).call();
1267    /// let dtype_only = load.replace().dtype(new_dtype).call();
1268    /// ```
1269    #[builder]
1270    pub fn replace(self: &Arc<Self>, dtype: Option<DType>, src: Option<Vec<Arc<Self>>>) -> Arc<Self> {
1271        let new_dtype = dtype.unwrap_or_else(|| self.dtype());
1272        let new_sources = src.unwrap_or_else(|| self.op().sources().to_vec());
1273
1274        // Short-circuit if nothing changed
1275        let old_sources = self.op().sources();
1276        let sources_unchanged = new_sources.len() == old_sources.len()
1277            && new_sources.iter().zip(old_sources.iter()).all(|(a, b)| Arc::ptr_eq(a, b));
1278
1279        if new_dtype == self.dtype() && sources_unchanged {
1280            return self.clone();
1281        }
1282
1283        self.with_sources(new_sources).with_dtype(new_dtype)
1284    }
1285}
1286
1287impl Clone for UOp {
1288    fn clone(&self) -> Self {
1289        Self {
1290            id: self.id,
1291            op: self.op.clone(),
1292            dtype: self.dtype.clone(),
1293            content_hash: self.content_hash,
1294            tag: self.tag.clone(),
1295            shape_cache: std::sync::OnceLock::new(),
1296            ranges_cache: std::sync::OnceLock::new(),
1297            in_scope_ranges_cache: std::sync::OnceLock::new(),
1298            vmin_vmax_cache: std::sync::OnceLock::new(),
1299            sound_vmin_vmax_cache: std::sync::OnceLock::new(),
1300            has_index_in_sources_cache: std::sync::OnceLock::new(),
1301            backward_slice_cache: std::sync::OnceLock::new(),
1302            metadata: self.metadata.clone(),
1303        }
1304    }
1305}
1306
1307/// Trait for converting scalar values into UOps.
1308///
1309/// This allows operator overloading to work with mixed scalar/UOp operands.
1310/// For example: `uop + 5.0` or `5.0 + uop`.
1311pub trait IntoUOp {
1312    fn into_uop(self, dtype: DType) -> Arc<UOp>;
1313}
1314
1315impl IntoUOp for ConstValue {
1316    fn into_uop(self, dtype: DType) -> Arc<UOp> {
1317        UOp::const_(dtype, self)
1318    }
1319}
1320
1321impl IntoUOp for f32 {
1322    fn into_uop(self, dtype: DType) -> Arc<UOp> {
1323        UOp::const_(dtype, ConstValue::Float(self as f64))
1324    }
1325}
1326
1327impl IntoUOp for f64 {
1328    fn into_uop(self, dtype: DType) -> Arc<UOp> {
1329        UOp::const_(dtype, ConstValue::Float(self))
1330    }
1331}
1332
1333impl IntoUOp for i32 {
1334    fn into_uop(self, dtype: DType) -> Arc<UOp> {
1335        UOp::const_(dtype, ConstValue::Int(self as i64))
1336    }
1337}
1338
1339impl IntoUOp for i64 {
1340    fn into_uop(self, dtype: DType) -> Arc<UOp> {
1341        UOp::const_(dtype, ConstValue::Int(self))
1342    }
1343}
1344
1345impl IntoUOp for u32 {
1346    fn into_uop(self, dtype: DType) -> Arc<UOp> {
1347        UOp::const_(dtype, ConstValue::UInt(self as u64))
1348    }
1349}
1350
1351impl IntoUOp for u64 {
1352    fn into_uop(self, dtype: DType) -> Arc<UOp> {
1353        UOp::const_(dtype, ConstValue::UInt(self))
1354    }
1355}
1356
1357impl IntoUOp for bool {
1358    fn into_uop(self, dtype: DType) -> Arc<UOp> {
1359        UOp::const_(dtype, ConstValue::Bool(self))
1360    }
1361}