morok_ir/uop/core.rs
1//! Core UOp struct and fundamental operations.
2//!
3//! This module contains the [`UOp`] struct definition and its core methods
4//! for accessing operation data, dtype, shape, and graph traversal.
5
6use std::collections::{HashMap, HashSet};
7use std::hash::{Hash, Hasher};
8use std::sync::Arc;
9
10use bon::bon;
11use smallvec::SmallVec;
12
13use crate::op::Op;
14use crate::pattern::{Matcher, RewriteResult};
15use crate::shape;
16use crate::types::{AxisType, ConstValue};
17use morok_dtype::DType;
18
19/// Matcher for `UOp::substitute` — looks up each node in a substitution map.
20///
21/// Equivalent to Tinygrad's `_substitute = PatternMatcher([(UPat(tuple(Ops)), lambda ctx,x: ctx.get(x))])`.
22struct SubstituteMatcher<'a>(&'a HashMap<UOpKey, Arc<UOp>>);
23
24impl Matcher<()> for SubstituteMatcher<'_> {
25 fn rewrite(&self, uop: &Arc<UOp>, _ctx: &mut ()) -> RewriteResult {
26 match self.0.get(&UOpKey(uop.clone())) {
27 Some(replacement) if !Arc::ptr_eq(uop, replacement) => RewriteResult::Rewritten(replacement.clone()),
28 _ => RewriteResult::NoMatch,
29 }
30 }
31}
32
33/// Matcher for `UOp::substitute_gated` — substitution with range-scope gating.
34///
35/// Equivalent to Tinygrad's `_substitute` + `pm_gate_substitute`:
36/// - If a node is in the substitution map, replace it.
37/// - If a node's ranges don't overlap with substitution keys, gate (skip subtree).
38struct SubstituteGatedMatcher<'a> {
39 map: &'a HashMap<UOpKey, Arc<UOp>>,
40 range_keys: &'a HashSet<UOpKey>,
41}
42
43impl Matcher<()> for SubstituteGatedMatcher<'_> {
44 fn rewrite(&self, uop: &Arc<UOp>, _ctx: &mut ()) -> RewriteResult {
45 // Direct substitution lookup
46 if let Some(replacement) = self.map.get(&UOpKey(uop.clone()))
47 && !Arc::ptr_eq(uop, replacement)
48 {
49 return RewriteResult::Rewritten(replacement.clone());
50 }
51 // Gate: skip subtrees whose ranges don't overlap with substitution keys
52 // Tinygrad (rangeify.py:187): `if not any(r in b.ranges for r in ctx.keys()): raise BottomUpGate()`
53 if !uop.in_scope_ranges().iter().any(|r| self.range_keys.contains(r)) {
54 return RewriteResult::Gate(uop.clone());
55 }
56 RewriteResult::NoMatch
57 }
58}
59
60/// Wrapper for Arc<UOp> that implements Hash and Eq based on stable ID.
61///
62/// This allows using Arc<UOp> as HashMap keys without implementing
63/// Hash/Eq on UOp itself (which would be problematic due to OnceCell fields).
64///
65/// Note: While UOp contains OnceCell fields, Hash/Eq are based solely on the
66/// immutable `id` field, making this safe to use as a HashMap key.
67#[allow(clippy::mutable_key_type)]
68#[derive(Clone)]
69pub struct UOpKey(pub Arc<UOp>);
70
71// Custom Debug impl to show only the UOp ID, avoiding recursive printing
72impl std::fmt::Debug for UOpKey {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 write!(f, "UOpKey(id={})", self.0.id)
75 }
76}
77
78impl PartialEq for UOpKey {
79 fn eq(&self, other: &Self) -> bool {
80 self.0.id == other.0.id
81 }
82}
83
84impl Eq for UOpKey {}
85
86impl Hash for UOpKey {
87 fn hash<H: Hasher>(&self, state: &mut H) {
88 self.0.id.hash(state);
89 }
90}
91
92/// Micro-operation node in the computation graph.
93///
94/// UOps form a DAG where operations reference their inputs through the Op enum.
95/// Hash consing ensures that structurally identical UOps share the same allocation.
96///
97/// Shape inference is lazy and cached - computed on first access via `shape()` method.
98///
99/// Note: Debug uses derive_more with `#[debug(skip)]` on cache fields to prevent
100/// stack overflow from recursive Arc<UOp> references in caches.
101#[derive(derive_more::Debug)]
102pub struct UOp {
103 /// Unique stable ID for this UOp instance.
104 /// Used for identity-based caching instead of fragile raw pointers.
105 pub id: u64,
106 pub(crate) op: Op,
107 pub(crate) dtype: DType,
108 /// Cached shape - computed lazily on first access.
109 /// OnceLock provides thread-safe lazy initialization.
110 #[debug(skip)]
111 pub(crate) shape_cache: std::sync::OnceLock<crate::Result<Option<shape::Shape>>>,
112 /// Cached list of RANGE operations in this UOp's graph.
113 /// Computed lazily via toposort to collect all RANGE ops.
114 #[debug(skip)]
115 pub(crate) ranges_cache: std::sync::OnceLock<Vec<Arc<UOp>>>,
116 /// Cached set of RANGE operations that are in scope at this UOp.
117 /// Unlike ranges_cache which contains ALL ranges in the graph,
118 /// this contains only the ranges that are currently "active" (not yet ended).
119 /// Computed lazily based on Tinygrad's ranges property.
120 /// Uses UOpKey wrapper to enable Hash/Eq based on UOp ID.
121 #[debug(skip)]
122 pub(crate) in_scope_ranges_cache: std::sync::OnceLock<HashSet<UOpKey>>,
123 /// Cached vmin/vmax range analysis values.
124 /// Computed lazily via range propagation through the computation graph.
125 /// Returns (vmin, vmax) as ConstValue types.
126 #[debug(skip)]
127 pub(crate) vmin_vmax_cache: std::sync::OnceLock<(ConstValue, ConstValue)>,
128 /// Sound vmin/vmax: `None` for ops where range analysis is unsound (LOAD, Pow, etc.).
129 /// Used by patterns that must not act on unsound bounds (e.g., vmin_vmax_collapse).
130 #[debug(skip)]
131 pub(crate) sound_vmin_vmax_cache: std::sync::OnceLock<Option<(ConstValue, ConstValue)>>,
132 /// Whether this node or any of its sources is an INDEX op.
133 /// Cached O(1) lookup used by `simplify_valid` to skip And chains inside INDEX trees.
134 #[debug(skip)]
135 pub(crate) has_index_in_sources_cache: std::sync::OnceLock<bool>,
136 /// Cached backward slice: IDs of all nodes reachable from this UOp (including self).
137 /// O(1) membership test via `backward_slice_ids().contains(&target.id)`.
138 #[debug(skip)]
139 pub(crate) backward_slice_cache: std::sync::OnceLock<HashSet<u64>>,
140 /// Structural content hash — deterministic regardless of allocation order.
141 /// Computed at creation time: hash(op_discriminant, dtype, op_data, children_content_hashes).
142 /// O(1) per node since children are already created with their content_hash set.
143 /// Used for schedule-level caching where UOp IDs are not stable across runs.
144 pub content_hash: u64,
145 /// Tag for tracking tensor identity through the rangeify pipeline.
146 ///
147 /// Matches Tinygrad's `UOp.tag` (ops.py:128). Tags are tuples of integer indices
148 /// that track which original tensor UOps map to which final kernel outputs.
149 /// Tags participate in hash consing — different tag = different UOp.
150 ///
151 /// Values:
152 /// - `None` — untagged (default)
153 /// - `Some([])` — empty tag (e.g., RANGE ops)
154 /// - `Some([i])` — single index (assigned by add_tags)
155 /// - `Some([i, j, ...])` — merged indices (from buffer folding)
156 pub tag: Option<SmallVec<[usize; 2]>>,
157 /// Optional metadata attached to this UOp.
158 ///
159 /// Metadata is NOT part of hash consing - attaching metadata creates a new UOp
160 /// instance with a different ID. This is used for kernel info (name, opts) after
161 /// optimization is complete.
162 ///
163 /// Uses Arc<dyn Any> to allow attaching any metadata type without circular
164 /// dependencies (e.g., schedule::KernelInfo).
165 #[debug(skip)]
166 pub(crate) metadata: Option<std::sync::Arc<dyn std::any::Any + Send + Sync>>,
167}
168
169/// Hash implementation for UOp based on content (dtype + op).
170///
171/// This enables content-based hashing for cross-run caching. The hash traverses
172/// the DAG structure since Op contains Arc<UOp> children that also get hashed.
173/// Cache fields are intentionally skipped - they don't affect semantic identity.
174impl Hash for UOp {
175 fn hash<H: Hasher>(&self, state: &mut H) {
176 self.dtype.hash(state);
177 self.op.hash(state);
178 // Intentionally skip: id, caches, metadata
179 }
180}
181
182impl UOp {
183 /// Get the operation.
184 pub fn op(&self) -> &Op {
185 &self.op
186 }
187
188 /// Get the data type.
189 pub fn dtype(&self) -> DType {
190 self.dtype.clone()
191 }
192
193 /// Get the tag.
194 pub fn tag(&self) -> &Option<SmallVec<[usize; 2]>> {
195 &self.tag
196 }
197
198 /// Create a new UOp with the given tag (Tinygrad: `rtag()`).
199 /// Returns self unchanged if tag is already equal.
200 pub fn rtag(self: &Arc<Self>, tag: Option<SmallVec<[usize; 2]>>) -> Arc<Self> {
201 if self.tag == tag {
202 return self.clone();
203 }
204 Self::new_tagged(self.op.clone(), self.dtype.clone(), tag)
205 }
206
207 /// Create a new UOp with the given tag set.
208 pub fn with_tag(self: &Arc<Self>, tag: SmallVec<[usize; 2]>) -> Arc<Self> {
209 self.rtag(Some(tag))
210 }
211
212 /// Check if this UOp has a concrete buffer identity in the graph.
213 ///
214 /// Returns true for BUFFER or RESHAPE/MULTI chains leading to BUFFER.
215 /// These are already contiguous by definition, so wrapping in CONTIGUOUS is a no-op.
216 ///
217 /// Based on Tinygrad's `UOp.has_buffer_identity()` (ops.py:616-619).
218 pub fn has_buffer_identity(&self) -> bool {
219 match &self.op {
220 Op::Reshape { src, .. } => src.has_buffer_identity(),
221 Op::Buffer { .. } => true,
222 _ => false,
223 }
224 }
225
226 /// Get pointer dtype components if this UOp has a Ptr dtype.
227 ///
228 /// Returns `(base, addrspace, size)` for Ptr types, None otherwise.
229 /// This simplifies pattern matching on pointer types.
230 ///
231 /// # Examples
232 ///
233 /// ```rust
234 /// # use morok_ir::UOp;
235 /// # use morok_dtype::{DType, AddrSpace, DeviceSpec};
236 /// let buffer = UOp::new_buffer(DeviceSpec::Cpu, 10, DType::Float32);
237 /// if let Some((base, addrspace, size)) = buffer.ptrdtype() {
238 /// assert_eq!(*base, DType::Float32);
239 /// assert_eq!(addrspace, AddrSpace::Global);
240 /// }
241 /// ```
242 pub fn ptrdtype(&self) -> Option<(&DType, morok_dtype::AddrSpace, Option<usize>)> {
243 match &self.dtype {
244 DType::Ptr { base, addrspace, size, .. } => Some((base.as_ref(), *addrspace, *size)),
245 _ => None,
246 }
247 }
248
249 /// Create a copy of this UOp with a different dtype.
250 ///
251 /// If the dtype is unchanged, returns self (clone of Arc).
252 /// This is the Rust equivalent of Tinygrad's `buf.replace(dtype=x)`.
253 ///
254 /// # Examples
255 ///
256 /// ```rust
257 /// # use std::sync::Arc;
258 /// # use morok_ir::UOp;
259 /// # use morok_dtype::DType;
260 /// let int_const = UOp::const_(DType::Int32, morok_ir::ConstValue::Int(5));
261 /// let float_const = int_const.with_dtype(DType::Float32);
262 /// assert_eq!(float_const.dtype(), DType::Float32);
263 /// ```
264 pub fn with_dtype(self: &Arc<Self>, dtype: DType) -> Arc<Self> {
265 if self.dtype == dtype {
266 return self.clone();
267 }
268 Self::new(self.op.clone(), dtype)
269 }
270
271 /// Walk through AFTER nodes to get the passthrough value.
272 ///
273 /// This is the Rust equivalent of Tinygrad's `.or_after()` pattern.
274 /// Recursively unwraps AFTER nodes to find the underlying value.
275 ///
276 /// # Examples
277 ///
278 /// ```ignore
279 /// // Given: AFTER(AFTER(value, [dep1]), [dep2])
280 /// // Returns: value
281 /// let inner = wrapped.unwrap_after();
282 /// ```
283 pub fn unwrap_after(self: &Arc<Self>) -> Arc<Self> {
284 match self.op() {
285 Op::After { passthrough, .. } => passthrough.unwrap_after(),
286 _ => self.clone(),
287 }
288 }
289
290 /// Walk through CAST nodes to get the inner value.
291 ///
292 /// This is the Rust equivalent of Tinygrad's `.or_casted()` pattern.
293 /// Recursively unwraps CAST nodes to find the underlying value.
294 ///
295 /// # Examples
296 ///
297 /// ```ignore
298 /// // Given: CAST(CAST(value, dtype1), dtype2)
299 /// // Returns: value
300 /// let inner = casted.unwrap_cast();
301 /// ```
302 pub fn unwrap_cast(self: &Arc<Self>) -> Arc<Self> {
303 match self.op() {
304 Op::Cast { src, .. } => src.unwrap_cast(),
305 _ => self.clone(),
306 }
307 }
308
309 /// Get the buffer from a STORE operation (via its INDEX child).
310 ///
311 /// STORE operations reference the buffer indirectly through an INDEX node.
312 /// This helper extracts the buffer from `STORE.index.buffer`.
313 ///
314 /// Returns `None` if:
315 /// - This is not a STORE operation
316 /// - The STORE's index is not an INDEX operation
317 pub fn store_buffer(&self) -> Option<&Arc<UOp>> {
318 match self.op() {
319 Op::Store { index, .. } => match index.op() {
320 Op::Index { buffer, .. } => Some(buffer),
321 _ => None,
322 },
323 _ => None,
324 }
325 }
326
327 /// Get the buffer from a LOAD operation.
328 ///
329 /// Returns `None` if this is not a LOAD operation.
330 pub fn load_buffer(&self) -> Option<Arc<UOp>> {
331 match self.op() {
332 Op::Load { buffer, .. } => Some(buffer.clone()),
333 _ => None,
334 }
335 }
336
337 /// Store a value at this INDEX node.
338 ///
339 /// Convenience method for `self.store(value)`.
340 /// Matches Tinygrad's `idx.store(val)` pattern.
341 ///
342 /// # Panics
343 ///
344 /// Debug-asserts that self is an INDEX operation.
345 pub fn store_value(self: &Arc<Self>, value: Arc<Self>) -> Arc<Self> {
346 debug_assert!(matches!(self.op(), Op::Index { .. }), "store_value requires INDEX");
347 self.store(value)
348 }
349
350 /// Alias for `with_sources()`.
351 ///
352 /// Creates a new UOp with the same operation type and dtype, but with
353 /// the provided sources replacing the original ones.
354 pub fn with_src(self: &Arc<Self>, new_srcs: Vec<Arc<Self>>) -> Arc<Self> {
355 self.with_sources(new_srcs)
356 }
357
358 /// Get the shape of this UOp.
359 ///
360 /// Shape is computed lazily on first access and cached.
361 /// Returns Ok(None) if shape cannot be determined (e.g., for control flow ops).
362 /// Returns Err if there is a shape mismatch error.
363 ///
364 /// # Examples
365 ///
366 /// ```rust
367 /// # use morok_ir::{UOp, ConstValue};
368 /// # use morok_dtype::DType;
369 /// let scalar = UOp::const_(DType::Float32, ConstValue::Float(1.0));
370 /// assert_eq!(scalar.shape().unwrap().as_ref().map(|s| s.len()), Some(0)); // Scalar has empty shape
371 /// ```
372 pub fn shape(self: &Arc<Self>) -> crate::Result<Option<&shape::Shape>> {
373 use crate::uop::cached_property::CachedProperty;
374 use crate::uop::properties::ShapeProperty;
375 match ShapeProperty::get(self) {
376 Ok(opt) => Ok(opt.as_ref()),
377 Err(e) => Err(e.clone()),
378 }
379 }
380
381 /// Get the minimum possible value of this UOp.
382 ///
383 /// Returns the minimum value based on range analysis.
384 /// Computed lazily on first access and cached.
385 ///
386 /// # Examples
387 ///
388 /// ```rust
389 /// # use morok_ir::{UOp, ConstValue};
390 /// # use morok_dtype::DType;
391 /// let five = UOp::const_(DType::Int32, ConstValue::Int(5));
392 /// assert_eq!(five.vmin(), &ConstValue::Int(5));
393 /// ```
394 pub fn vmin(self: &Arc<Self>) -> &ConstValue {
395 use crate::uop::cached_property::CachedProperty;
396 use crate::uop::properties::VminVmaxProperty;
397 &VminVmaxProperty::get(self).0
398 }
399
400 /// Get the maximum possible value of this UOp.
401 ///
402 /// Returns the maximum value based on range analysis.
403 /// Computed lazily on first access and cached.
404 ///
405 /// # Examples
406 ///
407 /// ```rust
408 /// # use morok_ir::{UOp, ConstValue};
409 /// # use morok_dtype::DType;
410 /// let five = UOp::const_(DType::Int32, ConstValue::Int(5));
411 /// assert_eq!(five.vmax(), &ConstValue::Int(5));
412 /// ```
413 pub fn vmax(self: &Arc<Self>) -> &ConstValue {
414 use crate::uop::cached_property::CachedProperty;
415 use crate::uop::properties::VminVmaxProperty;
416 &VminVmaxProperty::get(self).1
417 }
418
419 /// Extract device specification from this UOp graph.
420 ///
421 /// Traverses the graph to find Op::Device nodes following Tinygrad's
422 /// `_device` recursive property (ops.py:585-599):
423 /// - DEVICE: returns the DeviceSpec directly
424 /// - BUFFER: returns device from the device child
425 /// - COPY: returns device from the device child (target device)
426 /// - Otherwise: searches children recursively
427 ///
428 /// # Examples
429 ///
430 /// ```rust
431 /// # use morok_ir::UOp;
432 /// # use morok_dtype::{DType, DeviceSpec};
433 /// let buffer = UOp::new_buffer(DeviceSpec::Cpu, 10, DType::Float32);
434 /// assert_eq!(buffer.device_spec(), Some(DeviceSpec::Cpu));
435 /// ```
436 pub fn device_spec(&self) -> Option<morok_dtype::DeviceSpec> {
437 match self.op() {
438 Op::Device(spec) => Some(spec.clone()),
439 Op::Buffer { device, .. } => {
440 if let Op::Device(spec) = device.op() {
441 Some(spec.clone())
442 } else {
443 None
444 }
445 }
446 Op::Param { device: Some(device), .. } => {
447 if let Op::Device(spec) = device.op() {
448 Some(spec.clone())
449 } else {
450 None
451 }
452 }
453 Op::Param { device: None, .. } => None,
454 Op::Copy { device, .. } => {
455 if let Op::Device(spec) = device.op() {
456 Some(spec.clone())
457 } else {
458 None
459 }
460 }
461 _ => {
462 // Search children for device
463 for child in self.op().children() {
464 if let Some(spec) = child.device_spec() {
465 return Some(spec);
466 }
467 }
468 None
469 }
470 }
471 }
472
473 /// Get the base UOp by walking through movement operations.
474 ///
475 /// Movement operations (RESHAPE, PERMUTE, EXPAND, etc.) are views that don't
476 /// change the underlying data. This method recursively walks through these
477 /// operations to find the actual buffer or computation that owns the data.
478 ///
479 /// Based on Tinygrad's `base` property (ops.py:524-527).
480 ///
481 /// # Examples
482 ///
483 /// ```rust
484 /// # use morok_ir::{UOp, SInt, shape::Shape};
485 /// # use morok_dtype::DType;
486 /// # use morok_dtype::DeviceSpec;
487 /// let buffer = UOp::new_buffer(DeviceSpec::Cpu, 10, DType::Float32);
488 /// let shape = Shape::from_iter([SInt::Const(2), SInt::Const(5)]);
489 /// let reshaped = buffer.try_reshape(&shape).unwrap();
490 ///
491 /// // base() walks through RESHAPE to get the original BUFFER
492 /// assert!(std::sync::Arc::ptr_eq(&reshaped.base(), &buffer));
493 /// ```
494 pub fn base(self: &Arc<Self>) -> Arc<Self> {
495 match &self.op {
496 // Movement operations - recursively get base of source
497 Op::Reshape { src, .. }
498 | Op::Permute { src, .. }
499 | Op::Expand { src, .. }
500 | Op::Pad { src, .. }
501 | Op::Shrink { src, .. }
502 | Op::Flip { src, .. }
503 | Op::Multi { src, .. } => src.base(),
504 // All other operations are their own base
505 _ => self.clone(),
506 }
507 }
508
509 /// Get the underlying buffer UOp, walking through AFTER/MSELECT/MSTACK chains.
510 ///
511 /// Based on Tinygrad's `buf_uop` property (ops.py:601-606).
512 /// This recursively unwraps AFTER chains to find the actual buffer.
513 ///
514 /// # Examples
515 ///
516 /// ```ignore
517 /// use morok_ir::UOp;
518 ///
519 /// // AFTER wrapping a buffer
520 /// let buffer = UOp::new_buffer(...);
521 /// let after = buffer.after(deps);
522 ///
523 /// // buf_uop() walks through AFTER to get the underlying buffer
524 /// assert!(Arc::ptr_eq(&after.buf_uop(), &buffer));
525 /// ```
526 pub fn buf_uop(self: &Arc<Self>) -> Arc<Self> {
527 match self.op() {
528 Op::Buffer { .. } => self.clone(),
529 Op::MSelect { buffer, .. } => buffer.buf_uop(),
530 Op::MStack { buffers } if !buffers.is_empty() => buffers[0].buf_uop(),
531 Op::After { passthrough, .. } => passthrough.buf_uop(),
532 _ => {
533 // For other ops, check if base is AFTER
534 let base = self.base();
535 if matches!(base.op(), Op::After { .. }) { base.buf_uop() } else { self.clone() }
536 }
537 }
538 }
539
540 /// Topological sort of the computation graph.
541 ///
542 /// Returns nodes in an order where all dependencies come before their dependents.
543 pub fn toposort(self: &Arc<Self>) -> Vec<Arc<Self>> {
544 let mut visited = HashSet::new();
545 let mut result = Vec::new();
546 let mut stack = vec![(self.clone(), false)];
547
548 while let Some((node, processed)) = stack.pop() {
549 let ptr = Arc::as_ptr(&node);
550
551 if visited.contains(&ptr) {
552 continue;
553 }
554
555 if processed {
556 visited.insert(ptr);
557 result.push(node);
558 } else {
559 stack.push((node.clone(), true));
560
561 // Use for_each_child for zero-allocation traversal
562 let mut children = Vec::new();
563 node.op.map_child(|child| {
564 if !visited.contains(&Arc::as_ptr(child)) {
565 children.push(child.clone());
566 }
567 });
568
569 // Push in reverse order for proper traversal
570 for child in children.into_iter().rev() {
571 stack.push((child, false));
572 }
573 }
574 }
575
576 result
577 }
578
579 /// Topological sort with gate function (filtered toposort).
580 ///
581 /// Only traverses nodes for which `gate(node)` returns true.
582 /// Nodes for which gate returns false are excluded from the
583 /// traversal entirely (along with their ancestors).
584 ///
585 /// This is a key optimization for cached property computation,
586 /// allowing us to skip nodes that already have a property cached.
587 ///
588 /// # Performance
589 ///
590 /// For a graph with 10,000 nodes where 9,900 already have a cached property:
591 /// - **Full toposort**: 10,000 nodes visited
592 /// - **Filtered toposort**: 100 nodes visited
593 /// - **Speedup**: 100x
594 ///
595 /// # Example
596 ///
597 /// ```ignore
598 /// // Only process nodes that don't have shape cached
599 /// let uncached = uop.toposort_filtered(|node| {
600 /// node.shape_cache.get().is_none()
601 /// });
602 /// ```
603 pub fn toposort_filtered<F>(self: &Arc<Self>, gate: F) -> Vec<Arc<Self>>
604 where
605 F: Fn(&Arc<UOp>) -> bool,
606 {
607 let mut visited = HashSet::new();
608 let mut result = Vec::new();
609 let mut stack = vec![(self.clone(), false)];
610
611 while let Some((node, processed)) = stack.pop() {
612 let ptr = Arc::as_ptr(&node);
613
614 if visited.contains(&ptr) {
615 continue;
616 }
617
618 if processed {
619 visited.insert(ptr);
620 result.push(node);
621 } else {
622 // Key optimization: only traverse nodes that pass the gate
623 if gate(&node) {
624 stack.push((node.clone(), true));
625
626 let mut children = Vec::new();
627 node.op.map_child(|child| {
628 if !visited.contains(&Arc::as_ptr(child)) {
629 children.push(child.clone());
630 }
631 });
632
633 // Push in reverse order for proper traversal
634 for child in children.into_iter().rev() {
635 stack.push((child, false));
636 }
637 }
638 }
639 }
640
641 result
642 }
643
644 /// Check if any node in the backward slice satisfies a predicate.
645 ///
646 /// Early-exit DFS — returns `true` as soon as a matching node is found,
647 /// without building the full toposort Vec. Use this instead of
648 /// `toposort().iter().any(pred)` when you only need an existential check.
649 pub fn any_in_subtree<F>(self: &Arc<Self>, pred: F) -> bool
650 where
651 F: Fn(&Arc<UOp>) -> bool,
652 {
653 let mut visited = HashSet::new();
654 let mut stack = vec![self.clone()];
655 while let Some(node) = stack.pop() {
656 if !visited.insert(Arc::as_ptr(&node)) {
657 continue;
658 }
659 if pred(&node) {
660 return true;
661 }
662 node.op.map_child(|child| {
663 if !visited.contains(&Arc::as_ptr(child)) {
664 stack.push(child.clone());
665 }
666 });
667 }
668 false
669 }
670
671 /// Collect all nodes in the backward slice that match a predicate.
672 ///
673 /// DFS collecting matches — cheaper than `toposort().iter().filter(pred).collect()`
674 /// when you don't need topological ordering.
675 pub fn collect_in_subtree<F>(self: &Arc<Self>, pred: F) -> Vec<Arc<UOp>>
676 where
677 F: Fn(&Arc<UOp>) -> bool,
678 {
679 let mut visited = HashSet::new();
680 let mut stack = vec![self.clone()];
681 let mut result = Vec::new();
682 while let Some(node) = stack.pop() {
683 if !visited.insert(Arc::as_ptr(&node)) {
684 continue;
685 }
686 if pred(&node) {
687 result.push(node.clone());
688 }
689 node.op.map_child(|child| {
690 if !visited.contains(&Arc::as_ptr(child)) {
691 stack.push(child.clone());
692 }
693 });
694 }
695 result
696 }
697
698 /// Count unique nodes in the DAG rooted at this UOp.
699 ///
700 /// Much cheaper than `toposort().len()` — no result Vec, no ordering.
701 /// Uses pointer-based visited set for O(1) identity checks.
702 pub fn node_count(self: &Arc<Self>) -> usize {
703 let mut visited = HashSet::new();
704 let mut stack = vec![self.clone()];
705 while let Some(node) = stack.pop() {
706 if !visited.insert(Arc::as_ptr(&node)) {
707 continue;
708 }
709 node.op.map_child(|child| {
710 if !visited.contains(&Arc::as_ptr(child)) {
711 stack.push(child.clone());
712 }
713 });
714 }
715 visited.len()
716 }
717
718 /// O(1) cached check: does this node or any of its sources contain an INDEX op?
719 ///
720 /// Computed lazily and cached. Each node checks itself and its direct sources'
721 /// cached values, so the total cost across the graph is O(N).
722 pub fn has_index_in_sources(self: &Arc<Self>) -> bool {
723 *self.has_index_in_sources_cache.get_or_init(|| {
724 if matches!(self.op, Op::Index { .. }) {
725 return true;
726 }
727 let mut result = false;
728 self.op.map_child(|child| {
729 if child.has_index_in_sources() {
730 result = true;
731 }
732 });
733 result
734 })
735 }
736
737 /// Render this UOp and its sources as a compact ASCII tree.
738 ///
739 /// Shared nodes (appearing multiple times due to hash-consing) are shown
740 /// as back-references: `[id] → (see above)`
741 ///
742 /// # Example Output
743 ///
744 /// ```text
745 /// [42] STORE : Void
746 /// ├── [10] PARAM(0) : Ptr<Float32> shape=[4]
747 /// ├── [35] INDEX : Ptr<Float32> shape=[4]
748 /// │ ├── [10] → (see above)
749 /// │ └── [30] RANGE(0, Reduce) : Index
750 /// │ └── [5] CONST(Int(4)) : Index
751 /// └── [40] REDUCE(Add) : Float32 shape=[]
752 /// └── [35] → (see above)
753 /// ```
754 pub fn tree(self: &Arc<Self>) -> String {
755 crate::uop::tree::render_tree_compact(self)
756 }
757
758 /// Render this UOp and its sources as a full ASCII tree.
759 ///
760 /// Shared nodes are expanded every time they appear (verbose but complete).
761 /// Use this when you need to see the full subtree at every occurrence.
762 pub fn tree_full(self: &Arc<Self>) -> String {
763 crate::uop::tree::render_tree_full(self)
764 }
765
766 /// Get all RANGE operations in this UOp's computation graph.
767 ///
768 /// Lazily computed and cached. Useful for rangeify pass to track loop variables.
769 pub fn ranges(self: &Arc<Self>) -> &Vec<Arc<Self>> {
770 use crate::uop::cached_property::CachedProperty;
771 use crate::uop::properties::RangesProperty;
772 RangesProperty::get(self)
773 }
774
775 /// Get the RANGE operations that are in scope at this UOp.
776 ///
777 /// Returns only the ranges that are currently "active" (not yet ended).
778 /// This is computed by:
779 /// 1. Merging ranges from all source operations
780 /// 2. Removing ranges that are ended by this operation
781 /// 3. Adding self if this is a RANGE operation
782 ///
783 /// Based on Tinygrad's `ranges` property (ops.py:318-320) and
784 /// `_ranges` recursive property (ops.py:302-315).
785 ///
786 /// # Returns
787 ///
788 /// A HashSet of RANGE UOps that are in scope at this point in the graph.
789 /// The result is cached for performance.
790 ///
791 /// # Examples
792 ///
793 /// ```ignore
794 /// use morok_ir::{UOp, AxisType};
795 ///
796 /// // A simple computation inside a range
797 /// let range = UOp::range(end, 0, AxisType::Loop);
798 /// let value = UOp::const_(...);
799 /// let end_op = value.end(vec![range.clone()]);
800 ///
801 /// // Value has range in scope
802 /// assert!(value.in_scope_ranges().contains(&range));
803 ///
804 /// // After END, range is no longer in scope
805 /// assert!(!end_op.in_scope_ranges().contains(&range));
806 /// ```
807 #[allow(clippy::mutable_key_type)]
808 pub fn in_scope_ranges(self: &Arc<Self>) -> &HashSet<UOpKey> {
809 use crate::uop::cached_property::CachedProperty;
810 use crate::uop::properties::InScopeRangesProperty;
811 InScopeRangesProperty::get(self)
812 }
813
814 /// Check if all in-scope ranges at this UOp have the given AxisType.
815 ///
816 /// Returns true if the in-scope ranges set is empty or all ranges
817 /// match the specified axis type.
818 ///
819 /// # Use Cases
820 ///
821 /// - `all_in_scope_ranges_are(AxisType::Outer)` - Used in split_store
822 /// to determine if we're at a kernel boundary
823 ///
824 /// # Examples
825 ///
826 /// ```ignore
827 /// use morok_ir::{UOp, AxisType};
828 ///
829 /// // At kernel boundary: only OUTER ranges in scope
830 /// assert!(uop.all_in_scope_ranges_are(AxisType::Outer));
831 ///
832 /// // Inside kernel: has non-OUTER ranges
833 /// assert!(!uop.all_in_scope_ranges_are(AxisType::Outer));
834 /// ```
835 #[allow(clippy::mutable_key_type)]
836 pub fn all_in_scope_ranges_are(self: &Arc<Self>, axis_type: AxisType) -> bool {
837 use crate::Op;
838
839 let ranges = self.in_scope_ranges();
840
841 // Empty scope means we're at the top level (treat as all OUTER)
842 if ranges.is_empty() {
843 return true;
844 }
845
846 ranges.iter().all(|r| match r.0.op() {
847 Op::Range { axis_type: at, .. } => *at == axis_type,
848 _ => false, // Should never happen
849 })
850 }
851
852 /// Check if any in-scope range is NOT of the given AxisType.
853 ///
854 /// Inverse of `all_in_scope_ranges_are`. Useful for Tinygrad-style
855 /// filtering: "skip if any range is not OUTER".
856 ///
857 /// # Examples
858 ///
859 /// ```ignore
860 /// use morok_ir::{UOp, AxisType};
861 ///
862 /// // Has non-OUTER ranges: should skip in split_store
863 /// if uop.has_non_outer_ranges() {
864 /// return None; // Don't split here
865 /// }
866 /// ```
867 pub fn has_non_outer_ranges(self: &Arc<Self>) -> bool {
868 !self.all_in_scope_ranges_are(AxisType::Outer)
869 }
870
871 /// Build a consumer map for this UOp's computation graph.
872 ///
873 /// Returns a HashMap where each UOp maps to the list of UOps that consume it.
874 /// Useful for reverse traversal and dependency analysis.
875 #[allow(clippy::mutable_key_type)]
876 pub fn get_consumer_map(self: &Arc<Self>) -> HashMap<UOpKey, Vec<Arc<Self>>> {
877 let mut consumer_map: HashMap<UOpKey, Vec<Arc<Self>>> = HashMap::new();
878
879 for node in self.toposort() {
880 node.op.map_child(|child| {
881 consumer_map.entry(UOpKey(child.clone())).or_default().push(node.clone());
882 });
883 }
884
885 consumer_map
886 }
887
888 /// Reverse topological sort of the computation graph.
889 ///
890 /// Returns nodes in bottom-up order (leaves first, root last).
891 /// Requires a consumer map to traverse from leaves to roots.
892 #[allow(clippy::mutable_key_type)]
893 pub fn reverse_toposort(self: &Arc<Self>, consumer_map: &HashMap<UOpKey, Vec<Arc<Self>>>) -> Vec<Arc<Self>> {
894 let mut visited = HashMap::new(); // Use HashMap to track visited by ID
895 let mut result = Vec::new();
896 let mut stack = vec![(self.clone(), false)];
897
898 while let Some((node, processed)) = stack.pop() {
899 if visited.contains_key(&node.id) {
900 continue;
901 }
902
903 if processed {
904 visited.insert(node.id, ());
905 result.push(node);
906 } else {
907 stack.push((node.clone(), true));
908
909 // Visit consumers (nodes that depend on this node)
910 if let Some(consumers) = consumer_map.get(&UOpKey(node.clone())) {
911 for consumer in consumers {
912 if !visited.contains_key(&consumer.id) {
913 stack.push((consumer.clone(), false));
914 }
915 }
916 }
917 }
918 }
919
920 result
921 }
922
923 /// Replace UOps in the computation graph according to a substitution map.
924 ///
925 /// Delegates to `graph_rewrite_bottom_up` with a wildcard pattern that looks up
926 /// each node in the map — exactly like Tinygrad's `substitute`. The rewrite engine
927 /// provides O(n) memoization via its result cache.
928 #[allow(clippy::mutable_key_type)]
929 pub fn substitute(self: &Arc<Self>, map: &HashMap<UOpKey, Arc<Self>>) -> Arc<Self> {
930 if map.is_empty() {
931 return self.clone();
932 }
933 let matcher = SubstituteMatcher(map);
934 crate::rewrite::graph_rewrite_bottom_up(&matcher, self.clone(), &mut ())
935 }
936
937 /// Replace UOps with range-gated substitution (Tinygrad: `extra_pm=pm_gate_substitute`).
938 ///
939 /// Like `substitute`, but skips subtrees whose `in_scope_ranges()` don't contain
940 /// any of the substitution keys. This prevents substituting ranges in subexpressions
941 /// that don't reference them, matching Tinygrad's `gate_substitute` behavior.
942 #[allow(clippy::mutable_key_type)]
943 pub fn substitute_gated(self: &Arc<Self>, map: &HashMap<UOpKey, Arc<Self>>) -> Arc<Self> {
944 if map.is_empty() {
945 return self.clone();
946 }
947 let range_keys: HashSet<UOpKey> = map.keys().cloned().collect();
948 let matcher = SubstituteGatedMatcher { map, range_keys: &range_keys };
949 crate::rewrite::graph_rewrite_bottom_up(&matcher, self.clone(), &mut ())
950 }
951
952 /// Reconstruct this UOp with new sources.
953 ///
954 /// Creates a new UOp with the same operation type and dtype, but with the provided
955 /// sources replacing the original ones. Hash consing ensures that if an identical
956 /// UOp already exists, it will be reused.
957 ///
958 /// This is used by the graph rewrite engine when sources have been rewritten.
959 ///
960 /// # Panics
961 ///
962 /// Panics if the number of sources doesn't match the operation's arity.
963 ///
964 /// # Examples
965 ///
966 /// ```ignore
967 /// // Original: a + b
968 /// let add = UOp::add(a.clone(), b.clone());
969 ///
970 /// // Rewrite sources: a' + b'
971 /// let new_add = add.with_sources(vec![a_prime, b_prime]);
972 /// ```
973 pub fn with_sources(self: &Arc<Self>, new_srcs: Vec<Arc<Self>>) -> Arc<Self> {
974 use smallvec::SmallVec;
975
976 // Helper to get nth source
977 let src = |n: usize| new_srcs[n].clone();
978
979 let new_op = match &self.op {
980 // Nullary operations - no sources
981 Op::Const(_)
982 | Op::Unique(_)
983 | Op::Device(_)
984 | Op::Noop
985 | Op::Invalid
986 | Op::DefineLocal(_)
987 | Op::VConst { .. }
988 | Op::DefineVar { .. }
989 | Op::DefineReg { .. } => {
990 assert_eq!(new_srcs.len(), 0, "Nullary op should have no sources");
991 return self.clone(); // No sources to replace
992 }
993
994 // Unary operations
995 Op::Unary(op_type, _) => {
996 assert_eq!(new_srcs.len(), 1);
997 Op::Unary(*op_type, src(0))
998 }
999
1000 // Binary operations
1001 Op::Binary(op_type, _, _) => {
1002 assert_eq!(new_srcs.len(), 2);
1003 Op::Binary(*op_type, src(0), src(1))
1004 }
1005
1006 // Ternary operations
1007 Op::Ternary(op_type, _, _, _) => {
1008 assert_eq!(new_srcs.len(), 3);
1009 Op::Ternary(*op_type, src(0), src(1), src(2))
1010 }
1011
1012 // Type operations
1013 Op::Cast { dtype, .. } => {
1014 assert_eq!(new_srcs.len(), 1);
1015 Op::Cast { src: src(0), dtype: dtype.clone() }
1016 }
1017 Op::BitCast { dtype, .. } => {
1018 assert_eq!(new_srcs.len(), 1);
1019 Op::BitCast { src: src(0), dtype: dtype.clone() }
1020 }
1021
1022 // Special operations
1023 Op::MSelect { device_index, .. } => {
1024 assert_eq!(new_srcs.len(), 1);
1025 Op::MSelect { buffer: src(0), device_index: *device_index }
1026 }
1027 Op::Special { name, .. } => {
1028 assert_eq!(new_srcs.len(), 1);
1029 Op::Special { end: src(0), name: name.clone() }
1030 }
1031
1032 // Buffer operations
1033 Op::Buffer { size, .. } => {
1034 assert_eq!(new_srcs.len(), 2);
1035 Op::Buffer { unique: src(0), device: src(1), size: *size }
1036 }
1037 Op::Param { slot, size, device } => {
1038 if device.is_some() {
1039 assert_eq!(new_srcs.len(), 1);
1040 Op::Param { slot: *slot, size: *size, device: Some(src(0)) }
1041 } else {
1042 assert_eq!(new_srcs.len(), 0);
1043 return self.clone();
1044 }
1045 }
1046 Op::BufferView { size, offset, .. } => {
1047 assert_eq!(new_srcs.len(), 1);
1048 Op::BufferView { buffer: src(0), size: *size, offset: *offset }
1049 }
1050 Op::Bufferize { opts, .. } => {
1051 assert!(!new_srcs.is_empty());
1052 Op::Bufferize { compute: src(0), ranges: new_srcs[1..].iter().cloned().collect(), opts: opts.clone() }
1053 }
1054 Op::Index { gate, .. } => {
1055 assert!(!new_srcs.is_empty());
1056 // First source is buffer, rest are indices, last might be gate
1057 let buffer = src(0);
1058 let (indices, gate_new) = if gate.is_some() && new_srcs.len() >= 2 {
1059 let gate_src = new_srcs.last().unwrap().clone();
1060 let indices: SmallVec<[Arc<Self>; 4]> = new_srcs[1..new_srcs.len() - 1].iter().cloned().collect();
1061 (indices, Some(gate_src))
1062 } else {
1063 let indices: SmallVec<[Arc<Self>; 4]> = new_srcs[1..].iter().cloned().collect();
1064 (indices, None)
1065 };
1066 Op::Index { buffer, indices, gate: gate_new }
1067 }
1068 Op::PointerIndex { .. } => {
1069 assert_eq!(new_srcs.len(), 2);
1070 Op::PointerIndex { ptr: src(0), offset: src(1) }
1071 }
1072 Op::Copy { .. } => {
1073 assert_eq!(new_srcs.len(), 2);
1074 Op::Copy { src: src(0), device: src(1) }
1075 }
1076 Op::MStack { .. } => Op::MStack { buffers: new_srcs.iter().cloned().collect() },
1077
1078 // Movement operations
1079 Op::Reshape { .. } => {
1080 assert_eq!(new_srcs.len(), 2);
1081 Op::Reshape { src: src(0), new_shape: src(1) }
1082 }
1083 Op::Permute { axes, .. } => {
1084 assert_eq!(new_srcs.len(), 1);
1085 Op::Permute { src: src(0), axes: axes.clone() }
1086 }
1087 Op::Expand { .. } => {
1088 assert_eq!(new_srcs.len(), 2);
1089 Op::Expand { src: src(0), new_shape: src(1) }
1090 }
1091 Op::Pad { .. } => {
1092 assert_eq!(new_srcs.len(), 3);
1093 Op::Pad { src: src(0), begin_pads: src(1), end_pads: src(2) }
1094 }
1095 Op::Shrink { .. } => {
1096 assert_eq!(new_srcs.len(), 3);
1097 Op::Shrink { src: src(0), begins: src(1), ends: src(2) }
1098 }
1099 Op::Flip { axes, .. } => {
1100 assert_eq!(new_srcs.len(), 1);
1101 Op::Flip { src: src(0), axes: axes.clone() }
1102 }
1103 Op::Multi { axis, .. } => {
1104 assert_eq!(new_srcs.len(), 1);
1105 Op::Multi { src: src(0), axis: *axis }
1106 }
1107
1108 // Reduction operations
1109 Op::ReduceAxis { reduce_op, axes, .. } => {
1110 assert_eq!(new_srcs.len(), 1);
1111 Op::ReduceAxis { src: src(0), reduce_op: *reduce_op, axes: axes.clone() }
1112 }
1113 Op::Reduce { reduce_op, .. } => {
1114 assert!(!new_srcs.is_empty());
1115 Op::Reduce { src: src(0), ranges: new_srcs[1..].iter().cloned().collect(), reduce_op: *reduce_op }
1116 }
1117 Op::AllReduce { reduce_op, .. } => {
1118 assert_eq!(new_srcs.len(), 2);
1119 Op::AllReduce { src: src(0), device: src(1), reduce_op: *reduce_op }
1120 }
1121
1122 // Control flow operations
1123 Op::If { .. } => {
1124 assert!(!new_srcs.is_empty());
1125 Op::If { condition: src(0), body: new_srcs[1..].iter().cloned().collect() }
1126 }
1127 Op::EndIf { .. } => {
1128 assert_eq!(new_srcs.len(), 1);
1129 Op::EndIf { if_op: src(0) }
1130 }
1131 Op::Range { axis_id, axis_type, .. } => {
1132 assert!(!new_srcs.is_empty());
1133 Op::Range {
1134 end: src(0),
1135 axis_id: *axis_id,
1136 axis_type: *axis_type,
1137 deps: new_srcs[1..].iter().cloned().collect(),
1138 }
1139 }
1140 Op::End { .. } => {
1141 assert!(!new_srcs.is_empty());
1142 Op::End { computation: src(0), ranges: new_srcs[1..].iter().cloned().collect() }
1143 }
1144 Op::Barrier { .. } => {
1145 assert!(!new_srcs.is_empty());
1146 Op::Barrier { src: src(0), deps: new_srcs[1..].iter().cloned().collect() }
1147 }
1148
1149 // Vector operations — recompute dtype from new elements when element
1150 // dtype category changed (e.g. Scalar → Ptr during rewrite reconstruction).
1151 // Preserving old dtype is wrong when DEFINE_LOCAL → AFTER(Ptr) changes
1152 // element types from Scalar to Ptr, causing pm_add_loads infinite loops.
1153 Op::Vectorize { .. } => {
1154 let elements: SmallVec<[Arc<Self>; 4]> = new_srcs.iter().cloned().collect();
1155 let elem_dtype = elements[0].dtype();
1156 let new_dtype = match elem_dtype {
1157 DType::Scalar(_) | DType::Ptr { .. } => elem_dtype.vec(elements.len()),
1158 _ => self.dtype.clone(),
1159 };
1160 return Self::new(Op::Vectorize { elements }, new_dtype);
1161 }
1162 Op::Gep { indices, .. } => {
1163 assert_eq!(new_srcs.len(), 1);
1164 Op::Gep { vector: src(0), indices: indices.clone() }
1165 }
1166 Op::Cat { .. } => Op::Cat { sources: new_srcs.iter().cloned().collect() },
1167 Op::PtrCat { .. } => Op::PtrCat { sources: new_srcs.iter().cloned().collect() },
1168
1169 // Symbolic/Define operations
1170 Op::Bind { .. } => {
1171 assert_eq!(new_srcs.len(), 2);
1172 Op::Bind { var: src(0), value: src(1) }
1173 }
1174
1175 // Advanced operations
1176 Op::Wmma { metadata, .. } => {
1177 assert_eq!(new_srcs.len(), 3);
1178 Op::Wmma { a: src(0), b: src(1), c: src(2), metadata: metadata.clone() }
1179 }
1180 Op::Contract { upcast_ranges, .. } => {
1181 assert_eq!(new_srcs.len(), 1);
1182 Op::Contract { src: src(0), upcast_ranges: upcast_ranges.clone() }
1183 }
1184 Op::Unroll { unroll_axes, .. } => {
1185 assert_eq!(new_srcs.len(), 1);
1186 Op::Unroll { src: src(0), unroll_axes: unroll_axes.clone() }
1187 }
1188 Op::Kernel { .. } => {
1189 assert!(!new_srcs.is_empty());
1190 Op::Kernel {
1191 sources: new_srcs[..new_srcs.len() - 1].iter().cloned().collect(),
1192 ast: src(new_srcs.len() - 1),
1193 }
1194 }
1195 Op::Assign { .. } => {
1196 assert!(new_srcs.len() >= 2 && new_srcs.len() <= 3, "Assign requires 2-3 sources");
1197 Op::Assign {
1198 target: src(0),
1199 value: src(1),
1200 movement_ops: if new_srcs.len() > 2 { Some(src(2)) } else { None },
1201 }
1202 }
1203 Op::Detach { .. } => {
1204 assert_eq!(new_srcs.len(), 1);
1205 Op::Detach { src: src(0) }
1206 }
1207 Op::Contiguous { opts, .. } => {
1208 assert_eq!(new_srcs.len(), 1);
1209 Op::Contiguous { src: src(0), opts: opts.clone() }
1210 }
1211 Op::ContiguousBackward { .. } => {
1212 assert_eq!(new_srcs.len(), 1);
1213 Op::ContiguousBackward { src: src(0) }
1214 }
1215 Op::After { .. } => {
1216 assert!(!new_srcs.is_empty());
1217 let passthrough = src(0);
1218 // Validate: AFTER passthrough must not be control flow (Tinygrad semantics)
1219 debug_assert!(
1220 !matches!(passthrough.op(), Op::Range { .. } | Op::End { .. }),
1221 "reconstruct_sources: AFTER passthrough is {:?} (id={}), violates Tinygrad semantics",
1222 passthrough.op(),
1223 passthrough.id
1224 );
1225 Op::After { passthrough, deps: new_srcs[1..].iter().cloned().collect() }
1226 }
1227 Op::Precast { .. } => {
1228 assert_eq!(new_srcs.len(), 1);
1229 Op::Precast { src: src(0) }
1230 }
1231 Op::Custom { code, .. } => Op::Custom { deps: new_srcs.iter().cloned().collect(), code: code.clone() },
1232 Op::CustomI { code, .. } => Op::CustomI { deps: new_srcs.iter().cloned().collect(), code: code.clone() },
1233
1234 // Memory operations
1235 Op::Load { alt, .. } => {
1236 // Load has 2-3 sources: buffer, index, and optionally alt
1237 assert!(new_srcs.len() >= 2 && new_srcs.len() <= 3, "Load requires 2-3 sources");
1238 let new_alt = if new_srcs.len() == 3 { Some(src(2)) } else { alt.clone() };
1239 Op::Load { buffer: src(0), index: src(1), alt: new_alt }
1240 }
1241 Op::Store { .. } => {
1242 assert!(new_srcs.len() >= 2, "Store requires at least 2 sources (index, value)");
1243 Op::Store { index: src(0), value: src(1), ranges: new_srcs[2..].iter().cloned().collect() }
1244 }
1245
1246 // Graph organization
1247 Op::Sink { .. } => Op::Sink { sources: new_srcs.iter().cloned().collect() },
1248 Op::Group { .. } => Op::Group { sources: new_srcs.iter().cloned().collect() },
1249 };
1250
1251 // Preserve original dtype and tag (Tinygrad ops.py:1256: preserves tag through source reconstruction)
1252 Self::new_tagged(new_op, self.dtype.clone(), self.tag.clone())
1253 }
1254}
1255
1256#[bon]
1257impl UOp {
1258 /// Create a modified copy with optional field overrides.
1259 ///
1260 /// Enables concise pattern implementations by allowing selective field modification.
1261 /// Returns `self.clone()` if nothing changed (optimization for hash consing).
1262 ///
1263 /// # Examples
1264 ///
1265 /// ```ignore
1266 /// let new_load = load.replace().dtype(new_dtype).src(new_sources).call();
1267 /// let dtype_only = load.replace().dtype(new_dtype).call();
1268 /// ```
1269 #[builder]
1270 pub fn replace(self: &Arc<Self>, dtype: Option<DType>, src: Option<Vec<Arc<Self>>>) -> Arc<Self> {
1271 let new_dtype = dtype.unwrap_or_else(|| self.dtype());
1272 let new_sources = src.unwrap_or_else(|| self.op().sources().to_vec());
1273
1274 // Short-circuit if nothing changed
1275 let old_sources = self.op().sources();
1276 let sources_unchanged = new_sources.len() == old_sources.len()
1277 && new_sources.iter().zip(old_sources.iter()).all(|(a, b)| Arc::ptr_eq(a, b));
1278
1279 if new_dtype == self.dtype() && sources_unchanged {
1280 return self.clone();
1281 }
1282
1283 self.with_sources(new_sources).with_dtype(new_dtype)
1284 }
1285}
1286
1287impl Clone for UOp {
1288 fn clone(&self) -> Self {
1289 Self {
1290 id: self.id,
1291 op: self.op.clone(),
1292 dtype: self.dtype.clone(),
1293 content_hash: self.content_hash,
1294 tag: self.tag.clone(),
1295 shape_cache: std::sync::OnceLock::new(),
1296 ranges_cache: std::sync::OnceLock::new(),
1297 in_scope_ranges_cache: std::sync::OnceLock::new(),
1298 vmin_vmax_cache: std::sync::OnceLock::new(),
1299 sound_vmin_vmax_cache: std::sync::OnceLock::new(),
1300 has_index_in_sources_cache: std::sync::OnceLock::new(),
1301 backward_slice_cache: std::sync::OnceLock::new(),
1302 metadata: self.metadata.clone(),
1303 }
1304 }
1305}
1306
1307/// Trait for converting scalar values into UOps.
1308///
1309/// This allows operator overloading to work with mixed scalar/UOp operands.
1310/// For example: `uop + 5.0` or `5.0 + uop`.
1311pub trait IntoUOp {
1312 fn into_uop(self, dtype: DType) -> Arc<UOp>;
1313}
1314
1315impl IntoUOp for ConstValue {
1316 fn into_uop(self, dtype: DType) -> Arc<UOp> {
1317 UOp::const_(dtype, self)
1318 }
1319}
1320
1321impl IntoUOp for f32 {
1322 fn into_uop(self, dtype: DType) -> Arc<UOp> {
1323 UOp::const_(dtype, ConstValue::Float(self as f64))
1324 }
1325}
1326
1327impl IntoUOp for f64 {
1328 fn into_uop(self, dtype: DType) -> Arc<UOp> {
1329 UOp::const_(dtype, ConstValue::Float(self))
1330 }
1331}
1332
1333impl IntoUOp for i32 {
1334 fn into_uop(self, dtype: DType) -> Arc<UOp> {
1335 UOp::const_(dtype, ConstValue::Int(self as i64))
1336 }
1337}
1338
1339impl IntoUOp for i64 {
1340 fn into_uop(self, dtype: DType) -> Arc<UOp> {
1341 UOp::const_(dtype, ConstValue::Int(self))
1342 }
1343}
1344
1345impl IntoUOp for u32 {
1346 fn into_uop(self, dtype: DType) -> Arc<UOp> {
1347 UOp::const_(dtype, ConstValue::UInt(self as u64))
1348 }
1349}
1350
1351impl IntoUOp for u64 {
1352 fn into_uop(self, dtype: DType) -> Arc<UOp> {
1353 UOp::const_(dtype, ConstValue::UInt(self))
1354 }
1355}
1356
1357impl IntoUOp for bool {
1358 fn into_uop(self, dtype: DType) -> Arc<UOp> {
1359 UOp::const_(dtype, ConstValue::Bool(self))
1360 }
1361}