1use docbox_secrets::AppSecretManager;
2use models::tenant::Tenant;
3use moka::{future::Cache, policy::EvictionPolicy};
4use serde::{Deserialize, Serialize};
5pub use sqlx::postgres::PgSslMode;
6pub use sqlx::{
7 PgPool, Postgres, Transaction,
8 postgres::{PgConnectOptions, PgPoolOptions},
9};
10
11use std::sync::Arc;
12use std::{error::Error, time::Duration};
13use thiserror::Error;
14use tracing::debug;
15
16pub use sqlx;
17pub use sqlx::PgExecutor as DbExecutor;
18
19pub mod create;
20pub mod migrations;
21pub mod models;
22
23pub type DbPool = PgPool;
25
26pub type DbErr = sqlx::Error;
28
29pub type DbResult<T> = Result<T, DbErr>;
31
32pub type DbTransaction<'c> = Transaction<'c, Postgres>;
34
35const DB_CACHE_DURATION: Duration = Duration::from_secs(60 * 60 * 48);
37
38const DB_CONNECT_INFO_CACHE_DURATION: Duration = Duration::from_secs(60 * 60 * 12);
40
41pub const ROOT_DATABASE_NAME: &str = "docbox";
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct DatabasePoolCacheConfig {
47 pub host: String,
48 pub port: u16,
49 pub root_secret_name: String,
50}
51
52#[derive(Debug, Error)]
53pub enum DatabasePoolCacheConfigError {
54 #[error("missing DOCBOX_DB_HOST environment variable")]
55 MissingDatabaseHost,
56 #[error("missing DOCBOX_DB_PORT environment variable")]
57 MissingDatabasePort,
58 #[error("invalid DOCBOX_DB_PORT environment variable")]
59 InvalidDatabasePort,
60 #[error("missing DOCBOX_DB_CREDENTIAL_NAME environment variable")]
61 MissingDatabaseSecretName,
62}
63
64impl DatabasePoolCacheConfig {
65 pub fn from_env() -> Result<DatabasePoolCacheConfig, DatabasePoolCacheConfigError> {
66 let db_host: String = std::env::var("DOCBOX_DB_HOST")
67 .or(std::env::var("POSTGRES_HOST"))
68 .map_err(|_| DatabasePoolCacheConfigError::MissingDatabaseHost)?;
69 let db_port: u16 = std::env::var("DOCBOX_DB_PORT")
70 .or(std::env::var("POSTGRES_PORT"))
71 .map_err(|_| DatabasePoolCacheConfigError::MissingDatabasePort)?
72 .parse()
73 .map_err(|_| DatabasePoolCacheConfigError::InvalidDatabasePort)?;
74 let db_root_secret_name = std::env::var("DOCBOX_DB_CREDENTIAL_NAME")
75 .map_err(|_| DatabasePoolCacheConfigError::MissingDatabaseSecretName)?;
76
77 Ok(DatabasePoolCacheConfig {
78 host: db_host,
79 port: db_port,
80 root_secret_name: db_root_secret_name,
81 })
82 }
83}
84
85pub struct DatabasePoolCache {
87 host: String,
89
90 port: u16,
92
93 root_secret_name: String,
96
97 cache: Cache<String, DbPool>,
99
100 connect_info_cache: Cache<String, DbSecrets>,
103
104 secrets_manager: Arc<AppSecretManager>,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct DbSecrets {
111 pub username: String,
112 pub password: String,
113}
114
115#[derive(Debug, Error)]
116pub enum DbConnectErr {
117 #[error("database credentials not found in secrets manager")]
118 MissingCredentials,
119
120 #[error(transparent)]
121 SecretsManager(Box<dyn Error + Send + Sync + 'static>),
122
123 #[error(transparent)]
124 Db(#[from] DbErr),
125}
126
127impl DatabasePoolCache {
128 pub fn from_config(
129 config: DatabasePoolCacheConfig,
130 secrets_manager: Arc<AppSecretManager>,
131 ) -> Self {
132 Self::new(
133 config.host,
134 config.port,
135 config.root_secret_name,
136 secrets_manager,
137 )
138 }
139
140 pub fn new(
141 host: String,
142 port: u16,
143 root_secret_name: String,
144 secrets_manager: Arc<AppSecretManager>,
145 ) -> Self {
146 let cache = Cache::builder()
147 .time_to_idle(DB_CACHE_DURATION)
148 .max_capacity(50)
149 .eviction_policy(EvictionPolicy::tiny_lfu())
150 .build();
151
152 let connect_info_cache = Cache::builder()
153 .time_to_idle(DB_CONNECT_INFO_CACHE_DURATION)
154 .max_capacity(50)
155 .eviction_policy(EvictionPolicy::tiny_lfu())
156 .build();
157
158 Self {
159 host,
160 port,
161 root_secret_name,
162 cache,
163 connect_info_cache,
164 secrets_manager,
165 }
166 }
167
168 pub async fn get_root_pool(&self) -> Result<PgPool, DbConnectErr> {
170 self.get_pool(ROOT_DATABASE_NAME, &self.root_secret_name)
171 .await
172 }
173
174 pub async fn get_tenant_pool(&self, tenant: &Tenant) -> Result<PgPool, DbConnectErr> {
176 self.get_pool(&tenant.db_name, &tenant.db_secret_name).await
177 }
178
179 pub async fn flush(&self) {
181 self.cache.invalidate_all();
183 self.connect_info_cache.invalidate_all();
184 }
185
186 async fn get_pool(&self, db_name: &str, secret_name: &str) -> Result<PgPool, DbConnectErr> {
188 let cache_key = format!("{db_name}-{secret_name}");
189
190 if let Some(pool) = self.cache.get(&cache_key).await {
191 return Ok(pool);
192 }
193
194 let pool = self.create_pool(db_name, secret_name).await?;
195 self.cache.insert(cache_key, pool.clone()).await;
196
197 Ok(pool)
198 }
199
200 async fn get_credentials(&self, secret_name: &str) -> Result<DbSecrets, DbConnectErr> {
202 if let Some(connect_info) = self.connect_info_cache.get(secret_name).await {
203 return Ok(connect_info);
204 }
205
206 let credentials = self
208 .secrets_manager
209 .parsed_secret::<DbSecrets>(secret_name)
210 .await
211 .map_err(|err| DbConnectErr::SecretsManager(err.into()))?
212 .ok_or(DbConnectErr::MissingCredentials)?;
213
214 self.connect_info_cache
216 .insert(secret_name.to_string(), credentials.clone())
217 .await;
218
219 Ok(credentials)
220 }
221
222 async fn create_pool(&self, db_name: &str, secret_name: &str) -> Result<PgPool, DbConnectErr> {
224 debug!(?db_name, ?secret_name, "creating db pool connection");
225
226 let credentials = self.get_credentials(secret_name).await?;
227 let options = PgConnectOptions::new()
228 .host(&self.host)
229 .port(self.port)
230 .username(&credentials.username)
231 .password(&credentials.password)
232 .database(db_name);
233
234 match PgPoolOptions::new().connect_with(options).await {
235 Ok(value) => Ok(value),
237 Err(err) => {
238 self.connect_info_cache.remove(secret_name).await;
240 Err(DbConnectErr::Db(err))
241 }
242 }
243 }
244}