use std::any::{Any, TypeId};
use std::fmt::Debug;
use std::mem;
use std::ptr::{self, NonNull};
use crate::DatabaseKeyIndex;
use crate::sync::atomic::{AtomicPtr, Ordering};
use crate::zalsa::MemoIngredientIndex;
use crate::zalsa::Zalsa;
pub struct MemoTable {
memos: LazyMemoEntries,
}
#[cfg(not(feature = "shuttle"))]
const _: [(); mem::size_of::<MemoTable>()] = [(); 2 * mem::size_of::<usize>()];
impl MemoTable {
pub unsafe fn new(types: &MemoTableTypes) -> Self {
Self {
memos: LazyMemoEntries::new(types.len()),
}
}
pub fn reset(&mut self) {
self.memos.clear();
}
}
pub trait Memo: Any + Send + Sync {
fn remove_outputs(&self, zalsa: &Zalsa, executor: DatabaseKeyIndex);
#[cfg(feature = "salsa_unstable")]
fn memory_usage(&self) -> crate::database::MemoInfo;
}
#[derive(Default, Debug)]
struct MemoEntry {
atomic_memo: AtomicPtr<DummyMemo>,
}
struct LazyMemoEntries {
ptr: AtomicPtr<MemoEntry>,
len: usize,
}
impl LazyMemoEntries {
fn new(len: usize) -> Self {
Self {
ptr: AtomicPtr::new(ptr::null_mut()),
len,
}
}
#[inline]
fn get(&self, index: usize) -> Option<&MemoEntry> {
self.as_slice()?.get(index)
}
#[inline]
fn get_or_init(&self, index: usize) -> Option<&MemoEntry> {
if index >= self.len {
return None;
}
let memos = self.as_slice().unwrap_or_else(|| self.initialize());
Some(&memos[index])
}
#[inline]
fn get_mut(&mut self, index: usize) -> Option<&mut MemoEntry> {
self.as_mut_slice()?.get_mut(index)
}
fn iter(&self) -> std::slice::Iter<'_, MemoEntry> {
self.as_slice().unwrap_or_default().iter()
}
fn iter_mut(&mut self) -> std::slice::IterMut<'_, MemoEntry> {
self.as_mut_slice().unwrap_or_default().iter_mut()
}
#[inline]
fn as_slice(&self) -> Option<&[MemoEntry]> {
let ptr = NonNull::new(self.ptr.load(Ordering::Acquire))?;
Some(unsafe { std::slice::from_raw_parts(ptr.as_ptr(), self.len) })
}
#[inline]
fn as_mut_slice(&mut self) -> Option<&mut [MemoEntry]> {
let ptr = NonNull::new(*self.ptr.get_mut())?;
Some(unsafe { std::slice::from_raw_parts_mut(ptr.as_ptr(), self.len) })
}
#[cold]
fn initialize(&self) -> &[MemoEntry] {
let new_memos: Box<[MemoEntry]> = (0..self.len).map(|_| MemoEntry::default()).collect();
let new_memos = Box::into_raw(new_memos);
let new_memos_ptr = new_memos.cast::<MemoEntry>();
let ptr = match self.ptr.compare_exchange(
ptr::null_mut(),
new_memos_ptr,
Ordering::Release,
Ordering::Acquire,
) {
Ok(_) => new_memos_ptr,
Err(ptr) => {
unsafe { drop(Box::from_raw(new_memos)) };
ptr
}
};
unsafe { std::slice::from_raw_parts(ptr, self.len) }
}
fn clear(&mut self) {
let ptr = mem::replace(self.ptr.get_mut(), ptr::null_mut());
if ptr.is_null() {
return;
}
unsafe { drop(Box::from_raw(ptr::slice_from_raw_parts_mut(ptr, self.len))) };
}
}
impl Drop for LazyMemoEntries {
fn drop(&mut self) {
self.clear();
}
}
#[derive(Clone, Copy, Debug)]
pub struct MemoEntryType {
type_id: TypeId,
to_dyn_fn: fn(NonNull<DummyMemo>) -> NonNull<dyn Memo>,
}
impl MemoEntryType {
fn to_dummy<M: Memo>(memo: NonNull<M>) -> NonNull<DummyMemo> {
memo.cast()
}
unsafe fn from_dummy<M: Memo>(memo: NonNull<DummyMemo>) -> NonNull<M> {
memo.cast()
}
const fn to_dyn_fn<M: Memo>() -> fn(NonNull<DummyMemo>) -> NonNull<dyn Memo> {
let f: fn(NonNull<M>) -> NonNull<dyn Memo> = |x| x;
unsafe {
mem::transmute::<
fn(NonNull<M>) -> NonNull<dyn Memo>,
fn(NonNull<DummyMemo>) -> NonNull<dyn Memo>,
>(f)
}
}
#[inline]
pub fn of<M: Memo>() -> Self {
Self {
type_id: TypeId::of::<M>(),
to_dyn_fn: Self::to_dyn_fn::<M>(),
}
}
}
#[derive(Debug)]
struct DummyMemo;
impl Memo for DummyMemo {
fn remove_outputs(&self, _zalsa: &Zalsa, _executor: DatabaseKeyIndex) {}
#[cfg(feature = "salsa_unstable")]
fn memory_usage(&self) -> crate::database::MemoInfo {
crate::database::MemoInfo {
debug_name: "dummy",
output: crate::database::SlotInfo {
debug_name: "dummy",
size_of_metadata: 0,
size_of_fields: 0,
heap_size_of_fields: None,
memos: Vec::new(),
},
}
}
}
#[derive(Default)]
pub struct MemoTableTypes {
types: Vec<MemoEntryType>,
}
impl MemoTableTypes {
pub(crate) fn set(
&mut self,
memo_ingredient_index: MemoIngredientIndex,
memo_type: MemoEntryType,
) {
self.types
.insert(memo_ingredient_index.as_usize(), memo_type);
}
pub fn len(&self) -> usize {
self.types.len()
}
#[inline]
pub(crate) unsafe fn attach_memos<'a>(
&'a self,
memos: &'a MemoTable,
) -> MemoTableWithTypes<'a> {
MemoTableWithTypes { types: self, memos }
}
#[inline]
pub(crate) unsafe fn attach_memos_mut<'a>(
&'a self,
memos: &'a mut MemoTable,
) -> MemoTableWithTypesMut<'a> {
MemoTableWithTypesMut { types: self, memos }
}
}
pub struct MemoTableWithTypes<'a> {
types: &'a MemoTableTypes,
memos: &'a MemoTable,
}
impl MemoTableWithTypes<'_> {
pub(crate) fn insert<M: Memo>(
self,
memo_ingredient_index: MemoIngredientIndex,
memo: NonNull<M>,
) -> Option<NonNull<M>> {
let MemoEntry { atomic_memo } = self
.memos
.memos
.get_or_init(memo_ingredient_index.as_usize())?;
let type_ = unsafe {
self.types
.types
.get_unchecked(memo_ingredient_index.as_usize())
};
if type_.type_id != TypeId::of::<M>() {
type_assert_failed(memo_ingredient_index);
}
let old_memo = atomic_memo.swap(MemoEntryType::to_dummy(memo).as_ptr(), Ordering::AcqRel);
NonNull::new(old_memo).map(|old_memo| unsafe { MemoEntryType::from_dummy(old_memo) })
}
#[inline]
pub(crate) fn get<M: Memo>(
self,
memo_ingredient_index: MemoIngredientIndex,
) -> Option<NonNull<M>> {
let MemoEntry { atomic_memo } = self.memos.memos.get(memo_ingredient_index.as_usize())?;
let type_ = unsafe {
self.types
.types
.get_unchecked(memo_ingredient_index.as_usize())
};
if type_.type_id != TypeId::of::<M>() {
type_assert_failed(memo_ingredient_index);
}
NonNull::new(atomic_memo.load(Ordering::Acquire))
.map(|memo| unsafe { MemoEntryType::from_dummy(memo) })
}
#[cfg(feature = "salsa_unstable")]
pub(crate) fn memory_usage(&self) -> Vec<crate::database::MemoInfo> {
let mut memory_usage = Vec::new();
for (index, memo) in self.memos.memos.iter().enumerate() {
let Some(memo) = NonNull::new(memo.atomic_memo.load(Ordering::Acquire)) else {
continue;
};
let Some(type_) = self.types.types.get(index) else {
continue;
};
let dyn_memo: &dyn Memo = unsafe { (type_.to_dyn_fn)(memo).as_ref() };
memory_usage.push(dyn_memo.memory_usage());
}
memory_usage
}
}
pub(crate) struct MemoTableWithTypesMut<'a> {
types: &'a MemoTableTypes,
memos: &'a mut MemoTable,
}
impl MemoTableWithTypesMut<'_> {
pub(crate) fn map_memo<M: Memo>(
self,
memo_ingredient_index: MemoIngredientIndex,
f: impl FnOnce(&mut M),
) {
let Some(MemoEntry { atomic_memo }) =
self.memos.memos.get_mut(memo_ingredient_index.as_usize())
else {
return;
};
let type_ = unsafe {
self.types
.types
.get_unchecked(memo_ingredient_index.as_usize())
};
if type_.type_id != TypeId::of::<M>() {
type_assert_failed(memo_ingredient_index);
}
let Some(memo) = NonNull::new(*atomic_memo.get_mut()) else {
return;
};
f(unsafe { MemoEntryType::from_dummy(memo).as_mut() });
}
#[inline]
pub unsafe fn drop(&mut self) {
let types = self.types.types.iter();
for (type_, memo) in std::iter::zip(types, self.memos.memos.iter_mut()) {
unsafe { memo.take(type_) };
}
}
pub(crate) unsafe fn take_memos(
&mut self,
mut f: impl FnMut(MemoIngredientIndex, Box<dyn Memo>),
) {
self.memos
.memos
.iter_mut()
.zip(self.types.types.iter())
.enumerate()
.filter_map(|(index, (memo, type_))| {
let memo = unsafe { memo.take(type_)? };
Some((MemoIngredientIndex::from_usize(index), memo))
})
.for_each(|(index, memo)| f(index, memo));
}
}
#[cold]
#[inline(never)]
fn type_assert_failed(memo_ingredient_index: MemoIngredientIndex) -> ! {
panic!("inconsistent type-id for `{memo_ingredient_index:?}`")
}
impl MemoEntry {
#[inline]
unsafe fn take(&mut self, type_: &MemoEntryType) -> Option<Box<dyn Memo>> {
let memo = mem::replace(self.atomic_memo.get_mut(), ptr::null_mut());
let memo = NonNull::new(memo)?;
Some(unsafe { Box::from_raw((type_.to_dyn_fn)(memo).as_ptr()) })
}
}
impl Drop for DummyMemo {
fn drop(&mut self) {
unreachable!("should never get here")
}
}
impl std::fmt::Debug for MemoTable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoTable").finish_non_exhaustive()
}
}