1use std::any::TypeId;
4use std::cell::RefCell;
5use std::ops::Deref;
6use std::sync::Arc;
7use std::time::Instant;
8
9use whale::{Durability, RevisionCounter, Runtime as WhaleRuntime};
10
11use crate::asset::{AssetKey, AssetLocator, DurabilityLevel, FullAssetKey, PendingAsset};
12use crate::db::Db;
13use crate::key::FullCacheKey;
14use crate::loading::AssetLoadingState;
15use crate::query::Query;
16use crate::storage::{
17 AssetKeyRegistry, AssetState, AssetStorage, CachedEntry, CachedValue, LocatorStorage,
18 PendingStorage, QueryRegistry, VerifierStorage,
19};
20use crate::tracer::{
21 ExecutionResult, InvalidationReason, NoopTracer, SpanId, Tracer, TracerAssetKey,
22 TracerAssetState, TracerQueryKey,
23};
24use crate::QueryError;
25
26pub type ErrorComparator = fn(&anyhow::Error, &anyhow::Error) -> bool;
31
32const DURABILITY_LEVELS: usize = 4;
34
35thread_local! {
37 static QUERY_STACK: RefCell<Vec<FullCacheKey>> = const { RefCell::new(Vec::new()) };
38}
39
40#[derive(Clone, Copy)]
44pub struct ExecutionContext {
45 span_id: SpanId,
46}
47
48impl ExecutionContext {
49 #[inline]
51 pub fn new(span_id: SpanId) -> Self {
52 Self { span_id }
53 }
54
55 #[inline]
57 pub fn span_id(&self) -> SpanId {
58 self.span_id
59 }
60}
61
62#[derive(Debug, Clone)]
83pub struct Polled<T> {
84 pub value: T,
86 pub revision: RevisionCounter,
90}
91
92impl<T: Deref> Deref for Polled<T> {
93 type Target = T::Target;
94
95 fn deref(&self) -> &Self::Target {
96 &self.value
97 }
98}
99
100pub struct QueryRuntime<T: Tracer = NoopTracer> {
123 whale: WhaleRuntime<FullCacheKey, Option<CachedEntry>, DURABILITY_LEVELS>,
126 assets: Arc<AssetStorage>,
128 locators: Arc<LocatorStorage>,
130 pending: Arc<PendingStorage>,
132 query_registry: Arc<QueryRegistry>,
134 asset_key_registry: Arc<AssetKeyRegistry>,
136 verifiers: Arc<VerifierStorage>,
138 error_comparator: ErrorComparator,
140 tracer: Arc<T>,
142}
143
144impl Default for QueryRuntime<NoopTracer> {
145 fn default() -> Self {
146 Self::new()
147 }
148}
149
150impl<T: Tracer> Clone for QueryRuntime<T> {
151 fn clone(&self) -> Self {
152 Self {
153 whale: self.whale.clone(),
154 assets: self.assets.clone(),
155 locators: self.locators.clone(),
156 pending: self.pending.clone(),
157 query_registry: self.query_registry.clone(),
158 asset_key_registry: self.asset_key_registry.clone(),
159 verifiers: self.verifiers.clone(),
160 error_comparator: self.error_comparator,
161 tracer: self.tracer.clone(),
162 }
163 }
164}
165
166fn default_error_comparator(_a: &anyhow::Error, _b: &anyhow::Error) -> bool {
170 false
171}
172
173impl<T: Tracer> QueryRuntime<T> {
174 fn get_cached_with_revision<Q: Query>(
176 &self,
177 key: &FullCacheKey,
178 ) -> Option<(CachedValue<Arc<Q::Output>>, RevisionCounter)> {
179 let node = self.whale.get(key)?;
180 let revision = node.changed_at;
181 let entry = node.data.as_ref()?;
182 let cached = entry.to_cached_value::<Q::Output>()?;
183 Some((cached, revision))
184 }
185
186 #[inline]
188 pub fn tracer(&self) -> &T {
189 &self.tracer
190 }
191}
192
193impl QueryRuntime<NoopTracer> {
194 pub fn new() -> Self {
196 Self::with_tracer(NoopTracer)
197 }
198
199 pub fn builder() -> QueryRuntimeBuilder<NoopTracer> {
215 QueryRuntimeBuilder::new()
216 }
217}
218
219impl<T: Tracer> QueryRuntime<T> {
220 pub fn with_tracer(tracer: T) -> Self {
222 QueryRuntimeBuilder::new().tracer(tracer).build()
223 }
224
225 pub fn query<Q: Query>(&self, query: Q) -> Result<Arc<Q::Output>, QueryError> {
234 self.query_internal(query)
235 .and_then(|(inner_result, _)| inner_result.map_err(QueryError::UserError))
236 }
237
238 #[allow(clippy::type_complexity)]
243 fn query_internal<Q: Query>(
244 &self,
245 query: Q,
246 ) -> Result<(Result<Arc<Q::Output>, Arc<anyhow::Error>>, RevisionCounter), QueryError> {
247 let key = query.cache_key();
248 let full_key = FullCacheKey::new::<Q, _>(&key);
249
250 let span_id = self.tracer.new_span_id();
252 let exec_ctx = ExecutionContext::new(span_id);
253 let start_time = Instant::now();
254 let query_key = TracerQueryKey::new(std::any::type_name::<Q>(), full_key.debug_repr());
255
256 self.tracer.on_query_start(span_id, query_key.clone());
257
258 let cycle_detected = QUERY_STACK.with(|stack| {
260 let stack = stack.borrow();
261 stack.iter().any(|k| k == &full_key)
262 });
263
264 if cycle_detected {
265 let path = QUERY_STACK.with(|stack| {
266 let stack = stack.borrow();
267 let mut path: Vec<String> =
268 stack.iter().map(|k| k.debug_repr().to_string()).collect();
269 path.push(full_key.debug_repr().to_string());
270 path
271 });
272
273 self.tracer.on_cycle_detected(
274 path.iter()
275 .map(|s| TracerQueryKey::new("", s.clone()))
276 .collect(),
277 );
278 self.tracer.on_query_end(
279 span_id,
280 query_key.clone(),
281 ExecutionResult::CycleDetected,
282 start_time.elapsed(),
283 );
284
285 return Err(QueryError::Cycle { path });
286 }
287
288 let current_rev = self.whale.current_revision();
290
291 if self.whale.is_verified_at(&full_key, ¤t_rev) {
293 if let Some((cached, revision)) = self.get_cached_with_revision::<Q>(&full_key) {
295 self.tracer.on_cache_check(span_id, query_key.clone(), true);
296 self.tracer.on_query_end(
297 span_id,
298 query_key.clone(),
299 ExecutionResult::CacheHit,
300 start_time.elapsed(),
301 );
302
303 return match cached {
304 CachedValue::Ok(output) => Ok((Ok(output), revision)),
305 CachedValue::UserError(err) => Ok((Err(err), revision)),
306 };
307 }
308 }
309
310 if self.whale.is_valid(&full_key) {
312 if let Some((cached, revision)) = self.get_cached_with_revision::<Q>(&full_key) {
314 let mut deps_verified = true;
316 if let Some(deps) = self.whale.get_dependency_ids(&full_key) {
317 for dep in deps {
318 if let Some(verifier) = self.verifiers.get(&dep) {
319 if verifier.verify(self as &dyn std::any::Any).is_err() {
321 deps_verified = false;
322 break;
323 }
324 }
325 }
326 }
327
328 if deps_verified && self.whale.is_valid(&full_key) {
330 self.whale.mark_verified(&full_key, ¤t_rev);
332
333 self.tracer.on_cache_check(span_id, query_key.clone(), true);
334 self.tracer.on_query_end(
335 span_id,
336 query_key.clone(),
337 ExecutionResult::CacheHit,
338 start_time.elapsed(),
339 );
340
341 return match cached {
342 CachedValue::Ok(output) => Ok((Ok(output), revision)),
343 CachedValue::UserError(err) => Ok((Err(err), revision)),
344 };
345 }
346 }
348 }
349
350 self.tracer
351 .on_cache_check(span_id, query_key.clone(), false);
352
353 QUERY_STACK.with(|stack| {
355 stack.borrow_mut().push(full_key.clone());
356 });
357
358 let result = self.execute_query::<Q>(&query, &full_key, exec_ctx);
359
360 QUERY_STACK.with(|stack| {
361 stack.borrow_mut().pop();
362 });
363
364 let exec_result = match &result {
366 Ok((_, true, _)) => ExecutionResult::Changed,
367 Ok((_, false, _)) => ExecutionResult::Unchanged,
368 Err(QueryError::Suspend { .. }) => ExecutionResult::Suspended,
369 Err(QueryError::Cycle { .. }) => ExecutionResult::CycleDetected,
370 Err(e) => ExecutionResult::Error {
371 message: format!("{:?}", e),
372 },
373 };
374 self.tracer.on_query_end(
375 span_id,
376 query_key.clone(),
377 exec_result,
378 start_time.elapsed(),
379 );
380
381 result.map(|(inner_result, _, revision)| (inner_result, revision))
382 }
383
384 #[allow(clippy::type_complexity)]
390 fn execute_query<Q: Query>(
391 &self,
392 query: &Q,
393 full_key: &FullCacheKey,
394 exec_ctx: ExecutionContext,
395 ) -> Result<
396 (
397 Result<Arc<Q::Output>, Arc<anyhow::Error>>,
398 bool,
399 RevisionCounter,
400 ),
401 QueryError,
402 > {
403 let ctx = QueryContext {
405 runtime: self,
406 current_key: full_key.clone(),
407 parent_query_type: std::any::type_name::<Q>(),
408 exec_ctx,
409 deps: RefCell::new(Vec::new()),
410 };
411
412 let result = query.clone().query(&ctx);
414
415 let deps: Vec<FullCacheKey> = ctx.deps.borrow().clone();
417
418 let durability =
420 Durability::new(query.durability() as usize).unwrap_or(Durability::volatile());
421
422 match result {
423 Ok(output) => {
424 let output = Arc::new(output);
425
426 let existing_revision = if let Some((CachedValue::Ok(old), rev)) =
429 self.get_cached_with_revision::<Q>(full_key)
430 {
431 if Q::output_eq(&old, &output) {
432 Some(rev) } else {
434 None }
436 } else {
437 None };
439 let output_changed = existing_revision.is_none();
440
441 self.tracer.on_early_cutoff_check(
443 exec_ctx.span_id(),
444 TracerQueryKey::new(std::any::type_name::<Q>(), full_key.debug_repr()),
445 output_changed,
446 );
447
448 let entry = CachedEntry::Ok(output.clone() as Arc<dyn std::any::Any + Send + Sync>);
450 let revision = if let Some(existing_rev) = existing_revision {
451 let _ = self.whale.confirm_unchanged(full_key, deps);
453 existing_rev
454 } else {
455 match self
457 .whale
458 .register(full_key.clone(), Some(entry), durability, deps)
459 {
460 Ok(result) => result.new_rev,
461 Err(missing) => {
462 return Err(QueryError::DependenciesRemoved {
463 missing_keys: missing,
464 })
465 }
466 }
467 };
468
469 let is_new_query = self.query_registry.register(query);
471 if is_new_query {
472 let sentinel = FullCacheKey::query_set_sentinel::<Q>();
473 let _ = self
474 .whale
475 .register(sentinel, None, Durability::volatile(), vec![]);
476 }
477
478 self.verifiers
480 .insert::<Q, T>(full_key.clone(), query.clone());
481
482 Ok((Ok(output), output_changed, revision))
483 }
484 Err(QueryError::UserError(err)) => {
485 let existing_revision = if let Some((CachedValue::UserError(old_err), rev)) =
488 self.get_cached_with_revision::<Q>(full_key)
489 {
490 if (self.error_comparator)(old_err.as_ref(), err.as_ref()) {
491 Some(rev) } else {
493 None }
495 } else {
496 None };
498 let output_changed = existing_revision.is_none();
499
500 self.tracer.on_early_cutoff_check(
502 exec_ctx.span_id(),
503 TracerQueryKey::new(std::any::type_name::<Q>(), full_key.debug_repr()),
504 output_changed,
505 );
506
507 let entry = CachedEntry::UserError(err.clone());
509 let revision = if let Some(existing_rev) = existing_revision {
510 let _ = self.whale.confirm_unchanged(full_key, deps);
512 existing_rev
513 } else {
514 match self
516 .whale
517 .register(full_key.clone(), Some(entry), durability, deps)
518 {
519 Ok(result) => result.new_rev,
520 Err(missing) => {
521 return Err(QueryError::DependenciesRemoved {
522 missing_keys: missing,
523 })
524 }
525 }
526 };
527
528 let is_new_query = self.query_registry.register(query);
530 if is_new_query {
531 let sentinel = FullCacheKey::query_set_sentinel::<Q>();
532 let _ = self
533 .whale
534 .register(sentinel, None, Durability::volatile(), vec![]);
535 }
536
537 self.verifiers
539 .insert::<Q, T>(full_key.clone(), query.clone());
540
541 Ok((Err(err), output_changed, revision))
542 }
543 Err(e) => {
544 Err(e)
546 }
547 }
548 }
549
550 pub fn invalidate<Q: Query>(&self, key: &Q::CacheKey) {
554 let full_key = FullCacheKey::new::<Q, _>(key);
555
556 self.tracer.on_query_invalidated(
557 TracerQueryKey::new(std::any::type_name::<Q>(), full_key.debug_repr()),
558 InvalidationReason::ManualInvalidation,
559 );
560
561 let _ = self
563 .whale
564 .register(full_key, None, Durability::volatile(), vec![]);
565 }
566
567 pub fn remove_query<Q: Query>(&self, key: &Q::CacheKey) {
575 let full_key = FullCacheKey::new::<Q, _>(key);
576
577 self.tracer.on_query_invalidated(
578 TracerQueryKey::new(std::any::type_name::<Q>(), full_key.debug_repr()),
579 InvalidationReason::ManualInvalidation,
580 );
581
582 self.verifiers.remove(&full_key);
584
585 self.whale.remove(&full_key);
587
588 if self.query_registry.remove::<Q>(key) {
590 let sentinel = FullCacheKey::query_set_sentinel::<Q>();
591 let _ = self
592 .whale
593 .register(sentinel, None, Durability::volatile(), vec![]);
594 }
595 }
596
597 pub fn clear_cache(&self) {
601 let keys = self.whale.keys();
602 for key in keys {
603 self.whale.remove(&key);
604 }
605 }
606
607 #[allow(clippy::type_complexity)]
641 pub fn poll<Q: Query>(
642 &self,
643 query: Q,
644 ) -> Result<Polled<Result<Arc<Q::Output>, Arc<anyhow::Error>>>, QueryError> {
645 let (value, revision) = self.query_internal(query)?;
646 Ok(Polled { value, revision })
647 }
648
649 pub fn changed_at<Q: Query>(&self, key: &Q::CacheKey) -> Option<RevisionCounter> {
668 let full_key = FullCacheKey::new::<Q, _>(key);
669 self.whale.get(&full_key).map(|node| node.changed_at)
670 }
671}
672
673pub struct QueryRuntimeBuilder<T: Tracer = NoopTracer> {
691 error_comparator: ErrorComparator,
692 tracer: T,
693}
694
695impl Default for QueryRuntimeBuilder<NoopTracer> {
696 fn default() -> Self {
697 Self::new()
698 }
699}
700
701impl QueryRuntimeBuilder<NoopTracer> {
702 pub fn new() -> Self {
704 Self {
705 error_comparator: default_error_comparator,
706 tracer: NoopTracer,
707 }
708 }
709}
710
711impl<T: Tracer> QueryRuntimeBuilder<T> {
712 pub fn error_comparator(mut self, f: ErrorComparator) -> Self {
730 self.error_comparator = f;
731 self
732 }
733
734 pub fn tracer<U: Tracer>(self, tracer: U) -> QueryRuntimeBuilder<U> {
736 QueryRuntimeBuilder {
737 error_comparator: self.error_comparator,
738 tracer,
739 }
740 }
741
742 pub fn build(self) -> QueryRuntime<T> {
744 QueryRuntime {
745 whale: WhaleRuntime::new(),
746 assets: Arc::new(AssetStorage::new()),
747 locators: Arc::new(LocatorStorage::new()),
748 pending: Arc::new(PendingStorage::new()),
749 query_registry: Arc::new(QueryRegistry::new()),
750 asset_key_registry: Arc::new(AssetKeyRegistry::new()),
751 verifiers: Arc::new(VerifierStorage::new()),
752 error_comparator: self.error_comparator,
753 tracer: Arc::new(self.tracer),
754 }
755 }
756}
757
758impl<T: Tracer> QueryRuntime<T> {
763 pub fn register_asset_locator<K, L>(&self, locator: L)
775 where
776 K: AssetKey,
777 L: AssetLocator<K>,
778 {
779 self.locators.insert::<K, L>(locator);
780 }
781
782 pub fn pending_assets(&self) -> Vec<PendingAsset> {
798 self.pending.get_all()
799 }
800
801 pub fn pending_assets_of<K: AssetKey>(&self) -> Vec<K> {
803 self.pending.get_of_type::<K>()
804 }
805
806 pub fn has_pending_assets(&self) -> bool {
808 !self.pending.is_empty()
809 }
810
811 pub fn resolve_asset<K: AssetKey>(&self, key: K, value: K::Asset) {
828 let durability = key.durability();
829 self.resolve_asset_internal(key, value, durability);
830 }
831
832 pub fn resolve_asset_with_durability<K: AssetKey>(
836 &self,
837 key: K,
838 value: K::Asset,
839 durability: DurabilityLevel,
840 ) {
841 self.resolve_asset_internal(key, value, durability);
842 }
843
844 fn resolve_asset_internal<K: AssetKey>(
845 &self,
846 key: K,
847 value: K::Asset,
848 durability_level: DurabilityLevel,
849 ) {
850 let full_asset_key = FullAssetKey::new(&key);
851 let full_cache_key = FullCacheKey::from_asset_key(&full_asset_key);
852
853 let changed = if let Some(old_value) = self.assets.get_ready::<K>(&full_asset_key) {
855 !K::asset_eq(&old_value, &value)
856 } else {
857 true };
859
860 self.tracer.on_asset_resolved(
862 TracerAssetKey::new(std::any::type_name::<K>(), format!("{:?}", key)),
863 changed,
864 );
865
866 self.assets
868 .insert_ready::<K>(full_asset_key.clone(), Arc::new(value));
869
870 self.pending.remove(&full_asset_key);
872
873 let durability =
875 Durability::new(durability_level.as_u8() as usize).unwrap_or(Durability::volatile());
876
877 if changed {
878 let _ = self
880 .whale
881 .register(full_cache_key, None, durability, vec![]);
882 } else {
883 let _ = self.whale.confirm_unchanged(&full_cache_key, vec![]);
885 }
886
887 let is_new_asset = self.asset_key_registry.register(&key);
889 if is_new_asset {
890 let sentinel = FullCacheKey::asset_key_set_sentinel::<K>();
892 let _ = self
893 .whale
894 .register(sentinel, None, Durability::volatile(), vec![]);
895 }
896 }
897
898 pub fn invalidate_asset<K: AssetKey>(&self, key: &K) {
912 let full_asset_key = FullAssetKey::new(key);
913 let full_cache_key = FullCacheKey::from_asset_key(&full_asset_key);
914
915 self.tracer.on_asset_invalidated(TracerAssetKey::new(
917 std::any::type_name::<K>(),
918 format!("{:?}", key),
919 ));
920
921 self.assets
923 .insert(full_asset_key.clone(), AssetState::Loading);
924
925 self.pending.insert::<K>(full_asset_key, key.clone());
927
928 let _ = self
930 .whale
931 .register(full_cache_key, None, Durability::volatile(), vec![]);
932 }
933
934 pub fn remove_asset<K: AssetKey>(&self, key: &K) {
939 let full_asset_key = FullAssetKey::new(key);
940 let full_cache_key = FullCacheKey::from_asset_key(&full_asset_key);
941
942 let _ = self
945 .whale
946 .register(full_cache_key.clone(), None, Durability::volatile(), vec![]);
947
948 self.assets.remove(&full_asset_key);
950 self.pending.remove(&full_asset_key);
951
952 self.whale.remove(&full_cache_key);
954
955 if self.asset_key_registry.remove::<K>(key) {
957 let sentinel = FullCacheKey::asset_key_set_sentinel::<K>();
958 let _ = self
959 .whale
960 .register(sentinel, None, Durability::volatile(), vec![]);
961 }
962 }
963
964 pub fn get_asset<K: AssetKey>(&self, key: K) -> Result<AssetLoadingState<K>, QueryError> {
976 self.get_asset_internal(key)
977 }
978
979 fn get_asset_internal<K: AssetKey>(&self, key: K) -> Result<AssetLoadingState<K>, QueryError> {
981 let full_asset_key = FullAssetKey::new(&key);
982 let full_cache_key = FullCacheKey::from_asset_key(&full_asset_key);
983
984 let emit_requested = |tracer: &T, key: &K, state: TracerAssetState| {
986 tracer.on_asset_requested(
987 TracerAssetKey::new(std::any::type_name::<K>(), format!("{:?}", key)),
988 state,
989 );
990 };
991
992 if let Some(state) = self.assets.get(&full_asset_key) {
994 if self.whale.is_valid(&full_cache_key) {
996 return match state {
997 AssetState::Loading => {
998 emit_requested(&self.tracer, &key, TracerAssetState::Loading);
999 Ok(AssetLoadingState::loading(key))
1000 }
1001 AssetState::Ready(arc) => {
1002 emit_requested(&self.tracer, &key, TracerAssetState::Ready);
1003 match arc.downcast::<K::Asset>() {
1004 Ok(value) => Ok(AssetLoadingState::ready(key, value)),
1005 Err(_) => Err(QueryError::MissingDependency {
1006 description: format!("Asset type mismatch: {:?}", key),
1007 }),
1008 }
1009 }
1010 AssetState::NotFound => {
1011 emit_requested(&self.tracer, &key, TracerAssetState::NotFound);
1012 Err(QueryError::MissingDependency {
1013 description: format!("Asset not found: {:?}", key),
1014 })
1015 }
1016 };
1017 }
1018 }
1019
1020 if let Some(locator) = self.locators.get(TypeId::of::<K>()) {
1022 if let Some(state) = locator.locate_any(&key) {
1023 self.assets.insert(full_asset_key.clone(), state.clone());
1025
1026 match state {
1027 AssetState::Ready(arc) => {
1028 emit_requested(&self.tracer, &key, TracerAssetState::Ready);
1029
1030 let durability = Durability::new(key.durability().as_u8() as usize)
1032 .unwrap_or(Durability::volatile());
1033 self.whale
1034 .register(full_cache_key, None, durability, vec![])
1035 .expect("register with no dependencies cannot fail");
1036
1037 match arc.downcast::<K::Asset>() {
1038 Ok(value) => return Ok(AssetLoadingState::ready(key, value)),
1039 Err(_) => {
1040 return Err(QueryError::MissingDependency {
1041 description: format!("Asset type mismatch: {:?}", key),
1042 })
1043 }
1044 }
1045 }
1046 AssetState::Loading => {
1047 emit_requested(&self.tracer, &key, TracerAssetState::Loading);
1048 self.pending.insert::<K>(full_asset_key, key.clone());
1049
1050 self.whale
1052 .register(full_cache_key, None, Durability::volatile(), vec![])
1053 .expect("register with no dependencies cannot fail");
1054
1055 return Ok(AssetLoadingState::loading(key));
1056 }
1057 AssetState::NotFound => {
1058 emit_requested(&self.tracer, &key, TracerAssetState::NotFound);
1059 return Err(QueryError::MissingDependency {
1060 description: format!("Asset not found: {:?}", key),
1061 });
1062 }
1063 }
1064 }
1065 }
1066
1067 emit_requested(&self.tracer, &key, TracerAssetState::Loading);
1069 self.assets
1070 .insert(full_asset_key.clone(), AssetState::Loading);
1071 self.pending
1072 .insert::<K>(full_asset_key.clone(), key.clone());
1073
1074 self.whale
1076 .register(full_cache_key, None, Durability::volatile(), vec![])
1077 .expect("register with no dependencies cannot fail");
1078
1079 Ok(AssetLoadingState::loading(key))
1080 }
1081}
1082
1083impl<T: Tracer> Db for QueryRuntime<T> {
1084 fn query<Q: Query>(&self, query: Q) -> Result<Arc<Q::Output>, QueryError> {
1085 QueryRuntime::query(self, query)
1086 }
1087
1088 fn asset<K: AssetKey>(&self, key: K) -> Result<AssetLoadingState<K>, QueryError> {
1089 self.get_asset_internal(key)
1090 }
1091
1092 fn list_queries<Q: Query>(&self) -> Vec<Q> {
1093 self.query_registry.get_all::<Q>()
1094 }
1095
1096 fn list_asset_keys<K: AssetKey>(&self) -> Vec<K> {
1097 self.asset_key_registry.get_all::<K>()
1098 }
1099}
1100
1101pub struct QueryContext<'a, T: Tracer = NoopTracer> {
1105 runtime: &'a QueryRuntime<T>,
1106 current_key: FullCacheKey,
1107 parent_query_type: &'static str,
1108 exec_ctx: ExecutionContext,
1109 deps: RefCell<Vec<FullCacheKey>>,
1110}
1111
1112impl<'a, T: Tracer> QueryContext<'a, T> {
1113 pub fn query<Q: Query>(&self, query: Q) -> Result<Arc<Q::Output>, QueryError> {
1126 let key = query.cache_key();
1127 let full_key = FullCacheKey::new::<Q, _>(&key);
1128
1129 self.runtime.tracer.on_dependency_registered(
1131 self.exec_ctx.span_id(),
1132 TracerQueryKey::new(self.parent_query_type, self.current_key.debug_repr()),
1133 TracerQueryKey::new(std::any::type_name::<Q>(), full_key.debug_repr()),
1134 );
1135
1136 self.deps.borrow_mut().push(full_key.clone());
1138
1139 self.runtime.query(query)
1141 }
1142
1143 pub fn asset<K: AssetKey>(&self, key: K) -> Result<AssetLoadingState<K>, QueryError> {
1167 let full_asset_key = FullAssetKey::new(&key);
1168 let full_cache_key = FullCacheKey::from_asset_key(&full_asset_key);
1169
1170 self.runtime.tracer.on_asset_dependency_registered(
1172 self.exec_ctx.span_id(),
1173 TracerQueryKey::new(self.parent_query_type, self.current_key.debug_repr()),
1174 TracerAssetKey::new(std::any::type_name::<K>(), format!("{:?}", key)),
1175 );
1176
1177 self.deps.borrow_mut().push(full_cache_key);
1179
1180 let result = self.runtime.get_asset_internal(key);
1182
1183 if let Err(QueryError::MissingDependency { ref description }) = result {
1185 self.runtime.tracer.on_missing_dependency(
1186 TracerQueryKey::new(self.parent_query_type, self.current_key.debug_repr()),
1187 description.clone(),
1188 );
1189 }
1190
1191 result
1192 }
1193
1194 pub fn list_queries<Q: Query>(&self) -> Vec<Q> {
1217 let sentinel = FullCacheKey::query_set_sentinel::<Q>();
1219
1220 self.runtime.tracer.on_dependency_registered(
1221 self.exec_ctx.span_id(),
1222 TracerQueryKey::new(self.parent_query_type, self.current_key.debug_repr()),
1223 TracerQueryKey::new("QuerySet", sentinel.debug_repr()),
1224 );
1225
1226 if self.runtime.whale.get(&sentinel).is_none() {
1228 let _ =
1229 self.runtime
1230 .whale
1231 .register(sentinel.clone(), None, Durability::volatile(), vec![]);
1232 }
1233
1234 self.deps.borrow_mut().push(sentinel);
1235
1236 self.runtime.query_registry.get_all::<Q>()
1238 }
1239
1240 pub fn list_asset_keys<K: AssetKey>(&self) -> Vec<K> {
1265 let sentinel = FullCacheKey::asset_key_set_sentinel::<K>();
1267
1268 self.runtime.tracer.on_asset_dependency_registered(
1269 self.exec_ctx.span_id(),
1270 TracerQueryKey::new(self.parent_query_type, self.current_key.debug_repr()),
1271 TracerAssetKey::new("AssetKeySet", sentinel.debug_repr()),
1272 );
1273
1274 if self.runtime.whale.get(&sentinel).is_none() {
1276 let _ =
1277 self.runtime
1278 .whale
1279 .register(sentinel.clone(), None, Durability::volatile(), vec![]);
1280 }
1281
1282 self.deps.borrow_mut().push(sentinel);
1283
1284 self.runtime.asset_key_registry.get_all::<K>()
1286 }
1287}
1288
1289impl<'a, T: Tracer> Db for QueryContext<'a, T> {
1290 fn query<Q: Query>(&self, query: Q) -> Result<Arc<Q::Output>, QueryError> {
1291 QueryContext::query(self, query)
1292 }
1293
1294 fn asset<K: AssetKey>(&self, key: K) -> Result<AssetLoadingState<K>, QueryError> {
1295 QueryContext::asset(self, key)
1296 }
1297
1298 fn list_queries<Q: Query>(&self) -> Vec<Q> {
1299 QueryContext::list_queries(self)
1300 }
1301
1302 fn list_asset_keys<K: AssetKey>(&self) -> Vec<K> {
1303 QueryContext::list_asset_keys(self)
1304 }
1305}
1306
1307#[cfg(test)]
1308mod tests {
1309 use super::*;
1310
1311 #[test]
1312 fn test_simple_query() {
1313 #[derive(Clone)]
1314 struct Add {
1315 a: i32,
1316 b: i32,
1317 }
1318
1319 impl Query for Add {
1320 type CacheKey = (i32, i32);
1321 type Output = i32;
1322
1323 fn cache_key(&self) -> Self::CacheKey {
1324 (self.a, self.b)
1325 }
1326
1327 fn query(self, _db: &impl Db) -> Result<Self::Output, QueryError> {
1328 Ok(self.a + self.b)
1329 }
1330
1331 fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
1332 old == new
1333 }
1334 }
1335
1336 let runtime = QueryRuntime::new();
1337
1338 let result = runtime.query(Add { a: 1, b: 2 }).unwrap();
1339 assert_eq!(*result, 3);
1340
1341 let result2 = runtime.query(Add { a: 1, b: 2 }).unwrap();
1343 assert_eq!(*result2, 3);
1344 }
1345
1346 #[test]
1347 fn test_dependent_queries() {
1348 #[derive(Clone)]
1349 struct Base {
1350 value: i32,
1351 }
1352
1353 impl Query for Base {
1354 type CacheKey = i32;
1355 type Output = i32;
1356
1357 fn cache_key(&self) -> Self::CacheKey {
1358 self.value
1359 }
1360
1361 fn query(self, _db: &impl Db) -> Result<Self::Output, QueryError> {
1362 Ok(self.value * 2)
1363 }
1364
1365 fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
1366 old == new
1367 }
1368 }
1369
1370 #[derive(Clone)]
1371 struct Derived {
1372 base_value: i32,
1373 }
1374
1375 impl Query for Derived {
1376 type CacheKey = i32;
1377 type Output = i32;
1378
1379 fn cache_key(&self) -> Self::CacheKey {
1380 self.base_value
1381 }
1382
1383 fn query(self, db: &impl Db) -> Result<Self::Output, QueryError> {
1384 let base = db.query(Base {
1385 value: self.base_value,
1386 })?;
1387 Ok(*base + 10)
1388 }
1389
1390 fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
1391 old == new
1392 }
1393 }
1394
1395 let runtime = QueryRuntime::new();
1396
1397 let result = runtime.query(Derived { base_value: 5 }).unwrap();
1398 assert_eq!(*result, 20); }
1400
1401 #[test]
1402 fn test_cycle_detection() {
1403 #[derive(Clone)]
1404 struct CycleA {
1405 id: i32,
1406 }
1407
1408 #[derive(Clone)]
1409 struct CycleB {
1410 id: i32,
1411 }
1412
1413 impl Query for CycleA {
1414 type CacheKey = i32;
1415 type Output = i32;
1416
1417 fn cache_key(&self) -> Self::CacheKey {
1418 self.id
1419 }
1420
1421 fn query(self, db: &impl Db) -> Result<Self::Output, QueryError> {
1422 let b = db.query(CycleB { id: self.id })?;
1423 Ok(*b + 1)
1424 }
1425
1426 fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
1427 old == new
1428 }
1429 }
1430
1431 impl Query for CycleB {
1432 type CacheKey = i32;
1433 type Output = i32;
1434
1435 fn cache_key(&self) -> Self::CacheKey {
1436 self.id
1437 }
1438
1439 fn query(self, db: &impl Db) -> Result<Self::Output, QueryError> {
1440 let a = db.query(CycleA { id: self.id })?;
1441 Ok(*a + 1)
1442 }
1443
1444 fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
1445 old == new
1446 }
1447 }
1448
1449 let runtime = QueryRuntime::new();
1450
1451 let result = runtime.query(CycleA { id: 1 });
1452 assert!(matches!(result, Err(QueryError::Cycle { .. })));
1453 }
1454
1455 #[test]
1456 fn test_fallible_query() {
1457 #[derive(Clone)]
1458 struct ParseInt {
1459 input: String,
1460 }
1461
1462 impl Query for ParseInt {
1463 type CacheKey = String;
1464 type Output = Result<i32, std::num::ParseIntError>;
1465
1466 fn cache_key(&self) -> Self::CacheKey {
1467 self.input.clone()
1468 }
1469
1470 fn query(self, _db: &impl Db) -> Result<Self::Output, QueryError> {
1471 Ok(self.input.parse())
1472 }
1473
1474 fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
1475 old == new
1476 }
1477 }
1478
1479 let runtime = QueryRuntime::new();
1480
1481 let result = runtime
1483 .query(ParseInt {
1484 input: "42".to_string(),
1485 })
1486 .unwrap();
1487 assert_eq!(*result, Ok(42));
1488
1489 let result = runtime
1491 .query(ParseInt {
1492 input: "not_a_number".to_string(),
1493 })
1494 .unwrap();
1495 assert!(result.is_err());
1496 }
1497
1498 mod macro_tests {
1500 use super::*;
1501 use crate::query;
1502
1503 #[query]
1504 fn add(db: &impl Db, a: i32, b: i32) -> Result<i32, QueryError> {
1505 let _ = db; Ok(a + b)
1507 }
1508
1509 #[test]
1510 fn test_macro_basic() {
1511 let runtime = QueryRuntime::new();
1512 let result = runtime.query(Add::new(1, 2)).unwrap();
1513 assert_eq!(*result, 3);
1514 }
1515
1516 #[query(durability = 2)]
1517 fn with_durability(db: &impl Db, x: i32) -> Result<i32, QueryError> {
1518 let _ = db;
1519 Ok(x * 2)
1520 }
1521
1522 #[test]
1523 fn test_macro_durability() {
1524 let runtime = QueryRuntime::new();
1525 let result = runtime.query(WithDurability::new(5)).unwrap();
1526 assert_eq!(*result, 10);
1527 }
1528
1529 #[query(keys(id))]
1530 fn with_key_selection(
1531 db: &impl Db,
1532 id: u32,
1533 include_extra: bool,
1534 ) -> Result<String, QueryError> {
1535 let _ = db;
1536 Ok(format!("id={}, extra={}", id, include_extra))
1537 }
1538
1539 #[test]
1540 fn test_macro_key_selection() {
1541 let runtime = QueryRuntime::new();
1542
1543 let r1 = runtime.query(WithKeySelection::new(1, true)).unwrap();
1545 let r2 = runtime.query(WithKeySelection::new(1, false)).unwrap();
1546
1547 assert_eq!(*r1, "id=1, extra=true");
1549 assert_eq!(*r2, "id=1, extra=true"); }
1551
1552 #[query]
1553 fn dependent(db: &impl Db, a: i32, b: i32) -> Result<i32, QueryError> {
1554 let sum = db.query(Add::new(a, b))?;
1555 Ok(*sum * 2)
1556 }
1557
1558 #[test]
1559 fn test_macro_dependencies() {
1560 let runtime = QueryRuntime::new();
1561 let result = runtime.query(Dependent::new(3, 4)).unwrap();
1562 assert_eq!(*result, 14); }
1564
1565 #[query(output_eq)]
1566 fn with_output_eq(db: &impl Db, x: i32) -> Result<i32, QueryError> {
1567 let _ = db;
1568 Ok(x * 2)
1569 }
1570
1571 #[test]
1572 fn test_macro_output_eq() {
1573 let runtime = QueryRuntime::new();
1574 let result = runtime.query(WithOutputEq::new(5)).unwrap();
1575 assert_eq!(*result, 10);
1576 }
1577
1578 #[query(name = "CustomName")]
1579 fn original_name(db: &impl Db, x: i32) -> Result<i32, QueryError> {
1580 let _ = db;
1581 Ok(x)
1582 }
1583
1584 #[test]
1585 fn test_macro_custom_name() {
1586 let runtime = QueryRuntime::new();
1587 let result = runtime.query(CustomName::new(42)).unwrap();
1588 assert_eq!(*result, 42);
1589 }
1590
1591 #[allow(unused_variables)]
1595 #[inline]
1596 #[query]
1597 fn with_attributes(db: &impl Db, x: i32) -> Result<i32, QueryError> {
1598 let unused_var = 42;
1600 Ok(x * 2)
1601 }
1602
1603 #[test]
1604 fn test_macro_preserves_attributes() {
1605 let runtime = QueryRuntime::new();
1606 let result = runtime.query(WithAttributes::new(5)).unwrap();
1608 assert_eq!(*result, 10);
1609 }
1610 }
1611
1612 mod poll_tests {
1614 use super::*;
1615
1616 #[derive(Clone)]
1617 struct Counter {
1618 id: i32,
1619 }
1620
1621 impl Query for Counter {
1622 type CacheKey = i32;
1623 type Output = i32;
1624
1625 fn cache_key(&self) -> Self::CacheKey {
1626 self.id
1627 }
1628
1629 fn query(self, _db: &impl Db) -> Result<Self::Output, QueryError> {
1630 Ok(self.id * 10)
1631 }
1632
1633 fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
1634 old == new
1635 }
1636 }
1637
1638 #[test]
1639 fn test_poll_returns_value_and_revision() {
1640 let runtime = QueryRuntime::new();
1641
1642 let result = runtime.poll(Counter { id: 1 }).unwrap();
1643
1644 assert_eq!(**result.value.as_ref().unwrap(), 10);
1646
1647 assert!(result.revision > 0);
1649 }
1650
1651 #[test]
1652 fn test_poll_revision_stable_on_cache_hit() {
1653 let runtime = QueryRuntime::new();
1654
1655 let result1 = runtime.poll(Counter { id: 1 }).unwrap();
1657 let rev1 = result1.revision;
1658
1659 let result2 = runtime.poll(Counter { id: 1 }).unwrap();
1661 let rev2 = result2.revision;
1662
1663 assert_eq!(rev1, rev2);
1665 }
1666
1667 #[test]
1668 fn test_poll_revision_changes_on_invalidate() {
1669 let runtime = QueryRuntime::new();
1670
1671 let result1 = runtime.poll(Counter { id: 1 }).unwrap();
1673 let rev1 = result1.revision;
1674
1675 runtime.invalidate::<Counter>(&1);
1677 let result2 = runtime.poll(Counter { id: 1 }).unwrap();
1678 let rev2 = result2.revision;
1679
1680 assert_eq!(**result2.value.as_ref().unwrap(), 10);
1684
1685 assert!(rev2 >= rev1);
1688 }
1689
1690 #[test]
1691 fn test_changed_at_returns_none_for_unexecuted_query() {
1692 let runtime = QueryRuntime::new();
1693
1694 let rev = runtime.changed_at::<Counter>(&1);
1696 assert!(rev.is_none());
1697 }
1698
1699 #[test]
1700 fn test_changed_at_returns_revision_after_execution() {
1701 let runtime = QueryRuntime::new();
1702
1703 let _ = runtime.query(Counter { id: 1 }).unwrap();
1705
1706 let rev = runtime.changed_at::<Counter>(&1);
1708 assert!(rev.is_some());
1709 assert!(rev.unwrap() > 0);
1710 }
1711
1712 #[test]
1713 fn test_changed_at_matches_poll_revision() {
1714 let runtime = QueryRuntime::new();
1715
1716 let result = runtime.poll(Counter { id: 1 }).unwrap();
1718
1719 let rev = runtime.changed_at::<Counter>(&1);
1721 assert_eq!(rev, Some(result.revision));
1722 }
1723
1724 #[test]
1725 fn test_poll_value_access() {
1726 let runtime = QueryRuntime::new();
1727
1728 let result = runtime.poll(Counter { id: 5 }).unwrap();
1729
1730 let value: &i32 = result.value.as_ref().unwrap();
1732 assert_eq!(*value, 50);
1733
1734 let arc: &Arc<i32> = result.value.as_ref().unwrap();
1736 assert_eq!(**arc, 50);
1737 }
1738
1739 #[test]
1740 fn test_subscription_pattern() {
1741 let runtime = QueryRuntime::new();
1742
1743 let mut last_revision: RevisionCounter = 0;
1745 let mut notifications = 0;
1746
1747 let result = runtime.poll(Counter { id: 1 }).unwrap();
1749 if result.revision > last_revision {
1750 notifications += 1;
1751 last_revision = result.revision;
1752 }
1753
1754 let result = runtime.poll(Counter { id: 1 }).unwrap();
1756 if result.revision > last_revision {
1757 notifications += 1;
1758 last_revision = result.revision;
1759 }
1760
1761 let result = runtime.poll(Counter { id: 1 }).unwrap();
1763 if result.revision > last_revision {
1764 notifications += 1;
1765 #[allow(unused_assignments)]
1766 {
1767 last_revision = result.revision;
1768 }
1769 }
1770
1771 assert_eq!(notifications, 1);
1773 }
1774 }
1775}