1use models::tenant::Tenant;
2use moka::{future::Cache, policy::EvictionPolicy};
3use serde::{Deserialize, Serialize};
4pub use sqlx::postgres::PgSslMode;
5pub use sqlx::{
6 PgPool, Postgres, Transaction,
7 postgres::{PgConnectOptions, PgPoolOptions},
8};
9
10use std::{error::Error, time::Duration};
11use thiserror::Error;
12use tracing::debug;
13
14pub use sqlx;
15pub use sqlx::PgExecutor as DbExecutor;
16
17pub mod create;
18pub mod migrations;
19pub mod models;
20
21pub type DbPool = PgPool;
23
24pub type DbErr = sqlx::Error;
26
27pub type DbResult<T> = Result<T, DbErr>;
29
30pub type DbTransaction<'c> = Transaction<'c, Postgres>;
32
33const DB_CACHE_DURATION: Duration = Duration::from_secs(60 * 60 * 48);
35
36const DB_CONNECT_INFO_CACHE_DURATION: Duration = Duration::from_secs(60 * 60 * 12);
38
39pub const ROOT_DATABASE_NAME: &str = "docbox";
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct DatabasePoolCacheConfig {
45 pub host: String,
46 pub port: u16,
47 pub root_secret_name: String,
48}
49
50#[derive(Debug, Error)]
51pub enum DatabasePoolCacheConfigError {
52 #[error("missing DOCBOX_DB_HOST environment variable")]
53 MissingDatabaseHost,
54 #[error("missing DOCBOX_DB_PORT environment variable")]
55 MissingDatabasePort,
56 #[error("invalid DOCBOX_DB_PORT environment variable")]
57 InvalidDatabasePort,
58 #[error("missing DOCBOX_DB_CREDENTIAL_NAME environment variable")]
59 MissingDatabaseSecretName,
60}
61
62impl DatabasePoolCacheConfig {
63 pub fn from_env() -> Result<DatabasePoolCacheConfig, DatabasePoolCacheConfigError> {
64 let db_host: String = std::env::var("DOCBOX_DB_HOST")
65 .or(std::env::var("POSTGRES_HOST"))
66 .map_err(|_| DatabasePoolCacheConfigError::MissingDatabaseHost)?;
67 let db_port: u16 = std::env::var("DOCBOX_DB_PORT")
68 .or(std::env::var("POSTGRES_PORT"))
69 .map_err(|_| DatabasePoolCacheConfigError::MissingDatabasePort)?
70 .parse()
71 .map_err(|_| DatabasePoolCacheConfigError::InvalidDatabasePort)?;
72 let db_root_secret_name = std::env::var("DOCBOX_DB_CREDENTIAL_NAME")
73 .map_err(|_| DatabasePoolCacheConfigError::MissingDatabaseSecretName)?;
74
75 Ok(DatabasePoolCacheConfig {
76 host: db_host,
77 port: db_port,
78 root_secret_name: db_root_secret_name,
79 })
80 }
81}
82
83pub struct DatabasePoolCache<S: DbSecretManager> {
85 host: String,
87
88 port: u16,
90
91 root_secret_name: String,
94
95 cache: Cache<String, DbPool>,
97
98 connect_info_cache: Cache<String, DbSecrets>,
101
102 secrets_manager: S,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct DbSecrets {
109 pub username: String,
110 pub password: String,
111}
112
113#[derive(Debug, Error)]
114pub enum DbConnectErr {
115 #[error("database credentials not found in secrets manager")]
116 MissingCredentials,
117
118 #[error(transparent)]
119 SecretsManager(Box<dyn Error + Send + Sync + 'static>),
120
121 #[error(transparent)]
122 Db(#[from] DbErr),
123}
124
125pub trait DbSecretManager: Send + Sync {
126 fn get_secret(
127 &self,
128 name: &str,
129 ) -> impl Future<Output = Result<Option<DbSecrets>, DbConnectErr>> + Send;
130}
131
132impl<S> DatabasePoolCache<S>
133where
134 S: DbSecretManager,
135{
136 pub fn from_config(config: DatabasePoolCacheConfig, secrets_manager: S) -> Self {
137 Self::new(
138 config.host,
139 config.port,
140 config.root_secret_name,
141 secrets_manager,
142 )
143 }
144
145 pub fn new(host: String, port: u16, root_secret_name: String, secrets_manager: S) -> Self {
146 let cache = Cache::builder()
147 .time_to_idle(DB_CACHE_DURATION)
148 .max_capacity(50)
149 .eviction_policy(EvictionPolicy::tiny_lfu())
150 .build();
151
152 let connect_info_cache = Cache::builder()
153 .time_to_idle(DB_CONNECT_INFO_CACHE_DURATION)
154 .max_capacity(50)
155 .eviction_policy(EvictionPolicy::tiny_lfu())
156 .build();
157
158 Self {
159 host,
160 port,
161 root_secret_name,
162 cache,
163 connect_info_cache,
164 secrets_manager,
165 }
166 }
167
168 pub async fn get_root_pool(&self) -> Result<PgPool, DbConnectErr> {
170 self.get_pool(ROOT_DATABASE_NAME, &self.root_secret_name)
171 .await
172 }
173
174 pub async fn get_tenant_pool(&self, tenant: &Tenant) -> Result<PgPool, DbConnectErr> {
176 self.get_pool(&tenant.db_name, &tenant.db_secret_name).await
177 }
178
179 pub async fn flush(&self) {
181 self.cache.invalidate_all();
183 self.connect_info_cache.invalidate_all();
184 }
185
186 async fn get_pool(&self, db_name: &str, secret_name: &str) -> Result<PgPool, DbConnectErr> {
188 let cache_key = format!("{db_name}-{secret_name}");
189
190 if let Some(pool) = self.cache.get(&cache_key).await {
191 return Ok(pool);
192 }
193
194 let pool = self.create_pool(db_name, secret_name).await?;
195 self.cache.insert(cache_key, pool.clone()).await;
196
197 Ok(pool)
198 }
199
200 async fn get_credentials(&self, secret_name: &str) -> Result<DbSecrets, DbConnectErr> {
202 if let Some(connect_info) = self.connect_info_cache.get(secret_name).await {
203 return Ok(connect_info);
204 }
205
206 let credentials = self
208 .secrets_manager
209 .get_secret(secret_name)
210 .await
211 .map_err(|err| DbConnectErr::SecretsManager(err.into()))?
212 .ok_or(DbConnectErr::MissingCredentials)?;
213
214 self.connect_info_cache
216 .insert(secret_name.to_string(), credentials.clone())
217 .await;
218
219 Ok(credentials)
220 }
221
222 async fn create_pool(&self, db_name: &str, secret_name: &str) -> Result<PgPool, DbConnectErr> {
224 debug!(?db_name, ?secret_name, "creating db pool connection");
225
226 let credentials = self.get_credentials(secret_name).await?;
227 let options = PgConnectOptions::new()
228 .host(&self.host)
229 .port(self.port)
230 .username(&credentials.username)
231 .password(&credentials.password)
232 .database(db_name);
233
234 match PgPoolOptions::new().connect_with(options).await {
235 Ok(value) => Ok(value),
237 Err(err) => {
238 self.connect_info_cache.remove(secret_name).await;
240 Err(DbConnectErr::Db(err))
241 }
242 }
243 }
244}