1#![forbid(unsafe_code)]
2#![forbid(clippy::unwrap_used)]
3#![deny(clippy::pedantic)]
4#![deny(clippy::get_unwrap)]
5#![allow(clippy::module_name_repetitions)]
6
7mod error;
8mod path_config_connection;
9mod path_role_create;
10mod path_roles;
11mod secret_creds;
12mod store;
13
14use std::sync::Arc;
15
16use covert_storage::{
17 migrator::{migration_scripts, MigrationError},
18 BackendStoragePool,
19};
20use error::{Error, ErrorType};
21use rust_embed::RustEmbed;
22use secret_creds::secret_creds_renew;
23use sqlx::{postgres::PgPoolOptions, PgPool, Pool, Postgres};
24use store::{connection::ConnectionStore, role::RoleStore};
25use tokio::sync::{RwLock, RwLockReadGuard};
26
27use covert_framework::{extract::Extension, read, revoke, update, Backend, Router};
28use covert_types::{
29 backend::{BackendCategory, BackendType},
30 psql::ConnectionConfig,
31};
32use tracing::debug;
33
34use self::{
35 path_config_connection::{path_connection_read, path_connection_write},
36 path_role_create::generate_role_credentials,
37 path_roles::path_role_create,
38 secret_creds::secret_creds_revoke,
39};
40
41#[derive(RustEmbed)]
42#[folder = "migrations/"]
43struct Migrations;
44
45pub struct Context {
46 db: RwLock<Option<PgPool>>,
47 connection_repo: ConnectionStore,
48 role_repo: RoleStore,
49}
50
51#[tracing::instrument(skip_all)]
57pub async fn new_psql_backend(storage: BackendStoragePool) -> Result<Backend, MigrationError> {
58 let ctx = Arc::new(Context {
59 db: RwLock::default(),
60 connection_repo: ConnectionStore::new(storage.clone()),
61 role_repo: RoleStore::new(storage),
62 });
63
64 if ctx.set_pool().await.is_ok() {
66 debug!("Configured pool from previosuly stored connection configuration");
67 }
68
69 let router = Router::new()
70 .route(
71 "/config/connection",
72 read(path_connection_read)
73 .update(path_connection_write)
74 .create(path_connection_write),
75 )
76 .route("/creds/:name", update(generate_role_credentials))
77 .route(
78 "/roles/:name",
79 update(path_role_create).create(path_role_create),
80 )
81 .route(
82 "/creds",
83 revoke(secret_creds_revoke).renew(secret_creds_renew),
84 )
85 .layer(Extension(ctx))
86 .build()
87 .into_service();
88
89 let migrations = migration_scripts::<Migrations>()?;
90
91 Ok(Backend {
92 handler: router,
93 category: BackendCategory::Logical,
94 variant: BackendType::Postgres,
95 migrations,
96 })
97}
98
99impl Context {
100 #[tracing::instrument(skip_all)]
101 async fn handle_missing_pool_for_configured_connection<'a>(
102 &self,
103 ) -> Result<RwLockReadGuard<'a, PgPool>, Error> {
104 self.reset_db().await;
106 self.connection_repo.remove().await?;
107 Err(ErrorType::MissingConnection.into())
108 }
109
110 pub async fn pool(&self) -> Result<RwLockReadGuard<'_, PgPool>, Error> {
116 let pool_l = self.db.read().await;
117 match RwLockReadGuard::try_map(pool_l, |maybe_pool| match maybe_pool {
118 Some(pool) => Some(pool),
119 None => None,
120 }) {
121 Ok(res) => Ok(res),
122 Err(lock) => {
123 drop(lock);
124 match self.connection_repo.get().await? {
125 Some(_) => self.handle_missing_pool_for_configured_connection().await,
126 None => Err(ErrorType::MissingConnection.into()),
127 }
128 }
129 }
130 }
131
132 pub async fn set_pool(&self) -> Result<(), Error> {
139 self.reset_db().await;
141
142 let conn_config = self
143 .connection_repo
144 .get()
145 .await?
146 .ok_or(ErrorType::MissingConnection)?;
147 let pool = pool_from_config(&conn_config).await?;
148
149 let mut pool_wl = self.db.write().await;
150 *pool_wl = Some(pool);
151
152 Ok(())
153 }
154
155 async fn reset_db(&self) {
156 let mut pool = self.db.write().await;
157 if let Some(pool) = pool.as_ref() {
158 pool.close().await;
159 }
160 *pool = None;
161 }
162}
163
164pub(crate) async fn pool_from_config(config: &ConnectionConfig) -> Result<Pool<Postgres>, Error> {
165 let mut connection_url = config.connection_url.clone();
166
167 if connection_url.starts_with("postgres://") || connection_url.starts_with("postgresql://") {
169 if connection_url.contains('?') {
170 connection_url = format!("{connection_url}&timezone=utc");
171 } else {
172 connection_url = format!("{connection_url}?timezone=utc");
173 }
174 } else {
175 connection_url = format!("{connection_url} timezone=utc");
176 }
177
178 PgPoolOptions::new()
179 .max_connections(config.max_open_connections)
182 .test_before_acquire(true)
183 .connect(&connection_url)
184 .await
185 .map_err(|_| ErrorType::InvalidConnectionString.into())
186}