use std::alloc::{Layout, alloc, dealloc, handle_alloc_error};
use std::cell::{RefCell, UnsafeCell};
use std::fmt;
use std::fmt::Formatter;
use std::marker::PhantomData;
use std::mem::ManuallyDrop;
use std::panic::UnwindSafe;
use std::ptr::{self, NonNull};
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, Ordering};
use rustc_hash::FxHashMap;
use thin_vec::ThinVec;
#[cfg(feature = "accumulator")]
use crate::accumulator::{
Accumulator,
accumulated_map::{AccumulatedMap, AtomicInputAccumulatedValues},
};
use crate::active_query::{CompletedQuery, DetachedInputOutputs, QueryCompletion, QueryStack};
use crate::cycle::{AtomicIterationStamp, CycleHeads, IterationStamp, empty_cycle_heads};
use crate::durability::Durability;
use crate::key::DatabaseKeyIndex;
use crate::runtime::Stamp;
use crate::sync::atomic::AtomicBool;
use crate::table::{PageIndex, Slot, Table};
use crate::tracked_struct::{Disambiguator, Identity, IdentityHash};
use crate::zalsa::{IngredientIndex, Zalsa};
use crate::{Cancelled, Id, Revision};
pub struct ZalsaLocal {
query_stack: RefCell<QueryStack>,
most_recent_pages: UnsafeCell<FxHashMap<IngredientIndex, PageIndex>>,
cancelled: CancellationToken,
}
#[derive(Default, Clone, Debug)]
pub struct CancellationToken(Arc<AtomicU8>);
impl CancellationToken {
const CANCELLED_MASK: u8 = 0b01;
const DISABLED_MASK: u8 = 0b10;
pub fn cancel(&self) {
self.0.fetch_or(Self::CANCELLED_MASK, Ordering::Relaxed);
}
pub fn is_cancelled(&self) -> bool {
self.0.load(Ordering::Relaxed) & Self::CANCELLED_MASK != 0
}
#[inline]
fn set_cancellation_disabled(&self, disabled: bool) -> bool {
let previous_disabled_bit = if disabled {
self.0.fetch_or(Self::DISABLED_MASK, Ordering::Relaxed)
} else {
self.0.fetch_and(!Self::DISABLED_MASK, Ordering::Relaxed)
};
previous_disabled_bit & Self::DISABLED_MASK != 0
}
fn should_trigger_local_cancellation(&self) -> bool {
self.0.load(Ordering::Relaxed) == Self::CANCELLED_MASK
}
fn reset(&self) {
self.0.store(0, Ordering::Relaxed);
}
}
impl ZalsaLocal {
pub(crate) fn new() -> Self {
ZalsaLocal {
query_stack: RefCell::new(QueryStack::default()),
most_recent_pages: UnsafeCell::new(FxHashMap::default()),
cancelled: CancellationToken::default(),
}
}
pub(crate) fn record_unfilled_pages(&mut self, table: &Table) {
let most_recent_pages = self.most_recent_pages.get_mut();
most_recent_pages
.drain()
.for_each(|(ingredient, page)| table.record_unfilled_page(ingredient, page));
}
pub(crate) fn allocate<'db, T: Slot>(
&self,
zalsa: &'db Zalsa,
ingredient: IngredientIndex,
mut value: impl FnOnce(Id) -> T,
) -> (Id, &'db T) {
let most_recent_pages = unsafe { &mut *self.most_recent_pages.get() };
if let Some(&page) = most_recent_pages.get(&ingredient) {
let page_ref = zalsa.table().page::<T>(page);
match unsafe { page_ref.allocate(page, value) } {
Ok((id, value)) => return (id, value),
Err(v) => value = v,
}
}
self.allocate_cold(zalsa, ingredient, value)
}
#[cold]
#[inline(never)]
pub(crate) fn allocate_cold<'db, T: Slot>(
&self,
zalsa: &'db Zalsa,
ingredient: IngredientIndex,
mut value: impl FnOnce(Id) -> T,
) -> (Id, &'db T) {
let memo_types = || {
zalsa
.lookup_ingredient(ingredient)
.memo_table_types()
.clone()
};
let most_recent_pages = unsafe { &mut *self.most_recent_pages.get() };
let mut page = *most_recent_pages.entry(ingredient).or_insert_with(|| {
zalsa
.table()
.fetch_or_push_page::<T>(ingredient, memo_types)
});
loop {
let page_ref = zalsa.table().page::<T>(page);
match unsafe { page_ref.allocate(page, value) } {
Ok((id, value)) => return (id, value),
Err(v) => {
value = v;
page = zalsa.table().push_page::<T>(ingredient, memo_types());
most_recent_pages.insert(ingredient, page);
}
}
}
}
#[inline]
pub(crate) fn push_query(&self, database_key_index: DatabaseKeyIndex) -> ActiveQueryGuard<'_> {
unsafe {
self.with_query_stack_unchecked_mut(|stack| {
stack.push_new_query(database_key_index);
ActiveQueryGuard {
local_state: self,
database_key_index,
#[cfg(debug_assertions)]
push_len: stack.len(),
}
})
}
}
#[inline(always)]
pub(crate) unsafe fn with_query_stack_unchecked_mut<R>(
&self,
f: impl UnwindSafe + FnOnce(&mut QueryStack) -> R,
) -> R {
unsafe { f(&mut self.query_stack.try_borrow_mut().unwrap_unchecked()) }
}
#[inline(always)]
pub(crate) unsafe fn with_query_stack_unchecked<R>(
&self,
f: impl UnwindSafe + FnOnce(&QueryStack) -> R,
) -> R {
unsafe { f(&self.query_stack.try_borrow().unwrap_unchecked()) }
}
#[inline(always)]
pub(crate) fn try_with_query_stack<R>(
&self,
f: impl UnwindSafe + FnOnce(&QueryStack) -> R,
) -> Option<R> {
self.query_stack
.try_borrow()
.ok()
.as_ref()
.map(|stack| f(stack))
}
pub(crate) fn active_query(&self) -> Option<(DatabaseKeyIndex, Stamp)> {
unsafe {
self.with_query_stack_unchecked(|stack| {
stack
.last()
.map(|active_query| (active_query.database_key_index, active_query.stamp()))
})
}
}
#[cfg(feature = "accumulator")]
pub(crate) fn accumulate<A: Accumulator>(
&self,
index: IngredientIndex,
value: A,
) -> Result<(), ()> {
unsafe {
self.with_query_stack_unchecked_mut(|stack| {
if let Some(top_query) = stack.last_mut() {
top_query.accumulate(index, value);
Ok(())
} else {
Err(())
}
})
}
}
pub(crate) fn add_output(&self, entity: DatabaseKeyIndex) {
unsafe {
self.with_query_stack_unchecked_mut(|stack| {
if let Some(top_query) = stack.last_mut() {
top_query.add_output(entity)
}
})
}
}
pub(crate) fn is_tracked_struct_of_active_query(&self, entity: DatabaseKeyIndex) -> bool {
unsafe {
self.with_query_stack_unchecked_mut(|stack| {
stack
.last_mut()
.is_some_and(|top_query| top_query.tracked_struct_ids().is_active(entity))
})
}
}
#[inline(always)]
pub(crate) fn report_tracked_read(
&self,
input: DatabaseKeyIndex,
durability: Durability,
changed_at: Revision,
cycle_heads: &CycleHeads,
#[cfg(feature = "accumulator")] has_accumulated: bool,
#[cfg(feature = "accumulator")] accumulated_inputs: &AtomicInputAccumulatedValues,
) {
crate::tracing::debug!(
"report_tracked_read(input={:?}, durability={:?}, changed_at={:?})",
input,
durability,
changed_at
);
unsafe {
self.with_query_stack_unchecked_mut(|stack| {
if let Some(top_query) = stack.last_mut() {
top_query.add_read(
input,
durability,
changed_at,
cycle_heads,
#[cfg(feature = "accumulator")]
has_accumulated,
#[cfg(feature = "accumulator")]
accumulated_inputs,
);
}
})
}
}
#[inline(always)]
pub(crate) fn report_tracked_read_simple(
&self,
input: DatabaseKeyIndex,
durability: Durability,
changed_at: Revision,
) {
crate::tracing::debug!(
"report_tracked_read(input={:?}, durability={:?}, changed_at={:?})",
input,
durability,
changed_at
);
unsafe {
self.with_query_stack_unchecked_mut(|stack| {
if let Some(top_query) = stack.last_mut() {
top_query.add_read_simple(input, durability, changed_at);
}
})
}
}
#[inline(always)]
pub(crate) fn report_tracked_read_revision(&self, changed_at: Revision) {
unsafe {
self.with_query_stack_unchecked_mut(|stack| {
if let Some(top_query) = stack.last_mut() {
top_query.add_changed_at(changed_at);
}
})
}
}
#[inline(always)]
pub(crate) fn report_untracked_read(&self, current_revision: Revision) {
unsafe {
self.with_query_stack_unchecked_mut(|stack| {
if let Some(top_query) = stack.last_mut() {
top_query.add_untracked_read(current_revision);
}
})
}
}
#[track_caller]
pub(crate) fn disambiguate(&self, key: IdentityHash) -> (Stamp, Disambiguator) {
unsafe {
self.with_query_stack_unchecked_mut(|stack| {
let top_query = stack.last_mut().expect(
"cannot create a tracked struct disambiguator outside of a tracked function",
);
let disambiguator = top_query.disambiguate(key);
(top_query.stamp(), disambiguator)
})
}
}
#[track_caller]
pub(crate) fn tracked_struct_id(&self, identity: &Identity) -> Option<Id> {
unsafe {
self.with_query_stack_unchecked_mut(|stack| {
let top_query = stack
.last_mut()
.expect("cannot create a tracked struct ID outside of a tracked function");
top_query.tracked_struct_ids_mut().reuse(identity)
})
}
}
#[track_caller]
pub(crate) fn store_tracked_struct_id(&self, identity: Identity, id: Id) {
unsafe {
self.with_query_stack_unchecked_mut(|stack| {
let top_query = stack
.last_mut()
.expect("cannot store a tracked struct ID outside of a tracked function");
top_query.tracked_struct_ids_mut().insert(identity, id);
})
}
}
#[inline]
pub(crate) fn cancellation_token(&self) -> CancellationToken {
self.cancelled.clone()
}
#[inline]
pub(crate) fn uncancel(&self) {
self.cancelled.reset();
}
#[inline]
pub fn should_trigger_local_cancellation(&self) -> bool {
self.cancelled.should_trigger_local_cancellation()
}
#[cold]
pub(crate) fn unwind_pending_write(&self) {
Cancelled::PendingWrite.throw();
}
#[cold]
pub(crate) fn unwind_cancelled(&self) {
Cancelled::Local.throw();
}
#[inline]
pub(crate) fn set_cancellation_disabled(&self, was_disabled: bool) -> bool {
self.cancelled.set_cancellation_disabled(was_disabled)
}
}
impl std::panic::RefUnwindSafe for ZalsaLocal {}
#[derive(Debug)]
pub(crate) struct QueryRevisions {
pub(crate) changed_at: Revision,
pub(crate) durability: Durability,
pub(crate) origin_and_extra: OriginAndExtra,
#[cfg(feature = "accumulator")]
pub(super) accumulated_inputs: AtomicInputAccumulatedValues,
pub(super) verified_final: AtomicBool,
}
impl QueryRevisions {
#[inline]
pub(crate) const fn origin(&self) -> QueryOriginRef<'_> {
self.origin_and_extra.origin()
}
#[inline]
pub(crate) const fn is_derived_untracked(&self) -> bool {
self.origin_and_extra.is_derived_untracked()
}
#[cfg(not(feature = "persistence"))]
pub(crate) fn discard_edges_if_never_change(&mut self) {
if self.durability != Durability::NEVER_CHANGE
|| !matches!(self.origin(), QueryOriginRef::Derived(_))
|| !self.cycle_heads().is_empty()
{
return;
}
#[cfg(feature = "accumulator")]
if self.accumulated_inputs.load().is_any() {
return;
}
self.origin_and_extra.clear_edges();
}
#[cfg(feature = "salsa_unstable")]
pub(crate) fn allocation_size(&self) -> usize {
let QueryRevisions {
changed_at: _,
durability: _,
verified_final: _,
origin_and_extra,
#[cfg(feature = "accumulator")]
accumulated_inputs: _,
} = self;
origin_and_extra.allocation_size()
}
}
#[derive(Debug, Default)]
#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "persistence", serde(transparent))]
pub(crate) struct QueryRevisionsExtra(Option<QueryRevisionsExtraInner>);
impl QueryRevisionsExtra {
pub fn new(
#[cfg(feature = "accumulator")] accumulated: AccumulatedMap,
mut tracked_struct_ids: ThinVec<(Identity, Id)>,
cycle_heads: CycleHeads,
iteration: IterationStamp,
force_extra: bool,
) -> Self {
#[cfg(feature = "accumulator")]
let accumulated_is_empty = accumulated.is_empty();
#[cfg(not(feature = "accumulator"))]
let accumulated_is_empty = true;
let inner = if !force_extra
&& accumulated_is_empty
&& tracked_struct_ids.is_empty()
&& cycle_heads.is_empty()
&& iteration.is_default()
{
None
} else {
tracked_struct_ids.shrink_to_fit();
Some(QueryRevisionsExtraInner {
#[cfg(feature = "accumulator")]
accumulated,
cycle_heads,
tracked_struct_ids,
iteration: iteration.into(),
cycle_converged: false,
})
};
Self(inner)
}
}
#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))]
struct QueryRevisionsExtraInner {
#[cfg(feature = "accumulator")]
#[cfg_attr(feature = "persistence", serde(skip))] accumulated: AccumulatedMap,
tracked_struct_ids: ThinVec<(Identity, Id)>,
cycle_heads: CycleHeads,
#[cfg_attr(feature = "persistence", serde(skip))]
iteration: AtomicIterationStamp,
#[cfg_attr(feature = "persistence", serde(skip))]
cycle_converged: bool,
}
impl QueryRevisionsExtraInner {
fn empty() -> Self {
QueryRevisionsExtraInner {
#[cfg(feature = "accumulator")]
accumulated: AccumulatedMap::default(),
tracked_struct_ids: ThinVec::default(),
cycle_heads: empty_cycle_heads().clone(),
iteration: IterationStamp::default().into(),
cycle_converged: false,
}
}
#[cfg(feature = "salsa_unstable")]
fn allocation_size(&self) -> usize {
let QueryRevisionsExtraInner {
#[cfg(feature = "accumulator")]
accumulated,
tracked_struct_ids,
cycle_heads,
iteration: _,
cycle_converged: _,
} = self;
#[cfg(feature = "accumulator")]
let b = accumulated.allocation_size();
#[cfg(not(feature = "accumulator"))]
let b = 0;
b + cycle_heads.allocation_size() + std::mem::size_of_val(tracked_struct_ids.as_slice())
}
}
impl fmt::Debug for QueryRevisionsExtraInner {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
struct FmtTrackedStructIds<'a>(&'a ThinVec<(Identity, Id)>);
impl fmt::Debug for FmtTrackedStructIds<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let mut f = f.debug_list();
if self.0.len() > 5 {
f.entries(&self.0[..5]);
f.finish_non_exhaustive()
} else {
f.entries(self.0);
f.finish()
}
}
}
let mut f = f.debug_struct("QueryRevisionsExtraInner");
f.field("cycle_heads", &self.cycle_heads)
.field("iteration", &self.iteration)
.field("cycle_converged", &self.cycle_converged);
#[cfg(feature = "accumulator")]
{
f.field("accumulated", &self.accumulated);
}
f.field(
"tracked_struct_ids",
&FmtTrackedStructIds(&self.tracked_struct_ids),
);
f.finish()
}
}
#[cfg(not(feature = "shuttle"))]
#[cfg(target_pointer_width = "64")]
const _: [(); std::mem::size_of::<QueryRevisions>()] = [(); std::mem::size_of::<[usize; 3]>()];
#[cfg(not(feature = "shuttle"))]
#[cfg(target_pointer_width = "64")]
const _: [(); std::mem::size_of::<QueryRevisionsExtraInner>()] =
[(); std::mem::size_of::<[usize; if cfg!(feature = "accumulator") { 7 } else { 3 }]>()];
impl QueryRevisions {
pub(crate) fn fixpoint_initial(query: DatabaseKeyIndex, iteration: IterationStamp) -> Self {
Self {
changed_at: Revision::start(),
durability: Durability::MAX,
origin_and_extra: OriginAndExtra::derived(
std::iter::empty(),
QueryRevisionsExtra::new(
#[cfg(feature = "accumulator")]
AccumulatedMap::default(),
ThinVec::default(),
CycleHeads::initial(query, iteration),
iteration,
false,
),
),
#[cfg(feature = "accumulator")]
accumulated_inputs: Default::default(),
verified_final: AtomicBool::new(false),
}
}
#[cfg(feature = "accumulator")]
pub(crate) fn accumulated(&self) -> Option<&AccumulatedMap> {
self.origin_and_extra
.extra()
.map(|extra| &extra.accumulated)
.filter(|map| !map.is_empty())
}
pub(crate) fn cycle_heads(&self) -> &CycleHeads {
match self.origin_and_extra.extra() {
Some(extra) => &extra.cycle_heads,
None => empty_cycle_heads(),
}
}
pub(crate) fn set_cycle_heads(&mut self, cycle_heads: CycleHeads, iteration: IterationStamp) {
let extra = self.origin_and_extra.get_or_insert_extra();
extra.cycle_heads = cycle_heads;
extra.iteration = iteration.into();
}
pub(crate) const fn cycle_converged(&self) -> bool {
match self.origin_and_extra.extra() {
Some(extra) => extra.cycle_converged,
None => false,
}
}
pub(crate) const fn set_cycle_converged(&mut self, cycle_converged: bool) {
if let Some(extra) = self.origin_and_extra.extra_mut() {
extra.cycle_converged = cycle_converged
}
}
pub(crate) fn iteration(&self) -> IterationStamp {
match self.origin_and_extra.extra() {
Some(extra) => extra.iteration.load(),
None => IterationStamp::default(),
}
}
pub(crate) fn set_iteration_count(
&self,
database_key_index: DatabaseKeyIndex,
iteration: IterationStamp,
) {
let Some(extra) = self.origin_and_extra.extra() else {
return;
};
debug_assert!(extra.iteration.load() <= iteration);
extra.iteration.store_iteration(iteration);
extra
.cycle_heads
.update_iteration_count(database_key_index, iteration);
}
const fn extra(&self) -> Option<&QueryRevisionsExtraInner> {
self.origin_and_extra.extra()
}
pub fn tracked_struct_ids(&self) -> &[(Identity, Id)] {
self.extra()
.map(|extra| &*extra.tracked_struct_ids)
.unwrap_or_default()
}
pub fn tracked_struct_ids_mut(&mut self) -> Option<&mut ThinVec<(Identity, Id)>> {
self.origin_and_extra
.extra_mut()
.map(|extra| &mut extra.tracked_struct_ids)
.filter(|tracked_struct_ids| !tracked_struct_ids.is_empty())
}
}
#[derive(Debug, Clone, Copy)]
pub enum QueryOriginRef<'a> {
Assigned(DatabaseKeyIndex),
Derived(QueryEdges<'a>),
DerivedUntracked(QueryEdges<'a>),
}
impl<'a> QueryOriginRef<'a> {
#[inline]
pub(crate) fn inputs(self) -> impl DoubleEndedIterator<Item = DatabaseKeyIndex> + use<'a> {
self.edges().iter().filter_map(|edge| match edge.kind() {
QueryEdgeKind::Input => Some(edge.key()),
QueryEdgeKind::Output => None,
})
}
pub(crate) fn outputs(self) -> impl DoubleEndedIterator<Item = DatabaseKeyIndex> + use<'a> {
let opt_edges = match self {
QueryOriginRef::Derived(edges) | QueryOriginRef::DerivedUntracked(edges) => Some(edges),
QueryOriginRef::Assigned(_) => None,
};
opt_edges.into_iter().flat_map(output_edges)
}
#[inline]
pub(crate) const fn edges(self) -> QueryEdges<'a> {
match self {
QueryOriginRef::Derived(edges) | QueryOriginRef::DerivedUntracked(edges) => edges,
QueryOriginRef::Assigned(_) => QueryEdges::wide(&[]),
}
}
}
#[derive(Clone, Copy)]
#[repr(u8)]
enum QueryOriginKind {
Assigned = 0b01,
Derived = 0b11,
DerivedUntracked = 0b10,
}
#[derive(Clone, Copy)]
#[repr(u8)]
enum DerivedOriginKind {
Derived = QueryOriginKind::Derived as u8,
DerivedUntracked = QueryOriginKind::DerivedUntracked as u8,
}
#[derive(Clone, Copy)]
#[repr(u8)]
enum QueryEdgeLayout {
Packed = 0b000,
Wide = 0b100,
}
#[derive(Clone, Copy)]
#[repr(transparent)]
struct QueryOriginTag(u8);
impl QueryOriginTag {
const KIND_MASK: u8 = 0b011;
const LAYOUT_MASK: u8 = 0b100;
const fn assigned() -> Self {
QueryOriginTag(QueryOriginKind::Assigned as u8)
}
const fn derived(kind: DerivedOriginKind, layout: QueryEdgeLayout) -> Self {
QueryOriginTag(kind as u8 | layout as u8)
}
const fn kind(self) -> QueryOriginKind {
match self.0 & Self::KIND_MASK {
0b01 => QueryOriginKind::Assigned,
0b11 => QueryOriginKind::Derived,
0b10 => QueryOriginKind::DerivedUntracked,
_ => panic!("invalid query origin kind"),
}
}
const fn layout(self) -> QueryEdgeLayout {
if self.0 & Self::LAYOUT_MASK == 0 {
QueryEdgeLayout::Packed
} else {
QueryEdgeLayout::Wide
}
}
}
#[repr(Rust, packed)]
pub(crate) struct OriginAndExtra {
tag: OriginAndExtraTag,
payload: OriginAndExtraPayload,
metadata: u32,
}
const _: [(); std::mem::size_of::<OriginAndExtra>()] = [(); 13];
unsafe impl Send for OriginAndExtra
where
Id: Send,
QueryRevisionsExtraInner: Send,
PackedQueryEdge: Send,
QueryEdge: Send,
{
}
unsafe impl Sync for OriginAndExtra
where
Id: Sync,
QueryRevisionsExtraInner: Sync,
PackedQueryEdge: Sync,
QueryEdge: Sync,
{
}
impl OriginAndExtra {
#[inline]
pub(crate) fn derived<I>(input_outputs: I, extra: QueryRevisionsExtra) -> Self
where
I: ExactSizeIterator<Item = QueryEdge>,
{
Self::new_derived_with_kind(input_outputs, DerivedOriginKind::Derived, extra)
}
#[inline]
pub(crate) fn derived_untracked<I>(input_outputs: I, extra: QueryRevisionsExtra) -> Self
where
I: ExactSizeIterator<Item = QueryEdge>,
{
Self::new_derived_with_kind(input_outputs, DerivedOriginKind::DerivedUntracked, extra)
}
pub(crate) const fn assigned(key: DatabaseKeyIndex) -> Self {
Self {
tag: OriginAndExtraTag::without_extra(QueryOriginTag::assigned()),
payload: OriginAndExtraPayload {
index: key.key_index(),
},
metadata: key.ingredient_index().as_u32(),
}
}
fn assigned_with_extra(key: DatabaseKeyIndex, extra: QueryRevisionsExtraInner) -> Self {
let allocation = Box::into_raw(Box::new(AssignedOriginAndExtra {
extra,
index: key.key_index(),
}));
Self {
tag: OriginAndExtraTag::with_extra(QueryOriginTag::assigned()),
payload: OriginAndExtraPayload {
allocation: NonNull::new(allocation).unwrap().cast(),
},
metadata: key.ingredient_index().as_u32(),
}
}
#[inline]
fn new_derived_with_kind<I>(
input_outputs: I,
kind: DerivedOriginKind,
extra: QueryRevisionsExtra,
) -> Self
where
I: ExactSizeIterator<Item = QueryEdge>,
{
match extra.0 {
Some(extra) => Self::new_derived_with_extra(input_outputs, kind, extra),
None => Self::new_derived_without_extra(input_outputs, kind),
}
}
#[inline]
fn new_derived_without_extra(
input_outputs: impl ExactSizeIterator<Item = QueryEdge>,
kind: DerivedOriginKind,
) -> Self {
let (edge_layout, allocation, metadata) =
unsafe { Self::allocate_derived_with_header(input_outputs, ()) };
Self {
tag: OriginAndExtraTag::without_extra(QueryOriginTag::derived(kind, edge_layout)),
payload: OriginAndExtraPayload { allocation },
metadata,
}
}
#[inline]
fn new_derived_with_extra(
input_outputs: impl ExactSizeIterator<Item = QueryEdge>,
kind: DerivedOriginKind,
extra: QueryRevisionsExtraInner,
) -> Self {
let (edge_layout, allocation, metadata) =
unsafe { Self::allocate_derived_with_header(input_outputs, extra) };
Self {
tag: OriginAndExtraTag::with_extra(QueryOriginTag::derived(kind, edge_layout)),
payload: OriginAndExtraPayload { allocation },
metadata,
}
}
#[inline]
unsafe fn allocate_derived_with_header<H, I>(
mut input_outputs: I,
header: H,
) -> (QueryEdgeLayout, NonNull<()>, u32)
where
I: ExactSizeIterator<Item = QueryEdge>,
{
let length = input_outputs.len();
let metadata = u32::try_from(length)
.expect("exceeded more than `u32::MAX` query edges; this should never happen");
let mut packed = SliceWithHeader::allocate(length);
for edge in input_outputs.by_ref() {
let Some(edge) = PackedQueryEdge::new(edge) else {
let mut wide = SliceWithHeader::allocate(length);
wide.extend(
packed
.initialized_slice()
.iter()
.copied()
.map(PackedQueryEdge::edge),
);
drop(packed);
wide.push(edge);
wide.extend(input_outputs);
let allocation = wide.finish(header).into_raw();
return (QueryEdgeLayout::Wide, allocation, metadata);
};
packed.push(edge);
}
let allocation = packed.finish(header).into_raw();
(QueryEdgeLayout::Packed, allocation, metadata)
}
const fn extra(&self) -> Option<&QueryRevisionsExtraInner> {
match self.tag.layout() {
OriginAndExtraLayout::WithExtra => {
Some(unsafe {
self.payload
.allocation
.cast::<QueryRevisionsExtraInner>()
.as_ref()
})
}
OriginAndExtraLayout::WithoutExtra => None,
}
}
const fn extra_mut(&mut self) -> Option<&mut QueryRevisionsExtraInner> {
match self.tag.layout() {
OriginAndExtraLayout::WithExtra => {
Some(unsafe {
self.payload
.allocation
.cast::<QueryRevisionsExtraInner>()
.as_mut()
})
}
OriginAndExtraLayout::WithoutExtra => None,
}
}
fn get_or_insert_extra(&mut self) -> &mut QueryRevisionsExtraInner {
if matches!(self.tag.layout(), OriginAndExtraLayout::WithoutExtra) {
let extra = QueryRevisionsExtraInner::empty();
let replacement = match self.origin() {
QueryOriginRef::Assigned(key) => Self::assigned_with_extra(key, extra),
QueryOriginRef::Derived(edges) => {
Self::derived(edges.iter(), QueryRevisionsExtra(Some(extra)))
}
QueryOriginRef::DerivedUntracked(edges) => {
Self::derived_untracked(edges.iter(), QueryRevisionsExtra(Some(extra)))
}
};
*self = replacement;
}
self.extra_mut().unwrap()
}
#[cfg(not(feature = "persistence"))]
fn clear_edges(&mut self) {
if self.metadata == 0 {
return;
}
let kind = match self.tag.origin().kind() {
QueryOriginKind::Assigned => panic!("assigned query origins have no edges"),
QueryOriginKind::Derived => DerivedOriginKind::Derived,
QueryOriginKind::DerivedUntracked => DerivedOriginKind::DerivedUntracked,
};
let extra = self
.extra_mut()
.map(|extra| std::mem::replace(extra, QueryRevisionsExtraInner::empty()));
*self = Self::new_derived_with_kind([].into_iter(), kind, QueryRevisionsExtra(extra));
}
const fn is_derived_untracked(&self) -> bool {
matches!(self.tag.origin().kind(), QueryOriginKind::DerivedUntracked)
}
const fn origin(&self) -> QueryOriginRef<'_> {
let tag = self.tag.origin();
match tag.kind() {
QueryOriginKind::Assigned => {
let index = match self.tag.layout() {
OriginAndExtraLayout::WithoutExtra => {
unsafe { self.payload.index }
}
OriginAndExtraLayout::WithExtra => {
unsafe {
self.payload
.allocation
.cast::<AssignedOriginAndExtra>()
.as_ref()
.index
}
}
};
let ingredient = unsafe { IngredientIndex::new_unchecked(self.metadata) };
QueryOriginRef::Assigned(DatabaseKeyIndex::new(ingredient, index))
}
QueryOriginKind::Derived | QueryOriginKind::DerivedUntracked => {
let length = self.metadata as usize;
let edges = unsafe {
let allocation = self.payload.allocation;
match (self.tag.layout(), tag.layout()) {
(OriginAndExtraLayout::WithoutExtra, QueryEdgeLayout::Packed) => {
QueryEdges::packed(
SliceWithHeader::<(), PackedQueryEdge>::slice(allocation, length)
.as_ref(),
)
}
(OriginAndExtraLayout::WithoutExtra, QueryEdgeLayout::Wide) => {
QueryEdges::wide(
SliceWithHeader::<(), QueryEdge>::slice(allocation, length).as_ref(),
)
}
(OriginAndExtraLayout::WithExtra, QueryEdgeLayout::Packed) => {
QueryEdges::packed(
SliceWithHeader::<QueryRevisionsExtraInner, PackedQueryEdge>::slice(
allocation, length,
)
.as_ref(),
)
}
(OriginAndExtraLayout::WithExtra, QueryEdgeLayout::Wide) => {
QueryEdges::wide(
SliceWithHeader::<QueryRevisionsExtraInner, QueryEdge>::slice(
allocation, length,
)
.as_ref(),
)
}
}
};
match tag.kind() {
QueryOriginKind::Derived => QueryOriginRef::Derived(edges),
QueryOriginKind::DerivedUntracked => QueryOriginRef::DerivedUntracked(edges),
QueryOriginKind::Assigned => unreachable!(),
}
}
}
}
#[cfg(any(test, feature = "salsa_unstable"))]
fn allocation_size(&self) -> usize {
let tag = self.tag.origin();
let memory = match (self.tag.layout(), tag.kind(), tag.layout()) {
(OriginAndExtraLayout::WithoutExtra, QueryOriginKind::Assigned, _) => 0,
(OriginAndExtraLayout::WithExtra, QueryOriginKind::Assigned, _) => {
std::mem::size_of::<AssignedOriginAndExtra>()
}
(_, _, QueryEdgeLayout::Packed) => self.derived_allocation_size::<PackedQueryEdge>(),
(_, _, QueryEdgeLayout::Wide) => self.derived_allocation_size::<QueryEdge>(),
};
#[cfg(feature = "salsa_unstable")]
if let Some(extra) = self.extra() {
return memory + extra.allocation_size();
}
memory
}
#[cfg(any(test, feature = "salsa_unstable"))]
const fn derived_allocation_size<E: Copy>(&self) -> usize {
let length = self.metadata as usize;
match self.tag.layout() {
OriginAndExtraLayout::WithoutExtra => SliceWithHeader::<(), E>::layout(length).0.size(),
OriginAndExtraLayout::WithExtra => {
SliceWithHeader::<QueryRevisionsExtraInner, E>::layout(length)
.0
.size()
}
}
}
}
impl Drop for OriginAndExtra {
fn drop(&mut self) {
let tag = self.tag.origin();
match (tag.kind(), self.tag.layout(), tag.layout()) {
(QueryOriginKind::Assigned, OriginAndExtraLayout::WithoutExtra, _) => {}
(QueryOriginKind::Assigned, OriginAndExtraLayout::WithExtra, _) => {
let allocation = unsafe { self.payload.allocation };
drop(unsafe {
Box::from_raw(allocation.cast::<AssignedOriginAndExtra>().as_ptr())
});
}
(QueryOriginKind::Derived | QueryOriginKind::DerivedUntracked, layout, edge_layout) => {
let length = self.metadata as usize;
unsafe {
let allocation = self.payload.allocation;
match (layout, edge_layout) {
(OriginAndExtraLayout::WithoutExtra, QueryEdgeLayout::Packed) => drop(
SliceWithHeader::<(), PackedQueryEdge>::from_raw_parts(
allocation, length,
),
),
(OriginAndExtraLayout::WithoutExtra, QueryEdgeLayout::Wide) => drop(
SliceWithHeader::<(), QueryEdge>::from_raw_parts(allocation, length),
),
(OriginAndExtraLayout::WithExtra, QueryEdgeLayout::Packed) => drop(
SliceWithHeader::<QueryRevisionsExtraInner, PackedQueryEdge>::from_raw_parts(
allocation, length,
),
),
(OriginAndExtraLayout::WithExtra, QueryEdgeLayout::Wide) => drop(
SliceWithHeader::<QueryRevisionsExtraInner, QueryEdge>::from_raw_parts(
allocation, length,
),
),
}
}
}
}
}
}
impl std::fmt::Debug for OriginAndExtra {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(extra) = self.extra() {
f.debug_struct("OriginAndExtra")
.field("origin", &self.origin())
.field("extra", extra)
.finish()
} else {
self.origin().fmt(f)
}
}
}
#[derive(Clone, Copy)]
#[repr(u8)]
enum OriginAndExtraLayout {
WithoutExtra = 0b0000,
WithExtra = 0b1000,
}
#[derive(Clone, Copy)]
#[repr(transparent)]
struct OriginAndExtraTag(u8);
impl OriginAndExtraTag {
const WITH_EXTRA_MASK: u8 = OriginAndExtraLayout::WithExtra as u8;
const fn without_extra(origin: QueryOriginTag) -> Self {
OriginAndExtraTag(origin.0)
}
const fn with_extra(origin: QueryOriginTag) -> Self {
OriginAndExtraTag(OriginAndExtraLayout::WithExtra as u8 | origin.0)
}
const fn layout(self) -> OriginAndExtraLayout {
if self.0 & Self::WITH_EXTRA_MASK == 0 {
OriginAndExtraLayout::WithoutExtra
} else {
OriginAndExtraLayout::WithExtra
}
}
const fn origin(self) -> QueryOriginTag {
QueryOriginTag(self.0 & !Self::WITH_EXTRA_MASK)
}
}
union OriginAndExtraPayload {
allocation: NonNull<()>,
index: Id,
}
#[repr(C)]
struct AssignedOriginAndExtra {
extra: QueryRevisionsExtraInner,
index: Id,
}
struct SliceWithHeader<H, T: Copy> {
allocation: NonNull<()>,
layout: Layout,
marker: PhantomData<(H, T)>,
}
impl<H, T: Copy> SliceWithHeader<H, T> {
#[inline]
fn allocate(length: usize) -> SliceWithHeaderBuilder<H, T> {
let (layout, slice_offset) = Self::layout(length);
let allocation = if layout.size() == 0 {
if std::mem::align_of::<H>() >= std::mem::align_of::<T>() {
NonNull::<H>::dangling().cast()
} else {
NonNull::<T>::dangling().cast()
}
} else {
unsafe { NonNull::new(alloc(layout)) }.unwrap_or_else(|| handle_alloc_error(layout))
}
.cast();
let slice = unsafe { allocation.cast::<u8>().byte_add(slice_offset).cast::<T>() };
SliceWithHeaderBuilder {
allocation,
slice,
layout,
length,
initialized: 0,
marker: PhantomData,
}
}
const fn layout(length: usize) -> (Layout, usize) {
let slice = match Layout::array::<T>(length) {
Ok(slice) => slice,
Err(_) => panic!("slice allocation is too large"),
};
let (layout, slice_offset) = match Layout::new::<H>().extend(slice) {
Ok(layout) => layout,
Err(_) => panic!("slice-with-header allocation is too large"),
};
(layout, slice_offset)
}
const unsafe fn slice(allocation: NonNull<()>, length: usize) -> NonNull<[T]> {
let (_, slice_offset) = Self::layout(length);
let slice = unsafe { allocation.cast::<u8>().byte_add(slice_offset).cast::<T>() };
NonNull::slice_from_raw_parts(slice, length)
}
const unsafe fn from_raw_parts(allocation: NonNull<()>, length: usize) -> Self {
SliceWithHeader {
allocation,
layout: Self::layout(length).0,
marker: PhantomData,
}
}
fn into_raw(self) -> NonNull<()> {
ManuallyDrop::new(self).allocation
}
}
impl<H, T: Copy> Drop for SliceWithHeader<H, T> {
fn drop(&mut self) {
unsafe { ptr::drop_in_place(self.allocation.cast::<H>().as_ptr()) };
if self.layout.size() != 0 {
unsafe { dealloc(self.allocation.cast().as_ptr(), self.layout) };
}
}
}
struct SliceWithHeaderBuilder<H, T: Copy> {
allocation: NonNull<()>,
slice: NonNull<T>,
layout: Layout,
length: usize,
initialized: usize,
marker: PhantomData<fn() -> H>,
}
impl<H, T: Copy> SliceWithHeaderBuilder<H, T> {
#[inline]
const fn push(&mut self, item: T) {
assert!(
self.initialized < self.length,
"attempted to initialize more slice elements than allocated"
);
unsafe { self.slice.add(self.initialized).write(item) };
self.initialized += 1;
}
#[inline]
fn extend(&mut self, items: impl IntoIterator<Item = T>) {
for item in items {
self.push(item);
}
}
#[inline]
const fn initialized_slice(&self) -> &[T] {
unsafe { std::slice::from_raw_parts(self.slice.as_ptr(), self.initialized) }
}
#[inline]
fn finish(self, header: H) -> SliceWithHeader<H, T> {
assert_eq!(
self.initialized, self.length,
"not all allocated slice elements were initialized"
);
let this = ManuallyDrop::new(self);
unsafe { this.allocation.cast().write(header) };
SliceWithHeader {
allocation: this.allocation,
layout: this.layout,
marker: PhantomData,
}
}
}
impl<H, T: Copy> Drop for SliceWithHeaderBuilder<H, T> {
fn drop(&mut self) {
if self.layout.size() != 0 {
unsafe { dealloc(self.allocation.cast().as_ptr(), self.layout) };
}
}
}
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub struct QueryEdge {
index: u32,
generation: u32,
ingredient: IngredientIndex,
}
const _: [(); std::mem::size_of::<QueryEdge>()] = [(); 12];
impl QueryEdge {
pub const fn input(key: DatabaseKeyIndex) -> QueryEdge {
let id = key.key_index();
QueryEdge {
index: id.index(),
generation: id.generation(),
ingredient: key.ingredient_index(),
}
}
pub const fn output(key: DatabaseKeyIndex) -> QueryEdge {
let id = key.key_index();
QueryEdge {
index: id.index(),
generation: id.generation(),
ingredient: key.ingredient_index().with_tag(true),
}
}
pub const fn key(self) -> DatabaseKeyIndex {
DatabaseKeyIndex::new(self.ingredient.with_tag(false), self.id())
}
pub const fn kind(self) -> QueryEdgeKind {
if self.ingredient.tag() {
QueryEdgeKind::Output
} else {
QueryEdgeKind::Input
}
}
const fn id(self) -> Id {
unsafe { Id::from_index(self.index) }.with_generation(self.generation)
}
}
impl std::fmt::Debug for QueryEdge {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut tuple = match self.kind() {
QueryEdgeKind::Input => f.debug_tuple("Input"),
QueryEdgeKind::Output => f.debug_tuple("Output"),
};
tuple.field(&self.key()).finish()
}
}
#[derive(Copy, Clone)]
struct PackedQueryEdge {
index: u32,
metadata: u32,
}
const _: [(); std::mem::size_of::<PackedQueryEdge>()] = [(); 8];
impl PackedQueryEdge {
const INGREDIENT_SHIFT: u32 = 20;
const GENERATION_MASK: u32 = 0xFFFFF;
const INGREDIENT_MASK: u32 = 0xFFF;
const fn new(edge: QueryEdge) -> Option<Self> {
let ingredient = edge.ingredient.as_u32();
if ingredient > Self::INGREDIENT_MASK || edge.generation > Self::GENERATION_MASK {
return None;
}
Some(PackedQueryEdge {
index: edge.index,
metadata: edge.generation | (ingredient << Self::INGREDIENT_SHIFT),
})
}
const fn edge(self) -> QueryEdge {
QueryEdge {
index: self.index,
generation: self.metadata & Self::GENERATION_MASK,
ingredient: unsafe {
IngredientIndex::new_unchecked(self.metadata >> Self::INGREDIENT_SHIFT)
},
}
}
}
#[derive(Copy, Clone)]
pub struct QueryEdges<'a> {
data: QueryEdgesData<'a>,
}
#[derive(Copy, Clone)]
enum QueryEdgesData<'a> {
Packed(&'a [PackedQueryEdge]),
Wide(&'a [QueryEdge]),
}
impl<'a> QueryEdges<'a> {
const fn packed(edges: &'a [PackedQueryEdge]) -> Self {
QueryEdges {
data: QueryEdgesData::Packed(edges),
}
}
const fn wide(edges: &'a [QueryEdge]) -> Self {
QueryEdges {
data: QueryEdgesData::Wide(edges),
}
}
#[cfg(feature = "accumulator")]
pub(crate) const fn len(self) -> usize {
match self.data {
QueryEdgesData::Packed(edges) => edges.len(),
QueryEdgesData::Wide(edges) => edges.len(),
}
}
#[cfg(test)]
pub(crate) const fn allocation_size(self) -> usize {
match self.data {
QueryEdgesData::Packed(edges) => std::mem::size_of_val(edges),
QueryEdgesData::Wide(edges) => std::mem::size_of_val(edges),
}
}
pub(crate) fn iter(self) -> QueryEdgeIter<'a> {
let data = match self.data {
QueryEdgesData::Packed(edges) => QueryEdgeIterData::Packed(edges.iter()),
QueryEdgesData::Wide(edges) => QueryEdgeIterData::Wide(edges.iter()),
};
QueryEdgeIter { data }
}
pub(crate) fn iter_outputs(self) -> impl DoubleEndedIterator<Item = QueryEdge> + use<'a> {
let wide_edges = match self.data {
QueryEdgesData::Packed(_) => &[][..],
QueryEdgesData::Wide(edges) => edges,
};
wide_edges
.iter()
.copied()
.filter(|edge| matches!(edge.kind(), QueryEdgeKind::Output))
}
#[cfg(test)]
const fn is_packed(self) -> bool {
matches!(self.data, QueryEdgesData::Packed(_))
}
}
pub struct QueryEdgeIter<'a> {
data: QueryEdgeIterData<'a>,
}
enum QueryEdgeIterData<'a> {
Packed(std::slice::Iter<'a, PackedQueryEdge>),
Wide(std::slice::Iter<'a, QueryEdge>),
}
impl Iterator for QueryEdgeIter<'_> {
type Item = QueryEdge;
fn next(&mut self) -> Option<Self::Item> {
match &mut self.data {
QueryEdgeIterData::Packed(edges) => edges.next().copied().map(PackedQueryEdge::edge),
QueryEdgeIterData::Wide(edges) => edges.next().copied(),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.len();
(len, Some(len))
}
}
impl DoubleEndedIterator for QueryEdgeIter<'_> {
fn next_back(&mut self) -> Option<Self::Item> {
match &mut self.data {
QueryEdgeIterData::Packed(edges) => {
edges.next_back().copied().map(PackedQueryEdge::edge)
}
QueryEdgeIterData::Wide(edges) => edges.next_back().copied(),
}
}
}
impl ExactSizeIterator for QueryEdgeIter<'_> {
fn len(&self) -> usize {
match &self.data {
QueryEdgeIterData::Packed(edges) => edges.len(),
QueryEdgeIterData::Wide(edges) => edges.len(),
}
}
}
impl std::iter::FusedIterator for QueryEdgeIter<'_> {}
impl<'a> IntoIterator for QueryEdges<'a> {
type Item = QueryEdge;
type IntoIter = QueryEdgeIter<'a>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl std::fmt::Debug for QueryEdges<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_list().entries(self.iter()).finish()
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum QueryEdgeKind {
Input,
Output,
}
pub(crate) fn output_edges(
input_outputs: QueryEdges<'_>,
) -> impl DoubleEndedIterator<Item = DatabaseKeyIndex> + use<'_> {
input_outputs.iter_outputs().map(QueryEdge::key)
}
pub(crate) struct ActiveQueryGuard<'me> {
local_state: &'me ZalsaLocal,
#[cfg(debug_assertions)]
push_len: usize,
pub(crate) database_key_index: DatabaseKeyIndex,
}
impl<'me> ActiveQueryGuard<'me> {
pub(crate) fn seed_tracked_struct_ids(&self, tracked_struct_ids: &[(Identity, Id)]) {
unsafe {
self.local_state.with_query_stack_unchecked_mut(|stack| {
#[cfg(debug_assertions)]
assert_eq!(stack.len(), self.push_len, "mismatched push and pop");
let frame = stack.last_mut().unwrap();
frame.tracked_struct_ids_mut().seed(tracked_struct_ids);
})
}
}
pub(crate) fn seed_iteration(&self, previous: &QueryRevisions) {
let durability = previous.durability;
let changed_at = previous.changed_at;
let edges = previous.origin().edges();
let untracked_read = previous.is_derived_untracked();
let tracked_ids = previous.tracked_struct_ids();
unsafe {
self.local_state.with_query_stack_unchecked_mut(|stack| {
#[cfg(debug_assertions)]
assert_eq!(stack.len(), self.push_len, "mismatched push and pop");
let frame = stack.last_mut().unwrap();
frame.seed_iteration(durability, changed_at, edges, untracked_read, tracked_ids);
})
}
}
pub(crate) fn take_cycle_heads(&mut self) -> CycleHeads {
unsafe {
self.local_state.with_query_stack_unchecked_mut(|stack| {
#[cfg(debug_assertions)]
assert_eq!(stack.len(), self.push_len);
let frame = stack.last_mut().unwrap();
frame.take_cycle_heads()
})
}
}
pub(crate) fn detach(self) -> DetachedQuery<'me> {
let input_outputs = unsafe {
self.local_state.with_query_stack_unchecked_mut(|stack| {
#[cfg(debug_assertions)]
assert_eq!(stack.len(), self.push_len);
stack.last_mut().unwrap().detach_input_outputs()
})
};
DetachedQuery {
guard: self,
input_outputs,
}
}
fn complete(self, iteration: IterationStamp) -> CompletedQuery {
let query = unsafe {
self.local_state.with_query_stack_unchecked_mut(|stack| {
stack.pop_into_revisions(
self.database_key_index,
iteration,
#[cfg(debug_assertions)]
self.push_len,
)
})
};
std::mem::forget(self);
query
}
#[inline]
pub(crate) fn pop(self, iteration: IterationStamp) -> CompletedQuery {
self.complete(iteration)
}
fn pop_detached_completion(
self,
iteration: IterationStamp,
detached_input_outputs: DetachedInputOutputs,
force_extra: bool,
) -> QueryCompletion {
let completion = unsafe {
self.local_state
.with_query_stack_unchecked_mut(move |stack| {
stack.pop_detached_completion(
self.database_key_index,
iteration,
detached_input_outputs,
force_extra,
#[cfg(debug_assertions)]
self.push_len,
)
})
};
std::mem::forget(self);
completion
}
}
pub(crate) struct DetachedQuery<'me> {
guard: ActiveQueryGuard<'me>,
input_outputs: DetachedInputOutputs,
}
impl DetachedQuery<'_> {
pub(crate) fn input_outputs(&self) -> &crate::hash::FxIndexSet<QueryEdge> {
self.input_outputs.input_outputs()
}
pub(crate) fn pop_completion(
self,
iteration: IterationStamp,
force_extra: bool,
) -> QueryCompletion {
let Self {
guard,
input_outputs,
} = self;
guard.pop_detached_completion(iteration, input_outputs, force_extra)
}
}
impl Drop for ActiveQueryGuard<'_> {
fn drop(&mut self) {
unsafe {
self.local_state.with_query_stack_unchecked_mut(|stack| {
stack.pop(
self.database_key_index,
#[cfg(debug_assertions)]
self.push_len,
);
})
};
}
}
#[cfg(feature = "persistence")]
pub(crate) mod persistence {
#[cfg(test)]
use super::QueryOriginRef;
use super::{OriginAndExtra, QueryEdge, QueryEdges, QueryRevisions, QueryRevisionsExtra};
use crate::DatabaseKeyIndex;
use crate::sync::atomic::{AtomicBool, Ordering};
use crate::{Durability, Revision};
impl OriginAndExtra {
pub(crate) fn new(origin: PersistentQueryOrigin, extra: QueryRevisionsExtra) -> Self {
match origin {
PersistentQueryOrigin::Assigned(key) => match extra.0 {
Some(extra) => Self::assigned_with_extra(key, extra),
None => Self::assigned(key),
},
PersistentQueryOrigin::Derived(edges) => {
Self::derived(edges.iter().copied(), extra)
}
PersistentQueryOrigin::DerivedUntracked(edges) => {
Self::derived_untracked(edges.iter().copied(), extra)
}
}
}
}
impl QueryEdge {
const fn raw_key(self) -> DatabaseKeyIndex {
DatabaseKeyIndex::new(self.ingredient, self.id())
}
}
impl serde::Serialize for QueryEdge {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serde::Serialize::serialize(&self.raw_key(), serializer)
}
}
impl<'de> serde::Deserialize<'de> for QueryEdge {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let key: DatabaseKeyIndex = serde::Deserialize::deserialize(deserializer)?;
let id = key.key_index();
Ok(QueryEdge {
index: id.index(),
generation: id.generation(),
ingredient: key.ingredient_index(),
})
}
}
impl serde::Serialize for QueryEdges<'_> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.collect_seq(self.iter())
}
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
#[serde(rename = "QueryOrigin")]
pub(crate) enum PersistentQueryOrigin {
Assigned(DatabaseKeyIndex),
Derived(Box<[QueryEdge]>),
DerivedUntracked(Box<[QueryEdge]>),
}
impl PersistentQueryOrigin {
pub(crate) fn derived<I>(input_outputs: I) -> Self
where
I: IntoIterator<Item = QueryEdge>,
I::IntoIter: ExactSizeIterator,
{
Self::Derived(input_outputs.into_iter().collect())
}
pub(crate) fn derived_untracked<I>(input_outputs: I) -> Self
where
I: IntoIterator<Item = QueryEdge>,
I::IntoIter: ExactSizeIterator,
{
Self::DerivedUntracked(input_outputs.into_iter().collect())
}
pub(crate) const fn assigned(key: DatabaseKeyIndex) -> Self {
Self::Assigned(key)
}
#[cfg(test)]
pub(crate) const fn as_ref(&self) -> QueryOriginRef<'_> {
match self {
Self::Assigned(key) => QueryOriginRef::Assigned(*key),
Self::Derived(edges) => QueryOriginRef::Derived(QueryEdges::wide(edges)),
Self::DerivedUntracked(edges) => {
QueryOriginRef::DerivedUntracked(QueryEdges::wide(edges))
}
}
}
}
#[derive(serde::Serialize)]
pub(crate) struct MappedQueryRevisions<'a> {
changed_at: Revision,
durability: Durability,
#[serde(rename = "origin")]
origin: PersistentQueryOrigin,
#[serde(with = "verified_final")]
verified_final: AtomicBool,
extra: Option<&'a super::QueryRevisionsExtraInner>,
}
impl QueryRevisions {
pub(crate) fn with_origin(
&self,
serialized_origin: PersistentQueryOrigin,
) -> MappedQueryRevisions<'_> {
let QueryRevisions {
changed_at,
durability,
ref verified_final,
#[cfg(feature = "accumulator")]
accumulated_inputs: _, ref origin_and_extra,
} = *self;
MappedQueryRevisions {
changed_at,
durability,
extra: origin_and_extra.extra(),
origin: serialized_origin,
verified_final: AtomicBool::new(verified_final.load(Ordering::Relaxed)),
}
}
}
impl<'de> serde::Deserialize<'de> for QueryRevisions {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(serde::Deserialize)]
struct DeserializeQueryRevisions {
changed_at: Revision,
durability: Durability,
#[serde(rename = "origin")]
origin: PersistentQueryOrigin,
#[serde(with = "verified_final")]
verified_final: AtomicBool,
extra: QueryRevisionsExtra,
}
let revisions = DeserializeQueryRevisions::deserialize(deserializer)?;
Ok(QueryRevisions {
changed_at: revisions.changed_at,
durability: revisions.durability,
origin_and_extra: OriginAndExtra::new(revisions.origin, revisions.extra),
#[cfg(feature = "accumulator")]
accumulated_inputs: Default::default(), verified_final: revisions.verified_final,
})
}
}
pub(super) mod verified_final {
use crate::sync::atomic::{AtomicBool, Ordering};
pub fn serialize<S>(value: &AtomicBool, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serde::Serialize::serialize(&value.load(Ordering::Relaxed), serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<AtomicBool, D::Error>
where
D: serde::Deserializer<'de>,
{
serde::Deserialize::deserialize(deserializer).map(AtomicBool::new)
}
}
}
#[cfg(test)]
mod tests {
use std::mem::size_of;
#[cfg(feature = "persistence")]
use super::QueryEdgeKind;
#[cfg(feature = "persistence")]
use super::persistence::PersistentQueryOrigin;
#[cfg(not(feature = "shuttle"))]
use super::{
AssignedOriginAndExtra, QueryRevisionsExtra, QueryRevisionsExtraInner, SliceWithHeader,
};
use super::{OriginAndExtra, PackedQueryEdge, QueryEdge, QueryOriginRef};
use crate::{DatabaseKeyIndex, Id, IngredientIndex};
#[test]
fn query_origin_packs_edges_that_fit() {
let input = QueryEdge::input(key(231, 10_842_122, 41));
let other_input = QueryEdge::input(key(232, 10_842_123, 42));
let origin = OriginAndExtra::derived([input, other_input].into_iter(), Default::default());
let QueryOriginRef::Derived(edges) = origin.origin() else {
panic!("expected derived origin");
};
assert!(edges.is_packed());
assert_eq!(edges.allocation_size(), 2 * size_of::<PackedQueryEdge>());
assert_eq!(edges.iter().collect::<Vec<_>>(), vec![input, other_input]);
assert_eq!(
edges.iter().rev().collect::<Vec<_>>(),
vec![other_input, input]
);
assert_eq!(
origin.origin().inputs().collect::<Vec<_>>(),
vec![input.key(), other_input.key()]
);
assert_eq!(origin.origin().outputs().collect::<Vec<_>>(), vec![]);
}
#[test]
fn query_origin_spills_all_edges_if_it_contains_an_output() {
let input = QueryEdge::input(key(231, 10_842_122, 41));
let output = QueryEdge::output(key(232, 10_842_123, 42));
let origin = OriginAndExtra::derived([input, output].into_iter(), Default::default());
let QueryOriginRef::Derived(edges) = origin.origin() else {
panic!("expected derived origin");
};
assert!(!edges.is_packed());
assert_eq!(edges.allocation_size(), 2 * size_of::<QueryEdge>());
assert_eq!(edges.iter().collect::<Vec<_>>(), vec![input, output]);
assert_eq!(
origin.origin().inputs().collect::<Vec<_>>(),
vec![input.key()]
);
assert_eq!(
origin.origin().outputs().collect::<Vec<_>>(),
vec![output.key()]
);
}
#[test]
fn query_origin_packs_largest_supported_generation() {
let input = QueryEdge::input(key(231, 10_842_122, PackedQueryEdge::GENERATION_MASK));
let origin = OriginAndExtra::derived([input].into_iter(), Default::default());
let QueryOriginRef::Derived(edges) = origin.origin() else {
panic!("expected derived origin");
};
assert!(edges.is_packed());
assert_eq!(edges.iter().collect::<Vec<_>>(), vec![input]);
}
#[test]
fn query_origin_spills_all_edges_if_generation_does_not_fit() {
let packed = QueryEdge::input(key(231, 10_842_122, 41));
let wide = QueryEdge::input(key(232, 10_842_123, PackedQueryEdge::GENERATION_MASK + 1));
let origin = OriginAndExtra::derived([packed, wide].into_iter(), Default::default());
let QueryOriginRef::Derived(edges) = origin.origin() else {
panic!("expected derived origin");
};
assert!(!edges.is_packed());
assert_eq!(edges.allocation_size(), 2 * size_of::<QueryEdge>());
assert_eq!(edges.iter().collect::<Vec<_>>(), vec![packed, wide]);
}
#[test]
fn query_origin_spills_if_ingredient_does_not_fit() {
let wide = QueryEdge::input(key(PackedQueryEdge::INGREDIENT_MASK + 1, 10_842_122, 41));
let origin = OriginAndExtra::derived([wide].into_iter(), Default::default());
let QueryOriginRef::Derived(edges) = origin.origin() else {
panic!("expected derived origin");
};
assert!(!edges.is_packed());
assert_eq!(edges.iter().collect::<Vec<_>>(), vec![wide]);
}
#[cfg(all(not(feature = "persistence"), not(feature = "shuttle")))]
#[test]
fn clearing_edges_preserves_co_allocated_extra() {
let packed = QueryEdge::input(key(231, 10_842_122, 41));
let wide = QueryEdge::input(key(PackedQueryEdge::INGREDIENT_MASK + 1, 10_842_123, 42));
for edge in [packed, wide] {
let mut extra = QueryRevisionsExtraInner::empty();
extra.cycle_converged = true;
let mut origin = OriginAndExtra::derived_untracked(
[edge].into_iter(),
QueryRevisionsExtra(Some(extra)),
);
origin.clear_edges();
assert!(origin.extra().unwrap().cycle_converged);
let QueryOriginRef::DerivedUntracked(edges) = origin.origin() else {
panic!("expected untracked derived origin");
};
assert!(edges.is_packed());
assert!(edges.iter().next().is_none());
assert_eq!(
origin.allocation_size(),
SliceWithHeader::<QueryRevisionsExtraInner, PackedQueryEdge>::layout(0)
.0
.size()
);
}
}
#[cfg(not(feature = "shuttle"))]
#[test]
fn stored_derived_origin_inlines_extra_before_packed_edges() {
let input = QueryEdge::input(key(231, 10_842_122, 41));
let other_input = QueryEdge::input(key(232, 10_842_123, 42));
let mut origin =
OriginAndExtra::derived([input, other_input].into_iter(), Default::default());
origin.get_or_insert_extra().cycle_converged = true;
assert_eq!(
origin.allocation_size(),
SliceWithHeader::<QueryRevisionsExtraInner, PackedQueryEdge>::layout(2)
.0
.size()
);
assert!(origin.extra().unwrap().cycle_converged);
assert_eq!(
origin.origin().edges().iter().collect::<Vec<_>>(),
vec![input, other_input]
);
}
#[cfg(not(feature = "shuttle"))]
#[test]
fn stored_derived_origin_builds_directly_with_extra() {
let input = QueryEdge::input(key(231, 10_842_122, 41));
let mut extra = QueryRevisionsExtraInner::empty();
extra.cycle_converged = true;
let origin = OriginAndExtra::derived([input].into_iter(), QueryRevisionsExtra(Some(extra)));
assert!(origin.extra().unwrap().cycle_converged);
assert_eq!(
origin.origin().edges().iter().collect::<Vec<_>>(),
vec![input]
);
assert_eq!(
origin.allocation_size(),
SliceWithHeader::<QueryRevisionsExtraInner, PackedQueryEdge>::layout(1)
.0
.size()
);
}
#[cfg(not(feature = "shuttle"))]
#[test]
fn stored_wide_origin_builds_directly_with_extra() {
let input = QueryEdge::input(key(PackedQueryEdge::INGREDIENT_MASK + 1, 10_842_122, 41));
let origin = OriginAndExtra::derived(
[input].into_iter(),
QueryRevisionsExtra(Some(QueryRevisionsExtraInner::empty())),
);
assert_eq!(
origin.origin().edges().iter().collect::<Vec<_>>(),
vec![input]
);
assert_eq!(
origin.allocation_size(),
SliceWithHeader::<QueryRevisionsExtraInner, QueryEdge>::layout(1)
.0
.size()
);
}
#[cfg(not(feature = "shuttle"))]
#[test]
fn assigned_origin_with_extra_keeps_the_key() {
let key = key(231, 10_842_122, 41);
let mut origin = OriginAndExtra::assigned(key);
origin.get_or_insert_extra().cycle_converged = true;
assert_eq!(
origin.allocation_size(),
size_of::<AssignedOriginAndExtra>()
);
assert!(origin.extra().unwrap().cycle_converged);
let QueryOriginRef::Assigned(stored_key) = origin.origin() else {
panic!("expected assigned origin");
};
assert_eq!(stored_key, key);
}
#[cfg(feature = "persistence")]
#[test]
fn query_origin_serde_round_trip_preserves_edges() {
let input = QueryEdge::input(key(231, 10_842_122, 41));
let output = QueryEdge::output(key(232, 10_842_123, PackedQueryEdge::GENERATION_MASK + 1));
let origin = PersistentQueryOrigin::derived_untracked([input, output]);
let serialized = serde_json::to_string(&origin).unwrap();
let deserialized_origin: PersistentQueryOrigin = serde_json::from_str(&serialized).unwrap();
let QueryOriginRef::DerivedUntracked(edges) = deserialized_origin.as_ref() else {
panic!("expected untracked derived origin");
};
assert!(!edges.is_packed());
assert_eq!(edges.iter().collect::<Vec<_>>(), vec![input, output]);
assert_eq!(
edges.iter().map(QueryEdge::kind).collect::<Vec<_>>(),
vec![QueryEdgeKind::Input, QueryEdgeKind::Output]
);
}
fn key(ingredient: u32, index: u32, generation: u32) -> DatabaseKeyIndex {
let id = unsafe { Id::from_index(index) }.with_generation(generation);
DatabaseKeyIndex::new(IngredientIndex::new(ingredient), id)
}
}