Skip to main content

docbox_database/
pool.rs

1//! # Database Pool
2//!
3//! This is the docbox solution for managing multiple database connections
4//! and connection pools for each tenant and the root database itself.
5//!
6//! Pools are held in a cache with an expiry time to ensure they don't
7//! hog too many database connections.
8//!
9//! Database pools and credentials are stored in a Tiny LFU cache these caches
10//! can be flushed using [DatabasePoolCache::flush]
11//!
12//! ## Environment Variables
13//!
14//! * `DOCBOX_DB_HOST` - Database host
15//! * `DOCBOX_DB_PORT` - Database port
16//! * `DOCBOX_DB_CREDENTIAL_NAME` - Secrets manager name for the root database secret
17//! * `DOCBOX_DB_MAX_CONNECTIONS` - Max connections each tenant pool can contain
18//! * `DOCBOX_DB_MAX_ROOT_CONNECTIONS` - Max connections the root "docbox" pool can contain
19//! * `DOCBOX_DB_ACQUIRE_TIMEOUT` - Timeout before acquiring a connection fails
20//! * `DOCBOX_DB_IDLE_TIMEOUT` - Timeout before a idle connection is closed to save resources
21//! * `DOCBOX_DB_CACHE_DURATION` - Duration idle pools should be maintained for before closing
22//! * `DOCBOX_DB_CACHE_CAPACITY` - Maximum database pools to hold at once
23//! * `DOCBOX_DB_CREDENTIALS_CACHE_DURATION` - Duration database credentials should be cached for
24//! * `DOCBOX_DB_CREDENTIALS_CACHE_CAPACITY` - Maximum database credentials to cache
25
26use crate::{DbErr, DbPool, ROOT_DATABASE_NAME, ROOT_DATABASE_ROLE_NAME, models::tenant::Tenant};
27use aws_config::SdkConfig;
28use aws_credential_types::provider::{ProvideCredentials, error::CredentialsError};
29use aws_sigv4::{
30    http_request::{SignableBody, SignableRequest, SigningError, SigningSettings, sign},
31    sign::v4::signing_params,
32};
33use docbox_secrets::{SecretManager, SecretManagerError};
34use moka::{future::Cache, policy::EvictionPolicy};
35use serde::{Deserialize, Serialize};
36use sqlx::{
37    PgPool,
38    postgres::{PgConnectOptions, PgPoolOptions},
39};
40use std::time::Duration;
41use std::{num::ParseIntError, str::ParseBoolError};
42use std::{sync::Arc, time::SystemTime};
43use thiserror::Error;
44use tokio::time::sleep;
45
46///  Config for the database pool
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct DatabasePoolCacheConfig {
49    /// Database host
50    pub host: String,
51    /// Database port
52    pub port: u16,
53
54    /// Name of the secrets manager secret to use when connecting to
55    /// the root "docbox" database if using secret based authentication
56    pub root_secret_name: Option<String>,
57
58    /// Whether to use IAM authentication to connect to the
59    /// root database instead of secrets
60    #[serde(default)]
61    pub root_iam: bool,
62
63    /// Max number of active connections per tenant database pool
64    ///
65    /// This is the maximum number of connections that should be allocated
66    /// for performing all queries against each specific tenant.
67    ///
68    /// Ensure a reasonable amount of connections are allocated but make
69    /// sure that the `max_connections` * your number of tenants stays
70    /// within the limits for your database
71    ///
72    /// Default: 10
73    pub max_connections: Option<u32>,
74
75    /// Max number of active connections per "docbox" database pool
76    ///
77    /// This is the maximum number of connections that should be allocated
78    /// for performing queries like:
79    /// - Listing tenants
80    /// - Getting tenant details
81    ///
82    /// These pools are often short lived and complete their queries very fast
83    /// and thus don't need a huge amount of resources allocated to them
84    ///
85    /// Default: 2
86    pub max_connections_root: Option<u32>,
87
88    /// Timeout before a acquiring a database connection is considered
89    /// a failure
90    ///
91    /// Default: 60s
92    pub acquire_timeout: Option<u64>,
93
94    /// If a connection has been idle for this duration the connection
95    /// will be closed and released back to the database for other
96    /// consumers
97    ///
98    /// Default: 10min
99    pub idle_timeout: Option<u64>,
100
101    /// Maximum time pool are allowed to stay within the database
102    /// cache before they are automatically removed
103    ///
104    /// Default: 48h
105    pub pool_timeout: Option<u64>,
106
107    /// Duration in seconds idle database pools are allowed to be cached before
108    /// they are closed
109    ///
110    /// Default: 48h
111    pub cache_duration: Option<u64>,
112
113    /// Maximum database pools to maintain in the cache at once. If the
114    /// cache capacity is exceeded old pools will be closed and removed
115    /// from the cache
116    ///
117    /// This capacity should be aligned with your expected number of
118    /// tenants along with your `max_connections` to ensure your database
119    /// has enough connections to accommodate all tenants.
120    ///
121    /// Default: 50
122    pub cache_capacity: Option<u64>,
123
124    /// Duration in seconds database credentials (host, port, password, ..etc)
125    /// are allowed to be cached before they are refresh from the secrets
126    /// manager
127    ///
128    /// Default: 12h
129    pub credentials_cache_duration: Option<u64>,
130
131    /// Maximum database credentials to maintain in the cache at once. If the
132    /// cache capacity is exceeded old credentials will be removed from the cache
133    ///
134    /// Default: 50
135    pub credentials_cache_capacity: Option<u64>,
136}
137
138impl Default for DatabasePoolCacheConfig {
139    fn default() -> Self {
140        Self {
141            host: Default::default(),
142            port: 5432,
143            root_secret_name: Default::default(),
144            root_iam: false,
145            max_connections: None,
146            max_connections_root: None,
147            acquire_timeout: None,
148            idle_timeout: None,
149            pool_timeout: None,
150            cache_duration: None,
151            cache_capacity: None,
152            credentials_cache_duration: None,
153            credentials_cache_capacity: None,
154        }
155    }
156}
157
158#[derive(Debug, Error)]
159pub enum DatabasePoolCacheConfigError {
160    #[error("missing DOCBOX_DB_HOST environment variable")]
161    MissingDatabaseHost,
162    #[error("missing DOCBOX_DB_PORT environment variable")]
163    MissingDatabasePort,
164    #[error("invalid DOCBOX_DB_PORT environment variable")]
165    InvalidDatabasePort,
166    #[error("missing DOCBOX_DB_CREDENTIAL_NAME environment variable")]
167    MissingDatabaseSecretName,
168    #[error("invalid DOCBOX_DB_POOL_TIMEOUT environment variable")]
169    InvalidPoolTimeout(ParseIntError),
170    #[error("invalid DOCBOX_DB_IDLE_TIMEOUT environment variable")]
171    InvalidIdleTimeout(ParseIntError),
172    #[error("invalid DOCBOX_DB_ACQUIRE_TIMEOUT environment variable")]
173    InvalidAcquireTimeout(ParseIntError),
174    #[error("invalid DOCBOX_DB_CACHE_DURATION environment variable")]
175    InvalidCacheDuration(ParseIntError),
176    #[error("invalid DOCBOX_DB_CACHE_CAPACITY environment variable")]
177    InvalidCacheCapacity(ParseIntError),
178    #[error("invalid DOCBOX_DB_CREDENTIALS_CACHE_DURATION environment variable")]
179    InvalidCredentialsCacheDuration(ParseIntError),
180    #[error("invalid DOCBOX_DB_CREDENTIALS_CACHE_CAPACITY environment variable")]
181    InvalidCredentialsCacheCapacity(ParseIntError),
182    #[error("invalid DOCBOX_DB_ROOT_IAM environment variable")]
183    InvalidRootIam(ParseBoolError),
184}
185
186impl DatabasePoolCacheConfig {
187    pub fn from_env() -> Result<DatabasePoolCacheConfig, DatabasePoolCacheConfigError> {
188        let db_host: String = std::env::var("DOCBOX_DB_HOST")
189            .or(std::env::var("POSTGRES_HOST"))
190            .map_err(|_| DatabasePoolCacheConfigError::MissingDatabaseHost)?;
191        let db_port: u16 = std::env::var("DOCBOX_DB_PORT")
192            .or(std::env::var("POSTGRES_PORT"))
193            .map_err(|_| DatabasePoolCacheConfigError::MissingDatabasePort)?
194            .parse()
195            .map_err(|_| DatabasePoolCacheConfigError::InvalidDatabasePort)?;
196
197        let db_root_secret_name = std::env::var("DOCBOX_DB_CREDENTIAL_NAME").ok();
198        let db_root_iam = std::env::var("DOCBOX_DB_ROOT_IAM")
199            .ok()
200            .map(|value| value.parse::<bool>())
201            .transpose()
202            .map_err(DatabasePoolCacheConfigError::InvalidRootIam)?
203            .unwrap_or_default();
204
205        // Root secret name is required when not using IAM
206        if !db_root_iam && db_root_secret_name.is_none() {
207            return Err(DatabasePoolCacheConfigError::MissingDatabaseSecretName);
208        }
209
210        let max_connections: Option<u32> = std::env::var("DOCBOX_DB_MAX_CONNECTIONS")
211            .ok()
212            .and_then(|value| value.parse().ok());
213        let max_connections_root: Option<u32> = std::env::var("DOCBOX_DB_MAX_ROOT_CONNECTIONS")
214            .ok()
215            .and_then(|value| value.parse().ok());
216
217        let acquire_timeout: Option<u64> = match std::env::var("DOCBOX_DB_ACQUIRE_TIMEOUT") {
218            Ok(value) => Some(
219                value
220                    .parse::<u64>()
221                    .map_err(DatabasePoolCacheConfigError::InvalidAcquireTimeout)?,
222            ),
223            Err(_) => None,
224        };
225
226        let pool_timeout: Option<u64> = match std::env::var("DOCBOX_DB_POOL_TIMEOUT") {
227            Ok(value) => Some(
228                value
229                    .parse::<u64>()
230                    .map_err(DatabasePoolCacheConfigError::InvalidPoolTimeout)?,
231            ),
232            Err(_) => None,
233        };
234
235        let idle_timeout: Option<u64> = match std::env::var("DOCBOX_DB_IDLE_TIMEOUT") {
236            Ok(value) => Some(
237                value
238                    .parse::<u64>()
239                    .map_err(DatabasePoolCacheConfigError::InvalidIdleTimeout)?,
240            ),
241            Err(_) => None,
242        };
243
244        let cache_duration: Option<u64> = match std::env::var("DOCBOX_DB_CACHE_DURATION") {
245            Ok(value) => Some(
246                value
247                    .parse::<u64>()
248                    .map_err(DatabasePoolCacheConfigError::InvalidCacheDuration)?,
249            ),
250            Err(_) => None,
251        };
252
253        let cache_capacity: Option<u64> = match std::env::var("DOCBOX_DB_CACHE_CAPACITY") {
254            Ok(value) => Some(
255                value
256                    .parse::<u64>()
257                    .map_err(DatabasePoolCacheConfigError::InvalidCacheCapacity)?,
258            ),
259            Err(_) => None,
260        };
261
262        let credentials_cache_duration: Option<u64> =
263            match std::env::var("DOCBOX_DB_CREDENTIALS_CACHE_DURATION") {
264                Ok(value) => Some(
265                    value
266                        .parse::<u64>()
267                        .map_err(DatabasePoolCacheConfigError::InvalidCredentialsCacheDuration)?,
268                ),
269                Err(_) => None,
270            };
271
272        let credentials_cache_capacity: Option<u64> =
273            match std::env::var("DOCBOX_DB_CREDENTIALS_CACHE_CAPACITY") {
274                Ok(value) => Some(
275                    value
276                        .parse::<u64>()
277                        .map_err(DatabasePoolCacheConfigError::InvalidCredentialsCacheCapacity)?,
278                ),
279                Err(_) => None,
280            };
281
282        Ok(DatabasePoolCacheConfig {
283            host: db_host,
284            port: db_port,
285            root_iam: db_root_iam,
286            root_secret_name: db_root_secret_name,
287            max_connections,
288            max_connections_root,
289            acquire_timeout,
290            pool_timeout,
291            idle_timeout,
292            cache_duration,
293            cache_capacity,
294            credentials_cache_duration,
295            credentials_cache_capacity,
296        })
297    }
298}
299
300/// Cache for database pools
301pub struct DatabasePoolCache {
302    /// AWS config
303    aws_config: aws_config::SdkConfig,
304
305    /// Database host
306    host: String,
307
308    /// Database port
309    port: u16,
310
311    /// Name of the secrets manager secret that contains
312    /// the credentials for the root "docbox" database
313    ///
314    /// Only present if using secrets based authentication
315    root_secret_name: Option<String>,
316
317    /// Whether to use IAM authentication to connect to the
318    /// root database instead of secrets
319    root_iam: bool,
320
321    /// Cache from the database name to the pool for that database
322    cache: Cache<String, DbPool>,
323
324    /// Cache for the connection info details, stores the last known
325    /// credentials and the instant that they were obtained at
326    connect_info_cache: Cache<String, DbSecrets>,
327
328    /// Secrets manager access to load credentials
329    secrets_manager: SecretManager,
330
331    /// Max connections per tenant database pool
332    max_connections: u32,
333    /// Max connections per root database pool
334    max_connections_root: u32,
335
336    acquire_timeout: Duration,
337    idle_timeout: Duration,
338}
339
340/// Username and password for a specific database
341#[derive(Debug, Clone, Serialize, Deserialize)]
342pub struct DbSecrets {
343    pub username: String,
344    pub password: String,
345}
346
347#[derive(Debug, Error)]
348pub enum DbConnectErr {
349    #[error("database credentials not found in secrets manager")]
350    MissingCredentials,
351
352    #[error(transparent)]
353    SecretsManager(Box<SecretManagerError>),
354
355    #[error(transparent)]
356    Db(#[from] DbErr),
357
358    #[error(transparent)]
359    Shared(#[from] Arc<DbConnectErr>),
360
361    #[error("missing aws credentials provider")]
362    MissingCredentialsProvider,
363
364    #[error("failed to provide aws credentials")]
365    AwsCredentials(#[from] CredentialsError),
366
367    #[error("aws configuration missing region")]
368    MissingRegion,
369
370    #[error("failed to build aws signature")]
371    AwsSigner(#[from] signing_params::BuildError),
372
373    #[error("failed to sign aws request")]
374    AwsRequestSign(#[from] SigningError),
375
376    #[error("failed to parse signed aws url")]
377    AwsSignerInvalidUrl(url::ParseError),
378
379    #[error("failed to connect to tenant missing both IAM and secrets fields")]
380    InvalidTenantConfiguration,
381}
382
383impl DatabasePoolCache {
384    pub fn from_config(
385        aws_config: aws_config::SdkConfig,
386        config: DatabasePoolCacheConfig,
387        secrets_manager: SecretManager,
388    ) -> Self {
389        let mut pool_timeout = Duration::from_secs(config.cache_duration.unwrap_or(60 * 60 * 48));
390        let cache_duration = Duration::from_secs(config.cache_duration.unwrap_or(60 * 60 * 48));
391        let credentials_cache_duration =
392            Duration::from_secs(config.credentials_cache_duration.unwrap_or(60 * 60 * 12));
393
394        // When using IAM ensure the pool timeout is less than the expiration time
395        // of the temporary access tokens
396        if config.root_iam && config.pool_timeout.is_none() {
397            tracing::debug!(
398                "IAM database auth is enabled with no pool timeout, setting short pool timeout within token duration"
399            );
400            pool_timeout = Duration::from_secs(60 * 10);
401        }
402
403        let cache_capacity = config.cache_capacity.unwrap_or(50);
404        let credentials_cache_capacity = config.credentials_cache_capacity.unwrap_or(50);
405
406        let cache = Cache::builder()
407            .time_to_live(pool_timeout)
408            .time_to_idle(cache_duration)
409            .max_capacity(cache_capacity)
410            .eviction_policy(EvictionPolicy::tiny_lfu())
411            .async_eviction_listener(|cache_key: Arc<String>, pool: DbPool, _cause| {
412                Box::pin(async move {
413                    tracing::debug!(?cache_key, "database pool is no longer in use, closing");
414                    pool.close().await
415                })
416            })
417            .build();
418
419        let connect_info_cache = Cache::builder()
420            .time_to_idle(credentials_cache_duration)
421            .max_capacity(credentials_cache_capacity)
422            .eviction_policy(EvictionPolicy::tiny_lfu())
423            .build();
424
425        Self {
426            aws_config,
427            host: config.host,
428            port: config.port,
429            root_secret_name: config.root_secret_name,
430            root_iam: config.root_iam,
431            cache,
432            connect_info_cache,
433            secrets_manager,
434            max_connections: config.max_connections.unwrap_or(10),
435            max_connections_root: config.max_connections_root.unwrap_or(2),
436            idle_timeout: Duration::from_secs(config.idle_timeout.unwrap_or(60 * 10)),
437            acquire_timeout: Duration::from_secs(config.acquire_timeout.unwrap_or(60)),
438        }
439    }
440
441    /// Request a database pool for the root database
442    pub async fn get_root_pool(&self) -> Result<PgPool, DbConnectErr> {
443        match (self.root_secret_name.as_ref(), self.root_iam) {
444            (_, true) => {
445                self.get_pool_iam(ROOT_DATABASE_NAME, ROOT_DATABASE_ROLE_NAME)
446                    .await
447            }
448
449            (Some(db_secret_name), _) => self.get_pool(ROOT_DATABASE_NAME, db_secret_name).await,
450
451            _ => Err(DbConnectErr::InvalidTenantConfiguration),
452        }
453    }
454
455    /// Request a database pool for a specific tenant
456    pub async fn get_tenant_pool(&self, tenant: &Tenant) -> Result<DbPool, DbConnectErr> {
457        match (
458            tenant.db_iam_user_name.as_ref(),
459            tenant.db_secret_name.as_ref(),
460        ) {
461            (Some(db_iam_user_name), _) => {
462                self.get_pool_iam(&tenant.db_name, db_iam_user_name).await
463            }
464            (_, Some(db_secret_name)) => self.get_pool(&tenant.db_name, db_secret_name).await,
465
466            _ => Err(DbConnectErr::InvalidTenantConfiguration),
467        }
468    }
469
470    /// Closes the database pool for the specific tenant if one is
471    /// available and removes the pool from the cache
472    pub async fn close_tenant_pool(&self, tenant: &Tenant) {
473        let cache_key = Self::tenant_cache_key(tenant);
474        if let Some(pool) = self.cache.remove(&cache_key).await {
475            pool.close().await;
476        }
477
478        // Run cache async shutdown jobs
479        self.cache.run_pending_tasks().await;
480    }
481
482    /// Compute the pool cache key for a tenant based on the specific
483    /// authentication methods for that tenant
484    fn tenant_cache_key(tenant: &Tenant) -> String {
485        match (
486            tenant.db_secret_name.as_ref(),
487            tenant.db_iam_user_name.as_ref(),
488        ) {
489            (Some(db_secret_name), _) => {
490                format!("secret-{}-{}", &tenant.db_name, db_secret_name)
491            }
492            (_, Some(db_iam_user_name)) => {
493                format!("user-{}-{}", &tenant.db_name, db_iam_user_name)
494            }
495
496            _ => format!("db-{}", &tenant.db_name),
497        }
498    }
499
500    /// Empties all the caches
501    pub async fn flush(&self) {
502        // Clear cache
503        self.cache.invalidate_all();
504        self.connect_info_cache.invalidate_all();
505        self.cache.run_pending_tasks().await;
506    }
507
508    /// Close all connections in the pool and invalidate the cache
509    pub async fn close_all(&self) {
510        for (_, value) in self.cache.iter() {
511            value.close().await;
512        }
513
514        self.flush().await;
515    }
516
517    /// Obtains a database pool connection to the database with the provided name
518    /// using secrets manager based credentials
519    async fn get_pool(&self, db_name: &str, secret_name: &str) -> Result<DbPool, DbConnectErr> {
520        let cache_key = format!("secret-{db_name}-{secret_name}");
521
522        let pool = self
523            .cache
524            .try_get_with(cache_key, async {
525                tracing::debug!(?db_name, "acquiring database pool");
526
527                let pool = self
528                    .create_pool(db_name, secret_name)
529                    .await
530                    .map_err(Arc::new)?;
531
532                Ok(pool)
533            })
534            .await?;
535
536        Ok(pool)
537    }
538
539    /// Obtains a database pool connection to the database with the provided name
540    /// using IAM based credentials
541    async fn get_pool_iam(
542        &self,
543        db_name: &str,
544        db_role_name: &str,
545    ) -> Result<DbPool, DbConnectErr> {
546        let cache_key = format!("user-{db_name}-{db_role_name}");
547
548        let pool = self
549            .cache
550            .try_get_with(cache_key, async {
551                tracing::debug!(?db_name, "acquiring database pool (iam)");
552
553                let pool = self
554                    .create_pool_iam(db_name, db_role_name)
555                    .await
556                    .map_err(Arc::new)?;
557
558                Ok(pool)
559            })
560            .await?;
561
562        Ok(pool)
563    }
564
565    /// Obtains database connection info
566    async fn get_credentials(&self, secret_name: &str) -> Result<DbSecrets, DbConnectErr> {
567        if let Some(connect_info) = self.connect_info_cache.get(secret_name).await {
568            return Ok(connect_info);
569        }
570
571        // Load new credentials
572        let credentials = self
573            .secrets_manager
574            .parsed_secret::<DbSecrets>(secret_name)
575            .await
576            .map_err(|err| DbConnectErr::SecretsManager(Box::new(err)))?
577            .ok_or(DbConnectErr::MissingCredentials)?;
578
579        // Cache the credential
580        self.connect_info_cache
581            .insert(secret_name.to_string(), credentials.clone())
582            .await;
583
584        Ok(credentials)
585    }
586
587    /// Creates a database pool connection using IAM based authentication
588    async fn create_pool_iam(
589        &self,
590        db_name: &str,
591        db_role_name: &str,
592    ) -> Result<DbPool, DbConnectErr> {
593        tracing::debug!(?db_name, ?db_role_name, "creating db pool connection");
594
595        let options = iam_pool_connect_options(
596            &self.aws_config,
597            &self.host,
598            self.port,
599            db_name,
600            db_role_name,
601        )
602        .await?;
603
604        let max_connections = match db_name {
605            ROOT_DATABASE_NAME => self.max_connections_root,
606            _ => self.max_connections,
607        };
608
609        let pool = PgPoolOptions::new()
610            .max_connections(max_connections)
611            // Slightly larger acquire timeout for times when lots of files are being processed
612            .acquire_timeout(self.acquire_timeout)
613            // Close any connections that have been idle for more than 30min
614            .idle_timeout(self.idle_timeout)
615            .connect_with(options)
616            .await
617            .map_err(DbConnectErr::Db)?;
618
619        tokio::spawn(iam_pool_maintenance_task(
620            pool.clone(),
621            self.aws_config.clone(),
622            self.host.clone(),
623            self.port,
624            db_name.to_string(),
625            db_role_name.to_string(),
626        ));
627
628        Ok(pool)
629    }
630
631    /// Creates a database pool connection
632    async fn create_pool(&self, db_name: &str, secret_name: &str) -> Result<DbPool, DbConnectErr> {
633        tracing::debug!(?db_name, ?secret_name, "creating db pool connection");
634
635        let credentials = self.get_credentials(secret_name).await?;
636        let options = PgConnectOptions::new()
637            .host(&self.host)
638            .port(self.port)
639            .username(&credentials.username)
640            .password(&credentials.password)
641            .database(db_name);
642
643        let max_connections = match db_name {
644            ROOT_DATABASE_NAME => self.max_connections_root,
645            _ => self.max_connections,
646        };
647
648        match PgPoolOptions::new()
649            .max_connections(max_connections)
650            // Slightly larger acquire timeout for times when lots of files are being processed
651            .acquire_timeout(self.acquire_timeout)
652            // Close any connections that have been idle for more than 30min
653            .idle_timeout(self.idle_timeout)
654            .connect_with(options)
655            .await
656        {
657            // Success case
658            Ok(value) => Ok(value),
659            Err(err) => {
660                // Drop the connect info cache in case the credentials were wrong
661                self.connect_info_cache.remove(secret_name).await;
662                Err(DbConnectErr::Db(err))
663            }
664        }
665    }
666}
667
668async fn iam_pool_connect_options(
669    aws_config: &SdkConfig,
670    host: &str,
671    port: u16,
672    db_name: &str,
673    db_role_name: &str,
674) -> Result<PgConnectOptions, DbConnectErr> {
675    let token = create_rds_signed_token(aws_config, host, port, db_role_name).await?;
676
677    let options = PgConnectOptions::new()
678        .host(host)
679        .port(port)
680        .username(db_role_name)
681        .password(&token)
682        .database(db_name);
683
684    Ok(options)
685}
686
687async fn create_rds_signed_token(
688    aws_config: &SdkConfig,
689    host: &str,
690    port: u16,
691    user: &str,
692) -> Result<String, DbConnectErr> {
693    let credentials_provider = aws_config
694        .credentials_provider()
695        .ok_or(DbConnectErr::MissingCredentialsProvider)?;
696    let credentials = credentials_provider.provide_credentials().await?;
697    let identity = credentials.into();
698    let region = aws_config.region().ok_or(DbConnectErr::MissingRegion)?;
699
700    let mut signing_settings = SigningSettings::default();
701    signing_settings.expires_in = Some(Duration::from_secs(60 * 15));
702    signing_settings.signature_location = aws_sigv4::http_request::SignatureLocation::QueryParams;
703
704    let signing_params = aws_sigv4::sign::v4::SigningParams::builder()
705        .identity(&identity)
706        .region(region.as_ref())
707        .name("rds-db")
708        .time(SystemTime::now())
709        .settings(signing_settings)
710        .build()?;
711
712    let url = format!("https://{host}:{port}/?Action=connect&DBUser={user}");
713
714    let signable_request =
715        SignableRequest::new("GET", &url, std::iter::empty(), SignableBody::Bytes(&[]))?;
716
717    let (signing_instructions, _signature) =
718        sign(signable_request, &signing_params.into())?.into_parts();
719
720    let mut url = url::Url::parse(&url).map_err(DbConnectErr::AwsSignerInvalidUrl)?;
721    for (name, value) in signing_instructions.params() {
722        url.query_pairs_mut().append_pair(name, value);
723    }
724
725    let response = url.to_string().split_off("https://".len());
726    Ok(response)
727}
728
729/// Background task spawned for IAM pools running every 10minutes to ensure that the pool
730/// has an up-to-date temporary authentication token
731async fn iam_pool_maintenance_task(
732    db: DbPool,
733    aws_config: SdkConfig,
734    host: String,
735    port: u16,
736    db_name: String,
737    db_role_name: String,
738) {
739    let interval = Duration::from_secs(60 * 10);
740
741    loop {
742        if db.is_closed() {
743            return;
744        }
745
746        match iam_pool_connect_options(&aws_config, &host, port, &db_name, &db_role_name).await {
747            Ok(options) => {
748                db.set_connect_options(options);
749            }
750            Err(error) => {
751                tracing::error!(?error, "failed to refresh IAM pool connect options");
752            }
753        }
754
755        sleep(interval).await;
756    }
757}