Skip to main content

br_pgsql/
pools.rs

1use crate::config::Config;
2use crate::connect::Connect;
3use crate::error::PgsqlError;
4use log::{error, warn};
5use std::collections::VecDeque;
6use std::sync::{Arc, Condvar, Mutex, MutexGuard, PoisonError};
7use std::thread;
8use std::time::Duration;
9
10/// 空闲连接最大存活时间(5分钟),超过则丢弃而非归还池
11const MAX_IDLE_SECS: u64 = 300;
12/// 连接最大生命周期(30分钟),超过则丢弃,防止长期连接累积服务端状态
13const MAX_CONN_LIFETIME_SECS: u64 = 1800;
14struct PoolInner {
15    idle: VecDeque<Connect>,
16    total: usize,
17    max: usize,
18    /// 当前被事务持有的连接数
19    txn_total: usize,
20    /// 事务连接上限(防止事务饿死普通查询)
21    txn_max: usize,
22}
23
24/// 预占位守卫:Create/GotConn-rebuild 路径中,如果 Connect::new() panic,
25/// drop 时自动归还 total 计数并唤醒等待者,防止 total 永久膨胀。
26struct SlotGuard<'a> {
27    mutex: &'a Mutex<PoolInner>,
28    condvar: &'a Condvar,
29    active: bool,
30    for_transaction: bool,
31}
32
33impl<'a> SlotGuard<'a> {
34    fn new(mutex: &'a Mutex<PoolInner>, condvar: &'a Condvar, for_transaction: bool) -> Self {
35        Self {
36            mutex,
37            condvar,
38            active: true,
39            for_transaction,
40        }
41    }
42
43    /// 标记成功,不再需要回滚
44    fn disarm(&mut self) {
45        self.active = false;
46    }
47}
48
49impl Drop for SlotGuard<'_> {
50    fn drop(&mut self) {
51        if self.active {
52            let mut pool = lock_inner(self.mutex);
53            pool.total = pool.total.saturating_sub(1);
54            if self.for_transaction {
55                pool.txn_total = pool.txn_total.saturating_sub(1);
56            }
57            drop(pool);
58            self.condvar.notify_one();
59        }
60    }
61}
62#[derive(Clone)]
63pub struct Pools {
64    pub config: Config,
65    inner: Arc<(Mutex<PoolInner>, Condvar)>,
66}
67fn lock_inner(mutex: &Mutex<PoolInner>) -> MutexGuard<'_, PoolInner> {
68    mutex.lock().unwrap_or_else(PoisonError::into_inner)
69}
70pub struct ConnectionGuard {
71    pool: Pools,
72    conn: Option<Connect>,
73}
74impl ConnectionGuard {
75    pub fn new(pool: Pools) -> Result<Self, PgsqlError> {
76        let conn = pool.get_connect()?;
77        Ok(Self {
78            pool,
79            conn: Some(conn),
80        })
81    }
82    pub fn conn(&mut self) -> &mut Connect {
83        self.conn.as_mut().expect("connection already released")
84    }
85    /// 丢弃连接(不归还到池),用于连接已断开的场景
86    pub fn discard(&mut self) {
87        if let Some(_conn) = self.conn.take() {
88            let (ref mutex, ref condvar) = *self.pool.inner;
89            let mut pool = lock_inner(mutex);
90            pool.total = pool.total.saturating_sub(1);
91            drop(pool);
92            condvar.notify_one();
93        }
94    }
95}
96impl Drop for ConnectionGuard {
97    fn drop(&mut self) {
98        if let Some(conn) = self.conn.take() {
99            self.pool.release_conn(conn);
100        }
101    }
102}
103impl Pools {
104    pub fn get_guard(&self) -> Result<ConnectionGuard, PgsqlError> {
105        ConnectionGuard::new(self.clone())
106    }
107    pub fn new(config: Config, size: usize) -> Result<Self, PgsqlError> {
108        let init_size = 2.min(size);
109        let mut idle = VecDeque::with_capacity(size);
110        let mut created = 0;
111        for _ in 0..init_size {
112            match Connect::new(config.clone()) {
113                Ok(conn) => {
114                    idle.push_back(conn);
115                    created += 1;
116                }
117                Err(e) => warn!("初始化连接失败: {e}"),
118            }
119        }
120        let txn_max = (size / 3).max(1);
121        let inner = PoolInner {
122            idle,
123            total: created,
124            max: size,
125            txn_total: 0,
126            txn_max,
127        };
128
129        Ok(Self {
130            config,
131            inner: Arc::new((Mutex::new(inner), Condvar::new())),
132        })
133    }
134
135    /// 内部统一获取连接逻辑
136    fn acquire_connect(&self, for_transaction: bool) -> Result<Connect, PgsqlError> {
137        let mut attempts = 0;
138        let (ref mutex, ref condvar) = *self.inner;
139        let label = if for_transaction { "事务" } else { "" };
140        #[cfg(not(test))]
141        const BASE_SLEEP_MS: u64 = 200;
142        #[cfg(test)]
143        const BASE_SLEEP_MS: u64 = 1;
144        #[cfg(not(test))]
145        const MAX_SLEEP_MS: u64 = 2000;
146        #[cfg(test)]
147        const MAX_SLEEP_MS: u64 = 5;
148        #[cfg(not(test))]
149        const WAIT_TIMEOUT: Duration = Duration::from_secs(2);
150        #[cfg(test)]
151        const WAIT_TIMEOUT: Duration = Duration::from_millis(5);
152
153        let timeout_msg = if for_transaction {
154            "无法获取事务连接,重试超时"
155        } else {
156            "无法连接数据库,重试超时"
157        };
158
159        loop {
160            if attempts >= 5 {
161                return Err(PgsqlError::Pool(timeout_msg.into()));
162            }
163
164            let action = {
165                let mut pool = lock_inner(mutex);
166                // 事务连接受 txn_max 限制,防止饿死普通查询
167                if for_transaction && pool.txn_total >= pool.txn_max && pool.total >= pool.max {
168                    Action::Wait
169                } else if let Some(conn) = pool.idle.pop_front() {
170                    if for_transaction {
171                        pool.txn_total += 1;
172                    }
173                    Action::GotConn(Box::new(conn))
174                } else if pool.total < pool.max {
175                    pool.total += 1; // 预占位
176                    if for_transaction {
177                        pool.txn_total += 1;
178                    }
179                    Action::Create
180                } else {
181                    Action::Wait
182                }
183            };
184
185            match action {
186                Action::GotConn(mut conn) => {
187                    // 超过最大生命周期的连接直接丢弃
188                    if conn.age().as_secs() > MAX_CONN_LIFETIME_SECS {
189                        {
190                            let mut pool = lock_inner(mutex);
191                            pool.total = pool.total.saturating_sub(1);
192                            if for_transaction {
193                                pool.txn_total = pool.txn_total.saturating_sub(1);
194                            }
195                        }
196                        log::debug!("{}连接存活超过{}秒,已丢弃", label, MAX_CONN_LIFETIME_SECS);
197                        continue;
198                    }
199                    // 锁外做健康检查(is_valid 含懒 SELECT 1)
200                    if conn.is_valid() {
201                        conn.touch();
202                        return Ok(*conn);
203                    }
204                    // 连接失效,丢弃并重新循环(不消耗重试次数,参考 mysql crate 模式)
205                    {
206                        let mut pool = lock_inner(mutex);
207                        pool.total = pool.total.saturating_sub(1);
208                        if for_transaction {
209                            pool.txn_total = pool.txn_total.saturating_sub(1);
210                        }
211                    }
212                    warn!(
213                        "{}连接失效已丢弃,当前总连接数量: {}",
214                        label,
215                        self.total_connections()
216                    );
217                    // 不增加 attempts,直接重新循环获取连接
218                    continue;
219                }
220
221                Action::Create => {
222                    // SlotGuard 保护 Create 路径的 total 预占位
223                    let mut guard = SlotGuard::new(mutex, condvar, for_transaction);
224                    match Connect::new(self.config.clone()) {
225                        Ok(new_conn) => {
226                            guard.disarm();
227                            return Ok(new_conn);
228                        }
229                        Err(e) => {
230                            // guard drop 会自动 total -= 1 + notify
231                            drop(guard);
232                            let sleep_ms = BASE_SLEEP_MS
233                                .saturating_mul(1u64 << attempts.min(3))
234                                .min(MAX_SLEEP_MS);
235                            attempts += 1;
236                            error!("创建{}连接失败({}ms后重试): {}", label, sleep_ms, e);
237                            thread::sleep(Duration::from_millis(sleep_ms));
238                        }
239                    }
240                }
241                Action::Wait => {
242                    let pool = lock_inner(mutex);
243                    let (_pool, timeout) = condvar
244                        .wait_timeout(pool, WAIT_TIMEOUT)
245                        .unwrap_or_else(PoisonError::into_inner);
246                    drop(_pool);
247                    if timeout.timed_out() {
248                        attempts += 1;
249                    }
250                }
251            }
252        }
253    }
254    pub fn get_connect(&self) -> Result<Connect, PgsqlError> {
255        self.acquire_connect(false)
256    }
257    /// 事务专用连接,不归还到池
258    pub fn get_connect_for_transaction(&self) -> Result<Connect, PgsqlError> {
259        self.acquire_connect(true)
260    }
261    pub fn release_transaction_conn(&self) {
262        let (ref mutex, ref condvar) = *self.inner;
263        let mut pool = lock_inner(mutex);
264        pool.total = pool.total.saturating_sub(1);
265        pool.txn_total = pool.txn_total.saturating_sub(1);
266        drop(pool);
267        condvar.notify_one();
268    }
269    pub fn release_conn(&self, conn: Connect) {
270        let (ref mutex, ref condvar) = *self.inner;
271        if !conn.peer_valid() {
272            let mut pool = lock_inner(mutex);
273            pool.total = pool.total.saturating_sub(1);
274            drop(pool);
275            condvar.notify_one();
276            warn!("释放时检测到坏连接,已丢弃");
277            return;
278        }
279        if conn.age().as_secs() > MAX_CONN_LIFETIME_SECS {
280            let mut pool = lock_inner(mutex);
281            pool.total = pool.total.saturating_sub(1);
282            drop(pool);
283            condvar.notify_one();
284            log::debug!("释放时连接存活超过{}秒,已丢弃", MAX_CONN_LIFETIME_SECS);
285            return;
286        }
287        if conn.idle_elapsed().as_secs() > MAX_IDLE_SECS {
288            let mut pool = lock_inner(mutex);
289            pool.total = pool.total.saturating_sub(1);
290            drop(pool);
291            condvar.notify_one();
292            log::debug!("连接空闲超过{}秒,已丢弃", MAX_IDLE_SECS);
293            return;
294        }
295        let mut pool = lock_inner(mutex);
296        if pool.idle.len() < pool.max {
297            pool.idle.push_back(conn);
298        } else {
299            pool.total = pool.total.saturating_sub(1);
300            warn!("连接池已满,丢弃连接");
301        }
302        drop(pool);
303        condvar.notify_one();
304    }
305    pub fn idle_pool_size(&self) -> usize {
306        let (ref mutex, _) = *self.inner;
307        let pool = lock_inner(mutex);
308        pool.idle.len()
309    }
310    pub fn total_connections(&self) -> usize {
311        let (ref mutex, _) = *self.inner;
312        let pool = lock_inner(mutex);
313        pool.total
314    }
315    pub fn borrowed_connections(&self) -> usize {
316        let (ref mutex, _) = *self.inner;
317        let pool = lock_inner(mutex);
318        pool.total.saturating_sub(pool.idle.len())
319    }
320    /// 清空池中所有空闲连接(failover 场景:所有连接同时死亡时调用)
321    pub fn flush_idle(&self) {
322        let (ref mutex, _) = *self.inner;
323        let mut pool = lock_inner(mutex);
324        let flushed = pool.idle.len();
325        pool.total = pool.total.saturating_sub(flushed);
326        pool.idle.clear();
327        if flushed > 0 {
328            warn!("清空池中 {flushed} 个空闲连接(疑似批量失效)");
329        }
330    }
331    #[allow(dead_code)]
332    pub fn _cleanup_idle_connections(&self) {
333        let (ref mutex, _) = *self.inner;
334        let mut pool = lock_inner(mutex);
335        log::debug!("当前连接池中的连接数量(清理前): {}", pool.idle.len());
336        let before = pool.idle.len();
337        pool.idle.retain(|conn| {
338            let peer_ok = conn.peer_valid();
339            let idle_ok = conn.idle_elapsed().as_secs() <= MAX_IDLE_SECS;
340            if !peer_ok {
341                log::debug!("检测到无效连接,已移除");
342            } else if !idle_ok {
343                log::debug!("检测到空闲超时连接,已移除");
344            }
345            peer_ok && idle_ok
346        });
347        let removed = before - pool.idle.len();
348        pool.total = pool.total.saturating_sub(removed);
349    }
350}
351
352/// acquire_connect 内部决策
353enum Action {
354    GotConn(Box<Connect>),
355    Create,
356    Wait,
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362    use std::io::{Read as IoRead, Write as IoWrite};
363    use std::net::TcpListener;
364    use std::sync::atomic::{AtomicBool, Ordering};
365
366    fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
367        let mut m = Vec::with_capacity(5 + payload.len());
368        m.push(tag);
369        m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
370        m.extend_from_slice(payload);
371        m
372    }
373
374    fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
375        let mut body = Vec::new();
376        body.extend(&auth_type.to_be_bytes());
377        body.extend_from_slice(extra);
378        pg_msg(b'R', &body)
379    }
380
381    fn post_auth_ok() -> Vec<u8> {
382        let mut v = Vec::new();
383        v.extend(pg_auth(0, &[]));
384        v.extend(pg_msg(b'S', b"server_version\x0015.0\x00"));
385        let mut k = Vec::new();
386        k.extend(&1u32.to_be_bytes());
387        k.extend(&2u32.to_be_bytes());
388        v.extend(pg_msg(b'K', &k));
389        v.extend(pg_msg(b'Z', b"I"));
390        v
391    }
392
393    fn simple_query_response() -> Vec<u8> {
394        let mut r = Vec::new();
395        r.extend(pg_msg(b'1', &[]));
396        r.extend(pg_msg(b'2', &[]));
397        let mut rd = Vec::new();
398        rd.extend(&1u16.to_be_bytes());
399        rd.extend(b"c\x00");
400        rd.extend(&0u32.to_be_bytes());
401        rd.extend(&1u16.to_be_bytes());
402        rd.extend(&23u32.to_be_bytes());
403        rd.extend(&4i16.to_be_bytes());
404        rd.extend(&(-1i32).to_be_bytes());
405        rd.extend(&0u16.to_be_bytes());
406        r.extend(pg_msg(b'T', &rd));
407        let mut dr = Vec::new();
408        dr.extend(&1u16.to_be_bytes());
409        dr.extend(&1u32.to_be_bytes());
410        dr.push(b'1');
411        r.extend(pg_msg(b'D', &dr));
412        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
413        r.extend(pg_msg(b'Z', b"I"));
414        r
415    }
416
417    fn spawn_multi_server(stop: Arc<AtomicBool>) -> u16 {
418        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
419        let port = listener.local_addr().unwrap().port();
420        thread::spawn(move || {
421            listener.set_nonblocking(true).unwrap();
422            while !stop.load(Ordering::Relaxed) {
423                match listener.accept() {
424                    Ok((s, _)) => {
425                        s.set_nonblocking(false).ok();
426                        let stop2 = stop.clone();
427                        thread::spawn(move || {
428                            s.set_read_timeout(Some(Duration::from_secs(5))).ok();
429                            let mut s = s;
430                            let mut buf = [0u8; 4096];
431                            if s.read(&mut buf).unwrap_or(0) == 0 {
432                                return;
433                            }
434                            let _ = s.write_all(&pg_auth(3, &[]));
435                            if s.read(&mut buf).unwrap_or(0) == 0 {
436                                return;
437                            }
438                            let _ = s.write_all(&post_auth_ok());
439                            while !stop2.load(Ordering::Relaxed) {
440                                match s.read(&mut buf) {
441                                    Ok(0) | Err(_) => break,
442                                    Ok(_) => {
443                                        let _ = s.write_all(&simple_query_response());
444                                    }
445                                }
446                            }
447                        });
448                    }
449                    Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
450                        thread::sleep(Duration::from_millis(5));
451                    }
452                    Err(_) => break,
453                }
454            }
455        });
456        thread::sleep(Duration::from_millis(50));
457        port
458    }
459
460    fn mock_config(port: u16) -> Config {
461        Config {
462            debug: false,
463            hostname: "127.0.0.1".into(),
464            hostport: port as i32,
465            username: "u".into(),
466            userpass: "p".into(),
467            database: "d".into(),
468            charset: "utf8".into(),
469            pool_max: 5,
470        }
471    }
472
473    #[test]
474    fn pools_all_paths() {
475        let stop = Arc::new(AtomicBool::new(false));
476        let port = spawn_multi_server(stop.clone());
477        let cfg = mock_config(port);
478
479        // 基本创建 + 计数
480        let pools = Pools::new(cfg.clone(), 10).unwrap();
481        assert_eq!(pools.total_connections(), 2);
482        assert_eq!(pools.idle_pool_size(), 2);
483        assert_eq!(pools.borrowed_connections(), 0);
484
485        // 借出一个
486        let conn1 = pools.get_connect().unwrap();
487        assert_eq!(pools.idle_pool_size(), 1);
488        assert!(pools.borrowed_connections() > 0);
489
490        // 归还
491        let idle_before = pools.idle_pool_size();
492        pools.release_conn(conn1);
493        assert!(pools.idle_pool_size() > idle_before);
494
495        // 借出后 drop(不归还到池,total 不变因为 drop 不通知池)
496        let conn2 = pools.get_connect().unwrap();
497        drop(conn2);
498
499        // 归还坏连接 → total 减少
500        let mut conn3 = pools.get_connect().unwrap();
501        let total_before = pools.total_connections();
502        conn3._close();
503        pools.release_conn(conn3);
504        assert!(pools.total_connections() <= total_before);
505
506        // cleanup
507        pools._cleanup_idle_connections();
508
509        // ConnectionGuard
510        {
511            let mut guard = pools.get_guard().unwrap();
512            let qr = guard.conn().query("SELECT 1");
513            assert!(qr.is_ok());
514        }
515        assert!(pools.idle_pool_size() > 0);
516
517        // 事务连接
518        let pools2 = Pools::new(cfg.clone(), 10).unwrap();
519        let txn = pools2.get_connect_for_transaction().unwrap();
520        let total_before = pools2.total_connections();
521        pools2.release_transaction_conn();
522        assert_eq!(pools2.total_connections(), total_before - 1);
523        drop(txn);
524
525        // pool_max=1 → 池满时 get_connect 超时
526        let pools3 = Pools::new(cfg.clone(), 1).unwrap();
527        let held = pools3.get_connect().unwrap();
528        let result = pools3.get_connect();
529        assert!(result.is_err());
530        drop(held);
531
532        // 坏配置 → 0 连接
533        let bad_cfg = mock_config(1);
534        let pools4 = Pools::new(bad_cfg.clone(), 5).unwrap();
535        assert_eq!(pools4.total_connections(), 0);
536
537        // 坏配置 get_connect 失败
538        let pools5 = Pools::new(bad_cfg.clone(), 5).unwrap();
539        let result = pools5.get_connect();
540        assert!(result.is_err());
541
542        // 坏配置 get_connect_for_transaction 失败
543        let pools6 = Pools::new(bad_cfg.clone(), 5).unwrap();
544        let result = pools6.get_connect_for_transaction();
545        assert!(result.is_err());
546
547        // pool_max=1 初始化
548        let pools7 = Pools::new(cfg.clone(), 1).unwrap();
549        assert_eq!(pools7.total_connections(), 1);
550
551        stop.store(true, Ordering::Relaxed);
552    }
553}