1use crate::{DbErr, DbPool, ROOT_DATABASE_NAME, ROOT_DATABASE_ROLE_NAME, models::tenant::Tenant};
27use aws_config::SdkConfig;
28use aws_credential_types::provider::{ProvideCredentials, error::CredentialsError};
29use aws_sigv4::{
30 http_request::{SignableBody, SignableRequest, SigningError, SigningSettings, sign},
31 sign::v4::signing_params,
32};
33use docbox_secrets::{SecretManager, SecretManagerError};
34use moka::{future::Cache, policy::EvictionPolicy};
35use serde::{Deserialize, Serialize};
36use sqlx::{
37 PgPool,
38 postgres::{PgConnectOptions, PgPoolOptions},
39};
40use std::time::Duration;
41use std::{num::ParseIntError, str::ParseBoolError};
42use std::{sync::Arc, time::SystemTime};
43use thiserror::Error;
44use tokio::time::sleep;
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct DatabasePoolCacheConfig {
49 pub host: String,
51 pub port: u16,
53
54 pub root_secret_name: Option<String>,
57
58 #[serde(default)]
61 pub root_iam: bool,
62
63 pub max_connections: Option<u32>,
74
75 pub max_connections_root: Option<u32>,
87
88 pub acquire_timeout: Option<u64>,
93
94 pub idle_timeout: Option<u64>,
100
101 pub pool_timeout: Option<u64>,
106
107 pub cache_duration: Option<u64>,
112
113 pub cache_capacity: Option<u64>,
123
124 pub credentials_cache_duration: Option<u64>,
130
131 pub credentials_cache_capacity: Option<u64>,
136}
137
138impl Default for DatabasePoolCacheConfig {
139 fn default() -> Self {
140 Self {
141 host: Default::default(),
142 port: 5432,
143 root_secret_name: Default::default(),
144 root_iam: false,
145 max_connections: None,
146 max_connections_root: None,
147 acquire_timeout: None,
148 idle_timeout: None,
149 pool_timeout: None,
150 cache_duration: None,
151 cache_capacity: None,
152 credentials_cache_duration: None,
153 credentials_cache_capacity: None,
154 }
155 }
156}
157
158#[derive(Debug, Error)]
159pub enum DatabasePoolCacheConfigError {
160 #[error("missing DOCBOX_DB_HOST environment variable")]
161 MissingDatabaseHost,
162 #[error("missing DOCBOX_DB_PORT environment variable")]
163 MissingDatabasePort,
164 #[error("invalid DOCBOX_DB_PORT environment variable")]
165 InvalidDatabasePort,
166 #[error("missing DOCBOX_DB_CREDENTIAL_NAME environment variable")]
167 MissingDatabaseSecretName,
168 #[error("invalid DOCBOX_DB_POOL_TIMEOUT environment variable")]
169 InvalidPoolTimeout(ParseIntError),
170 #[error("invalid DOCBOX_DB_IDLE_TIMEOUT environment variable")]
171 InvalidIdleTimeout(ParseIntError),
172 #[error("invalid DOCBOX_DB_ACQUIRE_TIMEOUT environment variable")]
173 InvalidAcquireTimeout(ParseIntError),
174 #[error("invalid DOCBOX_DB_CACHE_DURATION environment variable")]
175 InvalidCacheDuration(ParseIntError),
176 #[error("invalid DOCBOX_DB_CACHE_CAPACITY environment variable")]
177 InvalidCacheCapacity(ParseIntError),
178 #[error("invalid DOCBOX_DB_CREDENTIALS_CACHE_DURATION environment variable")]
179 InvalidCredentialsCacheDuration(ParseIntError),
180 #[error("invalid DOCBOX_DB_CREDENTIALS_CACHE_CAPACITY environment variable")]
181 InvalidCredentialsCacheCapacity(ParseIntError),
182 #[error("invalid DOCBOX_DB_ROOT_IAM environment variable")]
183 InvalidRootIam(ParseBoolError),
184}
185
186impl DatabasePoolCacheConfig {
187 pub fn from_env() -> Result<DatabasePoolCacheConfig, DatabasePoolCacheConfigError> {
188 let db_host: String = std::env::var("DOCBOX_DB_HOST")
189 .or(std::env::var("POSTGRES_HOST"))
190 .map_err(|_| DatabasePoolCacheConfigError::MissingDatabaseHost)?;
191 let db_port: u16 = std::env::var("DOCBOX_DB_PORT")
192 .or(std::env::var("POSTGRES_PORT"))
193 .map_err(|_| DatabasePoolCacheConfigError::MissingDatabasePort)?
194 .parse()
195 .map_err(|_| DatabasePoolCacheConfigError::InvalidDatabasePort)?;
196
197 let db_root_secret_name = std::env::var("DOCBOX_DB_CREDENTIAL_NAME").ok();
198 let db_root_iam = std::env::var("DOCBOX_DB_ROOT_IAM")
199 .ok()
200 .map(|value| value.parse::<bool>())
201 .transpose()
202 .map_err(DatabasePoolCacheConfigError::InvalidRootIam)?
203 .unwrap_or_default();
204
205 if !db_root_iam && db_root_secret_name.is_none() {
207 return Err(DatabasePoolCacheConfigError::MissingDatabaseSecretName);
208 }
209
210 let max_connections: Option<u32> = std::env::var("DOCBOX_DB_MAX_CONNECTIONS")
211 .ok()
212 .and_then(|value| value.parse().ok());
213 let max_connections_root: Option<u32> = std::env::var("DOCBOX_DB_MAX_ROOT_CONNECTIONS")
214 .ok()
215 .and_then(|value| value.parse().ok());
216
217 let acquire_timeout: Option<u64> = match std::env::var("DOCBOX_DB_ACQUIRE_TIMEOUT") {
218 Ok(value) => Some(
219 value
220 .parse::<u64>()
221 .map_err(DatabasePoolCacheConfigError::InvalidAcquireTimeout)?,
222 ),
223 Err(_) => None,
224 };
225
226 let pool_timeout: Option<u64> = match std::env::var("DOCBOX_DB_POOL_TIMEOUT") {
227 Ok(value) => Some(
228 value
229 .parse::<u64>()
230 .map_err(DatabasePoolCacheConfigError::InvalidPoolTimeout)?,
231 ),
232 Err(_) => None,
233 };
234
235 let idle_timeout: Option<u64> = match std::env::var("DOCBOX_DB_IDLE_TIMEOUT") {
236 Ok(value) => Some(
237 value
238 .parse::<u64>()
239 .map_err(DatabasePoolCacheConfigError::InvalidIdleTimeout)?,
240 ),
241 Err(_) => None,
242 };
243
244 let cache_duration: Option<u64> = match std::env::var("DOCBOX_DB_CACHE_DURATION") {
245 Ok(value) => Some(
246 value
247 .parse::<u64>()
248 .map_err(DatabasePoolCacheConfigError::InvalidCacheDuration)?,
249 ),
250 Err(_) => None,
251 };
252
253 let cache_capacity: Option<u64> = match std::env::var("DOCBOX_DB_CACHE_CAPACITY") {
254 Ok(value) => Some(
255 value
256 .parse::<u64>()
257 .map_err(DatabasePoolCacheConfigError::InvalidCacheCapacity)?,
258 ),
259 Err(_) => None,
260 };
261
262 let credentials_cache_duration: Option<u64> =
263 match std::env::var("DOCBOX_DB_CREDENTIALS_CACHE_DURATION") {
264 Ok(value) => Some(
265 value
266 .parse::<u64>()
267 .map_err(DatabasePoolCacheConfigError::InvalidCredentialsCacheDuration)?,
268 ),
269 Err(_) => None,
270 };
271
272 let credentials_cache_capacity: Option<u64> =
273 match std::env::var("DOCBOX_DB_CREDENTIALS_CACHE_CAPACITY") {
274 Ok(value) => Some(
275 value
276 .parse::<u64>()
277 .map_err(DatabasePoolCacheConfigError::InvalidCredentialsCacheCapacity)?,
278 ),
279 Err(_) => None,
280 };
281
282 Ok(DatabasePoolCacheConfig {
283 host: db_host,
284 port: db_port,
285 root_iam: db_root_iam,
286 root_secret_name: db_root_secret_name,
287 max_connections,
288 max_connections_root,
289 acquire_timeout,
290 pool_timeout,
291 idle_timeout,
292 cache_duration,
293 cache_capacity,
294 credentials_cache_duration,
295 credentials_cache_capacity,
296 })
297 }
298}
299
300pub struct DatabasePoolCache {
302 aws_config: aws_config::SdkConfig,
304
305 host: String,
307
308 port: u16,
310
311 root_secret_name: Option<String>,
316
317 root_iam: bool,
320
321 cache: Cache<String, DbPool>,
323
324 connect_info_cache: Cache<String, DbSecrets>,
327
328 secrets_manager: SecretManager,
330
331 max_connections: u32,
333 max_connections_root: u32,
335
336 acquire_timeout: Duration,
337 idle_timeout: Duration,
338}
339
340#[derive(Debug, Clone, Serialize, Deserialize)]
342pub struct DbSecrets {
343 pub username: String,
344 pub password: String,
345}
346
347#[derive(Debug, Error)]
348pub enum DbConnectErr {
349 #[error("database credentials not found in secrets manager")]
350 MissingCredentials,
351
352 #[error(transparent)]
353 SecretsManager(Box<SecretManagerError>),
354
355 #[error(transparent)]
356 Db(#[from] DbErr),
357
358 #[error(transparent)]
359 Shared(#[from] Arc<DbConnectErr>),
360
361 #[error("missing aws credentials provider")]
362 MissingCredentialsProvider,
363
364 #[error("failed to provide aws credentials")]
365 AwsCredentials(#[from] CredentialsError),
366
367 #[error("aws configuration missing region")]
368 MissingRegion,
369
370 #[error("failed to build aws signature")]
371 AwsSigner(#[from] signing_params::BuildError),
372
373 #[error("failed to sign aws request")]
374 AwsRequestSign(#[from] SigningError),
375
376 #[error("failed to parse signed aws url")]
377 AwsSignerInvalidUrl(url::ParseError),
378
379 #[error("failed to connect to tenant missing both IAM and secrets fields")]
380 InvalidTenantConfiguration,
381}
382
383impl DatabasePoolCache {
384 pub fn from_config(
385 aws_config: aws_config::SdkConfig,
386 config: DatabasePoolCacheConfig,
387 secrets_manager: SecretManager,
388 ) -> Self {
389 let mut pool_timeout = Duration::from_secs(config.cache_duration.unwrap_or(60 * 60 * 48));
390 let cache_duration = Duration::from_secs(config.cache_duration.unwrap_or(60 * 60 * 48));
391 let credentials_cache_duration =
392 Duration::from_secs(config.credentials_cache_duration.unwrap_or(60 * 60 * 12));
393
394 if config.root_iam && config.pool_timeout.is_none() {
397 tracing::debug!(
398 "IAM database auth is enabled with no pool timeout, setting short pool timeout within token duration"
399 );
400 pool_timeout = Duration::from_secs(60 * 10);
401 }
402
403 let cache_capacity = config.cache_capacity.unwrap_or(50);
404 let credentials_cache_capacity = config.credentials_cache_capacity.unwrap_or(50);
405
406 let cache = Cache::builder()
407 .time_to_live(pool_timeout)
408 .time_to_idle(cache_duration)
409 .max_capacity(cache_capacity)
410 .eviction_policy(EvictionPolicy::tiny_lfu())
411 .async_eviction_listener(|cache_key: Arc<String>, pool: DbPool, _cause| {
412 Box::pin(async move {
413 tracing::debug!(?cache_key, "database pool is no longer in use, closing");
414 pool.close().await
415 })
416 })
417 .build();
418
419 let connect_info_cache = Cache::builder()
420 .time_to_idle(credentials_cache_duration)
421 .max_capacity(credentials_cache_capacity)
422 .eviction_policy(EvictionPolicy::tiny_lfu())
423 .build();
424
425 Self {
426 aws_config,
427 host: config.host,
428 port: config.port,
429 root_secret_name: config.root_secret_name,
430 root_iam: config.root_iam,
431 cache,
432 connect_info_cache,
433 secrets_manager,
434 max_connections: config.max_connections.unwrap_or(10),
435 max_connections_root: config.max_connections_root.unwrap_or(2),
436 idle_timeout: Duration::from_secs(config.idle_timeout.unwrap_or(60 * 10)),
437 acquire_timeout: Duration::from_secs(config.acquire_timeout.unwrap_or(60)),
438 }
439 }
440
441 pub async fn get_root_pool(&self) -> Result<PgPool, DbConnectErr> {
443 match (self.root_secret_name.as_ref(), self.root_iam) {
444 (_, true) => {
445 self.get_pool_iam(ROOT_DATABASE_NAME, ROOT_DATABASE_ROLE_NAME)
446 .await
447 }
448
449 (Some(db_secret_name), _) => self.get_pool(ROOT_DATABASE_NAME, db_secret_name).await,
450
451 _ => Err(DbConnectErr::InvalidTenantConfiguration),
452 }
453 }
454
455 pub async fn get_tenant_pool(&self, tenant: &Tenant) -> Result<DbPool, DbConnectErr> {
457 match (
458 tenant.db_iam_user_name.as_ref(),
459 tenant.db_secret_name.as_ref(),
460 ) {
461 (Some(db_iam_user_name), _) => {
462 self.get_pool_iam(&tenant.db_name, db_iam_user_name).await
463 }
464 (_, Some(db_secret_name)) => self.get_pool(&tenant.db_name, db_secret_name).await,
465
466 _ => Err(DbConnectErr::InvalidTenantConfiguration),
467 }
468 }
469
470 pub async fn close_tenant_pool(&self, tenant: &Tenant) {
473 let cache_key = Self::tenant_cache_key(tenant);
474 if let Some(pool) = self.cache.remove(&cache_key).await {
475 pool.close().await;
476 }
477
478 self.cache.run_pending_tasks().await;
480 }
481
482 fn tenant_cache_key(tenant: &Tenant) -> String {
485 match (
486 tenant.db_secret_name.as_ref(),
487 tenant.db_iam_user_name.as_ref(),
488 ) {
489 (Some(db_secret_name), _) => {
490 format!("secret-{}-{}", &tenant.db_name, db_secret_name)
491 }
492 (_, Some(db_iam_user_name)) => {
493 format!("user-{}-{}", &tenant.db_name, db_iam_user_name)
494 }
495
496 _ => format!("db-{}", &tenant.db_name),
497 }
498 }
499
500 pub async fn flush(&self) {
502 self.cache.invalidate_all();
504 self.connect_info_cache.invalidate_all();
505 self.cache.run_pending_tasks().await;
506 }
507
508 pub async fn close_all(&self) {
510 for (_, value) in self.cache.iter() {
511 value.close().await;
512 }
513
514 self.flush().await;
515 }
516
517 async fn get_pool(&self, db_name: &str, secret_name: &str) -> Result<DbPool, DbConnectErr> {
520 let cache_key = format!("secret-{db_name}-{secret_name}");
521
522 let pool = self
523 .cache
524 .try_get_with(cache_key, async {
525 tracing::debug!(?db_name, "acquiring database pool");
526
527 let pool = self
528 .create_pool(db_name, secret_name)
529 .await
530 .map_err(Arc::new)?;
531
532 Ok(pool)
533 })
534 .await?;
535
536 Ok(pool)
537 }
538
539 async fn get_pool_iam(
542 &self,
543 db_name: &str,
544 db_role_name: &str,
545 ) -> Result<DbPool, DbConnectErr> {
546 let cache_key = format!("user-{db_name}-{db_role_name}");
547
548 let pool = self
549 .cache
550 .try_get_with(cache_key, async {
551 tracing::debug!(?db_name, "acquiring database pool (iam)");
552
553 let pool = self
554 .create_pool_iam(db_name, db_role_name)
555 .await
556 .map_err(Arc::new)?;
557
558 Ok(pool)
559 })
560 .await?;
561
562 Ok(pool)
563 }
564
565 async fn get_credentials(&self, secret_name: &str) -> Result<DbSecrets, DbConnectErr> {
567 if let Some(connect_info) = self.connect_info_cache.get(secret_name).await {
568 return Ok(connect_info);
569 }
570
571 let credentials = self
573 .secrets_manager
574 .parsed_secret::<DbSecrets>(secret_name)
575 .await
576 .map_err(|err| DbConnectErr::SecretsManager(Box::new(err)))?
577 .ok_or(DbConnectErr::MissingCredentials)?;
578
579 self.connect_info_cache
581 .insert(secret_name.to_string(), credentials.clone())
582 .await;
583
584 Ok(credentials)
585 }
586
587 async fn create_pool_iam(
589 &self,
590 db_name: &str,
591 db_role_name: &str,
592 ) -> Result<DbPool, DbConnectErr> {
593 tracing::debug!(?db_name, ?db_role_name, "creating db pool connection");
594
595 let options = iam_pool_connect_options(
596 &self.aws_config,
597 &self.host,
598 self.port,
599 db_name,
600 db_role_name,
601 )
602 .await?;
603
604 let max_connections = match db_name {
605 ROOT_DATABASE_NAME => self.max_connections_root,
606 _ => self.max_connections,
607 };
608
609 let pool = PgPoolOptions::new()
610 .max_connections(max_connections)
611 .acquire_timeout(self.acquire_timeout)
613 .idle_timeout(self.idle_timeout)
615 .connect_with(options)
616 .await
617 .map_err(DbConnectErr::Db)?;
618
619 tokio::spawn(iam_pool_maintenance_task(
620 pool.clone(),
621 self.aws_config.clone(),
622 self.host.clone(),
623 self.port,
624 db_name.to_string(),
625 db_role_name.to_string(),
626 ));
627
628 Ok(pool)
629 }
630
631 async fn create_pool(&self, db_name: &str, secret_name: &str) -> Result<DbPool, DbConnectErr> {
633 tracing::debug!(?db_name, ?secret_name, "creating db pool connection");
634
635 let credentials = self.get_credentials(secret_name).await?;
636 let options = PgConnectOptions::new()
637 .host(&self.host)
638 .port(self.port)
639 .username(&credentials.username)
640 .password(&credentials.password)
641 .database(db_name);
642
643 let max_connections = match db_name {
644 ROOT_DATABASE_NAME => self.max_connections_root,
645 _ => self.max_connections,
646 };
647
648 match PgPoolOptions::new()
649 .max_connections(max_connections)
650 .acquire_timeout(self.acquire_timeout)
652 .idle_timeout(self.idle_timeout)
654 .connect_with(options)
655 .await
656 {
657 Ok(value) => Ok(value),
659 Err(err) => {
660 self.connect_info_cache.remove(secret_name).await;
662 Err(DbConnectErr::Db(err))
663 }
664 }
665 }
666}
667
668async fn iam_pool_connect_options(
669 aws_config: &SdkConfig,
670 host: &str,
671 port: u16,
672 db_name: &str,
673 db_role_name: &str,
674) -> Result<PgConnectOptions, DbConnectErr> {
675 let token = create_rds_signed_token(aws_config, host, port, db_role_name).await?;
676
677 let options = PgConnectOptions::new()
678 .host(host)
679 .port(port)
680 .username(db_role_name)
681 .password(&token)
682 .database(db_name);
683
684 Ok(options)
685}
686
687async fn create_rds_signed_token(
688 aws_config: &SdkConfig,
689 host: &str,
690 port: u16,
691 user: &str,
692) -> Result<String, DbConnectErr> {
693 let credentials_provider = aws_config
694 .credentials_provider()
695 .ok_or(DbConnectErr::MissingCredentialsProvider)?;
696 let credentials = credentials_provider.provide_credentials().await?;
697 let identity = credentials.into();
698 let region = aws_config.region().ok_or(DbConnectErr::MissingRegion)?;
699
700 let mut signing_settings = SigningSettings::default();
701 signing_settings.expires_in = Some(Duration::from_secs(60 * 15));
702 signing_settings.signature_location = aws_sigv4::http_request::SignatureLocation::QueryParams;
703
704 let signing_params = aws_sigv4::sign::v4::SigningParams::builder()
705 .identity(&identity)
706 .region(region.as_ref())
707 .name("rds-db")
708 .time(SystemTime::now())
709 .settings(signing_settings)
710 .build()?;
711
712 let url = format!("https://{host}:{port}/?Action=connect&DBUser={user}");
713
714 let signable_request =
715 SignableRequest::new("GET", &url, std::iter::empty(), SignableBody::Bytes(&[]))?;
716
717 let (signing_instructions, _signature) =
718 sign(signable_request, &signing_params.into())?.into_parts();
719
720 let mut url = url::Url::parse(&url).map_err(DbConnectErr::AwsSignerInvalidUrl)?;
721 for (name, value) in signing_instructions.params() {
722 url.query_pairs_mut().append_pair(name, value);
723 }
724
725 let response = url.to_string().split_off("https://".len());
726 Ok(response)
727}
728
729async fn iam_pool_maintenance_task(
732 db: DbPool,
733 aws_config: SdkConfig,
734 host: String,
735 port: u16,
736 db_name: String,
737 db_role_name: String,
738) {
739 let interval = Duration::from_secs(60 * 10);
740
741 loop {
742 if db.is_closed() {
743 return;
744 }
745
746 match iam_pool_connect_options(&aws_config, &host, port, &db_name, &db_role_name).await {
747 Ok(options) => {
748 db.set_connect_options(options);
749 }
750 Err(error) => {
751 tracing::error!(?error, "failed to refresh IAM pool connect options");
752 }
753 }
754
755 sleep(interval).await;
756 }
757}