1use crate::{DbErr, DbPool, ROOT_DATABASE_NAME, ROOT_DATABASE_ROLE_NAME, models::tenant::Tenant};
27use aws_credential_types::provider::{ProvideCredentials, error::CredentialsError};
28use aws_sigv4::{
29 http_request::{SignableBody, SignableRequest, SigningError, SigningSettings, sign},
30 sign::v4::signing_params,
31};
32use docbox_secrets::{SecretManager, SecretManagerError};
33use moka::{future::Cache, policy::EvictionPolicy};
34use serde::{Deserialize, Serialize};
35use sqlx::{
36 PgPool,
37 postgres::{PgConnectOptions, PgPoolOptions},
38};
39use std::time::Duration;
40use std::{num::ParseIntError, str::ParseBoolError};
41use std::{sync::Arc, time::SystemTime};
42use thiserror::Error;
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct DatabasePoolCacheConfig {
47 pub host: String,
49 pub port: u16,
51
52 pub root_secret_name: Option<String>,
55
56 #[serde(default)]
59 pub root_iam: bool,
60
61 pub max_connections: Option<u32>,
72
73 pub max_connections_root: Option<u32>,
85
86 pub acquire_timeout: Option<u64>,
91
92 pub idle_timeout: Option<u64>,
98
99 pub pool_timeout: Option<u64>,
104
105 pub cache_duration: Option<u64>,
110
111 pub cache_capacity: Option<u64>,
121
122 pub credentials_cache_duration: Option<u64>,
128
129 pub credentials_cache_capacity: Option<u64>,
134}
135
136impl Default for DatabasePoolCacheConfig {
137 fn default() -> Self {
138 Self {
139 host: Default::default(),
140 port: 5432,
141 root_secret_name: Default::default(),
142 root_iam: false,
143 max_connections: None,
144 max_connections_root: None,
145 acquire_timeout: None,
146 idle_timeout: None,
147 pool_timeout: None,
148 cache_duration: None,
149 cache_capacity: None,
150 credentials_cache_duration: None,
151 credentials_cache_capacity: None,
152 }
153 }
154}
155
156#[derive(Debug, Error)]
157pub enum DatabasePoolCacheConfigError {
158 #[error("missing DOCBOX_DB_HOST environment variable")]
159 MissingDatabaseHost,
160 #[error("missing DOCBOX_DB_PORT environment variable")]
161 MissingDatabasePort,
162 #[error("invalid DOCBOX_DB_PORT environment variable")]
163 InvalidDatabasePort,
164 #[error("missing DOCBOX_DB_CREDENTIAL_NAME environment variable")]
165 MissingDatabaseSecretName,
166 #[error("invalid DOCBOX_DB_POOL_TIMEOUT environment variable")]
167 InvalidPoolTimeout(ParseIntError),
168 #[error("invalid DOCBOX_DB_IDLE_TIMEOUT environment variable")]
169 InvalidIdleTimeout(ParseIntError),
170 #[error("invalid DOCBOX_DB_ACQUIRE_TIMEOUT environment variable")]
171 InvalidAcquireTimeout(ParseIntError),
172 #[error("invalid DOCBOX_DB_CACHE_DURATION environment variable")]
173 InvalidCacheDuration(ParseIntError),
174 #[error("invalid DOCBOX_DB_CACHE_CAPACITY environment variable")]
175 InvalidCacheCapacity(ParseIntError),
176 #[error("invalid DOCBOX_DB_CREDENTIALS_CACHE_DURATION environment variable")]
177 InvalidCredentialsCacheDuration(ParseIntError),
178 #[error("invalid DOCBOX_DB_CREDENTIALS_CACHE_CAPACITY environment variable")]
179 InvalidCredentialsCacheCapacity(ParseIntError),
180 #[error("invalid DOCBOX_DB_ROOT_IAM environment variable")]
181 InvalidRootIam(ParseBoolError),
182}
183
184impl DatabasePoolCacheConfig {
185 pub fn from_env() -> Result<DatabasePoolCacheConfig, DatabasePoolCacheConfigError> {
186 let db_host: String = std::env::var("DOCBOX_DB_HOST")
187 .or(std::env::var("POSTGRES_HOST"))
188 .map_err(|_| DatabasePoolCacheConfigError::MissingDatabaseHost)?;
189 let db_port: u16 = std::env::var("DOCBOX_DB_PORT")
190 .or(std::env::var("POSTGRES_PORT"))
191 .map_err(|_| DatabasePoolCacheConfigError::MissingDatabasePort)?
192 .parse()
193 .map_err(|_| DatabasePoolCacheConfigError::InvalidDatabasePort)?;
194
195 let db_root_secret_name = std::env::var("DOCBOX_DB_CREDENTIAL_NAME").ok();
196 let db_root_iam = std::env::var("DOCBOX_DB_ROOT_IAM")
197 .ok()
198 .map(|value| value.parse::<bool>())
199 .transpose()
200 .map_err(DatabasePoolCacheConfigError::InvalidRootIam)?
201 .unwrap_or_default();
202
203 if !db_root_iam && db_root_secret_name.is_none() {
205 return Err(DatabasePoolCacheConfigError::MissingDatabaseSecretName);
206 }
207
208 let max_connections: Option<u32> = std::env::var("DOCBOX_DB_MAX_CONNECTIONS")
209 .ok()
210 .and_then(|value| value.parse().ok());
211 let max_connections_root: Option<u32> = std::env::var("DOCBOX_DB_MAX_ROOT_CONNECTIONS")
212 .ok()
213 .and_then(|value| value.parse().ok());
214
215 let acquire_timeout: Option<u64> = match std::env::var("DOCBOX_DB_ACQUIRE_TIMEOUT") {
216 Ok(value) => Some(
217 value
218 .parse::<u64>()
219 .map_err(DatabasePoolCacheConfigError::InvalidAcquireTimeout)?,
220 ),
221 Err(_) => None,
222 };
223
224 let pool_timeout: Option<u64> = match std::env::var("DOCBOX_DB_POOL_TIMEOUT") {
225 Ok(value) => Some(
226 value
227 .parse::<u64>()
228 .map_err(DatabasePoolCacheConfigError::InvalidPoolTimeout)?,
229 ),
230 Err(_) => None,
231 };
232
233 let idle_timeout: Option<u64> = match std::env::var("DOCBOX_DB_IDLE_TIMEOUT") {
234 Ok(value) => Some(
235 value
236 .parse::<u64>()
237 .map_err(DatabasePoolCacheConfigError::InvalidIdleTimeout)?,
238 ),
239 Err(_) => None,
240 };
241
242 let cache_duration: Option<u64> = match std::env::var("DOCBOX_DB_CACHE_DURATION") {
243 Ok(value) => Some(
244 value
245 .parse::<u64>()
246 .map_err(DatabasePoolCacheConfigError::InvalidCacheDuration)?,
247 ),
248 Err(_) => None,
249 };
250
251 let cache_capacity: Option<u64> = match std::env::var("DOCBOX_DB_CACHE_CAPACITY") {
252 Ok(value) => Some(
253 value
254 .parse::<u64>()
255 .map_err(DatabasePoolCacheConfigError::InvalidCacheCapacity)?,
256 ),
257 Err(_) => None,
258 };
259
260 let credentials_cache_duration: Option<u64> =
261 match std::env::var("DOCBOX_DB_CREDENTIALS_CACHE_DURATION") {
262 Ok(value) => Some(
263 value
264 .parse::<u64>()
265 .map_err(DatabasePoolCacheConfigError::InvalidCredentialsCacheDuration)?,
266 ),
267 Err(_) => None,
268 };
269
270 let credentials_cache_capacity: Option<u64> =
271 match std::env::var("DOCBOX_DB_CREDENTIALS_CACHE_CAPACITY") {
272 Ok(value) => Some(
273 value
274 .parse::<u64>()
275 .map_err(DatabasePoolCacheConfigError::InvalidCredentialsCacheCapacity)?,
276 ),
277 Err(_) => None,
278 };
279
280 Ok(DatabasePoolCacheConfig {
281 host: db_host,
282 port: db_port,
283 root_iam: db_root_iam,
284 root_secret_name: db_root_secret_name,
285 max_connections,
286 max_connections_root,
287 acquire_timeout,
288 pool_timeout,
289 idle_timeout,
290 cache_duration,
291 cache_capacity,
292 credentials_cache_duration,
293 credentials_cache_capacity,
294 })
295 }
296}
297
298pub struct DatabasePoolCache {
300 aws_config: aws_config::SdkConfig,
302
303 host: String,
305
306 port: u16,
308
309 root_secret_name: Option<String>,
314
315 root_iam: bool,
318
319 cache: Cache<String, DbPool>,
321
322 connect_info_cache: Cache<String, DbSecrets>,
325
326 secrets_manager: SecretManager,
328
329 max_connections: u32,
331 max_connections_root: u32,
333
334 acquire_timeout: Duration,
335 idle_timeout: Duration,
336}
337
338#[derive(Debug, Clone, Serialize, Deserialize)]
340pub struct DbSecrets {
341 pub username: String,
342 pub password: String,
343}
344
345#[derive(Debug, Error)]
346pub enum DbConnectErr {
347 #[error("database credentials not found in secrets manager")]
348 MissingCredentials,
349
350 #[error(transparent)]
351 SecretsManager(Box<SecretManagerError>),
352
353 #[error(transparent)]
354 Db(#[from] DbErr),
355
356 #[error(transparent)]
357 Shared(#[from] Arc<DbConnectErr>),
358
359 #[error("missing aws credentials provider")]
360 MissingCredentialsProvider,
361
362 #[error("failed to provide aws credentials")]
363 AwsCredentials(#[from] CredentialsError),
364
365 #[error("aws configuration missing region")]
366 MissingRegion,
367
368 #[error("failed to build aws signature")]
369 AwsSigner(#[from] signing_params::BuildError),
370
371 #[error("failed to sign aws request")]
372 AwsRequestSign(#[from] SigningError),
373
374 #[error("failed to parse signed aws url")]
375 AwsSignerInvalidUrl(url::ParseError),
376
377 #[error("failed to connect to tenant missing both IAM and secrets fields")]
378 InvalidTenantConfiguration,
379}
380
381impl DatabasePoolCache {
382 pub fn from_config(
383 aws_config: aws_config::SdkConfig,
384 config: DatabasePoolCacheConfig,
385 secrets_manager: SecretManager,
386 ) -> Self {
387 let mut pool_timeout = Duration::from_secs(config.cache_duration.unwrap_or(60 * 60 * 48));
388 let cache_duration = Duration::from_secs(config.cache_duration.unwrap_or(60 * 60 * 48));
389 let credentials_cache_duration =
390 Duration::from_secs(config.credentials_cache_duration.unwrap_or(60 * 60 * 12));
391
392 if config.root_iam && config.pool_timeout.is_none() {
395 tracing::debug!(
396 "IAM database auth is enabled with no pool timeout, setting short pool timeout within token duration"
397 );
398 pool_timeout = Duration::from_secs(60 * 15);
399 }
400
401 let cache_capacity = config.cache_capacity.unwrap_or(50);
402 let credentials_cache_capacity = config.credentials_cache_capacity.unwrap_or(50);
403
404 let cache = Cache::builder()
405 .time_to_live(pool_timeout)
406 .time_to_idle(cache_duration)
407 .max_capacity(cache_capacity)
408 .eviction_policy(EvictionPolicy::tiny_lfu())
409 .async_eviction_listener(|cache_key: Arc<String>, pool: DbPool, _cause| {
410 Box::pin(async move {
411 tracing::debug!(?cache_key, "database pool is no longer in use, closing");
412 pool.close().await
413 })
414 })
415 .build();
416
417 let connect_info_cache = Cache::builder()
418 .time_to_idle(credentials_cache_duration)
419 .max_capacity(credentials_cache_capacity)
420 .eviction_policy(EvictionPolicy::tiny_lfu())
421 .build();
422
423 Self {
424 aws_config,
425 host: config.host,
426 port: config.port,
427 root_secret_name: config.root_secret_name,
428 root_iam: config.root_iam,
429 cache,
430 connect_info_cache,
431 secrets_manager,
432 max_connections: config.max_connections.unwrap_or(10),
433 max_connections_root: config.max_connections_root.unwrap_or(2),
434 idle_timeout: Duration::from_secs(config.idle_timeout.unwrap_or(60 * 10)),
435 acquire_timeout: Duration::from_secs(config.acquire_timeout.unwrap_or(60)),
436 }
437 }
438
439 pub async fn get_root_pool(&self) -> Result<PgPool, DbConnectErr> {
441 match (self.root_secret_name.as_ref(), self.root_iam) {
442 (_, true) => {
443 self.get_pool_iam(ROOT_DATABASE_NAME, ROOT_DATABASE_ROLE_NAME)
444 .await
445 }
446
447 (Some(db_secret_name), _) => self.get_pool(ROOT_DATABASE_NAME, db_secret_name).await,
448
449 _ => Err(DbConnectErr::InvalidTenantConfiguration),
450 }
451 }
452
453 pub async fn get_tenant_pool(&self, tenant: &Tenant) -> Result<DbPool, DbConnectErr> {
455 match (
456 tenant.db_iam_user_name.as_ref(),
457 tenant.db_secret_name.as_ref(),
458 ) {
459 (Some(db_iam_user_name), _) => {
460 self.get_pool_iam(&tenant.db_name, db_iam_user_name).await
461 }
462 (_, Some(db_secret_name)) => self.get_pool(&tenant.db_name, db_secret_name).await,
463
464 _ => Err(DbConnectErr::InvalidTenantConfiguration),
465 }
466 }
467
468 pub async fn close_tenant_pool(&self, tenant: &Tenant) {
471 let cache_key = Self::tenant_cache_key(tenant);
472 if let Some(pool) = self.cache.remove(&cache_key).await {
473 pool.close().await;
474 }
475
476 self.cache.run_pending_tasks().await;
478 }
479
480 fn tenant_cache_key(tenant: &Tenant) -> String {
483 match (
484 tenant.db_secret_name.as_ref(),
485 tenant.db_iam_user_name.as_ref(),
486 ) {
487 (Some(db_secret_name), _) => {
488 format!("secret-{}-{}", &tenant.db_name, db_secret_name)
489 }
490 (_, Some(db_iam_user_name)) => {
491 format!("user-{}-{}", &tenant.db_name, db_iam_user_name)
492 }
493
494 _ => format!("db-{}", &tenant.db_name),
495 }
496 }
497
498 pub async fn flush(&self) {
500 self.cache.invalidate_all();
502 self.connect_info_cache.invalidate_all();
503 self.cache.run_pending_tasks().await;
504 }
505
506 pub async fn close_all(&self) {
508 for (_, value) in self.cache.iter() {
509 value.close().await;
510 }
511
512 self.flush().await;
513 }
514
515 async fn get_pool(&self, db_name: &str, secret_name: &str) -> Result<DbPool, DbConnectErr> {
518 let cache_key = format!("secret-{db_name}-{secret_name}");
519
520 let pool = self
521 .cache
522 .try_get_with(cache_key, async {
523 tracing::debug!(?db_name, "acquiring database pool");
524
525 let pool = self
526 .create_pool(db_name, secret_name)
527 .await
528 .map_err(Arc::new)?;
529
530 Ok(pool)
531 })
532 .await?;
533
534 Ok(pool)
535 }
536
537 async fn get_pool_iam(
540 &self,
541 db_name: &str,
542 db_role_name: &str,
543 ) -> Result<DbPool, DbConnectErr> {
544 let cache_key = format!("user-{db_name}-{db_role_name}");
545
546 let pool = self
547 .cache
548 .try_get_with(cache_key, async {
549 tracing::debug!(?db_name, "acquiring database pool (iam)");
550
551 let pool = self
552 .create_pool_iam(db_name, db_role_name)
553 .await
554 .map_err(Arc::new)?;
555
556 Ok(pool)
557 })
558 .await?;
559
560 Ok(pool)
561 }
562
563 async fn get_credentials(&self, secret_name: &str) -> Result<DbSecrets, DbConnectErr> {
565 if let Some(connect_info) = self.connect_info_cache.get(secret_name).await {
566 return Ok(connect_info);
567 }
568
569 let credentials = self
571 .secrets_manager
572 .parsed_secret::<DbSecrets>(secret_name)
573 .await
574 .map_err(|err| DbConnectErr::SecretsManager(Box::new(err)))?
575 .ok_or(DbConnectErr::MissingCredentials)?;
576
577 self.connect_info_cache
579 .insert(secret_name.to_string(), credentials.clone())
580 .await;
581
582 Ok(credentials)
583 }
584
585 async fn create_rds_signed_token(
586 &self,
587 host: &str,
588 port: u16,
589 user: &str,
590 ) -> Result<String, DbConnectErr> {
591 let credentials_provider = self
592 .aws_config
593 .credentials_provider()
594 .ok_or(DbConnectErr::MissingCredentialsProvider)?;
595 let credentials = credentials_provider.provide_credentials().await?;
596 let identity = credentials.into();
597 let region = self
598 .aws_config
599 .region()
600 .ok_or(DbConnectErr::MissingRegion)?;
601
602 let mut signing_settings = SigningSettings::default();
603 signing_settings.expires_in = Some(Duration::from_secs(60 * 30));
604 signing_settings.signature_location =
605 aws_sigv4::http_request::SignatureLocation::QueryParams;
606
607 let signing_params = aws_sigv4::sign::v4::SigningParams::builder()
608 .identity(&identity)
609 .region(region.as_ref())
610 .name("rds-db")
611 .time(SystemTime::now())
612 .settings(signing_settings)
613 .build()?;
614
615 let url = format!("https://{host}:{port}/?Action=connect&DBUser={user}");
616
617 let signable_request =
618 SignableRequest::new("GET", &url, std::iter::empty(), SignableBody::Bytes(&[]))?;
619
620 let (signing_instructions, _signature) =
621 sign(signable_request, &signing_params.into())?.into_parts();
622
623 let mut url = url::Url::parse(&url).map_err(DbConnectErr::AwsSignerInvalidUrl)?;
624 for (name, value) in signing_instructions.params() {
625 url.query_pairs_mut().append_pair(name, value);
626 }
627
628 let response = url.to_string().split_off("https://".len());
629 Ok(response)
630 }
631
632 async fn create_pool_iam(
634 &self,
635 db_name: &str,
636 db_role_name: &str,
637 ) -> Result<DbPool, DbConnectErr> {
638 tracing::debug!(?db_name, ?db_role_name, "creating db pool connection");
639
640 let token = self
641 .create_rds_signed_token(&self.host, self.port, db_role_name)
642 .await?;
643
644 let options = PgConnectOptions::new()
645 .host(&self.host)
646 .port(self.port)
647 .username(db_role_name)
648 .password(&token)
649 .database(db_name);
650
651 let max_connections = match db_name {
652 ROOT_DATABASE_NAME => self.max_connections_root,
653 _ => self.max_connections,
654 };
655
656 PgPoolOptions::new()
657 .max_connections(max_connections)
658 .acquire_timeout(self.acquire_timeout)
660 .idle_timeout(self.idle_timeout)
662 .connect_with(options)
663 .await
664 .map_err(DbConnectErr::Db)
665 }
666
667 async fn create_pool(&self, db_name: &str, secret_name: &str) -> Result<DbPool, DbConnectErr> {
669 tracing::debug!(?db_name, ?secret_name, "creating db pool connection");
670
671 let credentials = self.get_credentials(secret_name).await?;
672 let options = PgConnectOptions::new()
673 .host(&self.host)
674 .port(self.port)
675 .username(&credentials.username)
676 .password(&credentials.password)
677 .database(db_name);
678
679 let max_connections = match db_name {
680 ROOT_DATABASE_NAME => self.max_connections_root,
681 _ => self.max_connections,
682 };
683
684 match PgPoolOptions::new()
685 .max_connections(max_connections)
686 .acquire_timeout(self.acquire_timeout)
688 .idle_timeout(self.idle_timeout)
690 .connect_with(options)
691 .await
692 {
693 Ok(value) => Ok(value),
695 Err(err) => {
696 self.connect_info_cache.remove(secret_name).await;
698 Err(DbConnectErr::Db(err))
699 }
700 }
701 }
702}