1use models::tenant::Tenant;
2use moka::{future::Cache, policy::EvictionPolicy};
3use serde::{Deserialize, Serialize};
4pub use sqlx::postgres::PgSslMode;
5pub use sqlx::{
6 PgPool, Postgres, Transaction,
7 postgres::{PgConnectOptions, PgPoolOptions},
8};
9
10use std::{error::Error, time::Duration};
11use thiserror::Error;
12use tracing::debug;
13
14pub use sqlx;
15pub use sqlx::PgExecutor as DbExecutor;
16
17pub mod create;
18pub mod migrations;
19pub mod models;
20
21pub type DbPool = PgPool;
23
24pub type DbErr = sqlx::Error;
26
27pub type DbResult<T> = Result<T, DbErr>;
29
30pub type DbTransaction<'c> = Transaction<'c, Postgres>;
32
33const DB_CACHE_DURATION: Duration = Duration::from_secs(60 * 60 * 48);
35
36const DB_CONNECT_INFO_CACHE_DURATION: Duration = Duration::from_secs(60 * 60 * 12);
38
39pub const ROOT_DATABASE_NAME: &str = "docbox";
41
42pub struct DatabasePoolCache<S: DbSecretManager> {
44 host: String,
46
47 port: u16,
49
50 root_secret_name: String,
53
54 cache: Cache<String, DbPool>,
56
57 connect_info_cache: Cache<String, DbSecrets>,
60
61 secrets_manager: S,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct DbSecrets {
68 pub username: String,
69 pub password: String,
70}
71
72#[derive(Debug, Error)]
73pub enum DbConnectErr {
74 #[error("database credentials not found in secrets manager")]
75 MissingCredentials,
76
77 #[error(transparent)]
78 SecretsManager(Box<dyn Error + Send + Sync + 'static>),
79
80 #[error(transparent)]
81 Db(#[from] DbErr),
82}
83
84pub trait DbSecretManager: Send + Sync {
85 fn get_secret(
86 &self,
87 name: &str,
88 ) -> impl Future<Output = Result<Option<DbSecrets>, DbConnectErr>> + Send;
89}
90
91impl<S> DatabasePoolCache<S>
92where
93 S: DbSecretManager,
94{
95 pub fn new(host: String, port: u16, root_secret_name: String, secrets_manager: S) -> Self {
96 let cache = Cache::builder()
97 .time_to_idle(DB_CACHE_DURATION)
98 .max_capacity(50)
99 .eviction_policy(EvictionPolicy::tiny_lfu())
100 .build();
101
102 let connect_info_cache = Cache::builder()
103 .time_to_idle(DB_CONNECT_INFO_CACHE_DURATION)
104 .max_capacity(50)
105 .eviction_policy(EvictionPolicy::tiny_lfu())
106 .build();
107
108 Self {
109 host,
110 port,
111 root_secret_name,
112 cache,
113 connect_info_cache,
114 secrets_manager,
115 }
116 }
117
118 pub async fn get_root_pool(&self) -> Result<PgPool, DbConnectErr> {
120 self.get_pool(ROOT_DATABASE_NAME, &self.root_secret_name)
121 .await
122 }
123
124 pub async fn get_tenant_pool(&self, tenant: &Tenant) -> Result<PgPool, DbConnectErr> {
126 self.get_pool(&tenant.db_name, &tenant.db_secret_name).await
127 }
128
129 pub async fn flush(&self) {
131 self.cache.invalidate_all();
133 self.connect_info_cache.invalidate_all();
134 }
135
136 async fn get_pool(&self, db_name: &str, secret_name: &str) -> Result<PgPool, DbConnectErr> {
138 let cache_key = format!("{}-{}", db_name, secret_name);
139
140 if let Some(pool) = self.cache.get(&cache_key).await {
141 return Ok(pool);
142 }
143
144 let pool = self.create_pool(db_name, secret_name).await?;
145 self.cache.insert(cache_key, pool.clone()).await;
146
147 Ok(pool)
148 }
149
150 async fn get_credentials(&self, secret_name: &str) -> Result<DbSecrets, DbConnectErr> {
152 if let Some(connect_info) = self.connect_info_cache.get(secret_name).await {
153 return Ok(connect_info);
154 }
155
156 let credentials = self
158 .secrets_manager
159 .get_secret(secret_name)
160 .await
161 .map_err(|err| DbConnectErr::SecretsManager(err.into()))?
162 .ok_or(DbConnectErr::MissingCredentials)?;
163
164 self.connect_info_cache
166 .insert(secret_name.to_string(), credentials.clone())
167 .await;
168
169 Ok(credentials)
170 }
171
172 async fn create_pool(&self, db_name: &str, secret_name: &str) -> Result<PgPool, DbConnectErr> {
174 debug!(?db_name, ?secret_name, "creating db pool connection");
175
176 let credentials = self.get_credentials(secret_name).await?;
177 let options = PgConnectOptions::new()
178 .host(&self.host)
179 .port(self.port)
180 .username(&credentials.username)
181 .password(&credentials.password)
182 .database(db_name);
183
184 match PgPoolOptions::new().connect_with(options).await {
185 Ok(value) => Ok(value),
187 Err(err) => {
188 self.connect_info_cache.remove(secret_name).await;
190 Err(DbConnectErr::Db(err))
191 }
192 }
193 }
194}