1use std::collections::HashMap;
2use std::hash::{DefaultHasher, Hash, Hasher};
3use std::path::Path;
4use std::sync::{Arc, RwLock};
5
6use crate::error::{Error, Result};
7
8use super::config::Config;
9use super::connect::connect;
10use super::database::Database;
11
12struct ShardedMap {
13 shards: Vec<RwLock<HashMap<String, Database>>>,
14}
15
16impl ShardedMap {
17 fn new(num_shards: usize) -> Self {
18 let mut shards = Vec::with_capacity(num_shards);
19 for _ in 0..num_shards {
20 shards.push(RwLock::new(HashMap::new()));
21 }
22 Self { shards }
23 }
24
25 fn shard_index(&self, key: &str) -> usize {
26 let mut hasher = DefaultHasher::new();
27 key.hash(&mut hasher);
28 hasher.finish() as usize % self.shards.len()
29 }
30
31 fn get(&self, key: &str) -> Option<Database> {
34 let idx = self.shard_index(key);
35 let shard = &self.shards[idx];
36 let read = shard.read().expect("pool shard lock poisoned");
37 read.get(key).cloned()
38 }
39
40 fn insert(&self, key: String, db: Database) {
43 let idx = self.shard_index(&key);
44 let shard = &self.shards[idx];
45 let mut write = shard.write().expect("pool shard lock poisoned");
46 write.insert(key, db);
47 }
48}
49
50#[derive(Clone)]
78pub struct DatabasePool {
79 inner: Arc<Inner>,
80}
81
82struct Inner {
83 default: Database,
84 config: Config,
85 shards: ShardedMap,
86}
87
88impl DatabasePool {
89 pub async fn new(config: &Config) -> Result<Self> {
99 let pool_config = config
100 .pool
101 .as_ref()
102 .ok_or_else(|| Error::internal("database pool config is required"))?;
103
104 if pool_config.lock_shards == 0 {
105 return Err(Error::internal("pool lock_shards must be greater than 0"));
106 }
107
108 let default = connect(config).await?;
109 let shards = ShardedMap::new(pool_config.lock_shards);
110
111 Ok(Self {
112 inner: Arc::new(Inner {
113 default,
114 config: config.clone(),
115 shards,
116 }),
117 })
118 }
119
120 pub async fn conn(&self, shard: Option<&str>) -> Result<Database> {
137 let Some(name) = shard else {
138 return Ok(self.inner.default.clone());
139 };
140
141 if name.is_empty()
142 || name.starts_with('.')
143 || name.contains('/')
144 || name.contains('\\')
145 || name.contains('\0')
146 {
147 return Err(Error::bad_request(format!("invalid shard name: {name:?}")));
148 }
149
150 if let Some(db) = self.inner.shards.get(name) {
151 return Ok(db);
152 }
153
154 let pool_config = self.inner.config.pool.as_ref().unwrap();
156 let shard_path = if pool_config.base_path == ":memory:" {
157 ":memory:".to_string()
158 } else {
159 Path::new(&pool_config.base_path)
160 .join(format!("{name}.db"))
161 .to_string_lossy()
162 .into_owned()
163 };
164 let shard_config = Config {
165 path: shard_path,
166 pool: None,
167 ..self.inner.config.clone()
168 };
169
170 let db = connect(&shard_config).await.map_err(|e| {
171 Error::internal(format!("failed to open shard database: {name}")).chain(e)
172 })?;
173
174 self.inner.shards.insert(name.to_string(), db.clone());
175 Ok(db)
176 }
177}
178
179pub struct ManagedDatabasePool(DatabasePool);
187
188impl crate::runtime::Task for ManagedDatabasePool {
189 async fn shutdown(self) -> Result<()> {
190 drop(self.0);
191 Ok(())
192 }
193}
194
195pub fn managed_pool(pool: DatabasePool) -> ManagedDatabasePool {
211 ManagedDatabasePool(pool)
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 async fn make_test_db() -> Database {
219 let config = super::super::config::Config {
220 path: ":memory:".to_string(),
221 ..Default::default()
222 };
223 super::super::connect::connect(&config).await.unwrap()
224 }
225
226 #[test]
227 fn sharded_map_get_returns_none_for_missing_key() {
228 let map = ShardedMap::new(4);
229 assert!(map.get("missing").is_none());
230 }
231
232 #[tokio::test]
233 async fn sharded_map_insert_and_get() {
234 let map = ShardedMap::new(4);
235 let db = make_test_db().await;
236 map.insert("tenant_a".to_string(), db);
237 assert!(map.get("tenant_a").is_some());
238 }
239
240 #[tokio::test]
241 async fn sharded_map_different_keys_independent() {
242 let map = ShardedMap::new(4);
243 let db = make_test_db().await;
244 map.insert("tenant_a".to_string(), db);
245 assert!(map.get("tenant_a").is_some());
246 assert!(map.get("tenant_b").is_none());
247 }
248
249 #[tokio::test]
250 async fn sharded_map_insert_idempotent() {
251 let map = ShardedMap::new(4);
252 let db1 = make_test_db().await;
253 let db2 = make_test_db().await;
254 map.insert("key".to_string(), db1);
255 map.insert("key".to_string(), db2);
256 assert!(map.get("key").is_some());
257 }
258}