use sea_orm::{ConnectOptions, Database, DatabaseConnection};
use std::time::Duration;
use crate::error::{Error, IntoApiError};
#[derive(Debug, Clone)]
pub struct DatabaseConfig {
pub url: String,
pub max_connections: u32,
pub min_connections: u32,
pub connect_timeout: u64,
pub idle_timeout: u64,
pub sqlx_logging: bool,
}
impl DatabaseConfig {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
max_connections: 10,
min_connections: 1,
connect_timeout: 30,
idle_timeout: 600,
sqlx_logging: cfg!(debug_assertions),
}
}
pub fn from_env() -> Result<Self, std::io::Error> {
let url = std::env::var("DATABASE_URL").map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::NotFound,
"DATABASE_URL environment variable not set",
)
})?;
let max_connections = std::env::var("DATABASE_MAX_CONNECTIONS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(10);
let min_connections = std::env::var("DATABASE_MIN_CONNECTIONS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(1);
let connect_timeout = std::env::var("DATABASE_CONNECT_TIMEOUT")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(30);
let idle_timeout = std::env::var("DATABASE_IDLE_TIMEOUT")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(600);
let sqlx_logging = std::env::var("DATABASE_LOGGING")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(cfg!(debug_assertions));
Ok(Self {
url,
max_connections,
min_connections,
connect_timeout,
idle_timeout,
sqlx_logging,
})
}
pub fn max_connections(mut self, n: u32) -> Self {
self.max_connections = n;
self
}
pub fn min_connections(mut self, n: u32) -> Self {
self.min_connections = n;
self
}
pub fn connect_timeout(mut self, secs: u64) -> Self {
self.connect_timeout = secs;
self
}
pub fn idle_timeout(mut self, secs: u64) -> Self {
self.idle_timeout = secs;
self
}
pub fn sqlx_logging(mut self, enabled: bool) -> Self {
self.sqlx_logging = enabled;
self
}
pub async fn connect(&self) -> Result<DatabaseConnection, DbError> {
let mut opts = ConnectOptions::new(&self.url);
opts.max_connections(self.max_connections)
.min_connections(self.min_connections)
.connect_timeout(Duration::from_secs(self.connect_timeout))
.idle_timeout(Duration::from_secs(self.idle_timeout))
.sqlx_logging(self.sqlx_logging);
Database::connect(opts).await.map_err(DbError)
}
}
#[derive(Debug)]
pub struct DbError(pub sea_orm::DbErr);
impl std::fmt::Display for DbError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for DbError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.0)
}
}
impl IntoApiError for DbError {
fn into_api_error(self) -> Error {
use sea_orm::DbErr;
match &self.0 {
DbErr::RecordNotFound(msg) => Error::not_found(msg.clone()),
DbErr::RecordNotInserted => Error::internal("failed to insert record"),
DbErr::RecordNotUpdated => Error::internal("failed to update record"),
DbErr::Custom(msg) => Error::internal(msg.clone()),
DbErr::Query(err) => {
tracing::error!(error = %err, "database query error");
Error::internal("database query failed")
}
DbErr::Conn(err) => {
tracing::error!(error = %err, "database connection error");
Error::internal("database connection failed")
}
DbErr::Exec(err) => {
tracing::error!(error = %err, "database execution error");
Error::internal("database operation failed")
}
_ => {
tracing::error!(error = %self.0, "database error");
Error::internal("database error")
}
}
}
}
impl From<sea_orm::DbErr> for DbError {
fn from(err: sea_orm::DbErr) -> Self {
DbError(err)
}
}
#[derive(Debug, Clone)]
pub struct Db(DatabaseConnection);
impl Db {
pub fn new(conn: DatabaseConnection) -> Self {
Self(conn)
}
pub fn conn(&self) -> &DatabaseConnection {
&self.0
}
pub fn into_inner(self) -> DatabaseConnection {
self.0
}
}
impl AsRef<DatabaseConnection> for Db {
fn as_ref(&self) -> &DatabaseConnection {
&self.0
}
}
impl std::ops::Deref for Db {
type Target = DatabaseConnection;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_database_config_new() {
let config = DatabaseConfig::new("postgres://localhost/test");
assert_eq!(config.url, "postgres://localhost/test");
assert_eq!(config.max_connections, 10);
assert_eq!(config.min_connections, 1);
}
#[test]
fn test_database_config_builder() {
let config = DatabaseConfig::new("postgres://localhost/test")
.max_connections(50)
.min_connections(5)
.connect_timeout(60)
.idle_timeout(300)
.sqlx_logging(false);
assert_eq!(config.max_connections, 50);
assert_eq!(config.min_connections, 5);
assert_eq!(config.connect_timeout, 60);
assert_eq!(config.idle_timeout, 300);
assert!(!config.sqlx_logging);
}
#[test]
fn test_db_error_not_found() {
let err = DbError(sea_orm::DbErr::RecordNotFound("user".to_string()));
let api_err = err.into_api_error();
assert_eq!(api_err.status, 404);
assert_eq!(api_err.code, "NOT_FOUND");
}
#[test]
fn test_db_error_custom() {
let err = DbError(sea_orm::DbErr::Custom("something went wrong".to_string()));
let api_err = err.into_api_error();
assert_eq!(api_err.status, 500);
assert_eq!(api_err.message, "something went wrong");
}
}