Skip to main content

br_pgsql/
pools.rs

1use crate::config::Config;
2use crate::connect::Connect;
3use crate::error::PgsqlError;
4use log::{error, warn};
5use std::collections::VecDeque;
6use std::sync::{Arc, Condvar, Mutex, MutexGuard, PoisonError};
7use std::thread;
8use std::time::Duration;
9
10/// 空闲连接最大存活时间(5分钟),超过则丢弃而非归还池
11const MAX_IDLE_SECS: u64 = 300;
12/// 连接最大生命周期(30分钟),超过则丢弃,防止长期连接累积服务端状态
13const MAX_CONN_LIFETIME_SECS: u64 = 1800;
14struct PoolInner {
15    idle: VecDeque<Connect>,
16    total: usize,
17    max: usize,
18    /// 当前被事务持有的连接数
19    txn_total: usize,
20    /// 事务连接上限(防止事务饿死普通查询)
21    txn_max: usize,
22}
23
24/// 预占位守卫:Create/GotConn-rebuild 路径中,如果 Connect::new() panic,
25/// drop 时自动归还 total 计数并唤醒等待者,防止 total 永久膨胀。
26struct SlotGuard<'a> {
27    mutex: &'a Mutex<PoolInner>,
28    condvar: &'a Condvar,
29    active: bool,
30    for_transaction: bool,
31}
32
33impl<'a> SlotGuard<'a> {
34    fn new(mutex: &'a Mutex<PoolInner>, condvar: &'a Condvar, for_transaction: bool) -> Self {
35        Self {
36            mutex,
37            condvar,
38            active: true,
39            for_transaction,
40        }
41    }
42
43    /// 标记成功,不再需要回滚
44    fn disarm(&mut self) {
45        self.active = false;
46    }
47}
48
49impl Drop for SlotGuard<'_> {
50    fn drop(&mut self) {
51        if self.active {
52            let mut pool = lock_inner(self.mutex);
53            pool.total = pool.total.saturating_sub(1);
54            if self.for_transaction {
55                pool.txn_total = pool.txn_total.saturating_sub(1);
56            }
57            drop(pool);
58            self.condvar.notify_one();
59        }
60    }
61}
62#[derive(Clone)]
63pub struct Pools {
64    pub config: Config,
65    inner: Arc<(Mutex<PoolInner>, Condvar)>,
66}
67fn lock_inner(mutex: &Mutex<PoolInner>) -> MutexGuard<'_, PoolInner> {
68    mutex.lock().unwrap_or_else(PoisonError::into_inner)
69}
70pub struct ConnectionGuard {
71    pool: Pools,
72    conn: Option<Connect>,
73}
74impl ConnectionGuard {
75    pub fn new(pool: Pools) -> Result<Self, PgsqlError> {
76        let conn = pool.get_connect()?;
77        Ok(Self {
78            pool,
79            conn: Some(conn),
80        })
81    }
82    pub fn conn(&mut self) -> &mut Connect {
83        self.conn.as_mut().expect("connection already released")
84    }
85    /// 丢弃连接(不归还到池),用于连接已断开的场景
86    pub fn discard(&mut self) {
87        if let Some(_conn) = self.conn.take() {
88            let (ref mutex, ref condvar) = *self.pool.inner;
89            let mut pool = lock_inner(mutex);
90            pool.total = pool.total.saturating_sub(1);
91            drop(pool);
92            condvar.notify_one();
93        }
94    }
95}
96impl Drop for ConnectionGuard {
97    fn drop(&mut self) {
98        if let Some(conn) = self.conn.take() {
99            self.pool.release_conn(conn);
100        }
101    }
102}
103impl Pools {
104    pub fn get_guard(&self) -> Result<ConnectionGuard, PgsqlError> {
105        ConnectionGuard::new(self.clone())
106    }
107    pub fn new(config: Config, size: usize) -> Result<Self, PgsqlError> {
108        let init_size = 2.min(size);
109        let mut idle = VecDeque::with_capacity(size);
110        let mut created = 0;
111        for _ in 0..init_size {
112            match Connect::new(config.clone()) {
113                Ok(conn) => {
114                    idle.push_back(conn);
115                    created += 1;
116                }
117                Err(e) => warn!("初始化连接失败: {e}"),
118            }
119        }
120        let txn_max = (size / 3).max(1);
121        let inner = PoolInner {
122            idle,
123            total: created,
124            max: size,
125            txn_total: 0,
126            txn_max,
127        };
128
129        Ok(Self {
130            config,
131            inner: Arc::new((Mutex::new(inner), Condvar::new())),
132        })
133    }
134
135    /// 内部统一获取连接逻辑
136    fn acquire_connect(&self, for_transaction: bool) -> Result<Connect, PgsqlError> {
137        let mut attempts = 0;
138        let (ref mutex, ref condvar) = *self.inner;
139        let label = if for_transaction { "事务" } else { "" };
140        #[cfg(not(test))]
141        const BASE_SLEEP_MS: u64 = 200;
142        #[cfg(test)]
143        const BASE_SLEEP_MS: u64 = 1;
144        #[cfg(not(test))]
145        const MAX_SLEEP_MS: u64 = 2000;
146        #[cfg(test)]
147        const MAX_SLEEP_MS: u64 = 5;
148        #[cfg(not(test))]
149        const WAIT_TIMEOUT: Duration = Duration::from_secs(2);
150        #[cfg(test)]
151        const WAIT_TIMEOUT: Duration = Duration::from_millis(5);
152
153        let timeout_msg = if for_transaction {
154            "无法获取事务连接,重试超时"
155        } else {
156            "无法连接数据库,重试超时"
157        };
158
159        loop {
160            if attempts >= 5 {
161                return Err(PgsqlError::Pool(timeout_msg.into()));
162            }
163
164            let action = {
165                let mut pool = lock_inner(mutex);
166                // 事务连接受 txn_max 限制,防止饿死普通查询
167                if for_transaction && pool.txn_total >= pool.txn_max && pool.total >= pool.max {
168                    Action::Wait
169                } else if let Some(conn) = pool.idle.pop_front() {
170                    if for_transaction {
171                        pool.txn_total += 1;
172                    }
173                    Action::GotConn(Box::new(conn))
174                } else if pool.total < pool.max {
175                    pool.total += 1; // 预占位
176                    if for_transaction {
177                        pool.txn_total += 1;
178                    }
179                    Action::Create
180                } else {
181                    Action::Wait
182                }
183            };
184
185            match action {
186                Action::GotConn(mut conn) => {
187                    // 超过最大生命周期的连接直接丢弃
188                    if conn.age().as_secs() > MAX_CONN_LIFETIME_SECS {
189                        {
190                            let mut pool = lock_inner(mutex);
191                            pool.total = pool.total.saturating_sub(1);
192                            if for_transaction {
193                                pool.txn_total = pool.txn_total.saturating_sub(1);
194                            }
195                        }
196                        log::debug!("{}连接存活超过{}秒,已丢弃", label, MAX_CONN_LIFETIME_SECS);
197                        continue;
198                    }
199                    // 锁外做健康检查(is_valid 含懒 SELECT 1)
200                    if conn.is_valid() {
201                        conn.touch();
202                        return Ok(*conn);
203                    }
204                    // 连接失效,丢弃并重新循环(不消耗重试次数,参考 mysql crate 模式)
205                    {
206                        let mut pool = lock_inner(mutex);
207                        pool.total = pool.total.saturating_sub(1);
208                        if for_transaction {
209                            pool.txn_total = pool.txn_total.saturating_sub(1);
210                        }
211                    }
212                    warn!(
213                        "{}连接失效已丢弃,当前总连接数量: {}",
214                        label,
215                        self.total_connections()
216                    );
217                    // 不增加 attempts,直接重新循环获取连接
218                    continue;
219                }
220
221                Action::Create => {
222                    // SlotGuard 保护 Create 路径的 total 预占位
223                    let mut guard = SlotGuard::new(mutex, condvar, for_transaction);
224                    match Connect::new(self.config.clone()) {
225                        Ok(new_conn) => {
226                            guard.disarm();
227                            return Ok(new_conn);
228                        }
229                        Err(e) => {
230                            // guard drop 会自动 total -= 1 + notify
231                            drop(guard);
232                            let sleep_ms = BASE_SLEEP_MS
233                                .saturating_mul(1u64 << attempts.min(3))
234                                .min(MAX_SLEEP_MS);
235                            attempts += 1;
236                            error!("创建{}连接失败({}ms后重试): {}", label, sleep_ms, e);
237                            thread::sleep(Duration::from_millis(sleep_ms));
238                        }
239                    }
240                }
241                Action::Wait => {
242                    let pool = lock_inner(mutex);
243                    let (_pool, timeout) = condvar
244                        .wait_timeout(pool, WAIT_TIMEOUT)
245                        .unwrap_or_else(PoisonError::into_inner);
246                    drop(_pool);
247                    if timeout.timed_out() {
248                        attempts += 1;
249                    }
250                }
251            }
252        }
253    }
254    pub fn get_connect(&self) -> Result<Connect, PgsqlError> {
255        self.acquire_connect(false)
256    }
257    /// 事务专用连接,不归还到池
258    pub fn get_connect_for_transaction(&self) -> Result<Connect, PgsqlError> {
259        self.acquire_connect(true)
260    }
261    pub fn release_transaction_conn(&self) {
262        let (ref mutex, ref condvar) = *self.inner;
263        let mut pool = lock_inner(mutex);
264        pool.total = pool.total.saturating_sub(1);
265        pool.txn_total = pool.txn_total.saturating_sub(1);
266        drop(pool);
267        condvar.notify_one();
268    }
269    /// 归还事务连接到连接池(而非销毁),同时递减 txn_total
270    pub fn release_transaction_conn_with_conn(&self, conn: Connect) {
271        let (ref mutex, _) = *self.inner;
272        {
273            let mut pool = lock_inner(mutex);
274            pool.txn_total = pool.txn_total.saturating_sub(1);
275        }
276        self.release_conn(conn);
277    }
278    pub fn release_conn(&self, conn: Connect) {
279        let (ref mutex, ref condvar) = *self.inner;
280        if !conn.peer_valid() {
281            let mut pool = lock_inner(mutex);
282            pool.total = pool.total.saturating_sub(1);
283            drop(pool);
284            condvar.notify_one();
285            warn!("释放时检测到坏连接,已丢弃");
286            return;
287        }
288        if conn.age().as_secs() > MAX_CONN_LIFETIME_SECS {
289            let mut pool = lock_inner(mutex);
290            pool.total = pool.total.saturating_sub(1);
291            drop(pool);
292            condvar.notify_one();
293            log::debug!("释放时连接存活超过{}秒,已丢弃", MAX_CONN_LIFETIME_SECS);
294            return;
295        }
296        if conn.idle_elapsed().as_secs() > MAX_IDLE_SECS {
297            let mut pool = lock_inner(mutex);
298            pool.total = pool.total.saturating_sub(1);
299            drop(pool);
300            condvar.notify_one();
301            log::debug!("连接空闲超过{}秒,已丢弃", MAX_IDLE_SECS);
302            return;
303        }
304        let mut pool = lock_inner(mutex);
305        if pool.idle.len() < pool.max {
306            pool.idle.push_back(conn);
307        } else {
308            pool.total = pool.total.saturating_sub(1);
309            warn!("连接池已满,丢弃连接");
310        }
311        drop(pool);
312        condvar.notify_one();
313    }
314    pub fn idle_pool_size(&self) -> usize {
315        let (ref mutex, _) = *self.inner;
316        let pool = lock_inner(mutex);
317        pool.idle.len()
318    }
319    pub fn total_connections(&self) -> usize {
320        let (ref mutex, _) = *self.inner;
321        let pool = lock_inner(mutex);
322        pool.total
323    }
324    pub fn borrowed_connections(&self) -> usize {
325        let (ref mutex, _) = *self.inner;
326        let pool = lock_inner(mutex);
327        pool.total.saturating_sub(pool.idle.len())
328    }
329    /// 清空池中所有空闲连接(failover 场景:所有连接同时死亡时调用)
330    pub fn flush_idle(&self) {
331        let (ref mutex, _) = *self.inner;
332        let mut pool = lock_inner(mutex);
333        let flushed = pool.idle.len();
334        pool.total = pool.total.saturating_sub(flushed);
335        pool.idle.clear();
336        if flushed > 0 {
337            warn!("清空池中 {flushed} 个空闲连接(疑似批量失效)");
338        }
339    }
340    #[allow(dead_code)]
341    pub fn _cleanup_idle_connections(&self) {
342        let (ref mutex, _) = *self.inner;
343        let mut pool = lock_inner(mutex);
344        log::debug!("当前连接池中的连接数量(清理前): {}", pool.idle.len());
345        let before = pool.idle.len();
346        pool.idle.retain(|conn| {
347            let peer_ok = conn.peer_valid();
348            let idle_ok = conn.idle_elapsed().as_secs() <= MAX_IDLE_SECS;
349            if !peer_ok {
350                log::debug!("检测到无效连接,已移除");
351            } else if !idle_ok {
352                log::debug!("检测到空闲超时连接,已移除");
353            }
354            peer_ok && idle_ok
355        });
356        let removed = before - pool.idle.len();
357        pool.total = pool.total.saturating_sub(removed);
358    }
359}
360
361/// acquire_connect 内部决策
362enum Action {
363    GotConn(Box<Connect>),
364    Create,
365    Wait,
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use std::io::{Read as IoRead, Write as IoWrite};
372    use std::net::TcpListener;
373    use std::sync::atomic::{AtomicBool, Ordering};
374
375    fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
376        let mut m = Vec::with_capacity(5 + payload.len());
377        m.push(tag);
378        m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
379        m.extend_from_slice(payload);
380        m
381    }
382
383    fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
384        let mut body = Vec::new();
385        body.extend(&auth_type.to_be_bytes());
386        body.extend_from_slice(extra);
387        pg_msg(b'R', &body)
388    }
389
390    fn post_auth_ok() -> Vec<u8> {
391        let mut v = Vec::new();
392        v.extend(pg_auth(0, &[]));
393        v.extend(pg_msg(b'S', b"server_version\x0015.0\x00"));
394        let mut k = Vec::new();
395        k.extend(&1u32.to_be_bytes());
396        k.extend(&2u32.to_be_bytes());
397        v.extend(pg_msg(b'K', &k));
398        v.extend(pg_msg(b'Z', b"I"));
399        v
400    }
401
402    fn simple_query_response() -> Vec<u8> {
403        let mut r = Vec::new();
404        r.extend(pg_msg(b'1', &[]));
405        r.extend(pg_msg(b'2', &[]));
406        let mut rd = Vec::new();
407        rd.extend(&1u16.to_be_bytes());
408        rd.extend(b"c\x00");
409        rd.extend(&0u32.to_be_bytes());
410        rd.extend(&1u16.to_be_bytes());
411        rd.extend(&23u32.to_be_bytes());
412        rd.extend(&4i16.to_be_bytes());
413        rd.extend(&(-1i32).to_be_bytes());
414        rd.extend(&0u16.to_be_bytes());
415        r.extend(pg_msg(b'T', &rd));
416        let mut dr = Vec::new();
417        dr.extend(&1u16.to_be_bytes());
418        dr.extend(&1u32.to_be_bytes());
419        dr.push(b'1');
420        r.extend(pg_msg(b'D', &dr));
421        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
422        r.extend(pg_msg(b'Z', b"I"));
423        r
424    }
425
426    fn spawn_multi_server(stop: Arc<AtomicBool>) -> u16 {
427        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
428        let port = listener.local_addr().unwrap().port();
429        thread::spawn(move || {
430            listener.set_nonblocking(true).unwrap();
431            while !stop.load(Ordering::Relaxed) {
432                match listener.accept() {
433                    Ok((s, _)) => {
434                        s.set_nonblocking(false).ok();
435                        let stop2 = stop.clone();
436                        thread::spawn(move || {
437                            s.set_read_timeout(Some(Duration::from_secs(5))).ok();
438                            let mut s = s;
439                            let mut buf = [0u8; 4096];
440                            if s.read(&mut buf).unwrap_or(0) == 0 {
441                                return;
442                            }
443                            let _ = s.write_all(&pg_auth(3, &[]));
444                            if s.read(&mut buf).unwrap_or(0) == 0 {
445                                return;
446                            }
447                            let _ = s.write_all(&post_auth_ok());
448                            while !stop2.load(Ordering::Relaxed) {
449                                match s.read(&mut buf) {
450                                    Ok(0) | Err(_) => break,
451                                    Ok(_) => {
452                                        let _ = s.write_all(&simple_query_response());
453                                    }
454                                }
455                            }
456                        });
457                    }
458                    Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
459                        thread::sleep(Duration::from_millis(5));
460                    }
461                    Err(_) => break,
462                }
463            }
464        });
465        thread::sleep(Duration::from_millis(50));
466        port
467    }
468
469    fn mock_config(port: u16) -> Config {
470        Config {
471            debug: false,
472            hostname: "127.0.0.1".into(),
473            hostport: port as i32,
474            username: "u".into(),
475            userpass: "p".into(),
476            database: "d".into(),
477            charset: "utf8".into(),
478            pool_max: 5,
479        }
480    }
481
482    #[test]
483    fn pools_all_paths() {
484        let stop = Arc::new(AtomicBool::new(false));
485        let port = spawn_multi_server(stop.clone());
486        let cfg = mock_config(port);
487
488        // 基本创建 + 计数
489        let pools = Pools::new(cfg.clone(), 10).unwrap();
490        assert_eq!(pools.total_connections(), 2);
491        assert_eq!(pools.idle_pool_size(), 2);
492        assert_eq!(pools.borrowed_connections(), 0);
493
494        // 借出一个
495        let conn1 = pools.get_connect().unwrap();
496        assert_eq!(pools.idle_pool_size(), 1);
497        assert!(pools.borrowed_connections() > 0);
498
499        // 归还
500        let idle_before = pools.idle_pool_size();
501        pools.release_conn(conn1);
502        assert!(pools.idle_pool_size() > idle_before);
503
504        // 借出后 drop(不归还到池,total 不变因为 drop 不通知池)
505        let conn2 = pools.get_connect().unwrap();
506        drop(conn2);
507
508        // 归还坏连接 → total 减少
509        let mut conn3 = pools.get_connect().unwrap();
510        let total_before = pools.total_connections();
511        conn3._close();
512        pools.release_conn(conn3);
513        assert!(pools.total_connections() <= total_before);
514
515        // cleanup
516        pools._cleanup_idle_connections();
517
518        // ConnectionGuard
519        {
520            let mut guard = pools.get_guard().unwrap();
521            let qr = guard.conn().query("SELECT 1");
522            assert!(qr.is_ok());
523        }
524        assert!(pools.idle_pool_size() > 0);
525
526        // 事务连接
527        let pools2 = Pools::new(cfg.clone(), 10).unwrap();
528        let txn = pools2.get_connect_for_transaction().unwrap();
529        let total_before = pools2.total_connections();
530        pools2.release_transaction_conn();
531        assert_eq!(pools2.total_connections(), total_before - 1);
532        drop(txn);
533
534        // pool_max=1 → 池满时 get_connect 超时
535        let pools3 = Pools::new(cfg.clone(), 1).unwrap();
536        let held = pools3.get_connect().unwrap();
537        let result = pools3.get_connect();
538        assert!(result.is_err());
539        drop(held);
540
541        // 坏配置 → 0 连接
542        let bad_cfg = mock_config(1);
543        let pools4 = Pools::new(bad_cfg.clone(), 5).unwrap();
544        assert_eq!(pools4.total_connections(), 0);
545
546        // 坏配置 get_connect 失败
547        let pools5 = Pools::new(bad_cfg.clone(), 5).unwrap();
548        let result = pools5.get_connect();
549        assert!(result.is_err());
550
551        // 坏配置 get_connect_for_transaction 失败
552        let pools6 = Pools::new(bad_cfg.clone(), 5).unwrap();
553        let result = pools6.get_connect_for_transaction();
554        assert!(result.is_err());
555
556        // pool_max=1 初始化
557        let pools7 = Pools::new(cfg.clone(), 1).unwrap();
558        assert_eq!(pools7.total_connections(), 1);
559
560        stop.store(true, Ordering::Relaxed);
561    }
562}