1use crate::{DbErr, DbPool, ROOT_DATABASE_NAME, ROOT_DATABASE_ROLE_NAME, models::tenant::Tenant};
28use aws_config::SdkConfig;
29use aws_credential_types::provider::{ProvideCredentials, error::CredentialsError};
30use aws_sigv4::{
31 http_request::{SignableBody, SignableRequest, SigningError, SigningSettings, sign},
32 sign::v4::signing_params,
33};
34use docbox_secrets::{SecretManager, SecretManagerError};
35use moka::{future::Cache, policy::EvictionPolicy};
36use serde::{Deserialize, Serialize};
37use sqlx::{
38 PgPool,
39 postgres::{PgConnectOptions, PgPoolOptions},
40};
41use std::time::Duration;
42use std::{num::ParseIntError, str::ParseBoolError};
43use std::{sync::Arc, time::SystemTime};
44use thiserror::Error;
45use tokio::time::sleep;
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct DatabasePoolCacheConfig {
50 pub host: String,
52 pub port: u16,
54
55 pub root_secret_name: Option<String>,
58
59 #[serde(default)]
62 pub root_iam: bool,
63
64 pub max_connections: Option<u32>,
75
76 pub max_connections_root: Option<u32>,
88
89 pub acquire_timeout: Option<u64>,
94
95 pub idle_timeout: Option<u64>,
101
102 pub pool_timeout: Option<u64>,
107
108 pub cache_duration: Option<u64>,
113
114 pub cache_capacity: Option<u64>,
124
125 pub credentials_cache_duration: Option<u64>,
131
132 pub credentials_cache_capacity: Option<u64>,
137}
138
139impl Default for DatabasePoolCacheConfig {
140 fn default() -> Self {
141 Self {
142 host: Default::default(),
143 port: 5432,
144 root_secret_name: Default::default(),
145 root_iam: false,
146 max_connections: None,
147 max_connections_root: None,
148 acquire_timeout: None,
149 idle_timeout: None,
150 pool_timeout: None,
151 cache_duration: None,
152 cache_capacity: None,
153 credentials_cache_duration: None,
154 credentials_cache_capacity: None,
155 }
156 }
157}
158
159#[derive(Debug, Error)]
160pub enum DatabasePoolCacheConfigError {
161 #[error("missing DOCBOX_DB_HOST environment variable")]
162 MissingDatabaseHost,
163 #[error("missing DOCBOX_DB_PORT environment variable")]
164 MissingDatabasePort,
165 #[error("invalid DOCBOX_DB_PORT environment variable")]
166 InvalidDatabasePort,
167 #[error("missing DOCBOX_DB_CREDENTIAL_NAME environment variable")]
168 MissingDatabaseSecretName,
169 #[error("invalid DOCBOX_DB_POOL_TIMEOUT environment variable")]
170 InvalidPoolTimeout(ParseIntError),
171 #[error("invalid DOCBOX_DB_IDLE_TIMEOUT environment variable")]
172 InvalidIdleTimeout(ParseIntError),
173 #[error("invalid DOCBOX_DB_ACQUIRE_TIMEOUT environment variable")]
174 InvalidAcquireTimeout(ParseIntError),
175 #[error("invalid DOCBOX_DB_CACHE_DURATION environment variable")]
176 InvalidCacheDuration(ParseIntError),
177 #[error("invalid DOCBOX_DB_CACHE_CAPACITY environment variable")]
178 InvalidCacheCapacity(ParseIntError),
179 #[error("invalid DOCBOX_DB_CREDENTIALS_CACHE_DURATION environment variable")]
180 InvalidCredentialsCacheDuration(ParseIntError),
181 #[error("invalid DOCBOX_DB_CREDENTIALS_CACHE_CAPACITY environment variable")]
182 InvalidCredentialsCacheCapacity(ParseIntError),
183 #[error("invalid DOCBOX_DB_ROOT_IAM environment variable")]
184 InvalidRootIam(ParseBoolError),
185}
186
187impl DatabasePoolCacheConfig {
188 pub fn from_env() -> Result<DatabasePoolCacheConfig, DatabasePoolCacheConfigError> {
189 let db_host: String = std::env::var("DOCBOX_DB_HOST")
190 .or(std::env::var("POSTGRES_HOST"))
191 .map_err(|_| DatabasePoolCacheConfigError::MissingDatabaseHost)?;
192 let db_port: u16 = std::env::var("DOCBOX_DB_PORT")
193 .or(std::env::var("POSTGRES_PORT"))
194 .map_err(|_| DatabasePoolCacheConfigError::MissingDatabasePort)?
195 .parse()
196 .map_err(|_| DatabasePoolCacheConfigError::InvalidDatabasePort)?;
197
198 let db_root_secret_name = std::env::var("DOCBOX_DB_CREDENTIAL_NAME").ok();
199 let db_root_iam = std::env::var("DOCBOX_DB_ROOT_IAM")
200 .ok()
201 .map(|value| value.parse::<bool>())
202 .transpose()
203 .map_err(DatabasePoolCacheConfigError::InvalidRootIam)?
204 .unwrap_or_default();
205
206 if !db_root_iam && db_root_secret_name.is_none() {
208 return Err(DatabasePoolCacheConfigError::MissingDatabaseSecretName);
209 }
210
211 let max_connections: Option<u32> = std::env::var("DOCBOX_DB_MAX_CONNECTIONS")
212 .ok()
213 .and_then(|value| value.parse().ok());
214 let max_connections_root: Option<u32> = std::env::var("DOCBOX_DB_MAX_ROOT_CONNECTIONS")
215 .ok()
216 .and_then(|value| value.parse().ok());
217
218 let acquire_timeout: Option<u64> = match std::env::var("DOCBOX_DB_ACQUIRE_TIMEOUT") {
219 Ok(value) => Some(
220 value
221 .parse::<u64>()
222 .map_err(DatabasePoolCacheConfigError::InvalidAcquireTimeout)?,
223 ),
224 Err(_) => None,
225 };
226
227 let pool_timeout: Option<u64> = match std::env::var("DOCBOX_DB_POOL_TIMEOUT") {
228 Ok(value) => Some(
229 value
230 .parse::<u64>()
231 .map_err(DatabasePoolCacheConfigError::InvalidPoolTimeout)?,
232 ),
233 Err(_) => None,
234 };
235
236 let idle_timeout: Option<u64> = match std::env::var("DOCBOX_DB_IDLE_TIMEOUT") {
237 Ok(value) => Some(
238 value
239 .parse::<u64>()
240 .map_err(DatabasePoolCacheConfigError::InvalidIdleTimeout)?,
241 ),
242 Err(_) => None,
243 };
244
245 let cache_duration: Option<u64> = match std::env::var("DOCBOX_DB_CACHE_DURATION") {
246 Ok(value) => Some(
247 value
248 .parse::<u64>()
249 .map_err(DatabasePoolCacheConfigError::InvalidCacheDuration)?,
250 ),
251 Err(_) => None,
252 };
253
254 let cache_capacity: Option<u64> = match std::env::var("DOCBOX_DB_CACHE_CAPACITY") {
255 Ok(value) => Some(
256 value
257 .parse::<u64>()
258 .map_err(DatabasePoolCacheConfigError::InvalidCacheCapacity)?,
259 ),
260 Err(_) => None,
261 };
262
263 let credentials_cache_duration: Option<u64> =
264 match std::env::var("DOCBOX_DB_CREDENTIALS_CACHE_DURATION") {
265 Ok(value) => Some(
266 value
267 .parse::<u64>()
268 .map_err(DatabasePoolCacheConfigError::InvalidCredentialsCacheDuration)?,
269 ),
270 Err(_) => None,
271 };
272
273 let credentials_cache_capacity: Option<u64> =
274 match std::env::var("DOCBOX_DB_CREDENTIALS_CACHE_CAPACITY") {
275 Ok(value) => Some(
276 value
277 .parse::<u64>()
278 .map_err(DatabasePoolCacheConfigError::InvalidCredentialsCacheCapacity)?,
279 ),
280 Err(_) => None,
281 };
282
283 Ok(DatabasePoolCacheConfig {
284 host: db_host,
285 port: db_port,
286 root_iam: db_root_iam,
287 root_secret_name: db_root_secret_name,
288 max_connections,
289 max_connections_root,
290 acquire_timeout,
291 pool_timeout,
292 idle_timeout,
293 cache_duration,
294 cache_capacity,
295 credentials_cache_duration,
296 credentials_cache_capacity,
297 })
298 }
299}
300
301pub struct DatabasePoolCache {
303 aws_config: aws_config::SdkConfig,
305
306 host: String,
308
309 port: u16,
311
312 root_secret_name: Option<String>,
317
318 root_iam: bool,
321
322 cache: Cache<String, DbPool>,
324
325 connect_info_cache: Cache<String, DbSecrets>,
328
329 secrets_manager: SecretManager,
331
332 max_connections: u32,
334 max_connections_root: u32,
336
337 acquire_timeout: Duration,
338 idle_timeout: Duration,
339}
340
341#[derive(Debug, Clone, Serialize, Deserialize)]
343pub struct DbSecrets {
344 pub username: String,
345 pub password: String,
346}
347
348#[derive(Debug, Error)]
349pub enum DbConnectErr {
350 #[error("database credentials not found in secrets manager")]
351 MissingCredentials,
352
353 #[error(transparent)]
354 SecretsManager(Box<SecretManagerError>),
355
356 #[error(transparent)]
357 Db(#[from] DbErr),
358
359 #[error(transparent)]
360 Shared(#[from] Arc<DbConnectErr>),
361
362 #[error("missing aws credentials provider")]
363 MissingCredentialsProvider,
364
365 #[error("failed to provide aws credentials")]
366 AwsCredentials(#[from] CredentialsError),
367
368 #[error("aws configuration missing region")]
369 MissingRegion,
370
371 #[error("failed to build aws signature")]
372 AwsSigner(#[from] signing_params::BuildError),
373
374 #[error("failed to sign aws request")]
375 AwsRequestSign(#[from] SigningError),
376
377 #[error("failed to parse signed aws url")]
378 AwsSignerInvalidUrl(url::ParseError),
379
380 #[error("failed to connect to tenant missing both IAM and secrets fields")]
381 InvalidTenantConfiguration,
382}
383
384impl DatabasePoolCache {
385 pub fn from_config(
386 aws_config: aws_config::SdkConfig,
387 config: DatabasePoolCacheConfig,
388 secrets_manager: SecretManager,
389 ) -> Self {
390 let mut pool_timeout = Duration::from_secs(config.cache_duration.unwrap_or(60 * 60 * 48));
391 let cache_duration = Duration::from_secs(config.cache_duration.unwrap_or(60 * 60 * 48));
392 let credentials_cache_duration =
393 Duration::from_secs(config.credentials_cache_duration.unwrap_or(60 * 60 * 12));
394
395 if config.root_iam && config.pool_timeout.is_none() {
398 tracing::debug!(
399 "IAM database auth is enabled with no pool timeout, setting short pool timeout within token duration"
400 );
401 pool_timeout = Duration::from_secs(60 * 10);
402 }
403
404 let cache_capacity = config.cache_capacity.unwrap_or(50);
405 let credentials_cache_capacity = config.credentials_cache_capacity.unwrap_or(50);
406
407 let cache = Cache::builder()
408 .time_to_live(pool_timeout)
409 .time_to_idle(cache_duration)
410 .max_capacity(cache_capacity)
411 .eviction_policy(EvictionPolicy::tiny_lfu())
412 .async_eviction_listener(|cache_key: Arc<String>, pool: DbPool, _cause| {
413 Box::pin(async move {
414 tracing::debug!(?cache_key, "database pool is no longer in use, closing");
415 pool.close().await
416 })
417 })
418 .build();
419
420 let connect_info_cache = Cache::builder()
421 .time_to_idle(credentials_cache_duration)
422 .max_capacity(credentials_cache_capacity)
423 .eviction_policy(EvictionPolicy::tiny_lfu())
424 .build();
425
426 Self {
427 aws_config,
428 host: config.host,
429 port: config.port,
430 root_secret_name: config.root_secret_name,
431 root_iam: config.root_iam,
432 cache,
433 connect_info_cache,
434 secrets_manager,
435 max_connections: config.max_connections.unwrap_or(10),
436 max_connections_root: config.max_connections_root.unwrap_or(2),
437 idle_timeout: Duration::from_secs(config.idle_timeout.unwrap_or(60 * 10)),
438 acquire_timeout: Duration::from_secs(config.acquire_timeout.unwrap_or(60)),
439 }
440 }
441
442 pub async fn get_root_pool(&self) -> Result<PgPool, DbConnectErr> {
444 match (self.root_secret_name.as_ref(), self.root_iam) {
445 (_, true) => {
446 self.get_pool_iam(ROOT_DATABASE_NAME, ROOT_DATABASE_ROLE_NAME)
447 .await
448 }
449
450 (Some(db_secret_name), _) => self.get_pool(ROOT_DATABASE_NAME, db_secret_name).await,
451
452 _ => Err(DbConnectErr::InvalidTenantConfiguration),
453 }
454 }
455
456 pub async fn get_tenant_pool(&self, tenant: &Tenant) -> Result<DbPool, DbConnectErr> {
458 match (
459 tenant.db_iam_user_name.as_ref(),
460 tenant.db_secret_name.as_ref(),
461 ) {
462 (Some(db_iam_user_name), _) => {
463 self.get_pool_iam(&tenant.db_name, db_iam_user_name).await
464 }
465 (_, Some(db_secret_name)) => self.get_pool(&tenant.db_name, db_secret_name).await,
466
467 _ => Err(DbConnectErr::InvalidTenantConfiguration),
468 }
469 }
470
471 pub async fn close_tenant_pool(&self, tenant: &Tenant) {
474 let cache_key = Self::tenant_cache_key(tenant);
475 if let Some(pool) = self.cache.remove(&cache_key).await {
476 pool.close().await;
477 }
478
479 self.cache.run_pending_tasks().await;
481 }
482
483 fn tenant_cache_key(tenant: &Tenant) -> String {
486 match (
487 tenant.db_secret_name.as_ref(),
488 tenant.db_iam_user_name.as_ref(),
489 ) {
490 (Some(db_secret_name), _) => {
491 format!("secret-{}-{}", &tenant.db_name, db_secret_name)
492 }
493 (_, Some(db_iam_user_name)) => {
494 format!("user-{}-{}", &tenant.db_name, db_iam_user_name)
495 }
496
497 _ => format!("db-{}", &tenant.db_name),
498 }
499 }
500
501 pub async fn flush(&self) {
503 self.cache.invalidate_all();
505 self.connect_info_cache.invalidate_all();
506 self.cache.run_pending_tasks().await;
507 }
508
509 pub async fn close_all(&self) {
511 for (_, value) in self.cache.iter() {
512 value.close().await;
513 }
514
515 self.flush().await;
516 }
517
518 async fn get_pool(&self, db_name: &str, secret_name: &str) -> Result<DbPool, DbConnectErr> {
521 let cache_key = format!("secret-{db_name}-{secret_name}");
522
523 let pool = self
524 .cache
525 .try_get_with(cache_key, async {
526 tracing::debug!(?db_name, "acquiring database pool");
527
528 let pool = self
529 .create_pool(db_name, secret_name)
530 .await
531 .map_err(Arc::new)?;
532
533 Ok(pool)
534 })
535 .await?;
536
537 Ok(pool)
538 }
539
540 async fn get_pool_iam(
543 &self,
544 db_name: &str,
545 db_role_name: &str,
546 ) -> Result<DbPool, DbConnectErr> {
547 let cache_key = format!("user-{db_name}-{db_role_name}");
548
549 let pool = self
550 .cache
551 .try_get_with(cache_key, async {
552 tracing::debug!(?db_name, "acquiring database pool (iam)");
553
554 let pool = self
555 .create_pool_iam(db_name, db_role_name)
556 .await
557 .map_err(Arc::new)?;
558
559 Ok(pool)
560 })
561 .await?;
562
563 Ok(pool)
564 }
565
566 async fn get_credentials(&self, secret_name: &str) -> Result<DbSecrets, DbConnectErr> {
568 if let Some(connect_info) = self.connect_info_cache.get(secret_name).await {
569 return Ok(connect_info);
570 }
571
572 let credentials = self
574 .secrets_manager
575 .parsed_secret::<DbSecrets>(secret_name)
576 .await
577 .map_err(|err| DbConnectErr::SecretsManager(Box::new(err)))?
578 .ok_or(DbConnectErr::MissingCredentials)?;
579
580 self.connect_info_cache
582 .insert(secret_name.to_string(), credentials.clone())
583 .await;
584
585 Ok(credentials)
586 }
587
588 async fn create_pool_iam(
590 &self,
591 db_name: &str,
592 db_role_name: &str,
593 ) -> Result<DbPool, DbConnectErr> {
594 tracing::debug!(?db_name, ?db_role_name, "creating db pool connection");
595
596 let options = iam_pool_connect_options(
597 &self.aws_config,
598 &self.host,
599 self.port,
600 db_name,
601 db_role_name,
602 )
603 .await?;
604
605 let max_connections = match db_name {
606 ROOT_DATABASE_NAME => self.max_connections_root,
607 _ => self.max_connections,
608 };
609
610 let pool = PgPoolOptions::new()
611 .max_connections(max_connections)
612 .acquire_timeout(self.acquire_timeout)
614 .idle_timeout(self.idle_timeout)
616 .connect_with(options)
617 .await
618 .map_err(DbConnectErr::Db)?;
619
620 tokio::spawn(iam_pool_maintenance_task(
621 pool.clone(),
622 self.aws_config.clone(),
623 self.host.clone(),
624 self.port,
625 db_name.to_string(),
626 db_role_name.to_string(),
627 ));
628
629 Ok(pool)
630 }
631
632 async fn create_pool(&self, db_name: &str, secret_name: &str) -> Result<DbPool, DbConnectErr> {
634 tracing::debug!(?db_name, ?secret_name, "creating db pool connection");
635
636 let credentials = self.get_credentials(secret_name).await?;
637 let options = PgConnectOptions::new()
638 .host(&self.host)
639 .port(self.port)
640 .username(&credentials.username)
641 .password(&credentials.password)
642 .database(db_name);
643
644 let max_connections = match db_name {
645 ROOT_DATABASE_NAME => self.max_connections_root,
646 _ => self.max_connections,
647 };
648
649 match PgPoolOptions::new()
650 .max_connections(max_connections)
651 .acquire_timeout(self.acquire_timeout)
653 .idle_timeout(self.idle_timeout)
655 .connect_with(options)
656 .await
657 {
658 Ok(value) => Ok(value),
660 Err(err) => {
661 self.connect_info_cache.remove(secret_name).await;
663 Err(DbConnectErr::Db(err))
664 }
665 }
666 }
667}
668
669async fn iam_pool_connect_options(
670 aws_config: &SdkConfig,
671 host: &str,
672 port: u16,
673 db_name: &str,
674 db_role_name: &str,
675) -> Result<PgConnectOptions, DbConnectErr> {
676 let token = create_rds_signed_token(aws_config, host, port, db_role_name).await?;
677
678 let options = PgConnectOptions::new()
679 .host(host)
680 .port(port)
681 .username(db_role_name)
682 .password(&token)
683 .database(db_name);
684
685 Ok(options)
686}
687
688async fn create_rds_signed_token(
689 aws_config: &SdkConfig,
690 host: &str,
691 port: u16,
692 user: &str,
693) -> Result<String, DbConnectErr> {
694 let credentials_provider = aws_config
695 .credentials_provider()
696 .ok_or(DbConnectErr::MissingCredentialsProvider)?;
697 let credentials = credentials_provider.provide_credentials().await?;
698 let identity = credentials.into();
699 let region = aws_config.region().ok_or(DbConnectErr::MissingRegion)?;
700
701 let mut signing_settings = SigningSettings::default();
702 signing_settings.expires_in = Some(Duration::from_secs(60 * 15));
703 signing_settings.signature_location = aws_sigv4::http_request::SignatureLocation::QueryParams;
704
705 let signing_params = aws_sigv4::sign::v4::SigningParams::builder()
706 .identity(&identity)
707 .region(region.as_ref())
708 .name("rds-db")
709 .time(SystemTime::now())
710 .settings(signing_settings)
711 .build()?;
712
713 let url = format!("https://{host}:{port}/?Action=connect&DBUser={user}");
714
715 let signable_request =
716 SignableRequest::new("GET", &url, std::iter::empty(), SignableBody::Bytes(&[]))?;
717
718 let (signing_instructions, _signature) =
719 sign(signable_request, &signing_params.into())?.into_parts();
720
721 let mut url = url::Url::parse(&url).map_err(DbConnectErr::AwsSignerInvalidUrl)?;
722 for (name, value) in signing_instructions.params() {
723 url.query_pairs_mut().append_pair(name, value);
724 }
725
726 let response = url.to_string().split_off("https://".len());
727 Ok(response)
728}
729
730async fn iam_pool_maintenance_task(
733 db: DbPool,
734 aws_config: SdkConfig,
735 host: String,
736 port: u16,
737 db_name: String,
738 db_role_name: String,
739) {
740 let interval = Duration::from_secs(60 * 10);
741
742 loop {
743 if db.is_closed() {
744 return;
745 }
746
747 match iam_pool_connect_options(&aws_config, &host, port, &db_name, &db_role_name).await {
748 Ok(options) => {
749 db.set_connect_options(options);
750 }
751 Err(error) => {
752 tracing::error!(?error, "failed to refresh IAM pool connect options");
753 }
754 }
755
756 sleep(interval).await;
757 }
758}