docbox_database/
pool.rs

1//! # Database Pool
2//!
3//! This is the docbox solution for managing multiple database connections
4//! and connection pools for each tenant and the root database itself.
5//!
6//! Pools are held in a cache with an expiry time to ensure they don't
7//! hog too many database connections.
8//!
9//! Database pools and credentials are stored in a Tiny LFU cache these caches
10//! can be flushed using [DatabasePoolCache::flush]
11//!
12//! ## Environment Variables
13//!
14//! * `DOCBOX_DB_HOST` - Database host
15//! * `DOCBOX_DB_PORT` - Database port
16//! * `DOCBOX_DB_CREDENTIAL_NAME` - Secrets manager name for the root database secret
17//! * `DOCBOX_DB_MAX_CONNECTIONS` - Max connections each tenant pool can contain
18//! * `DOCBOX_DB_MAX_ROOT_CONNECTIONS` - Max connections the root "docbox" pool can contain
19//! * `DOCBOX_DB_ACQUIRE_TIMEOUT` - Timeout before acquiring a connection fails
20//! * `DOCBOX_DB_IDLE_TIMEOUT` - Timeout before a idle connection is closed to save resources
21//! * `DOCBOX_DB_CACHE_DURATION` - Duration idle pools should be maintained for before closing
22//! * `DOCBOX_DB_CACHE_CAPACITY` - Maximum database pools to hold at once
23//! * `DOCBOX_DB_CREDENTIALS_CACHE_DURATION` - Duration database credentials should be cached for
24//! * `DOCBOX_DB_CREDENTIALS_CACHE_CAPACITY` - Maximum database credentials to cache
25
26use crate::{DbErr, DbPool, ROOT_DATABASE_NAME, models::tenant::Tenant};
27use docbox_secrets::{SecretManager, SecretManagerError};
28use moka::{future::Cache, policy::EvictionPolicy};
29use serde::{Deserialize, Serialize};
30use sqlx::{
31    PgPool,
32    postgres::{PgConnectOptions, PgPoolOptions},
33};
34use std::num::ParseIntError;
35use std::sync::Arc;
36use std::time::Duration;
37use thiserror::Error;
38
39///  Config for the database pool
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct DatabasePoolCacheConfig {
42    /// Database host
43    pub host: String,
44    /// Database port
45    pub port: u16,
46
47    /// Name of the secrets manager secret to use when connecting to
48    /// the root "docbox" database
49    pub root_secret_name: String,
50
51    /// Max number of active connections per tenant database pool
52    ///
53    /// This is the maximum number of connections that should be allocated
54    /// for performing all queries against each specific tenant.
55    ///
56    /// Ensure a reasonable amount of connections are allocated but make
57    /// sure that the `max_connections` * your number of tenants stays
58    /// within the limits for your database
59    ///
60    /// Default: 10
61    pub max_connections: Option<u32>,
62
63    /// Max number of active connections per "docbox" database pool
64    ///
65    /// This is the maximum number of connections that should be allocated
66    /// for performing queries like:
67    /// - Listing tenants
68    /// - Getting tenant details
69    ///
70    /// These pools are often short lived and complete their queries very fast
71    /// and thus don't need a huge amount of resources allocated to them
72    ///
73    /// Default: 2
74    pub max_connections_root: Option<u32>,
75
76    /// Timeout before a acquiring a database connection is considered
77    /// a failure
78    ///
79    /// Default: 60s
80    pub acquire_timeout: Option<u64>,
81
82    /// If a connection has been idle for this duration the connection
83    /// will be closed and released back to the database for other
84    /// consumers
85    ///
86    /// Default: 10min
87    pub idle_timeout: Option<u64>,
88
89    /// Duration in seconds idle database pools are allowed to be cached before
90    /// they are closed
91    ///
92    /// Default: 48h
93    pub cache_duration: Option<u64>,
94
95    /// Maximum database pools to maintain in the cache at once. If the
96    /// cache capacity is exceeded old pools will be closed and removed
97    /// from the cache
98    ///
99    /// This capacity should be aligned with your expected number of
100    /// tenants along with your `max_connections` to ensure your database
101    /// has enough connections to accommodate all tenants.
102    ///
103    /// Default: 50
104    pub cache_capacity: Option<u64>,
105
106    /// Duration in seconds database credentials (host, port, password, ..etc)
107    /// are allowed to be cached before they are refresh from the secrets
108    /// manager
109    ///
110    /// Default: 12h
111    pub credentials_cache_duration: Option<u64>,
112
113    /// Maximum database credentials to maintain in the cache at once. If the
114    /// cache capacity is exceeded old credentials will be removed from the cache
115    ///
116    /// Default: 50
117    pub credentials_cache_capacity: Option<u64>,
118}
119
120impl Default for DatabasePoolCacheConfig {
121    fn default() -> Self {
122        Self {
123            host: Default::default(),
124            port: 5432,
125            root_secret_name: Default::default(),
126            max_connections: None,
127            max_connections_root: None,
128            acquire_timeout: None,
129            idle_timeout: None,
130            cache_duration: None,
131            cache_capacity: None,
132            credentials_cache_duration: None,
133            credentials_cache_capacity: None,
134        }
135    }
136}
137
138#[derive(Debug, Error)]
139pub enum DatabasePoolCacheConfigError {
140    #[error("missing DOCBOX_DB_HOST environment variable")]
141    MissingDatabaseHost,
142    #[error("missing DOCBOX_DB_PORT environment variable")]
143    MissingDatabasePort,
144    #[error("invalid DOCBOX_DB_PORT environment variable")]
145    InvalidDatabasePort,
146    #[error("missing DOCBOX_DB_CREDENTIAL_NAME environment variable")]
147    MissingDatabaseSecretName,
148    #[error("invalid DOCBOX_DB_IDLE_TIMEOUT environment variable")]
149    InvalidIdleTimeout(ParseIntError),
150    #[error("invalid DOCBOX_DB_ACQUIRE_TIMEOUT environment variable")]
151    InvalidAcquireTimeout(ParseIntError),
152    #[error("invalid DOCBOX_DB_CACHE_DURATION environment variable")]
153    InvalidCacheDuration(ParseIntError),
154    #[error("invalid DOCBOX_DB_CACHE_CAPACITY environment variable")]
155    InvalidCacheCapacity(ParseIntError),
156    #[error("invalid DOCBOX_DB_CREDENTIALS_CACHE_DURATION environment variable")]
157    InvalidCredentialsCacheDuration(ParseIntError),
158    #[error("invalid DOCBOX_DB_CREDENTIALS_CACHE_CAPACITY environment variable")]
159    InvalidCredentialsCacheCapacity(ParseIntError),
160}
161
162impl DatabasePoolCacheConfig {
163    pub fn from_env() -> Result<DatabasePoolCacheConfig, DatabasePoolCacheConfigError> {
164        let db_host: String = std::env::var("DOCBOX_DB_HOST")
165            .or(std::env::var("POSTGRES_HOST"))
166            .map_err(|_| DatabasePoolCacheConfigError::MissingDatabaseHost)?;
167        let db_port: u16 = std::env::var("DOCBOX_DB_PORT")
168            .or(std::env::var("POSTGRES_PORT"))
169            .map_err(|_| DatabasePoolCacheConfigError::MissingDatabasePort)?
170            .parse()
171            .map_err(|_| DatabasePoolCacheConfigError::InvalidDatabasePort)?;
172        let db_root_secret_name = std::env::var("DOCBOX_DB_CREDENTIAL_NAME")
173            .map_err(|_| DatabasePoolCacheConfigError::MissingDatabaseSecretName)?;
174        let max_connections: Option<u32> = std::env::var("DOCBOX_DB_MAX_CONNECTIONS")
175            .ok()
176            .and_then(|value| value.parse().ok());
177        let max_connections_root: Option<u32> = std::env::var("DOCBOX_DB_MAX_ROOT_CONNECTIONS")
178            .ok()
179            .and_then(|value| value.parse().ok());
180
181        let acquire_timeout: Option<u64> = match std::env::var("DOCBOX_DB_ACQUIRE_TIMEOUT") {
182            Ok(value) => Some(
183                value
184                    .parse::<u64>()
185                    .map_err(DatabasePoolCacheConfigError::InvalidAcquireTimeout)?,
186            ),
187            Err(_) => None,
188        };
189
190        let idle_timeout: Option<u64> = match std::env::var("DOCBOX_DB_IDLE_TIMEOUT") {
191            Ok(value) => Some(
192                value
193                    .parse::<u64>()
194                    .map_err(DatabasePoolCacheConfigError::InvalidIdleTimeout)?,
195            ),
196            Err(_) => None,
197        };
198
199        let cache_duration: Option<u64> = match std::env::var("DOCBOX_DB_CACHE_DURATION") {
200            Ok(value) => Some(
201                value
202                    .parse::<u64>()
203                    .map_err(DatabasePoolCacheConfigError::InvalidCacheDuration)?,
204            ),
205            Err(_) => None,
206        };
207
208        let cache_capacity: Option<u64> = match std::env::var("DOCBOX_DB_CACHE_CAPACITY") {
209            Ok(value) => Some(
210                value
211                    .parse::<u64>()
212                    .map_err(DatabasePoolCacheConfigError::InvalidCacheCapacity)?,
213            ),
214            Err(_) => None,
215        };
216
217        let credentials_cache_duration: Option<u64> =
218            match std::env::var("DOCBOX_DB_CREDENTIALS_CACHE_DURATION") {
219                Ok(value) => Some(
220                    value
221                        .parse::<u64>()
222                        .map_err(DatabasePoolCacheConfigError::InvalidCredentialsCacheDuration)?,
223                ),
224                Err(_) => None,
225            };
226
227        let credentials_cache_capacity: Option<u64> =
228            match std::env::var("DOCBOX_DB_CREDENTIALS_CACHE_CAPACITY") {
229                Ok(value) => Some(
230                    value
231                        .parse::<u64>()
232                        .map_err(DatabasePoolCacheConfigError::InvalidCredentialsCacheCapacity)?,
233                ),
234                Err(_) => None,
235            };
236
237        Ok(DatabasePoolCacheConfig {
238            host: db_host,
239            port: db_port,
240            root_secret_name: db_root_secret_name,
241            max_connections,
242            max_connections_root,
243            acquire_timeout,
244            idle_timeout,
245            cache_duration,
246            cache_capacity,
247            credentials_cache_duration,
248            credentials_cache_capacity,
249        })
250    }
251}
252
253/// Cache for database pools
254pub struct DatabasePoolCache {
255    /// Database host
256    host: String,
257
258    /// Database port
259    port: u16,
260
261    /// Name of the secrets manager secret that contains
262    /// the credentials for the root "docbox" database
263    root_secret_name: String,
264
265    /// Cache from the database name to the pool for that database
266    cache: Cache<String, DbPool>,
267
268    /// Cache for the connection info details, stores the last known
269    /// credentials and the instant that they were obtained at
270    connect_info_cache: Cache<String, DbSecrets>,
271
272    /// Secrets manager access to load credentials
273    secrets_manager: SecretManager,
274
275    /// Max connections per tenant database pool
276    max_connections: u32,
277    /// Max connections per root database pool
278    max_connections_root: u32,
279
280    acquire_timeout: Duration,
281    idle_timeout: Duration,
282}
283
284/// Username and password for a specific database
285#[derive(Debug, Clone, Serialize, Deserialize)]
286pub struct DbSecrets {
287    pub username: String,
288    pub password: String,
289}
290
291#[derive(Debug, Error)]
292pub enum DbConnectErr {
293    #[error("database credentials not found in secrets manager")]
294    MissingCredentials,
295
296    #[error(transparent)]
297    SecretsManager(Box<SecretManagerError>),
298
299    #[error(transparent)]
300    Db(#[from] DbErr),
301
302    #[error(transparent)]
303    Shared(#[from] Arc<DbConnectErr>),
304}
305
306impl DatabasePoolCache {
307    pub fn from_config(config: DatabasePoolCacheConfig, secrets_manager: SecretManager) -> Self {
308        let cache_duration = Duration::from_secs(config.cache_duration.unwrap_or(60 * 60 * 48));
309        let credentials_cache_duration =
310            Duration::from_secs(config.credentials_cache_duration.unwrap_or(60 * 60 * 12));
311
312        let cache_capacity = config.cache_capacity.unwrap_or(50);
313        let credentials_cache_capacity = config.credentials_cache_capacity.unwrap_or(50);
314
315        let cache = Cache::builder()
316            .time_to_idle(cache_duration)
317            .max_capacity(cache_capacity)
318            .eviction_policy(EvictionPolicy::tiny_lfu())
319            .async_eviction_listener(|cache_key: Arc<String>, pool: DbPool, _cause| {
320                Box::pin(async move {
321                    tracing::debug!(?cache_key, "database pool is no longer in use, closing");
322                    pool.close().await
323                })
324            })
325            .build();
326
327        let connect_info_cache = Cache::builder()
328            .time_to_idle(credentials_cache_duration)
329            .max_capacity(credentials_cache_capacity)
330            .eviction_policy(EvictionPolicy::tiny_lfu())
331            .build();
332
333        Self {
334            host: config.host,
335            port: config.port,
336            root_secret_name: config.root_secret_name,
337            cache,
338            connect_info_cache,
339            secrets_manager,
340            max_connections: config.max_connections.unwrap_or(10),
341            max_connections_root: config.max_connections_root.unwrap_or(2),
342            idle_timeout: Duration::from_secs(config.idle_timeout.unwrap_or(60 * 10)),
343            acquire_timeout: Duration::from_secs(config.acquire_timeout.unwrap_or(60)),
344        }
345    }
346
347    /// Request a database pool for the root database
348    pub async fn get_root_pool(&self) -> Result<PgPool, DbConnectErr> {
349        self.get_pool(ROOT_DATABASE_NAME, &self.root_secret_name)
350            .await
351    }
352
353    /// Request a database pool for a specific tenant
354    pub async fn get_tenant_pool(&self, tenant: &Tenant) -> Result<DbPool, DbConnectErr> {
355        self.get_pool(&tenant.db_name, &tenant.db_secret_name).await
356    }
357
358    /// Closes the database pool for the specific tenant if one is
359    /// available and removes the pool from the cache
360    pub async fn close_tenant_pool(&self, tenant: &Tenant) {
361        let cache_key = format!("{}-{}", &tenant.db_name, &tenant.db_secret_name);
362        if let Some(pool) = self.cache.remove(&cache_key).await {
363            pool.close().await;
364        }
365
366        // Run cache async shutdown jobs
367        self.cache.run_pending_tasks().await;
368    }
369
370    /// Empties all the caches
371    pub async fn flush(&self) {
372        // Clear cache
373        self.cache.invalidate_all();
374        self.connect_info_cache.invalidate_all();
375        self.cache.run_pending_tasks().await;
376    }
377
378    /// Close all connections in the pool and invalidate the cache
379    pub async fn close_all(&self) {
380        for (_, value) in self.cache.iter() {
381            value.close().await;
382        }
383
384        self.flush().await;
385    }
386
387    /// Obtains a database pool connection to the database with the provided name
388    async fn get_pool(&self, db_name: &str, secret_name: &str) -> Result<DbPool, DbConnectErr> {
389        let cache_key = format!("{db_name}-{secret_name}");
390
391        let pool = self
392            .cache
393            .try_get_with(cache_key, async {
394                tracing::debug!(?db_name, "acquiring database pool");
395
396                let pool = self
397                    .create_pool(db_name, secret_name)
398                    .await
399                    .map_err(Arc::new)?;
400
401                Ok(pool)
402            })
403            .await?;
404
405        Ok(pool)
406    }
407
408    /// Obtains database connection info
409    async fn get_credentials(&self, secret_name: &str) -> Result<DbSecrets, DbConnectErr> {
410        if let Some(connect_info) = self.connect_info_cache.get(secret_name).await {
411            return Ok(connect_info);
412        }
413
414        // Load new credentials
415        let credentials = self
416            .secrets_manager
417            .parsed_secret::<DbSecrets>(secret_name)
418            .await
419            .map_err(|err| DbConnectErr::SecretsManager(Box::new(err)))?
420            .ok_or(DbConnectErr::MissingCredentials)?;
421
422        // Cache the credential
423        self.connect_info_cache
424            .insert(secret_name.to_string(), credentials.clone())
425            .await;
426
427        Ok(credentials)
428    }
429
430    /// Creates a database pool connection
431    async fn create_pool(&self, db_name: &str, secret_name: &str) -> Result<DbPool, DbConnectErr> {
432        tracing::debug!(?db_name, ?secret_name, "creating db pool connection");
433
434        let credentials = self.get_credentials(secret_name).await?;
435        let options = PgConnectOptions::new()
436            .host(&self.host)
437            .port(self.port)
438            .username(&credentials.username)
439            .password(&credentials.password)
440            .database(db_name);
441
442        let max_connections = match db_name {
443            ROOT_DATABASE_NAME => self.max_connections_root,
444            _ => self.max_connections,
445        };
446
447        match PgPoolOptions::new()
448            .max_connections(max_connections)
449            // Slightly larger acquire timeout for times when lots of files are being processed
450            .acquire_timeout(self.acquire_timeout)
451            // Close any connections that have been idle for more than 30min
452            .idle_timeout(self.idle_timeout)
453            .connect_with(options)
454            .await
455        {
456            // Success case
457            Ok(value) => Ok(value),
458            Err(err) => {
459                // Drop the connect info cache in case the credentials were wrong
460                self.connect_info_cache.remove(secret_name).await;
461                Err(DbConnectErr::Db(err))
462            }
463        }
464    }
465}