use std::hash::{Hash, Hasher};
use std::mem::discriminant;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::{Arc, OnceLock, Weak};
use papaya::HashMap;
use smallvec::SmallVec;
use crate::op::Op;
use crate::types::*;
use crate::uop::core::UOp;
use morok_dtype::DType;
use morok_dtype::DeviceSpec;
static UNIQUE_COUNTER: AtomicUsize = AtomicUsize::new(0);
pub(crate) fn next_unique_id() -> usize {
UNIQUE_COUNTER.fetch_add(1, Ordering::Relaxed)
}
static UOP_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
pub(crate) fn next_uop_id() -> u64 {
UOP_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
}
#[derive(Clone)]
struct UOpKey {
op_discriminant: std::mem::Discriminant<Op>,
dtype: DType,
src_hashes: SmallVec<[u64; 4]>,
op_data: OpData,
tag: Option<SmallVec<[usize; 2]>>,
cached_hash: u64,
}
impl Hash for UOpKey {
#[inline]
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
state.write_u64(self.cached_hash);
}
}
impl PartialEq for UOpKey {
fn eq(&self, other: &Self) -> bool {
self.cached_hash == other.cached_hash
&& self.op_discriminant == other.op_discriminant
&& self.dtype == other.dtype
&& self.src_hashes == other.src_hashes
&& self.op_data == other.op_data
&& self.tag == other.tag
}
}
impl Eq for UOpKey {}
#[derive(Eq, PartialEq, Hash, Clone)]
enum OpData {
Const(ConstValueHash),
Unique(usize),
Device(DeviceSpec),
DefineLocal(usize, usize),
Unary(UnaryOp),
Binary(BinaryOp),
Ternary(TernaryOp),
CastDType(DType),
BitCastDType(DType),
MSelectIdx(usize),
SpecialName(String),
BufferData(usize, usize), ParamData(usize, usize), BufferView(usize, usize),
Bufferize(BufferizeOpts),
PermuteAxes(Vec<usize>),
FlipAxes(Vec<bool>),
MultiAxis(usize),
ReduceAxisData(ReduceOp, Vec<usize>),
ReduceOp(ReduceOp),
AllReduceOp(ReduceOp),
RangeData(AxisId, AxisType),
GepIndices(Vec<usize>),
VConstValues(Vec<ConstValueHash>),
DefineVarData(String, i64, i64), DefineRegData(usize, usize),
WmmaData(Box<WmmaMetadata>),
ContractRanges(Vec<(usize, usize)>),
UnrollAxes(Vec<(usize, usize)>),
CustomCode(String),
ContiguousOpts(Vec<crate::types::ContiguousHint>),
None,
}
fn src_hashes(op: &Op) -> SmallVec<[u64; 4]> {
op.children().into_iter().map(|child| child.content_hash).collect()
}
impl UOpKey {
fn new(op: &Op, dtype: DType, tag: &Option<SmallVec<[usize; 2]>>) -> Self {
let op_discriminant = discriminant(op);
let src_hashes = src_hashes(op);
let op_data = match op {
Op::Const(c) => OpData::Const(*c),
Op::Unique(id) => OpData::Unique(*id),
Op::Device(d) => OpData::Device(d.clone()),
Op::DefineLocal(slot) => OpData::DefineLocal(*slot, next_unique_id()),
Op::Unary(unary_op, _) => OpData::Unary(*unary_op),
Op::Binary(binary_op, _, _) => OpData::Binary(*binary_op),
Op::Ternary(ternary_op, _, _, _) => OpData::Ternary(*ternary_op),
Op::Cast { dtype, .. } => OpData::CastDType(dtype.clone()),
Op::BitCast { dtype, .. } => OpData::BitCastDType(dtype.clone()),
Op::MSelect { device_index, .. } => OpData::MSelectIdx(*device_index),
Op::Special { name, .. } => OpData::SpecialName(name.clone()),
Op::Buffer { unique, size, .. } => {
if let Op::Unique(id) = unique.op() {
OpData::BufferData(*id, *size)
} else {
OpData::BufferData(unique.id as usize, *size)
}
}
Op::BufferView { size, offset, .. } => OpData::BufferView(*size, *offset),
Op::Bufferize { opts, .. } => OpData::Bufferize(opts.clone()),
Op::Permute { axes, .. } => OpData::PermuteAxes(axes.clone()),
Op::Flip { axes, .. } => OpData::FlipAxes(axes.clone()),
Op::Multi { axis, .. } => OpData::MultiAxis(*axis),
Op::ReduceAxis { reduce_op, axes, .. } => OpData::ReduceAxisData(*reduce_op, axes.clone()),
Op::Reduce { reduce_op, .. } => OpData::ReduceOp(*reduce_op),
Op::AllReduce { reduce_op, .. } => OpData::AllReduceOp(*reduce_op),
Op::Range { axis_id, axis_type, .. } => OpData::RangeData(*axis_id, *axis_type),
Op::Gep { indices, .. } => OpData::GepIndices(indices.clone()),
Op::VConst { values } => OpData::VConstValues(values.iter().map(|v| ConstValueHash(*v)).collect()),
Op::DefineVar { name, min_val, max_val } => OpData::DefineVarData(name.clone(), *min_val, *max_val),
Op::DefineReg { size, id } => OpData::DefineRegData(*size, *id),
Op::Wmma { metadata, .. } => OpData::WmmaData(metadata.clone().into()),
Op::Contract { upcast_ranges, .. } => OpData::ContractRanges(upcast_ranges.clone()),
Op::Unroll { unroll_axes, .. } => OpData::UnrollAxes(unroll_axes.clone()),
Op::Custom { code, .. } | Op::CustomI { code, .. } => OpData::CustomCode(code.clone()),
Op::Contiguous { opts, .. } => OpData::ContiguousOpts(opts.to_vec()),
Op::Param { slot, size, .. } => OpData::ParamData(*slot, *size),
Op::Noop | Op::Invalid => OpData::None,
Op::Sink { .. }
| Op::Group { .. }
| Op::Vectorize { .. }
| Op::Cat { .. }
| Op::PtrCat { .. }
| Op::MStack { .. }
| Op::Barrier { .. } => OpData::None,
Op::Reshape { .. } | Op::Expand { .. } | Op::Pad { .. } | Op::Shrink { .. } => OpData::None,
Op::Index { .. } | Op::PointerIndex { .. } | Op::Copy { .. } | Op::Load { .. } | Op::Store { .. } => {
OpData::None
}
Op::If { .. } | Op::EndIf { .. } | Op::End { .. } | Op::After { .. } => OpData::None,
Op::Detach { .. } | Op::ContiguousBackward { .. } | Op::Precast { .. } => OpData::None,
Op::Bind { .. } | Op::Kernel { .. } | Op::Assign { .. } => OpData::None,
};
let cached_hash = {
use xxhash_rust::xxh64::Xxh64;
let mut h = Xxh64::new(0);
op_discriminant.hash(&mut h);
dtype.hash(&mut h);
for id in &src_hashes {
h.write_u64(*id);
}
op_data.hash(&mut h);
tag.hash(&mut h);
h.finish()
};
Self { op_discriminant, dtype, src_hashes, op_data, tag: tag.clone(), cached_hash }
}
}
static UOPS: OnceLock<HashMap<UOpKey, Weak<UOp>>> = OnceLock::new();
fn uops() -> &'static HashMap<UOpKey, Weak<UOp>> {
UOPS.get_or_init(HashMap::new)
}
pub fn gc_dead_refs() {
let map = uops();
let guard = map.guard();
let to_remove: Vec<UOpKey> =
map.iter(&guard).filter(|(_, weak)| weak.upgrade().is_none()).map(|(k, _)| k.clone()).collect();
for key in to_remove {
map.remove(&key, &guard);
}
}
#[deprecated(note = "UOp cache now uses weak refs - cleanup is automatic. Use gc_dead_refs() to clean cache.")]
pub fn gc_unused_uops() {
gc_dead_refs();
}
pub fn live_uop_ids() -> std::collections::HashSet<u64> {
let map = uops();
let guard = map.guard();
map.iter(&guard).filter_map(|(_, weak)| weak.upgrade().map(|arc| arc.id)).collect()
}
impl UOp {
#[inline]
#[track_caller]
pub fn new(op: Op, dtype: DType) -> Arc<Self> {
Self::new_tagged(op, dtype, None)
}
#[track_caller]
pub fn new_tagged(op: Op, dtype: DType, tag: Option<SmallVec<[usize; 2]>>) -> Arc<Self> {
use papaya::{Compute, Operation};
let caller_location = std::panic::Location::caller();
let key = UOpKey::new(&op, dtype.clone(), &tag);
let guard = uops().guard();
if let Some(weak) = uops().get(&key, &guard)
&& let Some(arc) = weak.upgrade()
{
use crate::provenance::PROVENANCE_TRACKER;
PROVENANCE_TRACKER.with(|tracker| {
tracker.borrow_mut().capture(arc.id, caller_location);
});
return arc;
}
let content_hash = {
use xxhash_rust::xxh64::Xxh64;
let mut h = Xxh64::new(0);
std::mem::discriminant(&op).hash(&mut h);
dtype.hash(&mut h);
for child in op.children() {
h.write_u64(child.content_hash);
}
key.op_data.hash(&mut h);
h.finish()
};
let new_arc = Arc::new(Self {
id: next_uop_id(),
op,
dtype,
content_hash,
tag,
shape_cache: std::sync::OnceLock::new(),
ranges_cache: std::sync::OnceLock::new(),
in_scope_ranges_cache: std::sync::OnceLock::new(),
vmin_vmax_cache: std::sync::OnceLock::new(),
sound_vmin_vmax_cache: std::sync::OnceLock::new(),
has_index_in_sources_cache: std::sync::OnceLock::new(),
backward_slice_cache: std::sync::OnceLock::new(),
metadata: None,
});
let new_weak = Arc::downgrade(&new_arc);
let result = uops().compute(
key,
|entry| match entry {
Some((_, existing_weak)) => {
if let Some(existing_arc) = existing_weak.upgrade() {
Operation::Abort(existing_arc)
} else {
Operation::Insert(new_weak.clone())
}
}
None => Operation::Insert(new_weak.clone()),
},
&guard,
);
let final_arc = match result {
Compute::Inserted(_, _) | Compute::Updated { .. } => new_arc,
Compute::Aborted(existing_arc) => existing_arc,
_ => new_arc,
};
use crate::provenance::PROVENANCE_TRACKER;
PROVENANCE_TRACKER.with(|tracker| {
tracker.borrow_mut().capture(final_arc.id, caller_location);
});
final_arc
}
pub fn with_metadata<T: std::any::Any + Send + Sync + 'static>(self: &Arc<Self>, metadata: T) -> Arc<Self> {
self.with_metadata_raw(Arc::new(metadata))
}
pub fn metadata<T: std::any::Any + Send + Sync>(&self) -> Option<std::sync::Arc<T>> {
self.metadata.as_ref()?.clone().downcast::<T>().ok()
}
pub fn metadata_raw(&self) -> Option<Arc<dyn std::any::Any + Send + Sync>> {
self.metadata.clone()
}
pub fn with_metadata_raw(self: &Arc<Self>, metadata: Arc<dyn std::any::Any + Send + Sync>) -> Arc<Self> {
Arc::new(Self {
id: next_uop_id(),
op: self.op.clone(),
dtype: self.dtype.clone(),
content_hash: self.content_hash, tag: self.tag.clone(),
shape_cache: std::sync::OnceLock::new(),
ranges_cache: std::sync::OnceLock::new(),
in_scope_ranges_cache: std::sync::OnceLock::new(),
vmin_vmax_cache: std::sync::OnceLock::new(),
sound_vmin_vmax_cache: std::sync::OnceLock::new(),
has_index_in_sources_cache: std::sync::OnceLock::new(),
backward_slice_cache: std::sync::OnceLock::new(),
metadata: Some(metadata),
})
}
}