1use async_trait::async_trait;
7use std::collections::{HashMap, HashSet};
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::{broadcast, RwLock};
11
12#[cfg(target_arch = "wasm32")]
13type MonotonicInstant = web_time::Instant;
14#[cfg(not(target_arch = "wasm32"))]
15type MonotonicInstant = std::time::Instant;
16
17#[allow(clippy::disallowed_methods)] fn monotonic_now() -> MonotonicInstant {
19 MonotonicInstant::now()
20}
21
22use aura_core::domain::journal::FactValue;
23use aura_core::domain::ConsistencyMap;
24use aura_core::effects::reactive::SignalId;
25use aura_core::effects::{
26 indexed::IndexedJournalEffects,
27 query::{QueryEffects, QueryError, QuerySubscription},
28 reactive::{ReactiveEffects, Signal},
29};
30use aura_core::query::{
31 ConsensusId, DatalogBindings, DatalogProgram, FactPredicate, Query, QueryCapability,
32 QueryIsolation, QueryStats,
33};
34use aura_core::{Hash32, ResourceScope};
35
36use crate::database::query::AuraQuery;
37use crate::reactive::ReactiveHandler;
38
39use super::datalog::{format_rule, parse_fact_to_row};
40
41#[async_trait]
46trait QueryRegistration: Send + Sync {
47 fn signal_id(&self) -> &SignalId;
48 fn dependencies(&self) -> &[FactPredicate];
49 async fn refresh(&self, handler: &QueryHandler) -> Result<(), QueryError>;
50}
51
52struct QueryRegistrationImpl<Q: Query> {
53 signal: Signal<Q::Result>,
54 query: Q,
55 deps: Vec<FactPredicate>,
56}
57
58#[async_trait]
59impl<Q: Query> QueryRegistration for QueryRegistrationImpl<Q> {
60 fn signal_id(&self) -> &SignalId {
61 self.signal.id()
62 }
63
64 fn dependencies(&self) -> &[FactPredicate] {
65 &self.deps
66 }
67
68 async fn refresh(&self, handler: &QueryHandler) -> Result<(), QueryError> {
69 let result = handler.query(&self.query).await?;
70 handler
71 .reactive
72 .emit(&self.signal, result)
73 .await
74 .map_err(|e| QueryError::execution_error(e.to_string()))
75 }
76}
77
78#[derive(Debug, Default, Clone)]
84pub(super) struct QueryFacts {
85 facts: HashMap<String, Vec<Vec<String>>>,
87}
88
89impl QueryFacts {
90 pub fn add(&mut self, predicate: &str, args: Vec<String>) {
92 self.facts
93 .entry(predicate.to_string())
94 .or_default()
95 .push(args);
96 }
97
98 pub fn clear(&mut self) {
100 self.facts.clear();
101 }
102
103 #[cfg(test)]
105 pub fn is_empty(&self) -> bool {
106 self.facts.is_empty()
107 }
108
109 pub fn len(&self) -> usize {
111 self.facts.values().map(|v| v.len()).sum()
112 }
113
114 pub fn load_into(&self, query: &mut AuraQuery) {
119 for (predicate, rows) in &self.facts {
120 for args in rows {
121 let terms: Vec<crate::database::query::FactTerm> =
122 args.iter().map(|s| s.clone().into()).collect();
123 let _ = query.add_fact(predicate, terms);
124 }
125 }
126 }
127}
128
129#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
135pub enum CapabilityPolicy {
136 AllowAll,
138 #[default]
140 DenyUnlessGranted,
141}
142
143#[derive(Debug)]
148pub(super) struct CapabilityChecker {
149 granted: Vec<QueryCapability>,
151 policy: CapabilityPolicy,
153}
154
155impl CapabilityChecker {
156 pub fn check(&self, cap: &QueryCapability) -> bool {
162 if self.policy == CapabilityPolicy::AllowAll {
163 return true;
164 }
165
166 self.granted
167 .iter()
168 .any(|g| g.resource == cap.resource && g.action == cap.action)
169 }
170
171 pub fn grant(&mut self, cap: QueryCapability) {
173 self.granted.push(cap);
174 }
175
176 pub fn set_policy(&mut self, policy: CapabilityPolicy) {
178 self.policy = policy;
179 }
180}
181
182#[derive(Debug)]
191pub struct ConsensusTracker {
192 completed: HashSet<ConsensusId>,
194 notify_tx: broadcast::Sender<ConsensusId>,
196}
197
198impl Default for ConsensusTracker {
199 fn default() -> Self {
200 let (notify_tx, _) = broadcast::channel(256);
201 Self {
202 completed: HashSet::new(),
203 notify_tx,
204 }
205 }
206}
207
208impl ConsensusTracker {
209 pub fn mark_completed(&mut self, id: ConsensusId) {
213 self.completed.insert(id);
214 let _ = self.notify_tx.send(id);
216 }
217
218 pub fn is_completed(&self, id: &ConsensusId) -> bool {
220 self.completed.contains(id)
221 }
222
223 pub fn subscribe(&self) -> broadcast::Receiver<ConsensusId> {
225 self.notify_tx.subscribe()
226 }
227
228 pub fn all_completed(&self, ids: &[ConsensusId]) -> bool {
230 ids.iter().all(|id| self.completed.contains(id))
231 }
232
233 pub fn from_core_id(id: &aura_core::query::ConsensusId) -> ConsensusId {
235 ConsensusId::new(id.0)
236 }
237
238 pub fn to_core_id(id: &ConsensusId) -> aura_core::query::ConsensusId {
240 aura_core::query::ConsensusId::new(*id.as_bytes())
241 }
242}
243
244#[derive(Debug, Default)]
254pub struct SnapshotStore {
255 snapshots: HashMap<Hash32, QueryFacts>,
257 max_snapshots: usize,
259 creation_order: Vec<Hash32>,
261}
262
263impl SnapshotStore {
264 pub fn new(max_snapshots: usize) -> Self {
266 Self {
267 snapshots: HashMap::new(),
268 max_snapshots,
269 creation_order: Vec::new(),
270 }
271 }
272
273 pub fn create_snapshot(&mut self, prestate_hash: Hash32, facts: &QueryFacts) {
275 let snapshot = QueryFacts {
277 facts: facts.facts.clone(),
278 };
279
280 while self.snapshots.len() >= self.max_snapshots && !self.creation_order.is_empty() {
282 let oldest = self.creation_order.remove(0);
283 self.snapshots.remove(&oldest);
284 }
285
286 self.snapshots.insert(prestate_hash, snapshot);
287 self.creation_order.push(prestate_hash);
288 }
289
290 pub fn get_snapshot(&self, prestate_hash: &Hash32) -> Option<&QueryFacts> {
292 self.snapshots.get(prestate_hash)
293 }
294
295 pub fn has_snapshot(&self, prestate_hash: &Hash32) -> bool {
297 self.snapshots.contains_key(prestate_hash)
298 }
299
300 pub fn remove_snapshot(&mut self, prestate_hash: &Hash32) {
302 self.snapshots.remove(prestate_hash);
303 self.creation_order.retain(|h| h != prestate_hash);
304 }
305}
306
307#[derive(Debug, Default)]
313pub struct PendingConsensusTracker {
314 pending_by_scope: HashMap<ResourceScope, HashSet<ConsensusId>>,
316}
317
318impl PendingConsensusTracker {
319 pub fn register_pending(&mut self, scope: ResourceScope, id: ConsensusId) {
321 self.pending_by_scope.entry(scope).or_default().insert(id);
322 }
323
324 pub fn mark_completed(&mut self, id: &ConsensusId) {
326 for pending_set in self.pending_by_scope.values_mut() {
327 pending_set.remove(id);
328 }
329 }
330
331 pub fn pending_for_scope(&self, scope: &ResourceScope) -> Vec<ConsensusId> {
333 self.pending_by_scope
334 .get(scope)
335 .map(|set| set.iter().copied().collect())
336 .unwrap_or_default()
337 }
338}
339
340fn fact_value_to_args(value: &FactValue) -> Vec<String> {
349 match value {
350 FactValue::String(s) => vec![s.clone()],
351 FactValue::Number(n) => vec![n.to_string()],
352 FactValue::Bytes(b) => {
353 let hex: String = b.iter().map(|byte| format!("{byte:02x}")).collect();
355 vec![hex]
356 }
357 FactValue::Set(s) => s.iter().cloned().collect(),
358 FactValue::Nested(fact) => {
359 vec![format!("{:?}", fact)]
361 }
362 }
363}
364
365const DEFAULT_CONSENSUS_TIMEOUT: Duration = Duration::from_secs(30);
371
372const DEFAULT_MAX_SNAPSHOTS: usize = 100;
374
375pub struct QueryHandler {
404 reactive: Arc<ReactiveHandler>,
406 pub(super) facts: Arc<RwLock<QueryFacts>>,
408 capabilities: Arc<RwLock<CapabilityChecker>>,
410 indexed_journal: Option<Arc<dyn IndexedJournalEffects + Send + Sync>>,
412 consensus_tracker: Arc<RwLock<ConsensusTracker>>,
414 snapshot_store: Arc<RwLock<SnapshotStore>>,
416 pending_consensus: Arc<RwLock<PendingConsensusTracker>>,
418 query_bindings: Arc<RwLock<HashMap<SignalId, Box<dyn QueryRegistration>>>>,
420 consensus_timeout: Duration,
422}
423
424impl QueryHandler {
425 pub fn new(reactive: Arc<ReactiveHandler>) -> Self {
427 Self::new_with_policy(reactive, CapabilityPolicy::default())
428 }
429
430 pub fn new_with_policy(reactive: Arc<ReactiveHandler>, policy: CapabilityPolicy) -> Self {
432 Self {
433 reactive,
434 facts: Arc::new(RwLock::new(QueryFacts::default())),
435 capabilities: Arc::new(RwLock::new(CapabilityChecker {
436 granted: Vec::new(),
437 policy,
438 })),
439 indexed_journal: None,
440 consensus_tracker: Arc::new(RwLock::new(ConsensusTracker::default())),
441 snapshot_store: Arc::new(RwLock::new(SnapshotStore::new(DEFAULT_MAX_SNAPSHOTS))),
442 pending_consensus: Arc::new(RwLock::new(PendingConsensusTracker::default())),
443 query_bindings: Arc::new(RwLock::new(HashMap::new())),
444 consensus_timeout: DEFAULT_CONSENSUS_TIMEOUT,
445 }
446 }
447
448 pub fn with_indexed_journal(
454 reactive: Arc<ReactiveHandler>,
455 indexed_journal: Arc<dyn IndexedJournalEffects + Send + Sync>,
456 ) -> Self {
457 Self::with_indexed_journal_with_policy(
458 reactive,
459 indexed_journal,
460 CapabilityPolicy::default(),
461 )
462 }
463
464 pub fn with_indexed_journal_with_policy(
466 reactive: Arc<ReactiveHandler>,
467 indexed_journal: Arc<dyn IndexedJournalEffects + Send + Sync>,
468 policy: CapabilityPolicy,
469 ) -> Self {
470 Self {
471 reactive,
472 facts: Arc::new(RwLock::new(QueryFacts::default())),
473 capabilities: Arc::new(RwLock::new(CapabilityChecker {
474 granted: Vec::new(),
475 policy,
476 })),
477 indexed_journal: Some(indexed_journal),
478 consensus_tracker: Arc::new(RwLock::new(ConsensusTracker::default())),
479 snapshot_store: Arc::new(RwLock::new(SnapshotStore::new(DEFAULT_MAX_SNAPSHOTS))),
480 pending_consensus: Arc::new(RwLock::new(PendingConsensusTracker::default())),
481 query_bindings: Arc::new(RwLock::new(HashMap::new())),
482 consensus_timeout: DEFAULT_CONSENSUS_TIMEOUT,
483 }
484 }
485
486 pub fn with_consensus_timeout(mut self, timeout: Duration) -> Self {
488 self.consensus_timeout = timeout;
489 self
490 }
491
492 pub async fn add_fact(&self, predicate: &str, args: Vec<String>) {
497 let mut facts = self.facts.write().await;
498 facts.add(predicate, args);
499 }
500
501 pub async fn add_facts(&self, entries: Vec<(String, Vec<String>)>) {
503 let mut facts = self.facts.write().await;
504 for (predicate, args) in entries {
505 facts.add(&predicate, args);
506 }
507 }
508
509 pub async fn clear_facts(&self) {
511 let mut facts = self.facts.write().await;
512 facts.clear();
513 }
514
515 pub async fn grant_capability(&self, cap: QueryCapability) {
521 let mut checker = self.capabilities.write().await;
522 checker.grant(cap);
523 }
524
525 pub async fn set_capability_policy(&self, policy: CapabilityPolicy) {
529 let mut checker = self.capabilities.write().await;
530 checker.set_policy(policy);
531 }
532
533 pub async fn load_facts_for_predicate(&self, predicate: &str) -> Result<usize, QueryError> {
540 let Some(ref indexed) = self.indexed_journal else {
541 return Ok(0); };
543
544 let indexed_facts = indexed
545 .facts_by_predicate(predicate)
546 .await
547 .map_err(|e| QueryError::execution_error(e.to_string()))?;
548
549 let count = indexed_facts.len();
550 let mut facts = self.facts.write().await;
551
552 for fact in indexed_facts {
553 let args = fact_value_to_args(&fact.value);
554 facts.add(&fact.predicate, args);
555 }
556
557 Ok(count)
558 }
559
560 pub fn might_contain_fact(&self, predicate: &str, value: &FactValue) -> bool {
568 match &self.indexed_journal {
569 Some(indexed) => indexed.might_contain(predicate, value),
570 None => true, }
572 }
573
574 pub async fn register_query_binding<Q: Query>(
578 &self,
579 signal: &Signal<Q::Result>,
580 query: Q,
581 ) -> Result<(), QueryError> {
582 let deps = query.dependencies();
583 let registration = QueryRegistrationImpl {
584 signal: signal.clone(),
585 query: query.clone(),
586 deps,
587 };
588
589 self.query_bindings
590 .write()
591 .await
592 .insert(signal.id().clone(), Box::new(registration));
593
594 let result = self.query(&query).await?;
595 self.reactive
596 .emit(signal, result)
597 .await
598 .map_err(|e| QueryError::execution_error(e.to_string()))?;
599
600 Ok(())
601 }
602
603 async fn refresh_queries_for_predicate(&self, predicate: &FactPredicate) {
604 let bindings = self.query_bindings.read().await;
605 for registration in bindings.values() {
606 if registration
607 .dependencies()
608 .iter()
609 .any(|dep| dep.matches(predicate))
610 {
611 if let Err(err) = registration.refresh(self).await {
612 tracing::warn!(
613 error = %err,
614 signal_id = %registration.signal_id(),
615 "Failed to refresh query-bound signal"
616 );
617 }
618 }
619 }
620 }
621
622 pub async fn mark_consensus_completed(&self, id: ConsensusId) {
631 let mut tracker = self.consensus_tracker.write().await;
632 tracker.mark_completed(id);
633
634 let mut pending = self.pending_consensus.write().await;
636 pending.mark_completed(&id);
637 }
638
639 pub async fn register_pending_consensus(&self, scope: ResourceScope, id: ConsensusId) {
644 let mut pending = self.pending_consensus.write().await;
645 pending.register_pending(scope, id);
646 }
647
648 pub async fn create_snapshot(&self, prestate_hash: Hash32) {
653 let facts = self.facts.read().await;
654 let mut store = self.snapshot_store.write().await;
655 store.create_snapshot(prestate_hash, &facts);
656 }
657
658 pub async fn remove_snapshot(&self, prestate_hash: Hash32) {
660 let mut store = self.snapshot_store.write().await;
661 store.remove_snapshot(&prestate_hash);
662 }
663
664 pub async fn has_snapshot(&self, prestate_hash: &Hash32) -> bool {
666 let store = self.snapshot_store.read().await;
667 store.has_snapshot(prestate_hash)
668 }
669
670 async fn wait_for_consensus(&self, ids: &[ConsensusId]) -> Result<(), QueryError> {
675 if ids.is_empty() {
676 return Ok(());
677 }
678
679 {
681 let tracker = self.consensus_tracker.read().await;
682 if tracker.all_completed(ids) {
683 return Ok(());
684 }
685 }
686
687 let mut receiver = {
689 let tracker = self.consensus_tracker.read().await;
690 tracker.subscribe()
691 };
692
693 let deadline = monotonic_now() + self.consensus_timeout;
694
695 loop {
696 {
698 let tracker = self.consensus_tracker.read().await;
699 if tracker.all_completed(ids) {
700 return Ok(());
701 }
702 }
703
704 let remaining = deadline.saturating_duration_since(monotonic_now());
706 if remaining.is_zero() {
707 let tracker = self.consensus_tracker.read().await;
709 for id in ids {
710 if !tracker.is_completed(id) {
711 return Err(QueryError::consensus_timeout(ConsensusTracker::to_core_id(
712 id,
713 )));
714 }
715 }
716 return Ok(()); }
718
719 match tokio::time::timeout(remaining, receiver.recv()).await {
720 Ok(Ok(_completed_id)) => {
721 continue;
723 }
724 Ok(Err(_)) => {
725 return Err(QueryError::internal(
727 "Consensus tracker channel closed unexpectedly",
728 ));
729 }
730 Err(_) => {
731 let tracker = self.consensus_tracker.read().await;
733 for id in ids {
734 if !tracker.is_completed(id) {
735 return Err(QueryError::consensus_timeout(
736 ConsensusTracker::to_core_id(id),
737 ));
738 }
739 }
740 return Ok(());
741 }
742 }
743 }
744 }
745
746 async fn wait_for_scope_consensus(&self, scope: &ResourceScope) -> Result<(), QueryError> {
748 let pending_ids = {
749 let pending = self.pending_consensus.read().await;
750 pending.pending_for_scope(scope)
751 };
752
753 if pending_ids.is_empty() {
754 return Ok(());
755 }
756
757 self.wait_for_consensus(&pending_ids).await
758 }
759
760 async fn execute_snapshot_query<Q: Query>(
762 &self,
763 query: &Q,
764 prestate_hash: &Hash32,
765 ) -> Result<Q::Result, QueryError> {
766 let required = query.required_capabilities();
768 self.check_capabilities(&required).await?;
769
770 let snapshot = {
772 let store = self.snapshot_store.read().await;
773 store
774 .get_snapshot(prestate_hash)
775 .ok_or_else(|| QueryError::snapshot_not_available(*prestate_hash))?
776 .clone()
777 };
778
779 let program = query.to_datalog();
781 let bindings = self.execute_program_with_facts(&program, &snapshot).await?;
782
783 Q::parse(bindings).map_err(QueryError::from)
784 }
785
786 async fn execute_program_with_facts(
788 &self,
789 program: &DatalogProgram,
790 facts: &QueryFacts,
791 ) -> Result<DatalogBindings, QueryError> {
792 let mut aura_query = AuraQuery::new();
793
794 facts.load_into(&mut aura_query);
796
797 let mut all_rows = Vec::new();
799
800 for rule in &program.rules {
801 let rule_string = format_rule(rule);
802
803 match aura_query.query(&rule_string) {
804 Ok(result) => {
805 for fact_strings in result.facts {
806 let row = parse_fact_to_row(&fact_strings);
807 all_rows.push(row);
808 }
809 }
810 Err(e) => {
811 tracing::warn!(rule = %rule_string, error = %e, "Rule execution failed");
812 }
813 }
814 }
815
816 Ok(DatalogBindings { rows: all_rows })
817 }
818
819 async fn execute_program(
821 &self,
822 program: &DatalogProgram,
823 ) -> Result<DatalogBindings, QueryError> {
824 let facts = self.facts.read().await;
825 self.execute_program_with_facts(program, &facts).await
826 }
827}
828
829impl Default for QueryHandler {
830 fn default() -> Self {
831 Self::new(Arc::new(ReactiveHandler::new()))
832 }
833}
834
835impl Clone for QueryHandler {
836 fn clone(&self) -> Self {
837 Self {
838 reactive: self.reactive.clone(),
839 facts: self.facts.clone(),
840 capabilities: self.capabilities.clone(),
841 indexed_journal: self.indexed_journal.clone(),
842 consensus_tracker: self.consensus_tracker.clone(),
843 snapshot_store: self.snapshot_store.clone(),
844 pending_consensus: self.pending_consensus.clone(),
845 query_bindings: self.query_bindings.clone(),
846 consensus_timeout: self.consensus_timeout,
847 }
848 }
849}
850
851#[async_trait]
856impl QueryEffects for QueryHandler {
857 async fn query<Q: Query>(&self, query: &Q) -> Result<Q::Result, QueryError> {
858 let required = query.required_capabilities();
860 self.check_capabilities(&required).await?;
861
862 let program = query.to_datalog();
864
865 let bindings = self.execute_program(&program).await?;
867
868 Q::parse(bindings).map_err(QueryError::from)
870 }
871
872 async fn query_raw(&self, program: &DatalogProgram) -> Result<DatalogBindings, QueryError> {
873 self.execute_program(program).await
874 }
875
876 fn subscribe<Q: Query>(&self, query: &Q) -> QuerySubscription<Q::Result> {
877 let signal_name = format!("query:{}:{}", std::any::type_name::<Q>(), query.query_id());
879 let signal: Signal<Q::Result> = Signal::new(signal_name.as_str());
880
881 let stream = match self.reactive.subscribe(&signal) {
883 Ok(stream) => stream,
884 Err(error) => {
885 tracing::error!(
886 signal_id = %signal.id(),
887 query_id = %query.query_id(),
888 error = %error,
889 "query subscription requested before the reactive signal was registered"
890 );
891 let (_tx, receiver) = broadcast::channel(1);
892 aura_core::effects::reactive::SignalStream::new(receiver, signal.id().clone())
893 }
894 };
895
896 QuerySubscription::new(stream, query.query_id())
898 }
899
900 async fn check_capabilities(&self, capabilities: &[QueryCapability]) -> Result<(), QueryError> {
901 let checker = self.capabilities.read().await;
902
903 for cap in capabilities {
904 if !checker.check(cap) {
905 return Err(QueryError::missing_capability(cap));
906 }
907 }
908
909 Ok(())
910 }
911
912 async fn invalidate(&self, predicate: &FactPredicate) {
913 self.reactive.invalidate_queries(predicate).await;
914 self.refresh_queries_for_predicate(predicate).await;
915 }
916
917 async fn query_with_isolation<Q: Query>(
918 &self,
919 query: &Q,
920 isolation: QueryIsolation,
921 ) -> Result<Q::Result, QueryError> {
922 match &isolation {
923 QueryIsolation::ReadUncommitted => {
924 self.query(query).await
926 }
927 QueryIsolation::ReadCommitted { wait_for } => {
928 let local_ids: Vec<ConsensusId> = wait_for
930 .iter()
931 .map(ConsensusTracker::from_core_id)
932 .collect();
933
934 self.wait_for_consensus(&local_ids).await?;
936
937 self.query(query).await
939 }
940 QueryIsolation::Snapshot { prestate_hash } => {
941 self.execute_snapshot_query(query, prestate_hash).await
943 }
944 QueryIsolation::ReadLatest { scope } => {
945 self.wait_for_scope_consensus(scope).await?;
947
948 self.query(query).await
950 }
951 }
952 }
953
954 #[allow(clippy::disallowed_methods)] async fn query_with_stats<Q: Query>(
956 &self,
957 query: &Q,
958 ) -> Result<(Q::Result, QueryStats), QueryError> {
959 let start = crate::time::monotonic_now();
960
961 let result = self.query(query).await?;
963
964 let stats = QueryStats::new(start.elapsed())
966 .with_facts_scanned(self.facts.read().await.len() as u32)
967 .with_isolation(QueryIsolation::ReadUncommitted);
968
969 Ok((result, stats))
970 }
971
972 async fn query_with_consistency<Q: Query>(
973 &self,
974 query: &Q,
975 ) -> Result<(Q::Result, ConsistencyMap), QueryError> {
976 let result = self.query(query).await?;
978
979 let consistency = ConsistencyMap::new();
985
986 Ok((result, consistency))
987 }
988
989 #[allow(clippy::disallowed_methods)] async fn query_full<Q: Query>(
991 &self,
992 query: &Q,
993 isolation: QueryIsolation,
994 ) -> Result<(Q::Result, QueryStats), QueryError> {
995 let start = crate::time::monotonic_now();
996
997 let result = self.query_with_isolation(query, isolation.clone()).await?;
999
1000 let stats = QueryStats::new(start.elapsed())
1002 .with_facts_scanned(self.facts.read().await.len() as u32)
1003 .with_isolation(isolation);
1004
1005 Ok((result, stats))
1006 }
1007
1008 async fn register_query_binding<Q: Query>(
1009 &self,
1010 signal: &Signal<Q::Result>,
1011 query: Q,
1012 ) -> Result<(), QueryError> {
1013 QueryHandler::register_query_binding(self, signal, query).await
1014 }
1015}
1016
1017#[cfg(test)]
1022mod tests {
1023 use super::*;
1024
1025 #[tokio::test]
1026 async fn test_handler_creation() {
1027 let handler = QueryHandler::default();
1028 assert!(handler.facts.read().await.is_empty());
1029 }
1030
1031 #[tokio::test]
1032 async fn test_add_fact() {
1033 let handler = QueryHandler::default();
1034
1035 handler
1036 .add_fact("user", vec!["alice".to_string(), "admin".to_string()])
1037 .await;
1038
1039 let facts = handler.facts.read().await;
1040 assert!(!facts.is_empty());
1041 }
1042
1043 #[tokio::test]
1044 async fn test_add_multiple_facts() {
1045 let handler = QueryHandler::default();
1046
1047 handler
1048 .add_facts(vec![
1049 ("user".to_string(), vec!["alice".to_string()]),
1050 ("user".to_string(), vec!["bob".to_string()]),
1051 ("role".to_string(), vec!["admin".to_string()]),
1052 ])
1053 .await;
1054
1055 let facts = handler.facts.read().await;
1056 assert!(!facts.is_empty());
1057 }
1058
1059 #[tokio::test]
1060 async fn test_clear_facts() {
1061 let handler = QueryHandler::default();
1062
1063 handler.add_fact("test", vec!["value".to_string()]).await;
1064 handler.clear_facts().await;
1065
1066 assert!(handler.facts.read().await.is_empty());
1067 }
1068
1069 #[tokio::test]
1070 async fn test_check_capabilities_empty() {
1071 let handler = QueryHandler::default();
1072
1073 let result = handler.check_capabilities(&[]).await;
1075 assert!(result.is_ok());
1076 }
1077
1078 #[tokio::test]
1079 async fn test_check_capabilities_denied_by_default() {
1080 let handler = QueryHandler::default();
1081
1082 let cap = QueryCapability::read("messages");
1083 let result = handler.check_capabilities(&[cap]).await;
1084 assert!(result.is_err());
1085 }
1086
1087 #[tokio::test]
1088 async fn test_grant_capability() {
1089 let handler = QueryHandler::default();
1090
1091 let cap = QueryCapability::read("messages");
1092 handler.grant_capability(cap.clone()).await;
1093
1094 let result = handler.check_capabilities(&[cap]).await;
1096 assert!(result.is_ok());
1097 }
1098
1099 #[tokio::test]
1104 async fn test_consensus_tracker_completion() {
1105 let handler = QueryHandler::default();
1106
1107 let consensus_id = ConsensusId::new([1u8; 32]);
1108
1109 {
1111 let tracker = handler.consensus_tracker.read().await;
1112 assert!(!tracker.is_completed(&consensus_id));
1113 }
1114
1115 handler.mark_consensus_completed(consensus_id).await;
1117
1118 {
1120 let tracker = handler.consensus_tracker.read().await;
1121 assert!(tracker.is_completed(&consensus_id));
1122 }
1123 }
1124
1125 #[tokio::test]
1126 async fn test_snapshot_create_and_retrieve() {
1127 let handler = QueryHandler::default();
1128
1129 handler.add_fact("user", vec!["alice".to_string()]).await;
1131 handler.add_fact("user", vec!["bob".to_string()]).await;
1132
1133 let prestate_hash = Hash32([42u8; 32]);
1135 handler.create_snapshot(prestate_hash).await;
1136
1137 assert!(handler.has_snapshot(&prestate_hash).await);
1139
1140 {
1142 let store = handler.snapshot_store.read().await;
1143 let snapshot = store.get_snapshot(&prestate_hash).unwrap();
1144 assert_eq!(snapshot.len(), 2);
1145 }
1146 }
1147
1148 #[tokio::test]
1149 async fn test_snapshot_removal() {
1150 let handler = QueryHandler::default();
1151
1152 let prestate_hash = Hash32([42u8; 32]);
1154 handler.create_snapshot(prestate_hash).await;
1155
1156 assert!(handler.has_snapshot(&prestate_hash).await);
1157
1158 handler.remove_snapshot(prestate_hash).await;
1160
1161 assert!(!handler.has_snapshot(&prestate_hash).await);
1162 }
1163
1164 #[tokio::test]
1165 async fn test_pending_consensus_registration() {
1166 let handler = QueryHandler::default();
1167
1168 let consensus_id = ConsensusId::new([1u8; 32]);
1169 let scope = ResourceScope::Authority {
1170 authority_id: aura_core::AuthorityId::new_from_entropy([1u8; 32]),
1171 operation: aura_core::AuthorityOp::UpdateTree,
1172 };
1173
1174 handler
1176 .register_pending_consensus(scope.clone(), consensus_id)
1177 .await;
1178
1179 {
1181 let pending = handler.pending_consensus.read().await;
1182 let ids = pending.pending_for_scope(&scope);
1183 assert_eq!(ids.len(), 1);
1184 assert_eq!(ids[0], consensus_id);
1185 }
1186
1187 handler.mark_consensus_completed(consensus_id).await;
1189
1190 {
1192 let pending = handler.pending_consensus.read().await;
1193 let ids = pending.pending_for_scope(&scope);
1194 assert!(ids.is_empty());
1195 }
1196 }
1197
1198 #[tokio::test]
1199 async fn test_wait_for_consensus_already_completed() {
1200 let handler = QueryHandler::default();
1201
1202 let consensus_id = ConsensusId::new([1u8; 32]);
1203
1204 handler.mark_consensus_completed(consensus_id).await;
1206
1207 let result = handler.wait_for_consensus(&[consensus_id]).await;
1209 assert!(result.is_ok());
1210 }
1211
1212 #[tokio::test]
1213 async fn test_wait_for_consensus_empty_list() {
1214 let handler = QueryHandler::default();
1215
1216 let result = handler.wait_for_consensus(&[]).await;
1218 assert!(result.is_ok());
1219 }
1220
1221 #[tokio::test]
1222 async fn test_consensus_id_conversion_roundtrip() {
1223 let original_bytes = [42u8; 32];
1224 let core_id = aura_core::query::ConsensusId::new(original_bytes);
1225
1226 let local_id = ConsensusTracker::from_core_id(&core_id);
1227 let back_to_core = ConsensusTracker::to_core_id(&local_id);
1228
1229 assert_eq!(core_id, back_to_core);
1230 assert_eq!(local_id.as_bytes(), &original_bytes);
1231 }
1232}