1use std::sync::Arc;
56
57use async_trait::async_trait;
58use serde_json::json;
59use sha2::{Digest, Sha256};
60
61use super::{
62 fact_table_version::{
63 FactTableCacheConfig, FactTableVersionProvider, FactTableVersionStrategy,
64 generate_version_key_component,
65 },
66 key::generate_cache_key,
67 result::QueryResultCache,
68};
69use crate::{
70 db::{DatabaseAdapter, DatabaseType, PoolMetrics, WhereClause, types::JsonbValue},
71 error::Result,
72};
73
74pub struct CachedDatabaseAdapter<A: DatabaseAdapter> {
118 adapter: A,
120
121 cache: Arc<QueryResultCache>,
123
124 schema_version: String,
129
130 fact_table_config: FactTableCacheConfig,
132
133 version_provider: Arc<FactTableVersionProvider>,
135}
136
137impl<A: DatabaseAdapter> CachedDatabaseAdapter<A> {
138 #[must_use]
164 pub fn new(adapter: A, cache: QueryResultCache, schema_version: String) -> Self {
165 Self {
166 adapter,
167 cache: Arc::new(cache),
168 schema_version,
169 fact_table_config: FactTableCacheConfig::default(),
170 version_provider: Arc::new(FactTableVersionProvider::default()),
171 }
172 }
173
174 #[must_use]
211 pub fn with_fact_table_config(
212 adapter: A,
213 cache: QueryResultCache,
214 schema_version: String,
215 fact_table_config: FactTableCacheConfig,
216 ) -> Self {
217 Self {
218 adapter,
219 cache: Arc::new(cache),
220 schema_version,
221 fact_table_config,
222 version_provider: Arc::new(FactTableVersionProvider::default()),
223 }
224 }
225
226 pub fn invalidate_views(&self, views: &[String]) -> Result<u64> {
256 self.cache.invalidate_views(views)
257 }
258
259 pub fn invalidate_cascade_entities(
308 &self,
309 cascade_response: &serde_json::Value,
310 parser: &super::cascade_response_parser::CascadeResponseParser,
311 ) -> Result<u64> {
312 let cascade_entities = parser.parse_cascade_response(cascade_response)?;
314
315 if !cascade_entities.has_changes() {
316 return Ok(0);
318 }
319
320 let mut views_to_invalidate = std::collections::HashSet::new();
325 for entity in cascade_entities.all_affected() {
326 let view_name = format!("v_{}", entity.entity_type.to_lowercase());
328 views_to_invalidate.insert(view_name);
329 }
330
331 let views: Vec<String> = views_to_invalidate.into_iter().collect();
333 self.cache.invalidate_views(&views)
334 }
335
336 #[must_use]
351 pub const fn inner(&self) -> &A {
352 &self.adapter
353 }
354
355 #[must_use]
371 pub fn cache(&self) -> &QueryResultCache {
372 &self.cache
373 }
374
375 #[must_use]
387 pub fn schema_version(&self) -> &str {
388 &self.schema_version
389 }
390
391 #[must_use]
393 pub fn fact_table_config(&self) -> &FactTableCacheConfig {
394 &self.fact_table_config
395 }
396
397 #[must_use]
399 pub fn version_provider(&self) -> &FactTableVersionProvider {
400 &self.version_provider
401 }
402
403 fn extract_fact_table_from_sql(sql: &str) -> Option<String> {
407 let sql_lower = sql.to_lowercase();
409 let from_idx = sql_lower.find("from ")?;
410 let after_from = &sql_lower[from_idx + 5..];
411
412 let trimmed = after_from.trim_start();
414
415 if !trimmed.starts_with("tf_") {
417 return None;
418 }
419
420 let end_idx = trimmed
422 .find(|c: char| c.is_whitespace() || c == ',' || c == ')')
423 .unwrap_or(trimmed.len());
424
425 Some(trimmed[..end_idx].to_string())
426 }
427
428 fn generate_aggregation_cache_key(
432 sql: &str,
433 schema_version: &str,
434 version_component: Option<&str>,
435 ) -> String {
436 let mut hasher = Sha256::new();
437 hasher.update(sql.as_bytes());
438 hasher.update(schema_version.as_bytes());
439 if let Some(vc) = version_component {
440 hasher.update(vc.as_bytes());
441 }
442 let result = hasher.finalize();
443 format!("agg:{:x}", result)
444 }
445
446 async fn fetch_table_version(&self, table_name: &str) -> Option<i64> {
450 if let Some(version) = self.version_provider.get_cached_version(table_name) {
452 return Some(version);
453 }
454
455 let sql = format!(
457 "SELECT version FROM tf_versions WHERE table_name = '{}'",
458 table_name.replace('\'', "''") );
460
461 match self.adapter.execute_raw_query(&sql).await {
462 Ok(rows) if !rows.is_empty() => {
463 if let Some(serde_json::Value::Number(n)) = rows[0].get("version") {
464 if let Some(v) = n.as_i64() {
465 self.version_provider.set_cached_version(table_name, v);
466 return Some(v);
467 }
468 }
469 None
470 },
471 _ => None,
472 }
473 }
474
475 pub async fn execute_aggregation_query(
505 &self,
506 sql: &str,
507 ) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>> {
508 let Some(table_name) = Self::extract_fact_table_from_sql(sql) else {
510 return self.adapter.execute_raw_query(sql).await;
512 };
513
514 let strategy = self.fact_table_config.get_strategy(&table_name);
516
517 if !strategy.is_caching_enabled() {
519 return self.adapter.execute_raw_query(sql).await;
520 }
521
522 let table_version = if matches!(strategy, FactTableVersionStrategy::VersionTable) {
524 self.fetch_table_version(&table_name).await
525 } else {
526 None
527 };
528
529 let version_component = generate_version_key_component(
530 &table_name,
531 strategy,
532 table_version,
533 &self.schema_version,
534 );
535
536 let Some(version_component) = version_component else {
538 return self.adapter.execute_raw_query(sql).await;
540 };
541
542 let cache_key = Self::generate_aggregation_cache_key(
544 sql,
545 &self.schema_version,
546 Some(&version_component),
547 );
548
549 if let Some(cached_result) = self.cache.get(&cache_key)? {
551 let results: Vec<std::collections::HashMap<String, serde_json::Value>> = cached_result
553 .iter()
554 .filter_map(|jv| serde_json::from_value(jv.as_value().clone()).ok())
555 .collect();
556 return Ok(results);
557 }
558
559 let result = self.adapter.execute_raw_query(sql).await?;
561
562 let cached_values: Vec<JsonbValue> = result
564 .iter()
565 .filter_map(|row| serde_json::to_value(row).ok().map(JsonbValue::new))
566 .collect();
567
568 self.cache.put(
569 cache_key,
570 cached_values,
571 vec![table_name], )?;
573
574 Ok(result)
575 }
576}
577
578#[async_trait]
579impl<A: DatabaseAdapter> DatabaseAdapter for CachedDatabaseAdapter<A> {
580 async fn execute_with_projection(
581 &self,
582 view: &str,
583 projection: Option<&crate::schema::SqlProjectionHint>,
584 where_clause: Option<&WhereClause>,
585 limit: Option<u32>,
586 ) -> Result<Vec<JsonbValue>> {
587 if !self.cache.is_enabled() {
589 return self
590 .adapter
591 .execute_with_projection(view, projection, where_clause, limit)
592 .await;
593 }
594
595 let query_string = format!("query {{ {view} }}");
597 let projection_info = projection.map(|p| &p.projection_template[..]).unwrap_or("");
598 let variables = json!({
599 "limit": limit,
600 "projection": projection_info,
601 });
602
603 let cache_key =
604 generate_cache_key(&query_string, &variables, where_clause, &self.schema_version);
605
606 if let Some(cached_result) = self.cache.get(&cache_key)? {
608 return Ok((*cached_result).clone());
609 }
610
611 let result = self
613 .adapter
614 .execute_with_projection(view, projection, where_clause, limit)
615 .await?;
616
617 self.cache.put(cache_key, result.clone(), vec![view.to_string()])?;
619
620 Ok(result)
621 }
622
623 async fn execute_where_query(
624 &self,
625 view: &str,
626 where_clause: Option<&WhereClause>,
627 limit: Option<u32>,
628 offset: Option<u32>,
629 ) -> Result<Vec<JsonbValue>> {
630 if !self.cache.is_enabled() {
632 return self.adapter.execute_where_query(view, where_clause, limit, offset).await;
633 }
634
635 let is_simple_query = where_clause.is_none() && limit.is_none_or(|l| l <= 1000);
646
647 if is_simple_query {
648 return self.adapter.execute_where_query(view, where_clause, limit, offset).await;
651 }
652
653 let query_string = format!("query {{ {view} }}");
655 let variables = json!({
656 "limit": limit,
657 "offset": offset,
658 });
659
660 let cache_key =
661 generate_cache_key(&query_string, &variables, where_clause, &self.schema_version);
662
663 if let Some(cached_result) = self.cache.get(&cache_key)? {
665 return Ok((*cached_result).clone());
667 }
668
669 let result = self.adapter.execute_where_query(view, where_clause, limit, offset).await?;
671
672 self.cache.put(
676 cache_key,
677 result.clone(),
678 vec![view.to_string()], )?;
680
681 Ok(result)
682 }
683
684 fn database_type(&self) -> DatabaseType {
685 self.adapter.database_type()
686 }
687
688 async fn health_check(&self) -> Result<()> {
689 self.adapter.health_check().await
690 }
691
692 fn pool_metrics(&self) -> PoolMetrics {
693 self.adapter.pool_metrics()
694 }
695
696 async fn execute_raw_query(
697 &self,
698 sql: &str,
699 ) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>> {
700 self.execute_aggregation_query(sql).await
702 }
703
704 async fn execute_function_call(
705 &self,
706 function_name: &str,
707 args: &[serde_json::Value],
708 ) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>> {
709 self.adapter.execute_function_call(function_name, args).await
711 }
712}
713
714#[cfg(test)]
715mod tests {
716 use super::*;
717 use crate::{cache::CacheConfig, db::WhereOperator};
718
719 struct MockAdapter {
721 call_count: std::sync::atomic::AtomicU32,
723 raw_call_count: std::sync::atomic::AtomicU32,
725 }
726
727 impl MockAdapter {
728 fn new() -> Self {
729 Self {
730 call_count: std::sync::atomic::AtomicU32::new(0),
731 raw_call_count: std::sync::atomic::AtomicU32::new(0),
732 }
733 }
734
735 fn call_count(&self) -> u32 {
736 self.call_count.load(std::sync::atomic::Ordering::SeqCst)
738 + self.raw_call_count.load(std::sync::atomic::Ordering::SeqCst)
739 }
740 }
741
742 #[async_trait]
743 impl DatabaseAdapter for MockAdapter {
744 async fn execute_with_projection(
745 &self,
746 _view: &str,
747 _projection: Option<&crate::schema::SqlProjectionHint>,
748 _where_clause: Option<&WhereClause>,
749 _limit: Option<u32>,
750 ) -> Result<Vec<JsonbValue>> {
751 self.call_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
752
753 Ok(vec![
755 JsonbValue::new(json!({"id": 1, "name": "Alice"})),
756 JsonbValue::new(json!({"id": 2, "name": "Bob"})),
757 ])
758 }
759
760 async fn execute_where_query(
761 &self,
762 _view: &str,
763 _where_clause: Option<&WhereClause>,
764 _limit: Option<u32>,
765 _offset: Option<u32>,
766 ) -> Result<Vec<JsonbValue>> {
767 self.call_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
768
769 Ok(vec![
771 JsonbValue::new(json!({"id": 1, "name": "Alice"})),
772 JsonbValue::new(json!({"id": 2, "name": "Bob"})),
773 ])
774 }
775
776 fn database_type(&self) -> DatabaseType {
777 DatabaseType::PostgreSQL
778 }
779
780 async fn health_check(&self) -> Result<()> {
781 Ok(())
782 }
783
784 fn pool_metrics(&self) -> PoolMetrics {
785 PoolMetrics {
786 total_connections: 10,
787 idle_connections: 5,
788 active_connections: 3,
789 waiting_requests: 0,
790 }
791 }
792
793 async fn execute_raw_query(
794 &self,
795 _sql: &str,
796 ) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>> {
797 self.raw_call_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
798 let mut row = std::collections::HashMap::new();
800 row.insert("count".to_string(), json!(42));
801 Ok(vec![row])
802 }
803
804 async fn execute_function_call(
805 &self,
806 _function_name: &str,
807 _args: &[serde_json::Value],
808 ) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>> {
809 Ok(vec![])
810 }
811 }
812
813 #[tokio::test]
814 async fn test_cache_miss_then_hit() {
815 let mock = MockAdapter::new();
816 let cache = QueryResultCache::new(CacheConfig::enabled());
817 let adapter = CachedDatabaseAdapter::new(mock, cache, "1.0.0".to_string());
818
819 let where_clause = WhereClause::Field {
821 path: vec!["active".to_string()],
822 operator: WhereOperator::Eq,
823 value: json!(true),
824 };
825
826 let result1 = adapter
828 .execute_where_query("v_user", Some(&where_clause), None, None)
829 .await
830 .unwrap();
831 assert_eq!(result1.len(), 2);
832 assert_eq!(adapter.inner().call_count(), 1);
833
834 let result2 = adapter
836 .execute_where_query("v_user", Some(&where_clause), None, None)
837 .await
838 .unwrap();
839 assert_eq!(result2.len(), 2);
840 assert_eq!(adapter.inner().call_count(), 1); }
842
843 #[tokio::test]
844 async fn test_different_where_clauses_produce_different_cache_entries() {
845 let mock = MockAdapter::new();
846 let cache = QueryResultCache::new(CacheConfig::enabled());
847 let adapter = CachedDatabaseAdapter::new(mock, cache, "1.0.0".to_string());
848
849 let where1 = WhereClause::Field {
850 path: vec!["id".to_string()],
851 operator: WhereOperator::Eq,
852 value: json!(1),
853 };
854
855 let where2 = WhereClause::Field {
856 path: vec!["id".to_string()],
857 operator: WhereOperator::Eq,
858 value: json!(2),
859 };
860
861 adapter.execute_where_query("v_user", Some(&where1), None, None).await.unwrap();
863 assert_eq!(adapter.inner().call_count(), 1);
864
865 adapter.execute_where_query("v_user", Some(&where2), None, None).await.unwrap();
867 assert_eq!(adapter.inner().call_count(), 2);
868 }
869
870 #[tokio::test]
871 async fn test_invalidation_clears_cache() {
872 let mock = MockAdapter::new();
873 let cache = QueryResultCache::new(CacheConfig::enabled());
874 let adapter = CachedDatabaseAdapter::new(mock, cache, "1.0.0".to_string());
875
876 let where_clause = WhereClause::Field {
878 path: vec!["status".to_string()],
879 operator: WhereOperator::Eq,
880 value: json!("active"),
881 };
882
883 adapter
885 .execute_where_query("v_user", Some(&where_clause), None, None)
886 .await
887 .unwrap();
888 assert_eq!(adapter.inner().call_count(), 1);
889
890 adapter
892 .execute_where_query("v_user", Some(&where_clause), None, None)
893 .await
894 .unwrap();
895 assert_eq!(adapter.inner().call_count(), 1);
896
897 let invalidated = adapter.invalidate_views(&["v_user".to_string()]).unwrap();
899 assert_eq!(invalidated, 1);
900
901 adapter
903 .execute_where_query("v_user", Some(&where_clause), None, None)
904 .await
905 .unwrap();
906 assert_eq!(adapter.inner().call_count(), 2);
907 }
908
909 #[tokio::test]
910 async fn test_different_limits_produce_different_cache_entries() {
911 let mock = MockAdapter::new();
912 let cache = QueryResultCache::new(CacheConfig::enabled());
913 let adapter = CachedDatabaseAdapter::new(mock, cache, "1.0.0".to_string());
914
915 adapter.execute_where_query("v_user", None, Some(10), None).await.unwrap();
917 assert_eq!(adapter.inner().call_count(), 1);
918
919 adapter.execute_where_query("v_user", None, Some(20), None).await.unwrap();
921 assert_eq!(adapter.inner().call_count(), 2);
922 }
923
924 #[tokio::test]
925 async fn test_cache_disabled() {
926 let mock = MockAdapter::new();
927 let cache = QueryResultCache::new(CacheConfig::disabled());
928 let adapter = CachedDatabaseAdapter::new(mock, cache, "1.0.0".to_string());
929
930 let where_clause = WhereClause::Field {
932 path: vec!["status".to_string()],
933 operator: WhereOperator::Eq,
934 value: json!("active"),
935 };
936
937 adapter
939 .execute_where_query("v_user", Some(&where_clause), None, None)
940 .await
941 .unwrap();
942 assert_eq!(adapter.inner().call_count(), 1);
943
944 adapter
946 .execute_where_query("v_user", Some(&where_clause), None, None)
947 .await
948 .unwrap();
949 assert_eq!(adapter.inner().call_count(), 2);
950 }
951
952 #[tokio::test]
958 async fn test_simple_queries_bypass_cache() {
959 let mock = MockAdapter::new();
960 let cache = QueryResultCache::new(CacheConfig::enabled());
961 let adapter = CachedDatabaseAdapter::new(mock, cache, "1.0.0".to_string());
962
963 adapter.execute_where_query("v_user", None, None, None).await.unwrap();
965 assert_eq!(adapter.inner().call_count(), 1);
966
967 adapter.execute_where_query("v_user", None, None, None).await.unwrap();
969 assert_eq!(adapter.inner().call_count(), 2);
970
971 adapter.execute_where_query("v_user", None, Some(1000), None).await.unwrap();
973 assert_eq!(adapter.inner().call_count(), 3);
974
975 adapter.execute_where_query("v_user", None, Some(1000), None).await.unwrap();
977 assert_eq!(adapter.inner().call_count(), 4);
978
979 let where_clause = WhereClause::Field {
981 path: vec!["id".to_string()],
982 operator: WhereOperator::Eq,
983 value: json!(1),
984 };
985 adapter
986 .execute_where_query("v_user", Some(&where_clause), None, None)
987 .await
988 .unwrap();
989 assert_eq!(adapter.inner().call_count(), 5);
990
991 adapter
993 .execute_where_query("v_user", Some(&where_clause), None, None)
994 .await
995 .unwrap();
996 assert_eq!(adapter.inner().call_count(), 5); adapter.execute_where_query("v_user", None, Some(1001), None).await.unwrap();
1000 assert_eq!(adapter.inner().call_count(), 6);
1001
1002 adapter.execute_where_query("v_user", None, Some(1001), None).await.unwrap();
1004 assert_eq!(adapter.inner().call_count(), 6); }
1006
1007 #[tokio::test]
1008 async fn test_schema_version_change_invalidates_cache() {
1009 let cache = Arc::new(QueryResultCache::new(CacheConfig::enabled()));
1010 let version_provider = Arc::new(FactTableVersionProvider::default());
1011
1012 let mock1 = MockAdapter::new();
1014 let adapter_v1 = CachedDatabaseAdapter {
1015 adapter: mock1,
1016 cache: Arc::clone(&cache),
1017 schema_version: "1.0.0".to_string(),
1018 fact_table_config: FactTableCacheConfig::default(),
1019 version_provider: Arc::clone(&version_provider),
1020 };
1021
1022 adapter_v1.execute_where_query("v_user", None, None, None).await.unwrap();
1024
1025 let mock2 = MockAdapter::new();
1027 let adapter_v2 = CachedDatabaseAdapter {
1028 adapter: mock2,
1029 cache: Arc::clone(&cache),
1030 schema_version: "2.0.0".to_string(),
1031 fact_table_config: FactTableCacheConfig::default(),
1032 version_provider: Arc::clone(&version_provider),
1033 };
1034
1035 adapter_v2.execute_where_query("v_user", None, None, None).await.unwrap();
1037 assert_eq!(adapter_v2.inner().call_count(), 1); }
1039
1040 #[tokio::test]
1041 async fn test_forwards_database_type() {
1042 let mock = MockAdapter::new();
1043 let cache = QueryResultCache::new(CacheConfig::enabled());
1044 let adapter = CachedDatabaseAdapter::new(mock, cache, "1.0.0".to_string());
1045
1046 assert_eq!(adapter.database_type(), DatabaseType::PostgreSQL);
1047 }
1048
1049 #[tokio::test]
1050 async fn test_forwards_health_check() {
1051 let mock = MockAdapter::new();
1052 let cache = QueryResultCache::new(CacheConfig::enabled());
1053 let adapter = CachedDatabaseAdapter::new(mock, cache, "1.0.0".to_string());
1054
1055 adapter.health_check().await.unwrap();
1056 }
1057
1058 #[tokio::test]
1059 async fn test_forwards_pool_metrics() {
1060 let mock = MockAdapter::new();
1061 let cache = QueryResultCache::new(CacheConfig::enabled());
1062 let adapter = CachedDatabaseAdapter::new(mock, cache, "1.0.0".to_string());
1063
1064 let metrics = adapter.pool_metrics();
1065 assert_eq!(metrics.total_connections, 10);
1066 assert_eq!(metrics.idle_connections, 5);
1067 }
1068
1069 #[tokio::test]
1070 async fn test_inner_and_cache_accessors() {
1071 let mock = MockAdapter::new();
1072 let cache = QueryResultCache::new(CacheConfig::enabled());
1073 let adapter = CachedDatabaseAdapter::new(mock, cache, "1.0.0".to_string());
1074
1075 assert_eq!(adapter.inner().call_count(), 0);
1077
1078 let cache_metrics = adapter.cache().metrics().unwrap();
1080 assert_eq!(cache_metrics.hits, 0);
1081
1082 assert_eq!(adapter.schema_version(), "1.0.0");
1084 }
1085
1086 use super::super::cascade_response_parser::CascadeResponseParser;
1089
1090 #[tokio::test]
1091 async fn test_invalidate_cascade_entities_with_single_entity() {
1092 let mock = MockAdapter::new();
1093 let cache = QueryResultCache::new(CacheConfig::enabled());
1094 let adapter = CachedDatabaseAdapter::new(mock, cache, "1.0.0".to_string());
1095
1096 let where_clause = WhereClause::Field {
1098 path: vec!["status".to_string()],
1099 operator: WhereOperator::Eq,
1100 value: json!("active"),
1101 };
1102
1103 adapter
1105 .execute_where_query("v_user", Some(&where_clause), None, None)
1106 .await
1107 .unwrap();
1108 assert_eq!(adapter.inner().call_count(), 1);
1109
1110 adapter
1112 .execute_where_query("v_user", Some(&where_clause), None, None)
1113 .await
1114 .unwrap();
1115 assert_eq!(adapter.inner().call_count(), 1);
1116
1117 let cascade_response = json!({
1119 "createPost": {
1120 "cascade": {
1121 "updated": [
1122 {
1123 "__typename": "User",
1124 "id": "550e8400-e29b-41d4-a716-446655440000"
1125 }
1126 ],
1127 "deleted": []
1128 }
1129 }
1130 });
1131
1132 let parser = CascadeResponseParser::new();
1133 let invalidated = adapter.invalidate_cascade_entities(&cascade_response, &parser).unwrap();
1134
1135 assert_eq!(invalidated, 1);
1137
1138 adapter
1140 .execute_where_query("v_user", Some(&where_clause), None, None)
1141 .await
1142 .unwrap();
1143 assert_eq!(adapter.inner().call_count(), 2);
1144 }
1145
1146 #[tokio::test]
1147 async fn test_invalidate_cascade_entities_with_multiple_entities() {
1148 let mock = MockAdapter::new();
1149 let cache = QueryResultCache::new(CacheConfig::enabled());
1150 let adapter = CachedDatabaseAdapter::new(mock, cache, "1.0.0".to_string());
1151
1152 let where_clause = WhereClause::Field {
1154 path: vec!["status".to_string()],
1155 operator: WhereOperator::Eq,
1156 value: json!("active"),
1157 };
1158
1159 adapter
1161 .execute_where_query("v_user", Some(&where_clause), None, None)
1162 .await
1163 .unwrap();
1164 adapter
1165 .execute_where_query("v_post", Some(&where_clause), None, None)
1166 .await
1167 .unwrap();
1168 adapter
1169 .execute_where_query("v_comment", Some(&where_clause), None, None)
1170 .await
1171 .unwrap();
1172 assert_eq!(adapter.inner().call_count(), 3);
1173
1174 let cascade_response = json!({
1176 "updateUser": {
1177 "cascade": {
1178 "updated": [
1179 {"__typename": "User", "id": "u-1"},
1180 {"__typename": "Post", "id": "p-1"},
1181 {"__typename": "Comment", "id": "c-1"}
1182 ],
1183 "deleted": []
1184 }
1185 }
1186 });
1187
1188 let parser = CascadeResponseParser::new();
1189 let invalidated = adapter.invalidate_cascade_entities(&cascade_response, &parser).unwrap();
1190
1191 assert_eq!(invalidated, 3);
1193
1194 adapter
1196 .execute_where_query("v_user", Some(&where_clause), None, None)
1197 .await
1198 .unwrap();
1199 adapter
1200 .execute_where_query("v_post", Some(&where_clause), None, None)
1201 .await
1202 .unwrap();
1203 adapter
1204 .execute_where_query("v_comment", Some(&where_clause), None, None)
1205 .await
1206 .unwrap();
1207 assert_eq!(adapter.inner().call_count(), 6);
1209 }
1210
1211 #[tokio::test]
1212 async fn test_invalidate_cascade_entities_with_deleted_entities() {
1213 let mock = MockAdapter::new();
1214 let cache = QueryResultCache::new(CacheConfig::enabled());
1215 let adapter = CachedDatabaseAdapter::new(mock, cache, "1.0.0".to_string());
1216
1217 let where_clause = WhereClause::Field {
1218 path: vec!["status".to_string()],
1219 operator: WhereOperator::Eq,
1220 value: json!("active"),
1221 };
1222
1223 adapter
1225 .execute_where_query("v_post", Some(&where_clause), None, None)
1226 .await
1227 .unwrap();
1228 adapter
1229 .execute_where_query("v_comment", Some(&where_clause), None, None)
1230 .await
1231 .unwrap();
1232 assert_eq!(adapter.inner().call_count(), 2);
1233
1234 let cascade_response = json!({
1236 "deletePost": {
1237 "cascade": {
1238 "updated": [],
1239 "deleted": [
1240 {"__typename": "Post", "id": "p-123"},
1241 {"__typename": "Comment", "id": "c-456"}
1242 ]
1243 }
1244 }
1245 });
1246
1247 let parser = CascadeResponseParser::new();
1248 let invalidated = adapter.invalidate_cascade_entities(&cascade_response, &parser).unwrap();
1249
1250 assert_eq!(invalidated, 2);
1252
1253 adapter
1255 .execute_where_query("v_post", Some(&where_clause), None, None)
1256 .await
1257 .unwrap();
1258 adapter
1259 .execute_where_query("v_comment", Some(&where_clause), None, None)
1260 .await
1261 .unwrap();
1262 assert_eq!(adapter.inner().call_count(), 4);
1263 }
1264
1265 #[tokio::test]
1266 async fn test_invalidate_cascade_entities_with_no_cascade_field() {
1267 let mock = MockAdapter::new();
1268 let cache = QueryResultCache::new(CacheConfig::enabled());
1269 let adapter = CachedDatabaseAdapter::new(mock, cache, "1.0.0".to_string());
1270
1271 let where_clause = WhereClause::Field {
1273 path: vec!["status".to_string()],
1274 operator: WhereOperator::Eq,
1275 value: json!("active"),
1276 };
1277
1278 adapter
1280 .execute_where_query("v_user", Some(&where_clause), None, None)
1281 .await
1282 .unwrap();
1283 assert_eq!(adapter.inner().call_count(), 1);
1284
1285 let cascade_response = json!({
1287 "createPost": {
1288 "post": {
1289 "id": "p-123",
1290 "title": "Hello"
1291 }
1292 }
1293 });
1294
1295 let parser = CascadeResponseParser::new();
1296 let invalidated = adapter.invalidate_cascade_entities(&cascade_response, &parser).unwrap();
1297
1298 assert_eq!(invalidated, 0);
1300
1301 adapter
1303 .execute_where_query("v_user", Some(&where_clause), None, None)
1304 .await
1305 .unwrap();
1306 assert_eq!(adapter.inner().call_count(), 1); }
1308
1309 #[tokio::test]
1310 async fn test_invalidate_cascade_entities_mixed_updated_and_deleted() {
1311 let mock = MockAdapter::new();
1312 let cache = QueryResultCache::new(CacheConfig::enabled());
1313 let adapter = CachedDatabaseAdapter::new(mock, cache, "1.0.0".to_string());
1314
1315 let where_clause = WhereClause::Field {
1317 path: vec!["status".to_string()],
1318 operator: WhereOperator::Eq,
1319 value: json!("active"),
1320 };
1321
1322 adapter
1324 .execute_where_query("v_user", Some(&where_clause), None, None)
1325 .await
1326 .unwrap();
1327 adapter
1328 .execute_where_query("v_post", Some(&where_clause), None, None)
1329 .await
1330 .unwrap();
1331 assert_eq!(adapter.inner().call_count(), 2);
1332
1333 let cascade_response = json!({
1335 "mutation": {
1336 "cascade": {
1337 "updated": [
1338 {"__typename": "User", "id": "u-1"}
1339 ],
1340 "deleted": [
1341 {"__typename": "Post", "id": "p-1"}
1342 ]
1343 }
1344 }
1345 });
1346
1347 let parser = CascadeResponseParser::new();
1348 let invalidated = adapter.invalidate_cascade_entities(&cascade_response, &parser).unwrap();
1349
1350 assert_eq!(invalidated, 2);
1352
1353 adapter
1355 .execute_where_query("v_user", Some(&where_clause), None, None)
1356 .await
1357 .unwrap();
1358 adapter
1359 .execute_where_query("v_post", Some(&where_clause), None, None)
1360 .await
1361 .unwrap();
1362 assert_eq!(adapter.inner().call_count(), 4);
1363 }
1364
1365 #[tokio::test]
1366 async fn test_cascade_invalidation_deduplicates_entity_types() {
1367 let mock = MockAdapter::new();
1368 let cache = QueryResultCache::new(CacheConfig::enabled());
1369 let adapter = CachedDatabaseAdapter::new(mock, cache, "1.0.0".to_string());
1370
1371 let where_clause = WhereClause::Field {
1373 path: vec!["status".to_string()],
1374 operator: WhereOperator::Eq,
1375 value: json!("active"),
1376 };
1377
1378 adapter
1380 .execute_where_query("v_user", Some(&where_clause), None, None)
1381 .await
1382 .unwrap();
1383 assert_eq!(adapter.inner().call_count(), 1);
1384
1385 let cascade_response = json!({
1388 "mutation": {
1389 "cascade": {
1390 "updated": [
1391 {"__typename": "User", "id": "u-1"},
1392 {"__typename": "User", "id": "u-2"},
1393 {"__typename": "User", "id": "u-3"}
1394 ],
1395 "deleted": []
1396 }
1397 }
1398 });
1399
1400 let parser = CascadeResponseParser::new();
1401 let invalidated = adapter.invalidate_cascade_entities(&cascade_response, &parser).unwrap();
1402
1403 assert_eq!(invalidated, 1);
1406 }
1407
1408 #[tokio::test]
1409 async fn test_cascade_invalidation_vs_view_invalidation_same_result() {
1410 let where_clause = WhereClause::Field {
1411 path: vec!["status".to_string()],
1412 operator: WhereOperator::Eq,
1413 value: json!("active"),
1414 };
1415
1416 let mock1 = MockAdapter::new();
1418 let cache1 = QueryResultCache::new(CacheConfig::enabled());
1419 let adapter1 = CachedDatabaseAdapter::new(mock1, cache1, "1.0.0".to_string());
1420
1421 adapter1
1422 .execute_where_query("v_user", Some(&where_clause), None, None)
1423 .await
1424 .unwrap();
1425 adapter1
1426 .execute_where_query("v_post", Some(&where_clause), None, None)
1427 .await
1428 .unwrap();
1429
1430 let cascade_response = json!({
1431 "mutation": {
1432 "cascade": {
1433 "updated": [
1434 {"__typename": "User", "id": "u-1"},
1435 {"__typename": "Post", "id": "p-1"}
1436 ],
1437 "deleted": []
1438 }
1439 }
1440 });
1441
1442 let parser = CascadeResponseParser::new();
1443 let invalidated_cascade =
1444 adapter1.invalidate_cascade_entities(&cascade_response, &parser).unwrap();
1445
1446 let mock2 = MockAdapter::new();
1448 let cache2 = QueryResultCache::new(CacheConfig::enabled());
1449 let adapter2 = CachedDatabaseAdapter::new(mock2, cache2, "1.0.0".to_string());
1450
1451 adapter2
1452 .execute_where_query("v_user", Some(&where_clause), None, None)
1453 .await
1454 .unwrap();
1455 adapter2
1456 .execute_where_query("v_post", Some(&where_clause), None, None)
1457 .await
1458 .unwrap();
1459
1460 let invalidated_views = adapter2
1461 .invalidate_views(&["v_user".to_string(), "v_post".to_string()])
1462 .unwrap();
1463
1464 assert_eq!(invalidated_cascade, 2);
1466 assert_eq!(invalidated_views, 2);
1467 }
1468
1469 #[tokio::test]
1470 async fn test_cascade_invalidation_with_empty_cascade() {
1471 let mock = MockAdapter::new();
1472 let cache = QueryResultCache::new(CacheConfig::enabled());
1473 let adapter = CachedDatabaseAdapter::new(mock, cache, "1.0.0".to_string());
1474
1475 let where_clause = WhereClause::Field {
1477 path: vec!["status".to_string()],
1478 operator: WhereOperator::Eq,
1479 value: json!("active"),
1480 };
1481
1482 adapter
1484 .execute_where_query("v_user", Some(&where_clause), None, None)
1485 .await
1486 .unwrap();
1487 assert_eq!(adapter.inner().call_count(), 1);
1488
1489 let cascade_response = json!({
1491 "mutation": {
1492 "cascade": {
1493 "updated": [],
1494 "deleted": []
1495 }
1496 }
1497 });
1498
1499 let parser = CascadeResponseParser::new();
1500 let invalidated = adapter.invalidate_cascade_entities(&cascade_response, &parser).unwrap();
1501
1502 assert_eq!(invalidated, 0);
1504
1505 adapter
1507 .execute_where_query("v_user", Some(&where_clause), None, None)
1508 .await
1509 .unwrap();
1510 assert_eq!(adapter.inner().call_count(), 1); }
1512
1513 #[test]
1516 fn test_extract_fact_table_from_sql() {
1517 assert_eq!(
1519 CachedDatabaseAdapter::<MockAdapter>::extract_fact_table_from_sql(
1520 "SELECT SUM(revenue) FROM tf_sales WHERE year = 2024"
1521 ),
1522 Some("tf_sales".to_string())
1523 );
1524
1525 assert_eq!(
1527 CachedDatabaseAdapter::<MockAdapter>::extract_fact_table_from_sql(
1528 "SELECT COUNT(*) FROM tf_page_views"
1529 ),
1530 Some("tf_page_views".to_string())
1531 );
1532
1533 assert_eq!(
1535 CachedDatabaseAdapter::<MockAdapter>::extract_fact_table_from_sql(
1536 "select sum(x) FROM TF_EVENTS"
1537 ),
1538 Some("tf_events".to_string())
1539 );
1540
1541 assert_eq!(
1543 CachedDatabaseAdapter::<MockAdapter>::extract_fact_table_from_sql(
1544 "SELECT * FROM users WHERE id = 1"
1545 ),
1546 None
1547 );
1548
1549 assert_eq!(
1551 CachedDatabaseAdapter::<MockAdapter>::extract_fact_table_from_sql("SELECT 1 + 1"),
1552 None
1553 );
1554 }
1555
1556 #[test]
1557 fn test_generate_aggregation_cache_key() {
1558 let key1 = CachedDatabaseAdapter::<MockAdapter>::generate_aggregation_cache_key(
1559 "SELECT SUM(x) FROM tf_sales",
1560 "1.0.0",
1561 Some("tv:42"),
1562 );
1563
1564 let key2 = CachedDatabaseAdapter::<MockAdapter>::generate_aggregation_cache_key(
1565 "SELECT SUM(x) FROM tf_sales",
1566 "1.0.0",
1567 Some("tv:43"), );
1569
1570 let key3 = CachedDatabaseAdapter::<MockAdapter>::generate_aggregation_cache_key(
1571 "SELECT SUM(x) FROM tf_sales",
1572 "2.0.0", Some("tv:42"),
1574 );
1575
1576 assert!(key1.starts_with("agg:"));
1578 assert!(key2.starts_with("agg:"));
1579 assert!(key3.starts_with("agg:"));
1580
1581 assert_ne!(key1, key2);
1583 assert_ne!(key1, key3);
1584 assert_ne!(key2, key3);
1585 }
1586
1587 #[tokio::test]
1588 async fn test_aggregation_caching_time_based() {
1589 let mock = MockAdapter::new();
1590 let cache = QueryResultCache::new(CacheConfig::enabled());
1591
1592 let mut ft_config = FactTableCacheConfig::default();
1594 ft_config
1595 .set_strategy("tf_sales", FactTableVersionStrategy::TimeBased { ttl_seconds: 300 });
1596
1597 let adapter = CachedDatabaseAdapter::with_fact_table_config(
1598 mock,
1599 cache,
1600 "1.0.0".to_string(),
1601 ft_config,
1602 );
1603
1604 let _ = adapter
1606 .execute_aggregation_query("SELECT SUM(revenue) FROM tf_sales")
1607 .await
1608 .unwrap();
1609 assert_eq!(adapter.inner().call_count(), 1);
1610
1611 let _ = adapter
1613 .execute_aggregation_query("SELECT SUM(revenue) FROM tf_sales")
1614 .await
1615 .unwrap();
1616 assert_eq!(adapter.inner().call_count(), 1); }
1618
1619 #[tokio::test]
1620 async fn test_aggregation_caching_schema_version() {
1621 let mock = MockAdapter::new();
1622 let cache = QueryResultCache::new(CacheConfig::enabled());
1623
1624 let mut ft_config = FactTableCacheConfig::default();
1626 ft_config.set_strategy("tf_historical_rates", FactTableVersionStrategy::SchemaVersion);
1627
1628 let adapter = CachedDatabaseAdapter::with_fact_table_config(
1629 mock,
1630 cache,
1631 "1.0.0".to_string(),
1632 ft_config,
1633 );
1634
1635 let _ = adapter
1637 .execute_aggregation_query("SELECT AVG(rate) FROM tf_historical_rates")
1638 .await
1639 .unwrap();
1640 assert_eq!(adapter.inner().call_count(), 1);
1641
1642 let _ = adapter
1644 .execute_aggregation_query("SELECT AVG(rate) FROM tf_historical_rates")
1645 .await
1646 .unwrap();
1647 assert_eq!(adapter.inner().call_count(), 1); }
1649
1650 #[tokio::test]
1651 async fn test_aggregation_caching_disabled_by_default() {
1652 let mock = MockAdapter::new();
1653 let cache = QueryResultCache::new(CacheConfig::default());
1654
1655 let adapter = CachedDatabaseAdapter::new(mock, cache, "1.0.0".to_string());
1657
1658 let _ = adapter
1660 .execute_aggregation_query("SELECT SUM(revenue) FROM tf_sales")
1661 .await
1662 .unwrap();
1663 assert_eq!(adapter.inner().call_count(), 1);
1664
1665 let _ = adapter
1667 .execute_aggregation_query("SELECT SUM(revenue) FROM tf_sales")
1668 .await
1669 .unwrap();
1670 assert_eq!(adapter.inner().call_count(), 2); }
1672
1673 #[tokio::test]
1674 async fn test_aggregation_caching_non_fact_table() {
1675 let mock = MockAdapter::new();
1676 let cache = QueryResultCache::new(CacheConfig::enabled());
1677
1678 let ft_config = FactTableCacheConfig::with_default(FactTableVersionStrategy::SchemaVersion);
1680 let adapter = CachedDatabaseAdapter::with_fact_table_config(
1681 mock,
1682 cache,
1683 "1.0.0".to_string(),
1684 ft_config,
1685 );
1686
1687 let _ = adapter.execute_aggregation_query("SELECT COUNT(*) FROM users").await.unwrap();
1689 assert_eq!(adapter.inner().call_count(), 1);
1690
1691 let _ = adapter.execute_aggregation_query("SELECT COUNT(*) FROM users").await.unwrap();
1692 assert_eq!(adapter.inner().call_count(), 2); }
1694
1695 #[tokio::test]
1696 async fn test_aggregation_caching_different_queries() {
1697 let mock = MockAdapter::new();
1698 let cache = QueryResultCache::new(CacheConfig::enabled());
1699
1700 let mut ft_config = FactTableCacheConfig::default();
1701 ft_config.set_strategy("tf_sales", FactTableVersionStrategy::SchemaVersion);
1702
1703 let adapter = CachedDatabaseAdapter::with_fact_table_config(
1704 mock,
1705 cache,
1706 "1.0.0".to_string(),
1707 ft_config,
1708 );
1709
1710 let _ = adapter
1712 .execute_aggregation_query("SELECT SUM(revenue) FROM tf_sales WHERE year = 2024")
1713 .await
1714 .unwrap();
1715 assert_eq!(adapter.inner().call_count(), 1);
1716
1717 let _ = adapter
1719 .execute_aggregation_query("SELECT SUM(revenue) FROM tf_sales WHERE year = 2023")
1720 .await
1721 .unwrap();
1722 assert_eq!(adapter.inner().call_count(), 2); let _ = adapter
1726 .execute_aggregation_query("SELECT SUM(revenue) FROM tf_sales WHERE year = 2024")
1727 .await
1728 .unwrap();
1729 assert_eq!(adapter.inner().call_count(), 2); }
1731
1732 #[tokio::test]
1733 async fn test_fact_table_config_accessor() {
1734 let mock = MockAdapter::new();
1735 let cache = QueryResultCache::new(CacheConfig::enabled());
1736
1737 let mut ft_config = FactTableCacheConfig::default();
1738 ft_config.set_strategy("tf_sales", FactTableVersionStrategy::VersionTable);
1739
1740 let adapter = CachedDatabaseAdapter::with_fact_table_config(
1741 mock,
1742 cache,
1743 "1.0.0".to_string(),
1744 ft_config,
1745 );
1746
1747 assert_eq!(
1749 adapter.fact_table_config().get_strategy("tf_sales"),
1750 &FactTableVersionStrategy::VersionTable
1751 );
1752 assert_eq!(
1753 adapter.fact_table_config().get_strategy("tf_other"),
1754 &FactTableVersionStrategy::Disabled
1755 );
1756 }
1757}