use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, OnceLock, Weak};
use papaya::HashMap as PapayaMap;
use parking_lot::RwLock;
use svod_device::Buffer;
use svod_ir::{Op, UOp, UOpKey};
static TENSOR_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
fn next_tensor_id() -> u64 {
TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
}
pub struct TensorEntry {
pub id: u64,
pub uop: RwLock<Arc<UOp>>,
buffer: OnceLock<Arc<Buffer>>,
}
impl TensorEntry {
pub fn buffer(&self) -> Option<&Arc<Buffer>> {
self.buffer.get()
}
pub fn set_buffer(&self, buffer: Arc<Buffer>) -> bool {
self.buffer.set(buffer).is_ok()
}
}
static TENSORS: OnceLock<PapayaMap<u64, Weak<TensorEntry>>> = OnceLock::new();
static BUFFERS: OnceLock<PapayaMap<u64, Arc<Buffer>>> = OnceLock::new();
fn tensors() -> &'static PapayaMap<u64, Weak<TensorEntry>> {
TENSORS.get_or_init(PapayaMap::new)
}
fn buffers() -> &'static PapayaMap<u64, Arc<Buffer>> {
BUFFERS.get_or_init(PapayaMap::new)
}
pub fn register_tensor(uop: Arc<UOp>) -> Arc<TensorEntry> {
let id = next_tensor_id();
let entry = Arc::new(TensorEntry { id, uop: RwLock::new(uop), buffer: OnceLock::new() });
let guard = tensors().guard();
tensors().insert(id, Arc::downgrade(&entry), &guard);
entry
}
pub fn register_tensor_with_buffer(uop: Arc<UOp>, buffer: Arc<Buffer>, buffer_uop_id: u64) -> Arc<TensorEntry> {
let id = next_tensor_id();
let entry = Arc::new(TensorEntry { id, uop: RwLock::new(uop), buffer: OnceLock::from(buffer.clone()) });
let guard = tensors().guard();
tensors().insert(id, Arc::downgrade(&entry), &guard);
let buf_guard = buffers().guard();
buffers().insert(buffer_uop_id, buffer, &buf_guard);
entry
}
pub fn get_buffer(uop_id: u64) -> Option<Buffer> {
let guard = buffers().guard();
buffers().get(&uop_id, &guard).map(|arc_buf| (**arc_buf).clone())
}
pub fn get_buffer_arc(uop_id: u64) -> Option<Arc<Buffer>> {
let guard = buffers().guard();
buffers().get(&uop_id, &guard).cloned()
}
pub fn remove_buffer(uop_id: u64) {
let buf_guard = buffers().guard();
buffers().remove(&uop_id, &buf_guard);
}
pub fn buffer_count() -> usize {
buffers().len()
}
pub fn register_buffer(uop_id: u64, tensor_id: u64, buffer: Arc<Buffer>) {
let buf_guard = buffers().guard();
buffers().insert(uop_id, buffer.clone(), &buf_guard);
if let Some(entry) = get_tensor(tensor_id) {
entry.set_buffer(buffer);
}
}
pub fn register_buffer_by_uop_id(uop_id: u64, buffer: Arc<Buffer>) {
let guard = buffers().guard();
buffers().insert(uop_id, buffer, &guard);
}
pub fn get_tensor(id: u64) -> Option<Arc<TensorEntry>> {
let guard = tensors().guard();
tensors().get(&id, &guard)?.upgrade()
}
pub fn gc_dead_refs() {
let map = tensors();
let guard = map.guard();
let to_remove: Vec<u64> = map.iter(&guard).filter(|(_, weak)| weak.upgrade().is_none()).map(|(k, _)| *k).collect();
for id in to_remove {
map.remove(&id, &guard);
}
let live_uop_ids = svod_ir::uop::live_uop_ids();
let buf_map = buffers();
let buf_guard = buf_map.guard();
let stale_bufs: Vec<u64> =
buf_map.iter(&buf_guard).filter(|(uop_id, _)| !live_uop_ids.contains(uop_id)).map(|(id, _)| *id).collect();
for uop_id in stale_bufs {
buf_map.remove(&uop_id, &buf_guard);
}
}
#[deprecated(note = "Tensor registry now uses weak refs - cleanup is automatic. Use gc_dead_refs() to clean registry.")]
pub fn gc_unused_tensors() {
gc_dead_refs();
}
#[allow(clippy::mutable_key_type)]
pub fn apply_map_to_tensors(becomes_map: &HashMap<UOpKey, Arc<UOp>>) {
apply_map_to_tensors_inner(becomes_map, false);
}
#[allow(clippy::mutable_key_type)]
pub fn apply_map_to_tensors_walk(becomes_map: &HashMap<UOpKey, Arc<UOp>>) {
apply_map_to_tensors_inner(becomes_map, true);
}
#[allow(clippy::mutable_key_type)]
fn apply_map_to_tensors_inner(becomes_map: &HashMap<UOpKey, Arc<UOp>>, walk: bool) {
if becomes_map.is_empty() {
return;
}
let map = tensors();
let guard = map.guard();
let affected: Vec<Arc<TensorEntry>> = map
.iter(&guard)
.filter_map(|(_, weak)| {
let entry = weak.upgrade()?; let is_affected = {
let uop = entry.uop.read();
if becomes_map.contains_key(&UOpKey(uop.clone())) {
true
} else {
uop.toposort().iter().any(|n| becomes_map.contains_key(&UOpKey(n.clone())))
}
}; if is_affected { Some(entry) } else { None }
})
.collect();
if affected.is_empty() {
return;
}
let sources: Vec<Arc<UOp>> = affected.iter().map(|e| e.uop.read().clone()).collect();
let sink = UOp::sink(sources.clone());
let new_sink = if walk { sink.substitute_walk(becomes_map) } else { sink.substitute(becomes_map) };
if let Op::Sink { sources: new_sources, .. } = new_sink.op() {
for (entry, (old, new)) in affected.iter().zip(sources.iter().zip(new_sources.iter())) {
if !Arc::ptr_eq(old, new) {
*entry.uop.write() = new.clone();
}
}
}
}
#[cfg(test)]
#[path = "test/unit/tensor_registry.rs"]
mod tests;