Skip to main content

br_pgsql/
pools.rs

1use crate::config::Config;
2use crate::connect::Connect;
3use crate::error::PgsqlError;
4use log::{error, info, warn};
5use std::collections::VecDeque;
6use std::sync::{Arc, Condvar, Mutex, MutexGuard, PoisonError, Weak};
7use std::thread;
8use std::time::Duration;
9
10/// 空闲连接最大存活时间(5分钟),超过则丢弃而非归还池
11const MAX_IDLE_SECS: u64 = 300;
12/// 连接最大生命周期(30分钟),超过则丢弃,防止长期连接累积服务端状态
13const MAX_CONN_LIFETIME_SECS: u64 = 1800;
14struct PoolInner {
15    idle: VecDeque<Connect>,
16    total: usize,
17    max: usize,
18    /// 当前被事务持有的连接数
19    txn_total: usize,
20    /// 事务连接上限(防止事务饿死普通查询)
21    txn_max: usize,
22}
23
24/// 预占位守卫:Create/GotConn-rebuild 路径中,如果 Connect::new() panic,
25/// drop 时自动归还 total 计数并唤醒等待者,防止 total 永久膨胀。
26struct SlotGuard<'a> {
27    mutex: &'a Mutex<PoolInner>,
28    condvar: &'a Condvar,
29    active: bool,
30    for_transaction: bool,
31}
32
33impl<'a> SlotGuard<'a> {
34    fn new(mutex: &'a Mutex<PoolInner>, condvar: &'a Condvar, for_transaction: bool) -> Self {
35        Self {
36            mutex,
37            condvar,
38            active: true,
39            for_transaction,
40        }
41    }
42
43    /// 标记成功,不再需要回滚
44    fn disarm(&mut self) {
45        self.active = false;
46    }
47}
48
49impl Drop for SlotGuard<'_> {
50    fn drop(&mut self) {
51        if self.active {
52            let mut pool = lock_inner(self.mutex);
53            pool.total = pool.total.saturating_sub(1);
54            if self.for_transaction {
55                pool.txn_total = pool.txn_total.saturating_sub(1);
56            }
57            drop(pool);
58            self.condvar.notify_one();
59        }
60    }
61}
62#[derive(Clone)]
63pub struct Pools {
64    pub config: Config,
65    inner: Arc<(Mutex<PoolInner>, Condvar)>,
66}
67fn lock_inner(mutex: &Mutex<PoolInner>) -> MutexGuard<'_, PoolInner> {
68    mutex.lock().unwrap_or_else(PoisonError::into_inner)
69}
70pub struct ConnectionGuard {
71    pool: Pools,
72    conn: Option<Connect>,
73}
74impl ConnectionGuard {
75    pub fn new(pool: Pools) -> Result<Self, PgsqlError> {
76        let conn = pool.get_connect()?;
77        Ok(Self {
78            pool,
79            conn: Some(conn),
80        })
81    }
82    pub fn conn(&mut self) -> &mut Connect {
83        self.conn.as_mut().expect("connection already released")
84    }
85    /// 丢弃连接(不归还到池),用于连接已断开的场景
86    pub fn discard(&mut self) {
87        if let Some(_conn) = self.conn.take() {
88            let (ref mutex, ref condvar) = *self.pool.inner;
89            let mut pool = lock_inner(mutex);
90            pool.total = pool.total.saturating_sub(1);
91            drop(pool);
92            condvar.notify_one();
93        }
94    }
95}
96impl Drop for ConnectionGuard {
97    fn drop(&mut self) {
98        if let Some(conn) = self.conn.take() {
99            self.pool.release_conn(conn);
100        }
101    }
102}
103impl Pools {
104    pub fn get_guard(&self) -> Result<ConnectionGuard, PgsqlError> {
105        ConnectionGuard::new(self.clone())
106    }
107    pub fn new(config: Config, size: usize) -> Result<Self, PgsqlError> {
108        let init_size = 2.min(size);
109        let mut idle = VecDeque::with_capacity(size);
110        let mut created = 0;
111        for _ in 0..init_size {
112            match Connect::new(config.clone()) {
113                Ok(conn) => {
114                    idle.push_back(conn);
115                    created += 1;
116                }
117                Err(e) => warn!("初始化连接失败: {e}"),
118            }
119        }
120        let txn_max = (size / 3).max(1);
121        let inner = PoolInner {
122            idle,
123            total: created,
124            max: size,
125            txn_total: 0,
126            txn_max,
127        };
128
129        let arc = Arc::new((Mutex::new(inner), Condvar::new()));
130
131        // 后台空闲连接回收线程(每60秒清理一次,Weak 引用自动退出)
132        let weak = Arc::downgrade(&arc);
133        thread::spawn(move || {
134            Self::reaper_loop(weak);
135        });
136
137        Ok(Self { config, inner: arc })
138    }
139
140    /// 内部统一获取连接逻辑
141    fn acquire_connect(&self, for_transaction: bool) -> Result<Connect, PgsqlError> {
142        let mut attempts = 0;
143        let (ref mutex, ref condvar) = *self.inner;
144        let label = if for_transaction { "事务" } else { "" };
145        #[cfg(not(test))]
146        const BASE_SLEEP_MS: u64 = 200;
147        #[cfg(test)]
148        const BASE_SLEEP_MS: u64 = 1;
149        #[cfg(not(test))]
150        const MAX_SLEEP_MS: u64 = 2000;
151        #[cfg(test)]
152        const MAX_SLEEP_MS: u64 = 5;
153        #[cfg(not(test))]
154        const WAIT_TIMEOUT: Duration = Duration::from_secs(2);
155        #[cfg(test)]
156        const WAIT_TIMEOUT: Duration = Duration::from_millis(5);
157
158        let timeout_msg = if for_transaction {
159            "无法获取事务连接,重试超时"
160        } else {
161            "无法连接数据库,重试超时"
162        };
163
164        loop {
165            if attempts >= 5 {
166                return Err(PgsqlError::Pool(timeout_msg.into()));
167            }
168
169            let action = {
170                let mut pool = lock_inner(mutex);
171                // 事务连接受 txn_max 限制,防止饿死普通查询
172                if for_transaction && pool.txn_total >= pool.txn_max && pool.total >= pool.max {
173                    Action::Wait
174                } else if let Some(conn) = pool.idle.pop_front() {
175                    if for_transaction {
176                        pool.txn_total += 1;
177                    }
178                    Action::GotConn(Box::new(conn))
179                } else if pool.total < pool.max {
180                    pool.total += 1; // 预占位
181                    if for_transaction {
182                        pool.txn_total += 1;
183                    }
184                    Action::Create
185                } else {
186                    Action::Wait
187                }
188            };
189
190            match action {
191                Action::GotConn(mut conn) => {
192                    // 超过最大生命周期的连接直接丢弃
193                    if conn.age().as_secs() > MAX_CONN_LIFETIME_SECS {
194                        {
195                            let mut pool = lock_inner(mutex);
196                            pool.total = pool.total.saturating_sub(1);
197                            if for_transaction {
198                                pool.txn_total = pool.txn_total.saturating_sub(1);
199                            }
200                        }
201                        log::debug!("{}连接存活超过{}秒,已丢弃", label, MAX_CONN_LIFETIME_SECS);
202                        continue;
203                    }
204                    // 锁外做健康检查(is_valid 含懒 SELECT 1)
205                    if conn.is_valid() {
206                        conn.touch();
207                        return Ok(*conn);
208                    }
209                    // 连接失效,丢弃并重新循环(不消耗重试次数,参考 mysql crate 模式)
210                    {
211                        let mut pool = lock_inner(mutex);
212                        pool.total = pool.total.saturating_sub(1);
213                        if for_transaction {
214                            pool.txn_total = pool.txn_total.saturating_sub(1);
215                        }
216                    }
217                    warn!(
218                        "{}连接失效已丢弃,当前总连接数量: {}",
219                        label,
220                        self.total_connections()
221                    );
222                    // 不增加 attempts,直接重新循环获取连接
223                    continue;
224                }
225
226                Action::Create => {
227                    // SlotGuard 保护 Create 路径的 total 预占位
228                    let mut guard = SlotGuard::new(mutex, condvar, for_transaction);
229                    match Connect::new(self.config.clone()) {
230                        Ok(new_conn) => {
231                            guard.disarm();
232                            return Ok(new_conn);
233                        }
234                        Err(e) => {
235                            // guard drop 会自动 total -= 1 + notify
236                            drop(guard);
237                            let sleep_ms = BASE_SLEEP_MS
238                                .saturating_mul(1u64 << attempts.min(3))
239                                .min(MAX_SLEEP_MS);
240                            attempts += 1;
241                            error!("创建{}连接失败({}ms后重试): {}", label, sleep_ms, e);
242                            thread::sleep(Duration::from_millis(sleep_ms));
243                        }
244                    }
245                }
246                Action::Wait => {
247                    let pool = lock_inner(mutex);
248                    let (_pool, timeout) = condvar
249                        .wait_timeout(pool, WAIT_TIMEOUT)
250                        .unwrap_or_else(PoisonError::into_inner);
251                    drop(_pool);
252                    if timeout.timed_out() {
253                        attempts += 1;
254                    }
255                }
256            }
257        }
258    }
259    pub fn get_connect(&self) -> Result<Connect, PgsqlError> {
260        self.acquire_connect(false)
261    }
262    /// 事务专用连接,不归还到池
263    pub fn get_connect_for_transaction(&self) -> Result<Connect, PgsqlError> {
264        self.acquire_connect(true)
265    }
266    pub fn release_transaction_conn(&self) {
267        let (ref mutex, ref condvar) = *self.inner;
268        let mut pool = lock_inner(mutex);
269        pool.total = pool.total.saturating_sub(1);
270        pool.txn_total = pool.txn_total.saturating_sub(1);
271        drop(pool);
272        condvar.notify_one();
273    }
274    /// 归还事务连接到连接池(而非销毁),同时递减 txn_total
275    pub fn release_transaction_conn_with_conn(&self, conn: Connect) {
276        let (ref mutex, _) = *self.inner;
277        {
278            let mut pool = lock_inner(mutex);
279            pool.txn_total = pool.txn_total.saturating_sub(1);
280        }
281        self.release_conn(conn);
282    }
283    pub fn release_conn(&self, conn: Connect) {
284        let (ref mutex, ref condvar) = *self.inner;
285        if !conn.peer_valid() {
286            let mut pool = lock_inner(mutex);
287            pool.total = pool.total.saturating_sub(1);
288            drop(pool);
289            condvar.notify_one();
290            warn!("释放时检测到坏连接,已丢弃");
291            return;
292        }
293        if conn.age().as_secs() > MAX_CONN_LIFETIME_SECS {
294            let mut pool = lock_inner(mutex);
295            pool.total = pool.total.saturating_sub(1);
296            drop(pool);
297            condvar.notify_one();
298            log::debug!("释放时连接存活超过{}秒,已丢弃", MAX_CONN_LIFETIME_SECS);
299            return;
300        }
301        if conn.idle_elapsed().as_secs() > MAX_IDLE_SECS {
302            let mut pool = lock_inner(mutex);
303            pool.total = pool.total.saturating_sub(1);
304            drop(pool);
305            condvar.notify_one();
306            log::debug!("连接空闲超过{}秒,已丢弃", MAX_IDLE_SECS);
307            return;
308        }
309        let mut pool = lock_inner(mutex);
310        if pool.idle.len() < pool.max {
311            pool.idle.push_back(conn);
312        } else {
313            pool.total = pool.total.saturating_sub(1);
314            warn!("连接池已满,丢弃连接");
315        }
316        drop(pool);
317        condvar.notify_one();
318    }
319    pub fn idle_pool_size(&self) -> usize {
320        let (ref mutex, _) = *self.inner;
321        let pool = lock_inner(mutex);
322        pool.idle.len()
323    }
324    pub fn total_connections(&self) -> usize {
325        let (ref mutex, _) = *self.inner;
326        let pool = lock_inner(mutex);
327        pool.total
328    }
329    pub fn borrowed_connections(&self) -> usize {
330        let (ref mutex, _) = *self.inner;
331        let pool = lock_inner(mutex);
332        pool.total.saturating_sub(pool.idle.len())
333    }
334    /// 清空池中所有空闲连接(failover 场景:所有连接同时死亡时调用)
335    pub fn flush_idle(&self) {
336        let (ref mutex, _) = *self.inner;
337        let mut pool = lock_inner(mutex);
338        let flushed = pool.idle.len();
339        pool.total = pool.total.saturating_sub(flushed);
340        pool.idle.clear();
341        if flushed > 0 {
342            warn!("清空池中 {flushed} 个空闲连接(疑似批量失效)");
343        }
344    }
345    pub fn cleanup_idle_connections(&self) {
346        let (ref mutex, _) = *self.inner;
347        let mut pool = lock_inner(mutex);
348        let before = pool.idle.len();
349        pool.idle.retain(|conn| {
350            let peer_ok = conn.peer_valid();
351            let idle_ok = conn.idle_elapsed().as_secs() <= MAX_IDLE_SECS;
352            let lifetime_ok = conn.age().as_secs() <= MAX_CONN_LIFETIME_SECS;
353            if !peer_ok {
354                log::debug!("检测到无效连接,已移除");
355            } else if !idle_ok {
356                log::debug!("检测到空闲超时连接,已移除");
357            } else if !lifetime_ok {
358                log::debug!("检测到超过最大生命周期连接,已移除");
359            }
360            peer_ok && idle_ok && lifetime_ok
361        });
362        let removed = before - pool.idle.len();
363        pool.total = pool.total.saturating_sub(removed);
364        if removed > 0 {
365            log::debug!(
366                "空闲连接清理完成: 移除 {removed} 个,剩余 {} 个",
367                pool.idle.len()
368            );
369        }
370    }
371    /// 后台回收线程主循环:每60秒清理空闲/过期连接,Weak 引用失效时自动退出
372    fn reaper_loop(weak: Weak<(Mutex<PoolInner>, Condvar)>) {
373        #[cfg(not(test))]
374        const INTERVAL: Duration = Duration::from_secs(60);
375        #[cfg(test)]
376        const INTERVAL: Duration = Duration::from_millis(50);
377        loop {
378            thread::sleep(INTERVAL);
379            let arc = match weak.upgrade() {
380                Some(a) => a,
381                None => {
382                    info!("连接池已释放,回收线程退出");
383                    return;
384                }
385            };
386            let (ref mutex, _) = *arc;
387            let mut pool = lock_inner(mutex);
388            let before = pool.idle.len();
389            pool.idle.retain(|conn| {
390                conn.peer_valid()
391                    && conn.idle_elapsed().as_secs() <= MAX_IDLE_SECS
392                    && conn.age().as_secs() <= MAX_CONN_LIFETIME_SECS
393            });
394            let removed = before - pool.idle.len();
395            pool.total = pool.total.saturating_sub(removed);
396            if removed > 0 {
397                info!(
398                    "后台回收: 移除 {removed} 个空闲连接,剩余 {} 个",
399                    pool.idle.len()
400                );
401            }
402        }
403    }
404}
405
406/// acquire_connect 内部决策
407enum Action {
408    GotConn(Box<Connect>),
409    Create,
410    Wait,
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416    use std::io::{Read as IoRead, Write as IoWrite};
417    use std::net::TcpListener;
418    use std::sync::atomic::{AtomicBool, Ordering};
419
420    fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
421        let mut m = Vec::with_capacity(5 + payload.len());
422        m.push(tag);
423        m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
424        m.extend_from_slice(payload);
425        m
426    }
427
428    fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
429        let mut body = Vec::new();
430        body.extend(&auth_type.to_be_bytes());
431        body.extend_from_slice(extra);
432        pg_msg(b'R', &body)
433    }
434
435    fn post_auth_ok() -> Vec<u8> {
436        let mut v = Vec::new();
437        v.extend(pg_auth(0, &[]));
438        v.extend(pg_msg(b'S', b"server_version\x0015.0\x00"));
439        let mut k = Vec::new();
440        k.extend(&1u32.to_be_bytes());
441        k.extend(&2u32.to_be_bytes());
442        v.extend(pg_msg(b'K', &k));
443        v.extend(pg_msg(b'Z', b"I"));
444        v
445    }
446
447    fn simple_query_response() -> Vec<u8> {
448        let mut r = Vec::new();
449        r.extend(pg_msg(b'1', &[]));
450        r.extend(pg_msg(b'2', &[]));
451        let mut rd = Vec::new();
452        rd.extend(&1u16.to_be_bytes());
453        rd.extend(b"c\x00");
454        rd.extend(&0u32.to_be_bytes());
455        rd.extend(&1u16.to_be_bytes());
456        rd.extend(&23u32.to_be_bytes());
457        rd.extend(&4i16.to_be_bytes());
458        rd.extend(&(-1i32).to_be_bytes());
459        rd.extend(&0u16.to_be_bytes());
460        r.extend(pg_msg(b'T', &rd));
461        let mut dr = Vec::new();
462        dr.extend(&1u16.to_be_bytes());
463        dr.extend(&1u32.to_be_bytes());
464        dr.push(b'1');
465        r.extend(pg_msg(b'D', &dr));
466        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
467        r.extend(pg_msg(b'Z', b"I"));
468        r
469    }
470
471    fn spawn_multi_server(stop: Arc<AtomicBool>) -> u16 {
472        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
473        let port = listener.local_addr().unwrap().port();
474        thread::spawn(move || {
475            listener.set_nonblocking(true).unwrap();
476            while !stop.load(Ordering::Relaxed) {
477                match listener.accept() {
478                    Ok((s, _)) => {
479                        s.set_nonblocking(false).ok();
480                        let stop2 = stop.clone();
481                        thread::spawn(move || {
482                            s.set_read_timeout(Some(Duration::from_secs(5))).ok();
483                            let mut s = s;
484                            let mut buf = [0u8; 4096];
485                            if s.read(&mut buf).unwrap_or(0) == 0 {
486                                return;
487                            }
488                            let _ = s.write_all(&pg_auth(3, &[]));
489                            if s.read(&mut buf).unwrap_or(0) == 0 {
490                                return;
491                            }
492                            let _ = s.write_all(&post_auth_ok());
493                            while !stop2.load(Ordering::Relaxed) {
494                                match s.read(&mut buf) {
495                                    Ok(0) | Err(_) => break,
496                                    Ok(_) => {
497                                        let _ = s.write_all(&simple_query_response());
498                                    }
499                                }
500                            }
501                        });
502                    }
503                    Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
504                        thread::sleep(Duration::from_millis(5));
505                    }
506                    Err(_) => break,
507                }
508            }
509        });
510        thread::sleep(Duration::from_millis(50));
511        port
512    }
513
514    fn mock_config(port: u16) -> Config {
515        Config {
516            debug: false,
517            hostname: "127.0.0.1".into(),
518            hostport: port as i32,
519            username: "u".into(),
520            userpass: "p".into(),
521            database: "d".into(),
522            charset: "utf8".into(),
523            pool_max: 5,
524            sslmode: "disable".into(),
525        }
526    }
527
528    #[test]
529    fn pools_all_paths() {
530        let stop = Arc::new(AtomicBool::new(false));
531        let port = spawn_multi_server(stop.clone());
532        let cfg = mock_config(port);
533
534        // 基本创建 + 计数
535        let pools = Pools::new(cfg.clone(), 10).unwrap();
536        assert_eq!(pools.total_connections(), 2);
537        assert_eq!(pools.idle_pool_size(), 2);
538        assert_eq!(pools.borrowed_connections(), 0);
539
540        // 借出一个
541        let conn1 = pools.get_connect().unwrap();
542        assert_eq!(pools.idle_pool_size(), 1);
543        assert!(pools.borrowed_connections() > 0);
544
545        // 归还
546        let idle_before = pools.idle_pool_size();
547        pools.release_conn(conn1);
548        assert!(pools.idle_pool_size() > idle_before);
549
550        // 借出后 drop(不归还到池,total 不变因为 drop 不通知池)
551        let conn2 = pools.get_connect().unwrap();
552        drop(conn2);
553
554        // 归还坏连接 → total 减少
555        let mut conn3 = pools.get_connect().unwrap();
556        let total_before = pools.total_connections();
557        conn3._close();
558        pools.release_conn(conn3);
559        assert!(pools.total_connections() <= total_before);
560
561        // cleanup
562        pools.cleanup_idle_connections();
563
564        // ConnectionGuard
565        {
566            let mut guard = pools.get_guard().unwrap();
567            let qr = guard.conn().query("SELECT 1");
568            assert!(qr.is_ok());
569        }
570        assert!(pools.idle_pool_size() > 0);
571
572        // 事务连接
573        let pools2 = Pools::new(cfg.clone(), 10).unwrap();
574        let txn = pools2.get_connect_for_transaction().unwrap();
575        let total_before = pools2.total_connections();
576        pools2.release_transaction_conn();
577        assert_eq!(pools2.total_connections(), total_before - 1);
578        drop(txn);
579
580        // pool_max=1 → 池满时 get_connect 超时
581        let pools3 = Pools::new(cfg.clone(), 1).unwrap();
582        let held = pools3.get_connect().unwrap();
583        let result = pools3.get_connect();
584        assert!(result.is_err());
585        drop(held);
586
587        // 坏配置 → 0 连接
588        let bad_cfg = mock_config(1);
589        let pools4 = Pools::new(bad_cfg.clone(), 5).unwrap();
590        assert_eq!(pools4.total_connections(), 0);
591
592        // 坏配置 get_connect 失败
593        let pools5 = Pools::new(bad_cfg.clone(), 5).unwrap();
594        let result = pools5.get_connect();
595        assert!(result.is_err());
596
597        // 坏配置 get_connect_for_transaction 失败
598        let pools6 = Pools::new(bad_cfg.clone(), 5).unwrap();
599        let result = pools6.get_connect_for_transaction();
600        assert!(result.is_err());
601
602        // pool_max=1 初始化
603        let pools7 = Pools::new(cfg.clone(), 1).unwrap();
604        assert_eq!(pools7.total_connections(), 1);
605
606        stop.store(true, Ordering::Relaxed);
607    }
608}