Skip to main content

br_pgsql/
pools.rs

1use crate::config::Config;
2use crate::connect::Connect;
3use crate::error::PgsqlError;
4use log::{error, warn};
5use std::collections::VecDeque;
6use std::sync::{Arc, Condvar, Mutex, MutexGuard, PoisonError};
7use std::thread;
8use std::time::Duration;
9
10/// 空闲连接最大存活时间(5分钟),超过则丢弃而非归还池
11const MAX_IDLE_SECS: u64 = 300;
12struct PoolInner {
13    idle: VecDeque<Connect>,
14    total: usize,
15    max: usize,
16    /// 当前被事务持有的连接数
17    txn_total: usize,
18    /// 事务连接上限(防止事务饿死普通查询)
19    txn_max: usize,
20}
21
22/// 预占位守卫:Create/GotConn-rebuild 路径中,如果 Connect::new() panic,
23/// drop 时自动归还 total 计数并唤醒等待者,防止 total 永久膨胀。
24struct SlotGuard<'a> {
25    mutex: &'a Mutex<PoolInner>,
26    condvar: &'a Condvar,
27    active: bool,
28    for_transaction: bool,
29}
30
31impl<'a> SlotGuard<'a> {
32    fn new(mutex: &'a Mutex<PoolInner>, condvar: &'a Condvar, for_transaction: bool) -> Self {
33        Self {
34            mutex,
35            condvar,
36            active: true,
37            for_transaction,
38        }
39    }
40
41    /// 标记成功,不再需要回滚
42    fn disarm(&mut self) {
43        self.active = false;
44    }
45}
46
47impl Drop for SlotGuard<'_> {
48    fn drop(&mut self) {
49        if self.active {
50            let mut pool = lock_inner(self.mutex);
51            pool.total = pool.total.saturating_sub(1);
52            if self.for_transaction {
53                pool.txn_total = pool.txn_total.saturating_sub(1);
54            }
55            drop(pool);
56            self.condvar.notify_one();
57        }
58    }
59}
60#[derive(Clone)]
61pub struct Pools {
62    pub config: Config,
63    inner: Arc<(Mutex<PoolInner>, Condvar)>,
64}
65fn lock_inner(mutex: &Mutex<PoolInner>) -> MutexGuard<'_, PoolInner> {
66    mutex.lock().unwrap_or_else(PoisonError::into_inner)
67}
68pub struct ConnectionGuard {
69    pool: Pools,
70    conn: Option<Connect>,
71}
72impl ConnectionGuard {
73    pub fn new(pool: Pools) -> Result<Self, PgsqlError> {
74        let conn = pool.get_connect()?;
75        Ok(Self {
76            pool,
77            conn: Some(conn),
78        })
79    }
80    pub fn conn(&mut self) -> &mut Connect {
81        self.conn.as_mut().expect("connection already released")
82    }
83    /// 丢弃连接(不归还到池),用于连接已断开的场景
84    pub fn discard(&mut self) {
85        if let Some(_conn) = self.conn.take() {
86            let (ref mutex, ref condvar) = *self.pool.inner;
87            let mut pool = lock_inner(mutex);
88            pool.total = pool.total.saturating_sub(1);
89            drop(pool);
90            condvar.notify_one();
91        }
92    }
93}
94impl Drop for ConnectionGuard {
95    fn drop(&mut self) {
96        if let Some(conn) = self.conn.take() {
97            self.pool.release_conn(conn);
98        }
99    }
100}
101impl Pools {
102    pub fn get_guard(&self) -> Result<ConnectionGuard, PgsqlError> {
103        ConnectionGuard::new(self.clone())
104    }
105    pub fn new(config: Config, size: usize) -> Result<Self, PgsqlError> {
106        let init_size = 2.min(size);
107        let mut idle = VecDeque::with_capacity(size);
108        let mut created = 0;
109        for _ in 0..init_size {
110            match Connect::new(config.clone()) {
111                Ok(conn) => {
112                    idle.push_back(conn);
113                    created += 1;
114                }
115                Err(e) => warn!("初始化连接失败: {e}"),
116            }
117        }
118        let txn_max = (size / 3).max(1);
119        let inner = PoolInner {
120            idle,
121            total: created,
122            max: size,
123            txn_total: 0,
124            txn_max,
125        };
126
127        Ok(Self {
128            config,
129            inner: Arc::new((Mutex::new(inner), Condvar::new())),
130        })
131    }
132
133    /// 内部统一获取连接逻辑
134    fn acquire_connect(&self, for_transaction: bool) -> Result<Connect, PgsqlError> {
135        let mut attempts = 0;
136        let (ref mutex, ref condvar) = *self.inner;
137        let label = if for_transaction { "事务" } else { "" };
138        #[cfg(not(test))]
139        const BASE_SLEEP_MS: u64 = 200;
140        #[cfg(test)]
141        const BASE_SLEEP_MS: u64 = 1;
142        #[cfg(not(test))]
143        const MAX_SLEEP_MS: u64 = 2000;
144        #[cfg(test)]
145        const MAX_SLEEP_MS: u64 = 5;
146        #[cfg(not(test))]
147        const WAIT_TIMEOUT: Duration = Duration::from_secs(2);
148        #[cfg(test)]
149        const WAIT_TIMEOUT: Duration = Duration::from_millis(5);
150
151        let timeout_msg = if for_transaction {
152            "无法获取事务连接,重试超时"
153        } else {
154            "无法连接数据库,重试超时"
155        };
156
157        loop {
158            if attempts >= 5 {
159                return Err(PgsqlError::Pool(timeout_msg.into()));
160            }
161
162            let action = {
163                let mut pool = lock_inner(mutex);
164                // 事务连接受 txn_max 限制,防止饿死普通查询
165                if for_transaction && pool.txn_total >= pool.txn_max && pool.total >= pool.max {
166                    Action::Wait
167                } else if let Some(conn) = pool.idle.pop_front() {
168                    if for_transaction {
169                        pool.txn_total += 1;
170                    }
171                    Action::GotConn(Box::new(conn))
172                } else if pool.total < pool.max {
173                    pool.total += 1; // 预占位
174                    if for_transaction {
175                        pool.txn_total += 1;
176                    }
177                    Action::Create
178                } else {
179                    Action::Wait
180                }
181            };
182
183            match action {
184                Action::GotConn(mut conn) => {
185                    // 锁外做健康检查(is_valid 含懒 SELECT 1)
186                    if conn.is_valid() {
187                        conn.touch();
188                        return Ok(*conn);
189                    }
190                    // 连接失效,减计数,尝试重建
191                    {
192                        let mut pool = lock_inner(mutex);
193                        pool.total = pool.total.saturating_sub(1);
194                    }
195                    warn!(
196                        "{}连接失效,尝试重建,当前总连接数量: {}",
197                        label,
198                        self.total_connections()
199                    );
200                    // SlotGuard 保护重建路径的 total 预占位
201                    let mut guard = SlotGuard::new(mutex, condvar, for_transaction);
202                    {
203                        let mut pool = lock_inner(mutex);
204                        pool.total += 1;
205                    }
206                    match Connect::new(self.config.clone()) {
207                        Ok(new_conn) => {
208                            guard.disarm();
209                            return Ok(new_conn);
210                        }
211                        Err(e) => {
212                            // guard drop 会自动 total -= 1 + notify
213                            drop(guard);
214                            let sleep_ms = BASE_SLEEP_MS
215                                .saturating_mul(1u64 << attempts.min(3))
216                                .min(MAX_SLEEP_MS);
217                            attempts += 1;
218                            error!("{}重建连接失败({}ms后重试): {}", label, sleep_ms, e);
219                            thread::sleep(Duration::from_millis(sleep_ms));
220                        }
221                    }
222                }
223                Action::Create => {
224                    // SlotGuard 保护 Create 路径的 total 预占位
225                    let mut guard = SlotGuard::new(mutex, condvar, for_transaction);
226                    match Connect::new(self.config.clone()) {
227                        Ok(new_conn) => {
228                            guard.disarm();
229                            return Ok(new_conn);
230                        }
231                        Err(e) => {
232                            // guard drop 会自动 total -= 1 + notify
233                            drop(guard);
234                            let sleep_ms = BASE_SLEEP_MS
235                                .saturating_mul(1u64 << attempts.min(3))
236                                .min(MAX_SLEEP_MS);
237                            attempts += 1;
238                            error!("创建{}连接失败({}ms后重试): {}", label, sleep_ms, e);
239                            thread::sleep(Duration::from_millis(sleep_ms));
240                        }
241                    }
242                }
243                Action::Wait => {
244                    let pool = lock_inner(mutex);
245                    let (_pool, timeout) = condvar
246                        .wait_timeout(pool, WAIT_TIMEOUT)
247                        .unwrap_or_else(PoisonError::into_inner);
248                    drop(_pool);
249                    if timeout.timed_out() {
250                        attempts += 1;
251                    }
252                }
253            }
254        }
255    }
256    pub fn get_connect(&self) -> Result<Connect, PgsqlError> {
257        self.acquire_connect(false)
258    }
259    /// 事务专用连接,不归还到池
260    pub fn get_connect_for_transaction(&self) -> Result<Connect, PgsqlError> {
261        self.acquire_connect(true)
262    }
263    pub fn release_transaction_conn(&self) {
264        let (ref mutex, ref condvar) = *self.inner;
265        let mut pool = lock_inner(mutex);
266        pool.total = pool.total.saturating_sub(1);
267        pool.txn_total = pool.txn_total.saturating_sub(1);
268        drop(pool);
269        condvar.notify_one();
270    }
271    pub fn release_conn(&self, conn: Connect) {
272        let (ref mutex, ref condvar) = *self.inner;
273        if !conn.peer_valid() {
274            let mut pool = lock_inner(mutex);
275            pool.total = pool.total.saturating_sub(1);
276            drop(pool);
277            condvar.notify_one();
278            warn!("释放时检测到坏连接,已丢弃");
279            return;
280        }
281        if conn.idle_elapsed().as_secs() > MAX_IDLE_SECS {
282            let mut pool = lock_inner(mutex);
283            pool.total = pool.total.saturating_sub(1);
284            drop(pool);
285            condvar.notify_one();
286            log::debug!("连接空闲超过{}秒,已丢弃", MAX_IDLE_SECS);
287            return;
288        }
289        let mut pool = lock_inner(mutex);
290        if pool.idle.len() < pool.max {
291            pool.idle.push_back(conn);
292        } else {
293            pool.total = pool.total.saturating_sub(1);
294            warn!("连接池已满,丢弃连接");
295        }
296        drop(pool);
297        condvar.notify_one();
298    }
299    pub fn idle_pool_size(&self) -> usize {
300        let (ref mutex, _) = *self.inner;
301        let pool = lock_inner(mutex);
302        pool.idle.len()
303    }
304    pub fn total_connections(&self) -> usize {
305        let (ref mutex, _) = *self.inner;
306        let pool = lock_inner(mutex);
307        pool.total
308    }
309    pub fn borrowed_connections(&self) -> usize {
310        let (ref mutex, _) = *self.inner;
311        let pool = lock_inner(mutex);
312        pool.total.saturating_sub(pool.idle.len())
313    }
314    #[allow(dead_code)]
315    pub fn _cleanup_idle_connections(&self) {
316        let (ref mutex, _) = *self.inner;
317        let mut pool = lock_inner(mutex);
318        log::debug!("当前连接池中的连接数量(清理前): {}", pool.idle.len());
319        let before = pool.idle.len();
320        pool.idle.retain(|conn| {
321            let peer_ok = conn.peer_valid();
322            let idle_ok = conn.idle_elapsed().as_secs() <= MAX_IDLE_SECS;
323            if !peer_ok {
324                log::debug!("检测到无效连接,已移除");
325            } else if !idle_ok {
326                log::debug!("检测到空闲超时连接,已移除");
327            }
328            peer_ok && idle_ok
329        });
330        let removed = before - pool.idle.len();
331        pool.total = pool.total.saturating_sub(removed);
332    }
333}
334
335/// acquire_connect 内部决策
336enum Action {
337    GotConn(Box<Connect>),
338    Create,
339    Wait,
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use std::io::{Read as IoRead, Write as IoWrite};
346    use std::net::TcpListener;
347    use std::sync::atomic::{AtomicBool, Ordering};
348
349    fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
350        let mut m = Vec::with_capacity(5 + payload.len());
351        m.push(tag);
352        m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
353        m.extend_from_slice(payload);
354        m
355    }
356
357    fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
358        let mut body = Vec::new();
359        body.extend(&auth_type.to_be_bytes());
360        body.extend_from_slice(extra);
361        pg_msg(b'R', &body)
362    }
363
364    fn post_auth_ok() -> Vec<u8> {
365        let mut v = Vec::new();
366        v.extend(pg_auth(0, &[]));
367        v.extend(pg_msg(b'S', b"server_version\x0015.0\x00"));
368        let mut k = Vec::new();
369        k.extend(&1u32.to_be_bytes());
370        k.extend(&2u32.to_be_bytes());
371        v.extend(pg_msg(b'K', &k));
372        v.extend(pg_msg(b'Z', b"I"));
373        v
374    }
375
376    fn simple_query_response() -> Vec<u8> {
377        let mut r = Vec::new();
378        r.extend(pg_msg(b'1', &[]));
379        r.extend(pg_msg(b'2', &[]));
380        let mut rd = Vec::new();
381        rd.extend(&1u16.to_be_bytes());
382        rd.extend(b"c\x00");
383        rd.extend(&0u32.to_be_bytes());
384        rd.extend(&1u16.to_be_bytes());
385        rd.extend(&23u32.to_be_bytes());
386        rd.extend(&4i16.to_be_bytes());
387        rd.extend(&(-1i32).to_be_bytes());
388        rd.extend(&0u16.to_be_bytes());
389        r.extend(pg_msg(b'T', &rd));
390        let mut dr = Vec::new();
391        dr.extend(&1u16.to_be_bytes());
392        dr.extend(&1u32.to_be_bytes());
393        dr.push(b'1');
394        r.extend(pg_msg(b'D', &dr));
395        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
396        r.extend(pg_msg(b'Z', b"I"));
397        r
398    }
399
400    fn spawn_multi_server(stop: Arc<AtomicBool>) -> u16 {
401        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
402        let port = listener.local_addr().unwrap().port();
403        thread::spawn(move || {
404            listener.set_nonblocking(true).unwrap();
405            while !stop.load(Ordering::Relaxed) {
406                match listener.accept() {
407                    Ok((s, _)) => {
408                        s.set_nonblocking(false).ok();
409                        let stop2 = stop.clone();
410                        thread::spawn(move || {
411                            s.set_read_timeout(Some(Duration::from_secs(5))).ok();
412                            let mut s = s;
413                            let mut buf = [0u8; 4096];
414                            if s.read(&mut buf).unwrap_or(0) == 0 {
415                                return;
416                            }
417                            let _ = s.write_all(&pg_auth(3, &[]));
418                            if s.read(&mut buf).unwrap_or(0) == 0 {
419                                return;
420                            }
421                            let _ = s.write_all(&post_auth_ok());
422                            while !stop2.load(Ordering::Relaxed) {
423                                match s.read(&mut buf) {
424                                    Ok(0) | Err(_) => break,
425                                    Ok(_) => {
426                                        let _ = s.write_all(&simple_query_response());
427                                    }
428                                }
429                            }
430                        });
431                    }
432                    Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
433                        thread::sleep(Duration::from_millis(5));
434                    }
435                    Err(_) => break,
436                }
437            }
438        });
439        thread::sleep(Duration::from_millis(50));
440        port
441    }
442
443    fn mock_config(port: u16) -> Config {
444        Config {
445            debug: false,
446            hostname: "127.0.0.1".into(),
447            hostport: port as i32,
448            username: "u".into(),
449            userpass: "p".into(),
450            database: "d".into(),
451            charset: "utf8".into(),
452            pool_max: 5,
453        }
454    }
455
456    #[test]
457    fn pools_all_paths() {
458        let stop = Arc::new(AtomicBool::new(false));
459        let port = spawn_multi_server(stop.clone());
460        let cfg = mock_config(port);
461
462        // 基本创建 + 计数
463        let pools = Pools::new(cfg.clone(), 10).unwrap();
464        assert_eq!(pools.total_connections(), 2);
465        assert_eq!(pools.idle_pool_size(), 2);
466        assert_eq!(pools.borrowed_connections(), 0);
467
468        // 借出一个
469        let conn1 = pools.get_connect().unwrap();
470        assert_eq!(pools.idle_pool_size(), 1);
471        assert!(pools.borrowed_connections() > 0);
472
473        // 归还
474        let idle_before = pools.idle_pool_size();
475        pools.release_conn(conn1);
476        assert!(pools.idle_pool_size() > idle_before);
477
478        // 借出后 drop(不归还到池,total 不变因为 drop 不通知池)
479        let conn2 = pools.get_connect().unwrap();
480        drop(conn2);
481
482        // 归还坏连接 → total 减少
483        let mut conn3 = pools.get_connect().unwrap();
484        let total_before = pools.total_connections();
485        conn3._close();
486        pools.release_conn(conn3);
487        assert!(pools.total_connections() <= total_before);
488
489        // cleanup
490        pools._cleanup_idle_connections();
491
492        // ConnectionGuard
493        {
494            let mut guard = pools.get_guard().unwrap();
495            let qr = guard.conn().query("SELECT 1");
496            assert!(qr.is_ok());
497        }
498        assert!(pools.idle_pool_size() > 0);
499
500        // 事务连接
501        let pools2 = Pools::new(cfg.clone(), 10).unwrap();
502        let txn = pools2.get_connect_for_transaction().unwrap();
503        let total_before = pools2.total_connections();
504        pools2.release_transaction_conn();
505        assert_eq!(pools2.total_connections(), total_before - 1);
506        drop(txn);
507
508        // pool_max=1 → 池满时 get_connect 超时
509        let pools3 = Pools::new(cfg.clone(), 1).unwrap();
510        let held = pools3.get_connect().unwrap();
511        let result = pools3.get_connect();
512        assert!(result.is_err());
513        drop(held);
514
515        // 坏配置 → 0 连接
516        let bad_cfg = mock_config(1);
517        let pools4 = Pools::new(bad_cfg.clone(), 5).unwrap();
518        assert_eq!(pools4.total_connections(), 0);
519
520        // 坏配置 get_connect 失败
521        let pools5 = Pools::new(bad_cfg.clone(), 5).unwrap();
522        let result = pools5.get_connect();
523        assert!(result.is_err());
524
525        // 坏配置 get_connect_for_transaction 失败
526        let pools6 = Pools::new(bad_cfg.clone(), 5).unwrap();
527        let result = pools6.get_connect_for_transaction();
528        assert!(result.is_err());
529
530        // pool_max=1 初始化
531        let pools7 = Pools::new(cfg.clone(), 1).unwrap();
532        assert_eq!(pools7.total_connections(), 1);
533
534        stop.store(true, Ordering::Relaxed);
535    }
536}