use std::any::{Any, TypeId};
use std::fmt::Debug;
use std::mem;
use std::ptr::{self, NonNull};
use crate::DatabaseKeyIndex;
use crate::sync::atomic::{AtomicPtr, Ordering};
use crate::zalsa::MemoIngredientIndex;
use crate::zalsa::Zalsa;
pub struct MemoTable {
memos: Box<[MemoEntry]>,
}
impl MemoTable {
pub unsafe fn new(types: &MemoTableTypes) -> Self {
Self {
memos: (0..types.len()).map(|_| MemoEntry::default()).collect(),
}
}
pub fn reset(&mut self) {
for memo in &mut self.memos {
*memo = MemoEntry::default();
}
}
}
pub trait Memo: Any + Send + Sync {
fn remove_outputs(&self, zalsa: &Zalsa, executor: DatabaseKeyIndex);
#[cfg(feature = "salsa_unstable")]
fn memory_usage(&self) -> crate::database::MemoInfo;
}
#[derive(Default, Debug)]
struct MemoEntry {
atomic_memo: AtomicPtr<DummyMemo>,
}
#[derive(Clone, Copy, Debug)]
pub struct MemoEntryType {
type_id: TypeId,
to_dyn_fn: fn(NonNull<DummyMemo>) -> NonNull<dyn Memo>,
}
impl MemoEntryType {
fn to_dummy<M: Memo>(memo: NonNull<M>) -> NonNull<DummyMemo> {
memo.cast()
}
unsafe fn from_dummy<M: Memo>(memo: NonNull<DummyMemo>) -> NonNull<M> {
memo.cast()
}
const fn to_dyn_fn<M: Memo>() -> fn(NonNull<DummyMemo>) -> NonNull<dyn Memo> {
let f: fn(NonNull<M>) -> NonNull<dyn Memo> = |x| x;
unsafe {
mem::transmute::<
fn(NonNull<M>) -> NonNull<dyn Memo>,
fn(NonNull<DummyMemo>) -> NonNull<dyn Memo>,
>(f)
}
}
#[inline]
pub fn of<M: Memo>() -> Self {
Self {
type_id: TypeId::of::<M>(),
to_dyn_fn: Self::to_dyn_fn::<M>(),
}
}
}
#[derive(Debug)]
struct DummyMemo;
impl Memo for DummyMemo {
fn remove_outputs(&self, _zalsa: &Zalsa, _executor: DatabaseKeyIndex) {}
#[cfg(feature = "salsa_unstable")]
fn memory_usage(&self) -> crate::database::MemoInfo {
crate::database::MemoInfo {
debug_name: "dummy",
output: crate::database::SlotInfo {
debug_name: "dummy",
size_of_metadata: 0,
size_of_fields: 0,
heap_size_of_fields: None,
memos: Vec::new(),
},
}
}
}
#[derive(Default)]
pub struct MemoTableTypes {
types: Vec<MemoEntryType>,
}
impl MemoTableTypes {
pub(crate) fn set(
&mut self,
memo_ingredient_index: MemoIngredientIndex,
memo_type: MemoEntryType,
) {
self.types
.insert(memo_ingredient_index.as_usize(), memo_type);
}
pub fn len(&self) -> usize {
self.types.len()
}
#[inline]
pub(crate) unsafe fn attach_memos<'a>(
&'a self,
memos: &'a MemoTable,
) -> MemoTableWithTypes<'a> {
MemoTableWithTypes { types: self, memos }
}
#[inline]
pub(crate) unsafe fn attach_memos_mut<'a>(
&'a self,
memos: &'a mut MemoTable,
) -> MemoTableWithTypesMut<'a> {
MemoTableWithTypesMut { types: self, memos }
}
}
pub struct MemoTableWithTypes<'a> {
types: &'a MemoTableTypes,
memos: &'a MemoTable,
}
impl MemoTableWithTypes<'_> {
pub(crate) fn insert<M: Memo>(
self,
memo_ingredient_index: MemoIngredientIndex,
memo: NonNull<M>,
) -> Option<NonNull<M>> {
let MemoEntry { atomic_memo } = self.memos.memos.get(memo_ingredient_index.as_usize())?;
let type_ = unsafe {
self.types
.types
.get_unchecked(memo_ingredient_index.as_usize())
};
if type_.type_id != TypeId::of::<M>() {
type_assert_failed(memo_ingredient_index);
}
let old_memo = atomic_memo.swap(MemoEntryType::to_dummy(memo).as_ptr(), Ordering::AcqRel);
NonNull::new(old_memo).map(|old_memo| unsafe { MemoEntryType::from_dummy(old_memo) })
}
#[inline]
pub(crate) fn get<M: Memo>(
self,
memo_ingredient_index: MemoIngredientIndex,
) -> Option<NonNull<M>> {
let MemoEntry { atomic_memo } = self.memos.memos.get(memo_ingredient_index.as_usize())?;
let type_ = unsafe {
self.types
.types
.get_unchecked(memo_ingredient_index.as_usize())
};
if type_.type_id != TypeId::of::<M>() {
type_assert_failed(memo_ingredient_index);
}
NonNull::new(atomic_memo.load(Ordering::Acquire))
.map(|memo| unsafe { MemoEntryType::from_dummy(memo) })
}
#[cfg(feature = "salsa_unstable")]
pub(crate) fn memory_usage(&self) -> Vec<crate::database::MemoInfo> {
let mut memory_usage = Vec::new();
for (index, memo) in self.memos.memos.iter().enumerate() {
let Some(memo) = NonNull::new(memo.atomic_memo.load(Ordering::Acquire)) else {
continue;
};
let Some(type_) = self.types.types.get(index) else {
continue;
};
let dyn_memo: &dyn Memo = unsafe { (type_.to_dyn_fn)(memo).as_ref() };
memory_usage.push(dyn_memo.memory_usage());
}
memory_usage
}
}
pub(crate) struct MemoTableWithTypesMut<'a> {
types: &'a MemoTableTypes,
memos: &'a mut MemoTable,
}
impl MemoTableWithTypesMut<'_> {
pub(crate) fn map_memo<M: Memo>(
self,
memo_ingredient_index: MemoIngredientIndex,
f: impl FnOnce(&mut M),
) {
let Some(MemoEntry { atomic_memo }) =
self.memos.memos.get_mut(memo_ingredient_index.as_usize())
else {
return;
};
let type_ = unsafe {
self.types
.types
.get_unchecked(memo_ingredient_index.as_usize())
};
if type_.type_id != TypeId::of::<M>() {
type_assert_failed(memo_ingredient_index);
}
let Some(memo) = NonNull::new(*atomic_memo.get_mut()) else {
return;
};
f(unsafe { MemoEntryType::from_dummy(memo).as_mut() });
}
#[inline]
pub unsafe fn drop(&mut self) {
let types = self.types.types.iter();
for (type_, memo) in std::iter::zip(types, &mut self.memos.memos) {
unsafe { memo.take(type_) };
}
}
pub(crate) unsafe fn take_memos(
&mut self,
mut f: impl FnMut(MemoIngredientIndex, Box<dyn Memo>),
) {
self.memos
.memos
.iter_mut()
.zip(self.types.types.iter())
.enumerate()
.filter_map(|(index, (memo, type_))| {
let memo = unsafe { memo.take(type_)? };
Some((MemoIngredientIndex::from_usize(index), memo))
})
.for_each(|(index, memo)| f(index, memo));
}
}
#[cold]
#[inline(never)]
fn type_assert_failed(memo_ingredient_index: MemoIngredientIndex) -> ! {
panic!("inconsistent type-id for `{memo_ingredient_index:?}`")
}
impl MemoEntry {
#[inline]
unsafe fn take(&mut self, type_: &MemoEntryType) -> Option<Box<dyn Memo>> {
let memo = mem::replace(self.atomic_memo.get_mut(), ptr::null_mut());
let memo = NonNull::new(memo)?;
Some(unsafe { Box::from_raw((type_.to_dyn_fn)(memo).as_ptr()) })
}
}
impl Drop for DummyMemo {
fn drop(&mut self) {
unreachable!("should never get here")
}
}
impl std::fmt::Debug for MemoTable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoTable").finish_non_exhaustive()
}
}