#[cfg(any(
feature = "sqlx-postgres",
feature = "sqlx-mysql",
feature = "sqlx-sqlite"
))]
use rustapi_core::health::{HealthCheck, HealthCheckBuilder, HealthStatus};
use rustapi_core::ApiError;
#[cfg(any(
feature = "sqlx-postgres",
feature = "sqlx-mysql",
feature = "sqlx-sqlite"
))]
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum PoolError {
#[error("Pool configuration error: {0}")]
Configuration(String),
#[error("Database connection error: {0}")]
Connection(String),
#[error("SQLx error: {0}")]
Sqlx(#[from] sqlx::Error),
}
#[derive(Debug, Clone)]
pub struct SqlxPoolConfig {
pub url: String,
pub max_connections: u32,
pub min_connections: u32,
pub connect_timeout: Duration,
pub idle_timeout: Duration,
pub max_lifetime: Duration,
}
impl Default for SqlxPoolConfig {
fn default() -> Self {
Self {
url: String::new(),
max_connections: 10,
min_connections: 1,
connect_timeout: Duration::from_secs(30),
idle_timeout: Duration::from_secs(600),
max_lifetime: Duration::from_secs(1800),
}
}
}
impl SqlxPoolConfig {
pub fn validate(&self) -> Result<(), PoolError> {
if self.url.is_empty() {
return Err(PoolError::Configuration(
"Database URL cannot be empty".to_string(),
));
}
if self.max_connections == 0 {
return Err(PoolError::Configuration(
"max_connections must be greater than 0".to_string(),
));
}
if self.min_connections > self.max_connections {
return Err(PoolError::Configuration(
"min_connections cannot exceed max_connections".to_string(),
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct SqlxPoolBuilder {
config: SqlxPoolConfig,
}
impl SqlxPoolBuilder {
pub fn new(url: impl Into<String>) -> Self {
Self {
config: SqlxPoolConfig {
url: url.into(),
..Default::default()
},
}
}
pub fn max_connections(mut self, n: u32) -> Self {
self.config.max_connections = n;
self
}
pub fn min_connections(mut self, n: u32) -> Self {
self.config.min_connections = n;
self
}
pub fn connect_timeout(mut self, d: Duration) -> Self {
self.config.connect_timeout = d;
self
}
pub fn idle_timeout(mut self, d: Duration) -> Self {
self.config.idle_timeout = d;
self
}
pub fn max_lifetime(mut self, d: Duration) -> Self {
self.config.max_lifetime = d;
self
}
pub fn config(&self) -> &SqlxPoolConfig {
&self.config
}
#[cfg(feature = "sqlx-postgres")]
pub async fn build_postgres(self) -> Result<sqlx::PgPool, PoolError> {
self.config.validate()?;
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(self.config.max_connections)
.min_connections(self.config.min_connections)
.acquire_timeout(self.config.connect_timeout)
.idle_timeout(Some(self.config.idle_timeout))
.max_lifetime(Some(self.config.max_lifetime))
.connect(&self.config.url)
.await?;
Ok(pool)
}
#[cfg(feature = "sqlx-mysql")]
pub async fn build_mysql(self) -> Result<sqlx::MySqlPool, PoolError> {
self.config.validate()?;
let pool = sqlx::mysql::MySqlPoolOptions::new()
.max_connections(self.config.max_connections)
.min_connections(self.config.min_connections)
.acquire_timeout(self.config.connect_timeout)
.idle_timeout(Some(self.config.idle_timeout))
.max_lifetime(Some(self.config.max_lifetime))
.connect(&self.config.url)
.await?;
Ok(pool)
}
#[cfg(feature = "sqlx-sqlite")]
pub async fn build_sqlite(self) -> Result<sqlx::SqlitePool, PoolError> {
self.config.validate()?;
let pool = sqlx::sqlite::SqlitePoolOptions::new()
.max_connections(self.config.max_connections)
.min_connections(self.config.min_connections)
.acquire_timeout(self.config.connect_timeout)
.idle_timeout(Some(self.config.idle_timeout))
.max_lifetime(Some(self.config.max_lifetime))
.connect(&self.config.url)
.await?;
Ok(pool)
}
#[cfg(feature = "sqlx-postgres")]
pub fn health_check_postgres(pool: Arc<sqlx::PgPool>) -> HealthCheck {
HealthCheckBuilder::new(false)
.add_check("postgres", move || {
let pool = pool.clone();
async move {
match sqlx::query("SELECT 1").execute(pool.as_ref()).await {
Ok(_) => HealthStatus::healthy(),
Err(e) => HealthStatus::unhealthy(format!("Database check failed: {}", e)),
}
}
})
.build()
}
#[cfg(feature = "sqlx-mysql")]
pub fn health_check_mysql(pool: Arc<sqlx::MySqlPool>) -> HealthCheck {
HealthCheckBuilder::new(false)
.add_check("mysql", move || {
let pool = pool.clone();
async move {
match sqlx::query("SELECT 1").execute(pool.as_ref()).await {
Ok(_) => HealthStatus::healthy(),
Err(e) => HealthStatus::unhealthy(format!("Database check failed: {}", e)),
}
}
})
.build()
}
#[cfg(feature = "sqlx-sqlite")]
pub fn health_check_sqlite(pool: Arc<sqlx::SqlitePool>) -> HealthCheck {
HealthCheckBuilder::new(false)
.add_check("sqlite", move || {
let pool = pool.clone();
async move {
match sqlx::query("SELECT 1").execute(pool.as_ref()).await {
Ok(_) => HealthStatus::healthy(),
Err(e) => HealthStatus::unhealthy(format!("Database check failed: {}", e)),
}
}
})
.build()
}
}
pub trait SqlxErrorExt {
fn into_api_error(self) -> ApiError;
}
impl SqlxErrorExt for sqlx::Error {
fn into_api_error(self) -> ApiError {
convert_sqlx_error(self)
}
}
pub fn convert_sqlx_error(err: sqlx::Error) -> ApiError {
match &err {
sqlx::Error::PoolTimedOut => ApiError::new(
http::StatusCode::SERVICE_UNAVAILABLE,
"service_unavailable",
"Database connection pool exhausted",
)
.with_internal(err.to_string()),
sqlx::Error::PoolClosed => ApiError::new(
http::StatusCode::SERVICE_UNAVAILABLE,
"service_unavailable",
"Database connection pool is closed",
)
.with_internal(err.to_string()),
sqlx::Error::RowNotFound => ApiError::not_found("Resource not found"),
sqlx::Error::Database(db_err) => {
if let Some(code) = db_err.code() {
let code_str = code.as_ref();
if code_str == "23505" || code_str == "1062" || code_str == "2067" {
return ApiError::conflict("Resource already exists")
.with_internal(db_err.to_string());
}
if code_str == "23503" || code_str == "1452" || code_str == "787" {
return ApiError::bad_request("Referenced resource does not exist")
.with_internal(db_err.to_string());
}
if code_str == "23514" {
return ApiError::bad_request("Data validation failed")
.with_internal(db_err.to_string());
}
}
ApiError::internal("Database error").with_internal(db_err.to_string())
}
sqlx::Error::Io(_) => ApiError::new(
http::StatusCode::SERVICE_UNAVAILABLE,
"service_unavailable",
"Database connection error",
)
.with_internal(err.to_string()),
sqlx::Error::Tls(_) => ApiError::new(
http::StatusCode::SERVICE_UNAVAILABLE,
"service_unavailable",
"Database TLS error",
)
.with_internal(err.to_string()),
sqlx::Error::Protocol(_) => {
ApiError::internal("Database protocol error").with_internal(err.to_string())
}
sqlx::Error::TypeNotFound { .. } => {
ApiError::internal("Database type error").with_internal(err.to_string())
}
sqlx::Error::ColumnNotFound(_) => {
ApiError::internal("Database column not found").with_internal(err.to_string())
}
sqlx::Error::ColumnIndexOutOfBounds { .. } => {
ApiError::internal("Database column index error").with_internal(err.to_string())
}
sqlx::Error::ColumnDecode { .. } => {
ApiError::internal("Database decode error").with_internal(err.to_string())
}
sqlx::Error::Configuration(_) => {
ApiError::internal("Database configuration error").with_internal(err.to_string())
}
sqlx::Error::Migrate(_) => {
ApiError::internal("Database migration error").with_internal(err.to_string())
}
_ => ApiError::internal("Database error").with_internal(err.to_string()),
}
}
#[cfg(test)]
mod tests {
use super::*;
use http::StatusCode;
use proptest::prelude::*;
#[test]
fn test_pool_timeout_returns_503() {
let err = sqlx::Error::PoolTimedOut;
let api_err = convert_sqlx_error(err);
assert_eq!(api_err.status, StatusCode::SERVICE_UNAVAILABLE);
assert_eq!(api_err.error_type, "service_unavailable");
}
#[test]
fn test_pool_closed_returns_503() {
let err = sqlx::Error::PoolClosed;
let api_err = convert_sqlx_error(err);
assert_eq!(api_err.status, StatusCode::SERVICE_UNAVAILABLE);
assert_eq!(api_err.error_type, "service_unavailable");
}
#[test]
fn test_row_not_found_returns_404() {
let err = sqlx::Error::RowNotFound;
let api_err = convert_sqlx_error(err);
assert_eq!(api_err.status, StatusCode::NOT_FOUND);
assert_eq!(api_err.error_type, "not_found");
}
#[test]
fn test_builder_default_values() {
let builder = SqlxPoolBuilder::new("postgres://localhost/test");
let config = builder.config();
assert_eq!(config.url, "postgres://localhost/test");
assert_eq!(config.max_connections, 10);
assert_eq!(config.min_connections, 1);
assert_eq!(config.connect_timeout, Duration::from_secs(30));
assert_eq!(config.idle_timeout, Duration::from_secs(600));
assert_eq!(config.max_lifetime, Duration::from_secs(1800));
}
#[test]
fn test_builder_custom_values() {
let builder = SqlxPoolBuilder::new("postgres://localhost/test")
.max_connections(20)
.min_connections(5)
.connect_timeout(Duration::from_secs(10))
.idle_timeout(Duration::from_secs(300))
.max_lifetime(Duration::from_secs(900));
let config = builder.config();
assert_eq!(config.max_connections, 20);
assert_eq!(config.min_connections, 5);
assert_eq!(config.connect_timeout, Duration::from_secs(10));
assert_eq!(config.idle_timeout, Duration::from_secs(300));
assert_eq!(config.max_lifetime, Duration::from_secs(900));
}
#[test]
fn test_config_validation_empty_url() {
let config = SqlxPoolConfig::default();
let result = config.validate();
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), PoolError::Configuration(_)));
}
#[test]
fn test_config_validation_zero_max_connections() {
let config = SqlxPoolConfig {
url: "postgres://localhost/test".to_string(),
max_connections: 0,
..Default::default()
};
let result = config.validate();
assert!(result.is_err());
}
#[test]
fn test_config_validation_min_exceeds_max() {
let config = SqlxPoolConfig {
url: "postgres://localhost/test".to_string(),
max_connections: 5,
min_connections: 10,
..Default::default()
};
let result = config.validate();
assert!(result.is_err());
}
#[test]
fn test_config_validation_valid() {
let config = SqlxPoolConfig {
url: "postgres://localhost/test".to_string(),
max_connections: 10,
min_connections: 2,
..Default::default()
};
let result = config.validate();
assert!(result.is_ok());
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum SqlxErrorCategory {
PoolTimeout,
PoolClosed,
RowNotFound,
Protocol,
ColumnNotFound,
}
impl SqlxErrorCategory {
fn expected_status(&self) -> StatusCode {
match self {
SqlxErrorCategory::PoolTimeout => StatusCode::SERVICE_UNAVAILABLE,
SqlxErrorCategory::PoolClosed => StatusCode::SERVICE_UNAVAILABLE,
SqlxErrorCategory::RowNotFound => StatusCode::NOT_FOUND,
SqlxErrorCategory::Protocol => StatusCode::INTERNAL_SERVER_ERROR,
SqlxErrorCategory::ColumnNotFound => StatusCode::INTERNAL_SERVER_ERROR,
}
}
fn expected_error_type(&self) -> &'static str {
match self {
SqlxErrorCategory::PoolTimeout => "service_unavailable",
SqlxErrorCategory::PoolClosed => "service_unavailable",
SqlxErrorCategory::RowNotFound => "not_found",
SqlxErrorCategory::Protocol => "internal_error",
SqlxErrorCategory::ColumnNotFound => "internal_error",
}
}
fn create_error(&self) -> sqlx::Error {
match self {
SqlxErrorCategory::PoolTimeout => sqlx::Error::PoolTimedOut,
SqlxErrorCategory::PoolClosed => sqlx::Error::PoolClosed,
SqlxErrorCategory::RowNotFound => sqlx::Error::RowNotFound,
SqlxErrorCategory::Protocol => {
sqlx::Error::Protocol("test protocol error".to_string())
}
SqlxErrorCategory::ColumnNotFound => {
sqlx::Error::ColumnNotFound("test_column".to_string())
}
}
}
}
fn sqlx_error_category_strategy() -> impl Strategy<Value = SqlxErrorCategory> {
prop_oneof![
Just(SqlxErrorCategory::PoolTimeout),
Just(SqlxErrorCategory::PoolClosed),
Just(SqlxErrorCategory::RowNotFound),
Just(SqlxErrorCategory::Protocol),
Just(SqlxErrorCategory::ColumnNotFound),
]
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_sqlx_error_conversion_produces_appropriate_status(
category in sqlx_error_category_strategy()
) {
let sqlx_err = category.create_error();
let api_err = convert_sqlx_error(sqlx_err);
prop_assert_eq!(
api_err.status,
category.expected_status(),
"SQLx error category {:?} should produce status {:?}, got {:?}",
category,
category.expected_status(),
api_err.status
);
prop_assert_eq!(
api_err.error_type.as_str(),
category.expected_error_type(),
"SQLx error category {:?} should have error_type {:?}",
category,
category.expected_error_type()
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_connection_errors_return_503(
category in prop_oneof![
Just(SqlxErrorCategory::PoolTimeout),
Just(SqlxErrorCategory::PoolClosed),
]
) {
let sqlx_err = category.create_error();
let api_err = convert_sqlx_error(sqlx_err);
prop_assert_eq!(
api_err.status,
StatusCode::SERVICE_UNAVAILABLE,
"Connection errors should return 503 Service Unavailable"
);
prop_assert_eq!(
api_err.error_type.as_str(),
"service_unavailable",
"Connection errors should have error_type 'service_unavailable'"
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_pool_configuration_respects_limits(
max_conn in 1u32..100,
min_conn_factor in 0.0f64..1.0,
connect_timeout_secs in 1u64..120,
idle_timeout_secs in 60u64..3600,
max_lifetime_secs in 300u64..7200,
) {
let min_conn = ((max_conn as f64) * min_conn_factor).floor() as u32;
let builder = SqlxPoolBuilder::new("postgres://localhost/test")
.max_connections(max_conn)
.min_connections(min_conn)
.connect_timeout(Duration::from_secs(connect_timeout_secs))
.idle_timeout(Duration::from_secs(idle_timeout_secs))
.max_lifetime(Duration::from_secs(max_lifetime_secs));
let config = builder.config();
prop_assert_eq!(config.max_connections, max_conn);
prop_assert_eq!(config.min_connections, min_conn);
prop_assert_eq!(config.connect_timeout, Duration::from_secs(connect_timeout_secs));
prop_assert_eq!(config.idle_timeout, Duration::from_secs(idle_timeout_secs));
prop_assert_eq!(config.max_lifetime, Duration::from_secs(max_lifetime_secs));
prop_assert!(config.validate().is_ok());
prop_assert!(config.min_connections <= config.max_connections);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_invalid_config_is_rejected(
max_conn in 1u32..50,
min_conn_excess in 1u32..50,
) {
let config = SqlxPoolConfig {
url: "postgres://localhost/test".to_string(),
max_connections: max_conn,
min_connections: max_conn + min_conn_excess,
..Default::default()
};
prop_assert!(config.validate().is_err());
}
}
}