use std::any::{Any, TypeId};
use std::cell::RefCell;
use std::ops::Deref;
use std::rc::Rc;
use std::sync::Arc;
use whale::{Durability, GetOrInsertResult, RevisionCounter, Runtime as WhaleRuntime};
use crate::asset::{AssetKey, AssetLocator, DurabilityLevel, PendingAsset};
use crate::db::Db;
use crate::key::{
AssetCacheKey, AssetKeySetSentinelKey, FullCacheKey, QueryCacheKey, QuerySetSentinelKey,
};
use crate::loading::AssetLoadingState;
use crate::query::Query;
use crate::storage::{
AssetKeyRegistry, CachedEntry, CachedValue, ErasedLocateResult, LocatorStorage, PendingStorage,
QueryRegistry, VerifierStorage,
};
use crate::tracer::{
ExecutionResult, InvalidationReason, NoopTracer, SpanContext, SpanId, TraceId, Tracer,
TracerAssetState,
};
use crate::QueryError;
pub type ErrorComparator = fn(&anyhow::Error, &anyhow::Error) -> bool;
const DURABILITY_LEVELS: usize = 4;
thread_local! {
static QUERY_STACK: RefCell<Vec<FullCacheKey>> = const { RefCell::new(Vec::new()) };
static CONSISTENCY_TRACKER: RefCell<Option<Rc<ConsistencyTracker>>> = const { RefCell::new(None) };
static SPAN_STACK: RefCell<SpanStack> = const { RefCell::new(SpanStack::Empty) };
}
enum SpanStack {
Empty,
Active(TraceId, Vec<SpanId>),
}
fn check_leaf_asset_consistency(dep_changed_at: RevisionCounter) -> Result<(), QueryError> {
CONSISTENCY_TRACKER.with(|tracker| {
if let Some(ref t) = *tracker.borrow() {
t.check_leaf_asset(dep_changed_at)
} else {
Ok(())
}
})
}
struct ConsistencyTrackerGuard {
previous: Option<Rc<ConsistencyTracker>>,
}
impl ConsistencyTrackerGuard {
fn new(tracker: Rc<ConsistencyTracker>) -> Self {
let previous = CONSISTENCY_TRACKER.with(|t| t.borrow_mut().replace(tracker));
Self { previous }
}
}
impl Drop for ConsistencyTrackerGuard {
fn drop(&mut self) {
CONSISTENCY_TRACKER.with(|t| {
*t.borrow_mut() = self.previous.take();
});
}
}
fn check_cycle(key: &FullCacheKey) -> Result<(), QueryError> {
let cycle_detected = QUERY_STACK.with(|stack| stack.borrow().iter().any(|k| k == key));
if cycle_detected {
let path = QUERY_STACK.with(|stack| {
let stack = stack.borrow();
let mut path: Vec<FullCacheKey> = stack.iter().cloned().collect();
path.push(key.clone());
path
});
return Err(QueryError::Cycle { path });
}
Ok(())
}
struct StackGuard;
impl StackGuard {
fn push(key: FullCacheKey) -> Self {
QUERY_STACK.with(|stack| stack.borrow_mut().push(key));
StackGuard
}
}
impl Drop for StackGuard {
fn drop(&mut self) {
QUERY_STACK.with(|stack| {
stack.borrow_mut().pop();
});
}
}
struct SpanStackGuard;
impl SpanStackGuard {
fn push(trace_id: TraceId, span_id: SpanId) -> Self {
SPAN_STACK.with(|stack| {
let mut s = stack.borrow_mut();
match &mut *s {
SpanStack::Empty => *s = SpanStack::Active(trace_id, vec![span_id]),
SpanStack::Active(_, spans) => spans.push(span_id),
}
});
SpanStackGuard
}
}
impl Drop for SpanStackGuard {
fn drop(&mut self) {
SPAN_STACK.with(|stack| {
let mut s = stack.borrow_mut();
if let SpanStack::Active(_, spans) = &mut *s {
spans.pop();
if spans.is_empty() {
*s = SpanStack::Empty;
}
}
});
}
}
#[derive(Clone, Copy)]
pub struct ExecutionContext {
span_ctx: SpanContext,
}
impl ExecutionContext {
#[inline]
pub fn new(span_ctx: SpanContext) -> Self {
Self { span_ctx }
}
#[inline]
pub fn span_ctx(&self) -> &SpanContext {
&self.span_ctx
}
}
#[derive(Debug, Clone)]
pub struct Polled<T> {
pub value: T,
pub revision: RevisionCounter,
}
impl<T: Deref> Deref for Polled<T> {
type Target = T::Target;
fn deref(&self) -> &Self::Target {
&self.value
}
}
pub struct QueryRuntime<T: Tracer = NoopTracer> {
whale: WhaleRuntime<FullCacheKey, Option<CachedEntry>, DURABILITY_LEVELS>,
locators: Arc<LocatorStorage<T>>,
pending: Arc<PendingStorage>,
query_registry: Arc<QueryRegistry>,
asset_key_registry: Arc<AssetKeyRegistry>,
verifiers: Arc<VerifierStorage>,
error_comparator: ErrorComparator,
tracer: Arc<T>,
}
#[test]
fn test_runtime_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<QueryRuntime<NoopTracer>>();
}
impl Default for QueryRuntime<NoopTracer> {
fn default() -> Self {
Self::new()
}
}
impl<T: Tracer> Clone for QueryRuntime<T> {
fn clone(&self) -> Self {
Self {
whale: self.whale.clone(),
locators: self.locators.clone(),
pending: self.pending.clone(),
query_registry: self.query_registry.clone(),
asset_key_registry: self.asset_key_registry.clone(),
verifiers: self.verifiers.clone(),
error_comparator: self.error_comparator,
tracer: self.tracer.clone(),
}
}
}
fn default_error_comparator(_a: &anyhow::Error, _b: &anyhow::Error) -> bool {
false
}
impl<T: Tracer> QueryRuntime<T> {
fn get_cached_with_revision<Q: Query>(
&self,
key: &FullCacheKey,
) -> Option<(CachedValue<Arc<Q::Output>>, RevisionCounter)> {
let node = self.whale.get(key)?;
let revision = node.changed_at;
let entry = node.data.as_ref()?;
let cached = entry.to_cached_value::<Q::Output>()?;
Some((cached, revision))
}
#[inline]
pub fn tracer(&self) -> &T {
&self.tracer
}
}
impl QueryRuntime<NoopTracer> {
pub fn new() -> Self {
Self::with_tracer(NoopTracer)
}
pub fn builder() -> QueryRuntimeBuilder<NoopTracer> {
QueryRuntimeBuilder::new()
}
}
impl<T: Tracer> QueryRuntime<T> {
pub fn with_tracer(tracer: T) -> Self {
QueryRuntimeBuilder::new().tracer(tracer).build()
}
pub fn query<Q: Query>(&self, query: Q) -> Result<Arc<Q::Output>, QueryError> {
self.query_internal(query)
.and_then(|(inner_result, _)| inner_result.map_err(QueryError::UserError))
}
#[allow(clippy::type_complexity)]
fn query_internal<Q: Query>(
&self,
query: Q,
) -> Result<(Result<Arc<Q::Output>, Arc<anyhow::Error>>, RevisionCounter), QueryError> {
let query_cache_key = QueryCacheKey::new(query.clone());
let full_key: FullCacheKey = query_cache_key.clone().into();
let span_id = self.tracer.new_span_id();
let (trace_id, parent_span_id) = SPAN_STACK.with(|stack| match &*stack.borrow() {
SpanStack::Empty => (self.tracer.new_trace_id(), None),
SpanStack::Active(tid, spans) => (*tid, spans.last().copied()),
});
let span_ctx = SpanContext {
span_id,
trace_id,
parent_span_id,
};
let _span_guard = SpanStackGuard::push(trace_id, span_id);
let exec_ctx = ExecutionContext::new(span_ctx);
self.tracer.on_query_start(&span_ctx, &query_cache_key);
let cycle_detected = QUERY_STACK.with(|stack| {
let stack = stack.borrow();
stack.iter().any(|k| k == &full_key)
});
if cycle_detected {
let path = QUERY_STACK.with(|stack| {
let stack = stack.borrow();
let mut path: Vec<FullCacheKey> = stack.iter().cloned().collect();
path.push(full_key.clone());
path
});
self.tracer.on_cycle_detected(&path);
self.tracer
.on_query_end(&span_ctx, &query_cache_key, ExecutionResult::CycleDetected);
return Err(QueryError::Cycle { path });
}
let current_rev = self.whale.current_revision();
if self.whale.is_verified_at(&full_key, ¤t_rev) {
if let Some((cached, revision)) = self.get_cached_with_revision::<Q>(&full_key) {
self.tracer
.on_cache_check(&span_ctx, &query_cache_key, true);
self.tracer
.on_query_end(&span_ctx, &query_cache_key, ExecutionResult::CacheHit);
return match cached {
CachedValue::Ok(output) => Ok((Ok(output), revision)),
CachedValue::UserError(err) => Ok((Err(err), revision)),
};
}
}
if self.whale.is_valid(&full_key) {
if let Some((cached, revision)) = self.get_cached_with_revision::<Q>(&full_key) {
let mut deps_verified = true;
if let Some(deps) = self.whale.get_dependency_ids(&full_key) {
for dep in deps {
if let Some(verifier) = self.verifiers.get(&dep) {
if verifier.verify(self as &dyn std::any::Any).is_err() {
deps_verified = false;
break;
}
}
}
}
if deps_verified && self.whale.is_valid(&full_key) {
self.whale.mark_verified(&full_key, ¤t_rev);
self.tracer
.on_cache_check(&span_ctx, &query_cache_key, true);
self.tracer.on_query_end(
&span_ctx,
&query_cache_key,
ExecutionResult::CacheHit,
);
return match cached {
CachedValue::Ok(output) => Ok((Ok(output), revision)),
CachedValue::UserError(err) => Ok((Err(err), revision)),
};
}
}
}
self.tracer
.on_cache_check(&span_ctx, &query_cache_key, false);
let _guard = StackGuard::push(full_key.clone());
let result = self.execute_query::<Q>(&query, &query_cache_key, &full_key, exec_ctx);
drop(_guard);
let exec_result = match &result {
Ok((_, true, _)) => ExecutionResult::Changed,
Ok((_, false, _)) => ExecutionResult::Unchanged,
Err(QueryError::Suspend { .. }) => ExecutionResult::Suspended,
Err(QueryError::Cycle { .. }) => ExecutionResult::CycleDetected,
Err(e) => ExecutionResult::Error {
message: format!("{:?}", e),
},
};
self.tracer
.on_query_end(&span_ctx, &query_cache_key, exec_result);
result.map(|(inner_result, _, revision)| (inner_result, revision))
}
#[allow(clippy::type_complexity)]
fn execute_query<Q: Query>(
&self,
query: &Q,
query_cache_key: &QueryCacheKey,
full_key: &FullCacheKey,
exec_ctx: ExecutionContext,
) -> Result<
(
Result<Arc<Q::Output>, Arc<anyhow::Error>>,
bool,
RevisionCounter,
),
QueryError,
> {
let start_revision = self.whale.current_revision().get(Durability::volatile());
let tracker = Rc::new(ConsistencyTracker::new(start_revision));
let _tracker_guard = ConsistencyTrackerGuard::new(tracker);
let ctx = QueryContext {
runtime: self,
current_key: full_key.clone(),
exec_ctx,
deps: RefCell::new(Vec::new()),
};
let db = DbDispatch::QueryContext(&ctx);
let result = query.clone().query(&db);
let deps: Vec<FullCacheKey> = ctx.deps.borrow().clone();
let durability = Durability::stable();
match result {
Ok(output) => {
let existing_revision = if let Some((CachedValue::Ok(old), rev)) =
self.get_cached_with_revision::<Q>(full_key)
{
if Q::output_eq(&*old, &*output) {
Some(rev) } else {
None }
} else {
None };
let output_changed = existing_revision.is_none();
self.tracer.on_early_cutoff_check(
exec_ctx.span_ctx(),
query_cache_key,
output_changed,
);
let entry = CachedEntry::Ok(output.clone() as Arc<dyn std::any::Any + Send + Sync>);
let revision = if let Some(existing_rev) = existing_revision {
let _ = self.whale.confirm_unchanged(full_key, deps);
existing_rev
} else {
match self
.whale
.register(full_key.clone(), Some(entry), durability, deps)
{
Ok(result) => result.new_rev,
Err(missing) => {
return Err(QueryError::DependenciesRemoved {
missing_keys: missing,
})
}
}
};
let is_new_query = self.query_registry.register(query);
if is_new_query {
let sentinel = QuerySetSentinelKey::new::<Q>().into();
let _ = self
.whale
.register(sentinel, None, Durability::stable(), vec![]);
}
self.verifiers
.insert::<Q, T>(full_key.clone(), query.clone());
Ok((Ok(output), output_changed, revision))
}
Err(QueryError::UserError(err)) => {
let existing_revision = if let Some((CachedValue::UserError(old_err), rev)) =
self.get_cached_with_revision::<Q>(full_key)
{
if (self.error_comparator)(old_err.as_ref(), err.as_ref()) {
Some(rev) } else {
None }
} else {
None };
let output_changed = existing_revision.is_none();
self.tracer.on_early_cutoff_check(
exec_ctx.span_ctx(),
query_cache_key,
output_changed,
);
let entry = CachedEntry::UserError(err.clone());
let revision = if let Some(existing_rev) = existing_revision {
let _ = self.whale.confirm_unchanged(full_key, deps);
existing_rev
} else {
match self
.whale
.register(full_key.clone(), Some(entry), durability, deps)
{
Ok(result) => result.new_rev,
Err(missing) => {
return Err(QueryError::DependenciesRemoved {
missing_keys: missing,
})
}
}
};
let is_new_query = self.query_registry.register(query);
if is_new_query {
let sentinel = QuerySetSentinelKey::new::<Q>().into();
let _ = self
.whale
.register(sentinel, None, Durability::stable(), vec![]);
}
self.verifiers
.insert::<Q, T>(full_key.clone(), query.clone());
Ok((Err(err), output_changed, revision))
}
Err(e) => {
Err(e)
}
}
}
pub fn invalidate<Q: Query>(&self, query: &Q) {
let query_cache_key = QueryCacheKey::new(query.clone());
let full_key: FullCacheKey = query_cache_key.clone().into();
self.tracer
.on_query_invalidated(&query_cache_key, InvalidationReason::ManualInvalidation);
let _ = self
.whale
.register(full_key, None, Durability::stable(), vec![]);
}
pub fn remove_query<Q: Query>(&self, query: &Q) {
let query_cache_key = QueryCacheKey::new(query.clone());
let full_key: FullCacheKey = query_cache_key.clone().into();
self.tracer
.on_query_invalidated(&query_cache_key, InvalidationReason::ManualInvalidation);
self.verifiers.remove(&full_key);
self.whale.remove(&full_key);
if self.query_registry.remove::<Q>(query) {
let sentinel = QuerySetSentinelKey::new::<Q>().into();
let _ = self
.whale
.register(sentinel, None, Durability::stable(), vec![]);
}
}
pub fn clear_cache(&self) {
let keys = self.whale.keys();
for key in keys {
self.whale.remove(&key);
}
}
#[allow(clippy::type_complexity)]
pub fn poll<Q: Query>(
&self,
query: Q,
) -> Result<Polled<Result<Arc<Q::Output>, Arc<anyhow::Error>>>, QueryError> {
let (value, revision) = self.query_internal(query)?;
Ok(Polled { value, revision })
}
pub fn changed_at<Q: Query>(&self, query: &Q) -> Option<RevisionCounter> {
let full_key = QueryCacheKey::new(query.clone()).into();
self.whale.get(&full_key).map(|node| node.changed_at)
}
}
impl<T: Tracer> QueryRuntime<T> {
pub fn query_keys(&self) -> Vec<FullCacheKey> {
self.whale.keys()
}
pub fn remove_query_if_unused<Q: Query>(&self, query: &Q) -> bool {
let full_key = QueryCacheKey::new(query.clone()).into();
self.remove_if_unused(&full_key)
}
pub fn remove(&self, key: &FullCacheKey) -> bool {
self.verifiers.remove(key);
self.whale.remove(key).is_some()
}
pub fn remove_if_unused(&self, key: &FullCacheKey) -> bool {
if self.whale.remove_if_unused(key.clone()).is_some() {
self.verifiers.remove(key);
true
} else {
false
}
}
}
pub struct QueryRuntimeBuilder<T: Tracer = NoopTracer> {
error_comparator: ErrorComparator,
tracer: T,
}
impl Default for QueryRuntimeBuilder<NoopTracer> {
fn default() -> Self {
Self::new()
}
}
impl QueryRuntimeBuilder<NoopTracer> {
pub fn new() -> Self {
Self {
error_comparator: default_error_comparator,
tracer: NoopTracer,
}
}
}
impl<T: Tracer> QueryRuntimeBuilder<T> {
pub fn error_comparator(mut self, f: ErrorComparator) -> Self {
self.error_comparator = f;
self
}
pub fn tracer<U: Tracer>(self, tracer: U) -> QueryRuntimeBuilder<U> {
QueryRuntimeBuilder {
error_comparator: self.error_comparator,
tracer,
}
}
pub fn build(self) -> QueryRuntime<T> {
QueryRuntime {
whale: WhaleRuntime::new(),
locators: Arc::new(LocatorStorage::new()),
pending: Arc::new(PendingStorage::new()),
query_registry: Arc::new(QueryRegistry::new()),
asset_key_registry: Arc::new(AssetKeyRegistry::new()),
verifiers: Arc::new(VerifierStorage::new()),
error_comparator: self.error_comparator,
tracer: Arc::new(self.tracer),
}
}
}
impl<T: Tracer> QueryRuntime<T> {
pub fn register_asset_locator<K, L>(&self, locator: L)
where
K: AssetKey,
L: AssetLocator<K>,
{
self.locators.insert::<K, L>(locator);
}
pub fn pending_assets(&self) -> Vec<PendingAsset> {
self.pending.get_all()
}
pub fn pending_assets_of<K: AssetKey>(&self) -> Vec<K> {
self.pending.get_of_type::<K>()
}
pub fn has_pending_assets(&self) -> bool {
!self.pending.is_empty()
}
pub fn resolve_asset<K: AssetKey>(&self, key: K, value: K::Asset, durability: DurabilityLevel) {
self.resolve_asset_internal(key, value, durability);
}
pub fn resolve_asset_error<K: AssetKey>(
&self,
key: K,
error: impl Into<anyhow::Error>,
durability: DurabilityLevel,
) {
let asset_cache_key = AssetCacheKey::new(key.clone());
self.pending.remove(&asset_cache_key);
let error_arc = Arc::new(error.into());
let entry = CachedEntry::AssetError(error_arc.clone());
let durability =
Durability::new(durability.as_u8() as usize).unwrap_or(Durability::volatile());
let result = self
.whale
.update_with_compare(
asset_cache_key.into(),
Some(entry),
|old_data, _new_data| {
match old_data.and_then(|d| d.as_ref()) {
Some(CachedEntry::AssetError(old_err)) => {
!(self.error_comparator)(old_err.as_ref(), error_arc.as_ref())
}
_ => true, }
},
durability,
vec![],
)
.expect("update_with_compare with no dependencies cannot fail");
let asset_cache_key = AssetCacheKey::new(key.clone());
self.tracer
.on_asset_resolved(&asset_cache_key, result.changed);
let is_new_asset = self.asset_key_registry.register(&key);
if is_new_asset {
let sentinel = AssetKeySetSentinelKey::new::<K>().into();
let _ = self
.whale
.register(sentinel, None, Durability::stable(), vec![]);
}
}
fn resolve_asset_internal<K: AssetKey>(
&self,
key: K,
value: K::Asset,
durability_level: DurabilityLevel,
) {
let asset_cache_key = AssetCacheKey::new(key.clone());
self.pending.remove(&asset_cache_key);
let value_arc: Arc<K::Asset> = Arc::new(value);
let entry = CachedEntry::AssetReady(value_arc.clone() as Arc<dyn Any + Send + Sync>);
let durability =
Durability::new(durability_level.as_u8() as usize).unwrap_or(Durability::volatile());
let result = self
.whale
.update_with_compare(
asset_cache_key.into(),
Some(entry),
|old_data, _new_data| {
match old_data.and_then(|d| d.as_ref()) {
Some(CachedEntry::AssetReady(old_arc)) => {
match old_arc.clone().downcast::<K::Asset>() {
Ok(old_value) => !K::asset_eq(&old_value, &value_arc),
Err(_) => true, }
}
_ => true, }
},
durability,
vec![],
)
.expect("update_with_compare with no dependencies cannot fail");
let asset_cache_key = AssetCacheKey::new(key.clone());
self.tracer
.on_asset_resolved(&asset_cache_key, result.changed);
let is_new_asset = self.asset_key_registry.register(&key);
if is_new_asset {
let sentinel = AssetKeySetSentinelKey::new::<K>().into();
let _ = self
.whale
.register(sentinel, None, Durability::stable(), vec![]);
}
}
pub fn invalidate_asset<K: AssetKey>(&self, key: &K) {
let asset_cache_key = AssetCacheKey::new(key.clone());
let full_cache_key: FullCacheKey = asset_cache_key.clone().into();
self.tracer.on_asset_invalidated(&asset_cache_key);
self.pending
.insert::<K>(asset_cache_key.clone(), key.clone());
let _ = self
.whale
.register(full_cache_key, None, Durability::stable(), vec![]);
}
pub fn remove_asset<K: AssetKey>(&self, key: &K) {
let asset_cache_key = AssetCacheKey::new(key.clone());
let full_cache_key: FullCacheKey = asset_cache_key.clone().into();
self.pending.remove(&asset_cache_key);
self.whale.remove(&full_cache_key);
if self.asset_key_registry.remove::<K>(key) {
let sentinel = AssetKeySetSentinelKey::new::<K>().into();
let _ = self
.whale
.register(sentinel, None, Durability::stable(), vec![]);
}
}
pub fn get_asset<K: AssetKey>(&self, key: K) -> Result<AssetLoadingState<K>, QueryError> {
self.get_asset_internal(key)
}
fn get_asset_with_revision<K: AssetKey>(
&self,
key: K,
) -> Result<(AssetLoadingState<K>, RevisionCounter), QueryError> {
let asset_cache_key = AssetCacheKey::new(key.clone());
let full_cache_key: FullCacheKey = asset_cache_key.clone().into();
let asset_span_id = self.tracer.new_span_id();
let (trace_id, parent_span_id) = SPAN_STACK.with(|stack| match &*stack.borrow() {
SpanStack::Empty => (self.tracer.new_trace_id(), None),
SpanStack::Active(tid, spans) => (*tid, spans.last().copied()),
});
let span_ctx = SpanContext {
span_id: asset_span_id,
trace_id,
parent_span_id,
};
let _span_guard = SpanStackGuard::push(trace_id, asset_span_id);
if let Some(node) = self.whale.get(&full_cache_key) {
let changed_at = node.changed_at;
if self.whale.is_valid(&full_cache_key) {
let mut deps_verified = true;
if let Some(deps) = self.whale.get_dependency_ids(&full_cache_key) {
for dep in deps {
if let Some(verifier) = self.verifiers.get(&dep) {
if verifier.verify(self as &dyn std::any::Any).is_err() {
deps_verified = false;
break;
}
}
}
}
if deps_verified && self.whale.is_valid(&full_cache_key) {
let has_locator_deps = self
.whale
.get_dependency_ids(&full_cache_key)
.is_some_and(|deps| !deps.is_empty());
match &node.data {
Some(CachedEntry::AssetReady(arc)) => {
if !has_locator_deps {
check_leaf_asset_consistency(changed_at)?;
}
self.tracer.on_asset_requested(&span_ctx, &asset_cache_key);
self.tracer.on_asset_located(
&span_ctx,
&asset_cache_key,
TracerAssetState::Ready,
);
match arc.clone().downcast::<K::Asset>() {
Ok(value) => {
return Ok((AssetLoadingState::ready(key, value), changed_at))
}
Err(_) => {
unreachable!("Asset type mismatch: {:?}", key)
}
}
}
Some(CachedEntry::AssetError(err)) => {
if !has_locator_deps {
check_leaf_asset_consistency(changed_at)?;
}
self.tracer.on_asset_requested(&span_ctx, &asset_cache_key);
self.tracer.on_asset_located(
&span_ctx,
&asset_cache_key,
TracerAssetState::NotFound,
);
return Err(QueryError::UserError(err.clone()));
}
None => {
self.tracer.on_asset_requested(&span_ctx, &asset_cache_key);
self.tracer.on_asset_located(
&span_ctx,
&asset_cache_key,
TracerAssetState::Loading,
);
return Ok((AssetLoadingState::loading(key), changed_at));
}
_ => {
}
}
}
}
}
check_cycle(&full_cache_key)?;
let _guard = StackGuard::push(full_cache_key.clone());
self.tracer.on_asset_requested(&span_ctx, &asset_cache_key);
let locator_ctx = LocatorContext::new(self, full_cache_key.clone());
let locator_result =
self.locators
.locate_with_locator_ctx(TypeId::of::<K>(), &locator_ctx, &key);
if let Some(result) = locator_result {
let locator_deps = locator_ctx.into_deps();
match result {
Ok(ErasedLocateResult::Ready {
value: arc,
durability: durability_level,
}) => {
self.tracer.on_asset_located(
&span_ctx,
&asset_cache_key,
TracerAssetState::Ready,
);
let typed_value: Arc<K::Asset> = match arc.downcast::<K::Asset>() {
Ok(v) => v,
Err(_) => {
unreachable!("Asset type mismatch: {:?}", key);
}
};
let entry = CachedEntry::AssetReady(typed_value.clone());
let durability = Durability::new(durability_level.as_u8() as usize)
.unwrap_or(Durability::volatile());
let new_value = typed_value.clone();
let result = self
.whale
.update_with_compare(
full_cache_key.clone(),
Some(entry),
|old_data, _new_data| {
let Some(CachedEntry::AssetReady(old_arc)) =
old_data.and_then(|d| d.as_ref())
else {
return true;
};
let Ok(old_value) = old_arc.clone().downcast::<K::Asset>() else {
return true;
};
!K::asset_eq(&old_value, &new_value)
},
durability,
locator_deps,
)
.expect("update_with_compare should succeed");
self.verifiers
.insert_asset::<K, T>(full_cache_key, key.clone());
return Ok((AssetLoadingState::ready(key, typed_value), result.revision));
}
Ok(ErasedLocateResult::Pending) => {
self.tracer.on_asset_located(
&span_ctx,
&asset_cache_key,
TracerAssetState::Loading,
);
self.pending
.insert::<K>(asset_cache_key.clone(), key.clone());
match self
.whale
.get_or_insert(full_cache_key, None, Durability::volatile(), locator_deps)
.expect("get_or_insert should succeed")
{
GetOrInsertResult::Inserted(node) => {
return Ok((AssetLoadingState::loading(key), node.changed_at));
}
GetOrInsertResult::Existing(node) => {
let changed_at = node.changed_at;
match &node.data {
Some(CachedEntry::AssetReady(arc)) => {
match arc.clone().downcast::<K::Asset>() {
Ok(value) => {
return Ok((
AssetLoadingState::ready(key, value),
changed_at,
))
}
Err(_) => {
return Ok((
AssetLoadingState::loading(key),
changed_at,
))
}
}
}
Some(CachedEntry::AssetError(err)) => {
return Err(QueryError::UserError(err.clone()));
}
_ => return Ok((AssetLoadingState::loading(key), changed_at)),
}
}
}
}
Err(QueryError::UserError(err)) => {
self.tracer.on_asset_located(
&span_ctx,
&asset_cache_key,
TracerAssetState::NotFound,
);
let entry = CachedEntry::AssetError(err.clone());
let _ = self.whale.register(
full_cache_key,
Some(entry),
Durability::volatile(),
locator_deps,
);
return Err(QueryError::UserError(err));
}
Err(e) => {
return Err(e);
}
}
}
self.tracer
.on_asset_located(&span_ctx, &asset_cache_key, TracerAssetState::Loading);
self.pending
.insert::<K>(asset_cache_key.clone(), key.clone());
match self
.whale
.get_or_insert(full_cache_key, None, Durability::volatile(), vec![])
.expect("get_or_insert with no dependencies cannot fail")
{
GetOrInsertResult::Inserted(node) => {
Ok((AssetLoadingState::loading(key), node.changed_at))
}
GetOrInsertResult::Existing(node) => {
let changed_at = node.changed_at;
match &node.data {
Some(CachedEntry::AssetReady(arc)) => {
match arc.clone().downcast::<K::Asset>() {
Ok(value) => Ok((AssetLoadingState::ready(key, value), changed_at)),
Err(_) => Ok((AssetLoadingState::loading(key), changed_at)),
}
}
Some(CachedEntry::AssetError(err)) => Err(QueryError::UserError(err.clone())),
_ => Ok((AssetLoadingState::loading(key), changed_at)),
}
}
}
}
fn get_asset_with_revision_ctx<K: AssetKey>(
&self,
key: K,
_ctx: &QueryContext<'_, T>,
) -> Result<(AssetLoadingState<K>, RevisionCounter), QueryError> {
let asset_cache_key = AssetCacheKey::new(key.clone());
let full_cache_key: FullCacheKey = asset_cache_key.clone().into();
let asset_span_id = self.tracer.new_span_id();
let (trace_id, parent_span_id) = SPAN_STACK.with(|stack| match &*stack.borrow() {
SpanStack::Empty => (self.tracer.new_trace_id(), None),
SpanStack::Active(tid, spans) => (*tid, spans.last().copied()),
});
let span_ctx = SpanContext {
span_id: asset_span_id,
trace_id,
parent_span_id,
};
let _span_guard = SpanStackGuard::push(trace_id, asset_span_id);
if let Some(node) = self.whale.get(&full_cache_key) {
let changed_at = node.changed_at;
if self.whale.is_valid(&full_cache_key) {
let mut deps_verified = true;
if let Some(deps) = self.whale.get_dependency_ids(&full_cache_key) {
for dep in deps {
if let Some(verifier) = self.verifiers.get(&dep) {
if verifier.verify(self as &dyn std::any::Any).is_err() {
deps_verified = false;
break;
}
}
}
}
if deps_verified && self.whale.is_valid(&full_cache_key) {
let has_locator_deps = self
.whale
.get_dependency_ids(&full_cache_key)
.is_some_and(|deps| !deps.is_empty());
match &node.data {
Some(CachedEntry::AssetReady(arc)) => {
if !has_locator_deps {
check_leaf_asset_consistency(changed_at)?;
}
self.tracer.on_asset_requested(&span_ctx, &asset_cache_key);
self.tracer.on_asset_located(
&span_ctx,
&asset_cache_key,
TracerAssetState::Ready,
);
match arc.clone().downcast::<K::Asset>() {
Ok(value) => {
return Ok((AssetLoadingState::ready(key, value), changed_at))
}
Err(_) => {
unreachable!("Asset type mismatch: {:?}", key)
}
}
}
Some(CachedEntry::AssetError(err)) => {
if !has_locator_deps {
check_leaf_asset_consistency(changed_at)?;
}
self.tracer.on_asset_requested(&span_ctx, &asset_cache_key);
self.tracer.on_asset_located(
&span_ctx,
&asset_cache_key,
TracerAssetState::NotFound,
);
return Err(QueryError::UserError(err.clone()));
}
None => {
self.tracer.on_asset_requested(&span_ctx, &asset_cache_key);
self.tracer.on_asset_located(
&span_ctx,
&asset_cache_key,
TracerAssetState::Loading,
);
return Ok((AssetLoadingState::loading(key), changed_at));
}
_ => {
}
}
}
}
}
check_cycle(&full_cache_key)?;
let _guard = StackGuard::push(full_cache_key.clone());
self.tracer.on_asset_requested(&span_ctx, &asset_cache_key);
let locator_ctx = LocatorContext::new(self, full_cache_key.clone());
let locator_result =
self.locators
.locate_with_locator_ctx(TypeId::of::<K>(), &locator_ctx, &key);
if let Some(result) = locator_result {
let locator_deps = locator_ctx.into_deps();
match result {
Ok(ErasedLocateResult::Ready {
value: arc,
durability: durability_level,
}) => {
self.tracer.on_asset_located(
&span_ctx,
&asset_cache_key,
TracerAssetState::Ready,
);
let typed_value: Arc<K::Asset> = match arc.downcast::<K::Asset>() {
Ok(v) => v,
Err(_) => {
unreachable!("Asset type mismatch: {:?}", key);
}
};
let entry = CachedEntry::AssetReady(typed_value.clone());
let durability = Durability::new(durability_level.as_u8() as usize)
.unwrap_or(Durability::volatile());
let new_value = typed_value.clone();
let result = self
.whale
.update_with_compare(
full_cache_key.clone(),
Some(entry),
|old_data, _new_data| {
let Some(CachedEntry::AssetReady(old_arc)) =
old_data.and_then(|d| d.as_ref())
else {
return true;
};
let Ok(old_value) = old_arc.clone().downcast::<K::Asset>() else {
return true;
};
!K::asset_eq(&old_value, &new_value)
},
durability,
locator_deps,
)
.expect("update_with_compare should succeed");
self.verifiers
.insert_asset::<K, T>(full_cache_key, key.clone());
return Ok((AssetLoadingState::ready(key, typed_value), result.revision));
}
Ok(ErasedLocateResult::Pending) => {
self.tracer.on_asset_located(
&span_ctx,
&asset_cache_key,
TracerAssetState::Loading,
);
self.pending
.insert::<K>(asset_cache_key.clone(), key.clone());
match self
.whale
.get_or_insert(full_cache_key, None, Durability::volatile(), locator_deps)
.expect("get_or_insert should succeed")
{
GetOrInsertResult::Inserted(node) => {
return Ok((AssetLoadingState::loading(key), node.changed_at));
}
GetOrInsertResult::Existing(node) => {
let changed_at = node.changed_at;
match &node.data {
Some(CachedEntry::AssetReady(arc)) => {
match arc.clone().downcast::<K::Asset>() {
Ok(value) => {
return Ok((
AssetLoadingState::ready(key, value),
changed_at,
));
}
Err(_) => {
return Ok((
AssetLoadingState::loading(key),
changed_at,
))
}
}
}
Some(CachedEntry::AssetError(err)) => {
return Err(QueryError::UserError(err.clone()));
}
_ => return Ok((AssetLoadingState::loading(key), changed_at)),
}
}
}
}
Err(QueryError::UserError(err)) => {
self.tracer.on_asset_located(
&span_ctx,
&asset_cache_key,
TracerAssetState::NotFound,
);
let entry = CachedEntry::AssetError(err.clone());
let _ = self.whale.register(
full_cache_key,
Some(entry),
Durability::volatile(),
locator_deps,
);
return Err(QueryError::UserError(err));
}
Err(e) => {
return Err(e);
}
}
}
self.tracer
.on_asset_located(&span_ctx, &asset_cache_key, TracerAssetState::Loading);
self.pending
.insert::<K>(asset_cache_key.clone(), key.clone());
match self
.whale
.get_or_insert(full_cache_key, None, Durability::volatile(), vec![])
.expect("get_or_insert with no dependencies cannot fail")
{
GetOrInsertResult::Inserted(node) => {
Ok((AssetLoadingState::loading(key), node.changed_at))
}
GetOrInsertResult::Existing(node) => {
let changed_at = node.changed_at;
match &node.data {
Some(CachedEntry::AssetReady(arc)) => {
match arc.clone().downcast::<K::Asset>() {
Ok(value) => Ok((AssetLoadingState::ready(key, value), changed_at)),
Err(_) => Ok((AssetLoadingState::loading(key), changed_at)),
}
}
Some(CachedEntry::AssetError(err)) => Err(QueryError::UserError(err.clone())),
_ => Ok((AssetLoadingState::loading(key), changed_at)),
}
}
}
}
fn get_asset_internal<K: AssetKey>(&self, key: K) -> Result<AssetLoadingState<K>, QueryError> {
self.get_asset_with_revision(key).map(|(state, _)| state)
}
}
impl<T: Tracer> Db for QueryRuntime<T> {
fn query<Q: Query>(&self, query: Q) -> Result<Arc<Q::Output>, QueryError> {
QueryRuntime::query(self, query)
}
fn asset<K: AssetKey>(&self, key: K) -> Result<Arc<K::Asset>, QueryError> {
self.get_asset_internal(key)?.suspend()
}
fn asset_state<K: AssetKey>(&self, key: K) -> Result<AssetLoadingState<K>, QueryError> {
self.get_asset_internal(key)
}
fn list_queries<Q: Query>(&self) -> Vec<Q> {
self.query_registry.get_all::<Q>()
}
fn list_asset_keys<K: AssetKey>(&self) -> Vec<K> {
self.asset_key_registry.get_all::<K>()
}
}
#[derive(Debug)]
pub(crate) struct ConsistencyTracker {
start_revision: RevisionCounter,
}
impl ConsistencyTracker {
pub fn new(start_revision: RevisionCounter) -> Self {
Self { start_revision }
}
pub fn check_leaf_asset(&self, dep_changed_at: RevisionCounter) -> Result<(), QueryError> {
if dep_changed_at > self.start_revision {
Err(QueryError::InconsistentAssetResolution)
} else {
Ok(())
}
}
}
pub(crate) struct QueryContext<'a, T: Tracer = NoopTracer> {
runtime: &'a QueryRuntime<T>,
current_key: FullCacheKey,
exec_ctx: ExecutionContext,
deps: RefCell<Vec<FullCacheKey>>,
}
impl<'a, T: Tracer> QueryContext<'a, T> {
pub fn query<Q: Query>(&self, query: Q) -> Result<Arc<Q::Output>, QueryError> {
let full_key: FullCacheKey = QueryCacheKey::new(query.clone()).into();
self.runtime.tracer.on_dependency_registered(
self.exec_ctx.span_ctx(),
&self.current_key,
&full_key,
);
self.deps.borrow_mut().push(full_key);
self.runtime.query(query)
}
pub fn asset<K: AssetKey>(&self, key: K) -> Result<Arc<K::Asset>, QueryError> {
self.asset_state(key)?.suspend()
}
pub fn asset_state<K: AssetKey>(&self, key: K) -> Result<AssetLoadingState<K>, QueryError> {
let full_cache_key: FullCacheKey = AssetCacheKey::new(key.clone()).into();
self.runtime.tracer.on_asset_dependency_registered(
self.exec_ctx.span_ctx(),
&self.current_key,
&full_cache_key,
);
self.deps.borrow_mut().push(full_cache_key);
let (state, _changed_at) = self.runtime.get_asset_with_revision_ctx(key, self)?;
Ok(state)
}
pub fn list_queries<Q: Query>(&self) -> Vec<Q> {
let sentinel: FullCacheKey = QuerySetSentinelKey::new::<Q>().into();
self.runtime.tracer.on_dependency_registered(
self.exec_ctx.span_ctx(),
&self.current_key,
&sentinel,
);
if self.runtime.whale.get(&sentinel).is_none() {
let _ =
self.runtime
.whale
.register(sentinel.clone(), None, Durability::volatile(), vec![]);
}
self.deps.borrow_mut().push(sentinel);
self.runtime.query_registry.get_all::<Q>()
}
pub fn list_asset_keys<K: AssetKey>(&self) -> Vec<K> {
let sentinel: FullCacheKey = AssetKeySetSentinelKey::new::<K>().into();
self.runtime.tracer.on_asset_dependency_registered(
self.exec_ctx.span_ctx(),
&self.current_key,
&sentinel,
);
if self.runtime.whale.get(&sentinel).is_none() {
let _ =
self.runtime
.whale
.register(sentinel.clone(), None, Durability::volatile(), vec![]);
}
self.deps.borrow_mut().push(sentinel);
self.runtime.asset_key_registry.get_all::<K>()
}
}
impl<'a, T: Tracer> Db for QueryContext<'a, T> {
fn query<Q: Query>(&self, query: Q) -> Result<Arc<Q::Output>, QueryError> {
QueryContext::query(self, query)
}
fn asset<K: AssetKey>(&self, key: K) -> Result<Arc<K::Asset>, QueryError> {
QueryContext::asset(self, key)
}
fn asset_state<K: AssetKey>(&self, key: K) -> Result<AssetLoadingState<K>, QueryError> {
QueryContext::asset_state(self, key)
}
fn list_queries<Q: Query>(&self) -> Vec<Q> {
QueryContext::list_queries(self)
}
fn list_asset_keys<K: AssetKey>(&self) -> Vec<K> {
QueryContext::list_asset_keys(self)
}
}
pub(crate) struct LocatorContext<'a, T: Tracer> {
runtime: &'a QueryRuntime<T>,
deps: RefCell<Vec<FullCacheKey>>,
}
impl<'a, T: Tracer> LocatorContext<'a, T> {
pub(crate) fn new(runtime: &'a QueryRuntime<T>, _asset_key: FullCacheKey) -> Self {
Self {
runtime,
deps: RefCell::new(Vec::new()),
}
}
pub(crate) fn into_deps(self) -> Vec<FullCacheKey> {
self.deps.into_inner()
}
}
impl<T: Tracer> Db for LocatorContext<'_, T> {
fn query<Q: Query>(&self, query: Q) -> Result<Arc<Q::Output>, QueryError> {
let full_key = QueryCacheKey::new(query.clone()).into();
self.deps.borrow_mut().push(full_key);
self.runtime.query(query)
}
fn asset<K: AssetKey>(&self, key: K) -> Result<Arc<K::Asset>, QueryError> {
self.asset_state(key)?.suspend()
}
fn asset_state<K: AssetKey>(&self, key: K) -> Result<AssetLoadingState<K>, QueryError> {
let full_cache_key = AssetCacheKey::new(key.clone()).into();
self.deps.borrow_mut().push(full_cache_key);
let (state, _changed_at) = self.runtime.get_asset_with_revision(key)?;
Ok(state)
}
fn list_queries<Q: Query>(&self) -> Vec<Q> {
self.runtime.list_queries()
}
fn list_asset_keys<K: AssetKey>(&self) -> Vec<K> {
self.runtime.list_asset_keys()
}
}
pub(crate) enum DbDispatch<'a, T: Tracer = NoopTracer> {
QueryContext(&'a QueryContext<'a, T>),
LocatorContext(&'a LocatorContext<'a, T>),
}
impl<T: Tracer> Db for DbDispatch<'_, T> {
fn query<Q: Query>(&self, query: Q) -> Result<Arc<Q::Output>, QueryError> {
match self {
DbDispatch::QueryContext(ctx) => ctx.query(query),
DbDispatch::LocatorContext(ctx) => ctx.query(query),
}
}
fn asset<K: AssetKey>(&self, key: K) -> Result<Arc<K::Asset>, QueryError> {
match self {
DbDispatch::QueryContext(ctx) => ctx.asset(key),
DbDispatch::LocatorContext(ctx) => ctx.asset(key),
}
}
fn asset_state<K: AssetKey>(&self, key: K) -> Result<AssetLoadingState<K>, QueryError> {
match self {
DbDispatch::QueryContext(ctx) => ctx.asset_state(key),
DbDispatch::LocatorContext(ctx) => ctx.asset_state(key),
}
}
fn list_queries<Q: Query>(&self) -> Vec<Q> {
match self {
DbDispatch::QueryContext(ctx) => ctx.list_queries(),
DbDispatch::LocatorContext(ctx) => ctx.list_queries(),
}
}
fn list_asset_keys<K: AssetKey>(&self) -> Vec<K> {
match self {
DbDispatch::QueryContext(ctx) => ctx.list_asset_keys(),
DbDispatch::LocatorContext(ctx) => ctx.list_asset_keys(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_query() {
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
struct Add {
a: i32,
b: i32,
}
impl Query for Add {
type Output = i32;
fn query(self, _db: &impl Db) -> Result<Arc<Self::Output>, QueryError> {
Ok(Arc::new(self.a + self.b))
}
fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
old == new
}
}
let runtime = QueryRuntime::new();
let result = runtime.query(Add { a: 1, b: 2 }).unwrap();
assert_eq!(*result, 3);
let result2 = runtime.query(Add { a: 1, b: 2 }).unwrap();
assert_eq!(*result2, 3);
}
#[test]
fn test_dependent_queries() {
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
struct Base {
value: i32,
}
impl Query for Base {
type Output = i32;
fn query(self, _db: &impl Db) -> Result<Arc<Self::Output>, QueryError> {
Ok(Arc::new(self.value * 2))
}
fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
old == new
}
}
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
struct Derived {
base_value: i32,
}
impl Query for Derived {
type Output = i32;
fn query(self, db: &impl Db) -> Result<Arc<Self::Output>, QueryError> {
let base = db.query(Base {
value: self.base_value,
})?;
Ok(Arc::new(*base + 10))
}
fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
old == new
}
}
let runtime = QueryRuntime::new();
let result = runtime.query(Derived { base_value: 5 }).unwrap();
assert_eq!(*result, 20); }
#[test]
fn test_cycle_detection() {
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
struct CycleA {
id: i32,
}
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
struct CycleB {
id: i32,
}
impl Query for CycleA {
type Output = i32;
fn query(self, db: &impl Db) -> Result<Arc<Self::Output>, QueryError> {
let b = db.query(CycleB { id: self.id })?;
Ok(Arc::new(*b + 1))
}
fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
old == new
}
}
impl Query for CycleB {
type Output = i32;
fn query(self, db: &impl Db) -> Result<Arc<Self::Output>, QueryError> {
let a = db.query(CycleA { id: self.id })?;
Ok(Arc::new(*a + 1))
}
fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
old == new
}
}
let runtime = QueryRuntime::new();
let result = runtime.query(CycleA { id: 1 });
assert!(matches!(result, Err(QueryError::Cycle { .. })));
}
#[test]
fn test_fallible_query() {
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
struct ParseInt {
input: String,
}
impl Query for ParseInt {
type Output = Result<i32, std::num::ParseIntError>;
fn query(self, _db: &impl Db) -> Result<Arc<Self::Output>, QueryError> {
Ok(Arc::new(self.input.parse()))
}
fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
old == new
}
}
let runtime = QueryRuntime::new();
let result = runtime
.query(ParseInt {
input: "42".to_string(),
})
.unwrap();
assert_eq!(*result, Ok(42));
let result = runtime
.query(ParseInt {
input: "not_a_number".to_string(),
})
.unwrap();
assert!(result.is_err());
}
mod macro_tests {
use super::*;
use crate::query;
#[query]
fn add(db: &impl Db, a: i32, b: i32) -> Result<i32, QueryError> {
let _ = db; Ok(a + b)
}
#[test]
fn test_macro_basic() {
let runtime = QueryRuntime::new();
let result = runtime.query(Add::new(1, 2)).unwrap();
assert_eq!(*result, 3);
}
#[query]
fn simple_double(db: &impl Db, x: i32) -> Result<i32, QueryError> {
let _ = db;
Ok(x * 2)
}
#[test]
fn test_macro_simple() {
let runtime = QueryRuntime::new();
let result = runtime.query(SimpleDouble::new(5)).unwrap();
assert_eq!(*result, 10);
}
#[query(keys(id))]
fn with_key_selection(
db: &impl Db,
id: u32,
include_extra: bool,
) -> Result<String, QueryError> {
let _ = db;
Ok(format!("id={}, extra={}", id, include_extra))
}
#[test]
fn test_macro_key_selection() {
let runtime = QueryRuntime::new();
let r1 = runtime.query(WithKeySelection::new(1, true)).unwrap();
let r2 = runtime.query(WithKeySelection::new(1, false)).unwrap();
assert_eq!(*r1, "id=1, extra=true");
assert_eq!(*r2, "id=1, extra=true"); }
#[query]
fn dependent(db: &impl Db, a: i32, b: i32) -> Result<i32, QueryError> {
let sum = db.query(Add::new(a, b))?;
Ok(*sum * 2)
}
#[test]
fn test_macro_dependencies() {
let runtime = QueryRuntime::new();
let result = runtime.query(Dependent::new(3, 4)).unwrap();
assert_eq!(*result, 14); }
#[query(output_eq)]
fn with_output_eq(db: &impl Db, x: i32) -> Result<i32, QueryError> {
let _ = db;
Ok(x * 2)
}
#[test]
fn test_macro_output_eq() {
let runtime = QueryRuntime::new();
let result = runtime.query(WithOutputEq::new(5)).unwrap();
assert_eq!(*result, 10);
}
#[query(name = "CustomName")]
fn original_name(db: &impl Db, x: i32) -> Result<i32, QueryError> {
let _ = db;
Ok(x)
}
#[test]
fn test_macro_custom_name() {
let runtime = QueryRuntime::new();
let result = runtime.query(CustomName::new(42)).unwrap();
assert_eq!(*result, 42);
}
#[allow(unused_variables)]
#[inline]
#[query]
fn with_attributes(db: &impl Db, x: i32) -> Result<i32, QueryError> {
let unused_var = 42;
Ok(x * 2)
}
#[test]
fn test_macro_preserves_attributes() {
let runtime = QueryRuntime::new();
let result = runtime.query(WithAttributes::new(5)).unwrap();
assert_eq!(*result, 10);
}
}
mod poll_tests {
use super::*;
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
struct Counter {
id: i32,
}
impl Query for Counter {
type Output = i32;
fn query(self, _db: &impl Db) -> Result<Arc<Self::Output>, QueryError> {
Ok(Arc::new(self.id * 10))
}
fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
old == new
}
}
#[test]
fn test_poll_returns_value_and_revision() {
let runtime = QueryRuntime::new();
let result = runtime.poll(Counter { id: 1 }).unwrap();
assert_eq!(**result.value.as_ref().unwrap(), 10);
assert!(result.revision > 0);
}
#[test]
fn test_poll_revision_stable_on_cache_hit() {
let runtime = QueryRuntime::new();
let result1 = runtime.poll(Counter { id: 1 }).unwrap();
let rev1 = result1.revision;
let result2 = runtime.poll(Counter { id: 1 }).unwrap();
let rev2 = result2.revision;
assert_eq!(rev1, rev2);
}
#[test]
fn test_poll_revision_changes_on_invalidate() {
let runtime = QueryRuntime::new();
let result1 = runtime.poll(Counter { id: 1 }).unwrap();
let rev1 = result1.revision;
runtime.invalidate(&Counter { id: 1 });
let result2 = runtime.poll(Counter { id: 1 }).unwrap();
let rev2 = result2.revision;
assert_eq!(**result2.value.as_ref().unwrap(), 10);
assert!(rev2 >= rev1);
}
#[test]
fn test_changed_at_returns_none_for_unexecuted_query() {
let runtime = QueryRuntime::new();
let rev = runtime.changed_at(&Counter { id: 1 });
assert!(rev.is_none());
}
#[test]
fn test_changed_at_returns_revision_after_execution() {
let runtime = QueryRuntime::new();
let _ = runtime.query(Counter { id: 1 }).unwrap();
let rev = runtime.changed_at(&Counter { id: 1 });
assert!(rev.is_some());
assert!(rev.unwrap() > 0);
}
#[test]
fn test_changed_at_matches_poll_revision() {
let runtime = QueryRuntime::new();
let result = runtime.poll(Counter { id: 1 }).unwrap();
let rev = runtime.changed_at(&Counter { id: 1 });
assert_eq!(rev, Some(result.revision));
}
#[test]
fn test_poll_value_access() {
let runtime = QueryRuntime::new();
let result = runtime.poll(Counter { id: 5 }).unwrap();
let value: &i32 = result.value.as_ref().unwrap();
assert_eq!(*value, 50);
let arc: &Arc<i32> = result.value.as_ref().unwrap();
assert_eq!(**arc, 50);
}
#[test]
fn test_subscription_pattern() {
let runtime = QueryRuntime::new();
let mut last_revision: RevisionCounter = 0;
let mut notifications = 0;
let result = runtime.poll(Counter { id: 1 }).unwrap();
if result.revision > last_revision {
notifications += 1;
last_revision = result.revision;
}
let result = runtime.poll(Counter { id: 1 }).unwrap();
if result.revision > last_revision {
notifications += 1;
last_revision = result.revision;
}
let result = runtime.poll(Counter { id: 1 }).unwrap();
if result.revision > last_revision {
notifications += 1;
#[allow(unused_assignments)]
{
last_revision = result.revision;
}
}
assert_eq!(notifications, 1);
}
}
mod gc_tests {
use super::*;
use crate::tracer::{SpanContext, SpanId, TraceId};
use std::collections::HashSet;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Mutex;
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
struct Leaf {
id: i32,
}
impl Query for Leaf {
type Output = i32;
fn query(self, _db: &impl Db) -> Result<Arc<Self::Output>, QueryError> {
Ok(Arc::new(self.id * 10))
}
fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
old == new
}
}
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
struct Parent {
child_id: i32,
}
impl Query for Parent {
type Output = i32;
fn query(self, db: &impl Db) -> Result<Arc<Self::Output>, QueryError> {
let child = db.query(Leaf { id: self.child_id })?;
Ok(Arc::new(*child + 1))
}
fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
old == new
}
}
#[test]
fn test_query_keys_returns_all_cached_queries() {
let runtime = QueryRuntime::new();
let _ = runtime.query(Leaf { id: 1 }).unwrap();
let _ = runtime.query(Leaf { id: 2 }).unwrap();
let _ = runtime.query(Leaf { id: 3 }).unwrap();
let keys = runtime.query_keys();
assert!(keys.len() >= 3);
}
#[test]
fn test_remove_removes_query() {
let runtime = QueryRuntime::new();
let _ = runtime.query(Leaf { id: 1 }).unwrap();
let full_key = QueryCacheKey::new(Leaf { id: 1 }).into();
assert!(runtime.changed_at(&Leaf { id: 1 }).is_some());
assert!(runtime.remove(&full_key));
assert!(runtime.changed_at(&Leaf { id: 1 }).is_none());
}
#[test]
fn test_remove_if_unused_removes_leaf_query() {
let runtime = QueryRuntime::new();
let _ = runtime.query(Leaf { id: 1 }).unwrap();
assert!(runtime.remove_query_if_unused(&Leaf { id: 1 }));
assert!(runtime.changed_at(&Leaf { id: 1 }).is_none());
}
#[test]
fn test_remove_if_unused_does_not_remove_query_with_dependents() {
let runtime = QueryRuntime::new();
let _ = runtime.query(Parent { child_id: 1 }).unwrap();
assert!(!runtime.remove_query_if_unused(&Leaf { id: 1 }));
assert!(runtime.changed_at(&Leaf { id: 1 }).is_some());
assert!(runtime.remove_query_if_unused(&Parent { child_id: 1 }));
}
#[test]
fn test_remove_if_unused_with_full_cache_key() {
let runtime = QueryRuntime::new();
let _ = runtime.query(Leaf { id: 1 }).unwrap();
let full_key = QueryCacheKey::new(Leaf { id: 1 }).into();
assert!(runtime.remove_if_unused(&full_key));
assert!(runtime.changed_at(&Leaf { id: 1 }).is_none());
}
struct GcTracker {
accessed_keys: Mutex<HashSet<String>>,
access_count: AtomicUsize,
}
impl GcTracker {
fn new() -> Self {
Self {
accessed_keys: Mutex::new(HashSet::new()),
access_count: AtomicUsize::new(0),
}
}
}
impl Tracer for GcTracker {
fn new_span_id(&self) -> SpanId {
SpanId(1)
}
fn new_trace_id(&self) -> TraceId {
TraceId(1)
}
fn on_query_start(&self, _ctx: &SpanContext, query_key: &QueryCacheKey) {
self.accessed_keys
.lock()
.unwrap()
.insert(query_key.debug_repr().to_string());
self.access_count.fetch_add(1, Ordering::Relaxed);
}
}
#[test]
fn test_tracer_receives_on_query_start() {
let tracker = GcTracker::new();
let runtime = QueryRuntime::with_tracer(tracker);
let _ = runtime.query(Leaf { id: 1 }).unwrap();
let _ = runtime.query(Leaf { id: 2 }).unwrap();
let count = runtime.tracer().access_count.load(Ordering::Relaxed);
assert_eq!(count, 2);
let keys = runtime.tracer().accessed_keys.lock().unwrap();
assert!(keys.iter().any(|k| k.contains("Leaf")));
}
#[test]
fn test_tracer_receives_on_query_start_for_cache_hits() {
let tracker = GcTracker::new();
let runtime = QueryRuntime::with_tracer(tracker);
let _ = runtime.query(Leaf { id: 1 }).unwrap();
let _ = runtime.query(Leaf { id: 1 }).unwrap();
let count = runtime.tracer().access_count.load(Ordering::Relaxed);
assert_eq!(count, 2);
}
#[test]
fn test_gc_workflow() {
let tracker = GcTracker::new();
let runtime = QueryRuntime::with_tracer(tracker);
let _ = runtime.query(Leaf { id: 1 }).unwrap();
let _ = runtime.query(Leaf { id: 2 }).unwrap();
let _ = runtime.query(Leaf { id: 3 }).unwrap();
let mut removed = 0;
for key in runtime.query_keys() {
if runtime.remove_if_unused(&key) {
removed += 1;
}
}
assert!(removed >= 3);
assert!(runtime.changed_at(&Leaf { id: 1 }).is_none());
assert!(runtime.changed_at(&Leaf { id: 2 }).is_none());
assert!(runtime.changed_at(&Leaf { id: 3 }).is_none());
}
}
}