use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::path::Path;
use std::sync::{Arc, RwLock};
use crate::error::{Error, Result};
use super::config::Config;
use super::connect::connect;
use super::database::Database;
struct ShardedMap {
shards: Vec<RwLock<HashMap<String, Database>>>,
}
impl ShardedMap {
fn new(num_shards: usize) -> Self {
let mut shards = Vec::with_capacity(num_shards);
for _ in 0..num_shards {
shards.push(RwLock::new(HashMap::new()));
}
Self { shards }
}
fn shard_index(&self, key: &str) -> usize {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish() as usize % self.shards.len()
}
fn get(&self, key: &str) -> Option<Database> {
let idx = self.shard_index(key);
let shard = &self.shards[idx];
let read = shard.read().expect("pool shard lock poisoned");
read.get(key).cloned()
}
fn insert(&self, key: String, db: Database) {
let idx = self.shard_index(&key);
let shard = &self.shards[idx];
let mut write = shard.write().expect("pool shard lock poisoned");
write.insert(key, db);
}
}
#[derive(Clone)]
pub struct DatabasePool {
inner: Arc<Inner>,
}
struct Inner {
default: Database,
config: Config,
shards: ShardedMap,
}
impl DatabasePool {
pub async fn new(config: &Config) -> Result<Self> {
let pool_config = config
.pool
.as_ref()
.ok_or_else(|| Error::internal("database pool config is required"))?;
if pool_config.lock_shards == 0 {
return Err(Error::internal("pool lock_shards must be greater than 0"));
}
let default = connect(config).await?;
let shards = ShardedMap::new(pool_config.lock_shards);
Ok(Self {
inner: Arc::new(Inner {
default,
config: config.clone(),
shards,
}),
})
}
pub async fn conn(&self, shard: Option<&str>) -> Result<Database> {
let Some(name) = shard else {
return Ok(self.inner.default.clone());
};
if name.is_empty()
|| name.starts_with('.')
|| name.contains('/')
|| name.contains('\\')
|| name.contains('\0')
{
return Err(Error::bad_request(format!("invalid shard name: {name:?}")));
}
if let Some(db) = self.inner.shards.get(name) {
return Ok(db);
}
let pool_config = self.inner.config.pool.as_ref().unwrap();
let shard_path = if pool_config.base_path == ":memory:" {
":memory:".to_string()
} else {
Path::new(&pool_config.base_path)
.join(format!("{name}.db"))
.to_string_lossy()
.into_owned()
};
let shard_config = Config {
path: shard_path,
pool: None,
..self.inner.config.clone()
};
let db = connect(&shard_config).await.map_err(|e| {
Error::internal(format!("failed to open shard database: {name}")).chain(e)
})?;
self.inner.shards.insert(name.to_string(), db.clone());
Ok(db)
}
}
pub struct ManagedDatabasePool(DatabasePool);
impl crate::runtime::Task for ManagedDatabasePool {
async fn shutdown(self) -> Result<()> {
drop(self.0);
Ok(())
}
}
pub fn managed_pool(pool: DatabasePool) -> ManagedDatabasePool {
ManagedDatabasePool(pool)
}
#[cfg(test)]
mod tests {
use super::*;
async fn make_test_db() -> Database {
let config = super::super::config::Config {
path: ":memory:".to_string(),
..Default::default()
};
super::super::connect::connect(&config).await.unwrap()
}
#[test]
fn sharded_map_get_returns_none_for_missing_key() {
let map = ShardedMap::new(4);
assert!(map.get("missing").is_none());
}
#[tokio::test]
async fn sharded_map_insert_and_get() {
let map = ShardedMap::new(4);
let db = make_test_db().await;
map.insert("tenant_a".to_string(), db);
assert!(map.get("tenant_a").is_some());
}
#[tokio::test]
async fn sharded_map_different_keys_independent() {
let map = ShardedMap::new(4);
let db = make_test_db().await;
map.insert("tenant_a".to_string(), db);
assert!(map.get("tenant_a").is_some());
assert!(map.get("tenant_b").is_none());
}
#[tokio::test]
async fn sharded_map_insert_idempotent() {
let map = ShardedMap::new(4);
let db1 = make_test_db().await;
let db2 = make_test_db().await;
map.insert("key".to_string(), db1);
map.insert("key".to_string(), db2);
assert!(map.get("key").is_some());
}
}