covert_psql/
lib.rs

1#![forbid(unsafe_code)]
2#![forbid(clippy::unwrap_used)]
3#![deny(clippy::pedantic)]
4#![deny(clippy::get_unwrap)]
5#![allow(clippy::module_name_repetitions)]
6
7mod error;
8mod path_config_connection;
9mod path_role_create;
10mod path_roles;
11mod secret_creds;
12mod store;
13
14use std::sync::Arc;
15
16use covert_storage::{
17    migrator::{migration_scripts, MigrationError},
18    BackendStoragePool,
19};
20use error::{Error, ErrorType};
21use rust_embed::RustEmbed;
22use secret_creds::secret_creds_renew;
23use sqlx::{postgres::PgPoolOptions, PgPool, Pool, Postgres};
24use store::{connection::ConnectionStore, role::RoleStore};
25use tokio::sync::{RwLock, RwLockReadGuard};
26
27use covert_framework::{extract::Extension, read, revoke, update, Backend, Router};
28use covert_types::{
29    backend::{BackendCategory, BackendType},
30    psql::ConnectionConfig,
31};
32use tracing::debug;
33
34use self::{
35    path_config_connection::{path_connection_read, path_connection_write},
36    path_role_create::generate_role_credentials,
37    path_roles::path_role_create,
38    secret_creds::secret_creds_revoke,
39};
40
41#[derive(RustEmbed)]
42#[folder = "migrations/"]
43struct Migrations;
44
45pub struct Context {
46    db: RwLock<Option<PgPool>>,
47    connection_repo: ConnectionStore,
48    role_repo: RoleStore,
49}
50
51/// Returns a new `PostgreSQL` secret engine.
52///
53/// # Errors
54///
55/// Returns an error if it fails to read the migration scripts.
56#[tracing::instrument(skip_all)]
57pub async fn new_psql_backend(storage: BackendStoragePool) -> Result<Backend, MigrationError> {
58    let ctx = Arc::new(Context {
59        db: RwLock::default(),
60        connection_repo: ConnectionStore::new(storage.clone()),
61        role_repo: RoleStore::new(storage),
62    });
63
64    // Try to recover pool from the connection config if it is configured.
65    if ctx.set_pool().await.is_ok() {
66        debug!("Configured pool from previosuly stored connection configuration");
67    }
68
69    let router = Router::new()
70        .route(
71            "/config/connection",
72            read(path_connection_read)
73                .update(path_connection_write)
74                .create(path_connection_write),
75        )
76        .route("/creds/:name", update(generate_role_credentials))
77        .route(
78            "/roles/:name",
79            update(path_role_create).create(path_role_create),
80        )
81        .route(
82            "/creds",
83            revoke(secret_creds_revoke).renew(secret_creds_renew),
84        )
85        .layer(Extension(ctx))
86        .build()
87        .into_service();
88
89    let migrations = migration_scripts::<Migrations>()?;
90
91    Ok(Backend {
92        handler: router,
93        category: BackendCategory::Logical,
94        variant: BackendType::Postgres,
95        migrations,
96    })
97}
98
99impl Context {
100    #[tracing::instrument(skip_all)]
101    async fn handle_missing_pool_for_configured_connection<'a>(
102        &self,
103    ) -> Result<RwLockReadGuard<'a, PgPool>, Error> {
104        // Something is wrong with the pool, so close it and reset connection
105        self.reset_db().await;
106        self.connection_repo.remove().await?;
107        Err(ErrorType::MissingConnection.into())
108    }
109
110    /// Return a psql connection pool.
111    ///
112    /// # Errors
113    ///
114    /// Fails if the pool is not yet configured.
115    pub async fn pool(&self) -> Result<RwLockReadGuard<'_, PgPool>, Error> {
116        let pool_l = self.db.read().await;
117        match RwLockReadGuard::try_map(pool_l, |maybe_pool| match maybe_pool {
118            Some(pool) => Some(pool),
119            None => None,
120        }) {
121            Ok(res) => Ok(res),
122            Err(lock) => {
123                drop(lock);
124                match self.connection_repo.get().await? {
125                    Some(_) => self.handle_missing_pool_for_configured_connection().await,
126                    None => Err(ErrorType::MissingConnection.into()),
127                }
128            }
129        }
130    }
131
132    /// Set the psql connection pool.
133    ///
134    /// # Errors
135    ///
136    /// Fails if it fails to establish a connection to the database from the
137    /// connection configuration.
138    pub async fn set_pool(&self) -> Result<(), Error> {
139        // Reset any existing pool
140        self.reset_db().await;
141
142        let conn_config = self
143            .connection_repo
144            .get()
145            .await?
146            .ok_or(ErrorType::MissingConnection)?;
147        let pool = pool_from_config(&conn_config).await?;
148
149        let mut pool_wl = self.db.write().await;
150        *pool_wl = Some(pool);
151
152        Ok(())
153    }
154
155    async fn reset_db(&self) {
156        let mut pool = self.db.write().await;
157        if let Some(pool) = pool.as_ref() {
158            pool.close().await;
159        }
160        *pool = None;
161    }
162}
163
164pub(crate) async fn pool_from_config(config: &ConnectionConfig) -> Result<Pool<Postgres>, Error> {
165    let mut connection_url = config.connection_url.clone();
166
167    // Ensure timezone is set to UTC for all the connections
168    if connection_url.starts_with("postgres://") || connection_url.starts_with("postgresql://") {
169        if connection_url.contains('?') {
170            connection_url = format!("{connection_url}&timezone=utc");
171        } else {
172            connection_url = format!("{connection_url}?timezone=utc");
173        }
174    } else {
175        connection_url = format!("{connection_url} timezone=utc");
176    }
177
178    PgPoolOptions::new()
179        // Set some connection pool settings. We don't need much of this,
180        // since the request rate shouldn't be high.
181        .max_connections(config.max_open_connections)
182        .test_before_acquire(true)
183        .connect(&connection_url)
184        .await
185        .map_err(|_| ErrorType::InvalidConnectionString.into())
186}