Skip to main content

br_pgsql/
pools.rs

1use crate::config::Config;
2use crate::connect::Connect;
3use crate::error::PgsqlError;
4use log::{error, warn};
5use std::collections::VecDeque;
6use std::sync::{Arc, Condvar, Mutex, MutexGuard, PoisonError};
7use std::thread;
8use std::time::Duration;
9
10/// 空闲连接最大存活时间(5分钟),超过则丢弃而非归还池
11const MAX_IDLE_SECS: u64 = 300;
12struct PoolInner {
13    idle: VecDeque<Connect>,
14    total: usize,
15    max: usize,
16}
17
18/// 预占位守卫:Create/GotConn-rebuild 路径中,如果 Connect::new() panic,
19/// drop 时自动归还 total 计数并唤醒等待者,防止 total 永久膨胀。
20struct SlotGuard<'a> {
21    mutex: &'a Mutex<PoolInner>,
22    condvar: &'a Condvar,
23    active: bool,
24}
25
26impl<'a> SlotGuard<'a> {
27    fn new(mutex: &'a Mutex<PoolInner>, condvar: &'a Condvar) -> Self {
28        Self {
29            mutex,
30            condvar,
31            active: true,
32        }
33    }
34
35    /// 标记成功,不再需要回滚
36    fn disarm(&mut self) {
37        self.active = false;
38    }
39}
40
41impl Drop for SlotGuard<'_> {
42    fn drop(&mut self) {
43        if self.active {
44            let mut pool = lock_inner(self.mutex);
45            pool.total = pool.total.saturating_sub(1);
46            drop(pool);
47            self.condvar.notify_one();
48        }
49    }
50}
51#[derive(Clone)]
52pub struct Pools {
53    pub config: Config,
54    inner: Arc<(Mutex<PoolInner>, Condvar)>,
55}
56fn lock_inner(mutex: &Mutex<PoolInner>) -> MutexGuard<'_, PoolInner> {
57    mutex.lock().unwrap_or_else(PoisonError::into_inner)
58}
59pub struct ConnectionGuard {
60    pool: Pools,
61    conn: Option<Connect>,
62}
63impl ConnectionGuard {
64    pub fn new(pool: Pools) -> Result<Self, PgsqlError> {
65        let conn = pool.get_connect()?;
66        Ok(Self {
67            pool,
68            conn: Some(conn),
69        })
70    }
71    pub fn conn(&mut self) -> &mut Connect {
72        self.conn.as_mut().expect("connection already released")
73    }
74    /// 丢弃连接(不归还到池),用于连接已断开的场景
75    pub fn discard(&mut self) {
76        if let Some(_conn) = self.conn.take() {
77            let (ref mutex, ref condvar) = *self.pool.inner;
78            let mut pool = lock_inner(mutex);
79            pool.total = pool.total.saturating_sub(1);
80            drop(pool);
81            condvar.notify_one();
82        }
83    }
84}
85impl Drop for ConnectionGuard {
86    fn drop(&mut self) {
87        if let Some(conn) = self.conn.take() {
88            self.pool.release_conn(conn);
89        }
90    }
91}
92impl Pools {
93    pub fn get_guard(&self) -> Result<ConnectionGuard, PgsqlError> {
94        ConnectionGuard::new(self.clone())
95    }
96    pub fn new(config: Config, size: usize) -> Result<Self, PgsqlError> {
97        let init_size = 2.min(size);
98        let mut idle = VecDeque::with_capacity(size);
99        let mut created = 0;
100        for _ in 0..init_size {
101            match Connect::new(config.clone()) {
102                Ok(conn) => {
103                    idle.push_back(conn);
104                    created += 1;
105                }
106                Err(e) => warn!("初始化连接失败: {e}"),
107            }
108        }
109        let inner = PoolInner {
110            idle,
111            total: created,
112            max: size,
113        };
114
115        Ok(Self {
116            config,
117            inner: Arc::new((Mutex::new(inner), Condvar::new())),
118        })
119    }
120
121    /// 内部统一获取连接逻辑
122    fn acquire_connect(&self, for_transaction: bool) -> Result<Connect, PgsqlError> {
123        let mut attempts = 0;
124        let (ref mutex, ref condvar) = *self.inner;
125        let label = if for_transaction { "事务" } else { "" };
126        #[cfg(not(test))]
127        const BASE_SLEEP_MS: u64 = 1000;
128        #[cfg(test)]
129        const BASE_SLEEP_MS: u64 = 1;
130
131        #[cfg(not(test))]
132        const MAX_SLEEP_MS: u64 = 8000;
133        #[cfg(test)]
134        const MAX_SLEEP_MS: u64 = 5;
135
136        #[cfg(not(test))]
137        const WAIT_TIMEOUT: Duration = Duration::from_secs(3);
138        #[cfg(test)]
139        const WAIT_TIMEOUT: Duration = Duration::from_millis(5);
140
141        let timeout_msg = if for_transaction {
142            "无法获取事务连接,重试超时"
143        } else {
144            "无法连接数据库,重试超时"
145        };
146
147        loop {
148            if attempts >= 20 {
149                return Err(PgsqlError::Pool(timeout_msg.into()));
150            }
151
152            let action = {
153                let mut pool = lock_inner(mutex);
154                if let Some(conn) = pool.idle.pop_front() {
155                    Action::GotConn(Box::new(conn))
156                } else if pool.total < pool.max {
157                    pool.total += 1; // 预占位
158                    Action::Create
159                } else {
160                    Action::Wait
161                }
162            };
163
164            match action {
165                Action::GotConn(mut conn) => {
166                    // 锁外做健康检查(is_valid 含懒 SELECT 1)
167                    if conn.is_valid() {
168                        conn.touch();
169                        return Ok(*conn);
170                    }
171                    // 连接失效,减计数,尝试重建
172                    {
173                        let mut pool = lock_inner(mutex);
174                        pool.total = pool.total.saturating_sub(1);
175                    }
176                    warn!(
177                        "{}连接失效,尝试重建,当前总连接数量: {}",
178                        label,
179                        self.total_connections()
180                    );
181                    // SlotGuard 保护重建路径的 total 预占位
182                    let mut guard = SlotGuard::new(mutex, condvar);
183                    {
184                        let mut pool = lock_inner(mutex);
185                        pool.total += 1;
186                    }
187                    match Connect::new(self.config.clone()) {
188                        Ok(new_conn) => {
189                            guard.disarm();
190                            return Ok(new_conn);
191                        }
192                        Err(e) => {
193                            // guard drop 会自动 total -= 1 + notify
194                            drop(guard);
195                            let sleep_ms = BASE_SLEEP_MS
196                                .saturating_mul(1u64 << attempts.min(3))
197                                .min(MAX_SLEEP_MS);
198                            attempts += 1;
199                            error!("{}重建连接失败({}ms后重试): {}", label, sleep_ms, e);
200                            thread::sleep(Duration::from_millis(sleep_ms));
201                        }
202                    }
203                }
204                Action::Create => {
205                    // SlotGuard 保护 Create 路径的 total 预占位
206                    let mut guard = SlotGuard::new(mutex, condvar);
207                    match Connect::new(self.config.clone()) {
208                        Ok(new_conn) => {
209                            guard.disarm();
210                            return Ok(new_conn);
211                        }
212                        Err(e) => {
213                            // guard drop 会自动 total -= 1 + notify
214                            drop(guard);
215                            let sleep_ms = BASE_SLEEP_MS
216                                .saturating_mul(1u64 << attempts.min(3))
217                                .min(MAX_SLEEP_MS);
218                            attempts += 1;
219                            error!("创建{}连接失败({}ms后重试): {}", label, sleep_ms, e);
220                            thread::sleep(Duration::from_millis(sleep_ms));
221                        }
222                    }
223                }
224                Action::Wait => {
225                    let pool = lock_inner(mutex);
226                    let (_pool, timeout) = condvar
227                        .wait_timeout(pool, WAIT_TIMEOUT)
228                        .unwrap_or_else(PoisonError::into_inner);
229                    drop(_pool);
230                    if timeout.timed_out() {
231                        attempts += 1;
232                    }
233                }
234            }
235        }
236    }
237    pub fn get_connect(&self) -> Result<Connect, PgsqlError> {
238        self.acquire_connect(false)
239    }
240    /// 事务专用连接,不归还到池
241    pub fn get_connect_for_transaction(&self) -> Result<Connect, PgsqlError> {
242        self.acquire_connect(true)
243    }
244    pub fn release_transaction_conn(&self) {
245        let (ref mutex, ref condvar) = *self.inner;
246        let mut pool = lock_inner(mutex);
247        pool.total = pool.total.saturating_sub(1);
248        drop(pool);
249        condvar.notify_one();
250    }
251    pub fn release_conn(&self, conn: Connect) {
252        let (ref mutex, ref condvar) = *self.inner;
253        if !conn.peer_valid() {
254            let mut pool = lock_inner(mutex);
255            pool.total = pool.total.saturating_sub(1);
256            drop(pool);
257            condvar.notify_one();
258            warn!("释放时检测到坏连接,已丢弃");
259            return;
260        }
261        if conn.idle_elapsed().as_secs() > MAX_IDLE_SECS {
262            let mut pool = lock_inner(mutex);
263            pool.total = pool.total.saturating_sub(1);
264            drop(pool);
265            condvar.notify_one();
266            log::debug!("连接空闲超过{}秒,已丢弃", MAX_IDLE_SECS);
267            return;
268        }
269        let mut pool = lock_inner(mutex);
270        if pool.idle.len() < pool.max {
271            pool.idle.push_back(conn);
272        } else {
273            pool.total = pool.total.saturating_sub(1);
274            warn!("连接池已满,丢弃连接");
275        }
276        drop(pool);
277        condvar.notify_one();
278    }
279    pub fn idle_pool_size(&self) -> usize {
280        let (ref mutex, _) = *self.inner;
281        let pool = lock_inner(mutex);
282        pool.idle.len()
283    }
284    pub fn total_connections(&self) -> usize {
285        let (ref mutex, _) = *self.inner;
286        let pool = lock_inner(mutex);
287        pool.total
288    }
289    pub fn borrowed_connections(&self) -> usize {
290        let (ref mutex, _) = *self.inner;
291        let pool = lock_inner(mutex);
292        pool.total.saturating_sub(pool.idle.len())
293    }
294    #[allow(dead_code)]
295    pub fn _cleanup_idle_connections(&self) {
296        let (ref mutex, _) = *self.inner;
297        let mut pool = lock_inner(mutex);
298        log::debug!("当前连接池中的连接数量(清理前): {}", pool.idle.len());
299        let before = pool.idle.len();
300        pool.idle.retain(|conn| {
301            let peer_ok = conn.peer_valid();
302            let idle_ok = conn.idle_elapsed().as_secs() <= MAX_IDLE_SECS;
303            if !peer_ok {
304                log::debug!("检测到无效连接,已移除");
305            } else if !idle_ok {
306                log::debug!("检测到空闲超时连接,已移除");
307            }
308            peer_ok && idle_ok
309        });
310        let removed = before - pool.idle.len();
311        pool.total = pool.total.saturating_sub(removed);
312    }
313}
314
315/// acquire_connect 内部决策
316enum Action {
317    GotConn(Box<Connect>),
318    Create,
319    Wait,
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325    use std::io::{Read as IoRead, Write as IoWrite};
326    use std::net::TcpListener;
327    use std::sync::atomic::{AtomicBool, Ordering};
328
329    fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
330        let mut m = Vec::with_capacity(5 + payload.len());
331        m.push(tag);
332        m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
333        m.extend_from_slice(payload);
334        m
335    }
336
337    fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
338        let mut body = Vec::new();
339        body.extend(&auth_type.to_be_bytes());
340        body.extend_from_slice(extra);
341        pg_msg(b'R', &body)
342    }
343
344    fn post_auth_ok() -> Vec<u8> {
345        let mut v = Vec::new();
346        v.extend(pg_auth(0, &[]));
347        v.extend(pg_msg(b'S', b"server_version\x0015.0\x00"));
348        let mut k = Vec::new();
349        k.extend(&1u32.to_be_bytes());
350        k.extend(&2u32.to_be_bytes());
351        v.extend(pg_msg(b'K', &k));
352        v.extend(pg_msg(b'Z', b"I"));
353        v
354    }
355
356    fn simple_query_response() -> Vec<u8> {
357        let mut r = Vec::new();
358        r.extend(pg_msg(b'1', &[]));
359        r.extend(pg_msg(b'2', &[]));
360        let mut rd = Vec::new();
361        rd.extend(&1u16.to_be_bytes());
362        rd.extend(b"c\x00");
363        rd.extend(&0u32.to_be_bytes());
364        rd.extend(&1u16.to_be_bytes());
365        rd.extend(&23u32.to_be_bytes());
366        rd.extend(&4i16.to_be_bytes());
367        rd.extend(&(-1i32).to_be_bytes());
368        rd.extend(&0u16.to_be_bytes());
369        r.extend(pg_msg(b'T', &rd));
370        let mut dr = Vec::new();
371        dr.extend(&1u16.to_be_bytes());
372        dr.extend(&1u32.to_be_bytes());
373        dr.push(b'1');
374        r.extend(pg_msg(b'D', &dr));
375        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
376        r.extend(pg_msg(b'Z', b"I"));
377        r
378    }
379
380    fn spawn_multi_server(stop: Arc<AtomicBool>) -> u16 {
381        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
382        let port = listener.local_addr().unwrap().port();
383        thread::spawn(move || {
384            listener.set_nonblocking(true).unwrap();
385            while !stop.load(Ordering::Relaxed) {
386                match listener.accept() {
387                    Ok((s, _)) => {
388                        s.set_nonblocking(false).ok();
389                        let stop2 = stop.clone();
390                        thread::spawn(move || {
391                            s.set_read_timeout(Some(Duration::from_secs(5))).ok();
392                            let mut s = s;
393                            let mut buf = [0u8; 4096];
394                            if s.read(&mut buf).unwrap_or(0) == 0 {
395                                return;
396                            }
397                            let _ = s.write_all(&pg_auth(3, &[]));
398                            if s.read(&mut buf).unwrap_or(0) == 0 {
399                                return;
400                            }
401                            let _ = s.write_all(&post_auth_ok());
402                            while !stop2.load(Ordering::Relaxed) {
403                                match s.read(&mut buf) {
404                                    Ok(0) | Err(_) => break,
405                                    Ok(_) => {
406                                        let _ = s.write_all(&simple_query_response());
407                                    }
408                                }
409                            }
410                        });
411                    }
412                    Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
413                        thread::sleep(Duration::from_millis(5));
414                    }
415                    Err(_) => break,
416                }
417            }
418        });
419        thread::sleep(Duration::from_millis(50));
420        port
421    }
422
423    fn mock_config(port: u16) -> Config {
424        Config {
425            debug: false,
426            hostname: "127.0.0.1".into(),
427            hostport: port as i32,
428            username: "u".into(),
429            userpass: "p".into(),
430            database: "d".into(),
431            charset: "utf8".into(),
432            pool_max: 5,
433        }
434    }
435
436    #[test]
437    fn pools_all_paths() {
438        let stop = Arc::new(AtomicBool::new(false));
439        let port = spawn_multi_server(stop.clone());
440        let cfg = mock_config(port);
441
442        // 基本创建 + 计数
443        let pools = Pools::new(cfg.clone(), 10).unwrap();
444        assert_eq!(pools.total_connections(), 2);
445        assert_eq!(pools.idle_pool_size(), 2);
446        assert_eq!(pools.borrowed_connections(), 0);
447
448        // 借出一个
449        let conn1 = pools.get_connect().unwrap();
450        assert_eq!(pools.idle_pool_size(), 1);
451        assert!(pools.borrowed_connections() > 0);
452
453        // 归还
454        let idle_before = pools.idle_pool_size();
455        pools.release_conn(conn1);
456        assert!(pools.idle_pool_size() > idle_before);
457
458        // 借出后 drop(不归还到池,total 不变因为 drop 不通知池)
459        let conn2 = pools.get_connect().unwrap();
460        drop(conn2);
461
462        // 归还坏连接 → total 减少
463        let mut conn3 = pools.get_connect().unwrap();
464        let total_before = pools.total_connections();
465        conn3._close();
466        pools.release_conn(conn3);
467        assert!(pools.total_connections() <= total_before);
468
469        // cleanup
470        pools._cleanup_idle_connections();
471
472        // ConnectionGuard
473        {
474            let mut guard = pools.get_guard().unwrap();
475            let qr = guard.conn().query("SELECT 1");
476            assert!(qr.is_ok());
477        }
478        assert!(pools.idle_pool_size() > 0);
479
480        // 事务连接
481        let pools2 = Pools::new(cfg.clone(), 10).unwrap();
482        let txn = pools2.get_connect_for_transaction().unwrap();
483        let total_before = pools2.total_connections();
484        pools2.release_transaction_conn();
485        assert_eq!(pools2.total_connections(), total_before - 1);
486        drop(txn);
487
488        // pool_max=1 → 池满时 get_connect 超时
489        let pools3 = Pools::new(cfg.clone(), 1).unwrap();
490        let held = pools3.get_connect().unwrap();
491        let result = pools3.get_connect();
492        assert!(result.is_err());
493        drop(held);
494
495        // 坏配置 → 0 连接
496        let bad_cfg = mock_config(1);
497        let pools4 = Pools::new(bad_cfg.clone(), 5).unwrap();
498        assert_eq!(pools4.total_connections(), 0);
499
500        // 坏配置 get_connect 失败
501        let pools5 = Pools::new(bad_cfg.clone(), 5).unwrap();
502        let result = pools5.get_connect();
503        assert!(result.is_err());
504
505        // 坏配置 get_connect_for_transaction 失败
506        let pools6 = Pools::new(bad_cfg.clone(), 5).unwrap();
507        let result = pools6.get_connect_for_transaction();
508        assert!(result.is_err());
509
510        // pool_max=1 初始化
511        let pools7 = Pools::new(cfg.clone(), 1).unwrap();
512        assert_eq!(pools7.total_connections(), 1);
513
514        stop.store(true, Ordering::Relaxed);
515    }
516}