use crate::database::config::redact_database_url;
use crate::error::{Result, TidewayError};
use crate::traits::database::{DatabaseConnection, DatabasePool};
use async_trait::async_trait;
use sea_orm::{ConnectOptions, Database, DatabaseConnection as SeaOrmConnection};
use std::sync::Arc;
use std::time::Duration;
pub struct SeaOrmConnectionWrapper {
pub conn: SeaOrmConnection,
}
impl DatabaseConnection for SeaOrmConnectionWrapper {
fn is_valid(&self) -> bool {
true
}
}
impl std::ops::Deref for SeaOrmConnectionWrapper {
type Target = SeaOrmConnection;
fn deref(&self) -> &Self::Target {
&self.conn
}
}
pub struct SeaOrmPool {
conn: Arc<SeaOrmConnection>,
redacted_url: String,
health_status: Arc<std::sync::atomic::AtomicBool>,
}
impl SeaOrmPool {
pub fn new(conn: SeaOrmConnection, url: String) -> Self {
Self {
conn: Arc::new(conn),
redacted_url: redact_database_url(&url),
health_status: Arc::new(std::sync::atomic::AtomicBool::new(true)),
}
}
pub async fn from_config(config: &crate::database::DatabaseConfig) -> Result<Self> {
if config.max_connections == 0 {
return Err(TidewayError::bad_request(
"max_connections must be greater than 0",
));
}
if config.max_connections > 1000 {
return Err(TidewayError::bad_request(
"max_connections cannot exceed 1000",
));
}
if config.min_connections > config.max_connections {
return Err(TidewayError::bad_request(
"min_connections cannot be greater than max_connections",
));
}
if config.connect_timeout == 0 {
return Err(TidewayError::bad_request(
"connect_timeout must be greater than 0",
));
}
if config.connect_timeout > 300 {
return Err(TidewayError::bad_request(
"connect_timeout cannot exceed 300 seconds",
));
}
let mut opt = ConnectOptions::new(&config.url);
opt.max_connections(config.max_connections)
.min_connections(config.min_connections)
.connect_timeout(Duration::from_secs(config.connect_timeout))
.idle_timeout(Duration::from_secs(config.idle_timeout))
.sqlx_logging(true);
let conn = Database::connect(opt)
.await
.map_err(|e| TidewayError::internal(format!("Failed to connect to database: {}", e)))?;
tracing::info!(
"Database connected with {} max connections",
config.max_connections
);
Ok(Self::new(conn, config.url.clone()))
}
pub fn inner(&self) -> &SeaOrmConnection {
&self.conn
}
pub async fn ping(&self) -> bool {
match self.conn.ping().await {
Ok(()) => {
self.health_status
.store(true, std::sync::atomic::Ordering::Release);
true
}
Err(e) => {
tracing::warn!("Database ping failed: {}", e);
self.health_status
.store(false, std::sync::atomic::Ordering::Release);
false
}
}
}
}
#[async_trait]
impl DatabasePool for SeaOrmPool {
async fn connection(&self) -> Result<Box<dyn DatabaseConnection>> {
Ok(Box::new(SeaOrmConnectionWrapper {
conn: (*self.conn).clone(),
}))
}
fn is_healthy(&self) -> bool {
self.health_status
.load(std::sync::atomic::Ordering::Acquire)
}
async fn close(self: Box<Self>) -> Result<()> {
drop(self);
Ok(())
}
fn connection_url(&self) -> Option<&str> {
Some(&self.redacted_url)
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}