use std::{
collections::{HashMap, HashSet},
sync::{Arc, Mutex},
};
use async_trait::async_trait;
use super::{
cascade_invalidator::CascadeInvalidator,
fact_table_version::{FactTableCacheConfig, FactTableVersionProvider},
result::QueryResultCache,
};
use crate::{
cache::config::RlsEnforcement,
db::{
DatabaseAdapter, DatabaseType, PoolMetrics, SupportsMutations, WhereClause,
types::{JsonbValue, OrderByClause},
},
error::{FraiseQLError, Result},
schema::CompiledSchema,
};
mod mutation;
mod query;
#[cfg(test)]
mod tests;
pub use query::view_name_to_entity_type;
pub struct CachedDatabaseAdapter<A: DatabaseAdapter> {
pub(super) adapter: A,
pub(super) cache: Arc<QueryResultCache>,
pub(super) schema_version: String,
pub(super) view_ttl_overrides: HashMap<String, u64>,
pub(super) cacheable_views: HashSet<String>,
pub(super) opt_in_mode: bool,
pub(super) fact_table_config: FactTableCacheConfig,
pub(super) version_provider: Arc<FactTableVersionProvider>,
pub(super) cascade_invalidator: Option<Arc<Mutex<CascadeInvalidator>>>,
}
impl<A: DatabaseAdapter> CachedDatabaseAdapter<A> {
#[must_use]
pub fn new(adapter: A, cache: QueryResultCache, schema_version: String) -> Self {
Self {
adapter,
cache: Arc::new(cache),
schema_version,
view_ttl_overrides: HashMap::new(),
cacheable_views: HashSet::new(),
opt_in_mode: false,
fact_table_config: FactTableCacheConfig::default(),
version_provider: Arc::new(FactTableVersionProvider::default()),
cascade_invalidator: None,
}
}
#[must_use]
pub fn with_view_ttl_overrides(mut self, overrides: HashMap<String, u64>) -> Self {
self.cacheable_views = overrides.keys().cloned().collect();
self.view_ttl_overrides = overrides;
self.opt_in_mode = true;
self
}
#[must_use]
pub fn with_cascade_invalidator(mut self, invalidator: CascadeInvalidator) -> Self {
self.cascade_invalidator = Some(Arc::new(Mutex::new(invalidator)));
self
}
#[must_use]
pub fn with_ttl_overrides_from_schema(mut self, schema: &CompiledSchema) -> Self {
for query in &schema.queries {
if let (Some(view), Some(ttl)) = (&query.sql_source, query.cache_ttl_seconds) {
self.cacheable_views.insert(view.clone());
self.view_ttl_overrides.insert(view.clone(), ttl);
}
}
self.opt_in_mode = true;
self
}
#[must_use]
pub fn with_fact_table_config(
adapter: A,
cache: QueryResultCache,
schema_version: String,
fact_table_config: FactTableCacheConfig,
) -> Self {
Self {
adapter,
cache: Arc::new(cache),
schema_version,
view_ttl_overrides: HashMap::new(),
cacheable_views: HashSet::new(),
opt_in_mode: false,
fact_table_config,
version_provider: Arc::new(FactTableVersionProvider::default()),
cascade_invalidator: None,
}
}
#[must_use]
pub const fn inner(&self) -> &A {
&self.adapter
}
#[must_use]
pub fn cache(&self) -> &QueryResultCache {
&self.cache
}
#[must_use]
pub fn schema_version(&self) -> &str {
&self.schema_version
}
#[must_use]
pub const fn fact_table_config(&self) -> &FactTableCacheConfig {
&self.fact_table_config
}
#[must_use]
pub fn version_provider(&self) -> &FactTableVersionProvider {
&self.version_provider
}
pub async fn validate_rls_active(&self) -> Result<()> {
let result = self
.adapter
.execute_raw_query("SELECT current_setting('row_security', true) AS rls_setting")
.await;
let rls_active = match result {
Ok(rows) => rows
.first()
.and_then(|row| row.get("rls_setting"))
.and_then(serde_json::Value::as_str)
.is_some_and(|s| s == "on" || s == "force"),
Err(_) => false, };
if rls_active {
Ok(())
} else {
Err(FraiseQLError::Configuration {
message: "Caching is enabled in a multi-tenant schema but Row-Level Security \
does not appear to be active on the database. This would allow \
cross-tenant data leakage through the cache. \
Either disable caching, enable RLS, or set \
`rls_enforcement = \"off\"` in CacheConfig for single-tenant \
deployments."
.to_string(),
})
}
}
pub async fn enforce_rls(&self, enforcement: RlsEnforcement) -> Result<()> {
if enforcement == RlsEnforcement::Off {
return Ok(());
}
match self.validate_rls_active().await {
Ok(()) => Ok(()),
Err(e) => match enforcement {
RlsEnforcement::Error => Err(e),
RlsEnforcement::Warn => {
tracing::warn!(
"RLS check failed (rls_enforcement = \"warn\"): {}. \
Cross-tenant cache leakage is possible.",
e
);
Ok(())
},
RlsEnforcement::Off => Ok(()), },
}
}
}
impl<A: DatabaseAdapter + Clone> Clone for CachedDatabaseAdapter<A> {
fn clone(&self) -> Self {
Self {
adapter: self.adapter.clone(),
cache: Arc::clone(&self.cache),
schema_version: self.schema_version.clone(),
view_ttl_overrides: self.view_ttl_overrides.clone(),
cacheable_views: self.cacheable_views.clone(),
opt_in_mode: self.opt_in_mode,
fact_table_config: self.fact_table_config.clone(),
version_provider: Arc::clone(&self.version_provider),
cascade_invalidator: self.cascade_invalidator.clone(),
}
}
}
#[async_trait]
impl<A: DatabaseAdapter> DatabaseAdapter for CachedDatabaseAdapter<A> {
async fn execute_with_projection(
&self,
view: &str,
projection: Option<&crate::schema::SqlProjectionHint>,
where_clause: Option<&WhereClause>,
limit: Option<u32>,
offset: Option<u32>,
order_by: Option<&[OrderByClause]>,
) -> Result<Vec<JsonbValue>> {
self.execute_with_projection_impl(view, projection, where_clause, limit, offset, order_by)
.await
.map(Arc::unwrap_or_clone)
}
async fn execute_where_query(
&self,
view: &str,
where_clause: Option<&WhereClause>,
limit: Option<u32>,
offset: Option<u32>,
order_by: Option<&[OrderByClause]>,
) -> Result<Vec<JsonbValue>> {
self.execute_where_query_impl(view, where_clause, limit, offset, order_by)
.await
.map(Arc::unwrap_or_clone)
}
async fn execute_with_projection_arc(
&self,
view: &str,
projection: Option<&crate::schema::SqlProjectionHint>,
where_clause: Option<&WhereClause>,
limit: Option<u32>,
offset: Option<u32>,
order_by: Option<&[OrderByClause]>,
) -> Result<Arc<Vec<JsonbValue>>> {
self.execute_with_projection_impl(view, projection, where_clause, limit, offset, order_by)
.await
}
async fn execute_where_query_arc(
&self,
view: &str,
where_clause: Option<&WhereClause>,
limit: Option<u32>,
offset: Option<u32>,
order_by: Option<&[OrderByClause]>,
) -> Result<Arc<Vec<JsonbValue>>> {
self.execute_where_query_impl(view, where_clause, limit, offset, order_by).await
}
fn database_type(&self) -> DatabaseType {
self.adapter.database_type()
}
async fn health_check(&self) -> Result<()> {
self.adapter.health_check().await
}
fn pool_metrics(&self) -> PoolMetrics {
self.adapter.pool_metrics()
}
async fn execute_raw_query(
&self,
sql: &str,
) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>> {
self.execute_aggregation_query(sql).await
}
async fn execute_parameterized_aggregate(
&self,
sql: &str,
params: &[serde_json::Value],
) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>> {
self.adapter.execute_parameterized_aggregate(sql, params).await
}
async fn execute_function_call(
&self,
function_name: &str,
args: &[serde_json::Value],
) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>> {
self.adapter.execute_function_call(function_name, args).await
}
async fn invalidate_views(&self, views: &[String]) -> Result<u64> {
CachedDatabaseAdapter::invalidate_views(self, views)
}
async fn invalidate_by_entity(&self, entity_type: &str, entity_id: &str) -> Result<u64> {
CachedDatabaseAdapter::invalidate_by_entity(self, entity_type, entity_id)
}
async fn invalidate_list_queries(&self, views: &[String]) -> Result<u64> {
CachedDatabaseAdapter::invalidate_list_queries(self, views)
}
async fn bump_fact_table_versions(&self, tables: &[String]) -> Result<()> {
self.bump_fact_table_versions_impl(tables).await
}
}
impl<A: SupportsMutations + Send + Sync> SupportsMutations for CachedDatabaseAdapter<A> {}