docbox_database/
lib.rs

1use models::tenant::Tenant;
2use moka::{future::Cache, policy::EvictionPolicy};
3use serde::{Deserialize, Serialize};
4pub use sqlx::postgres::PgSslMode;
5pub use sqlx::{
6    PgPool, Postgres, Transaction,
7    postgres::{PgConnectOptions, PgPoolOptions},
8};
9
10use std::{error::Error, time::Duration};
11use thiserror::Error;
12use tracing::debug;
13
14pub use sqlx;
15pub use sqlx::PgExecutor as DbExecutor;
16
17pub mod create;
18pub mod migrations;
19pub mod models;
20
21/// Type of the database connection pool
22pub type DbPool = PgPool;
23
24/// Short type alias for a database error
25pub type DbErr = sqlx::Error;
26
27/// Type alias for a result where the error is a [DbErr]
28pub type DbResult<T> = Result<T, DbErr>;
29
30/// Type of a database transaction
31pub type DbTransaction<'c> = Transaction<'c, Postgres>;
32
33/// Duration to maintain database pool caches (48h)
34const DB_CACHE_DURATION: Duration = Duration::from_secs(60 * 60 * 48);
35
36/// Duration to cache database credentials for (12h)
37const DB_CONNECT_INFO_CACHE_DURATION: Duration = Duration::from_secs(60 * 60 * 12);
38
39/// Name of the root database
40pub const ROOT_DATABASE_NAME: &str = "docbox";
41
42/// Cache for database pools
43pub struct DatabasePoolCache<S: DbSecretManager> {
44    /// Database host
45    host: String,
46
47    /// Database port
48    port: u16,
49
50    /// Name of the secrets manager secret that contains
51    /// the credentials for the root "docbox" database
52    root_secret_name: String,
53
54    /// Cache from the database name to the pool for that database
55    cache: Cache<String, DbPool>,
56
57    /// Cache for the connection info details, stores the last known
58    /// credentials and the instant that they were obtained at
59    connect_info_cache: Cache<String, DbSecrets>,
60
61    /// Secrets manager access to load credentials
62    secrets_manager: S,
63}
64
65/// Username and password for a specific database
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct DbSecrets {
68    pub username: String,
69    pub password: String,
70}
71
72#[derive(Debug, Error)]
73pub enum DbConnectErr {
74    #[error("database credentials not found in secrets manager")]
75    MissingCredentials,
76
77    #[error(transparent)]
78    SecretsManager(Box<dyn Error + Send + Sync + 'static>),
79
80    #[error(transparent)]
81    Db(#[from] DbErr),
82}
83
84pub trait DbSecretManager: Send + Sync {
85    fn get_secret(
86        &self,
87        name: &str,
88    ) -> impl Future<Output = Result<Option<DbSecrets>, DbConnectErr>> + Send;
89}
90
91impl<S> DatabasePoolCache<S>
92where
93    S: DbSecretManager,
94{
95    pub fn new(host: String, port: u16, root_secret_name: String, secrets_manager: S) -> Self {
96        let cache = Cache::builder()
97            .time_to_idle(DB_CACHE_DURATION)
98            .max_capacity(50)
99            .eviction_policy(EvictionPolicy::tiny_lfu())
100            .build();
101
102        let connect_info_cache = Cache::builder()
103            .time_to_idle(DB_CONNECT_INFO_CACHE_DURATION)
104            .max_capacity(50)
105            .eviction_policy(EvictionPolicy::tiny_lfu())
106            .build();
107
108        Self {
109            host,
110            port,
111            root_secret_name,
112            cache,
113            connect_info_cache,
114            secrets_manager,
115        }
116    }
117
118    /// Request a database pool for the root database
119    pub async fn get_root_pool(&self) -> Result<PgPool, DbConnectErr> {
120        self.get_pool(ROOT_DATABASE_NAME, &self.root_secret_name)
121            .await
122    }
123
124    /// Request a database pool for a specific tenant
125    pub async fn get_tenant_pool(&self, tenant: &Tenant) -> Result<PgPool, DbConnectErr> {
126        self.get_pool(&tenant.db_name, &tenant.db_secret_name).await
127    }
128
129    /// Empties all the caches
130    pub async fn flush(&self) {
131        // Clear cache
132        self.cache.invalidate_all();
133        self.connect_info_cache.invalidate_all();
134    }
135
136    /// Obtains a database pool connection to the database with the provided name
137    async fn get_pool(&self, db_name: &str, secret_name: &str) -> Result<PgPool, DbConnectErr> {
138        let cache_key = format!("{}-{}", db_name, secret_name);
139
140        if let Some(pool) = self.cache.get(&cache_key).await {
141            return Ok(pool);
142        }
143
144        let pool = self.create_pool(db_name, secret_name).await?;
145        self.cache.insert(cache_key, pool.clone()).await;
146
147        Ok(pool)
148    }
149
150    /// Obtains database connection info
151    async fn get_credentials(&self, secret_name: &str) -> Result<DbSecrets, DbConnectErr> {
152        if let Some(connect_info) = self.connect_info_cache.get(secret_name).await {
153            return Ok(connect_info);
154        }
155
156        // Load new credentials
157        let credentials = self
158            .secrets_manager
159            .get_secret(secret_name)
160            .await
161            .map_err(|err| DbConnectErr::SecretsManager(err.into()))?
162            .ok_or(DbConnectErr::MissingCredentials)?;
163
164        // Cache the credential
165        self.connect_info_cache
166            .insert(secret_name.to_string(), credentials.clone())
167            .await;
168
169        Ok(credentials)
170    }
171
172    /// Creates a database pool connection
173    async fn create_pool(&self, db_name: &str, secret_name: &str) -> Result<PgPool, DbConnectErr> {
174        debug!(?db_name, ?secret_name, "creating db pool connection");
175
176        let credentials = self.get_credentials(secret_name).await?;
177        let options = PgConnectOptions::new()
178            .host(&self.host)
179            .port(self.port)
180            .username(&credentials.username)
181            .password(&credentials.password)
182            .database(db_name);
183
184        match PgPoolOptions::new().connect_with(options).await {
185            // Success case
186            Ok(value) => Ok(value),
187            Err(err) => {
188                // Drop the connect info cache incase the credentials were wrong
189                self.connect_info_cache.remove(secret_name).await;
190                Err(DbConnectErr::Db(err))
191            }
192        }
193    }
194}