Skip to main content

morok_ir/uop/
hash_consing.rs

1//! Hash consing infrastructure for UOp deduplication.
2//!
3//! This module implements the caching system that ensures structurally identical
4//! UOps share the same memory allocation (hash consing).
5//!
6//! # Thread Safety
7//!
8//! Uses a global lock-free concurrent HashMap (papaya) for cross-thread deduplication.
9//! Creating the same UOp in different threads returns the same `Arc<UOp>`, so
10//! `Arc::ptr_eq` works correctly across thread boundaries.
11//!
12//! # Memory Management (Tinygrad-aligned)
13//!
14//! UOps are stored as `Weak<UOp>` references in the cache. When no strong references
15//! remain (outside the cache), the UOp is automatically eligible for cleanup.
16//! Dead weak references are cleaned up lazily on next access or via `gc_dead_refs()`.
17//!
18//! This matches Tinygrad's approach using `weakref.WeakKeyDictionary` - no manual
19//! cleanup calls required in user code.
20
21use std::hash::{Hash, Hasher};
22use std::mem::discriminant;
23use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
24use std::sync::{Arc, OnceLock, Weak};
25
26use papaya::HashMap;
27use smallvec::SmallVec;
28
29use crate::op::Op;
30use crate::types::*;
31use crate::uop::core::UOp;
32use morok_dtype::DType;
33use morok_dtype::DeviceSpec;
34
35// Global atomic counter for unique identifiers.
36//
37// Uses AtomicUsize for thread-safe ID generation across all threads.
38// Ordering::Relaxed is sufficient since we only need uniqueness, not synchronization.
39static UNIQUE_COUNTER: AtomicUsize = AtomicUsize::new(0);
40
41pub(crate) fn next_unique_id() -> usize {
42    UNIQUE_COUNTER.fetch_add(1, Ordering::Relaxed)
43}
44
45// Global atomic counter for UOp stable IDs.
46//
47// Provides monotonic IDs that never repeat, eliminating ABA problem.
48// Uses u64 to provide 2^64 unique IDs (effectively unlimited).
49static UOP_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
50
51pub(crate) fn next_uop_id() -> u64 {
52    UOP_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
53}
54
55/// Cache key for hash consing.
56///
57/// Uses stable UOp IDs for child UOps to avoid infinite recursion during hashing.
58/// IDs are monotonic and never reused, eliminating ABA problem from pointer-based approach.
59///
60/// Performance: hash is pre-computed during construction and cached in `cached_hash`.
61/// This avoids re-hashing on every HashMap lookup (the previous bottleneck: 57% of CPU
62/// in xxhash). Follows Tinygrad's approach where UOp hash is `id()`-based (~nanoseconds).
63#[derive(Clone)]
64struct UOpKey {
65    op_discriminant: std::mem::Discriminant<Op>,
66    dtype: DType,
67    src_hashes: SmallVec<[u64; 4]>,
68    op_data: OpData,
69    tag: Option<SmallVec<[usize; 2]>>,
70    /// Pre-computed hash — avoids re-hashing on every HashMap operation.
71    cached_hash: u64,
72}
73
74impl Hash for UOpKey {
75    #[inline]
76    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
77        // Use pre-computed hash directly — O(1) regardless of OpData complexity
78        state.write_u64(self.cached_hash);
79    }
80}
81
82impl PartialEq for UOpKey {
83    fn eq(&self, other: &Self) -> bool {
84        // Fast path: different hashes → definitely not equal
85        self.cached_hash == other.cached_hash
86            && self.op_discriminant == other.op_discriminant
87            && self.dtype == other.dtype
88            && self.src_hashes == other.src_hashes
89            && self.op_data == other.op_data
90            && self.tag == other.tag
91    }
92}
93
94impl Eq for UOpKey {}
95
96/// Non-recursive data from Op variants for hashing.
97///
98/// Captures operation-specific data that std::mem::discriminant doesn't include.
99/// This is critical for hash consing correctness - without this, operations like
100/// Add and Mul would be treated as identical since they have the same discriminant.
101#[derive(Eq, PartialEq, Hash, Clone)]
102enum OpData {
103    // Nullary operations
104    Const(ConstValueHash),
105    Unique(usize),
106    Device(DeviceSpec),
107    // DefineLocal includes unique ID to prevent hash consing across kernels.
108    DefineLocal(usize, usize), // (slot, unique_id)
109
110    // Grouped operations
111    Unary(UnaryOp),
112    Binary(BinaryOp),
113    Ternary(TernaryOp),
114
115    // Type operations
116    CastDType(DType),
117    BitCastDType(DType),
118
119    // Special operations
120    MSelectIdx(usize),
121    SpecialName(String),
122
123    // Buffer operations
124    BufferData(usize, usize), // (unique_id, size) - each buffer is unique
125    ParamData(usize, usize),  // (slot, size) — dedup by structure, matching Tinygrad's UOp cache
126    BufferView(usize, usize),
127    Bufferize(BufferizeOpts),
128
129    // Movement/Reshape operations
130    PermuteAxes(Vec<usize>),
131    FlipAxes(Vec<bool>),
132    MultiAxis(usize),
133
134    // Reduction operations
135    ReduceAxisData(ReduceOp, Vec<usize>),
136    ReduceOp(ReduceOp),
137    AllReduceOp(ReduceOp),
138
139    // Control flow operations
140    RangeData(AxisId, AxisType),
141
142    // Vector operations
143    GepIndices(Vec<usize>),
144    VConstValues(Vec<ConstValueHash>),
145
146    // Symbolic/Define operations
147    DefineVarData(String, i64, i64), // (name, min_val, max_val)
148    DefineRegData(usize, usize),     // (size, id)
149
150    // Advanced operations
151    WmmaData(Box<WmmaMetadata>),
152    ContractRanges(Vec<(usize, usize)>),
153    UnrollAxes(Vec<(usize, usize)>),
154    CustomCode(String),
155
156    // Movement operations with extra data
157    ContiguousOpts(Vec<crate::types::ContiguousHint>),
158
159    // Operations with only children (no extra semantic data)
160    None,
161}
162
163/// Get child UOp structural hashes for hash consing.
164///
165/// Uses `content_hash` (structural) instead of `id` (identity) so that
166/// structurally identical children produce the same key — even if they're
167/// different `Arc` pointers. This makes hash consing truly structural,
168/// matching Tinygrad's behavior where `id()` works because hash consing
169/// guarantees same structure = same object.
170///
171/// Returns SmallVec of hashes, optimized for common case of ≤4 children (inline storage).
172fn src_hashes(op: &Op) -> SmallVec<[u64; 4]> {
173    op.children().into_iter().map(|child| child.content_hash).collect()
174}
175
176impl UOpKey {
177    fn new(op: &Op, dtype: DType, tag: &Option<SmallVec<[usize; 2]>>) -> Self {
178        let op_discriminant = discriminant(op);
179        let src_hashes = src_hashes(op);
180
181        let op_data = match op {
182            Op::Const(c) => OpData::Const(*c),
183            Op::Unique(id) => OpData::Unique(*id),
184            Op::Device(d) => OpData::Device(d.clone()),
185            Op::DefineLocal(slot) => OpData::DefineLocal(*slot, next_unique_id()),
186            Op::Unary(unary_op, _) => OpData::Unary(*unary_op),
187            Op::Binary(binary_op, _, _) => OpData::Binary(*binary_op),
188            Op::Ternary(ternary_op, _, _, _) => OpData::Ternary(*ternary_op),
189            Op::Cast { dtype, .. } => OpData::CastDType(dtype.clone()),
190            Op::BitCast { dtype, .. } => OpData::BitCastDType(dtype.clone()),
191            Op::MSelect { device_index, .. } => OpData::MSelectIdx(*device_index),
192            Op::Special { name, .. } => OpData::SpecialName(name.clone()),
193            Op::Buffer { unique, size, .. } => {
194                if let Op::Unique(id) = unique.op() {
195                    OpData::BufferData(*id, *size)
196                } else {
197                    // Fallback: use UOp's stable id
198                    OpData::BufferData(unique.id as usize, *size)
199                }
200            }
201            Op::BufferView { size, offset, .. } => OpData::BufferView(*size, *offset),
202            Op::Bufferize { opts, .. } => OpData::Bufferize(opts.clone()),
203            Op::Permute { axes, .. } => OpData::PermuteAxes(axes.clone()),
204            Op::Flip { axes, .. } => OpData::FlipAxes(axes.clone()),
205            Op::Multi { axis, .. } => OpData::MultiAxis(*axis),
206            Op::ReduceAxis { reduce_op, axes, .. } => OpData::ReduceAxisData(*reduce_op, axes.clone()),
207            Op::Reduce { reduce_op, .. } => OpData::ReduceOp(*reduce_op),
208            Op::AllReduce { reduce_op, .. } => OpData::AllReduceOp(*reduce_op),
209            Op::Range { axis_id, axis_type, .. } => OpData::RangeData(*axis_id, *axis_type),
210            Op::Gep { indices, .. } => OpData::GepIndices(indices.clone()),
211            Op::VConst { values } => OpData::VConstValues(values.iter().map(|v| ConstValueHash(*v)).collect()),
212            Op::DefineVar { name, min_val, max_val } => OpData::DefineVarData(name.clone(), *min_val, *max_val),
213            Op::DefineReg { size, id } => OpData::DefineRegData(*size, *id),
214            Op::Wmma { metadata, .. } => OpData::WmmaData(metadata.clone().into()),
215            Op::Contract { upcast_ranges, .. } => OpData::ContractRanges(upcast_ranges.clone()),
216            Op::Unroll { unroll_axes, .. } => OpData::UnrollAxes(unroll_axes.clone()),
217            Op::Custom { code, .. } | Op::CustomI { code, .. } => OpData::CustomCode(code.clone()),
218            Op::Contiguous { opts, .. } => OpData::ContiguousOpts(opts.to_vec()),
219            Op::Param { slot, size, .. } => OpData::ParamData(*slot, *size),
220            // All remaining ops encode semantic data entirely through children
221            // (captured by src_hashes) — no extra OpData needed.
222            Op::Noop | Op::Invalid => OpData::None,
223            // Multi-child ops: children ARE the data
224            Op::Sink { .. }
225            | Op::Group { .. }
226            | Op::Vectorize { .. }
227            | Op::Cat { .. }
228            | Op::PtrCat { .. }
229            | Op::MStack { .. }
230            | Op::Barrier { .. } => OpData::None,
231            // Movement ops: shape/bounds are Arc<UOp> children
232            Op::Reshape { .. } | Op::Expand { .. } | Op::Pad { .. } | Op::Shrink { .. } => OpData::None,
233            // Memory/control: all fields are Arc<UOp> children
234            Op::Index { .. } | Op::PointerIndex { .. } | Op::Copy { .. } | Op::Load { .. } | Op::Store { .. } => {
235                OpData::None
236            }
237            Op::If { .. } | Op::EndIf { .. } | Op::End { .. } | Op::After { .. } => OpData::None,
238            // Single-source ops with no extra data
239            Op::Detach { .. } | Op::ContiguousBackward { .. } | Op::Precast { .. } => OpData::None,
240            // Binding/kernel: children encode all semantics
241            Op::Bind { .. } | Op::Kernel { .. } | Op::Assign { .. } => OpData::None,
242        };
243
244        // Pre-compute hash using xxhash (fast, non-cryptographic).
245        // Cached to avoid re-hashing on every HashMap lookup — the previous
246        // bottleneck was 57% of CPU time spent in xxhash due to repeated hashing.
247        let cached_hash = {
248            use xxhash_rust::xxh64::Xxh64;
249            let mut h = Xxh64::new(0);
250            op_discriminant.hash(&mut h);
251            dtype.hash(&mut h);
252            for id in &src_hashes {
253                h.write_u64(*id);
254            }
255            op_data.hash(&mut h);
256            tag.hash(&mut h);
257            h.finish()
258        };
259
260        Self { op_discriminant, dtype, src_hashes, op_data, tag: tag.clone(), cached_hash }
261    }
262}
263
264// Global hash consing cache using lock-free concurrent HashMap.
265//
266// Design: Stores Weak<UOp> for automatic memory management (Tinygrad-aligned).
267// - Cross-thread deduplication: same UOpKey → same Arc<UOp> across all threads
268// - Lock-free reads and writes via papaya's epoch-based reclamation
269// - Automatic cleanup: when no strong refs remain, weak ref becomes dead
270// - Dead refs cleaned lazily on next access or via gc_dead_refs()
271//
272// Memory lifecycle (matches Tinygrad's weakref.WeakKeyDictionary):
273// 1. UOps created via UOp::new() store Weak refs in cache
274// 2. Strong refs held by Tensor, Scheduler, etc. keep UOps alive
275// 3. When all strong refs dropped, UOp deallocated, weak ref becomes dead
276// 4. Dead weak refs cleaned up lazily or via gc_dead_refs()
277static UOPS: OnceLock<HashMap<UOpKey, Weak<UOp>>> = OnceLock::new();
278
279fn uops() -> &'static HashMap<UOpKey, Weak<UOp>> {
280    UOPS.get_or_init(HashMap::new)
281}
282
283/// Remove dead weak references from the cache.
284///
285/// This is optional - dead refs are also cleaned lazily on next access.
286/// Call this if you want to proactively free cache memory.
287///
288/// # Example
289///
290/// ```ignore
291/// // After dropping many tensors, optionally clean up cache
292/// gc_dead_refs();
293/// ```
294pub fn gc_dead_refs() {
295    let map = uops();
296    let guard = map.guard();
297
298    // Collect keys with dead weak refs
299    let to_remove: Vec<UOpKey> =
300        map.iter(&guard).filter(|(_, weak)| weak.upgrade().is_none()).map(|(k, _)| k.clone()).collect();
301
302    // Remove dead entries
303    for key in to_remove {
304        map.remove(&key, &guard);
305    }
306}
307
308/// Legacy alias for gc_dead_refs (for compatibility).
309///
310/// With weak references, UOps are automatically cleaned up when no longer
311/// referenced. This function now just cleans up dead weak refs in the cache.
312#[deprecated(note = "UOp cache now uses weak refs - cleanup is automatic. Use gc_dead_refs() to clean cache.")]
313pub fn gc_unused_uops() {
314    gc_dead_refs();
315}
316
317/// Get the set of IDs for UOps currently alive in the cache.
318///
319/// This is used by kernel cache GC to determine which compiled kernels
320/// can be safely removed (those whose AST IDs are no longer live).
321///
322/// # Returns
323///
324/// A HashSet containing the IDs of all currently cached UOps (only live ones).
325pub fn live_uop_ids() -> std::collections::HashSet<u64> {
326    let map = uops();
327    let guard = map.guard();
328    map.iter(&guard).filter_map(|(_, weak)| weak.upgrade().map(|arc| arc.id)).collect()
329}
330
331impl UOp {
332    /// Create a new UOp with hash consing.
333    ///
334    /// If an identical UOp already exists (in any thread) and is still alive,
335    /// returns a reference to it. Otherwise, creates a new UOp and caches it.
336    ///
337    /// # Thread Safety
338    ///
339    /// This function is thread-safe. Creating the same UOp from different threads
340    /// will return the same `Arc<UOp>`, so `Arc::ptr_eq` works across threads.
341    ///
342    /// # Memory Management
343    ///
344    /// The cache stores weak references. UOps are automatically cleaned up when
345    /// no strong references remain (Tinygrad-aligned behavior).
346    #[inline]
347    #[track_caller]
348    pub fn new(op: Op, dtype: DType) -> Arc<Self> {
349        Self::new_tagged(op, dtype, None)
350    }
351
352    /// Create a UOp with an explicit tag (Tinygrad: `UOp(op, dtype, src, arg, tag)`).
353    /// Tag participates in hash consing — same structure + different tag = different UOp.
354    #[track_caller]
355    pub fn new_tagged(op: Op, dtype: DType, tag: Option<SmallVec<[usize; 2]>>) -> Arc<Self> {
356        use papaya::{Compute, Operation};
357
358        let caller_location = std::panic::Location::caller();
359        let key = UOpKey::new(&op, dtype.clone(), &tag);
360        let guard = uops().guard();
361
362        // Fast path: check if valid entry exists
363        if let Some(weak) = uops().get(&key, &guard)
364            && let Some(arc) = weak.upgrade()
365        {
366            use crate::provenance::PROVENANCE_TRACKER;
367            PROVENANCE_TRACKER.with(|tracker| {
368                tracker.borrow_mut().capture(arc.id, caller_location);
369            });
370            return arc;
371        }
372
373        let content_hash = {
374            use xxhash_rust::xxh64::Xxh64;
375            let mut h = Xxh64::new(0);
376            std::mem::discriminant(&op).hash(&mut h);
377            dtype.hash(&mut h);
378            for child in op.children() {
379                h.write_u64(child.content_hash);
380            }
381            key.op_data.hash(&mut h);
382            h.finish()
383        };
384
385        let new_arc = Arc::new(Self {
386            id: next_uop_id(),
387            op,
388            dtype,
389            content_hash,
390            tag,
391            shape_cache: std::sync::OnceLock::new(),
392            ranges_cache: std::sync::OnceLock::new(),
393            in_scope_ranges_cache: std::sync::OnceLock::new(),
394            vmin_vmax_cache: std::sync::OnceLock::new(),
395            sound_vmin_vmax_cache: std::sync::OnceLock::new(),
396            has_index_in_sources_cache: std::sync::OnceLock::new(),
397            backward_slice_cache: std::sync::OnceLock::new(),
398            metadata: None,
399        });
400        let new_weak = Arc::downgrade(&new_arc);
401
402        let result = uops().compute(
403            key,
404            |entry| match entry {
405                Some((_, existing_weak)) => {
406                    if let Some(existing_arc) = existing_weak.upgrade() {
407                        Operation::Abort(existing_arc)
408                    } else {
409                        Operation::Insert(new_weak.clone())
410                    }
411                }
412                None => Operation::Insert(new_weak.clone()),
413            },
414            &guard,
415        );
416
417        let final_arc = match result {
418            Compute::Inserted(_, _) | Compute::Updated { .. } => new_arc,
419            Compute::Aborted(existing_arc) => existing_arc,
420            _ => new_arc,
421        };
422
423        use crate::provenance::PROVENANCE_TRACKER;
424        PROVENANCE_TRACKER.with(|tracker| {
425            tracker.borrow_mut().capture(final_arc.id, caller_location);
426        });
427
428        final_arc
429    }
430
431    /// Attach metadata to this UOp, creating a new instance.
432    ///
433    /// Metadata is NOT part of hash consing - this method creates a new UOp
434    /// with a different ID but the same operation structure. This allows
435    /// attaching metadata (like kernel info) after optimization.
436    ///
437    /// # Examples
438    ///
439    /// ```ignore
440    /// let ast = /* ... optimized AST ... */;
441    /// let with_info = ast.with_metadata(KernelInfo::new("r_g16l16", vec![], false));
442    /// ```
443    pub fn with_metadata<T: std::any::Any + Send + Sync + 'static>(self: &Arc<Self>, metadata: T) -> Arc<Self> {
444        self.with_metadata_raw(Arc::new(metadata))
445    }
446
447    /// Get metadata of a specific type if it exists.
448    ///
449    /// Returns `None` if no metadata is attached or if the metadata is of a different type.
450    ///
451    /// # Examples
452    ///
453    /// ```ignore
454    /// if let Some(info) = ast.metadata::<KernelInfo>() {
455    ///     println!("Kernel name: {}", info.name);
456    /// }
457    /// ```
458    pub fn metadata<T: std::any::Any + Send + Sync>(&self) -> Option<std::sync::Arc<T>> {
459        self.metadata.as_ref()?.clone().downcast::<T>().ok()
460    }
461
462    /// Get raw metadata (type-erased).
463    ///
464    /// Used to preserve metadata across graph rewrites that create new root nodes.
465    pub fn metadata_raw(&self) -> Option<Arc<dyn std::any::Any + Send + Sync>> {
466        self.metadata.clone()
467    }
468
469    /// Attach raw metadata (type-erased), creating a new instance.
470    ///
471    /// Used to re-attach metadata that was saved before graph rewrites.
472    pub fn with_metadata_raw(self: &Arc<Self>, metadata: Arc<dyn std::any::Any + Send + Sync>) -> Arc<Self> {
473        Arc::new(Self {
474            id: next_uop_id(),
475            op: self.op.clone(),
476            dtype: self.dtype.clone(),
477            content_hash: self.content_hash, // same structure, same content hash
478            tag: self.tag.clone(),
479            shape_cache: std::sync::OnceLock::new(),
480            ranges_cache: std::sync::OnceLock::new(),
481            in_scope_ranges_cache: std::sync::OnceLock::new(),
482            vmin_vmax_cache: std::sync::OnceLock::new(),
483            sound_vmin_vmax_cache: std::sync::OnceLock::new(),
484            has_index_in_sources_cache: std::sync::OnceLock::new(),
485            backward_slice_cache: std::sync::OnceLock::new(),
486            metadata: Some(metadata),
487        })
488    }
489}