1use std::collections::HashMap;
2use std::hash::{DefaultHasher, Hash, Hasher};
3use std::path::Path;
4use std::sync::{Arc, RwLock};
5
6use crate::error::{Error, Result};
7
8use super::config::Config;
9use super::connect::connect;
10use super::database::Database;
11
12struct ShardedMap {
13 shards: Vec<RwLock<HashMap<String, Database>>>,
14}
15
16impl ShardedMap {
17 fn new(num_shards: usize) -> Self {
18 let mut shards = Vec::with_capacity(num_shards);
19 for _ in 0..num_shards {
20 shards.push(RwLock::new(HashMap::new()));
21 }
22 Self { shards }
23 }
24
25 fn shard_index(&self, key: &str) -> usize {
26 let mut hasher = DefaultHasher::new();
27 key.hash(&mut hasher);
28 hasher.finish() as usize % self.shards.len()
29 }
30
31 fn get(&self, key: &str) -> Option<Database> {
34 let idx = self.shard_index(key);
35 let shard = &self.shards[idx];
36 let read = shard.read().expect("pool shard lock poisoned");
37 read.get(key).cloned()
38 }
39
40 fn insert(&self, key: String, db: Database) {
43 let idx = self.shard_index(&key);
44 let shard = &self.shards[idx];
45 let mut write = shard.write().expect("pool shard lock poisoned");
46 write.insert(key, db);
47 }
48}
49
50#[derive(Clone)]
80pub struct DatabasePool {
81 inner: Arc<Inner>,
82}
83
84struct Inner {
85 default: Database,
86 config: Config,
87 shards: ShardedMap,
88}
89
90impl DatabasePool {
91 pub async fn new(config: &Config) -> Result<Self> {
101 let pool_config = config
102 .pool
103 .as_ref()
104 .ok_or_else(|| Error::internal("database pool config is required"))?;
105
106 if pool_config.lock_shards == 0 {
107 return Err(Error::internal("pool lock_shards must be greater than 0"));
108 }
109
110 let default = connect(config).await?;
111 let shards = ShardedMap::new(pool_config.lock_shards);
112
113 Ok(Self {
114 inner: Arc::new(Inner {
115 default,
116 config: config.clone(),
117 shards,
118 }),
119 })
120 }
121
122 pub async fn conn(&self, shard: Option<&str>) -> Result<Database> {
139 let Some(name) = shard else {
140 return Ok(self.inner.default.clone());
141 };
142
143 if name.is_empty()
144 || name.starts_with('.')
145 || name.contains('/')
146 || name.contains('\\')
147 || name.contains('\0')
148 {
149 return Err(Error::bad_request(format!("invalid shard name: {name:?}")));
150 }
151
152 if let Some(db) = self.inner.shards.get(name) {
153 return Ok(db);
154 }
155
156 let pool_config = self.inner.config.pool.as_ref().unwrap();
158 let shard_path = if pool_config.base_path == ":memory:" {
159 ":memory:".to_string()
160 } else {
161 Path::new(&pool_config.base_path)
162 .join(format!("{name}.db"))
163 .to_string_lossy()
164 .into_owned()
165 };
166 let shard_config = Config {
167 path: shard_path,
168 pool: None,
169 ..self.inner.config.clone()
170 };
171
172 let db = connect(&shard_config).await.map_err(|e| {
173 Error::internal(format!("failed to open shard database: {name}")).chain(e)
174 })?;
175
176 self.inner.shards.insert(name.to_string(), db.clone());
177 Ok(db)
178 }
179}
180
181pub struct ManagedDatabasePool(DatabasePool);
189
190impl crate::runtime::Task for ManagedDatabasePool {
191 async fn shutdown(self) -> Result<()> {
192 drop(self.0);
193 Ok(())
194 }
195}
196
197pub fn managed_pool(pool: DatabasePool) -> ManagedDatabasePool {
213 ManagedDatabasePool(pool)
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 async fn make_test_db() -> Database {
221 let config = super::super::config::Config {
222 path: ":memory:".to_string(),
223 ..Default::default()
224 };
225 super::super::connect::connect(&config).await.unwrap()
226 }
227
228 #[test]
229 fn sharded_map_get_returns_none_for_missing_key() {
230 let map = ShardedMap::new(4);
231 assert!(map.get("missing").is_none());
232 }
233
234 #[tokio::test]
235 async fn sharded_map_insert_and_get() {
236 let map = ShardedMap::new(4);
237 let db = make_test_db().await;
238 map.insert("tenant_a".to_string(), db);
239 assert!(map.get("tenant_a").is_some());
240 }
241
242 #[tokio::test]
243 async fn sharded_map_different_keys_independent() {
244 let map = ShardedMap::new(4);
245 let db = make_test_db().await;
246 map.insert("tenant_a".to_string(), db);
247 assert!(map.get("tenant_a").is_some());
248 assert!(map.get("tenant_b").is_none());
249 }
250
251 #[tokio::test]
252 async fn sharded_map_insert_idempotent() {
253 let map = ShardedMap::new(4);
254 let db1 = make_test_db().await;
255 let db2 = make_test_db().await;
256 map.insert("key".to_string(), db1);
257 map.insert("key".to_string(), db2);
258 assert!(map.get("key").is_some());
259 }
260}