mod backend;
mod schema_validation;
pub use backend::{AnyConnection, AnyPool, BackendType};
pub use schema_validation::{
escape_password, validate_schema_name, validate_username, SchemaError, UsernameError,
};
#[cfg(feature = "postgres")]
pub use backend::{DbConnection, DbConnectionManager, DbPool};
#[cfg(all(feature = "sqlite", not(feature = "postgres")))]
pub use backend::{DbConnection, DbPool};
use thiserror::Error;
use tracing::info;
use url::Url;
#[cfg(feature = "postgres")]
use deadpool_diesel::postgres::{Manager as PgManager, Pool as PgPool, Runtime as PgRuntime};
#[cfg(feature = "sqlite")]
use deadpool_diesel::sqlite::{
Manager as SqliteManager, Pool as SqlitePool, Runtime as SqliteRuntime,
};
#[derive(Debug, Error)]
pub enum DatabaseError {
#[error("Failed to create {backend} connection pool: {source}")]
PoolCreation {
backend: &'static str,
#[source]
source: Box<dyn std::error::Error + Send + Sync>,
},
#[error("Invalid database URL: {0}")]
InvalidUrl(#[from] url::ParseError),
#[error("Schema validation failed: {0}")]
Schema(#[from] SchemaError),
#[error("Migration failed: {0}")]
Migration(String),
}
#[derive(Clone)]
pub struct Database {
pool: AnyPool,
backend: BackendType,
schema: Option<String>,
}
impl std::fmt::Debug for Database {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Database")
.field("backend", &self.backend)
.field("schema", &self.schema)
.field("pool", &"<connection pool>")
.finish()
}
}
impl Database {
pub fn new(connection_string: &str, database_name: &str, max_size: u32) -> Self {
Self::new_with_schema(connection_string, database_name, max_size, None)
}
pub fn new_with_schema(
connection_string: &str,
database_name: &str,
max_size: u32,
schema: Option<&str>,
) -> Self {
Self::try_new_with_schema(connection_string, database_name, max_size, schema)
.expect("Failed to create database connection pool")
}
pub fn try_new_with_schema(
connection_string: &str,
_database_name: &str,
max_size: u32,
schema: Option<&str>,
) -> Result<Self, DatabaseError> {
let backend = BackendType::from_url(connection_string);
let validated_schema = schema
.map(|s| validate_schema_name(s).map(|v| v.to_string()))
.transpose()?;
#[cfg(all(feature = "postgres", feature = "sqlite"))]
match backend {
BackendType::Postgres => {
let connection_url = Self::build_postgres_url(connection_string, _database_name)?;
let manager = PgManager::new(connection_url, PgRuntime::Tokio1);
let pool = PgPool::builder(manager)
.max_size(max_size as usize)
.build()
.map_err(|e| DatabaseError::PoolCreation {
backend: "PostgreSQL",
source: Box::new(e),
})?;
info!(
"PostgreSQL connection pool initialized{}",
validated_schema
.as_ref()
.map_or(String::new(), |s| format!(" with schema '{}'", s))
);
Ok(Self {
pool: AnyPool::Postgres(pool),
backend,
schema: validated_schema,
})
}
BackendType::Sqlite => {
let connection_url = Self::build_sqlite_url(connection_string);
let manager = SqliteManager::new(connection_url, SqliteRuntime::Tokio1);
let sqlite_pool_size = 1;
let pool = SqlitePool::builder(manager)
.max_size(sqlite_pool_size)
.build()
.map_err(|e| DatabaseError::PoolCreation {
backend: "SQLite",
source: Box::new(e),
})?;
info!(
"SQLite connection pool initialized (size: {})",
sqlite_pool_size
);
Ok(Self {
pool: AnyPool::Sqlite(pool),
backend,
schema: validated_schema,
})
}
}
#[cfg(all(feature = "postgres", not(feature = "sqlite")))]
{
let _ = backend; let connection_url = Self::build_postgres_url(connection_string, _database_name)?;
let manager = PgManager::new(connection_url, PgRuntime::Tokio1);
let pool = PgPool::builder(manager)
.max_size(max_size as usize)
.build()
.map_err(|e| DatabaseError::PoolCreation {
backend: "PostgreSQL",
source: Box::new(e),
})?;
info!(
"PostgreSQL connection pool initialized{}",
validated_schema
.as_ref()
.map_or(String::new(), |s| format!(" with schema '{}'", s))
);
return Ok(Self {
pool,
backend: BackendType::Postgres,
schema: validated_schema,
});
}
#[cfg(all(feature = "sqlite", not(feature = "postgres")))]
{
let _ = backend; let connection_url = Self::build_sqlite_url(connection_string);
let manager = SqliteManager::new(connection_url, SqliteRuntime::Tokio1);
let sqlite_pool_size = 1;
let pool = SqlitePool::builder(manager)
.max_size(sqlite_pool_size)
.build()
.map_err(|e| DatabaseError::PoolCreation {
backend: "SQLite",
source: Box::new(e),
})?;
info!(
"SQLite connection pool initialized (size: {})",
sqlite_pool_size
);
return Ok(Self {
pool,
backend: BackendType::Sqlite,
schema: validated_schema,
});
}
}
pub fn backend(&self) -> BackendType {
self.backend
}
pub fn schema(&self) -> Option<&str> {
self.schema.as_deref()
}
pub fn pool(&self) -> AnyPool {
self.pool.clone()
}
pub fn get_connection(&self) -> AnyPool {
self.pool.clone()
}
pub fn close(&self) {
tracing::info!("Closing database connection pool");
self.pool.close();
}
fn build_postgres_url(base_url: &str, database_name: &str) -> Result<String, url::ParseError> {
let mut url = Url::parse(base_url)?;
url.set_path(database_name);
Ok(url.to_string())
}
fn build_sqlite_url(connection_string: &str) -> String {
if let Some(path) = connection_string.strip_prefix("sqlite://") {
path.to_string()
} else {
connection_string.to_string()
}
}
pub async fn run_migrations(&self) -> Result<(), String> {
use diesel_migrations::MigrationHarness;
#[cfg(all(feature = "postgres", feature = "sqlite"))]
match &self.pool {
AnyPool::Postgres(pool) => {
let conn = pool.get().await.map_err(|e| e.to_string())?;
conn.interact(|conn| {
conn.run_pending_migrations(crate::database::POSTGRES_MIGRATIONS)
.map(|_| ())
.map_err(|e| format!("Failed to run PostgreSQL migrations: {}", e))
})
.await
.map_err(|e| format!("Failed to run migrations: {}", e))??;
}
AnyPool::Sqlite(pool) => {
let conn = pool.get().await.map_err(|e| e.to_string())?;
conn.interact(|conn| {
use diesel::prelude::*;
diesel::sql_query("PRAGMA journal_mode=WAL;")
.execute(conn)
.map_err(|e| format!("Failed to set WAL mode: {}", e))?;
diesel::sql_query("PRAGMA busy_timeout=30000;")
.execute(conn)
.map_err(|e| format!("Failed to set busy_timeout: {}", e))?;
conn.run_pending_migrations(crate::database::SQLITE_MIGRATIONS)
.map(|_| ())
.map_err(|e| format!("Failed to run SQLite migrations: {}", e))
})
.await
.map_err(|e| format!("Failed to run migrations: {}", e))??;
}
}
#[cfg(all(feature = "postgres", not(feature = "sqlite")))]
{
let conn = self.pool.get().await.map_err(|e| e.to_string())?;
conn.interact(|conn| {
conn.run_pending_migrations(crate::database::POSTGRES_MIGRATIONS)
.map(|_| ())
.map_err(|e| format!("Failed to run PostgreSQL migrations: {}", e))
})
.await
.map_err(|e| format!("Failed to run migrations: {}", e))?
.map_err(|e| e)?;
}
#[cfg(all(feature = "sqlite", not(feature = "postgres")))]
{
let conn = self.pool.get().await.map_err(|e| e.to_string())?;
conn.interact(|conn| {
use diesel::prelude::*;
diesel::sql_query("PRAGMA journal_mode=WAL;")
.execute(conn)
.map_err(|e| format!("Failed to set WAL mode: {}", e))?;
diesel::sql_query("PRAGMA busy_timeout=30000;")
.execute(conn)
.map_err(|e| format!("Failed to set busy_timeout: {}", e))?;
conn.run_pending_migrations(crate::database::SQLITE_MIGRATIONS)
.map(|_| ())
.map_err(|e| format!("Failed to run SQLite migrations: {}", e))
})
.await
.map_err(|e| format!("Failed to run migrations: {}", e))?
.map_err(|e| e)?;
}
Ok(())
}
#[cfg(feature = "postgres")]
pub async fn setup_schema(&self, schema: &str) -> Result<(), String> {
use diesel::prelude::*;
let validated_schema = validate_schema_name(schema).map_err(|e| e.to_string())?;
#[cfg(all(feature = "postgres", feature = "sqlite"))]
let pool = match &self.pool {
AnyPool::Postgres(pool) => pool,
AnyPool::Sqlite(_) => {
return Err("Schema setup is not supported for SQLite".to_string());
}
};
#[cfg(all(feature = "postgres", not(feature = "sqlite")))]
let pool = &self.pool;
let conn = pool.get().await.map_err(|e| e.to_string())?;
let schema_name = validated_schema.to_string();
let schema_name_clone = schema_name.clone();
conn.interact(move |conn| {
let create_schema_sql = format!("CREATE SCHEMA IF NOT EXISTS {}", schema_name);
diesel::sql_query(&create_schema_sql).execute(conn)
})
.await
.map_err(|e| format!("Failed to create schema: {}", e))?
.map_err(|e| format!("Failed to create schema: {}", e))?;
conn.interact(move |conn| {
let set_search_path_sql = format!("SET search_path TO {}, public", schema_name_clone);
diesel::sql_query(&set_search_path_sql).execute(conn)
})
.await
.map_err(|e| format!("Failed to set search path: {}", e))?
.map_err(|e| format!("Failed to set search path: {}", e))?;
conn.interact(|conn| {
use diesel_migrations::MigrationHarness;
conn.run_pending_migrations(crate::database::POSTGRES_MIGRATIONS)
.map(|_| ())
.map_err(|e| format!("Failed to run migrations: {}", e))
})
.await
.map_err(|e| format!("Failed to run migrations in schema: {}", e))??;
info!("Schema '{}' set up successfully", schema);
Ok(())
}
#[cfg(feature = "postgres")]
pub async fn get_connection_with_schema(
&self,
) -> Result<
deadpool::managed::Object<PgManager>,
deadpool::managed::PoolError<deadpool_diesel::Error>,
> {
use diesel::prelude::*;
#[cfg(all(feature = "postgres", feature = "sqlite"))]
let pool = match &self.pool {
AnyPool::Postgres(pool) => pool,
AnyPool::Sqlite(_) => {
panic!("get_connection_with_schema called on SQLite backend");
}
};
#[cfg(all(feature = "postgres", not(feature = "sqlite")))]
let pool = &self.pool;
let conn = pool.get().await?;
if let Some(ref schema) = self.schema {
if let Ok(validated) = validate_schema_name(schema) {
let schema_name = validated.to_string();
let _ = conn
.interact(move |conn| {
let set_search_path_sql =
format!("SET search_path TO {}, public", schema_name);
diesel::sql_query(&set_search_path_sql).execute(conn)
})
.await;
}
}
Ok(conn)
}
#[cfg(feature = "postgres")]
pub async fn get_postgres_connection(
&self,
) -> Result<
deadpool::managed::Object<PgManager>,
deadpool::managed::PoolError<deadpool_diesel::Error>,
> {
self.get_connection_with_schema().await
}
#[cfg(feature = "sqlite")]
pub async fn get_sqlite_connection(
&self,
) -> Result<
deadpool::managed::Object<SqliteManager>,
deadpool::managed::PoolError<deadpool_diesel::Error>,
> {
#[cfg(all(feature = "postgres", feature = "sqlite"))]
let pool = match &self.pool {
AnyPool::Sqlite(pool) => pool,
AnyPool::Postgres(_) => {
panic!("get_sqlite_connection called on PostgreSQL backend");
}
};
#[cfg(all(feature = "sqlite", not(feature = "postgres")))]
let pool = &self.pool;
let conn = pool.get().await?;
conn.interact(|conn| {
use diesel::prelude::*;
let _ = diesel::sql_query("PRAGMA journal_mode=WAL;").execute(conn);
let _ = diesel::sql_query("PRAGMA busy_timeout=30000;").execute(conn);
})
.await
.ok();
Ok(conn)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_postgres_url_parsing_scenarios() {
let mut url = Url::parse("postgres://postgres:postgres@localhost:5432").unwrap();
url.set_path("test_db");
assert_eq!(url.path(), "/test_db");
assert_eq!(url.scheme(), "postgres");
assert_eq!(url.host_str(), Some("localhost"));
assert_eq!(url.port(), Some(5432));
assert_eq!(url.username(), "postgres");
assert_eq!(url.password(), Some("postgres"));
let mut url = Url::parse("postgres://postgres:postgres@localhost").unwrap();
url.set_path("test_db");
assert_eq!(url.port(), None);
let mut url = Url::parse("postgres://localhost:5432").unwrap();
url.set_path("test_db");
assert_eq!(url.username(), "");
assert_eq!(url.password(), None);
assert!(Url::parse("not-a-url").is_err());
}
#[test]
fn test_sqlite_connection_strings() {
let url = Database::build_sqlite_url("/path/to/database.db");
assert_eq!(url, "/path/to/database.db");
let url = Database::build_sqlite_url(":memory:");
assert_eq!(url, ":memory:");
let url = Database::build_sqlite_url("./database.db");
assert_eq!(url, "./database.db");
let url = Database::build_sqlite_url("sqlite:///path/to/db.sqlite");
assert_eq!(url, "/path/to/db.sqlite");
}
#[test]
fn test_backend_type_detection() {
#[cfg(feature = "postgres")]
{
assert_eq!(
BackendType::from_url("postgres://localhost/db"),
BackendType::Postgres
);
assert_eq!(
BackendType::from_url("postgresql://localhost/db"),
BackendType::Postgres
);
}
#[cfg(feature = "sqlite")]
{
assert_eq!(
BackendType::from_url("sqlite:///path/to/db"),
BackendType::Sqlite
);
assert_eq!(
BackendType::from_url("/absolute/path.db"),
BackendType::Sqlite
);
assert_eq!(
BackendType::from_url("./relative/path.db"),
BackendType::Sqlite
);
assert_eq!(BackendType::from_url(":memory:"), BackendType::Sqlite);
assert_eq!(
BackendType::from_url("database.sqlite"),
BackendType::Sqlite
);
assert_eq!(
BackendType::from_url("database.sqlite3"),
BackendType::Sqlite
);
assert_eq!(
BackendType::from_url("file:test?mode=memory&cache=shared"),
BackendType::Sqlite
);
assert_eq!(
BackendType::from_url("file:cloacina_test?mode=memory&cache=shared"),
BackendType::Sqlite
);
}
}
}