Skip to main content

docbox_database/
pool.rs

1//! # Database Pool
2//!
3//! This is the docbox solution for managing multiple database connections
4//! and connection pools for each tenant and the root database itself.
5//!
6//! Pools are held in a cache with an expiry time to ensure they don't
7//! hog too many database connections.
8//!
9//! Database pools and credentials are stored in a Tiny LFU cache these caches
10//! can be flushed using [DatabasePoolCache::flush]
11//!
12//! ## Environment Variables
13//!
14//! * `DOCBOX_DB_HOST` - Database host
15//! * `DOCBOX_DB_PORT` - Database port
16//! * `DOCBOX_DB_CREDENTIAL_NAME` - Secrets manager name for the root database secret
17//! * `DOCBOX_DB_MAX_CONNECTIONS` - Max connections each tenant pool can contain
18//! * `DOCBOX_DB_MAX_ROOT_CONNECTIONS` - Max connections the root "docbox" pool can contain
19//! * `DOCBOX_DB_ACQUIRE_TIMEOUT` - Timeout before acquiring a connection fails
20//! * `DOCBOX_DB_IDLE_TIMEOUT` - Timeout before a idle connection is closed to save resources
21//! * `DOCBOX_DB_CACHE_DURATION` - Duration idle pools should be maintained for before closing
22//! * `DOCBOX_DB_CACHE_CAPACITY` - Maximum database pools to hold at once
23//! * `DOCBOX_DB_CREDENTIALS_CACHE_DURATION` - Duration database credentials should be cached for
24//! * `DOCBOX_DB_CREDENTIALS_CACHE_CAPACITY` - Maximum database credentials to cache
25
26use crate::{DbErr, DbPool, ROOT_DATABASE_NAME, ROOT_DATABASE_ROLE_NAME, models::tenant::Tenant};
27use aws_credential_types::provider::{ProvideCredentials, error::CredentialsError};
28use aws_sigv4::{
29    http_request::{SignableBody, SignableRequest, SigningError, SigningSettings, sign},
30    sign::v4::signing_params,
31};
32use docbox_secrets::{SecretManager, SecretManagerError};
33use moka::{future::Cache, policy::EvictionPolicy};
34use serde::{Deserialize, Serialize};
35use sqlx::{
36    PgPool,
37    postgres::{PgConnectOptions, PgPoolOptions},
38};
39use std::time::Duration;
40use std::{num::ParseIntError, str::ParseBoolError};
41use std::{sync::Arc, time::SystemTime};
42use thiserror::Error;
43
44///  Config for the database pool
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct DatabasePoolCacheConfig {
47    /// Database host
48    pub host: String,
49    /// Database port
50    pub port: u16,
51
52    /// Name of the secrets manager secret to use when connecting to
53    /// the root "docbox" database if using secret based authentication
54    pub root_secret_name: Option<String>,
55
56    /// Whether to use IAM authentication to connect to the
57    /// root database instead of secrets
58    #[serde(default)]
59    pub root_iam: bool,
60
61    /// Max number of active connections per tenant database pool
62    ///
63    /// This is the maximum number of connections that should be allocated
64    /// for performing all queries against each specific tenant.
65    ///
66    /// Ensure a reasonable amount of connections are allocated but make
67    /// sure that the `max_connections` * your number of tenants stays
68    /// within the limits for your database
69    ///
70    /// Default: 10
71    pub max_connections: Option<u32>,
72
73    /// Max number of active connections per "docbox" database pool
74    ///
75    /// This is the maximum number of connections that should be allocated
76    /// for performing queries like:
77    /// - Listing tenants
78    /// - Getting tenant details
79    ///
80    /// These pools are often short lived and complete their queries very fast
81    /// and thus don't need a huge amount of resources allocated to them
82    ///
83    /// Default: 2
84    pub max_connections_root: Option<u32>,
85
86    /// Timeout before a acquiring a database connection is considered
87    /// a failure
88    ///
89    /// Default: 60s
90    pub acquire_timeout: Option<u64>,
91
92    /// If a connection has been idle for this duration the connection
93    /// will be closed and released back to the database for other
94    /// consumers
95    ///
96    /// Default: 10min
97    pub idle_timeout: Option<u64>,
98
99    /// Duration in seconds idle database pools are allowed to be cached before
100    /// they are closed
101    ///
102    /// Default: 48h
103    pub cache_duration: Option<u64>,
104
105    /// Maximum database pools to maintain in the cache at once. If the
106    /// cache capacity is exceeded old pools will be closed and removed
107    /// from the cache
108    ///
109    /// This capacity should be aligned with your expected number of
110    /// tenants along with your `max_connections` to ensure your database
111    /// has enough connections to accommodate all tenants.
112    ///
113    /// Default: 50
114    pub cache_capacity: Option<u64>,
115
116    /// Duration in seconds database credentials (host, port, password, ..etc)
117    /// are allowed to be cached before they are refresh from the secrets
118    /// manager
119    ///
120    /// Default: 12h
121    pub credentials_cache_duration: Option<u64>,
122
123    /// Maximum database credentials to maintain in the cache at once. If the
124    /// cache capacity is exceeded old credentials will be removed from the cache
125    ///
126    /// Default: 50
127    pub credentials_cache_capacity: Option<u64>,
128}
129
130impl Default for DatabasePoolCacheConfig {
131    fn default() -> Self {
132        Self {
133            host: Default::default(),
134            port: 5432,
135            root_secret_name: Default::default(),
136            root_iam: false,
137            max_connections: None,
138            max_connections_root: None,
139            acquire_timeout: None,
140            idle_timeout: None,
141            cache_duration: None,
142            cache_capacity: None,
143            credentials_cache_duration: None,
144            credentials_cache_capacity: None,
145        }
146    }
147}
148
149#[derive(Debug, Error)]
150pub enum DatabasePoolCacheConfigError {
151    #[error("missing DOCBOX_DB_HOST environment variable")]
152    MissingDatabaseHost,
153    #[error("missing DOCBOX_DB_PORT environment variable")]
154    MissingDatabasePort,
155    #[error("invalid DOCBOX_DB_PORT environment variable")]
156    InvalidDatabasePort,
157    #[error("missing DOCBOX_DB_CREDENTIAL_NAME environment variable")]
158    MissingDatabaseSecretName,
159    #[error("invalid DOCBOX_DB_IDLE_TIMEOUT environment variable")]
160    InvalidIdleTimeout(ParseIntError),
161    #[error("invalid DOCBOX_DB_ACQUIRE_TIMEOUT environment variable")]
162    InvalidAcquireTimeout(ParseIntError),
163    #[error("invalid DOCBOX_DB_CACHE_DURATION environment variable")]
164    InvalidCacheDuration(ParseIntError),
165    #[error("invalid DOCBOX_DB_CACHE_CAPACITY environment variable")]
166    InvalidCacheCapacity(ParseIntError),
167    #[error("invalid DOCBOX_DB_CREDENTIALS_CACHE_DURATION environment variable")]
168    InvalidCredentialsCacheDuration(ParseIntError),
169    #[error("invalid DOCBOX_DB_CREDENTIALS_CACHE_CAPACITY environment variable")]
170    InvalidCredentialsCacheCapacity(ParseIntError),
171    #[error("invalid DOCBOX_DB_ROOT_IAM environment variable")]
172    InvalidRootIam(ParseBoolError),
173}
174
175impl DatabasePoolCacheConfig {
176    pub fn from_env() -> Result<DatabasePoolCacheConfig, DatabasePoolCacheConfigError> {
177        let db_host: String = std::env::var("DOCBOX_DB_HOST")
178            .or(std::env::var("POSTGRES_HOST"))
179            .map_err(|_| DatabasePoolCacheConfigError::MissingDatabaseHost)?;
180        let db_port: u16 = std::env::var("DOCBOX_DB_PORT")
181            .or(std::env::var("POSTGRES_PORT"))
182            .map_err(|_| DatabasePoolCacheConfigError::MissingDatabasePort)?
183            .parse()
184            .map_err(|_| DatabasePoolCacheConfigError::InvalidDatabasePort)?;
185
186        let db_root_secret_name = std::env::var("DOCBOX_DB_CREDENTIAL_NAME").ok();
187        let db_root_iam = std::env::var("DOCBOX_DB_ROOT_IAM")
188            .ok()
189            .map(|value| value.parse::<bool>())
190            .transpose()
191            .map_err(DatabasePoolCacheConfigError::InvalidRootIam)?
192            .unwrap_or_default();
193
194        // Root secret name is required when not using IAM
195        if !db_root_iam && db_root_secret_name.is_none() {
196            return Err(DatabasePoolCacheConfigError::MissingDatabaseSecretName);
197        }
198
199        let max_connections: Option<u32> = std::env::var("DOCBOX_DB_MAX_CONNECTIONS")
200            .ok()
201            .and_then(|value| value.parse().ok());
202        let max_connections_root: Option<u32> = std::env::var("DOCBOX_DB_MAX_ROOT_CONNECTIONS")
203            .ok()
204            .and_then(|value| value.parse().ok());
205
206        let acquire_timeout: Option<u64> = match std::env::var("DOCBOX_DB_ACQUIRE_TIMEOUT") {
207            Ok(value) => Some(
208                value
209                    .parse::<u64>()
210                    .map_err(DatabasePoolCacheConfigError::InvalidAcquireTimeout)?,
211            ),
212            Err(_) => None,
213        };
214
215        let idle_timeout: Option<u64> = match std::env::var("DOCBOX_DB_IDLE_TIMEOUT") {
216            Ok(value) => Some(
217                value
218                    .parse::<u64>()
219                    .map_err(DatabasePoolCacheConfigError::InvalidIdleTimeout)?,
220            ),
221            Err(_) => None,
222        };
223
224        let cache_duration: Option<u64> = match std::env::var("DOCBOX_DB_CACHE_DURATION") {
225            Ok(value) => Some(
226                value
227                    .parse::<u64>()
228                    .map_err(DatabasePoolCacheConfigError::InvalidCacheDuration)?,
229            ),
230            Err(_) => None,
231        };
232
233        let cache_capacity: Option<u64> = match std::env::var("DOCBOX_DB_CACHE_CAPACITY") {
234            Ok(value) => Some(
235                value
236                    .parse::<u64>()
237                    .map_err(DatabasePoolCacheConfigError::InvalidCacheCapacity)?,
238            ),
239            Err(_) => None,
240        };
241
242        let credentials_cache_duration: Option<u64> =
243            match std::env::var("DOCBOX_DB_CREDENTIALS_CACHE_DURATION") {
244                Ok(value) => Some(
245                    value
246                        .parse::<u64>()
247                        .map_err(DatabasePoolCacheConfigError::InvalidCredentialsCacheDuration)?,
248                ),
249                Err(_) => None,
250            };
251
252        let credentials_cache_capacity: Option<u64> =
253            match std::env::var("DOCBOX_DB_CREDENTIALS_CACHE_CAPACITY") {
254                Ok(value) => Some(
255                    value
256                        .parse::<u64>()
257                        .map_err(DatabasePoolCacheConfigError::InvalidCredentialsCacheCapacity)?,
258                ),
259                Err(_) => None,
260            };
261
262        Ok(DatabasePoolCacheConfig {
263            host: db_host,
264            port: db_port,
265            root_iam: db_root_iam,
266            root_secret_name: db_root_secret_name,
267            max_connections,
268            max_connections_root,
269            acquire_timeout,
270            idle_timeout,
271            cache_duration,
272            cache_capacity,
273            credentials_cache_duration,
274            credentials_cache_capacity,
275        })
276    }
277}
278
279/// Cache for database pools
280pub struct DatabasePoolCache {
281    /// AWS config
282    aws_config: aws_config::SdkConfig,
283
284    /// Database host
285    host: String,
286
287    /// Database port
288    port: u16,
289
290    /// Name of the secrets manager secret that contains
291    /// the credentials for the root "docbox" database
292    ///
293    /// Only present if using secrets based authentication
294    root_secret_name: Option<String>,
295
296    /// Whether to use IAM authentication to connect to the
297    /// root database instead of secrets
298    root_iam: bool,
299
300    /// Cache from the database name to the pool for that database
301    cache: Cache<String, DbPool>,
302
303    /// Cache for the connection info details, stores the last known
304    /// credentials and the instant that they were obtained at
305    connect_info_cache: Cache<String, DbSecrets>,
306
307    /// Secrets manager access to load credentials
308    secrets_manager: SecretManager,
309
310    /// Max connections per tenant database pool
311    max_connections: u32,
312    /// Max connections per root database pool
313    max_connections_root: u32,
314
315    acquire_timeout: Duration,
316    idle_timeout: Duration,
317}
318
319/// Username and password for a specific database
320#[derive(Debug, Clone, Serialize, Deserialize)]
321pub struct DbSecrets {
322    pub username: String,
323    pub password: String,
324}
325
326#[derive(Debug, Error)]
327pub enum DbConnectErr {
328    #[error("database credentials not found in secrets manager")]
329    MissingCredentials,
330
331    #[error(transparent)]
332    SecretsManager(Box<SecretManagerError>),
333
334    #[error(transparent)]
335    Db(#[from] DbErr),
336
337    #[error(transparent)]
338    Shared(#[from] Arc<DbConnectErr>),
339
340    #[error("missing aws credentials provider")]
341    MissingCredentialsProvider,
342
343    #[error("failed to provide aws credentials")]
344    AwsCredentials(#[from] CredentialsError),
345
346    #[error("aws configuration missing region")]
347    MissingRegion,
348
349    #[error("failed to build aws signature")]
350    AwsSigner(#[from] signing_params::BuildError),
351
352    #[error("failed to sign aws request")]
353    AwsRequestSign(#[from] SigningError),
354
355    #[error("failed to parse signed aws url")]
356    AwsSignerInvalidUrl(url::ParseError),
357
358    #[error("failed to connect to tenant missing both IAM and secrets fields")]
359    InvalidTenantConfiguration,
360}
361
362impl DatabasePoolCache {
363    pub fn from_config(
364        aws_config: aws_config::SdkConfig,
365        config: DatabasePoolCacheConfig,
366        secrets_manager: SecretManager,
367    ) -> Self {
368        let cache_duration = Duration::from_secs(config.cache_duration.unwrap_or(60 * 60 * 48));
369        let credentials_cache_duration =
370            Duration::from_secs(config.credentials_cache_duration.unwrap_or(60 * 60 * 12));
371
372        let cache_capacity = config.cache_capacity.unwrap_or(50);
373        let credentials_cache_capacity = config.credentials_cache_capacity.unwrap_or(50);
374
375        let cache = Cache::builder()
376            .time_to_idle(cache_duration)
377            .max_capacity(cache_capacity)
378            .eviction_policy(EvictionPolicy::tiny_lfu())
379            .async_eviction_listener(|cache_key: Arc<String>, pool: DbPool, _cause| {
380                Box::pin(async move {
381                    tracing::debug!(?cache_key, "database pool is no longer in use, closing");
382                    pool.close().await
383                })
384            })
385            .build();
386
387        let connect_info_cache = Cache::builder()
388            .time_to_idle(credentials_cache_duration)
389            .max_capacity(credentials_cache_capacity)
390            .eviction_policy(EvictionPolicy::tiny_lfu())
391            .build();
392
393        Self {
394            aws_config,
395            host: config.host,
396            port: config.port,
397            root_secret_name: config.root_secret_name,
398            root_iam: config.root_iam,
399            cache,
400            connect_info_cache,
401            secrets_manager,
402            max_connections: config.max_connections.unwrap_or(10),
403            max_connections_root: config.max_connections_root.unwrap_or(2),
404            idle_timeout: Duration::from_secs(config.idle_timeout.unwrap_or(60 * 10)),
405            acquire_timeout: Duration::from_secs(config.acquire_timeout.unwrap_or(60)),
406        }
407    }
408
409    /// Request a database pool for the root database
410    pub async fn get_root_pool(&self) -> Result<PgPool, DbConnectErr> {
411        match (self.root_secret_name.as_ref(), self.root_iam) {
412            (_, true) => {
413                self.get_pool_iam(ROOT_DATABASE_NAME, ROOT_DATABASE_ROLE_NAME)
414                    .await
415            }
416
417            (Some(db_secret_name), _) => self.get_pool(ROOT_DATABASE_NAME, db_secret_name).await,
418
419            _ => Err(DbConnectErr::InvalidTenantConfiguration),
420        }
421    }
422
423    /// Request a database pool for a specific tenant
424    pub async fn get_tenant_pool(&self, tenant: &Tenant) -> Result<DbPool, DbConnectErr> {
425        match (
426            tenant.db_iam_user_name.as_ref(),
427            tenant.db_secret_name.as_ref(),
428        ) {
429            (Some(db_iam_user_name), _) => {
430                self.get_pool_iam(&tenant.db_name, db_iam_user_name).await
431            }
432            (_, Some(db_secret_name)) => self.get_pool(&tenant.db_name, db_secret_name).await,
433
434            _ => Err(DbConnectErr::InvalidTenantConfiguration),
435        }
436    }
437
438    /// Closes the database pool for the specific tenant if one is
439    /// available and removes the pool from the cache
440    pub async fn close_tenant_pool(&self, tenant: &Tenant) {
441        let cache_key = Self::tenant_cache_key(tenant);
442        if let Some(pool) = self.cache.remove(&cache_key).await {
443            pool.close().await;
444        }
445
446        // Run cache async shutdown jobs
447        self.cache.run_pending_tasks().await;
448    }
449
450    /// Compute the pool cache key for a tenant based on the specific
451    /// authentication methods for that tenant
452    fn tenant_cache_key(tenant: &Tenant) -> String {
453        match (
454            tenant.db_secret_name.as_ref(),
455            tenant.db_iam_user_name.as_ref(),
456        ) {
457            (Some(db_secret_name), _) => {
458                format!("secret-{}-{}", &tenant.db_name, db_secret_name)
459            }
460            (_, Some(db_iam_user_name)) => {
461                format!("user-{}-{}", &tenant.db_name, db_iam_user_name)
462            }
463
464            _ => format!("db-{}", &tenant.db_name),
465        }
466    }
467
468    /// Empties all the caches
469    pub async fn flush(&self) {
470        // Clear cache
471        self.cache.invalidate_all();
472        self.connect_info_cache.invalidate_all();
473        self.cache.run_pending_tasks().await;
474    }
475
476    /// Close all connections in the pool and invalidate the cache
477    pub async fn close_all(&self) {
478        for (_, value) in self.cache.iter() {
479            value.close().await;
480        }
481
482        self.flush().await;
483    }
484
485    /// Obtains a database pool connection to the database with the provided name
486    /// using secrets manager based credentials
487    async fn get_pool(&self, db_name: &str, secret_name: &str) -> Result<DbPool, DbConnectErr> {
488        let cache_key = format!("secret-{db_name}-{secret_name}");
489
490        let pool = self
491            .cache
492            .try_get_with(cache_key, async {
493                tracing::debug!(?db_name, "acquiring database pool");
494
495                let pool = self
496                    .create_pool(db_name, secret_name)
497                    .await
498                    .map_err(Arc::new)?;
499
500                Ok(pool)
501            })
502            .await?;
503
504        Ok(pool)
505    }
506
507    /// Obtains a database pool connection to the database with the provided name
508    /// using IAM based credentials
509    async fn get_pool_iam(
510        &self,
511        db_name: &str,
512        db_role_name: &str,
513    ) -> Result<DbPool, DbConnectErr> {
514        let cache_key = format!("user-{db_name}-{db_role_name}");
515
516        let pool = self
517            .cache
518            .try_get_with(cache_key, async {
519                tracing::debug!(?db_name, "acquiring database pool (iam)");
520
521                let pool = self
522                    .create_pool_iam(db_name, db_role_name)
523                    .await
524                    .map_err(Arc::new)?;
525
526                Ok(pool)
527            })
528            .await?;
529
530        Ok(pool)
531    }
532
533    /// Obtains database connection info
534    async fn get_credentials(&self, secret_name: &str) -> Result<DbSecrets, DbConnectErr> {
535        if let Some(connect_info) = self.connect_info_cache.get(secret_name).await {
536            return Ok(connect_info);
537        }
538
539        // Load new credentials
540        let credentials = self
541            .secrets_manager
542            .parsed_secret::<DbSecrets>(secret_name)
543            .await
544            .map_err(|err| DbConnectErr::SecretsManager(Box::new(err)))?
545            .ok_or(DbConnectErr::MissingCredentials)?;
546
547        // Cache the credential
548        self.connect_info_cache
549            .insert(secret_name.to_string(), credentials.clone())
550            .await;
551
552        Ok(credentials)
553    }
554
555    async fn create_rds_signed_token(
556        &self,
557        host: &str,
558        port: u16,
559        user: &str,
560    ) -> Result<String, DbConnectErr> {
561        let credentials_provider = self
562            .aws_config
563            .credentials_provider()
564            .ok_or(DbConnectErr::MissingCredentialsProvider)?;
565        let credentials = credentials_provider.provide_credentials().await?;
566        let identity = credentials.into();
567        let region = self
568            .aws_config
569            .region()
570            .ok_or(DbConnectErr::MissingRegion)?;
571
572        let mut signing_settings = SigningSettings::default();
573        signing_settings.expires_in = Some(Duration::from_secs(60 * 15));
574        signing_settings.signature_location =
575            aws_sigv4::http_request::SignatureLocation::QueryParams;
576
577        let signing_params = aws_sigv4::sign::v4::SigningParams::builder()
578            .identity(&identity)
579            .region(region.as_ref())
580            .name("rds-db")
581            .time(SystemTime::now())
582            .settings(signing_settings)
583            .build()?;
584
585        let url = format!("https://{host}:{port}/?Action=connect&DBUser={user}");
586
587        let signable_request =
588            SignableRequest::new("GET", &url, std::iter::empty(), SignableBody::Bytes(&[]))?;
589
590        let (signing_instructions, _signature) =
591            sign(signable_request, &signing_params.into())?.into_parts();
592
593        let mut url = url::Url::parse(&url).map_err(DbConnectErr::AwsSignerInvalidUrl)?;
594        for (name, value) in signing_instructions.params() {
595            url.query_pairs_mut().append_pair(name, value);
596        }
597
598        let response = url.to_string().split_off("https://".len());
599        Ok(response)
600    }
601
602    /// Creates a database pool connection using IAM based authentication
603    async fn create_pool_iam(
604        &self,
605        db_name: &str,
606        db_role_name: &str,
607    ) -> Result<DbPool, DbConnectErr> {
608        tracing::debug!(?db_name, ?db_role_name, "creating db pool connection");
609
610        let token = self
611            .create_rds_signed_token(&self.host, self.port, db_role_name)
612            .await?;
613
614        let options = PgConnectOptions::new()
615            .host(&self.host)
616            .port(self.port)
617            .username(db_role_name)
618            .password(&token)
619            .database(db_name);
620
621        let max_connections = match db_name {
622            ROOT_DATABASE_NAME => self.max_connections_root,
623            _ => self.max_connections,
624        };
625
626        PgPoolOptions::new()
627            .max_connections(max_connections)
628            // Slightly larger acquire timeout for times when lots of files are being processed
629            .acquire_timeout(self.acquire_timeout)
630            // Close any connections that have been idle for more than 30min
631            .idle_timeout(self.idle_timeout)
632            .connect_with(options)
633            .await
634            .map_err(DbConnectErr::Db)
635    }
636
637    /// Creates a database pool connection
638    async fn create_pool(&self, db_name: &str, secret_name: &str) -> Result<DbPool, DbConnectErr> {
639        tracing::debug!(?db_name, ?secret_name, "creating db pool connection");
640
641        let credentials = self.get_credentials(secret_name).await?;
642        let options = PgConnectOptions::new()
643            .host(&self.host)
644            .port(self.port)
645            .username(&credentials.username)
646            .password(&credentials.password)
647            .database(db_name);
648
649        let max_connections = match db_name {
650            ROOT_DATABASE_NAME => self.max_connections_root,
651            _ => self.max_connections,
652        };
653
654        match PgPoolOptions::new()
655            .max_connections(max_connections)
656            // Slightly larger acquire timeout for times when lots of files are being processed
657            .acquire_timeout(self.acquire_timeout)
658            // Close any connections that have been idle for more than 30min
659            .idle_timeout(self.idle_timeout)
660            .connect_with(options)
661            .await
662        {
663            // Success case
664            Ok(value) => Ok(value),
665            Err(err) => {
666                // Drop the connect info cache in case the credentials were wrong
667                self.connect_info_cache.remove(secret_name).await;
668                Err(DbConnectErr::Db(err))
669            }
670        }
671    }
672}