use sea_orm::{ConnectOptions, ConnectionTrait, Database, DatabaseConnection, DatabaseBackend};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use thiserror::Error;
use tracing::{error, info, warn};
#[derive(Error, Debug)]
pub enum DatabaseError {
#[error("数据库连接失败: {0}")]
ConnectionFailed(String),
#[error("设置 Schema 失败: {0}")]
SchemaSetFailed(String),
#[error("无效的配置: {0}")]
InvalidConfig(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
pub database_type: String,
pub host: String,
pub port: u16,
pub username: String,
pub password: String,
pub database_name: String,
#[serde(default = "default_schema")]
pub schema: String,
#[serde(default = "default_logging_level")]
pub logging_level: String,
#[serde(default = "default_use_pgbouncer")]
pub use_pgbouncer: bool,
}
fn default_schema() -> String {
"public".to_string()
}
fn default_logging_level() -> String {
"info".to_string()
}
fn default_use_pgbouncer() -> bool {
false
}
#[derive(Debug, Clone)]
pub struct ConnectionOptions {
pub max_connections: u32,
pub min_connections: u32,
pub connect_timeout: u64,
pub acquire_timeout: u64,
pub idle_timeout: u64,
pub max_lifetime: u64,
pub sqlx_logging: bool,
}
impl Default for ConnectionOptions {
fn default() -> Self {
Self {
max_connections: 300,
min_connections: 5,
connect_timeout: 8,
acquire_timeout: 8,
idle_timeout: 600, max_lifetime: 1800, sqlx_logging: true,
}
}
}
pub struct DatabaseService;
impl DatabaseService {
pub fn build_database_url(config: &DatabaseConfig) -> String {
let encoded_username = urlencoding::encode(&config.username);
let encoded_password = urlencoding::encode(&config.password);
let encoded_host = if config.host.contains(':') && !config.host.starts_with('[') {
format!("[{}]", config.host)
} else {
config.host.clone()
};
let encoded_database = urlencoding::encode(&config.database_name);
format!(
"{}://{}:{}@{}:{}/{}",
config.database_type,
encoded_username,
encoded_password,
encoded_host,
config.port,
encoded_database
)
}
pub async fn create_connection(
config: &DatabaseConfig,
schema: Option<&str>,
options: Option<ConnectionOptions>,
) -> Result<DatabaseConnection, DatabaseError> {
let database_url = Self::build_database_url(config);
let target_schema = schema.unwrap_or(&config.schema);
info!("正在连接数据库...");
info!(
"数据库地址: {}@{}:{}",
config.database_name, config.host, config.port
);
info!("使用 Schema: {}", target_schema);
let opts = options.unwrap_or_default();
let mut connect_options = ConnectOptions::new(&database_url);
let log_level = Self::parse_log_level(&config.logging_level);
connect_options
.max_connections(opts.max_connections)
.min_connections(opts.min_connections)
.connect_timeout(Duration::from_secs(opts.connect_timeout))
.acquire_timeout(Duration::from_secs(opts.acquire_timeout))
.idle_timeout(Duration::from_secs(opts.idle_timeout))
.max_lifetime(Duration::from_secs(opts.max_lifetime))
.sqlx_logging(opts.sqlx_logging)
.sqlx_logging_level(log_level);
let db = Database::connect(connect_options).await.map_err(|e| {
error!("数据库连接失败: {}", e);
DatabaseError::ConnectionFailed(format!("数据库连接失败: {}", e))
})?;
let backend = db.get_database_backend();
match backend {
DatabaseBackend::Postgres => {
if !config.use_pgbouncer {
Self::set_extra_float_digits(&db).await?;
} else {
info!("使用 PgBouncer,跳过设置 extra_float_digits");
}
match Self::set_schema(&db, target_schema).await {
Ok(_) => {}
Err(e) => {
if config.use_pgbouncer {
warn!(
"设置 schema 失败(可能是 PgBouncer transaction 模式导致): {}",
e
);
warn!(
"建议:1) 改用 session 模式,或 2) 在数据库层面设置: ALTER DATABASE {} SET search_path TO {}",
config.database_name, target_schema
);
} else {
return Err(e);
}
}
}
}
DatabaseBackend::MySql => {
info!("MySQL 数据库连接已建立");
}
DatabaseBackend::Sqlite => {
info!("SQLite 数据库连接已建立");
}
_ => {
info!("数据库连接已建立");
}
}
info!("✓ 数据库连接成功");
Ok(db)
}
pub async fn init(
config: &DatabaseConfig,
options: Option<ConnectionOptions>,
) -> Result<DatabaseConnection, DatabaseError> {
Self::create_connection(config, None, options).await
}
pub async fn set_extra_float_digits(db: &DatabaseConnection) -> Result<(), DatabaseError> {
db.execute_unprepared("SET extra_float_digits = 0")
.await
.map_err(|e| {
error!("设置 extra_float_digits 失败: {}", e);
DatabaseError::SchemaSetFailed(format!("设置 extra_float_digits 失败: {}", e))
})?;
info!("✓ 已设置 extra_float_digits = 0");
Ok(())
}
fn validate_schema_name(schema: &str) -> Result<(), DatabaseError> {
if schema.is_empty() {
return Err(DatabaseError::InvalidConfig(
"Schema 名称不能为空".to_string(),
));
}
if !schema
.chars()
.all(|c| c.is_alphanumeric() || c == '_' || c == '-')
{
return Err(DatabaseError::InvalidConfig(format!(
"Schema 名称包含非法字符: {},只允许字母、数字、下划线和连字符",
schema
)));
}
if schema.len() > 63 {
return Err(DatabaseError::InvalidConfig(format!(
"Schema 名称过长: {},最大长度为 63 字符",
schema
)));
}
Ok(())
}
pub async fn set_schema(db: &DatabaseConnection, schema: &str) -> Result<(), DatabaseError> {
Self::validate_schema_name(schema)?;
let sql = format!("SET search_path TO {}", schema);
db.execute_unprepared(&sql).await.map_err(|e| {
error!("设置 schema 失败: {}", e);
DatabaseError::SchemaSetFailed(format!("设置 schema 失败: {}", e))
})?;
info!("✓ 已设置 search_path 到 schema: {}", schema);
Ok(())
}
pub async fn test_connection(db: &DatabaseConnection) -> Result<(), DatabaseError> {
db.execute_unprepared("SELECT 1").await.map_err(|e| {
error!("数据库连接测试失败: {}", e);
DatabaseError::ConnectionFailed(format!("数据库连接测试失败: {}", e))
})?;
info!("✓ 数据库连接测试成功");
Ok(())
}
fn parse_log_level(level: &str) -> log::LevelFilter {
match level.to_lowercase().as_str() {
"off" => log::LevelFilter::Off,
"error" => log::LevelFilter::Error,
"warn" => log::LevelFilter::Warn,
"info" => log::LevelFilter::Info,
"debug" => log::LevelFilter::Debug,
"trace" => log::LevelFilter::Trace,
_ => {
warn!("未知的日志级别 '{}', 使用默认级别 Info", level);
log::LevelFilter::Info
}
}
}
pub fn validate_config(config: &DatabaseConfig) -> Result<(), DatabaseError> {
if config.host.is_empty() {
return Err(DatabaseError::InvalidConfig(
"数据库主机不能为空".to_string(),
));
}
if config.port == 0 {
return Err(DatabaseError::InvalidConfig(
format!("数据库端口无效: {},有效范围: 1-65535", config.port)
));
}
if config.database_name.is_empty() {
return Err(DatabaseError::InvalidConfig(
"数据库名称不能为空".to_string(),
));
}
if config.database_type != "sqlite" {
if config.username.is_empty() {
return Err(DatabaseError::InvalidConfig(
"数据库用户名不能为空".to_string(),
));
}
}
if !["postgres", "mysql", "sqlite"].contains(&config.database_type.as_str()) {
return Err(DatabaseError::InvalidConfig(
format!(
"不支持的数据库类型: {},支持的类型: postgres, mysql, sqlite",
config.database_type
)
));
}
let valid_log_levels = ["off", "error", "warn", "info", "debug", "trace"];
if !valid_log_levels.contains(&config.logging_level.to_lowercase().as_str()) {
warn!(
"无效的日志级别: {},将使用默认级别 info",
config.logging_level
);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_database_url() {
let config = DatabaseConfig {
database_type: "postgres".to_string(),
host: "localhost".to_string(),
port: 5432,
username: "user".to_string(),
password: "pass".to_string(),
database_name: "testdb".to_string(),
schema: "public".to_string(),
logging_level: "info".to_string(),
use_pgbouncer: false,
};
let url = DatabaseService::build_database_url(&config);
assert_eq!(url, "postgres://user:pass@localhost:5432/testdb");
}
#[test]
fn test_validate_config() {
let valid_config = DatabaseConfig {
database_type: "postgres".to_string(),
host: "localhost".to_string(),
port: 5432,
username: "user".to_string(),
password: "pass".to_string(),
database_name: "testdb".to_string(),
schema: "public".to_string(),
logging_level: "info".to_string(),
use_pgbouncer: false,
};
assert!(DatabaseService::validate_config(&valid_config).is_ok());
let invalid_config = DatabaseConfig {
database_type: "postgres".to_string(),
host: "".to_string(),
port: 5432,
username: "user".to_string(),
password: "pass".to_string(),
database_name: "testdb".to_string(),
schema: "public".to_string(),
logging_level: "info".to_string(),
use_pgbouncer: false,
};
assert!(DatabaseService::validate_config(&invalid_config).is_err());
}
}