nsq_async_rs/
connection_pool.rs

1use crate::connection::Connection;
2use crate::error::Result;
3use crate::protocol::IdentifyConfig;
4use dashmap::DashMap;
5use log::{info, warn};
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9/// 连接池配置
10#[derive(Debug, Clone)]
11pub struct ConnectionPoolConfig {
12    /// 连接超时时间
13    pub connection_timeout: Duration,
14    /// 最大空闲时间,超过此时间的连接会被清理
15    pub max_idle_time: Duration,
16    /// 健康检查间隔
17    pub health_check_interval: Duration,
18    /// 每个地址最大连接数
19    pub max_connections_per_host: usize,
20}
21
22impl Default for ConnectionPoolConfig {
23    fn default() -> Self {
24        Self {
25            connection_timeout: Duration::from_secs(5),
26            max_idle_time: Duration::from_secs(60),
27            health_check_interval: Duration::from_secs(30),
28            max_connections_per_host: 10,
29        }
30    }
31}
32
33/// 连接健康状态
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum HealthStatus {
36    /// 未知状态
37    Unknown,
38    /// 健康状态
39    Healthy,
40    /// 不健康状态
41    Unhealthy,
42}
43
44/// 池化连接
45#[derive(Debug)]
46pub struct PooledConnection {
47    /// 连接实例
48    pub connection: Arc<Connection>,
49    /// 上次使用时间
50    pub last_used: Instant,
51    /// 上次检查时间
52    pub last_checked: Instant,
53    /// 健康状态
54    pub health_status: HealthStatus,
55    /// 连接指纹,用于标识连接
56    pub fingerprint: String,
57}
58
59impl PooledConnection {
60    /// 创建新的池化连接
61    pub fn new(connection: Arc<Connection>, fingerprint: String) -> Self {
62        let now = Instant::now();
63        Self {
64            connection,
65            last_used: now,
66            last_checked: now,
67            health_status: HealthStatus::Unknown,
68            fingerprint,
69        }
70    }
71
72    /// 检查连接健康状态
73    pub async fn check_health(&mut self) -> bool {
74        self.last_checked = Instant::now();
75
76        match self.connection.handle_heartbeat().await {
77            Ok(_) => {
78                self.health_status = HealthStatus::Healthy;
79                true
80            }
81            Err(e) => {
82                warn!("连接健康检查失败: {}, 地址: {}", e, self.connection.addr());
83                self.health_status = HealthStatus::Unhealthy;
84                false
85            }
86        }
87    }
88
89    /// 更新最后使用时间
90    pub fn update_last_used(&mut self) {
91        self.last_used = Instant::now();
92    }
93
94    /// 检查连接是否空闲超时
95    pub fn is_idle(&self, max_idle_time: Duration) -> bool {
96        self.last_used.elapsed() > max_idle_time
97    }
98}
99
100/// 连接池管理器
101#[derive(Debug)]
102pub struct ConnectionPool {
103    /// 连接池配置
104    config: ConnectionPoolConfig,
105    /// 连接存储,使用DashMap提供高效的并发访问
106    connections: DashMap<String, Vec<PooledConnection>>,
107}
108
109impl ConnectionPool {
110    /// 创建新的连接池
111    pub fn new(config: ConnectionPoolConfig) -> Self {
112        Self {
113            config,
114            connections: DashMap::new(),
115        }
116    }
117
118    /// 生成连接指纹
119    fn generate_fingerprint(
120        addr: &str,
121        identify_config: &Option<IdentifyConfig>,
122        auth_secret: &Option<String>,
123    ) -> String {
124        // 简单实现,实际应用中可能需要更复杂的指纹生成算法
125        format!(
126            "{}:{}:{}",
127            addr,
128            identify_config
129                .as_ref()
130                .map(|c| format!("{:?}", c))
131                .unwrap_or_default(),
132            auth_secret.as_ref().unwrap_or(&String::new())
133        )
134    }
135
136    /// 获取连接
137    pub async fn get_connection(
138        &self,
139        addr: &str,
140        identify_config: Option<IdentifyConfig>,
141        auth_secret: Option<String>,
142    ) -> Result<Arc<Connection>> {
143        let fingerprint = Self::generate_fingerprint(addr, &identify_config, &auth_secret);
144
145        // 尝试从连接池获取健康的连接
146        if let Some(mut connections) = self.connections.get_mut(&fingerprint) {
147            // 查找健康的连接
148            for conn in connections.value_mut().iter_mut() {
149                if conn.health_status != HealthStatus::Unhealthy {
150                    conn.update_last_used();
151
152                    return Ok(Arc::clone(&conn.connection));
153                }
154            }
155
156            // 尝试恢复一个不健康的连接
157            for conn in connections.value_mut().iter_mut() {
158                if conn.check_health().await {
159                    conn.update_last_used();
160
161                    return Ok(Arc::clone(&conn.connection));
162                }
163            }
164
165            // 如果连接数未达到上限,创建新连接
166            if connections.value().len() < self.config.max_connections_per_host {
167                return self
168                    .create_and_store_connection(addr, identify_config, auth_secret, &fingerprint)
169                    .await;
170            }
171
172            // 连接池已满,找到最久未使用的连接并替换
173            if let Some(oldest_index) = connections
174                .value()
175                .iter()
176                .enumerate()
177                .min_by_key(|(_, conn)| conn.last_used)
178                .map(|(i, _)| i)
179            {
180                connections.value_mut().remove(oldest_index);
181                return self
182                    .create_and_store_connection(addr, identify_config, auth_secret, &fingerprint)
183                    .await;
184            }
185        }
186
187        // 如果连接池中没有该地址的连接,创建新连接
188        self.create_and_store_connection(addr, identify_config, auth_secret, &fingerprint)
189            .await
190    }
191
192    /// 创建并存储连接
193    async fn create_and_store_connection(
194        &self,
195        addr: &str,
196        identify_config: Option<IdentifyConfig>,
197        auth_secret: Option<String>,
198        fingerprint: &str,
199    ) -> Result<Arc<Connection>> {
200        // 创建新连接
201        let connection = Connection::new(
202            addr,
203            identify_config.clone(),
204            auth_secret.clone(),
205            Duration::from_secs(60), // 默认读超时
206            Duration::from_secs(5),  // 默认写超时
207        )
208        .await?;
209
210        let connection = Arc::new(connection);
211        let pooled_connection =
212            PooledConnection::new(Arc::clone(&connection), fingerprint.to_string());
213
214        // 添加到连接池
215        self.connections
216            .entry(fingerprint.to_string())
217            .or_default()
218            .push(pooled_connection);
219
220        Ok(connection)
221    }
222
223    /// 启动连接池清理任务
224    pub fn start_cleanup_task(pool: Arc<ConnectionPool>) {
225        tokio::spawn(async move {
226            let mut interval = tokio::time::interval(Duration::from_secs(30));
227            loop {
228                interval.tick().await;
229                pool.cleanup_idle_connections().await;
230            }
231        });
232    }
233
234    /// 启动健康检查任务
235    pub fn start_health_check_task(pool: Arc<ConnectionPool>) {
236        tokio::spawn(async move {
237            let mut interval = tokio::time::interval(pool.config.health_check_interval);
238            loop {
239                interval.tick().await;
240                pool.check_connections_health().await;
241            }
242        });
243    }
244
245    /// 清理空闲连接
246    pub async fn cleanup_idle_connections(&self) {
247        let max_idle_time = self.config.max_idle_time;
248
249        for mut entry in self.connections.iter_mut() {
250            let before_count = entry.value().len();
251            entry
252                .value_mut()
253                .retain(|conn| !conn.is_idle(max_idle_time));
254            let after_count = entry.value().len();
255
256            if before_count > after_count {}
257        }
258    }
259
260    /// 检查连接健康状态
261    pub async fn check_connections_health(&self) {
262        for mut entry in self.connections.iter_mut() {
263            for conn in entry.value_mut().iter_mut() {
264                // 只检查超过健康检查间隔的连接
265                if conn.last_checked.elapsed() > self.config.health_check_interval {
266                    let _ = conn.check_health().await;
267                }
268            }
269        }
270    }
271
272    /// 获取连接池统计信息
273    pub fn get_stats(&self) -> ConnectionPoolStats {
274        let mut stats = ConnectionPoolStats::default();
275
276        for entry in self.connections.iter() {
277            stats.total_connections += entry.value().len();
278
279            for conn in entry.value() {
280                match conn.health_status {
281                    HealthStatus::Healthy => stats.healthy_connections += 1,
282                    HealthStatus::Unhealthy => stats.unhealthy_connections += 1,
283                    HealthStatus::Unknown => stats.unknown_status_connections += 1,
284                }
285
286                if conn.is_idle(self.config.max_idle_time) {
287                    stats.idle_connections += 1;
288                }
289            }
290        }
291
292        stats.host_count = self.connections.len();
293        stats
294    }
295}
296
297/// 连接池统计信息
298#[derive(Debug, Default, Clone)]
299pub struct ConnectionPoolStats {
300    /// 总连接数
301    pub total_connections: usize,
302    /// 健康连接数
303    pub healthy_connections: usize,
304    /// 不健康连接数
305    pub unhealthy_connections: usize,
306    /// 未知状态连接数
307    pub unknown_status_connections: usize,
308    /// 空闲连接数
309    pub idle_connections: usize,
310    /// 主机数量
311    pub host_count: usize,
312}
313
314/// 创建一个新的连接池实例
315pub fn create_connection_pool(config: ConnectionPoolConfig) -> Arc<ConnectionPool> {
316    let pool = Arc::new(ConnectionPool::new(config));
317
318    // 启动清理和健康检查任务
319    ConnectionPool::start_cleanup_task(Arc::clone(&pool));
320    ConnectionPool::start_health_check_task(Arc::clone(&pool));
321
322    info!("NSQ连接池已初始化");
323    pool
324}
325
326/// 创建默认配置的连接池实例
327pub fn create_default_connection_pool() -> Arc<ConnectionPool> {
328    create_connection_pool(ConnectionPoolConfig::default())
329}