1use std::any::TypeId;
4use std::cell::RefCell;
5use std::ops::Deref;
6use std::sync::Arc;
7
8use whale::{Durability, RevisionCounter, Runtime as WhaleRuntime};
9
10use crate::asset::{AssetKey, AssetLocator, DurabilityLevel, FullAssetKey, PendingAsset};
11use crate::db::Db;
12use crate::key::FullCacheKey;
13use crate::loading::AssetLoadingState;
14use crate::query::Query;
15use crate::storage::{
16 AssetKeyRegistry, AssetState, AssetStorage, CachedEntry, CachedValue, LocatorStorage,
17 PendingStorage, QueryRegistry, VerifierStorage,
18};
19use crate::tracer::{
20 ExecutionResult, InvalidationReason, NoopTracer, SpanId, Tracer, TracerAssetKey,
21 TracerAssetState, TracerQueryKey,
22};
23use crate::QueryError;
24
25pub type ErrorComparator = fn(&anyhow::Error, &anyhow::Error) -> bool;
30
31const DURABILITY_LEVELS: usize = 4;
33
34thread_local! {
36 static QUERY_STACK: RefCell<Vec<FullCacheKey>> = const { RefCell::new(Vec::new()) };
37}
38
39#[derive(Clone, Copy)]
43pub struct ExecutionContext {
44 span_id: SpanId,
45}
46
47impl ExecutionContext {
48 #[inline]
50 pub fn new(span_id: SpanId) -> Self {
51 Self { span_id }
52 }
53
54 #[inline]
56 pub fn span_id(&self) -> SpanId {
57 self.span_id
58 }
59}
60
61#[derive(Debug, Clone)]
82pub struct Polled<T> {
83 pub value: T,
85 pub revision: RevisionCounter,
89}
90
91impl<T: Deref> Deref for Polled<T> {
92 type Target = T::Target;
93
94 fn deref(&self) -> &Self::Target {
95 &self.value
96 }
97}
98
99pub struct QueryRuntime<T: Tracer = NoopTracer> {
122 whale: WhaleRuntime<FullCacheKey, Option<CachedEntry>, DURABILITY_LEVELS>,
125 assets: Arc<AssetStorage>,
127 locators: Arc<LocatorStorage>,
129 pending: Arc<PendingStorage>,
131 query_registry: Arc<QueryRegistry>,
133 asset_key_registry: Arc<AssetKeyRegistry>,
135 verifiers: Arc<VerifierStorage>,
137 error_comparator: ErrorComparator,
139 tracer: Arc<T>,
141}
142
143impl Default for QueryRuntime<NoopTracer> {
144 fn default() -> Self {
145 Self::new()
146 }
147}
148
149impl<T: Tracer> Clone for QueryRuntime<T> {
150 fn clone(&self) -> Self {
151 Self {
152 whale: self.whale.clone(),
153 assets: self.assets.clone(),
154 locators: self.locators.clone(),
155 pending: self.pending.clone(),
156 query_registry: self.query_registry.clone(),
157 asset_key_registry: self.asset_key_registry.clone(),
158 verifiers: self.verifiers.clone(),
159 error_comparator: self.error_comparator,
160 tracer: self.tracer.clone(),
161 }
162 }
163}
164
165fn default_error_comparator(_a: &anyhow::Error, _b: &anyhow::Error) -> bool {
169 false
170}
171
172impl<T: Tracer> QueryRuntime<T> {
173 fn get_cached_with_revision<Q: Query>(
175 &self,
176 key: &FullCacheKey,
177 ) -> Option<(CachedValue<Arc<Q::Output>>, RevisionCounter)> {
178 let node = self.whale.get(key)?;
179 let revision = node.changed_at;
180 let entry = node.data.as_ref()?;
181 let cached = entry.to_cached_value::<Q::Output>()?;
182 Some((cached, revision))
183 }
184
185 #[inline]
187 pub fn tracer(&self) -> &T {
188 &self.tracer
189 }
190}
191
192impl QueryRuntime<NoopTracer> {
193 pub fn new() -> Self {
195 Self::with_tracer(NoopTracer)
196 }
197
198 pub fn builder() -> QueryRuntimeBuilder<NoopTracer> {
214 QueryRuntimeBuilder::new()
215 }
216}
217
218impl<T: Tracer> QueryRuntime<T> {
219 pub fn with_tracer(tracer: T) -> Self {
221 QueryRuntimeBuilder::new().tracer(tracer).build()
222 }
223
224 pub fn query<Q: Query>(&self, query: Q) -> Result<Arc<Q::Output>, QueryError> {
233 self.query_internal(query)
234 .and_then(|(inner_result, _)| inner_result.map_err(QueryError::UserError))
235 }
236
237 #[allow(clippy::type_complexity)]
242 fn query_internal<Q: Query>(
243 &self,
244 query: Q,
245 ) -> Result<(Result<Arc<Q::Output>, Arc<anyhow::Error>>, RevisionCounter), QueryError> {
246 let key = query.cache_key();
247 let full_key = FullCacheKey::new::<Q, _>(&key);
248
249 let span_id = self.tracer.new_span_id();
251 let exec_ctx = ExecutionContext::new(span_id);
252 let query_key = TracerQueryKey::new(std::any::type_name::<Q>(), full_key.debug_repr());
253
254 self.tracer.on_query_start(span_id, query_key.clone());
255
256 let cycle_detected = QUERY_STACK.with(|stack| {
258 let stack = stack.borrow();
259 stack.iter().any(|k| k == &full_key)
260 });
261
262 if cycle_detected {
263 let path = QUERY_STACK.with(|stack| {
264 let stack = stack.borrow();
265 let mut path: Vec<String> =
266 stack.iter().map(|k| k.debug_repr().to_string()).collect();
267 path.push(full_key.debug_repr().to_string());
268 path
269 });
270
271 self.tracer.on_cycle_detected(
272 path.iter()
273 .map(|s| TracerQueryKey::new("", s.clone()))
274 .collect(),
275 );
276 self.tracer
277 .on_query_end(span_id, query_key.clone(), ExecutionResult::CycleDetected);
278
279 return Err(QueryError::Cycle { path });
280 }
281
282 let current_rev = self.whale.current_revision();
284
285 if self.whale.is_verified_at(&full_key, ¤t_rev) {
287 if let Some((cached, revision)) = self.get_cached_with_revision::<Q>(&full_key) {
289 self.tracer.on_cache_check(span_id, query_key.clone(), true);
290 self.tracer
291 .on_query_end(span_id, query_key.clone(), ExecutionResult::CacheHit);
292
293 return match cached {
294 CachedValue::Ok(output) => Ok((Ok(output), revision)),
295 CachedValue::UserError(err) => Ok((Err(err), revision)),
296 };
297 }
298 }
299
300 if self.whale.is_valid(&full_key) {
302 if let Some((cached, revision)) = self.get_cached_with_revision::<Q>(&full_key) {
304 let mut deps_verified = true;
306 if let Some(deps) = self.whale.get_dependency_ids(&full_key) {
307 for dep in deps {
308 if let Some(verifier) = self.verifiers.get(&dep) {
309 if verifier.verify(self as &dyn std::any::Any).is_err() {
311 deps_verified = false;
312 break;
313 }
314 }
315 }
316 }
317
318 if deps_verified && self.whale.is_valid(&full_key) {
320 self.whale.mark_verified(&full_key, ¤t_rev);
322
323 self.tracer.on_cache_check(span_id, query_key.clone(), true);
324 self.tracer
325 .on_query_end(span_id, query_key.clone(), ExecutionResult::CacheHit);
326
327 return match cached {
328 CachedValue::Ok(output) => Ok((Ok(output), revision)),
329 CachedValue::UserError(err) => Ok((Err(err), revision)),
330 };
331 }
332 }
334 }
335
336 self.tracer
337 .on_cache_check(span_id, query_key.clone(), false);
338
339 QUERY_STACK.with(|stack| {
341 stack.borrow_mut().push(full_key.clone());
342 });
343
344 let result = self.execute_query::<Q>(&query, &full_key, exec_ctx);
345
346 QUERY_STACK.with(|stack| {
347 stack.borrow_mut().pop();
348 });
349
350 let exec_result = match &result {
352 Ok((_, true, _)) => ExecutionResult::Changed,
353 Ok((_, false, _)) => ExecutionResult::Unchanged,
354 Err(QueryError::Suspend { .. }) => ExecutionResult::Suspended,
355 Err(QueryError::Cycle { .. }) => ExecutionResult::CycleDetected,
356 Err(e) => ExecutionResult::Error {
357 message: format!("{:?}", e),
358 },
359 };
360 self.tracer
361 .on_query_end(span_id, query_key.clone(), exec_result);
362
363 result.map(|(inner_result, _, revision)| (inner_result, revision))
364 }
365
366 #[allow(clippy::type_complexity)]
372 fn execute_query<Q: Query>(
373 &self,
374 query: &Q,
375 full_key: &FullCacheKey,
376 exec_ctx: ExecutionContext,
377 ) -> Result<
378 (
379 Result<Arc<Q::Output>, Arc<anyhow::Error>>,
380 bool,
381 RevisionCounter,
382 ),
383 QueryError,
384 > {
385 let ctx = QueryContext {
387 runtime: self,
388 current_key: full_key.clone(),
389 parent_query_type: std::any::type_name::<Q>(),
390 exec_ctx,
391 deps: RefCell::new(Vec::new()),
392 };
393
394 let result = query.clone().query(&ctx);
396
397 let deps: Vec<FullCacheKey> = ctx.deps.borrow().clone();
399
400 let durability =
402 Durability::new(query.durability() as usize).unwrap_or(Durability::volatile());
403
404 match result {
405 Ok(output) => {
406 let output = Arc::new(output);
407
408 let existing_revision = if let Some((CachedValue::Ok(old), rev)) =
411 self.get_cached_with_revision::<Q>(full_key)
412 {
413 if Q::output_eq(&old, &output) {
414 Some(rev) } else {
416 None }
418 } else {
419 None };
421 let output_changed = existing_revision.is_none();
422
423 self.tracer.on_early_cutoff_check(
425 exec_ctx.span_id(),
426 TracerQueryKey::new(std::any::type_name::<Q>(), full_key.debug_repr()),
427 output_changed,
428 );
429
430 let entry = CachedEntry::Ok(output.clone() as Arc<dyn std::any::Any + Send + Sync>);
432 let revision = if let Some(existing_rev) = existing_revision {
433 let _ = self.whale.confirm_unchanged(full_key, deps);
435 existing_rev
436 } else {
437 match self
439 .whale
440 .register(full_key.clone(), Some(entry), durability, deps)
441 {
442 Ok(result) => result.new_rev,
443 Err(missing) => {
444 return Err(QueryError::DependenciesRemoved {
445 missing_keys: missing,
446 })
447 }
448 }
449 };
450
451 let is_new_query = self.query_registry.register(query);
453 if is_new_query {
454 let sentinel = FullCacheKey::query_set_sentinel::<Q>();
455 let _ = self
456 .whale
457 .register(sentinel, None, Durability::volatile(), vec![]);
458 }
459
460 self.verifiers
462 .insert::<Q, T>(full_key.clone(), query.clone());
463
464 Ok((Ok(output), output_changed, revision))
465 }
466 Err(QueryError::UserError(err)) => {
467 let existing_revision = if let Some((CachedValue::UserError(old_err), rev)) =
470 self.get_cached_with_revision::<Q>(full_key)
471 {
472 if (self.error_comparator)(old_err.as_ref(), err.as_ref()) {
473 Some(rev) } else {
475 None }
477 } else {
478 None };
480 let output_changed = existing_revision.is_none();
481
482 self.tracer.on_early_cutoff_check(
484 exec_ctx.span_id(),
485 TracerQueryKey::new(std::any::type_name::<Q>(), full_key.debug_repr()),
486 output_changed,
487 );
488
489 let entry = CachedEntry::UserError(err.clone());
491 let revision = if let Some(existing_rev) = existing_revision {
492 let _ = self.whale.confirm_unchanged(full_key, deps);
494 existing_rev
495 } else {
496 match self
498 .whale
499 .register(full_key.clone(), Some(entry), durability, deps)
500 {
501 Ok(result) => result.new_rev,
502 Err(missing) => {
503 return Err(QueryError::DependenciesRemoved {
504 missing_keys: missing,
505 })
506 }
507 }
508 };
509
510 let is_new_query = self.query_registry.register(query);
512 if is_new_query {
513 let sentinel = FullCacheKey::query_set_sentinel::<Q>();
514 let _ = self
515 .whale
516 .register(sentinel, None, Durability::volatile(), vec![]);
517 }
518
519 self.verifiers
521 .insert::<Q, T>(full_key.clone(), query.clone());
522
523 Ok((Err(err), output_changed, revision))
524 }
525 Err(e) => {
526 Err(e)
528 }
529 }
530 }
531
532 pub fn invalidate<Q: Query>(&self, key: &Q::CacheKey) {
536 let full_key = FullCacheKey::new::<Q, _>(key);
537
538 self.tracer.on_query_invalidated(
539 TracerQueryKey::new(std::any::type_name::<Q>(), full_key.debug_repr()),
540 InvalidationReason::ManualInvalidation,
541 );
542
543 let _ = self
545 .whale
546 .register(full_key, None, Durability::volatile(), vec![]);
547 }
548
549 pub fn remove_query<Q: Query>(&self, key: &Q::CacheKey) {
557 let full_key = FullCacheKey::new::<Q, _>(key);
558
559 self.tracer.on_query_invalidated(
560 TracerQueryKey::new(std::any::type_name::<Q>(), full_key.debug_repr()),
561 InvalidationReason::ManualInvalidation,
562 );
563
564 self.verifiers.remove(&full_key);
566
567 self.whale.remove(&full_key);
569
570 if self.query_registry.remove::<Q>(key) {
572 let sentinel = FullCacheKey::query_set_sentinel::<Q>();
573 let _ = self
574 .whale
575 .register(sentinel, None, Durability::volatile(), vec![]);
576 }
577 }
578
579 pub fn clear_cache(&self) {
583 let keys = self.whale.keys();
584 for key in keys {
585 self.whale.remove(&key);
586 }
587 }
588
589 #[allow(clippy::type_complexity)]
623 pub fn poll<Q: Query>(
624 &self,
625 query: Q,
626 ) -> Result<Polled<Result<Arc<Q::Output>, Arc<anyhow::Error>>>, QueryError> {
627 let (value, revision) = self.query_internal(query)?;
628 Ok(Polled { value, revision })
629 }
630
631 pub fn changed_at<Q: Query>(&self, key: &Q::CacheKey) -> Option<RevisionCounter> {
650 let full_key = FullCacheKey::new::<Q, _>(key);
651 self.whale.get(&full_key).map(|node| node.changed_at)
652 }
653}
654
655pub struct QueryRuntimeBuilder<T: Tracer = NoopTracer> {
673 error_comparator: ErrorComparator,
674 tracer: T,
675}
676
677impl Default for QueryRuntimeBuilder<NoopTracer> {
678 fn default() -> Self {
679 Self::new()
680 }
681}
682
683impl QueryRuntimeBuilder<NoopTracer> {
684 pub fn new() -> Self {
686 Self {
687 error_comparator: default_error_comparator,
688 tracer: NoopTracer,
689 }
690 }
691}
692
693impl<T: Tracer> QueryRuntimeBuilder<T> {
694 pub fn error_comparator(mut self, f: ErrorComparator) -> Self {
712 self.error_comparator = f;
713 self
714 }
715
716 pub fn tracer<U: Tracer>(self, tracer: U) -> QueryRuntimeBuilder<U> {
718 QueryRuntimeBuilder {
719 error_comparator: self.error_comparator,
720 tracer,
721 }
722 }
723
724 pub fn build(self) -> QueryRuntime<T> {
726 QueryRuntime {
727 whale: WhaleRuntime::new(),
728 assets: Arc::new(AssetStorage::new()),
729 locators: Arc::new(LocatorStorage::new()),
730 pending: Arc::new(PendingStorage::new()),
731 query_registry: Arc::new(QueryRegistry::new()),
732 asset_key_registry: Arc::new(AssetKeyRegistry::new()),
733 verifiers: Arc::new(VerifierStorage::new()),
734 error_comparator: self.error_comparator,
735 tracer: Arc::new(self.tracer),
736 }
737 }
738}
739
740impl<T: Tracer> QueryRuntime<T> {
745 pub fn register_asset_locator<K, L>(&self, locator: L)
757 where
758 K: AssetKey,
759 L: AssetLocator<K>,
760 {
761 self.locators.insert::<K, L>(locator);
762 }
763
764 pub fn pending_assets(&self) -> Vec<PendingAsset> {
780 self.pending.get_all()
781 }
782
783 pub fn pending_assets_of<K: AssetKey>(&self) -> Vec<K> {
785 self.pending.get_of_type::<K>()
786 }
787
788 pub fn has_pending_assets(&self) -> bool {
790 !self.pending.is_empty()
791 }
792
793 pub fn resolve_asset<K: AssetKey>(&self, key: K, value: K::Asset) {
810 let durability = key.durability();
811 self.resolve_asset_internal(key, value, durability);
812 }
813
814 pub fn resolve_asset_with_durability<K: AssetKey>(
818 &self,
819 key: K,
820 value: K::Asset,
821 durability: DurabilityLevel,
822 ) {
823 self.resolve_asset_internal(key, value, durability);
824 }
825
826 fn resolve_asset_internal<K: AssetKey>(
827 &self,
828 key: K,
829 value: K::Asset,
830 durability_level: DurabilityLevel,
831 ) {
832 let full_asset_key = FullAssetKey::new(&key);
833 let full_cache_key = FullCacheKey::from_asset_key(&full_asset_key);
834
835 let changed = if let Some(old_value) = self.assets.get_ready::<K>(&full_asset_key) {
837 !K::asset_eq(&old_value, &value)
838 } else {
839 true };
841
842 self.tracer.on_asset_resolved(
844 TracerAssetKey::new(std::any::type_name::<K>(), format!("{:?}", key)),
845 changed,
846 );
847
848 self.assets
850 .insert_ready::<K>(full_asset_key.clone(), Arc::new(value));
851
852 self.pending.remove(&full_asset_key);
854
855 let durability =
857 Durability::new(durability_level.as_u8() as usize).unwrap_or(Durability::volatile());
858
859 if changed {
860 let _ = self
862 .whale
863 .register(full_cache_key, None, durability, vec![]);
864 } else {
865 let _ = self.whale.confirm_unchanged(&full_cache_key, vec![]);
867 }
868
869 let is_new_asset = self.asset_key_registry.register(&key);
871 if is_new_asset {
872 let sentinel = FullCacheKey::asset_key_set_sentinel::<K>();
874 let _ = self
875 .whale
876 .register(sentinel, None, Durability::volatile(), vec![]);
877 }
878 }
879
880 pub fn invalidate_asset<K: AssetKey>(&self, key: &K) {
894 let full_asset_key = FullAssetKey::new(key);
895 let full_cache_key = FullCacheKey::from_asset_key(&full_asset_key);
896
897 self.tracer.on_asset_invalidated(TracerAssetKey::new(
899 std::any::type_name::<K>(),
900 format!("{:?}", key),
901 ));
902
903 self.assets
905 .insert(full_asset_key.clone(), AssetState::Loading);
906
907 self.pending.insert::<K>(full_asset_key, key.clone());
909
910 let _ = self
912 .whale
913 .register(full_cache_key, None, Durability::volatile(), vec![]);
914 }
915
916 pub fn remove_asset<K: AssetKey>(&self, key: &K) {
921 let full_asset_key = FullAssetKey::new(key);
922 let full_cache_key = FullCacheKey::from_asset_key(&full_asset_key);
923
924 let _ = self
927 .whale
928 .register(full_cache_key.clone(), None, Durability::volatile(), vec![]);
929
930 self.assets.remove(&full_asset_key);
932 self.pending.remove(&full_asset_key);
933
934 self.whale.remove(&full_cache_key);
936
937 if self.asset_key_registry.remove::<K>(key) {
939 let sentinel = FullCacheKey::asset_key_set_sentinel::<K>();
940 let _ = self
941 .whale
942 .register(sentinel, None, Durability::volatile(), vec![]);
943 }
944 }
945
946 pub fn get_asset<K: AssetKey>(&self, key: K) -> Result<AssetLoadingState<K>, QueryError> {
958 self.get_asset_internal(key)
959 }
960
961 fn get_asset_internal<K: AssetKey>(&self, key: K) -> Result<AssetLoadingState<K>, QueryError> {
963 let full_asset_key = FullAssetKey::new(&key);
964 let full_cache_key = FullCacheKey::from_asset_key(&full_asset_key);
965
966 let emit_requested = |tracer: &T, key: &K, state: TracerAssetState| {
968 tracer.on_asset_requested(
969 TracerAssetKey::new(std::any::type_name::<K>(), format!("{:?}", key)),
970 state,
971 );
972 };
973
974 if let Some(state) = self.assets.get(&full_asset_key) {
976 if self.whale.is_valid(&full_cache_key) {
978 return match state {
979 AssetState::Loading => {
980 emit_requested(&self.tracer, &key, TracerAssetState::Loading);
981 Ok(AssetLoadingState::loading(key))
982 }
983 AssetState::Ready(arc) => {
984 emit_requested(&self.tracer, &key, TracerAssetState::Ready);
985 match arc.downcast::<K::Asset>() {
986 Ok(value) => Ok(AssetLoadingState::ready(key, value)),
987 Err(_) => Err(QueryError::MissingDependency {
988 description: format!("Asset type mismatch: {:?}", key),
989 }),
990 }
991 }
992 AssetState::NotFound => {
993 emit_requested(&self.tracer, &key, TracerAssetState::NotFound);
994 Err(QueryError::MissingDependency {
995 description: format!("Asset not found: {:?}", key),
996 })
997 }
998 };
999 }
1000 }
1001
1002 if let Some(locator) = self.locators.get(TypeId::of::<K>()) {
1004 if let Some(state) = locator.locate_any(&key) {
1005 self.assets.insert(full_asset_key.clone(), state.clone());
1007
1008 match state {
1009 AssetState::Ready(arc) => {
1010 emit_requested(&self.tracer, &key, TracerAssetState::Ready);
1011
1012 let durability = Durability::new(key.durability().as_u8() as usize)
1014 .unwrap_or(Durability::volatile());
1015 self.whale
1016 .register(full_cache_key, None, durability, vec![])
1017 .expect("register with no dependencies cannot fail");
1018
1019 match arc.downcast::<K::Asset>() {
1020 Ok(value) => return Ok(AssetLoadingState::ready(key, value)),
1021 Err(_) => {
1022 return Err(QueryError::MissingDependency {
1023 description: format!("Asset type mismatch: {:?}", key),
1024 })
1025 }
1026 }
1027 }
1028 AssetState::Loading => {
1029 emit_requested(&self.tracer, &key, TracerAssetState::Loading);
1030 self.pending.insert::<K>(full_asset_key, key.clone());
1031
1032 self.whale
1034 .register(full_cache_key, None, Durability::volatile(), vec![])
1035 .expect("register with no dependencies cannot fail");
1036
1037 return Ok(AssetLoadingState::loading(key));
1038 }
1039 AssetState::NotFound => {
1040 emit_requested(&self.tracer, &key, TracerAssetState::NotFound);
1041 return Err(QueryError::MissingDependency {
1042 description: format!("Asset not found: {:?}", key),
1043 });
1044 }
1045 }
1046 }
1047 }
1048
1049 emit_requested(&self.tracer, &key, TracerAssetState::Loading);
1051 self.assets
1052 .insert(full_asset_key.clone(), AssetState::Loading);
1053 self.pending
1054 .insert::<K>(full_asset_key.clone(), key.clone());
1055
1056 self.whale
1058 .register(full_cache_key, None, Durability::volatile(), vec![])
1059 .expect("register with no dependencies cannot fail");
1060
1061 Ok(AssetLoadingState::loading(key))
1062 }
1063}
1064
1065impl<T: Tracer> Db for QueryRuntime<T> {
1066 fn query<Q: Query>(&self, query: Q) -> Result<Arc<Q::Output>, QueryError> {
1067 QueryRuntime::query(self, query)
1068 }
1069
1070 fn asset<K: AssetKey>(&self, key: K) -> Result<AssetLoadingState<K>, QueryError> {
1071 self.get_asset_internal(key)
1072 }
1073
1074 fn list_queries<Q: Query>(&self) -> Vec<Q> {
1075 self.query_registry.get_all::<Q>()
1076 }
1077
1078 fn list_asset_keys<K: AssetKey>(&self) -> Vec<K> {
1079 self.asset_key_registry.get_all::<K>()
1080 }
1081}
1082
1083pub struct QueryContext<'a, T: Tracer = NoopTracer> {
1087 runtime: &'a QueryRuntime<T>,
1088 current_key: FullCacheKey,
1089 parent_query_type: &'static str,
1090 exec_ctx: ExecutionContext,
1091 deps: RefCell<Vec<FullCacheKey>>,
1092}
1093
1094impl<'a, T: Tracer> QueryContext<'a, T> {
1095 pub fn query<Q: Query>(&self, query: Q) -> Result<Arc<Q::Output>, QueryError> {
1108 let key = query.cache_key();
1109 let full_key = FullCacheKey::new::<Q, _>(&key);
1110
1111 self.runtime.tracer.on_dependency_registered(
1113 self.exec_ctx.span_id(),
1114 TracerQueryKey::new(self.parent_query_type, self.current_key.debug_repr()),
1115 TracerQueryKey::new(std::any::type_name::<Q>(), full_key.debug_repr()),
1116 );
1117
1118 self.deps.borrow_mut().push(full_key.clone());
1120
1121 self.runtime.query(query)
1123 }
1124
1125 pub fn asset<K: AssetKey>(&self, key: K) -> Result<AssetLoadingState<K>, QueryError> {
1149 let full_asset_key = FullAssetKey::new(&key);
1150 let full_cache_key = FullCacheKey::from_asset_key(&full_asset_key);
1151
1152 self.runtime.tracer.on_asset_dependency_registered(
1154 self.exec_ctx.span_id(),
1155 TracerQueryKey::new(self.parent_query_type, self.current_key.debug_repr()),
1156 TracerAssetKey::new(std::any::type_name::<K>(), format!("{:?}", key)),
1157 );
1158
1159 self.deps.borrow_mut().push(full_cache_key);
1161
1162 let result = self.runtime.get_asset_internal(key);
1164
1165 if let Err(QueryError::MissingDependency { ref description }) = result {
1167 self.runtime.tracer.on_missing_dependency(
1168 TracerQueryKey::new(self.parent_query_type, self.current_key.debug_repr()),
1169 description.clone(),
1170 );
1171 }
1172
1173 result
1174 }
1175
1176 pub fn list_queries<Q: Query>(&self) -> Vec<Q> {
1199 let sentinel = FullCacheKey::query_set_sentinel::<Q>();
1201
1202 self.runtime.tracer.on_dependency_registered(
1203 self.exec_ctx.span_id(),
1204 TracerQueryKey::new(self.parent_query_type, self.current_key.debug_repr()),
1205 TracerQueryKey::new("QuerySet", sentinel.debug_repr()),
1206 );
1207
1208 if self.runtime.whale.get(&sentinel).is_none() {
1210 let _ =
1211 self.runtime
1212 .whale
1213 .register(sentinel.clone(), None, Durability::volatile(), vec![]);
1214 }
1215
1216 self.deps.borrow_mut().push(sentinel);
1217
1218 self.runtime.query_registry.get_all::<Q>()
1220 }
1221
1222 pub fn list_asset_keys<K: AssetKey>(&self) -> Vec<K> {
1247 let sentinel = FullCacheKey::asset_key_set_sentinel::<K>();
1249
1250 self.runtime.tracer.on_asset_dependency_registered(
1251 self.exec_ctx.span_id(),
1252 TracerQueryKey::new(self.parent_query_type, self.current_key.debug_repr()),
1253 TracerAssetKey::new("AssetKeySet", sentinel.debug_repr()),
1254 );
1255
1256 if self.runtime.whale.get(&sentinel).is_none() {
1258 let _ =
1259 self.runtime
1260 .whale
1261 .register(sentinel.clone(), None, Durability::volatile(), vec![]);
1262 }
1263
1264 self.deps.borrow_mut().push(sentinel);
1265
1266 self.runtime.asset_key_registry.get_all::<K>()
1268 }
1269}
1270
1271impl<'a, T: Tracer> Db for QueryContext<'a, T> {
1272 fn query<Q: Query>(&self, query: Q) -> Result<Arc<Q::Output>, QueryError> {
1273 QueryContext::query(self, query)
1274 }
1275
1276 fn asset<K: AssetKey>(&self, key: K) -> Result<AssetLoadingState<K>, QueryError> {
1277 QueryContext::asset(self, key)
1278 }
1279
1280 fn list_queries<Q: Query>(&self) -> Vec<Q> {
1281 QueryContext::list_queries(self)
1282 }
1283
1284 fn list_asset_keys<K: AssetKey>(&self) -> Vec<K> {
1285 QueryContext::list_asset_keys(self)
1286 }
1287}
1288
1289#[cfg(test)]
1290mod tests {
1291 use super::*;
1292
1293 #[test]
1294 fn test_simple_query() {
1295 #[derive(Clone)]
1296 struct Add {
1297 a: i32,
1298 b: i32,
1299 }
1300
1301 impl Query for Add {
1302 type CacheKey = (i32, i32);
1303 type Output = i32;
1304
1305 fn cache_key(&self) -> Self::CacheKey {
1306 (self.a, self.b)
1307 }
1308
1309 fn query(self, _db: &impl Db) -> Result<Self::Output, QueryError> {
1310 Ok(self.a + self.b)
1311 }
1312
1313 fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
1314 old == new
1315 }
1316 }
1317
1318 let runtime = QueryRuntime::new();
1319
1320 let result = runtime.query(Add { a: 1, b: 2 }).unwrap();
1321 assert_eq!(*result, 3);
1322
1323 let result2 = runtime.query(Add { a: 1, b: 2 }).unwrap();
1325 assert_eq!(*result2, 3);
1326 }
1327
1328 #[test]
1329 fn test_dependent_queries() {
1330 #[derive(Clone)]
1331 struct Base {
1332 value: i32,
1333 }
1334
1335 impl Query for Base {
1336 type CacheKey = i32;
1337 type Output = i32;
1338
1339 fn cache_key(&self) -> Self::CacheKey {
1340 self.value
1341 }
1342
1343 fn query(self, _db: &impl Db) -> Result<Self::Output, QueryError> {
1344 Ok(self.value * 2)
1345 }
1346
1347 fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
1348 old == new
1349 }
1350 }
1351
1352 #[derive(Clone)]
1353 struct Derived {
1354 base_value: i32,
1355 }
1356
1357 impl Query for Derived {
1358 type CacheKey = i32;
1359 type Output = i32;
1360
1361 fn cache_key(&self) -> Self::CacheKey {
1362 self.base_value
1363 }
1364
1365 fn query(self, db: &impl Db) -> Result<Self::Output, QueryError> {
1366 let base = db.query(Base {
1367 value: self.base_value,
1368 })?;
1369 Ok(*base + 10)
1370 }
1371
1372 fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
1373 old == new
1374 }
1375 }
1376
1377 let runtime = QueryRuntime::new();
1378
1379 let result = runtime.query(Derived { base_value: 5 }).unwrap();
1380 assert_eq!(*result, 20); }
1382
1383 #[test]
1384 fn test_cycle_detection() {
1385 #[derive(Clone)]
1386 struct CycleA {
1387 id: i32,
1388 }
1389
1390 #[derive(Clone)]
1391 struct CycleB {
1392 id: i32,
1393 }
1394
1395 impl Query for CycleA {
1396 type CacheKey = i32;
1397 type Output = i32;
1398
1399 fn cache_key(&self) -> Self::CacheKey {
1400 self.id
1401 }
1402
1403 fn query(self, db: &impl Db) -> Result<Self::Output, QueryError> {
1404 let b = db.query(CycleB { id: self.id })?;
1405 Ok(*b + 1)
1406 }
1407
1408 fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
1409 old == new
1410 }
1411 }
1412
1413 impl Query for CycleB {
1414 type CacheKey = i32;
1415 type Output = i32;
1416
1417 fn cache_key(&self) -> Self::CacheKey {
1418 self.id
1419 }
1420
1421 fn query(self, db: &impl Db) -> Result<Self::Output, QueryError> {
1422 let a = db.query(CycleA { id: self.id })?;
1423 Ok(*a + 1)
1424 }
1425
1426 fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
1427 old == new
1428 }
1429 }
1430
1431 let runtime = QueryRuntime::new();
1432
1433 let result = runtime.query(CycleA { id: 1 });
1434 assert!(matches!(result, Err(QueryError::Cycle { .. })));
1435 }
1436
1437 #[test]
1438 fn test_fallible_query() {
1439 #[derive(Clone)]
1440 struct ParseInt {
1441 input: String,
1442 }
1443
1444 impl Query for ParseInt {
1445 type CacheKey = String;
1446 type Output = Result<i32, std::num::ParseIntError>;
1447
1448 fn cache_key(&self) -> Self::CacheKey {
1449 self.input.clone()
1450 }
1451
1452 fn query(self, _db: &impl Db) -> Result<Self::Output, QueryError> {
1453 Ok(self.input.parse())
1454 }
1455
1456 fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
1457 old == new
1458 }
1459 }
1460
1461 let runtime = QueryRuntime::new();
1462
1463 let result = runtime
1465 .query(ParseInt {
1466 input: "42".to_string(),
1467 })
1468 .unwrap();
1469 assert_eq!(*result, Ok(42));
1470
1471 let result = runtime
1473 .query(ParseInt {
1474 input: "not_a_number".to_string(),
1475 })
1476 .unwrap();
1477 assert!(result.is_err());
1478 }
1479
1480 mod macro_tests {
1482 use super::*;
1483 use crate::query;
1484
1485 #[query]
1486 fn add(db: &impl Db, a: i32, b: i32) -> Result<i32, QueryError> {
1487 let _ = db; Ok(a + b)
1489 }
1490
1491 #[test]
1492 fn test_macro_basic() {
1493 let runtime = QueryRuntime::new();
1494 let result = runtime.query(Add::new(1, 2)).unwrap();
1495 assert_eq!(*result, 3);
1496 }
1497
1498 #[query(durability = 2)]
1499 fn with_durability(db: &impl Db, x: i32) -> Result<i32, QueryError> {
1500 let _ = db;
1501 Ok(x * 2)
1502 }
1503
1504 #[test]
1505 fn test_macro_durability() {
1506 let runtime = QueryRuntime::new();
1507 let result = runtime.query(WithDurability::new(5)).unwrap();
1508 assert_eq!(*result, 10);
1509 }
1510
1511 #[query(keys(id))]
1512 fn with_key_selection(
1513 db: &impl Db,
1514 id: u32,
1515 include_extra: bool,
1516 ) -> Result<String, QueryError> {
1517 let _ = db;
1518 Ok(format!("id={}, extra={}", id, include_extra))
1519 }
1520
1521 #[test]
1522 fn test_macro_key_selection() {
1523 let runtime = QueryRuntime::new();
1524
1525 let r1 = runtime.query(WithKeySelection::new(1, true)).unwrap();
1527 let r2 = runtime.query(WithKeySelection::new(1, false)).unwrap();
1528
1529 assert_eq!(*r1, "id=1, extra=true");
1531 assert_eq!(*r2, "id=1, extra=true"); }
1533
1534 #[query]
1535 fn dependent(db: &impl Db, a: i32, b: i32) -> Result<i32, QueryError> {
1536 let sum = db.query(Add::new(a, b))?;
1537 Ok(*sum * 2)
1538 }
1539
1540 #[test]
1541 fn test_macro_dependencies() {
1542 let runtime = QueryRuntime::new();
1543 let result = runtime.query(Dependent::new(3, 4)).unwrap();
1544 assert_eq!(*result, 14); }
1546
1547 #[query(output_eq)]
1548 fn with_output_eq(db: &impl Db, x: i32) -> Result<i32, QueryError> {
1549 let _ = db;
1550 Ok(x * 2)
1551 }
1552
1553 #[test]
1554 fn test_macro_output_eq() {
1555 let runtime = QueryRuntime::new();
1556 let result = runtime.query(WithOutputEq::new(5)).unwrap();
1557 assert_eq!(*result, 10);
1558 }
1559
1560 #[query(name = "CustomName")]
1561 fn original_name(db: &impl Db, x: i32) -> Result<i32, QueryError> {
1562 let _ = db;
1563 Ok(x)
1564 }
1565
1566 #[test]
1567 fn test_macro_custom_name() {
1568 let runtime = QueryRuntime::new();
1569 let result = runtime.query(CustomName::new(42)).unwrap();
1570 assert_eq!(*result, 42);
1571 }
1572
1573 #[allow(unused_variables)]
1577 #[inline]
1578 #[query]
1579 fn with_attributes(db: &impl Db, x: i32) -> Result<i32, QueryError> {
1580 let unused_var = 42;
1582 Ok(x * 2)
1583 }
1584
1585 #[test]
1586 fn test_macro_preserves_attributes() {
1587 let runtime = QueryRuntime::new();
1588 let result = runtime.query(WithAttributes::new(5)).unwrap();
1590 assert_eq!(*result, 10);
1591 }
1592 }
1593
1594 mod poll_tests {
1596 use super::*;
1597
1598 #[derive(Clone)]
1599 struct Counter {
1600 id: i32,
1601 }
1602
1603 impl Query for Counter {
1604 type CacheKey = i32;
1605 type Output = i32;
1606
1607 fn cache_key(&self) -> Self::CacheKey {
1608 self.id
1609 }
1610
1611 fn query(self, _db: &impl Db) -> Result<Self::Output, QueryError> {
1612 Ok(self.id * 10)
1613 }
1614
1615 fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
1616 old == new
1617 }
1618 }
1619
1620 #[test]
1621 fn test_poll_returns_value_and_revision() {
1622 let runtime = QueryRuntime::new();
1623
1624 let result = runtime.poll(Counter { id: 1 }).unwrap();
1625
1626 assert_eq!(**result.value.as_ref().unwrap(), 10);
1628
1629 assert!(result.revision > 0);
1631 }
1632
1633 #[test]
1634 fn test_poll_revision_stable_on_cache_hit() {
1635 let runtime = QueryRuntime::new();
1636
1637 let result1 = runtime.poll(Counter { id: 1 }).unwrap();
1639 let rev1 = result1.revision;
1640
1641 let result2 = runtime.poll(Counter { id: 1 }).unwrap();
1643 let rev2 = result2.revision;
1644
1645 assert_eq!(rev1, rev2);
1647 }
1648
1649 #[test]
1650 fn test_poll_revision_changes_on_invalidate() {
1651 let runtime = QueryRuntime::new();
1652
1653 let result1 = runtime.poll(Counter { id: 1 }).unwrap();
1655 let rev1 = result1.revision;
1656
1657 runtime.invalidate::<Counter>(&1);
1659 let result2 = runtime.poll(Counter { id: 1 }).unwrap();
1660 let rev2 = result2.revision;
1661
1662 assert_eq!(**result2.value.as_ref().unwrap(), 10);
1666
1667 assert!(rev2 >= rev1);
1670 }
1671
1672 #[test]
1673 fn test_changed_at_returns_none_for_unexecuted_query() {
1674 let runtime = QueryRuntime::new();
1675
1676 let rev = runtime.changed_at::<Counter>(&1);
1678 assert!(rev.is_none());
1679 }
1680
1681 #[test]
1682 fn test_changed_at_returns_revision_after_execution() {
1683 let runtime = QueryRuntime::new();
1684
1685 let _ = runtime.query(Counter { id: 1 }).unwrap();
1687
1688 let rev = runtime.changed_at::<Counter>(&1);
1690 assert!(rev.is_some());
1691 assert!(rev.unwrap() > 0);
1692 }
1693
1694 #[test]
1695 fn test_changed_at_matches_poll_revision() {
1696 let runtime = QueryRuntime::new();
1697
1698 let result = runtime.poll(Counter { id: 1 }).unwrap();
1700
1701 let rev = runtime.changed_at::<Counter>(&1);
1703 assert_eq!(rev, Some(result.revision));
1704 }
1705
1706 #[test]
1707 fn test_poll_value_access() {
1708 let runtime = QueryRuntime::new();
1709
1710 let result = runtime.poll(Counter { id: 5 }).unwrap();
1711
1712 let value: &i32 = result.value.as_ref().unwrap();
1714 assert_eq!(*value, 50);
1715
1716 let arc: &Arc<i32> = result.value.as_ref().unwrap();
1718 assert_eq!(**arc, 50);
1719 }
1720
1721 #[test]
1722 fn test_subscription_pattern() {
1723 let runtime = QueryRuntime::new();
1724
1725 let mut last_revision: RevisionCounter = 0;
1727 let mut notifications = 0;
1728
1729 let result = runtime.poll(Counter { id: 1 }).unwrap();
1731 if result.revision > last_revision {
1732 notifications += 1;
1733 last_revision = result.revision;
1734 }
1735
1736 let result = runtime.poll(Counter { id: 1 }).unwrap();
1738 if result.revision > last_revision {
1739 notifications += 1;
1740 last_revision = result.revision;
1741 }
1742
1743 let result = runtime.poll(Counter { id: 1 }).unwrap();
1745 if result.revision > last_revision {
1746 notifications += 1;
1747 #[allow(unused_assignments)]
1748 {
1749 last_revision = result.revision;
1750 }
1751 }
1752
1753 assert_eq!(notifications, 1);
1755 }
1756 }
1757}