use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, Instant};
use parking_lot::Mutex;
use rusqlite::Connection;
use super::mutation::TenantId;
use super::trait_def::CommitError;
use crate::migrations::MigrationRunner;
pub type PathResolver = Arc<dyn Fn(TenantId) -> PathBuf + Send + Sync>;
pub struct TenantCommitConnectionPool {
conns: Mutex<HashMap<TenantId, CachedConnection>>,
resolver: PathResolver,
max_size: usize,
}
struct CachedConnection {
conn: Arc<Mutex<Connection>>,
last_used: Instant,
}
impl TenantCommitConnectionPool {
pub const DEFAULT_MAX_SIZE: usize = 256;
pub const DEFAULT_IDLE_THRESHOLD: Duration = Duration::from_secs(5 * 60);
pub fn new(resolver: PathResolver) -> Self {
Self {
conns: Mutex::new(HashMap::new()),
resolver,
max_size: Self::DEFAULT_MAX_SIZE,
}
}
pub fn with_max_size(mut self, max_size: usize) -> Self {
self.max_size = max_size;
self
}
pub fn for_tenant(&self, tenant_id: TenantId) -> Result<Arc<Mutex<Connection>>, CommitError> {
let mut map = self.conns.lock();
if let Some(entry) = map.get_mut(&tenant_id) {
entry.last_used = Instant::now();
return Ok(Arc::clone(&entry.conn));
}
if map.len() >= self.max_size {
if let Some(victim) = map.iter().min_by_key(|(_, c)| c.last_used).map(|(t, _)| *t) {
map.remove(&victim);
}
}
let path = (self.resolver)(tenant_id);
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).map_err(|e| CommitError::StorageFailure {
message: format!("create tenant dir {parent:?}: {e}"),
})?;
}
let mut conn = Connection::open(&path).map_err(|e| CommitError::StorageFailure {
message: format!("open tenant db {path:?}: {e}"),
})?;
Self::configure_pragmas(&conn)?;
MigrationRunner::run_pending(&mut conn).map_err(|e| CommitError::StorageFailure {
message: format!("run migrations on {path:?}: {e}"),
})?;
let arc = Arc::new(Mutex::new(conn));
map.insert(
tenant_id,
CachedConnection {
conn: Arc::clone(&arc),
last_used: Instant::now(),
},
);
Ok(arc)
}
pub fn close_idle(&self, idle_threshold: Duration) -> usize {
let cutoff = Instant::now()
.checked_sub(idle_threshold)
.unwrap_or_else(Instant::now);
let mut map = self.conns.lock();
let to_evict: Vec<TenantId> = map
.iter()
.filter(|(_, c)| c.last_used < cutoff)
.map(|(t, _)| *t)
.collect();
for t in &to_evict {
map.remove(t);
}
to_evict.len()
}
pub fn open_count(&self) -> usize {
self.conns.lock().len()
}
pub fn close_all(&self) {
self.conns.lock().clear();
}
fn configure_pragmas(conn: &Connection) -> Result<(), CommitError> {
conn.execute_batch(
"PRAGMA journal_mode=WAL;\n\
PRAGMA synchronous=NORMAL;\n\
PRAGMA foreign_keys=ON;",
)
.map_err(|e| CommitError::StorageFailure {
message: format!("PRAGMA setup failed: {e}"),
})?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn build_pool(dir: &TempDir) -> TenantCommitConnectionPool {
let base = dir.path().to_path_buf();
TenantCommitConnectionPool::new(Arc::new(move |tid: TenantId| {
base.join(format!("tenant_{}", tid.0)).join("yantrik.db")
}))
}
#[test]
fn first_open_creates_file_and_runs_migrations() {
let dir = TempDir::new().unwrap();
let pool = build_pool(&dir);
let conn_arc = pool.for_tenant(TenantId::new(1)).unwrap();
let conn = conn_arc.lock();
let table_count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM sqlite_master \
WHERE type='table' AND name='memory_commit_log'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(table_count, 1);
let meta_count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM _yantrikdb_meta_migrations WHERE id = 1",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(meta_count, 1);
}
#[test]
fn repeat_open_returns_cached_arc() {
let dir = TempDir::new().unwrap();
let pool = build_pool(&dir);
let a = pool.for_tenant(TenantId::new(7)).unwrap();
let b = pool.for_tenant(TenantId::new(7)).unwrap();
assert!(Arc::ptr_eq(&a, &b));
assert_eq!(pool.open_count(), 1);
}
#[test]
fn different_tenants_get_distinct_connections() {
let dir = TempDir::new().unwrap();
let pool = build_pool(&dir);
let a = pool.for_tenant(TenantId::new(1)).unwrap();
let b = pool.for_tenant(TenantId::new(2)).unwrap();
assert!(!Arc::ptr_eq(&a, &b));
assert_eq!(pool.open_count(), 2);
assert!(dir.path().join("tenant_1/yantrik.db").exists());
assert!(dir.path().join("tenant_2/yantrik.db").exists());
}
#[test]
fn migration_idempotent_across_reopens() {
let dir = TempDir::new().unwrap();
{
let pool = build_pool(&dir);
pool.for_tenant(TenantId::new(1)).unwrap();
}
let pool2 = build_pool(&dir);
let conn_arc = pool2.for_tenant(TenantId::new(1)).unwrap();
let conn = conn_arc.lock();
let meta_count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM _yantrikdb_meta_migrations WHERE id = 1",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(meta_count, 1, "m001 must record exactly once");
}
#[test]
fn lru_eviction_at_max_size() {
let dir = TempDir::new().unwrap();
let base = dir.path().to_path_buf();
let pool = TenantCommitConnectionPool::new(Arc::new(move |tid: TenantId| {
base.join(format!("tenant_{}", tid.0)).join("yantrik.db")
}))
.with_max_size(2);
let _a = pool.for_tenant(TenantId::new(1)).unwrap();
let _b = pool.for_tenant(TenantId::new(2)).unwrap();
let _b2 = pool.for_tenant(TenantId::new(2)).unwrap();
let _c = pool.for_tenant(TenantId::new(3)).unwrap();
assert_eq!(pool.open_count(), 2);
let a2 = pool.for_tenant(TenantId::new(1)).unwrap();
assert_eq!(pool.open_count(), 2);
assert!(!Arc::ptr_eq(&a2, &_b));
}
#[test]
fn close_idle_evicts_old_connections() {
let dir = TempDir::new().unwrap();
let pool = build_pool(&dir);
let _a = pool.for_tenant(TenantId::new(1)).unwrap();
let _b = pool.for_tenant(TenantId::new(2)).unwrap();
assert_eq!(pool.open_count(), 2);
std::thread::sleep(Duration::from_millis(2));
let evicted = pool.close_idle(Duration::from_millis(1));
assert_eq!(evicted, 2);
assert_eq!(pool.open_count(), 0);
}
#[test]
fn close_all_drops_every_connection() {
let dir = TempDir::new().unwrap();
let pool = build_pool(&dir);
let _a = pool.for_tenant(TenantId::new(1)).unwrap();
let _b = pool.for_tenant(TenantId::new(7)).unwrap();
assert_eq!(pool.open_count(), 2);
pool.close_all();
assert_eq!(pool.open_count(), 0);
}
#[test]
fn pragmas_are_set_on_first_open() {
let dir = TempDir::new().unwrap();
let pool = build_pool(&dir);
let conn_arc = pool.for_tenant(TenantId::new(1)).unwrap();
let conn = conn_arc.lock();
let mode: String = conn
.query_row("PRAGMA journal_mode", [], |row| row.get(0))
.unwrap();
assert_eq!(mode.to_lowercase(), "wal");
}
#[test]
fn parent_dir_created_if_missing() {
let dir = TempDir::new().unwrap();
let base = dir.path().to_path_buf();
let pool = TenantCommitConnectionPool::new(Arc::new(move |tid: TenantId| {
base.join(format!("never/seen/this/dir/tenant_{}/yantrik.db", tid.0))
}));
let _conn = pool.for_tenant(TenantId::new(99)).unwrap();
assert!(dir
.path()
.join("never/seen/this/dir/tenant_99/yantrik.db")
.exists());
}
#[allow(dead_code)]
fn _send_sync_compile_check<T: Send + Sync>(_: T) {}
#[allow(dead_code)]
fn _pool_is_send_sync(p: TenantCommitConnectionPool) {
_send_sync_compile_check(p);
}
}