1use crate::{DbErr, DbPool, ROOT_DATABASE_NAME, ROOT_DATABASE_ROLE_NAME, models::tenant::Tenant};
27use aws_credential_types::provider::{ProvideCredentials, error::CredentialsError};
28use aws_sigv4::{
29 http_request::{SignableBody, SignableRequest, SigningError, SigningSettings, sign},
30 sign::v4::signing_params,
31};
32use docbox_secrets::{SecretManager, SecretManagerError};
33use moka::{future::Cache, policy::EvictionPolicy};
34use serde::{Deserialize, Serialize};
35use sqlx::{
36 PgPool,
37 postgres::{PgConnectOptions, PgPoolOptions},
38};
39use std::time::Duration;
40use std::{num::ParseIntError, str::ParseBoolError};
41use std::{sync::Arc, time::SystemTime};
42use thiserror::Error;
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct DatabasePoolCacheConfig {
47 pub host: String,
49 pub port: u16,
51
52 pub root_secret_name: Option<String>,
55
56 #[serde(default)]
59 pub root_iam: bool,
60
61 pub max_connections: Option<u32>,
72
73 pub max_connections_root: Option<u32>,
85
86 pub acquire_timeout: Option<u64>,
91
92 pub idle_timeout: Option<u64>,
98
99 pub cache_duration: Option<u64>,
104
105 pub cache_capacity: Option<u64>,
115
116 pub credentials_cache_duration: Option<u64>,
122
123 pub credentials_cache_capacity: Option<u64>,
128}
129
130impl Default for DatabasePoolCacheConfig {
131 fn default() -> Self {
132 Self {
133 host: Default::default(),
134 port: 5432,
135 root_secret_name: Default::default(),
136 root_iam: false,
137 max_connections: None,
138 max_connections_root: None,
139 acquire_timeout: None,
140 idle_timeout: None,
141 cache_duration: None,
142 cache_capacity: None,
143 credentials_cache_duration: None,
144 credentials_cache_capacity: None,
145 }
146 }
147}
148
149#[derive(Debug, Error)]
150pub enum DatabasePoolCacheConfigError {
151 #[error("missing DOCBOX_DB_HOST environment variable")]
152 MissingDatabaseHost,
153 #[error("missing DOCBOX_DB_PORT environment variable")]
154 MissingDatabasePort,
155 #[error("invalid DOCBOX_DB_PORT environment variable")]
156 InvalidDatabasePort,
157 #[error("missing DOCBOX_DB_CREDENTIAL_NAME environment variable")]
158 MissingDatabaseSecretName,
159 #[error("invalid DOCBOX_DB_IDLE_TIMEOUT environment variable")]
160 InvalidIdleTimeout(ParseIntError),
161 #[error("invalid DOCBOX_DB_ACQUIRE_TIMEOUT environment variable")]
162 InvalidAcquireTimeout(ParseIntError),
163 #[error("invalid DOCBOX_DB_CACHE_DURATION environment variable")]
164 InvalidCacheDuration(ParseIntError),
165 #[error("invalid DOCBOX_DB_CACHE_CAPACITY environment variable")]
166 InvalidCacheCapacity(ParseIntError),
167 #[error("invalid DOCBOX_DB_CREDENTIALS_CACHE_DURATION environment variable")]
168 InvalidCredentialsCacheDuration(ParseIntError),
169 #[error("invalid DOCBOX_DB_CREDENTIALS_CACHE_CAPACITY environment variable")]
170 InvalidCredentialsCacheCapacity(ParseIntError),
171 #[error("invalid DOCBOX_DB_ROOT_IAM environment variable")]
172 InvalidRootIam(ParseBoolError),
173}
174
175impl DatabasePoolCacheConfig {
176 pub fn from_env() -> Result<DatabasePoolCacheConfig, DatabasePoolCacheConfigError> {
177 let db_host: String = std::env::var("DOCBOX_DB_HOST")
178 .or(std::env::var("POSTGRES_HOST"))
179 .map_err(|_| DatabasePoolCacheConfigError::MissingDatabaseHost)?;
180 let db_port: u16 = std::env::var("DOCBOX_DB_PORT")
181 .or(std::env::var("POSTGRES_PORT"))
182 .map_err(|_| DatabasePoolCacheConfigError::MissingDatabasePort)?
183 .parse()
184 .map_err(|_| DatabasePoolCacheConfigError::InvalidDatabasePort)?;
185
186 let db_root_secret_name = std::env::var("DOCBOX_DB_CREDENTIAL_NAME").ok();
187 let db_root_iam = std::env::var("DOCBOX_DB_ROOT_IAM")
188 .ok()
189 .map(|value| value.parse::<bool>())
190 .transpose()
191 .map_err(DatabasePoolCacheConfigError::InvalidRootIam)?
192 .unwrap_or_default();
193
194 if !db_root_iam && db_root_secret_name.is_none() {
196 return Err(DatabasePoolCacheConfigError::MissingDatabaseSecretName);
197 }
198
199 let max_connections: Option<u32> = std::env::var("DOCBOX_DB_MAX_CONNECTIONS")
200 .ok()
201 .and_then(|value| value.parse().ok());
202 let max_connections_root: Option<u32> = std::env::var("DOCBOX_DB_MAX_ROOT_CONNECTIONS")
203 .ok()
204 .and_then(|value| value.parse().ok());
205
206 let acquire_timeout: Option<u64> = match std::env::var("DOCBOX_DB_ACQUIRE_TIMEOUT") {
207 Ok(value) => Some(
208 value
209 .parse::<u64>()
210 .map_err(DatabasePoolCacheConfigError::InvalidAcquireTimeout)?,
211 ),
212 Err(_) => None,
213 };
214
215 let idle_timeout: Option<u64> = match std::env::var("DOCBOX_DB_IDLE_TIMEOUT") {
216 Ok(value) => Some(
217 value
218 .parse::<u64>()
219 .map_err(DatabasePoolCacheConfigError::InvalidIdleTimeout)?,
220 ),
221 Err(_) => None,
222 };
223
224 let cache_duration: Option<u64> = match std::env::var("DOCBOX_DB_CACHE_DURATION") {
225 Ok(value) => Some(
226 value
227 .parse::<u64>()
228 .map_err(DatabasePoolCacheConfigError::InvalidCacheDuration)?,
229 ),
230 Err(_) => None,
231 };
232
233 let cache_capacity: Option<u64> = match std::env::var("DOCBOX_DB_CACHE_CAPACITY") {
234 Ok(value) => Some(
235 value
236 .parse::<u64>()
237 .map_err(DatabasePoolCacheConfigError::InvalidCacheCapacity)?,
238 ),
239 Err(_) => None,
240 };
241
242 let credentials_cache_duration: Option<u64> =
243 match std::env::var("DOCBOX_DB_CREDENTIALS_CACHE_DURATION") {
244 Ok(value) => Some(
245 value
246 .parse::<u64>()
247 .map_err(DatabasePoolCacheConfigError::InvalidCredentialsCacheDuration)?,
248 ),
249 Err(_) => None,
250 };
251
252 let credentials_cache_capacity: Option<u64> =
253 match std::env::var("DOCBOX_DB_CREDENTIALS_CACHE_CAPACITY") {
254 Ok(value) => Some(
255 value
256 .parse::<u64>()
257 .map_err(DatabasePoolCacheConfigError::InvalidCredentialsCacheCapacity)?,
258 ),
259 Err(_) => None,
260 };
261
262 Ok(DatabasePoolCacheConfig {
263 host: db_host,
264 port: db_port,
265 root_iam: db_root_iam,
266 root_secret_name: db_root_secret_name,
267 max_connections,
268 max_connections_root,
269 acquire_timeout,
270 idle_timeout,
271 cache_duration,
272 cache_capacity,
273 credentials_cache_duration,
274 credentials_cache_capacity,
275 })
276 }
277}
278
279pub struct DatabasePoolCache {
281 aws_config: aws_config::SdkConfig,
283
284 host: String,
286
287 port: u16,
289
290 root_secret_name: Option<String>,
295
296 root_iam: bool,
299
300 cache: Cache<String, DbPool>,
302
303 connect_info_cache: Cache<String, DbSecrets>,
306
307 secrets_manager: SecretManager,
309
310 max_connections: u32,
312 max_connections_root: u32,
314
315 acquire_timeout: Duration,
316 idle_timeout: Duration,
317}
318
319#[derive(Debug, Clone, Serialize, Deserialize)]
321pub struct DbSecrets {
322 pub username: String,
323 pub password: String,
324}
325
326#[derive(Debug, Error)]
327pub enum DbConnectErr {
328 #[error("database credentials not found in secrets manager")]
329 MissingCredentials,
330
331 #[error(transparent)]
332 SecretsManager(Box<SecretManagerError>),
333
334 #[error(transparent)]
335 Db(#[from] DbErr),
336
337 #[error(transparent)]
338 Shared(#[from] Arc<DbConnectErr>),
339
340 #[error("missing aws credentials provider")]
341 MissingCredentialsProvider,
342
343 #[error("failed to provide aws credentials")]
344 AwsCredentials(#[from] CredentialsError),
345
346 #[error("aws configuration missing region")]
347 MissingRegion,
348
349 #[error("failed to build aws signature")]
350 AwsSigner(#[from] signing_params::BuildError),
351
352 #[error("failed to sign aws request")]
353 AwsRequestSign(#[from] SigningError),
354
355 #[error("failed to parse signed aws url")]
356 AwsSignerInvalidUrl(url::ParseError),
357
358 #[error("failed to connect to tenant missing both IAM and secrets fields")]
359 InvalidTenantConfiguration,
360}
361
362impl DatabasePoolCache {
363 pub fn from_config(
364 aws_config: aws_config::SdkConfig,
365 config: DatabasePoolCacheConfig,
366 secrets_manager: SecretManager,
367 ) -> Self {
368 let cache_duration = Duration::from_secs(config.cache_duration.unwrap_or(60 * 60 * 48));
369 let credentials_cache_duration =
370 Duration::from_secs(config.credentials_cache_duration.unwrap_or(60 * 60 * 12));
371
372 let cache_capacity = config.cache_capacity.unwrap_or(50);
373 let credentials_cache_capacity = config.credentials_cache_capacity.unwrap_or(50);
374
375 let cache = Cache::builder()
376 .time_to_idle(cache_duration)
377 .max_capacity(cache_capacity)
378 .eviction_policy(EvictionPolicy::tiny_lfu())
379 .async_eviction_listener(|cache_key: Arc<String>, pool: DbPool, _cause| {
380 Box::pin(async move {
381 tracing::debug!(?cache_key, "database pool is no longer in use, closing");
382 pool.close().await
383 })
384 })
385 .build();
386
387 let connect_info_cache = Cache::builder()
388 .time_to_idle(credentials_cache_duration)
389 .max_capacity(credentials_cache_capacity)
390 .eviction_policy(EvictionPolicy::tiny_lfu())
391 .build();
392
393 Self {
394 aws_config,
395 host: config.host,
396 port: config.port,
397 root_secret_name: config.root_secret_name,
398 root_iam: config.root_iam,
399 cache,
400 connect_info_cache,
401 secrets_manager,
402 max_connections: config.max_connections.unwrap_or(10),
403 max_connections_root: config.max_connections_root.unwrap_or(2),
404 idle_timeout: Duration::from_secs(config.idle_timeout.unwrap_or(60 * 10)),
405 acquire_timeout: Duration::from_secs(config.acquire_timeout.unwrap_or(60)),
406 }
407 }
408
409 pub async fn get_root_pool(&self) -> Result<PgPool, DbConnectErr> {
411 match (self.root_secret_name.as_ref(), self.root_iam) {
412 (_, true) => {
413 self.get_pool_iam(ROOT_DATABASE_NAME, ROOT_DATABASE_ROLE_NAME)
414 .await
415 }
416
417 (Some(db_secret_name), _) => self.get_pool(ROOT_DATABASE_NAME, db_secret_name).await,
418
419 _ => Err(DbConnectErr::InvalidTenantConfiguration),
420 }
421 }
422
423 pub async fn get_tenant_pool(&self, tenant: &Tenant) -> Result<DbPool, DbConnectErr> {
425 match (
426 tenant.db_iam_user_name.as_ref(),
427 tenant.db_secret_name.as_ref(),
428 ) {
429 (Some(db_iam_user_name), _) => {
430 self.get_pool_iam(&tenant.db_name, db_iam_user_name).await
431 }
432 (_, Some(db_secret_name)) => self.get_pool(&tenant.db_name, db_secret_name).await,
433
434 _ => Err(DbConnectErr::InvalidTenantConfiguration),
435 }
436 }
437
438 pub async fn close_tenant_pool(&self, tenant: &Tenant) {
441 let cache_key = Self::tenant_cache_key(tenant);
442 if let Some(pool) = self.cache.remove(&cache_key).await {
443 pool.close().await;
444 }
445
446 self.cache.run_pending_tasks().await;
448 }
449
450 fn tenant_cache_key(tenant: &Tenant) -> String {
453 match (
454 tenant.db_secret_name.as_ref(),
455 tenant.db_iam_user_name.as_ref(),
456 ) {
457 (Some(db_secret_name), _) => {
458 format!("secret-{}-{}", &tenant.db_name, db_secret_name)
459 }
460 (_, Some(db_iam_user_name)) => {
461 format!("user-{}-{}", &tenant.db_name, db_iam_user_name)
462 }
463
464 _ => format!("db-{}", &tenant.db_name),
465 }
466 }
467
468 pub async fn flush(&self) {
470 self.cache.invalidate_all();
472 self.connect_info_cache.invalidate_all();
473 self.cache.run_pending_tasks().await;
474 }
475
476 pub async fn close_all(&self) {
478 for (_, value) in self.cache.iter() {
479 value.close().await;
480 }
481
482 self.flush().await;
483 }
484
485 async fn get_pool(&self, db_name: &str, secret_name: &str) -> Result<DbPool, DbConnectErr> {
488 let cache_key = format!("secret-{db_name}-{secret_name}");
489
490 let pool = self
491 .cache
492 .try_get_with(cache_key, async {
493 tracing::debug!(?db_name, "acquiring database pool");
494
495 let pool = self
496 .create_pool(db_name, secret_name)
497 .await
498 .map_err(Arc::new)?;
499
500 Ok(pool)
501 })
502 .await?;
503
504 Ok(pool)
505 }
506
507 async fn get_pool_iam(
510 &self,
511 db_name: &str,
512 db_role_name: &str,
513 ) -> Result<DbPool, DbConnectErr> {
514 let cache_key = format!("user-{db_name}-{db_role_name}");
515
516 let pool = self
517 .cache
518 .try_get_with(cache_key, async {
519 tracing::debug!(?db_name, "acquiring database pool (iam)");
520
521 let pool = self
522 .create_pool_iam(db_name, db_role_name)
523 .await
524 .map_err(Arc::new)?;
525
526 Ok(pool)
527 })
528 .await?;
529
530 Ok(pool)
531 }
532
533 async fn get_credentials(&self, secret_name: &str) -> Result<DbSecrets, DbConnectErr> {
535 if let Some(connect_info) = self.connect_info_cache.get(secret_name).await {
536 return Ok(connect_info);
537 }
538
539 let credentials = self
541 .secrets_manager
542 .parsed_secret::<DbSecrets>(secret_name)
543 .await
544 .map_err(|err| DbConnectErr::SecretsManager(Box::new(err)))?
545 .ok_or(DbConnectErr::MissingCredentials)?;
546
547 self.connect_info_cache
549 .insert(secret_name.to_string(), credentials.clone())
550 .await;
551
552 Ok(credentials)
553 }
554
555 async fn create_rds_signed_token(
556 &self,
557 host: &str,
558 port: u16,
559 user: &str,
560 ) -> Result<String, DbConnectErr> {
561 let credentials_provider = self
562 .aws_config
563 .credentials_provider()
564 .ok_or(DbConnectErr::MissingCredentialsProvider)?;
565 let credentials = credentials_provider.provide_credentials().await?;
566 let identity = credentials.into();
567 let region = self
568 .aws_config
569 .region()
570 .ok_or(DbConnectErr::MissingRegion)?;
571
572 let mut signing_settings = SigningSettings::default();
573 signing_settings.expires_in = Some(Duration::from_secs(60 * 15));
574 signing_settings.signature_location =
575 aws_sigv4::http_request::SignatureLocation::QueryParams;
576
577 let signing_params = aws_sigv4::sign::v4::SigningParams::builder()
578 .identity(&identity)
579 .region(region.as_ref())
580 .name("rds-db")
581 .time(SystemTime::now())
582 .settings(signing_settings)
583 .build()?;
584
585 let url = format!("https://{host}:{port}/?Action=connect&DBUser={user}");
586
587 let signable_request =
588 SignableRequest::new("GET", &url, std::iter::empty(), SignableBody::Bytes(&[]))?;
589
590 let (signing_instructions, _signature) =
591 sign(signable_request, &signing_params.into())?.into_parts();
592
593 let mut url = url::Url::parse(&url).map_err(DbConnectErr::AwsSignerInvalidUrl)?;
594 for (name, value) in signing_instructions.params() {
595 url.query_pairs_mut().append_pair(name, value);
596 }
597
598 let response = url.to_string().split_off("https://".len());
599 Ok(response)
600 }
601
602 async fn create_pool_iam(
604 &self,
605 db_name: &str,
606 db_role_name: &str,
607 ) -> Result<DbPool, DbConnectErr> {
608 tracing::debug!(?db_name, ?db_role_name, "creating db pool connection");
609
610 let token = self
611 .create_rds_signed_token(&self.host, self.port, db_role_name)
612 .await?;
613
614 let options = PgConnectOptions::new()
615 .host(&self.host)
616 .port(self.port)
617 .username(db_role_name)
618 .password(&token)
619 .database(db_name);
620
621 let max_connections = match db_name {
622 ROOT_DATABASE_NAME => self.max_connections_root,
623 _ => self.max_connections,
624 };
625
626 PgPoolOptions::new()
627 .max_connections(max_connections)
628 .acquire_timeout(self.acquire_timeout)
630 .idle_timeout(self.idle_timeout)
632 .connect_with(options)
633 .await
634 .map_err(DbConnectErr::Db)
635 }
636
637 async fn create_pool(&self, db_name: &str, secret_name: &str) -> Result<DbPool, DbConnectErr> {
639 tracing::debug!(?db_name, ?secret_name, "creating db pool connection");
640
641 let credentials = self.get_credentials(secret_name).await?;
642 let options = PgConnectOptions::new()
643 .host(&self.host)
644 .port(self.port)
645 .username(&credentials.username)
646 .password(&credentials.password)
647 .database(db_name);
648
649 let max_connections = match db_name {
650 ROOT_DATABASE_NAME => self.max_connections_root,
651 _ => self.max_connections,
652 };
653
654 match PgPoolOptions::new()
655 .max_connections(max_connections)
656 .acquire_timeout(self.acquire_timeout)
658 .idle_timeout(self.idle_timeout)
660 .connect_with(options)
661 .await
662 {
663 Ok(value) => Ok(value),
665 Err(err) => {
666 self.connect_info_cache.remove(secret_name).await;
668 Err(DbConnectErr::Db(err))
669 }
670 }
671 }
672}