Skip to main content

docbox_database/
pool.rs

1//! # Database Pool
2//!
3//! This is the docbox solution for managing multiple database connections
4//! and connection pools for each tenant and the root database itself.
5//!
6//! Pools are held in a cache with an expiry time to ensure they don't
7//! hog too many database connections.
8//!
9//! Database pools and credentials are stored in a Tiny LFU cache these caches
10//! can be flushed using [DatabasePoolCache::flush]
11//!
12//! ## Environment Variables
13//!
14//! * `DOCBOX_DB_HOST` - Database host
15//! * `DOCBOX_DB_PORT` - Database port
16//! * `DOCBOX_DB_CREDENTIAL_NAME` - Secrets manager name for the root database secret
17//! * `DOCBOX_DB_MAX_CONNECTIONS` - Max connections each tenant pool can contain
18//! * `DOCBOX_DB_MAX_ROOT_CONNECTIONS` - Max connections the root "docbox" pool can contain
19//! * `DOCBOX_DB_ACQUIRE_TIMEOUT` - Timeout before acquiring a connection fails
20//! * `DOCBOX_DB_POOL_TIMEOUT` - Maximum time a connection can live in the cache for
21//! * `DOCBOX_DB_IDLE_TIMEOUT` - Timeout before a idle connection is closed to save resources
22//! * `DOCBOX_DB_CACHE_DURATION` - Duration pools can remain in the cache for untouched before they are closed and removed
23//! * `DOCBOX_DB_CACHE_CAPACITY` - Maximum database pools to hold at once
24//! * `DOCBOX_DB_CREDENTIALS_CACHE_DURATION` - Duration database credentials should be cached for
25//! * `DOCBOX_DB_CREDENTIALS_CACHE_CAPACITY` - Maximum database credentials to cache
26
27use crate::{DbErr, DbPool, ROOT_DATABASE_NAME, ROOT_DATABASE_ROLE_NAME, models::tenant::Tenant};
28use aws_config::SdkConfig;
29use aws_credential_types::provider::{ProvideCredentials, error::CredentialsError};
30use aws_sigv4::{
31    http_request::{SignableBody, SignableRequest, SigningError, SigningSettings, sign},
32    sign::v4::signing_params,
33};
34use docbox_secrets::{SecretManager, SecretManagerError};
35use moka::{future::Cache, policy::EvictionPolicy};
36use serde::{Deserialize, Serialize};
37use sqlx::{
38    PgPool,
39    postgres::{PgConnectOptions, PgPoolOptions},
40};
41use std::time::Duration;
42use std::{num::ParseIntError, str::ParseBoolError};
43use std::{sync::Arc, time::SystemTime};
44use thiserror::Error;
45use tokio::time::sleep;
46
47///  Config for the database pool
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct DatabasePoolCacheConfig {
50    /// Database host
51    pub host: String,
52    /// Database port
53    pub port: u16,
54
55    /// Name of the secrets manager secret to use when connecting to
56    /// the root "docbox" database if using secret based authentication
57    pub root_secret_name: Option<String>,
58
59    /// Whether to use IAM authentication to connect to the
60    /// root database instead of secrets
61    #[serde(default)]
62    pub root_iam: bool,
63
64    /// Max number of active connections per tenant database pool
65    ///
66    /// This is the maximum number of connections that should be allocated
67    /// for performing all queries against each specific tenant.
68    ///
69    /// Ensure a reasonable amount of connections are allocated but make
70    /// sure that the `max_connections` * your number of tenants stays
71    /// within the limits for your database
72    ///
73    /// Default: 10
74    pub max_connections: Option<u32>,
75
76    /// Max number of active connections per "docbox" database pool
77    ///
78    /// This is the maximum number of connections that should be allocated
79    /// for performing queries like:
80    /// - Listing tenants
81    /// - Getting tenant details
82    ///
83    /// These pools are often short lived and complete their queries very fast
84    /// and thus don't need a huge amount of resources allocated to them
85    ///
86    /// Default: 2
87    pub max_connections_root: Option<u32>,
88
89    /// Timeout before a acquiring a database connection is considered
90    /// a failure
91    ///
92    /// Default: 60s
93    pub acquire_timeout: Option<u64>,
94
95    /// If a connection has been idle for this duration the connection
96    /// will be closed and released back to the database for other
97    /// consumers
98    ///
99    /// Default: 10min
100    pub idle_timeout: Option<u64>,
101
102    /// Maximum time pool are allowed to stay within the database
103    /// cache before they are automatically removed
104    ///
105    /// Default: 48h
106    pub pool_timeout: Option<u64>,
107
108    /// Duration in seconds idle database pools are allowed to be cached before
109    /// they are closed
110    ///
111    /// Default: 48h
112    pub cache_duration: Option<u64>,
113
114    /// Maximum database pools to maintain in the cache at once. If the
115    /// cache capacity is exceeded old pools will be closed and removed
116    /// from the cache
117    ///
118    /// This capacity should be aligned with your expected number of
119    /// tenants along with your `max_connections` to ensure your database
120    /// has enough connections to accommodate all tenants.
121    ///
122    /// Default: 50
123    pub cache_capacity: Option<u64>,
124
125    /// Duration in seconds database credentials (host, port, password, ..etc)
126    /// are allowed to be cached before they are refresh from the secrets
127    /// manager
128    ///
129    /// Default: 12h
130    pub credentials_cache_duration: Option<u64>,
131
132    /// Maximum database credentials to maintain in the cache at once. If the
133    /// cache capacity is exceeded old credentials will be removed from the cache
134    ///
135    /// Default: 50
136    pub credentials_cache_capacity: Option<u64>,
137}
138
139impl Default for DatabasePoolCacheConfig {
140    fn default() -> Self {
141        Self {
142            host: Default::default(),
143            port: 5432,
144            root_secret_name: Default::default(),
145            root_iam: false,
146            max_connections: None,
147            max_connections_root: None,
148            acquire_timeout: None,
149            idle_timeout: None,
150            pool_timeout: None,
151            cache_duration: None,
152            cache_capacity: None,
153            credentials_cache_duration: None,
154            credentials_cache_capacity: None,
155        }
156    }
157}
158
159#[derive(Debug, Error)]
160pub enum DatabasePoolCacheConfigError {
161    #[error("missing DOCBOX_DB_HOST environment variable")]
162    MissingDatabaseHost,
163    #[error("missing DOCBOX_DB_PORT environment variable")]
164    MissingDatabasePort,
165    #[error("invalid DOCBOX_DB_PORT environment variable")]
166    InvalidDatabasePort,
167    #[error("missing DOCBOX_DB_CREDENTIAL_NAME environment variable")]
168    MissingDatabaseSecretName,
169    #[error("invalid DOCBOX_DB_POOL_TIMEOUT environment variable")]
170    InvalidPoolTimeout(ParseIntError),
171    #[error("invalid DOCBOX_DB_IDLE_TIMEOUT environment variable")]
172    InvalidIdleTimeout(ParseIntError),
173    #[error("invalid DOCBOX_DB_ACQUIRE_TIMEOUT environment variable")]
174    InvalidAcquireTimeout(ParseIntError),
175    #[error("invalid DOCBOX_DB_CACHE_DURATION environment variable")]
176    InvalidCacheDuration(ParseIntError),
177    #[error("invalid DOCBOX_DB_CACHE_CAPACITY environment variable")]
178    InvalidCacheCapacity(ParseIntError),
179    #[error("invalid DOCBOX_DB_CREDENTIALS_CACHE_DURATION environment variable")]
180    InvalidCredentialsCacheDuration(ParseIntError),
181    #[error("invalid DOCBOX_DB_CREDENTIALS_CACHE_CAPACITY environment variable")]
182    InvalidCredentialsCacheCapacity(ParseIntError),
183    #[error("invalid DOCBOX_DB_ROOT_IAM environment variable")]
184    InvalidRootIam(ParseBoolError),
185}
186
187impl DatabasePoolCacheConfig {
188    pub fn from_env() -> Result<DatabasePoolCacheConfig, DatabasePoolCacheConfigError> {
189        let db_host: String = std::env::var("DOCBOX_DB_HOST")
190            .or(std::env::var("POSTGRES_HOST"))
191            .map_err(|_| DatabasePoolCacheConfigError::MissingDatabaseHost)?;
192        let db_port: u16 = std::env::var("DOCBOX_DB_PORT")
193            .or(std::env::var("POSTGRES_PORT"))
194            .map_err(|_| DatabasePoolCacheConfigError::MissingDatabasePort)?
195            .parse()
196            .map_err(|_| DatabasePoolCacheConfigError::InvalidDatabasePort)?;
197
198        let db_root_secret_name = std::env::var("DOCBOX_DB_CREDENTIAL_NAME").ok();
199        let db_root_iam = std::env::var("DOCBOX_DB_ROOT_IAM")
200            .ok()
201            .map(|value| value.parse::<bool>())
202            .transpose()
203            .map_err(DatabasePoolCacheConfigError::InvalidRootIam)?
204            .unwrap_or_default();
205
206        // Root secret name is required when not using IAM
207        if !db_root_iam && db_root_secret_name.is_none() {
208            return Err(DatabasePoolCacheConfigError::MissingDatabaseSecretName);
209        }
210
211        let max_connections: Option<u32> = std::env::var("DOCBOX_DB_MAX_CONNECTIONS")
212            .ok()
213            .and_then(|value| value.parse().ok());
214        let max_connections_root: Option<u32> = std::env::var("DOCBOX_DB_MAX_ROOT_CONNECTIONS")
215            .ok()
216            .and_then(|value| value.parse().ok());
217
218        let acquire_timeout: Option<u64> = match std::env::var("DOCBOX_DB_ACQUIRE_TIMEOUT") {
219            Ok(value) => Some(
220                value
221                    .parse::<u64>()
222                    .map_err(DatabasePoolCacheConfigError::InvalidAcquireTimeout)?,
223            ),
224            Err(_) => None,
225        };
226
227        let pool_timeout: Option<u64> = match std::env::var("DOCBOX_DB_POOL_TIMEOUT") {
228            Ok(value) => Some(
229                value
230                    .parse::<u64>()
231                    .map_err(DatabasePoolCacheConfigError::InvalidPoolTimeout)?,
232            ),
233            Err(_) => None,
234        };
235
236        let idle_timeout: Option<u64> = match std::env::var("DOCBOX_DB_IDLE_TIMEOUT") {
237            Ok(value) => Some(
238                value
239                    .parse::<u64>()
240                    .map_err(DatabasePoolCacheConfigError::InvalidIdleTimeout)?,
241            ),
242            Err(_) => None,
243        };
244
245        let cache_duration: Option<u64> = match std::env::var("DOCBOX_DB_CACHE_DURATION") {
246            Ok(value) => Some(
247                value
248                    .parse::<u64>()
249                    .map_err(DatabasePoolCacheConfigError::InvalidCacheDuration)?,
250            ),
251            Err(_) => None,
252        };
253
254        let cache_capacity: Option<u64> = match std::env::var("DOCBOX_DB_CACHE_CAPACITY") {
255            Ok(value) => Some(
256                value
257                    .parse::<u64>()
258                    .map_err(DatabasePoolCacheConfigError::InvalidCacheCapacity)?,
259            ),
260            Err(_) => None,
261        };
262
263        let credentials_cache_duration: Option<u64> =
264            match std::env::var("DOCBOX_DB_CREDENTIALS_CACHE_DURATION") {
265                Ok(value) => Some(
266                    value
267                        .parse::<u64>()
268                        .map_err(DatabasePoolCacheConfigError::InvalidCredentialsCacheDuration)?,
269                ),
270                Err(_) => None,
271            };
272
273        let credentials_cache_capacity: Option<u64> =
274            match std::env::var("DOCBOX_DB_CREDENTIALS_CACHE_CAPACITY") {
275                Ok(value) => Some(
276                    value
277                        .parse::<u64>()
278                        .map_err(DatabasePoolCacheConfigError::InvalidCredentialsCacheCapacity)?,
279                ),
280                Err(_) => None,
281            };
282
283        Ok(DatabasePoolCacheConfig {
284            host: db_host,
285            port: db_port,
286            root_iam: db_root_iam,
287            root_secret_name: db_root_secret_name,
288            max_connections,
289            max_connections_root,
290            acquire_timeout,
291            pool_timeout,
292            idle_timeout,
293            cache_duration,
294            cache_capacity,
295            credentials_cache_duration,
296            credentials_cache_capacity,
297        })
298    }
299}
300
301/// Cache for database pools
302pub struct DatabasePoolCache {
303    /// AWS config
304    aws_config: aws_config::SdkConfig,
305
306    /// Database host
307    host: String,
308
309    /// Database port
310    port: u16,
311
312    /// Name of the secrets manager secret that contains
313    /// the credentials for the root "docbox" database
314    ///
315    /// Only present if using secrets based authentication
316    root_secret_name: Option<String>,
317
318    /// Whether to use IAM authentication to connect to the
319    /// root database instead of secrets
320    root_iam: bool,
321
322    /// Cache from the database name to the pool for that database
323    cache: Cache<String, DbPool>,
324
325    /// Cache for the connection info details, stores the last known
326    /// credentials and the instant that they were obtained at
327    connect_info_cache: Cache<String, DbSecrets>,
328
329    /// Secrets manager access to load credentials
330    secrets_manager: SecretManager,
331
332    /// Max connections per tenant database pool
333    max_connections: u32,
334    /// Max connections per root database pool
335    max_connections_root: u32,
336
337    acquire_timeout: Duration,
338    idle_timeout: Duration,
339}
340
341/// Username and password for a specific database
342#[derive(Debug, Clone, Serialize, Deserialize)]
343pub struct DbSecrets {
344    pub username: String,
345    pub password: String,
346}
347
348#[derive(Debug, Error)]
349pub enum DbConnectErr {
350    #[error("database credentials not found in secrets manager")]
351    MissingCredentials,
352
353    #[error(transparent)]
354    SecretsManager(Box<SecretManagerError>),
355
356    #[error(transparent)]
357    Db(#[from] DbErr),
358
359    #[error(transparent)]
360    Shared(#[from] Arc<DbConnectErr>),
361
362    #[error("missing aws credentials provider")]
363    MissingCredentialsProvider,
364
365    #[error("failed to provide aws credentials")]
366    AwsCredentials(#[from] CredentialsError),
367
368    #[error("aws configuration missing region")]
369    MissingRegion,
370
371    #[error("failed to build aws signature")]
372    AwsSigner(#[from] signing_params::BuildError),
373
374    #[error("failed to sign aws request")]
375    AwsRequestSign(#[from] SigningError),
376
377    #[error("failed to parse signed aws url")]
378    AwsSignerInvalidUrl(url::ParseError),
379
380    #[error("failed to connect to tenant missing both IAM and secrets fields")]
381    InvalidTenantConfiguration,
382}
383
384impl DatabasePoolCache {
385    pub fn from_config(
386        aws_config: aws_config::SdkConfig,
387        config: DatabasePoolCacheConfig,
388        secrets_manager: SecretManager,
389    ) -> Self {
390        let mut pool_timeout = Duration::from_secs(config.cache_duration.unwrap_or(60 * 60 * 48));
391        let cache_duration = Duration::from_secs(config.cache_duration.unwrap_or(60 * 60 * 48));
392        let credentials_cache_duration =
393            Duration::from_secs(config.credentials_cache_duration.unwrap_or(60 * 60 * 12));
394
395        // When using IAM ensure the pool timeout is less than the expiration time
396        // of the temporary access tokens
397        if config.root_iam && config.pool_timeout.is_none() {
398            tracing::debug!(
399                "IAM database auth is enabled with no pool timeout, setting short pool timeout within token duration"
400            );
401            pool_timeout = Duration::from_secs(60 * 10);
402        }
403
404        let cache_capacity = config.cache_capacity.unwrap_or(50);
405        let credentials_cache_capacity = config.credentials_cache_capacity.unwrap_or(50);
406
407        let cache = Cache::builder()
408            .time_to_live(pool_timeout)
409            .time_to_idle(cache_duration)
410            .max_capacity(cache_capacity)
411            .eviction_policy(EvictionPolicy::tiny_lfu())
412            .async_eviction_listener(|cache_key: Arc<String>, pool: DbPool, _cause| {
413                Box::pin(async move {
414                    tracing::debug!(?cache_key, "database pool is no longer in use, closing");
415                    pool.close().await
416                })
417            })
418            .build();
419
420        let connect_info_cache = Cache::builder()
421            .time_to_idle(credentials_cache_duration)
422            .max_capacity(credentials_cache_capacity)
423            .eviction_policy(EvictionPolicy::tiny_lfu())
424            .build();
425
426        Self {
427            aws_config,
428            host: config.host,
429            port: config.port,
430            root_secret_name: config.root_secret_name,
431            root_iam: config.root_iam,
432            cache,
433            connect_info_cache,
434            secrets_manager,
435            max_connections: config.max_connections.unwrap_or(10),
436            max_connections_root: config.max_connections_root.unwrap_or(2),
437            idle_timeout: Duration::from_secs(config.idle_timeout.unwrap_or(60 * 10)),
438            acquire_timeout: Duration::from_secs(config.acquire_timeout.unwrap_or(60)),
439        }
440    }
441
442    /// Request a database pool for the root database
443    pub async fn get_root_pool(&self) -> Result<PgPool, DbConnectErr> {
444        match (self.root_secret_name.as_ref(), self.root_iam) {
445            (_, true) => {
446                self.get_pool_iam(ROOT_DATABASE_NAME, ROOT_DATABASE_ROLE_NAME)
447                    .await
448            }
449
450            (Some(db_secret_name), _) => self.get_pool(ROOT_DATABASE_NAME, db_secret_name).await,
451
452            _ => Err(DbConnectErr::InvalidTenantConfiguration),
453        }
454    }
455
456    /// Request a database pool for a specific tenant
457    pub async fn get_tenant_pool(&self, tenant: &Tenant) -> Result<DbPool, DbConnectErr> {
458        match (
459            tenant.db_iam_user_name.as_ref(),
460            tenant.db_secret_name.as_ref(),
461        ) {
462            (Some(db_iam_user_name), _) => {
463                self.get_pool_iam(&tenant.db_name, db_iam_user_name).await
464            }
465            (_, Some(db_secret_name)) => self.get_pool(&tenant.db_name, db_secret_name).await,
466
467            _ => Err(DbConnectErr::InvalidTenantConfiguration),
468        }
469    }
470
471    /// Closes the database pool for the specific tenant if one is
472    /// available and removes the pool from the cache
473    pub async fn close_tenant_pool(&self, tenant: &Tenant) {
474        let cache_key = Self::tenant_cache_key(tenant);
475        if let Some(pool) = self.cache.remove(&cache_key).await {
476            pool.close().await;
477        }
478
479        // Run cache async shutdown jobs
480        self.cache.run_pending_tasks().await;
481    }
482
483    /// Compute the pool cache key for a tenant based on the specific
484    /// authentication methods for that tenant
485    fn tenant_cache_key(tenant: &Tenant) -> String {
486        match (
487            tenant.db_secret_name.as_ref(),
488            tenant.db_iam_user_name.as_ref(),
489        ) {
490            (Some(db_secret_name), _) => {
491                format!("secret-{}-{}", &tenant.db_name, db_secret_name)
492            }
493            (_, Some(db_iam_user_name)) => {
494                format!("user-{}-{}", &tenant.db_name, db_iam_user_name)
495            }
496
497            _ => format!("db-{}", &tenant.db_name),
498        }
499    }
500
501    /// Empties all the caches
502    pub async fn flush(&self) {
503        // Clear cache
504        self.cache.invalidate_all();
505        self.connect_info_cache.invalidate_all();
506        self.cache.run_pending_tasks().await;
507    }
508
509    /// Close all connections in the pool and invalidate the cache
510    pub async fn close_all(&self) {
511        for (_, value) in self.cache.iter() {
512            value.close().await;
513        }
514
515        self.flush().await;
516    }
517
518    /// Obtains a database pool connection to the database with the provided name
519    /// using secrets manager based credentials
520    async fn get_pool(&self, db_name: &str, secret_name: &str) -> Result<DbPool, DbConnectErr> {
521        let cache_key = format!("secret-{db_name}-{secret_name}");
522
523        let pool = self
524            .cache
525            .try_get_with(cache_key, async {
526                tracing::debug!(?db_name, "acquiring database pool");
527
528                let pool = self
529                    .create_pool(db_name, secret_name)
530                    .await
531                    .map_err(Arc::new)?;
532
533                Ok(pool)
534            })
535            .await?;
536
537        Ok(pool)
538    }
539
540    /// Obtains a database pool connection to the database with the provided name
541    /// using IAM based credentials
542    async fn get_pool_iam(
543        &self,
544        db_name: &str,
545        db_role_name: &str,
546    ) -> Result<DbPool, DbConnectErr> {
547        let cache_key = format!("user-{db_name}-{db_role_name}");
548
549        let pool = self
550            .cache
551            .try_get_with(cache_key, async {
552                tracing::debug!(?db_name, "acquiring database pool (iam)");
553
554                let pool = self
555                    .create_pool_iam(db_name, db_role_name)
556                    .await
557                    .map_err(Arc::new)?;
558
559                Ok(pool)
560            })
561            .await?;
562
563        Ok(pool)
564    }
565
566    /// Obtains database connection info
567    async fn get_credentials(&self, secret_name: &str) -> Result<DbSecrets, DbConnectErr> {
568        if let Some(connect_info) = self.connect_info_cache.get(secret_name).await {
569            return Ok(connect_info);
570        }
571
572        // Load new credentials
573        let credentials = self
574            .secrets_manager
575            .parsed_secret::<DbSecrets>(secret_name)
576            .await
577            .map_err(|err| DbConnectErr::SecretsManager(Box::new(err)))?
578            .ok_or(DbConnectErr::MissingCredentials)?;
579
580        // Cache the credential
581        self.connect_info_cache
582            .insert(secret_name.to_string(), credentials.clone())
583            .await;
584
585        Ok(credentials)
586    }
587
588    /// Creates a database pool connection using IAM based authentication
589    async fn create_pool_iam(
590        &self,
591        db_name: &str,
592        db_role_name: &str,
593    ) -> Result<DbPool, DbConnectErr> {
594        tracing::debug!(?db_name, ?db_role_name, "creating db pool connection");
595
596        let options = iam_pool_connect_options(
597            &self.aws_config,
598            &self.host,
599            self.port,
600            db_name,
601            db_role_name,
602        )
603        .await?;
604
605        let max_connections = match db_name {
606            ROOT_DATABASE_NAME => self.max_connections_root,
607            _ => self.max_connections,
608        };
609
610        let pool = PgPoolOptions::new()
611            .max_connections(max_connections)
612            // Slightly larger acquire timeout for times when lots of files are being processed
613            .acquire_timeout(self.acquire_timeout)
614            // Close any connections that have been idle for more than 30min
615            .idle_timeout(self.idle_timeout)
616            .connect_with(options)
617            .await
618            .map_err(DbConnectErr::Db)?;
619
620        tokio::spawn(iam_pool_maintenance_task(
621            pool.clone(),
622            self.aws_config.clone(),
623            self.host.clone(),
624            self.port,
625            db_name.to_string(),
626            db_role_name.to_string(),
627        ));
628
629        Ok(pool)
630    }
631
632    /// Creates a database pool connection
633    async fn create_pool(&self, db_name: &str, secret_name: &str) -> Result<DbPool, DbConnectErr> {
634        tracing::debug!(?db_name, ?secret_name, "creating db pool connection");
635
636        let credentials = self.get_credentials(secret_name).await?;
637        let options = PgConnectOptions::new()
638            .host(&self.host)
639            .port(self.port)
640            .username(&credentials.username)
641            .password(&credentials.password)
642            .database(db_name);
643
644        let max_connections = match db_name {
645            ROOT_DATABASE_NAME => self.max_connections_root,
646            _ => self.max_connections,
647        };
648
649        match PgPoolOptions::new()
650            .max_connections(max_connections)
651            // Slightly larger acquire timeout for times when lots of files are being processed
652            .acquire_timeout(self.acquire_timeout)
653            // Close any connections that have been idle for more than 30min
654            .idle_timeout(self.idle_timeout)
655            .connect_with(options)
656            .await
657        {
658            // Success case
659            Ok(value) => Ok(value),
660            Err(err) => {
661                // Drop the connect info cache in case the credentials were wrong
662                self.connect_info_cache.remove(secret_name).await;
663                Err(DbConnectErr::Db(err))
664            }
665        }
666    }
667}
668
669async fn iam_pool_connect_options(
670    aws_config: &SdkConfig,
671    host: &str,
672    port: u16,
673    db_name: &str,
674    db_role_name: &str,
675) -> Result<PgConnectOptions, DbConnectErr> {
676    let token = create_rds_signed_token(aws_config, host, port, db_role_name).await?;
677
678    let options = PgConnectOptions::new()
679        .host(host)
680        .port(port)
681        .username(db_role_name)
682        .password(&token)
683        .database(db_name);
684
685    Ok(options)
686}
687
688async fn create_rds_signed_token(
689    aws_config: &SdkConfig,
690    host: &str,
691    port: u16,
692    user: &str,
693) -> Result<String, DbConnectErr> {
694    let credentials_provider = aws_config
695        .credentials_provider()
696        .ok_or(DbConnectErr::MissingCredentialsProvider)?;
697    let credentials = credentials_provider.provide_credentials().await?;
698    let identity = credentials.into();
699    let region = aws_config.region().ok_or(DbConnectErr::MissingRegion)?;
700
701    let mut signing_settings = SigningSettings::default();
702    signing_settings.expires_in = Some(Duration::from_secs(60 * 15));
703    signing_settings.signature_location = aws_sigv4::http_request::SignatureLocation::QueryParams;
704
705    let signing_params = aws_sigv4::sign::v4::SigningParams::builder()
706        .identity(&identity)
707        .region(region.as_ref())
708        .name("rds-db")
709        .time(SystemTime::now())
710        .settings(signing_settings)
711        .build()?;
712
713    let url = format!("https://{host}:{port}/?Action=connect&DBUser={user}");
714
715    let signable_request =
716        SignableRequest::new("GET", &url, std::iter::empty(), SignableBody::Bytes(&[]))?;
717
718    let (signing_instructions, _signature) =
719        sign(signable_request, &signing_params.into())?.into_parts();
720
721    let mut url = url::Url::parse(&url).map_err(DbConnectErr::AwsSignerInvalidUrl)?;
722    for (name, value) in signing_instructions.params() {
723        url.query_pairs_mut().append_pair(name, value);
724    }
725
726    let response = url.to_string().split_off("https://".len());
727    Ok(response)
728}
729
730/// Background task spawned for IAM pools running every 10minutes to ensure that the pool
731/// has an up-to-date temporary authentication token
732async fn iam_pool_maintenance_task(
733    db: DbPool,
734    aws_config: SdkConfig,
735    host: String,
736    port: u16,
737    db_name: String,
738    db_role_name: String,
739) {
740    let interval = Duration::from_secs(60 * 10);
741
742    loop {
743        if db.is_closed() {
744            return;
745        }
746
747        match iam_pool_connect_options(&aws_config, &host, port, &db_name, &db_role_name).await {
748            Ok(options) => {
749                db.set_connect_options(options);
750            }
751            Err(error) => {
752                tracing::error!(?error, "failed to refresh IAM pool connect options");
753            }
754        }
755
756        sleep(interval).await;
757    }
758}