use std::sync::Arc;
use arc_swap::ArcSwap;
use dashmap::DashMap;
use fraiseql_core::{db::traits::DatabaseAdapter, runtime::Executor};
use fraiseql_error::FraiseQLError;
pub struct TenantExecutorRegistry<A: DatabaseAdapter> {
default: Arc<ArcSwap<Executor<A>>>,
tenants: DashMap<String, Arc<ArcSwap<Executor<A>>>>,
}
impl<A: DatabaseAdapter> TenantExecutorRegistry<A> {
#[must_use]
pub fn new(default: Arc<ArcSwap<Executor<A>>>) -> Self {
Self {
default,
tenants: DashMap::new(),
}
}
pub fn executor_for(
&self,
tenant_key: Option<&str>,
) -> fraiseql_error::Result<arc_swap::Guard<Arc<Executor<A>>>> {
match tenant_key {
None => Ok(self.default.load()),
Some(key) => {
let entry = self.tenants.get(key).ok_or_else(|| {
FraiseQLError::unauthorized(format!("Tenant '{key}' is not registered"))
})?;
Ok(entry.value().load())
},
}
}
pub fn upsert(&self, key: impl Into<String>, executor: Arc<Executor<A>>) -> bool {
let key = key.into();
if let Some(existing) = self.tenants.get(&key) {
existing.value().store(executor);
false
} else {
self.tenants.insert(key, Arc::new(ArcSwap::from(executor)));
true
}
}
pub fn remove(&self, key: &str) -> fraiseql_error::Result<Arc<ArcSwap<Executor<A>>>> {
self.tenants
.remove(key)
.map(|(_, v)| v)
.ok_or_else(|| FraiseQLError::not_found("tenant", key))
}
#[must_use]
pub fn tenant_keys(&self) -> Vec<String> {
self.tenants.iter().map(|e| e.key().clone()).collect()
}
#[must_use]
pub fn len(&self) -> usize {
self.tenants.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.tenants.is_empty()
}
#[must_use]
pub fn default_executor(&self) -> arc_swap::Guard<Arc<Executor<A>>> {
self.default.load()
}
pub async fn health_check(&self, key: &str) -> fraiseql_error::Result<()> {
let entry = self.tenants.get(key).ok_or_else(|| FraiseQLError::not_found("tenant", key))?;
let executor = entry.value().load();
executor.adapter().health_check().await
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)] #![allow(clippy::missing_panics_doc)] #![allow(clippy::missing_errors_doc)] #![allow(missing_docs)]
use std::sync::Arc;
use arc_swap::ArcSwap;
use async_trait::async_trait;
use fraiseql_core::{
db::{
WhereClause,
traits::DatabaseAdapter,
types::{DatabaseType, JsonbValue, PoolMetrics},
},
error::Result as FraiseQLResult,
runtime::Executor,
schema::CompiledSchema,
};
use super::*;
#[derive(Debug, Clone)]
struct StubAdapter {
_label: &'static str,
}
impl StubAdapter {
fn new(label: &'static str) -> Self {
Self { _label: label }
}
}
#[async_trait]
impl DatabaseAdapter for StubAdapter {
async fn execute_where_query(
&self,
_view: &str,
_where_clause: Option<&WhereClause>,
_limit: Option<u32>,
_offset: Option<u32>,
_order_by: Option<&[fraiseql_core::db::types::OrderByClause]>,
) -> FraiseQLResult<Vec<JsonbValue>> {
Ok(vec![])
}
async fn execute_with_projection(
&self,
_view: &str,
_projection: Option<&fraiseql_core::schema::SqlProjectionHint>,
_where_clause: Option<&WhereClause>,
_limit: Option<u32>,
_offset: Option<u32>,
_order_by: Option<&[fraiseql_core::db::types::OrderByClause]>,
) -> FraiseQLResult<Vec<JsonbValue>> {
Ok(vec![])
}
fn database_type(&self) -> DatabaseType {
DatabaseType::SQLite
}
async fn health_check(&self) -> FraiseQLResult<()> {
Ok(())
}
fn pool_metrics(&self) -> PoolMetrics {
PoolMetrics::default()
}
async fn execute_raw_query(
&self,
_sql: &str,
) -> FraiseQLResult<Vec<std::collections::HashMap<String, serde_json::Value>>> {
Ok(vec![])
}
async fn execute_parameterized_aggregate(
&self,
_sql: &str,
_params: &[serde_json::Value],
) -> FraiseQLResult<Vec<std::collections::HashMap<String, serde_json::Value>>> {
Ok(vec![])
}
}
fn default_executor() -> Arc<ArcSwap<Executor<StubAdapter>>> {
let schema = CompiledSchema::default();
let executor = Arc::new(Executor::new(schema, Arc::new(StubAdapter::new("default"))));
Arc::new(ArcSwap::from(executor))
}
fn tenant_executor(label: &'static str) -> Arc<Executor<StubAdapter>> {
let mut schema = CompiledSchema::default();
schema
.queries
.push(fraiseql_core::schema::QueryDefinition::new("users", "User"));
Arc::new(Executor::new(schema, Arc::new(StubAdapter::new(label))))
}
#[test]
fn test_registry_returns_default_when_no_tenant() {
let registry = TenantExecutorRegistry::new(default_executor());
let exec = registry.executor_for(None);
assert!(exec.is_ok());
assert_eq!(exec.unwrap().schema().queries.len(), 0);
}
#[test]
fn test_registry_returns_tenant_executor() {
let registry = TenantExecutorRegistry::new(default_executor());
registry.upsert("tenant-abc", tenant_executor("abc"));
let exec = registry.executor_for(Some("tenant-abc"));
assert!(exec.is_ok());
assert_eq!(exec.unwrap().schema().queries.len(), 1);
}
#[test]
fn test_registry_falls_back_to_default_for_no_key() {
let registry = TenantExecutorRegistry::new(default_executor());
registry.upsert("tenant-abc", tenant_executor("abc"));
let exec = registry.executor_for(None);
assert!(exec.is_ok());
assert_eq!(exec.unwrap().schema().queries.len(), 0);
}
#[test]
fn test_registry_rejects_explicit_but_unregistered_key() {
let registry = TenantExecutorRegistry::new(default_executor());
let Err(err) = registry.executor_for(Some("unknown")) else {
panic!("expected Err for unregistered key");
};
assert!(
matches!(err, FraiseQLError::Authorization { .. }),
"Expected Authorization error, got: {err:?}"
);
}
#[test]
fn test_registry_upsert_returns_true_on_insert() {
let registry = TenantExecutorRegistry::new(default_executor());
let was_insert = registry.upsert("tenant-abc", tenant_executor("abc"));
assert!(was_insert);
}
#[test]
fn test_registry_upsert_returns_false_on_update() {
let registry = TenantExecutorRegistry::new(default_executor());
registry.upsert("tenant-abc", tenant_executor("abc"));
let was_insert = registry.upsert("tenant-abc", tenant_executor("abc-v2"));
assert!(!was_insert);
}
#[test]
fn test_registry_remove_existing() {
let registry = TenantExecutorRegistry::new(default_executor());
registry.upsert("tenant-abc", tenant_executor("abc"));
assert_eq!(registry.len(), 1);
assert!(registry.remove("tenant-abc").is_ok());
assert_eq!(registry.len(), 0);
}
#[test]
fn test_registry_remove_unknown_returns_error() {
let registry = TenantExecutorRegistry::new(default_executor());
let Err(err) = registry.remove("unknown") else {
panic!("expected Err for unknown key");
};
assert!(
matches!(err, FraiseQLError::NotFound { .. }),
"Expected NotFound error, got: {err:?}"
);
}
#[test]
fn test_registry_tenant_keys() {
let registry = TenantExecutorRegistry::new(default_executor());
registry.upsert("tenant-abc", tenant_executor("abc"));
registry.upsert("tenant-xyz", tenant_executor("xyz"));
let mut keys = registry.tenant_keys();
keys.sort();
assert_eq!(keys, vec!["tenant-abc", "tenant-xyz"]);
}
#[test]
fn test_registry_len_and_is_empty() {
let registry = TenantExecutorRegistry::new(default_executor());
assert!(registry.is_empty());
assert_eq!(registry.len(), 0);
registry.upsert("tenant-abc", tenant_executor("abc"));
assert!(!registry.is_empty());
assert_eq!(registry.len(), 1);
}
#[test]
fn test_registry_hot_reload_tenant() {
let registry = TenantExecutorRegistry::new(default_executor());
registry.upsert("tenant-abc", tenant_executor("abc-v1"));
let guard_v1 = registry.executor_for(Some("tenant-abc")).unwrap();
assert_eq!(guard_v1.schema().queries.len(), 1);
let mut schema_v2 = CompiledSchema::default();
schema_v2
.queries
.push(fraiseql_core::schema::QueryDefinition::new("users", "User"));
schema_v2
.queries
.push(fraiseql_core::schema::QueryDefinition::new("posts", "Post"));
let executor_v2 = Arc::new(Executor::new(schema_v2, Arc::new(StubAdapter::new("abc-v2"))));
registry.upsert("tenant-abc", executor_v2);
assert_eq!(guard_v1.schema().queries.len(), 1);
let guard_v2 = registry.executor_for(Some("tenant-abc")).unwrap();
assert_eq!(guard_v2.schema().queries.len(), 2);
}
#[test]
fn test_remove_tenant_in_flight_guard_survives() {
let registry = TenantExecutorRegistry::new(default_executor());
registry.upsert("tenant-abc", tenant_executor("abc"));
let guard = registry.executor_for(Some("tenant-abc")).unwrap();
let removed = registry.remove("tenant-abc");
assert!(removed.is_ok());
assert_eq!(guard.schema().queries.len(), 1);
let result = registry.executor_for(Some("tenant-abc"));
assert!(result.is_err());
}
}