docbox_database/
lib.rs

1use docbox_secrets::AppSecretManager;
2use models::tenant::Tenant;
3use moka::{future::Cache, policy::EvictionPolicy};
4use serde::{Deserialize, Serialize};
5pub use sqlx::postgres::PgSslMode;
6pub use sqlx::{
7    PgPool, Postgres, Transaction,
8    postgres::{PgConnectOptions, PgPoolOptions},
9};
10
11use std::sync::Arc;
12use std::{error::Error, time::Duration};
13use thiserror::Error;
14use tracing::debug;
15
16pub use sqlx;
17pub use sqlx::PgExecutor as DbExecutor;
18
19pub mod create;
20pub mod migrations;
21pub mod models;
22
23/// Type of the database connection pool
24pub type DbPool = PgPool;
25
26/// Short type alias for a database error
27pub type DbErr = sqlx::Error;
28
29/// Type alias for a result where the error is a [DbErr]
30pub type DbResult<T> = Result<T, DbErr>;
31
32/// Type of a database transaction
33pub type DbTransaction<'c> = Transaction<'c, Postgres>;
34
35/// Duration to maintain database pool caches (48h)
36const DB_CACHE_DURATION: Duration = Duration::from_secs(60 * 60 * 48);
37
38/// Duration to cache database credentials for (12h)
39const DB_CONNECT_INFO_CACHE_DURATION: Duration = Duration::from_secs(60 * 60 * 12);
40
41/// Name of the root database
42pub const ROOT_DATABASE_NAME: &str = "docbox";
43
44///  Config for the database pool
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct DatabasePoolCacheConfig {
47    pub host: String,
48    pub port: u16,
49    pub root_secret_name: String,
50}
51
52#[derive(Debug, Error)]
53pub enum DatabasePoolCacheConfigError {
54    #[error("missing DOCBOX_DB_HOST environment variable")]
55    MissingDatabaseHost,
56    #[error("missing DOCBOX_DB_PORT environment variable")]
57    MissingDatabasePort,
58    #[error("invalid DOCBOX_DB_PORT environment variable")]
59    InvalidDatabasePort,
60    #[error("missing DOCBOX_DB_CREDENTIAL_NAME environment variable")]
61    MissingDatabaseSecretName,
62}
63
64impl DatabasePoolCacheConfig {
65    pub fn from_env() -> Result<DatabasePoolCacheConfig, DatabasePoolCacheConfigError> {
66        let db_host: String = std::env::var("DOCBOX_DB_HOST")
67            .or(std::env::var("POSTGRES_HOST"))
68            .map_err(|_| DatabasePoolCacheConfigError::MissingDatabaseHost)?;
69        let db_port: u16 = std::env::var("DOCBOX_DB_PORT")
70            .or(std::env::var("POSTGRES_PORT"))
71            .map_err(|_| DatabasePoolCacheConfigError::MissingDatabasePort)?
72            .parse()
73            .map_err(|_| DatabasePoolCacheConfigError::InvalidDatabasePort)?;
74        let db_root_secret_name = std::env::var("DOCBOX_DB_CREDENTIAL_NAME")
75            .map_err(|_| DatabasePoolCacheConfigError::MissingDatabaseSecretName)?;
76
77        Ok(DatabasePoolCacheConfig {
78            host: db_host,
79            port: db_port,
80            root_secret_name: db_root_secret_name,
81        })
82    }
83}
84
85/// Cache for database pools
86pub struct DatabasePoolCache {
87    /// Database host
88    host: String,
89
90    /// Database port
91    port: u16,
92
93    /// Name of the secrets manager secret that contains
94    /// the credentials for the root "docbox" database
95    root_secret_name: String,
96
97    /// Cache from the database name to the pool for that database
98    cache: Cache<String, DbPool>,
99
100    /// Cache for the connection info details, stores the last known
101    /// credentials and the instant that they were obtained at
102    connect_info_cache: Cache<String, DbSecrets>,
103
104    /// Secrets manager access to load credentials
105    secrets_manager: Arc<AppSecretManager>,
106}
107
108/// Username and password for a specific database
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct DbSecrets {
111    pub username: String,
112    pub password: String,
113}
114
115#[derive(Debug, Error)]
116pub enum DbConnectErr {
117    #[error("database credentials not found in secrets manager")]
118    MissingCredentials,
119
120    #[error(transparent)]
121    SecretsManager(Box<dyn Error + Send + Sync + 'static>),
122
123    #[error(transparent)]
124    Db(#[from] DbErr),
125}
126
127impl DatabasePoolCache {
128    pub fn from_config(
129        config: DatabasePoolCacheConfig,
130        secrets_manager: Arc<AppSecretManager>,
131    ) -> Self {
132        Self::new(
133            config.host,
134            config.port,
135            config.root_secret_name,
136            secrets_manager,
137        )
138    }
139
140    pub fn new(
141        host: String,
142        port: u16,
143        root_secret_name: String,
144        secrets_manager: Arc<AppSecretManager>,
145    ) -> Self {
146        let cache = Cache::builder()
147            .time_to_idle(DB_CACHE_DURATION)
148            .max_capacity(50)
149            .eviction_policy(EvictionPolicy::tiny_lfu())
150            .build();
151
152        let connect_info_cache = Cache::builder()
153            .time_to_idle(DB_CONNECT_INFO_CACHE_DURATION)
154            .max_capacity(50)
155            .eviction_policy(EvictionPolicy::tiny_lfu())
156            .build();
157
158        Self {
159            host,
160            port,
161            root_secret_name,
162            cache,
163            connect_info_cache,
164            secrets_manager,
165        }
166    }
167
168    /// Request a database pool for the root database
169    pub async fn get_root_pool(&self) -> Result<PgPool, DbConnectErr> {
170        self.get_pool(ROOT_DATABASE_NAME, &self.root_secret_name)
171            .await
172    }
173
174    /// Request a database pool for a specific tenant
175    pub async fn get_tenant_pool(&self, tenant: &Tenant) -> Result<PgPool, DbConnectErr> {
176        self.get_pool(&tenant.db_name, &tenant.db_secret_name).await
177    }
178
179    /// Empties all the caches
180    pub async fn flush(&self) {
181        // Clear cache
182        self.cache.invalidate_all();
183        self.connect_info_cache.invalidate_all();
184    }
185
186    /// Obtains a database pool connection to the database with the provided name
187    async fn get_pool(&self, db_name: &str, secret_name: &str) -> Result<PgPool, DbConnectErr> {
188        let cache_key = format!("{db_name}-{secret_name}");
189
190        if let Some(pool) = self.cache.get(&cache_key).await {
191            return Ok(pool);
192        }
193
194        let pool = self.create_pool(db_name, secret_name).await?;
195        self.cache.insert(cache_key, pool.clone()).await;
196
197        Ok(pool)
198    }
199
200    /// Obtains database connection info
201    async fn get_credentials(&self, secret_name: &str) -> Result<DbSecrets, DbConnectErr> {
202        if let Some(connect_info) = self.connect_info_cache.get(secret_name).await {
203            return Ok(connect_info);
204        }
205
206        // Load new credentials
207        let credentials = self
208            .secrets_manager
209            .parsed_secret::<DbSecrets>(secret_name)
210            .await
211            .map_err(|err| DbConnectErr::SecretsManager(err.into()))?
212            .ok_or(DbConnectErr::MissingCredentials)?;
213
214        // Cache the credential
215        self.connect_info_cache
216            .insert(secret_name.to_string(), credentials.clone())
217            .await;
218
219        Ok(credentials)
220    }
221
222    /// Creates a database pool connection
223    async fn create_pool(&self, db_name: &str, secret_name: &str) -> Result<PgPool, DbConnectErr> {
224        debug!(?db_name, ?secret_name, "creating db pool connection");
225
226        let credentials = self.get_credentials(secret_name).await?;
227        let options = PgConnectOptions::new()
228            .host(&self.host)
229            .port(self.port)
230            .username(&credentials.username)
231            .password(&credentials.password)
232            .database(db_name);
233
234        match PgPoolOptions::new().connect_with(options).await {
235            // Success case
236            Ok(value) => Ok(value),
237            Err(err) => {
238                // Drop the connect info cache in case the credentials were wrong
239                self.connect_info_cache.remove(secret_name).await;
240                Err(DbConnectErr::Db(err))
241            }
242        }
243    }
244}