Skip to main content

br_pgsql/
pools.rs

1use crate::config::Config;
2use crate::connect::Connect;
3use crate::error::PgsqlError;
4use log::{error, warn};
5use std::collections::VecDeque;
6use std::sync::{Arc, Mutex, MutexGuard, PoisonError};
7use std::thread;
8use std::time::Duration;
9
10pub static DB_POOL: std::sync::LazyLock<Mutex<VecDeque<Connect>>> =
11    std::sync::LazyLock::new(|| Mutex::new(VecDeque::new()));
12
13fn lock_pool<'a>() -> MutexGuard<'a, VecDeque<Connect>> {
14    DB_POOL.lock().unwrap_or_else(PoisonError::into_inner)
15}
16
17#[derive(Clone)]
18pub struct Pools {
19    pub config: Config,
20    max_pools: usize,
21    total_connections: Arc<Mutex<usize>>,
22}
23
24fn lock_counter(counter: &Mutex<usize>) -> MutexGuard<'_, usize> {
25    counter.lock().unwrap_or_else(PoisonError::into_inner)
26}
27
28pub struct ConnectionGuard {
29    pool: Pools,
30    conn: Option<Connect>,
31}
32
33impl ConnectionGuard {
34    pub fn new(pool: Pools) -> Result<Self, PgsqlError> {
35        let conn = pool.clone().get_connect()?;
36        Ok(Self {
37            pool,
38            conn: Some(conn),
39        })
40    }
41
42    pub fn conn(&mut self) -> &mut Connect {
43        self.conn.as_mut().expect("connection already released")
44    }
45}
46
47impl Drop for ConnectionGuard {
48    fn drop(&mut self) {
49        if let Some(conn) = self.conn.take() {
50            self.pool.release_conn(conn);
51        }
52    }
53}
54
55impl Pools {
56    pub fn get_guard(&mut self) -> Result<ConnectionGuard, PgsqlError> {
57        ConnectionGuard::new(self.clone())
58    }
59
60    pub fn new(config: Config, size: usize) -> Result<Self, PgsqlError> {
61        let mut pool_guard = lock_pool();
62        let init_size = 2.min(size);
63        let mut created = 0;
64        for _ in 0..init_size {
65            match Connect::new(config.clone()) {
66                Ok(conn) => {
67                    pool_guard.push_back(conn);
68                    created += 1;
69                }
70                Err(e) => warn!("初始化连接失败: {e}"),
71            }
72        }
73
74        let pools = Self {
75            config,
76            max_pools: size,
77            total_connections: Arc::new(Mutex::new(created)),
78        };
79
80        Ok(pools)
81    }
82
83    pub fn get_connect(&mut self) -> Result<Connect, PgsqlError> {
84        let mut attempts = 0;
85
86        #[cfg(not(test))]
87        const RETRY_SLEEP: Duration = Duration::from_secs(1);
88        #[cfg(test)]
89        const RETRY_SLEEP: Duration = Duration::from_millis(1);
90
91        loop {
92            if attempts >= 20 {
93                return Err(PgsqlError::Pool("无法连接数据库,重试超时".into()));
94            }
95            let maybe_conn = {
96                let mut pool = lock_pool();
97                pool.pop_front()
98            };
99
100            if let Some(mut conn) = maybe_conn {
101                if conn.is_valid() {
102                    return Ok(conn);
103                } else {
104                    let mut counter = lock_counter(&self.total_connections);
105                    *counter = counter.saturating_sub(1);
106                    drop(counter);
107                    warn!(
108                        "连接失效,尝试重建,当前总连接数量: {}",
109                        self.total_connections()
110                    );
111                    match Connect::new(self.config.clone()) {
112                        Ok(new_conn) => {
113                            *lock_counter(&self.total_connections) += 1;
114                            return Ok(new_conn);
115                        }
116                        Err(e) => {
117                            error!("重建连接失败: {}", e);
118                            attempts += 1;
119                            thread::sleep(RETRY_SLEEP);
120                            continue;
121                        }
122                    }
123                }
124            } else if self.total_connections() < self.max_pools {
125                match Connect::new(self.config.clone()) {
126                    Ok(new_conn) => {
127                        *lock_counter(&self.total_connections) += 1;
128                        return Ok(new_conn);
129                    }
130                    Err(e) => {
131                        error!("创建新连接失败: {}", e);
132                        attempts += 1;
133                        thread::sleep(RETRY_SLEEP);
134                        continue;
135                    }
136                }
137            } else {
138                attempts += 1;
139                thread::sleep(Duration::from_millis(50));
140            }
141        }
142    }
143
144    /// 事务专用连接,不归还到池
145    pub fn get_connect_for_transaction(&mut self) -> Result<Connect, PgsqlError> {
146        let mut attempts = 0;
147
148        #[cfg(not(test))]
149        const RETRY_SLEEP: Duration = Duration::from_secs(1);
150        #[cfg(test)]
151        const RETRY_SLEEP: Duration = Duration::from_millis(1);
152
153        loop {
154            if attempts >= 20 {
155                return Err(PgsqlError::Pool("无法获取事务连接,重试超时".into()));
156            }
157
158            let maybe_conn = {
159                let mut pool = lock_pool();
160                pool.pop_front()
161            };
162
163            if let Some(mut conn) = maybe_conn {
164                if conn.is_valid() {
165                    return Ok(conn);
166                } else {
167                    let mut counter = lock_counter(&self.total_connections);
168                    *counter = counter.saturating_sub(1);
169                    drop(counter);
170                    warn!(
171                        "事务连接失效,尝试重建,当前总连接数量: {}",
172                        self.total_connections()
173                    );
174                }
175            }
176
177            match Connect::new(self.config.clone()) {
178                Ok(new_conn) => {
179                    *lock_counter(&self.total_connections) += 1;
180                    return Ok(new_conn);
181                }
182                Err(e) => {
183                    error!("创建事务连接失败: {}", e);
184                    attempts += 1;
185                    thread::sleep(RETRY_SLEEP);
186                    continue;
187                }
188            }
189        }
190    }
191
192    pub fn release_transaction_conn(&self) {
193        let mut counter = lock_counter(&self.total_connections);
194        *counter = counter.saturating_sub(1);
195    }
196
197    pub fn release_conn(&self, mut conn: Connect) {
198        if conn.is_valid() {
199            let mut pool = lock_pool();
200            if pool.len() < self.max_pools {
201                pool.push_back(conn);
202            } else {
203                let mut counter = lock_counter(&self.total_connections);
204                *counter = counter.saturating_sub(1);
205                warn!("连接池已满,丢弃连接");
206            }
207        } else {
208            let mut counter = lock_counter(&self.total_connections);
209            *counter = counter.saturating_sub(1);
210            warn!("释放时检测到坏连接,已丢弃");
211        }
212    }
213
214    pub fn idle_pool_size(&self) -> usize {
215        let pool = lock_pool();
216        pool.len()
217    }
218
219    pub fn total_connections(&self) -> usize {
220        *lock_counter(&self.total_connections)
221    }
222
223    pub fn borrowed_connections(&self) -> usize {
224        self.total_connections()
225            .saturating_sub(self.idle_pool_size())
226    }
227
228    #[allow(dead_code)]
229    pub fn _cleanup_idle_connections(&self) {
230        let mut pool = lock_pool();
231        println!("当前连接池中的连接数量(清理前): {}", pool.len());
232
233        pool.retain(|conn| {
234            let is_ok = conn.stream.peer_addr().is_ok();
235            if !is_ok {
236                println!("检测到无效连接,已移除");
237            }
238            is_ok
239        });
240
241        println!("当前连接池中的连接数量(清理后): {}", pool.len());
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use std::io::{Read as IoRead, Write as IoWrite};
249    use std::net::TcpListener;
250    use std::sync::atomic::{AtomicBool, Ordering};
251
252    fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
253        let mut m = Vec::with_capacity(5 + payload.len());
254        m.push(tag);
255        m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
256        m.extend_from_slice(payload);
257        m
258    }
259
260    fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
261        let mut body = Vec::new();
262        body.extend(&auth_type.to_be_bytes());
263        body.extend_from_slice(extra);
264        pg_msg(b'R', &body)
265    }
266
267    fn post_auth_ok() -> Vec<u8> {
268        let mut v = Vec::new();
269        v.extend(pg_auth(0, &[]));
270        v.extend(pg_msg(b'S', b"server_version\x0015.0\x00"));
271        let mut k = Vec::new();
272        k.extend(&1u32.to_be_bytes());
273        k.extend(&2u32.to_be_bytes());
274        v.extend(pg_msg(b'K', &k));
275        v.extend(pg_msg(b'Z', b"I"));
276        v
277    }
278
279    fn simple_query_response() -> Vec<u8> {
280        let mut r = Vec::new();
281        r.extend(pg_msg(b'1', &[]));
282        r.extend(pg_msg(b'2', &[]));
283        let mut rd = Vec::new();
284        rd.extend(&1u16.to_be_bytes());
285        rd.extend(b"c\x00");
286        rd.extend(&0u32.to_be_bytes());
287        rd.extend(&1u16.to_be_bytes());
288        rd.extend(&23u32.to_be_bytes());
289        rd.extend(&4i16.to_be_bytes());
290        rd.extend(&(-1i32).to_be_bytes());
291        rd.extend(&0u16.to_be_bytes());
292        r.extend(pg_msg(b'T', &rd));
293        let mut dr = Vec::new();
294        dr.extend(&1u16.to_be_bytes());
295        dr.extend(&1u32.to_be_bytes());
296        dr.push(b'1');
297        r.extend(pg_msg(b'D', &dr));
298        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
299        r.extend(pg_msg(b'Z', b"I"));
300        r
301    }
302
303    fn spawn_multi_server(stop: Arc<AtomicBool>) -> u16 {
304        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
305        let port = listener.local_addr().unwrap().port();
306        thread::spawn(move || {
307            listener.set_nonblocking(true).unwrap();
308            while !stop.load(Ordering::Relaxed) {
309                match listener.accept() {
310                    Ok((s, _)) => {
311                        s.set_nonblocking(false).ok();
312                        let stop2 = stop.clone();
313                        thread::spawn(move || {
314                            s.set_read_timeout(Some(Duration::from_secs(5))).ok();
315                            let mut s = s;
316                            let mut buf = [0u8; 4096];
317                            if s.read(&mut buf).unwrap_or(0) == 0 {
318                                return;
319                            }
320                            let _ = s.write_all(&pg_auth(3, &[]));
321                            if s.read(&mut buf).unwrap_or(0) == 0 {
322                                return;
323                            }
324                            let _ = s.write_all(&post_auth_ok());
325                            while !stop2.load(Ordering::Relaxed) {
326                                match s.read(&mut buf) {
327                                    Ok(0) | Err(_) => break,
328                                    Ok(_) => {
329                                        let _ = s.write_all(&simple_query_response());
330                                    }
331                                }
332                            }
333                        });
334                    }
335                    Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
336                        thread::sleep(Duration::from_millis(5));
337                    }
338                    Err(_) => break,
339                }
340            }
341        });
342        thread::sleep(Duration::from_millis(50));
343        port
344    }
345
346    fn mock_config(port: u16) -> Config {
347        Config {
348            debug: false,
349            hostname: "127.0.0.1".into(),
350            hostport: port as i32,
351            username: "u".into(),
352            userpass: "p".into(),
353            database: "d".into(),
354            charset: "utf8".into(),
355            pool_max: 5,
356        }
357    }
358
359    #[test]
360    fn pools_all_paths() {
361        lock_pool().clear();
362        let stop = Arc::new(AtomicBool::new(false));
363        let port = spawn_multi_server(stop.clone());
364        let cfg = mock_config(port);
365
366        let mut pools = Pools::new(cfg.clone(), 10).unwrap();
367        assert_eq!(pools.total_connections(), 2);
368        assert_eq!(pools.idle_pool_size(), 2);
369        assert_eq!(pools.borrowed_connections(), 0);
370
371        let conn1 = pools.get_connect().unwrap();
372        assert_eq!(pools.idle_pool_size(), 1);
373        assert!(pools.borrowed_connections() > 0);
374
375        let idle_before = pools.idle_pool_size();
376        pools.release_conn(conn1);
377        assert!(pools.idle_pool_size() > idle_before);
378
379        lock_pool().clear();
380        let conn2 = pools.get_connect().unwrap();
381        drop(conn2);
382
383        let mut conn3 = pools.get_connect().unwrap();
384        let total_before = pools.total_connections();
385        conn3._close();
386        pools.release_conn(conn3);
387        assert!(pools.total_connections() <= total_before);
388
389        let mut bad = pools.get_connect().unwrap();
390        bad._close();
391        lock_pool().push_back(bad);
392        let rebuilt = pools.get_connect().unwrap();
393        drop(rebuilt);
394
395        lock_pool().clear();
396        let p2 = Pools::new(cfg.clone(), 10).unwrap();
397        let idle = p2.idle_pool_size();
398        let total = p2.total_connections();
399        assert_eq!(p2.borrowed_connections(), total - idle);
400
401        lock_pool().clear();
402        let mut p3 = Pools::new(cfg.clone(), 10).unwrap();
403        {
404            let mut guard = p3.get_guard().unwrap();
405            let qr = guard.conn().query("SELECT 1");
406            assert!(qr.is_ok());
407        }
408        assert!(p3.idle_pool_size() > 0);
409
410        lock_pool().clear();
411        let mut p4 = Pools::new(cfg.clone(), 10).unwrap();
412        let txn = p4.get_connect_for_transaction().unwrap();
413        let total_before = p4.total_connections();
414        p4.release_transaction_conn();
415        assert_eq!(p4.total_connections(), total_before - 1);
416        drop(txn);
417
418        lock_pool().clear();
419        let mut p5 = Pools::new(cfg.clone(), 10).unwrap();
420        let mut bad_txn = p5.get_connect().unwrap();
421        bad_txn._close();
422        lock_pool().clear();
423        lock_pool().push_back(bad_txn);
424        let txn2 = p5.get_connect_for_transaction();
425        assert!(txn2.is_ok());
426        drop(txn2);
427
428        lock_pool().clear();
429        let mut p6 = Pools::new(cfg.clone(), 3).unwrap();
430        let c1 = p6.get_connect().unwrap();
431        let c2 = p6.get_connect().unwrap();
432        let c3 = p6.get_connect().unwrap();
433        p6.release_conn(c1);
434        p6.release_conn(c2);
435        let total_before = p6.total_connections();
436        p6.release_conn(c3);
437        assert!(p6.total_connections() <= total_before);
438
439        lock_pool().clear();
440        let p7 = Pools::new(cfg.clone(), 10).unwrap();
441        p7._cleanup_idle_connections();
442
443        lock_pool().clear();
444        let p8 = Pools::new(cfg.clone(), 1).unwrap();
445        assert_eq!(p8.total_connections(), 1);
446
447        lock_pool().clear();
448        let bad_cfg = mock_config(1);
449        let p9 = Pools::new(bad_cfg, 5);
450        assert!(p9.is_ok());
451        assert_eq!(p9.unwrap().total_connections(), 0);
452
453        lock_pool().clear();
454        let mut p10 = Pools::new(cfg.clone(), 1).unwrap();
455        let held = p10.get_connect().unwrap();
456        let result = p10.get_connect();
457        assert!(result.is_err());
458        drop(held);
459
460        lock_pool().clear();
461        let p11 = Pools::new(cfg.clone(), 1).unwrap();
462        assert_eq!(p11.idle_pool_size(), 1);
463        let extra_conn = Connect::new(cfg.clone()).unwrap();
464        let total_before = p11.total_connections();
465        p11.release_conn(extra_conn);
466        assert!(p11.total_connections() <= total_before);
467
468        lock_pool().clear();
469        let mut p12 = Pools::new(cfg.clone(), 3).unwrap();
470        let mut bad_conn = p12.get_connect().unwrap();
471        bad_conn._close();
472        lock_pool().push_back(bad_conn);
473        p12._cleanup_idle_connections();
474
475        lock_pool().clear();
476        let bad_cfg = mock_config(1);
477        let mut p13 = Pools::new(bad_cfg.clone(), 5).unwrap();
478        let good_conn = Connect::new(cfg.clone()).unwrap();
479        let mut bad_c = Connect::new(cfg.clone()).unwrap();
480        bad_c._close();
481        lock_pool().clear();
482        lock_pool().push_back(bad_c);
483        *lock_counter(&p13.total_connections) = 1;
484        let result = p13.get_connect();
485        assert!(result.is_err());
486        drop(good_conn);
487
488        lock_pool().clear();
489        let mut p14 = Pools::new(bad_cfg.clone(), 5).unwrap();
490        *lock_counter(&p14.total_connections) = 0;
491        let result = p14.get_connect();
492        assert!(result.is_err());
493
494        lock_pool().clear();
495        let mut p15 = Pools::new(bad_cfg.clone(), 5).unwrap();
496        *lock_counter(&p15.total_connections) = 0;
497        let result = p15.get_connect_for_transaction();
498        assert!(result.is_err());
499
500        lock_pool().clear();
501        let mut p16 = Pools::new(bad_cfg.clone(), 5).unwrap();
502        let mut bad_txn = Connect::new(cfg.clone()).unwrap();
503        bad_txn._close();
504        lock_pool().push_back(bad_txn);
505        *lock_counter(&p16.total_connections) = 1;
506        let result = p16.get_connect_for_transaction();
507        assert!(result.is_err());
508
509        stop.store(true, Ordering::Relaxed);
510        lock_pool().clear();
511    }
512}