use super::pool::{TenantError, TenantInfo};
use crate::config::Config;
use crate::db::PoolConfig;
use sea_orm::{ConnectOptions, ConnectionTrait, Database, DatabaseConnection, DbErr, Statement};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
const MAX_CACHED_POOLS: usize = 50;
const TENANT_CACHE_TTL: Duration = Duration::from_secs(300);
struct TenantCacheEntry {
info: TenantInfo,
cached_at: Instant,
}
impl TenantCacheEntry {
fn new(info: TenantInfo) -> Self {
Self {
info,
cached_at: Instant::now(),
}
}
fn is_expired(&self) -> bool {
self.cached_at.elapsed() > TENANT_CACHE_TTL
}
}
struct PoolEntry {
pool: DatabaseConnection,
last_used: Instant,
}
pub struct TenantManager {
pools: RwLock<HashMap<i64, PoolEntry>>,
master_connection: RwLock<Option<DatabaseConnection>>,
tenant_cache: RwLock<HashMap<String, TenantCacheEntry>>,
pool_config: PoolConfig,
tenant_fetcher: Option<Arc<dyn TenantFetcher>>,
}
#[async_trait::async_trait]
pub trait TenantFetcher: Send + Sync {
async fn fetch_tenant(
&self,
master_db: &DatabaseConnection,
tenant_name: &str,
) -> Result<Option<TenantInfo>, TenantError>;
async fn build_connection_url(
&self,
master_db: &DatabaseConnection,
droplet_id: i64,
) -> Result<String, TenantError>;
}
impl TenantManager {
pub fn new() -> Self {
Self {
pools: RwLock::new(HashMap::new()),
master_connection: RwLock::new(None),
tenant_cache: RwLock::new(HashMap::new()),
pool_config: PoolConfig::default(),
tenant_fetcher: None,
}
}
pub fn with_pool_config(mut self, config: PoolConfig) -> Self {
self.pool_config = config;
self
}
pub fn with_tenant_fetcher(mut self, fetcher: Arc<dyn TenantFetcher>) -> Self {
self.tenant_fetcher = Some(fetcher);
self
}
pub async fn get_master_connection(&self) -> Result<DatabaseConnection, TenantError> {
{
let guard = self.master_connection.read().await;
if let Some(conn) = guard.as_ref() {
return Ok(conn.clone());
}
}
let mut guard = self.master_connection.write().await;
if let Some(conn) = guard.as_ref() {
return Ok(conn.clone());
}
let config = Config::try_get()
.ok_or_else(|| TenantError::Internal("Config not initialized".to_string()))?;
tracing::info!("Initializing master database connection");
let opt = self.pool_config.to_connect_options(&config.database.url());
let db = Database::connect(opt).await.map_err(TenantError::Database)?;
*guard = Some(db.clone());
Ok(db)
}
pub async fn get_tenant_info(&self, tenant: &str) -> Result<TenantInfo, TenantError> {
{
let cache = self.tenant_cache.read().await;
if let Some(entry) = cache.get(tenant) {
if !entry.is_expired() {
return Ok(entry.info.clone());
}
}
}
let master_db = self.get_master_connection().await?;
let fetcher = self
.tenant_fetcher
.as_ref()
.ok_or_else(|| TenantError::Internal("Tenant fetcher not configured".to_string()))?;
let info = fetcher
.fetch_tenant(&master_db, tenant)
.await?
.ok_or_else(|| TenantError::NotFound(tenant.to_string()))?;
{
let mut cache = self.tenant_cache.write().await;
cache.insert(tenant.to_string(), TenantCacheEntry::new(info.clone()));
}
Ok(info)
}
pub async fn validate_tenant_in_master(&self, tenant: &str) -> Result<(), TenantError> {
self.get_tenant_info(tenant).await?;
Ok(())
}
pub async fn get_connection(
&self,
tenant: &str,
) -> Result<(DatabaseConnection, TenantInfo), TenantError> {
let tenant_info = self.get_tenant_info(tenant).await?;
let droplet_id = tenant_info
.droplet_id
.ok_or(TenantError::NoDropletAssigned)?;
let pool = self.get_droplet_pool(droplet_id).await?;
self.switch_database(&pool, tenant).await?;
Ok((pool, tenant_info))
}
async fn get_droplet_pool(&self, droplet_id: i64) -> Result<DatabaseConnection, TenantError> {
{
let mut guard = self.pools.write().await;
if let Some(entry) = guard.get_mut(&droplet_id) {
entry.last_used = Instant::now();
return Ok(entry.pool.clone());
}
}
let master_db = self.get_master_connection().await?;
let fetcher = self
.tenant_fetcher
.as_ref()
.ok_or_else(|| TenantError::Internal("Tenant fetcher not configured".to_string()))?;
let url = fetcher.build_connection_url(&master_db, droplet_id).await?;
tracing::info!("Creating pool for droplet {}", droplet_id);
let opt = self.pool_config.to_connect_options(&url);
let pool = Database::connect(opt).await.map_err(TenantError::Database)?;
let mut guard = self.pools.write().await;
if guard.len() >= MAX_CACHED_POOLS {
self.evict_oldest_pool(&mut guard);
}
guard.insert(
droplet_id,
PoolEntry {
pool: pool.clone(),
last_used: Instant::now(),
},
);
Ok(pool)
}
async fn switch_database(
&self,
pool: &DatabaseConnection,
tenant: &str,
) -> Result<(), TenantError> {
if crate::validation::validate_tenant_name(tenant).is_err() {
return Err(TenantError::InvalidName(tenant.to_string()));
}
let sql = format!("USE `{}`", tenant);
pool.execute(Statement::from_string(
sea_orm::DatabaseBackend::MySql,
sql,
))
.await
.map_err(TenantError::Database)?;
Ok(())
}
fn evict_oldest_pool(&self, cache: &mut HashMap<i64, PoolEntry>) {
if let Some((oldest_id, _)) = cache
.iter()
.min_by_key(|(_, entry)| entry.last_used)
.map(|(k, v)| (*k, v.last_used))
{
tracing::debug!("Evicting pool for droplet: {}", oldest_id);
cache.remove(&oldest_id);
}
}
pub async fn invalidate_tenant_cache(&self, tenant: &str) {
let mut cache = self.tenant_cache.write().await;
cache.remove(tenant);
}
pub async fn clear_tenant_cache(&self) {
let mut cache = self.tenant_cache.write().await;
cache.clear();
}
}
impl Default for TenantManager {
fn default() -> Self {
Self::new()
}
}
static TENANT_MANAGER: tokio::sync::OnceCell<Arc<TenantManager>> = tokio::sync::OnceCell::const_new();
pub async fn get_tenant_manager() -> &'static Arc<TenantManager> {
TENANT_MANAGER
.get_or_init(|| async { Arc::new(TenantManager::new()) })
.await
}
pub async fn init_tenant_manager(manager: TenantManager) -> &'static Arc<TenantManager> {
TENANT_MANAGER
.get_or_init(|| async { Arc::new(manager) })
.await
}