1use std::hash::{Hash, Hasher};
22use std::mem::discriminant;
23use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
24use std::sync::{Arc, OnceLock, Weak};
25
26use papaya::HashMap;
27use smallvec::SmallVec;
28
29use crate::op::Op;
30use crate::types::*;
31use crate::uop::core::UOp;
32use morok_dtype::DType;
33use morok_dtype::DeviceSpec;
34
35static UNIQUE_COUNTER: AtomicUsize = AtomicUsize::new(0);
40
41pub(crate) fn next_unique_id() -> usize {
42 UNIQUE_COUNTER.fetch_add(1, Ordering::Relaxed)
43}
44
45static UOP_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
50
51pub(crate) fn next_uop_id() -> u64 {
52 UOP_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
53}
54
55#[derive(Clone)]
64struct UOpKey {
65 op_discriminant: std::mem::Discriminant<Op>,
66 dtype: DType,
67 src_hashes: SmallVec<[u64; 4]>,
68 op_data: OpData,
69 tag: Option<SmallVec<[usize; 2]>>,
70 cached_hash: u64,
72}
73
74impl Hash for UOpKey {
75 #[inline]
76 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
77 state.write_u64(self.cached_hash);
79 }
80}
81
82impl PartialEq for UOpKey {
83 fn eq(&self, other: &Self) -> bool {
84 self.cached_hash == other.cached_hash
86 && self.op_discriminant == other.op_discriminant
87 && self.dtype == other.dtype
88 && self.src_hashes == other.src_hashes
89 && self.op_data == other.op_data
90 && self.tag == other.tag
91 }
92}
93
94impl Eq for UOpKey {}
95
96#[derive(Eq, PartialEq, Hash, Clone)]
102enum OpData {
103 Const(ConstValueHash),
105 Unique(usize),
106 Device(DeviceSpec),
107 DefineLocal(usize, usize), Unary(UnaryOp),
112 Binary(BinaryOp),
113 Ternary(TernaryOp),
114
115 CastDType(DType),
117 BitCastDType(DType),
118
119 MSelectIdx(usize),
121 SpecialName(String),
122
123 BufferData(usize, usize), ParamData(usize, usize), BufferView(usize, usize),
127 Bufferize(BufferizeOpts),
128
129 PermuteAxes(Vec<usize>),
131 FlipAxes(Vec<bool>),
132 MultiAxis(usize),
133
134 ReduceAxisData(ReduceOp, Vec<usize>),
136 ReduceOp(ReduceOp),
137 AllReduceOp(ReduceOp),
138
139 RangeData(AxisId, AxisType),
141
142 GepIndices(Vec<usize>),
144 VConstValues(Vec<ConstValueHash>),
145
146 DefineVarData(String, i64, i64), DefineRegData(usize, usize), WmmaData(Box<WmmaMetadata>),
152 ContractRanges(Vec<(usize, usize)>),
153 UnrollAxes(Vec<(usize, usize)>),
154 CustomCode(String),
155
156 ContiguousOpts(Vec<crate::types::ContiguousHint>),
158
159 None,
161}
162
163fn src_hashes(op: &Op) -> SmallVec<[u64; 4]> {
173 op.children().into_iter().map(|child| child.content_hash).collect()
174}
175
176impl UOpKey {
177 fn new(op: &Op, dtype: DType, tag: &Option<SmallVec<[usize; 2]>>) -> Self {
178 let op_discriminant = discriminant(op);
179 let src_hashes = src_hashes(op);
180
181 let op_data = match op {
182 Op::Const(c) => OpData::Const(*c),
183 Op::Unique(id) => OpData::Unique(*id),
184 Op::Device(d) => OpData::Device(d.clone()),
185 Op::DefineLocal(slot) => OpData::DefineLocal(*slot, next_unique_id()),
186 Op::Unary(unary_op, _) => OpData::Unary(*unary_op),
187 Op::Binary(binary_op, _, _) => OpData::Binary(*binary_op),
188 Op::Ternary(ternary_op, _, _, _) => OpData::Ternary(*ternary_op),
189 Op::Cast { dtype, .. } => OpData::CastDType(dtype.clone()),
190 Op::BitCast { dtype, .. } => OpData::BitCastDType(dtype.clone()),
191 Op::MSelect { device_index, .. } => OpData::MSelectIdx(*device_index),
192 Op::Special { name, .. } => OpData::SpecialName(name.clone()),
193 Op::Buffer { unique, size, .. } => {
194 if let Op::Unique(id) = unique.op() {
195 OpData::BufferData(*id, *size)
196 } else {
197 OpData::BufferData(unique.id as usize, *size)
199 }
200 }
201 Op::BufferView { size, offset, .. } => OpData::BufferView(*size, *offset),
202 Op::Bufferize { opts, .. } => OpData::Bufferize(opts.clone()),
203 Op::Permute { axes, .. } => OpData::PermuteAxes(axes.clone()),
204 Op::Flip { axes, .. } => OpData::FlipAxes(axes.clone()),
205 Op::Multi { axis, .. } => OpData::MultiAxis(*axis),
206 Op::ReduceAxis { reduce_op, axes, .. } => OpData::ReduceAxisData(*reduce_op, axes.clone()),
207 Op::Reduce { reduce_op, .. } => OpData::ReduceOp(*reduce_op),
208 Op::AllReduce { reduce_op, .. } => OpData::AllReduceOp(*reduce_op),
209 Op::Range { axis_id, axis_type, .. } => OpData::RangeData(*axis_id, *axis_type),
210 Op::Gep { indices, .. } => OpData::GepIndices(indices.clone()),
211 Op::VConst { values } => OpData::VConstValues(values.iter().map(|v| ConstValueHash(*v)).collect()),
212 Op::DefineVar { name, min_val, max_val } => OpData::DefineVarData(name.clone(), *min_val, *max_val),
213 Op::DefineReg { size, id } => OpData::DefineRegData(*size, *id),
214 Op::Wmma { metadata, .. } => OpData::WmmaData(metadata.clone().into()),
215 Op::Contract { upcast_ranges, .. } => OpData::ContractRanges(upcast_ranges.clone()),
216 Op::Unroll { unroll_axes, .. } => OpData::UnrollAxes(unroll_axes.clone()),
217 Op::Custom { code, .. } | Op::CustomI { code, .. } => OpData::CustomCode(code.clone()),
218 Op::Contiguous { opts, .. } => OpData::ContiguousOpts(opts.to_vec()),
219 Op::Param { slot, size, .. } => OpData::ParamData(*slot, *size),
220 Op::Noop | Op::Invalid => OpData::None,
223 Op::Sink { .. }
225 | Op::Group { .. }
226 | Op::Vectorize { .. }
227 | Op::Cat { .. }
228 | Op::PtrCat { .. }
229 | Op::MStack { .. }
230 | Op::Barrier { .. } => OpData::None,
231 Op::Reshape { .. } | Op::Expand { .. } | Op::Pad { .. } | Op::Shrink { .. } => OpData::None,
233 Op::Index { .. } | Op::PointerIndex { .. } | Op::Copy { .. } | Op::Load { .. } | Op::Store { .. } => {
235 OpData::None
236 }
237 Op::If { .. } | Op::EndIf { .. } | Op::End { .. } | Op::After { .. } => OpData::None,
238 Op::Detach { .. } | Op::ContiguousBackward { .. } | Op::Precast { .. } => OpData::None,
240 Op::Bind { .. } | Op::Kernel { .. } | Op::Assign { .. } => OpData::None,
242 };
243
244 let cached_hash = {
248 use xxhash_rust::xxh64::Xxh64;
249 let mut h = Xxh64::new(0);
250 op_discriminant.hash(&mut h);
251 dtype.hash(&mut h);
252 for id in &src_hashes {
253 h.write_u64(*id);
254 }
255 op_data.hash(&mut h);
256 tag.hash(&mut h);
257 h.finish()
258 };
259
260 Self { op_discriminant, dtype, src_hashes, op_data, tag: tag.clone(), cached_hash }
261 }
262}
263
264static UOPS: OnceLock<HashMap<UOpKey, Weak<UOp>>> = OnceLock::new();
278
279fn uops() -> &'static HashMap<UOpKey, Weak<UOp>> {
280 UOPS.get_or_init(HashMap::new)
281}
282
283pub fn gc_dead_refs() {
295 let map = uops();
296 let guard = map.guard();
297
298 let to_remove: Vec<UOpKey> =
300 map.iter(&guard).filter(|(_, weak)| weak.upgrade().is_none()).map(|(k, _)| k.clone()).collect();
301
302 for key in to_remove {
304 map.remove(&key, &guard);
305 }
306}
307
308#[deprecated(note = "UOp cache now uses weak refs - cleanup is automatic. Use gc_dead_refs() to clean cache.")]
313pub fn gc_unused_uops() {
314 gc_dead_refs();
315}
316
317pub fn live_uop_ids() -> std::collections::HashSet<u64> {
326 let map = uops();
327 let guard = map.guard();
328 map.iter(&guard).filter_map(|(_, weak)| weak.upgrade().map(|arc| arc.id)).collect()
329}
330
331impl UOp {
332 #[inline]
347 #[track_caller]
348 pub fn new(op: Op, dtype: DType) -> Arc<Self> {
349 Self::new_tagged(op, dtype, None)
350 }
351
352 #[track_caller]
355 pub fn new_tagged(op: Op, dtype: DType, tag: Option<SmallVec<[usize; 2]>>) -> Arc<Self> {
356 use papaya::{Compute, Operation};
357
358 let caller_location = std::panic::Location::caller();
359 let key = UOpKey::new(&op, dtype.clone(), &tag);
360 let guard = uops().guard();
361
362 if let Some(weak) = uops().get(&key, &guard)
364 && let Some(arc) = weak.upgrade()
365 {
366 use crate::provenance::PROVENANCE_TRACKER;
367 PROVENANCE_TRACKER.with(|tracker| {
368 tracker.borrow_mut().capture(arc.id, caller_location);
369 });
370 return arc;
371 }
372
373 let content_hash = {
374 use xxhash_rust::xxh64::Xxh64;
375 let mut h = Xxh64::new(0);
376 std::mem::discriminant(&op).hash(&mut h);
377 dtype.hash(&mut h);
378 for child in op.children() {
379 h.write_u64(child.content_hash);
380 }
381 key.op_data.hash(&mut h);
382 h.finish()
383 };
384
385 let new_arc = Arc::new(Self {
386 id: next_uop_id(),
387 op,
388 dtype,
389 content_hash,
390 tag,
391 shape_cache: std::sync::OnceLock::new(),
392 ranges_cache: std::sync::OnceLock::new(),
393 in_scope_ranges_cache: std::sync::OnceLock::new(),
394 vmin_vmax_cache: std::sync::OnceLock::new(),
395 sound_vmin_vmax_cache: std::sync::OnceLock::new(),
396 has_index_in_sources_cache: std::sync::OnceLock::new(),
397 backward_slice_cache: std::sync::OnceLock::new(),
398 metadata: None,
399 });
400 let new_weak = Arc::downgrade(&new_arc);
401
402 let result = uops().compute(
403 key,
404 |entry| match entry {
405 Some((_, existing_weak)) => {
406 if let Some(existing_arc) = existing_weak.upgrade() {
407 Operation::Abort(existing_arc)
408 } else {
409 Operation::Insert(new_weak.clone())
410 }
411 }
412 None => Operation::Insert(new_weak.clone()),
413 },
414 &guard,
415 );
416
417 let final_arc = match result {
418 Compute::Inserted(_, _) | Compute::Updated { .. } => new_arc,
419 Compute::Aborted(existing_arc) => existing_arc,
420 _ => new_arc,
421 };
422
423 use crate::provenance::PROVENANCE_TRACKER;
424 PROVENANCE_TRACKER.with(|tracker| {
425 tracker.borrow_mut().capture(final_arc.id, caller_location);
426 });
427
428 final_arc
429 }
430
431 pub fn with_metadata<T: std::any::Any + Send + Sync + 'static>(self: &Arc<Self>, metadata: T) -> Arc<Self> {
444 self.with_metadata_raw(Arc::new(metadata))
445 }
446
447 pub fn metadata<T: std::any::Any + Send + Sync>(&self) -> Option<std::sync::Arc<T>> {
459 self.metadata.as_ref()?.clone().downcast::<T>().ok()
460 }
461
462 pub fn metadata_raw(&self) -> Option<Arc<dyn std::any::Any + Send + Sync>> {
466 self.metadata.clone()
467 }
468
469 pub fn with_metadata_raw(self: &Arc<Self>, metadata: Arc<dyn std::any::Any + Send + Sync>) -> Arc<Self> {
473 Arc::new(Self {
474 id: next_uop_id(),
475 op: self.op.clone(),
476 dtype: self.dtype.clone(),
477 content_hash: self.content_hash, tag: self.tag.clone(),
479 shape_cache: std::sync::OnceLock::new(),
480 ranges_cache: std::sync::OnceLock::new(),
481 in_scope_ranges_cache: std::sync::OnceLock::new(),
482 vmin_vmax_cache: std::sync::OnceLock::new(),
483 sound_vmin_vmax_cache: std::sync::OnceLock::new(),
484 has_index_in_sources_cache: std::sync::OnceLock::new(),
485 backward_slice_cache: std::sync::OnceLock::new(),
486 metadata: Some(metadata),
487 })
488 }
489}