docbox_database/
lib.rs

1use models::tenant::Tenant;
2use moka::{future::Cache, policy::EvictionPolicy};
3use serde::{Deserialize, Serialize};
4pub use sqlx::postgres::PgSslMode;
5pub use sqlx::{
6    PgPool, Postgres, Transaction,
7    postgres::{PgConnectOptions, PgPoolOptions},
8};
9
10use std::{error::Error, time::Duration};
11use thiserror::Error;
12use tracing::debug;
13
14pub use sqlx;
15pub use sqlx::PgExecutor as DbExecutor;
16
17pub mod create;
18pub mod migrations;
19pub mod models;
20
21/// Type of the database connection pool
22pub type DbPool = PgPool;
23
24/// Short type alias for a database error
25pub type DbErr = sqlx::Error;
26
27/// Type alias for a result where the error is a [DbErr]
28pub type DbResult<T> = Result<T, DbErr>;
29
30/// Type of a database transaction
31pub type DbTransaction<'c> = Transaction<'c, Postgres>;
32
33/// Duration to maintain database pool caches (48h)
34const DB_CACHE_DURATION: Duration = Duration::from_secs(60 * 60 * 48);
35
36/// Duration to cache database credentials for (12h)
37const DB_CONNECT_INFO_CACHE_DURATION: Duration = Duration::from_secs(60 * 60 * 12);
38
39/// Name of the root database
40pub const ROOT_DATABASE_NAME: &str = "docbox";
41
42///  Config for the database pool
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct DatabasePoolCacheConfig {
45    pub host: String,
46    pub port: u16,
47    pub root_secret_name: String,
48}
49
50#[derive(Debug, Error)]
51pub enum DatabasePoolCacheConfigError {
52    #[error("missing DOCBOX_DB_HOST environment variable")]
53    MissingDatabaseHost,
54    #[error("missing DOCBOX_DB_PORT environment variable")]
55    MissingDatabasePort,
56    #[error("invalid DOCBOX_DB_PORT environment variable")]
57    InvalidDatabasePort,
58    #[error("missing DOCBOX_DB_CREDENTIAL_NAME environment variable")]
59    MissingDatabaseSecretName,
60}
61
62impl DatabasePoolCacheConfig {
63    pub fn from_env() -> Result<DatabasePoolCacheConfig, DatabasePoolCacheConfigError> {
64        let db_host: String = std::env::var("DOCBOX_DB_HOST")
65            .or(std::env::var("POSTGRES_HOST"))
66            .map_err(|_| DatabasePoolCacheConfigError::MissingDatabaseHost)?;
67        let db_port: u16 = std::env::var("DOCBOX_DB_PORT")
68            .or(std::env::var("POSTGRES_PORT"))
69            .map_err(|_| DatabasePoolCacheConfigError::MissingDatabasePort)?
70            .parse()
71            .map_err(|_| DatabasePoolCacheConfigError::InvalidDatabasePort)?;
72        let db_root_secret_name = std::env::var("DOCBOX_DB_CREDENTIAL_NAME")
73            .map_err(|_| DatabasePoolCacheConfigError::MissingDatabaseSecretName)?;
74
75        Ok(DatabasePoolCacheConfig {
76            host: db_host,
77            port: db_port,
78            root_secret_name: db_root_secret_name,
79        })
80    }
81}
82
83/// Cache for database pools
84pub struct DatabasePoolCache<S: DbSecretManager> {
85    /// Database host
86    host: String,
87
88    /// Database port
89    port: u16,
90
91    /// Name of the secrets manager secret that contains
92    /// the credentials for the root "docbox" database
93    root_secret_name: String,
94
95    /// Cache from the database name to the pool for that database
96    cache: Cache<String, DbPool>,
97
98    /// Cache for the connection info details, stores the last known
99    /// credentials and the instant that they were obtained at
100    connect_info_cache: Cache<String, DbSecrets>,
101
102    /// Secrets manager access to load credentials
103    secrets_manager: S,
104}
105
106/// Username and password for a specific database
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct DbSecrets {
109    pub username: String,
110    pub password: String,
111}
112
113#[derive(Debug, Error)]
114pub enum DbConnectErr {
115    #[error("database credentials not found in secrets manager")]
116    MissingCredentials,
117
118    #[error(transparent)]
119    SecretsManager(Box<dyn Error + Send + Sync + 'static>),
120
121    #[error(transparent)]
122    Db(#[from] DbErr),
123}
124
125pub trait DbSecretManager: Send + Sync {
126    fn get_secret(
127        &self,
128        name: &str,
129    ) -> impl Future<Output = Result<Option<DbSecrets>, DbConnectErr>> + Send;
130}
131
132impl<S> DatabasePoolCache<S>
133where
134    S: DbSecretManager,
135{
136    pub fn from_config(config: DatabasePoolCacheConfig, secrets_manager: S) -> Self {
137        Self::new(
138            config.host,
139            config.port,
140            config.root_secret_name,
141            secrets_manager,
142        )
143    }
144
145    pub fn new(host: String, port: u16, root_secret_name: String, secrets_manager: S) -> Self {
146        let cache = Cache::builder()
147            .time_to_idle(DB_CACHE_DURATION)
148            .max_capacity(50)
149            .eviction_policy(EvictionPolicy::tiny_lfu())
150            .build();
151
152        let connect_info_cache = Cache::builder()
153            .time_to_idle(DB_CONNECT_INFO_CACHE_DURATION)
154            .max_capacity(50)
155            .eviction_policy(EvictionPolicy::tiny_lfu())
156            .build();
157
158        Self {
159            host,
160            port,
161            root_secret_name,
162            cache,
163            connect_info_cache,
164            secrets_manager,
165        }
166    }
167
168    /// Request a database pool for the root database
169    pub async fn get_root_pool(&self) -> Result<PgPool, DbConnectErr> {
170        self.get_pool(ROOT_DATABASE_NAME, &self.root_secret_name)
171            .await
172    }
173
174    /// Request a database pool for a specific tenant
175    pub async fn get_tenant_pool(&self, tenant: &Tenant) -> Result<PgPool, DbConnectErr> {
176        self.get_pool(&tenant.db_name, &tenant.db_secret_name).await
177    }
178
179    /// Empties all the caches
180    pub async fn flush(&self) {
181        // Clear cache
182        self.cache.invalidate_all();
183        self.connect_info_cache.invalidate_all();
184    }
185
186    /// Obtains a database pool connection to the database with the provided name
187    async fn get_pool(&self, db_name: &str, secret_name: &str) -> Result<PgPool, DbConnectErr> {
188        let cache_key = format!("{db_name}-{secret_name}");
189
190        if let Some(pool) = self.cache.get(&cache_key).await {
191            return Ok(pool);
192        }
193
194        let pool = self.create_pool(db_name, secret_name).await?;
195        self.cache.insert(cache_key, pool.clone()).await;
196
197        Ok(pool)
198    }
199
200    /// Obtains database connection info
201    async fn get_credentials(&self, secret_name: &str) -> Result<DbSecrets, DbConnectErr> {
202        if let Some(connect_info) = self.connect_info_cache.get(secret_name).await {
203            return Ok(connect_info);
204        }
205
206        // Load new credentials
207        let credentials = self
208            .secrets_manager
209            .get_secret(secret_name)
210            .await
211            .map_err(|err| DbConnectErr::SecretsManager(err.into()))?
212            .ok_or(DbConnectErr::MissingCredentials)?;
213
214        // Cache the credential
215        self.connect_info_cache
216            .insert(secret_name.to_string(), credentials.clone())
217            .await;
218
219        Ok(credentials)
220    }
221
222    /// Creates a database pool connection
223    async fn create_pool(&self, db_name: &str, secret_name: &str) -> Result<PgPool, DbConnectErr> {
224        debug!(?db_name, ?secret_name, "creating db pool connection");
225
226        let credentials = self.get_credentials(secret_name).await?;
227        let options = PgConnectOptions::new()
228            .host(&self.host)
229            .port(self.port)
230            .username(&credentials.username)
231            .password(&credentials.password)
232            .database(db_name);
233
234        match PgPoolOptions::new().connect_with(options).await {
235            // Success case
236            Ok(value) => Ok(value),
237            Err(err) => {
238                // Drop the connect info cache incase the credentials were wrong
239                self.connect_info_cache.remove(secret_name).await;
240                Err(DbConnectErr::Db(err))
241            }
242        }
243    }
244}