use std::any::Any;
use std::fmt::{Debug, Formatter};
use std::mem::transmute;
use std::ptr::NonNull;
use crate::cycle::{
CycleHeads, CycleHeadsIterator, IterationCount, ProvisionalStatus, empty_cycle_heads,
};
use crate::function::{Configuration, IngredientImpl};
use crate::ingredient::WaitForResult;
use crate::key::DatabaseKeyIndex;
use crate::revision::AtomicRevision;
use crate::sync::atomic::Ordering;
use crate::table::memo::MemoTableWithTypesMut;
use crate::zalsa::{MemoIngredientIndex, Zalsa};
use crate::zalsa_local::{QueryOriginRef, QueryRevisions};
use crate::{Event, EventKind, Id, Revision};
impl<C: Configuration> IngredientImpl<C> {
pub(super) fn insert_memo_into_table_for<'db>(
&self,
zalsa: &'db Zalsa,
id: Id,
memo: NonNull<Memo<'db, C>>,
memo_ingredient_index: MemoIngredientIndex,
) -> Option<NonNull<Memo<'db, C>>> {
let static_memo =
unsafe { transmute::<NonNull<Memo<'db, C>>, NonNull<Memo<'static, C>>>(memo) };
let old_static_memo = zalsa
.memo_table_for::<C::SalsaStruct<'_>>(id)
.insert(memo_ingredient_index, static_memo)?;
Some(unsafe {
transmute::<NonNull<Memo<'static, C>>, NonNull<Memo<'db, C>>>(old_static_memo)
})
}
pub(super) fn get_memo_from_table_for<'db>(
&self,
zalsa: &'db Zalsa,
id: Id,
memo_ingredient_index: MemoIngredientIndex,
) -> Option<&'db Memo<'db, C>> {
let static_memo = zalsa
.memo_table_for::<C::SalsaStruct<'_>>(id)
.get(memo_ingredient_index)?;
Some(unsafe { transmute::<&Memo<'static, C>, &'db Memo<'db, C>>(static_memo.as_ref()) })
}
pub(super) fn evict_value_from_memo_for(
table: MemoTableWithTypesMut<'_>,
memo_ingredient_index: MemoIngredientIndex,
) {
let map = |memo: &mut Memo<'static, C>| {
match memo.revisions.origin.as_ref() {
QueryOriginRef::Assigned(_) | QueryOriginRef::DerivedUntracked(_) => {
}
QueryOriginRef::Derived(_) => {
memo.value = None;
}
}
};
table.map_memo(memo_ingredient_index, map)
}
}
#[derive(Debug)]
pub struct Memo<'db, C: Configuration> {
pub(super) value: Option<C::Output<'db>>,
pub(super) verified_at: AtomicRevision,
pub(super) revisions: QueryRevisions,
}
impl<'db, C: Configuration> Memo<'db, C> {
pub(super) fn new(
value: Option<C::Output<'db>>,
revision_now: Revision,
revisions: QueryRevisions,
) -> Self {
debug_assert!(
!revisions.verified_final.load(Ordering::Relaxed) || revisions.cycle_heads().is_empty(),
"Memo must be finalized if it has no cycle heads"
);
Memo {
value,
verified_at: AtomicRevision::from(revision_now),
revisions,
}
}
pub(super) fn should_serialize(&self) -> bool {
self.value.is_some() && !self.may_be_provisional()
}
#[inline]
pub(super) fn may_be_provisional(&self) -> bool {
!self.revisions.verified_final.load(Ordering::Relaxed)
}
#[inline(always)]
pub(super) fn cycle_heads(&self) -> &CycleHeads {
if self.may_be_provisional() {
self.revisions.cycle_heads()
} else {
empty_cycle_heads()
}
}
#[inline(always)]
pub(super) fn was_cycle_participant(&self) -> bool {
!self.revisions.cycle_heads().is_empty()
}
#[inline]
pub(super) fn mark_as_verified(&self, zalsa: &Zalsa, database_key_index: DatabaseKeyIndex) {
zalsa.event(&|| {
Event::new(EventKind::DidValidateMemoizedValue {
database_key: database_key_index,
})
});
self.verified_at.store(zalsa.current_revision());
}
pub(super) fn mark_outputs_as_verified(
&self,
zalsa: &Zalsa,
database_key_index: DatabaseKeyIndex,
) {
for output in self.revisions.origin.as_ref().outputs() {
output.mark_validated_output(zalsa, database_key_index);
}
}
pub(super) fn tracing_debug(&self) -> impl std::fmt::Debug + use<'_, 'db, C> {
struct TracingDebug<'memo, 'db, C: Configuration> {
memo: &'memo Memo<'db, C>,
}
impl<C: Configuration> std::fmt::Debug for TracingDebug<'_, '_, C> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Memo")
.field(
"value",
if self.memo.value.is_some() {
&"Some(<value>)"
} else {
&"None"
},
)
.field("verified_at", &self.memo.verified_at)
.field("revisions", &self.memo.revisions)
.finish()
}
}
TracingDebug { memo: self }
}
}
impl<C: Configuration> crate::table::memo::Memo for Memo<'static, C>
where
C::Output<'static>: Send + Sync + Any,
{
fn remove_outputs(&self, zalsa: &Zalsa, executor: DatabaseKeyIndex) {
for stale_output in self.revisions.origin.as_ref().outputs() {
stale_output.remove_stale_output(zalsa, executor);
}
for (identity, id) in self.revisions.tracked_struct_ids() {
let key = DatabaseKeyIndex::new(identity.ingredient_index(), *id);
key.remove_stale_output(zalsa, executor);
}
}
#[cfg(feature = "salsa_unstable")]
fn memory_usage(&self) -> crate::database::MemoInfo {
let size_of = std::mem::size_of::<Memo<C>>() + self.revisions.allocation_size();
let heap_size = if let Some(value) = self.value.as_ref() {
C::heap_size(value)
} else {
Some(0)
};
crate::database::MemoInfo {
debug_name: C::DEBUG_NAME,
output: crate::database::SlotInfo {
size_of_metadata: size_of - std::mem::size_of::<C::Output<'static>>(),
debug_name: std::any::type_name::<C::Output<'static>>(),
size_of_fields: std::mem::size_of::<C::Output<'static>>(),
heap_size_of_fields: heap_size,
memos: Vec::new(),
},
}
}
}
#[cfg(feature = "persistence")]
mod persistence {
use crate::function::Configuration;
use crate::function::memo::Memo;
use crate::revision::AtomicRevision;
use crate::zalsa_local::persistence::MappedQueryRevisions;
use crate::zalsa_local::{QueryOrigin, QueryRevisions};
use serde::Deserialize;
use serde::ser::SerializeStruct;
pub(crate) struct MappedMemo<'memo, 'db, C: Configuration> {
pub(crate) value: Option<&'memo C::Output<'db>>,
pub(crate) verified_at: AtomicRevision,
pub(crate) revisions: MappedQueryRevisions<'memo>,
}
impl<'db, C: Configuration> Memo<'db, C> {
pub(crate) fn with_origin(&self, origin: QueryOrigin) -> MappedMemo<'_, 'db, C> {
let Memo {
ref verified_at,
ref value,
ref revisions,
} = *self;
MappedMemo {
value: value.as_ref(),
verified_at: AtomicRevision::from(verified_at.load()),
revisions: revisions.with_origin(origin),
}
}
}
impl<C> serde::Serialize for MappedMemo<'_, '_, C>
where
C: Configuration,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
struct SerializeValue<'me, 'db, C: Configuration>(&'me C::Output<'db>);
impl<C> serde::Serialize for SerializeValue<'_, '_, C>
where
C: Configuration,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
C::serialize(self.0, serializer)
}
}
let MappedMemo {
value,
verified_at,
revisions,
} = self;
let value = value.expect(
"attempted to serialize memo where `Memo::should_serialize` returned `false`",
);
let mut s = serializer.serialize_struct("Memo", 3)?;
s.serialize_field("value", &SerializeValue::<C>(value))?;
s.serialize_field("verified_at", &verified_at)?;
s.serialize_field("revisions", &revisions)?;
s.end()
}
}
impl<'de, C> serde::Deserialize<'de> for Memo<'static, C>
where
C: Configuration,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(rename = "Memo")]
pub struct DeserializeMemo<C: Configuration> {
#[serde(bound = "C: Configuration")]
value: DeserializeValue<C>,
verified_at: AtomicRevision,
revisions: QueryRevisions,
}
struct DeserializeValue<C: Configuration>(C::Output<'static>);
impl<'de, C> serde::Deserialize<'de> for DeserializeValue<C>
where
C: Configuration,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
C::deserialize(deserializer)
.map(DeserializeValue)
.map_err(serde::de::Error::custom)
}
}
let memo = DeserializeMemo::<C>::deserialize(deserializer)?;
Ok(Memo {
value: Some(memo.value.0),
verified_at: memo.verified_at,
revisions: memo.revisions,
})
}
}
}
#[derive(Debug)]
pub(super) enum TryClaimHeadsResult {
Cycle {
head_iteration_count: IterationCount,
memo_iteration_count: IterationCount,
verified_at: Revision,
},
Available,
Running,
}
pub(super) struct TryClaimCycleHeadsIter<'a> {
zalsa: &'a Zalsa,
cycle_heads: CycleHeadsIterator<'a>,
}
impl<'a> TryClaimCycleHeadsIter<'a> {
pub(super) fn new(zalsa: &'a Zalsa, cycle_heads: &'a CycleHeads) -> Self {
Self {
zalsa,
cycle_heads: cycle_heads.iter(),
}
}
}
impl Iterator for TryClaimCycleHeadsIter<'_> {
type Item = TryClaimHeadsResult;
fn next(&mut self) -> Option<Self::Item> {
let head = self.cycle_heads.next()?;
let head_database_key = head.database_key_index;
let head_key_index = head_database_key.key_index();
let ingredient = self
.zalsa
.lookup_ingredient(head_database_key.ingredient_index());
match ingredient.wait_for(self.zalsa, head_key_index) {
WaitForResult::Cycle { .. } => {
crate::tracing::trace!("Waiting for {head_database_key:?} results in a cycle");
let provisional_status = ingredient
.provisional_status(self.zalsa, head_key_index)
.expect("cycle head memo to exist");
let (current_iteration_count, verified_at) = match provisional_status {
ProvisionalStatus::Provisional {
iteration,
verified_at,
cycle_heads: _,
} => (iteration, verified_at),
ProvisionalStatus::Final {
iteration,
verified_at,
} => (iteration, verified_at),
};
Some(TryClaimHeadsResult::Cycle {
memo_iteration_count: current_iteration_count,
head_iteration_count: head.iteration_count.load(),
verified_at,
})
}
WaitForResult::Running(running) => {
crate::tracing::trace!("Ingredient {head_database_key:?} is running: {running:?}");
Some(TryClaimHeadsResult::Running)
}
WaitForResult::Available => Some(TryClaimHeadsResult::Available),
}
}
}
#[cfg(all(not(feature = "shuttle"), target_pointer_width = "64"))]
mod _memory_usage {
use crate::cycle::CycleRecoveryStrategy;
use crate::ingredient::Location;
use crate::plumbing::{self, IngredientIndices, MemoIngredientSingletonIndex, SalsaStructInDb};
use crate::table::memo::MemoTableWithTypes;
use crate::zalsa::Zalsa;
use crate::{Database, Id, Revision};
use std::any::TypeId;
use std::num::NonZeroUsize;
const _: [(); std::mem::size_of::<super::Memo<DummyConfiguration>>()] =
[(); std::mem::size_of::<[usize; 6]>()];
struct DummyStruct;
impl SalsaStructInDb for DummyStruct {
type MemoIngredientMap = MemoIngredientSingletonIndex;
fn lookup_ingredient_index(_: &Zalsa) -> IngredientIndices {
unimplemented!()
}
fn cast(_: Id, _: TypeId) -> Option<Self> {
unimplemented!()
}
unsafe fn memo_table(_: &Zalsa, _: Id, _: Revision) -> MemoTableWithTypes<'_> {
unimplemented!()
}
fn entries(_: &Zalsa) -> impl Iterator<Item = crate::DatabaseKeyIndex> + '_ {
std::iter::empty()
}
}
struct DummyConfiguration;
impl super::Configuration for DummyConfiguration {
const DEBUG_NAME: &'static str = "";
const LOCATION: Location = Location { file: "", line: 0 };
const PERSIST: bool = false;
const CYCLE_STRATEGY: CycleRecoveryStrategy = CycleRecoveryStrategy::Panic;
type DbView = dyn Database;
type SalsaStruct<'db> = DummyStruct;
type Input<'db> = ();
type Output<'db> = NonZeroUsize;
type Eviction = crate::function::eviction::NoopEviction;
fn values_equal<'db>(_: &Self::Output<'db>, _: &Self::Output<'db>) -> bool {
unimplemented!()
}
fn id_to_input(_: &Zalsa, _: Id) -> Self::Input<'_> {
unimplemented!()
}
fn execute<'db>(_: &'db Self::DbView, _: Self::Input<'db>) -> Self::Output<'db> {
unimplemented!()
}
fn cycle_initial<'db>(
_: &'db Self::DbView,
_: Id,
_: Self::Input<'db>,
) -> Self::Output<'db> {
unimplemented!()
}
fn recover_from_cycle<'db>(
_: &'db Self::DbView,
_: &crate::Cycle,
_: &Self::Output<'db>,
value: Self::Output<'db>,
_: Self::Input<'db>,
) -> Self::Output<'db> {
value
}
fn serialize<S>(_: &Self::Output<'_>, _: S) -> Result<S::Ok, S::Error>
where
S: plumbing::serde::Serializer,
{
unimplemented!()
}
fn deserialize<'de, D>(_: D) -> Result<Self::Output<'static>, D::Error>
where
D: plumbing::serde::Deserializer<'de>,
{
unimplemented!()
}
}
}