1use sea_orm::{ConnectionTrait, DatabaseConnection, DatabaseBackend};
6use serde::{Deserialize, Serialize};
7use thiserror::Error;
8use tracing::debug;
9
10#[derive(Error, Debug)]
12pub enum PoolError {
13 #[error("连接池信息获取失败: {0}")]
14 InfoFailed(String),
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct PoolStats {
20 pub max_connections: u32,
22 pub min_connections: u32,
24 pub active_connections: Option<u32>,
26 pub idle_connections: Option<u32>,
28 pub waiting_requests: Option<u32>,
30}
31
32pub struct ConnectionPoolService;
34
35impl ConnectionPoolService {
36 pub async fn get_pool_stats(
47 db: &DatabaseConnection,
48 max_connections: Option<u32>,
49 min_connections: Option<u32>,
50 ) -> Result<PoolStats, PoolError> {
51 let backend = db.get_database_backend();
52
53 let (active, idle, waiting) = match backend {
55 DatabaseBackend::Postgres => {
56 match Self::get_postgres_pool_stats(db).await {
58 Ok(stats) => stats,
59 Err(e) => {
60 debug!("无法获取 PostgreSQL 连接池统计信息: {}", e);
61 (None, None, None)
62 }
63 }
64 }
65 _ => {
66 (None, None, None)
68 }
69 };
70
71 Ok(PoolStats {
72 max_connections: max_connections.unwrap_or(100),
73 min_connections: min_connections.unwrap_or(5),
74 active_connections: active,
75 idle_connections: idle,
76 waiting_requests: waiting,
77 })
78 }
79
80 async fn get_postgres_pool_stats(
87 _db: &DatabaseConnection,
88 ) -> Result<(Option<u32>, Option<u32>, Option<u32>), PoolError> {
89 debug!("PostgreSQL 连接池统计信息需要通过 sqlx 连接池获取");
98 Ok((None, None, None))
99 }
100
101 pub async fn health_check(db: &DatabaseConnection) -> Result<bool, PoolError> {
103 db.execute_unprepared("SELECT 1")
105 .await
106 .map_err(|e| PoolError::InfoFailed(format!("连接池健康检查失败: {}", e)))?;
107
108 Ok(true)
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 #[test]
117 fn test_pool_stats_serialization() {
118 let stats = PoolStats {
119 max_connections: 100,
120 min_connections: 5,
121 active_connections: Some(10),
122 idle_connections: Some(5),
123 waiting_requests: Some(0),
124 };
125
126 let json = serde_json::to_string(&stats).unwrap();
127 assert!(json.contains("max_connections"));
128 }
129}