1use crate::{DbErr, DbPool, ROOT_DATABASE_NAME, models::tenant::Tenant};
27use docbox_secrets::{SecretManager, SecretManagerError};
28use moka::{future::Cache, policy::EvictionPolicy};
29use serde::{Deserialize, Serialize};
30use sqlx::{
31 PgPool,
32 postgres::{PgConnectOptions, PgPoolOptions},
33};
34use std::num::ParseIntError;
35use std::sync::Arc;
36use std::time::Duration;
37use thiserror::Error;
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct DatabasePoolCacheConfig {
42 pub host: String,
44 pub port: u16,
46
47 pub root_secret_name: String,
50
51 pub max_connections: Option<u32>,
62
63 pub max_connections_root: Option<u32>,
75
76 pub acquire_timeout: Option<u64>,
81
82 pub idle_timeout: Option<u64>,
88
89 pub cache_duration: Option<u64>,
94
95 pub cache_capacity: Option<u64>,
105
106 pub credentials_cache_duration: Option<u64>,
112
113 pub credentials_cache_capacity: Option<u64>,
118}
119
120impl Default for DatabasePoolCacheConfig {
121 fn default() -> Self {
122 Self {
123 host: Default::default(),
124 port: 5432,
125 root_secret_name: Default::default(),
126 max_connections: None,
127 max_connections_root: None,
128 acquire_timeout: None,
129 idle_timeout: None,
130 cache_duration: None,
131 cache_capacity: None,
132 credentials_cache_duration: None,
133 credentials_cache_capacity: None,
134 }
135 }
136}
137
138#[derive(Debug, Error)]
139pub enum DatabasePoolCacheConfigError {
140 #[error("missing DOCBOX_DB_HOST environment variable")]
141 MissingDatabaseHost,
142 #[error("missing DOCBOX_DB_PORT environment variable")]
143 MissingDatabasePort,
144 #[error("invalid DOCBOX_DB_PORT environment variable")]
145 InvalidDatabasePort,
146 #[error("missing DOCBOX_DB_CREDENTIAL_NAME environment variable")]
147 MissingDatabaseSecretName,
148 #[error("invalid DOCBOX_DB_IDLE_TIMEOUT environment variable")]
149 InvalidIdleTimeout(ParseIntError),
150 #[error("invalid DOCBOX_DB_ACQUIRE_TIMEOUT environment variable")]
151 InvalidAcquireTimeout(ParseIntError),
152 #[error("invalid DOCBOX_DB_CACHE_DURATION environment variable")]
153 InvalidCacheDuration(ParseIntError),
154 #[error("invalid DOCBOX_DB_CACHE_CAPACITY environment variable")]
155 InvalidCacheCapacity(ParseIntError),
156 #[error("invalid DOCBOX_DB_CREDENTIALS_CACHE_DURATION environment variable")]
157 InvalidCredentialsCacheDuration(ParseIntError),
158 #[error("invalid DOCBOX_DB_CREDENTIALS_CACHE_CAPACITY environment variable")]
159 InvalidCredentialsCacheCapacity(ParseIntError),
160}
161
162impl DatabasePoolCacheConfig {
163 pub fn from_env() -> Result<DatabasePoolCacheConfig, DatabasePoolCacheConfigError> {
164 let db_host: String = std::env::var("DOCBOX_DB_HOST")
165 .or(std::env::var("POSTGRES_HOST"))
166 .map_err(|_| DatabasePoolCacheConfigError::MissingDatabaseHost)?;
167 let db_port: u16 = std::env::var("DOCBOX_DB_PORT")
168 .or(std::env::var("POSTGRES_PORT"))
169 .map_err(|_| DatabasePoolCacheConfigError::MissingDatabasePort)?
170 .parse()
171 .map_err(|_| DatabasePoolCacheConfigError::InvalidDatabasePort)?;
172 let db_root_secret_name = std::env::var("DOCBOX_DB_CREDENTIAL_NAME")
173 .map_err(|_| DatabasePoolCacheConfigError::MissingDatabaseSecretName)?;
174 let max_connections: Option<u32> = std::env::var("DOCBOX_DB_MAX_CONNECTIONS")
175 .ok()
176 .and_then(|value| value.parse().ok());
177 let max_connections_root: Option<u32> = std::env::var("DOCBOX_DB_MAX_ROOT_CONNECTIONS")
178 .ok()
179 .and_then(|value| value.parse().ok());
180
181 let acquire_timeout: Option<u64> = match std::env::var("DOCBOX_DB_ACQUIRE_TIMEOUT") {
182 Ok(value) => Some(
183 value
184 .parse::<u64>()
185 .map_err(DatabasePoolCacheConfigError::InvalidAcquireTimeout)?,
186 ),
187 Err(_) => None,
188 };
189
190 let idle_timeout: Option<u64> = match std::env::var("DOCBOX_DB_IDLE_TIMEOUT") {
191 Ok(value) => Some(
192 value
193 .parse::<u64>()
194 .map_err(DatabasePoolCacheConfigError::InvalidIdleTimeout)?,
195 ),
196 Err(_) => None,
197 };
198
199 let cache_duration: Option<u64> = match std::env::var("DOCBOX_DB_CACHE_DURATION") {
200 Ok(value) => Some(
201 value
202 .parse::<u64>()
203 .map_err(DatabasePoolCacheConfigError::InvalidCacheDuration)?,
204 ),
205 Err(_) => None,
206 };
207
208 let cache_capacity: Option<u64> = match std::env::var("DOCBOX_DB_CACHE_CAPACITY") {
209 Ok(value) => Some(
210 value
211 .parse::<u64>()
212 .map_err(DatabasePoolCacheConfigError::InvalidCacheCapacity)?,
213 ),
214 Err(_) => None,
215 };
216
217 let credentials_cache_duration: Option<u64> =
218 match std::env::var("DOCBOX_DB_CREDENTIALS_CACHE_DURATION") {
219 Ok(value) => Some(
220 value
221 .parse::<u64>()
222 .map_err(DatabasePoolCacheConfigError::InvalidCredentialsCacheDuration)?,
223 ),
224 Err(_) => None,
225 };
226
227 let credentials_cache_capacity: Option<u64> =
228 match std::env::var("DOCBOX_DB_CREDENTIALS_CACHE_CAPACITY") {
229 Ok(value) => Some(
230 value
231 .parse::<u64>()
232 .map_err(DatabasePoolCacheConfigError::InvalidCredentialsCacheCapacity)?,
233 ),
234 Err(_) => None,
235 };
236
237 Ok(DatabasePoolCacheConfig {
238 host: db_host,
239 port: db_port,
240 root_secret_name: db_root_secret_name,
241 max_connections,
242 max_connections_root,
243 acquire_timeout,
244 idle_timeout,
245 cache_duration,
246 cache_capacity,
247 credentials_cache_duration,
248 credentials_cache_capacity,
249 })
250 }
251}
252
253pub struct DatabasePoolCache {
255 host: String,
257
258 port: u16,
260
261 root_secret_name: String,
264
265 cache: Cache<String, DbPool>,
267
268 connect_info_cache: Cache<String, DbSecrets>,
271
272 secrets_manager: SecretManager,
274
275 max_connections: u32,
277 max_connections_root: u32,
279
280 acquire_timeout: Duration,
281 idle_timeout: Duration,
282}
283
284#[derive(Debug, Clone, Serialize, Deserialize)]
286pub struct DbSecrets {
287 pub username: String,
288 pub password: String,
289}
290
291#[derive(Debug, Error)]
292pub enum DbConnectErr {
293 #[error("database credentials not found in secrets manager")]
294 MissingCredentials,
295
296 #[error(transparent)]
297 SecretsManager(Box<SecretManagerError>),
298
299 #[error(transparent)]
300 Db(#[from] DbErr),
301
302 #[error(transparent)]
303 Shared(#[from] Arc<DbConnectErr>),
304}
305
306impl DatabasePoolCache {
307 pub fn from_config(config: DatabasePoolCacheConfig, secrets_manager: SecretManager) -> Self {
308 let cache_duration = Duration::from_secs(config.cache_duration.unwrap_or(60 * 60 * 48));
309 let credentials_cache_duration =
310 Duration::from_secs(config.credentials_cache_duration.unwrap_or(60 * 60 * 12));
311
312 let cache_capacity = config.cache_capacity.unwrap_or(50);
313 let credentials_cache_capacity = config.credentials_cache_capacity.unwrap_or(50);
314
315 let cache = Cache::builder()
316 .time_to_idle(cache_duration)
317 .max_capacity(cache_capacity)
318 .eviction_policy(EvictionPolicy::tiny_lfu())
319 .async_eviction_listener(|cache_key: Arc<String>, pool: DbPool, _cause| {
320 Box::pin(async move {
321 tracing::debug!(?cache_key, "database pool is no longer in use, closing");
322 pool.close().await
323 })
324 })
325 .build();
326
327 let connect_info_cache = Cache::builder()
328 .time_to_idle(credentials_cache_duration)
329 .max_capacity(credentials_cache_capacity)
330 .eviction_policy(EvictionPolicy::tiny_lfu())
331 .build();
332
333 Self {
334 host: config.host,
335 port: config.port,
336 root_secret_name: config.root_secret_name,
337 cache,
338 connect_info_cache,
339 secrets_manager,
340 max_connections: config.max_connections.unwrap_or(10),
341 max_connections_root: config.max_connections_root.unwrap_or(2),
342 idle_timeout: Duration::from_secs(config.idle_timeout.unwrap_or(60 * 10)),
343 acquire_timeout: Duration::from_secs(config.acquire_timeout.unwrap_or(60)),
344 }
345 }
346
347 pub async fn get_root_pool(&self) -> Result<PgPool, DbConnectErr> {
349 self.get_pool(ROOT_DATABASE_NAME, &self.root_secret_name)
350 .await
351 }
352
353 pub async fn get_tenant_pool(&self, tenant: &Tenant) -> Result<DbPool, DbConnectErr> {
355 self.get_pool(&tenant.db_name, &tenant.db_secret_name).await
356 }
357
358 pub async fn close_tenant_pool(&self, tenant: &Tenant) {
361 let cache_key = format!("{}-{}", &tenant.db_name, &tenant.db_secret_name);
362 if let Some(pool) = self.cache.remove(&cache_key).await {
363 pool.close().await;
364 }
365
366 self.cache.run_pending_tasks().await;
368 }
369
370 pub async fn flush(&self) {
372 self.cache.invalidate_all();
374 self.connect_info_cache.invalidate_all();
375 self.cache.run_pending_tasks().await;
376 }
377
378 pub async fn close_all(&self) {
380 for (_, value) in self.cache.iter() {
381 value.close().await;
382 }
383
384 self.flush().await;
385 }
386
387 async fn get_pool(&self, db_name: &str, secret_name: &str) -> Result<DbPool, DbConnectErr> {
389 let cache_key = format!("{db_name}-{secret_name}");
390
391 let pool = self
392 .cache
393 .try_get_with(cache_key, async {
394 tracing::debug!(?db_name, "acquiring database pool");
395
396 let pool = self
397 .create_pool(db_name, secret_name)
398 .await
399 .map_err(Arc::new)?;
400
401 Ok(pool)
402 })
403 .await?;
404
405 Ok(pool)
406 }
407
408 async fn get_credentials(&self, secret_name: &str) -> Result<DbSecrets, DbConnectErr> {
410 if let Some(connect_info) = self.connect_info_cache.get(secret_name).await {
411 return Ok(connect_info);
412 }
413
414 let credentials = self
416 .secrets_manager
417 .parsed_secret::<DbSecrets>(secret_name)
418 .await
419 .map_err(|err| DbConnectErr::SecretsManager(Box::new(err)))?
420 .ok_or(DbConnectErr::MissingCredentials)?;
421
422 self.connect_info_cache
424 .insert(secret_name.to_string(), credentials.clone())
425 .await;
426
427 Ok(credentials)
428 }
429
430 async fn create_pool(&self, db_name: &str, secret_name: &str) -> Result<DbPool, DbConnectErr> {
432 tracing::debug!(?db_name, ?secret_name, "creating db pool connection");
433
434 let credentials = self.get_credentials(secret_name).await?;
435 let options = PgConnectOptions::new()
436 .host(&self.host)
437 .port(self.port)
438 .username(&credentials.username)
439 .password(&credentials.password)
440 .database(db_name);
441
442 let max_connections = match db_name {
443 ROOT_DATABASE_NAME => self.max_connections_root,
444 _ => self.max_connections,
445 };
446
447 match PgPoolOptions::new()
448 .max_connections(max_connections)
449 .acquire_timeout(self.acquire_timeout)
451 .idle_timeout(self.idle_timeout)
453 .connect_with(options)
454 .await
455 {
456 Ok(value) => Ok(value),
458 Err(err) => {
459 self.connect_info_cache.remove(secret_name).await;
461 Err(DbConnectErr::Db(err))
462 }
463 }
464 }
465}