use diesel::{ConnectionError, ConnectionResult};
use diesel_async::pooled_connection::bb8::Pool;
use diesel_async::pooled_connection::{AsyncDieselConnectionManager, ManagerConfig};
use diesel_async::{AsyncMigrationHarness, AsyncPgConnection};
use diesel_migrations::{EmbeddedMigrations, MigrationHarness as _, embed_migrations};
use futures_util::FutureExt;
use futures_util::future::BoxFuture;
use rustls::pki_types::CertificateDer;
use rustls::pki_types::pem::PemObject;
use rustls::{ClientConfig, RootCertStore};
use std::env::var;
use std::fs::read;
use tokio::time::Duration;
use vial_shared::{CreateSecretRequest, EncryptedPayload};
use crate::db::models::Secret;
use crate::errors::ServerError;
#[derive(Clone)]
pub struct Handler {
conn: Pool<AsyncPgConnection>,
}
pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("src/migrations");
pub async fn get_connection(url: &str) -> Handler {
rustls::crypto::ring::default_provider()
.install_default()
.expect("Failed to install rustls crypto provider");
let mut config = ManagerConfig::default();
config.custom_setup = Box::new(establish_connection);
let mgr = AsyncDieselConnectionManager::<AsyncPgConnection>::new_with_config(url, config);
let conn = Pool::builder()
.max_size(10)
.min_idle(Some(5))
.max_lifetime(Some(Duration::from_secs(60 * 60 * 24)))
.idle_timeout(Some(Duration::from_secs(60 * 2)))
.build(mgr)
.await
.unwrap_or_else(|e| panic!("Failed to create DB connection. Error: {e}"));
{
let async_connection = establish_connection(url)
.await
.expect("Failed to establish_connection to DB");
let mut harness = AsyncMigrationHarness::new(async_connection);
harness
.run_pending_migrations(MIGRATIONS)
.expect("Failed to run migrations");
let _async_connection = harness.into_inner();
}
let handler = Handler { conn };
let handler_clone = handler.clone();
tokio::spawn(async move {
handler_clone.initiate_expired_cleanup().await;
});
handler
}
fn establish_connection(config: &str) -> BoxFuture<'_, ConnectionResult<AsyncPgConnection>> {
let fut = async {
let mut root_store = RootCertStore::empty();
if let Ok(cert_location) = var("CERT_LOCATION") {
let file_bytes = read(&cert_location)
.unwrap_or_else(|e| panic!("Failed to read {cert_location}. Error: {e}"));
let cert = CertificateDer::from_pem_slice(&file_bytes)
.unwrap_or_else(|e| panic!("Failed to create cert. Error: {e}"));
root_store.add(cert).unwrap();
}
let rustls_config = ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
let tls = tokio_postgres_rustls::MakeRustlsConnect::new(rustls_config);
let (client, conn) = tokio_postgres::connect(config, tls)
.await
.map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
AsyncPgConnection::try_from_client_and_connection(client, conn).await
};
fut.boxed()
}
impl Handler {
pub async fn get_secret(&self, id: &str) -> Result<Option<EncryptedPayload>, ServerError> {
let mut conn = self
.conn
.get()
.await
.map_err(|e| ServerError::DatabaseError(e.to_string()))?;
Secret::get_secret(id, &mut conn)
.await
.map_err(|e| ServerError::DatabaseError(e.to_string()))
.map(|opt| opt.map(Secret::get_payload))
}
pub async fn clear_expired(&self) -> Result<(), ServerError> {
let mut conn = self
.conn
.get()
.await
.map_err(|e| ServerError::DatabaseError(e.to_string()))?;
Secret::clear_expired(&mut conn)
.await
.map_err(|e| ServerError::DatabaseError(e.to_string()))?;
Ok(())
}
pub async fn clear_expired_days(&self, days: i32) -> Result<(), ServerError> {
let mut conn = self
.conn
.get()
.await
.map_err(|e| ServerError::DatabaseError(e.to_string()))?;
Secret::clear_expired_days(days, &mut conn)
.await
.map_err(|e| ServerError::DatabaseError(e.to_string()))?;
Ok(())
}
pub async fn new_secret(&self, new_secret: CreateSecretRequest) -> Result<String, ServerError> {
let mut conn = self
.conn
.get()
.await
.map_err(|e| ServerError::DatabaseError(e.to_string()))?;
let secret = Secret::new(
new_secret.ciphertext,
new_secret.expires_at,
new_secret.max_views,
)?;
let secret_id = secret.get_id();
secret
.insert(&mut conn)
.await
.map_err(|e| ServerError::DatabaseError(e.to_string()))?;
Ok(secret_id)
}
pub async fn initiate_days_cleanup(&self, days: i32) {
let self_clone = self.clone();
tokio::spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(60)).await;
if let Err(e) = self_clone.clear_expired_days(days).await {
println!("Failed to clear expired secrets. Error: {e}");
}
}
});
}
async fn initiate_expired_cleanup(&self) {
loop {
tokio::time::sleep(Duration::from_secs(60)).await;
if let Err(e) = self.clear_expired().await {
println!("Failed to clear expired secrets. Error: {e}");
}
}
}
}