1use crate::config::Config;
2use crate::connect::Connect;
3use crate::error::PgsqlError;
4use log::{error, warn};
5use std::collections::VecDeque;
6use std::sync::{Arc, Mutex, MutexGuard, PoisonError};
7use std::thread;
8use std::time::Duration;
9
10pub static DB_POOL: std::sync::LazyLock<Mutex<VecDeque<Connect>>> =
11 std::sync::LazyLock::new(|| Mutex::new(VecDeque::new()));
12
13fn lock_pool<'a>() -> MutexGuard<'a, VecDeque<Connect>> {
14 DB_POOL.lock().unwrap_or_else(PoisonError::into_inner)
15}
16
17#[derive(Clone)]
18pub struct Pools {
19 pub config: Config,
20 max_pools: usize,
21 total_connections: Arc<Mutex<usize>>,
22}
23
24fn lock_counter(counter: &Mutex<usize>) -> MutexGuard<'_, usize> {
25 counter.lock().unwrap_or_else(PoisonError::into_inner)
26}
27
28pub struct ConnectionGuard {
29 pool: Pools,
30 conn: Option<Connect>,
31}
32
33impl ConnectionGuard {
34 pub fn new(pool: Pools) -> Result<Self, PgsqlError> {
35 let conn = pool.clone().get_connect()?;
36 Ok(Self {
37 pool,
38 conn: Some(conn),
39 })
40 }
41
42 pub fn conn(&mut self) -> &mut Connect {
43 self.conn.as_mut().expect("connection already released")
44 }
45}
46
47impl Drop for ConnectionGuard {
48 fn drop(&mut self) {
49 if let Some(conn) = self.conn.take() {
50 self.pool.release_conn(conn);
51 }
52 }
53}
54
55impl Pools {
56 pub fn get_guard(&mut self) -> Result<ConnectionGuard, PgsqlError> {
57 ConnectionGuard::new(self.clone())
58 }
59
60 pub fn new(config: Config, size: usize) -> Result<Self, PgsqlError> {
61 let mut pool_guard = lock_pool();
62 let init_size = 2.min(size);
63 let mut created = 0;
64 for _ in 0..init_size {
65 match Connect::new(config.clone()) {
66 Ok(conn) => {
67 pool_guard.push_back(conn);
68 created += 1;
69 }
70 Err(e) => warn!("初始化连接失败: {e}"),
71 }
72 }
73
74 let pools = Self {
75 config,
76 max_pools: size,
77 total_connections: Arc::new(Mutex::new(created)),
78 };
79
80 Ok(pools)
81 }
82
83 pub fn get_connect(&mut self) -> Result<Connect, PgsqlError> {
84 let mut attempts = 0;
85
86 #[cfg(not(test))]
87 const RETRY_SLEEP: Duration = Duration::from_secs(1);
88 #[cfg(test)]
89 const RETRY_SLEEP: Duration = Duration::from_millis(1);
90
91 loop {
92 if attempts >= 20 {
93 return Err(PgsqlError::Pool("无法连接数据库,重试超时".into()));
94 }
95 let maybe_conn = {
96 let mut pool = lock_pool();
97 pool.pop_front()
98 };
99
100 if let Some(mut conn) = maybe_conn {
101 if conn.is_valid() {
102 return Ok(conn);
103 } else {
104 let mut counter = lock_counter(&self.total_connections);
105 *counter = counter.saturating_sub(1);
106 drop(counter);
107 warn!(
108 "连接失效,尝试重建,当前总连接数量: {}",
109 self.total_connections()
110 );
111 match Connect::new(self.config.clone()) {
112 Ok(new_conn) => {
113 *lock_counter(&self.total_connections) += 1;
114 return Ok(new_conn);
115 }
116 Err(e) => {
117 error!("重建连接失败: {}", e);
118 attempts += 1;
119 thread::sleep(RETRY_SLEEP);
120 continue;
121 }
122 }
123 }
124 } else if self.total_connections() < self.max_pools {
125 match Connect::new(self.config.clone()) {
126 Ok(new_conn) => {
127 *lock_counter(&self.total_connections) += 1;
128 return Ok(new_conn);
129 }
130 Err(e) => {
131 error!("创建新连接失败: {}", e);
132 attempts += 1;
133 thread::sleep(RETRY_SLEEP);
134 continue;
135 }
136 }
137 } else {
138 attempts += 1;
139 thread::sleep(Duration::from_millis(50));
140 }
141 }
142 }
143
144 pub fn get_connect_for_transaction(&mut self) -> Result<Connect, PgsqlError> {
146 let mut attempts = 0;
147
148 #[cfg(not(test))]
149 const RETRY_SLEEP: Duration = Duration::from_secs(1);
150 #[cfg(test)]
151 const RETRY_SLEEP: Duration = Duration::from_millis(1);
152
153 loop {
154 if attempts >= 20 {
155 return Err(PgsqlError::Pool("无法获取事务连接,重试超时".into()));
156 }
157
158 let maybe_conn = {
159 let mut pool = lock_pool();
160 pool.pop_front()
161 };
162
163 if let Some(mut conn) = maybe_conn {
164 if conn.is_valid() {
165 return Ok(conn);
166 } else {
167 let mut counter = lock_counter(&self.total_connections);
168 *counter = counter.saturating_sub(1);
169 drop(counter);
170 warn!(
171 "事务连接失效,尝试重建,当前总连接数量: {}",
172 self.total_connections()
173 );
174 }
175 }
176
177 match Connect::new(self.config.clone()) {
178 Ok(new_conn) => {
179 *lock_counter(&self.total_connections) += 1;
180 return Ok(new_conn);
181 }
182 Err(e) => {
183 error!("创建事务连接失败: {}", e);
184 attempts += 1;
185 thread::sleep(RETRY_SLEEP);
186 continue;
187 }
188 }
189 }
190 }
191
192 pub fn release_transaction_conn(&self) {
193 let mut counter = lock_counter(&self.total_connections);
194 *counter = counter.saturating_sub(1);
195 }
196
197 pub fn release_conn(&self, mut conn: Connect) {
198 if conn.is_valid() {
199 let mut pool = lock_pool();
200 if pool.len() < self.max_pools {
201 pool.push_back(conn);
202 } else {
203 let mut counter = lock_counter(&self.total_connections);
204 *counter = counter.saturating_sub(1);
205 warn!("连接池已满,丢弃连接");
206 }
207 } else {
208 let mut counter = lock_counter(&self.total_connections);
209 *counter = counter.saturating_sub(1);
210 warn!("释放时检测到坏连接,已丢弃");
211 }
212 }
213
214 pub fn idle_pool_size(&self) -> usize {
215 let pool = lock_pool();
216 pool.len()
217 }
218
219 pub fn total_connections(&self) -> usize {
220 *lock_counter(&self.total_connections)
221 }
222
223 pub fn borrowed_connections(&self) -> usize {
224 self.total_connections()
225 .saturating_sub(self.idle_pool_size())
226 }
227
228 #[allow(dead_code)]
229 pub fn _cleanup_idle_connections(&self) {
230 let mut pool = lock_pool();
231 println!("当前连接池中的连接数量(清理前): {}", pool.len());
232
233 pool.retain(|conn| {
234 let is_ok = conn.stream.peer_addr().is_ok();
235 if !is_ok {
236 println!("检测到无效连接,已移除");
237 }
238 is_ok
239 });
240
241 println!("当前连接池中的连接数量(清理后): {}", pool.len());
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use std::io::{Read as IoRead, Write as IoWrite};
249 use std::net::TcpListener;
250 use std::sync::atomic::{AtomicBool, Ordering};
251
252 fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
253 let mut m = Vec::with_capacity(5 + payload.len());
254 m.push(tag);
255 m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
256 m.extend_from_slice(payload);
257 m
258 }
259
260 fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
261 let mut body = Vec::new();
262 body.extend(&auth_type.to_be_bytes());
263 body.extend_from_slice(extra);
264 pg_msg(b'R', &body)
265 }
266
267 fn post_auth_ok() -> Vec<u8> {
268 let mut v = Vec::new();
269 v.extend(pg_auth(0, &[]));
270 v.extend(pg_msg(b'S', b"server_version\x0015.0\x00"));
271 let mut k = Vec::new();
272 k.extend(&1u32.to_be_bytes());
273 k.extend(&2u32.to_be_bytes());
274 v.extend(pg_msg(b'K', &k));
275 v.extend(pg_msg(b'Z', b"I"));
276 v
277 }
278
279 fn simple_query_response() -> Vec<u8> {
280 let mut r = Vec::new();
281 r.extend(pg_msg(b'1', &[]));
282 r.extend(pg_msg(b'2', &[]));
283 let mut rd = Vec::new();
284 rd.extend(&1u16.to_be_bytes());
285 rd.extend(b"c\x00");
286 rd.extend(&0u32.to_be_bytes());
287 rd.extend(&1u16.to_be_bytes());
288 rd.extend(&23u32.to_be_bytes());
289 rd.extend(&4i16.to_be_bytes());
290 rd.extend(&(-1i32).to_be_bytes());
291 rd.extend(&0u16.to_be_bytes());
292 r.extend(pg_msg(b'T', &rd));
293 let mut dr = Vec::new();
294 dr.extend(&1u16.to_be_bytes());
295 dr.extend(&1u32.to_be_bytes());
296 dr.push(b'1');
297 r.extend(pg_msg(b'D', &dr));
298 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
299 r.extend(pg_msg(b'Z', b"I"));
300 r
301 }
302
303 fn spawn_multi_server(stop: Arc<AtomicBool>) -> u16 {
304 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
305 let port = listener.local_addr().unwrap().port();
306 thread::spawn(move || {
307 listener.set_nonblocking(true).unwrap();
308 while !stop.load(Ordering::Relaxed) {
309 match listener.accept() {
310 Ok((s, _)) => {
311 s.set_nonblocking(false).ok();
312 let stop2 = stop.clone();
313 thread::spawn(move || {
314 s.set_read_timeout(Some(Duration::from_secs(5))).ok();
315 let mut s = s;
316 let mut buf = [0u8; 4096];
317 if s.read(&mut buf).unwrap_or(0) == 0 {
318 return;
319 }
320 let _ = s.write_all(&pg_auth(3, &[]));
321 if s.read(&mut buf).unwrap_or(0) == 0 {
322 return;
323 }
324 let _ = s.write_all(&post_auth_ok());
325 while !stop2.load(Ordering::Relaxed) {
326 match s.read(&mut buf) {
327 Ok(0) | Err(_) => break,
328 Ok(_) => {
329 let _ = s.write_all(&simple_query_response());
330 }
331 }
332 }
333 });
334 }
335 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
336 thread::sleep(Duration::from_millis(5));
337 }
338 Err(_) => break,
339 }
340 }
341 });
342 thread::sleep(Duration::from_millis(50));
343 port
344 }
345
346 fn mock_config(port: u16) -> Config {
347 Config {
348 debug: false,
349 hostname: "127.0.0.1".into(),
350 hostport: port as i32,
351 username: "u".into(),
352 userpass: "p".into(),
353 database: "d".into(),
354 charset: "utf8".into(),
355 pool_max: 5,
356 }
357 }
358
359 #[test]
360 fn pools_all_paths() {
361 lock_pool().clear();
362 let stop = Arc::new(AtomicBool::new(false));
363 let port = spawn_multi_server(stop.clone());
364 let cfg = mock_config(port);
365
366 let mut pools = Pools::new(cfg.clone(), 10).unwrap();
367 assert_eq!(pools.total_connections(), 2);
368 assert_eq!(pools.idle_pool_size(), 2);
369 assert_eq!(pools.borrowed_connections(), 0);
370
371 let conn1 = pools.get_connect().unwrap();
372 assert_eq!(pools.idle_pool_size(), 1);
373 assert!(pools.borrowed_connections() > 0);
374
375 let idle_before = pools.idle_pool_size();
376 pools.release_conn(conn1);
377 assert!(pools.idle_pool_size() > idle_before);
378
379 lock_pool().clear();
380 let conn2 = pools.get_connect().unwrap();
381 drop(conn2);
382
383 let mut conn3 = pools.get_connect().unwrap();
384 let total_before = pools.total_connections();
385 conn3._close();
386 pools.release_conn(conn3);
387 assert!(pools.total_connections() <= total_before);
388
389 let mut bad = pools.get_connect().unwrap();
390 bad._close();
391 lock_pool().push_back(bad);
392 let rebuilt = pools.get_connect().unwrap();
393 drop(rebuilt);
394
395 lock_pool().clear();
396 let p2 = Pools::new(cfg.clone(), 10).unwrap();
397 let idle = p2.idle_pool_size();
398 let total = p2.total_connections();
399 assert_eq!(p2.borrowed_connections(), total - idle);
400
401 lock_pool().clear();
402 let mut p3 = Pools::new(cfg.clone(), 10).unwrap();
403 {
404 let mut guard = p3.get_guard().unwrap();
405 let qr = guard.conn().query("SELECT 1");
406 assert!(qr.is_ok());
407 }
408 assert!(p3.idle_pool_size() > 0);
409
410 lock_pool().clear();
411 let mut p4 = Pools::new(cfg.clone(), 10).unwrap();
412 let txn = p4.get_connect_for_transaction().unwrap();
413 let total_before = p4.total_connections();
414 p4.release_transaction_conn();
415 assert_eq!(p4.total_connections(), total_before - 1);
416 drop(txn);
417
418 lock_pool().clear();
419 let mut p5 = Pools::new(cfg.clone(), 10).unwrap();
420 let mut bad_txn = p5.get_connect().unwrap();
421 bad_txn._close();
422 lock_pool().clear();
423 lock_pool().push_back(bad_txn);
424 let txn2 = p5.get_connect_for_transaction();
425 assert!(txn2.is_ok());
426 drop(txn2);
427
428 lock_pool().clear();
429 let mut p6 = Pools::new(cfg.clone(), 3).unwrap();
430 let c1 = p6.get_connect().unwrap();
431 let c2 = p6.get_connect().unwrap();
432 let c3 = p6.get_connect().unwrap();
433 p6.release_conn(c1);
434 p6.release_conn(c2);
435 let total_before = p6.total_connections();
436 p6.release_conn(c3);
437 assert!(p6.total_connections() <= total_before);
438
439 lock_pool().clear();
440 let p7 = Pools::new(cfg.clone(), 10).unwrap();
441 p7._cleanup_idle_connections();
442
443 lock_pool().clear();
444 let p8 = Pools::new(cfg.clone(), 1).unwrap();
445 assert_eq!(p8.total_connections(), 1);
446
447 lock_pool().clear();
448 let bad_cfg = mock_config(1);
449 let p9 = Pools::new(bad_cfg, 5);
450 assert!(p9.is_ok());
451 assert_eq!(p9.unwrap().total_connections(), 0);
452
453 lock_pool().clear();
454 let mut p10 = Pools::new(cfg.clone(), 1).unwrap();
455 let held = p10.get_connect().unwrap();
456 let result = p10.get_connect();
457 assert!(result.is_err());
458 drop(held);
459
460 lock_pool().clear();
461 let p11 = Pools::new(cfg.clone(), 1).unwrap();
462 assert_eq!(p11.idle_pool_size(), 1);
463 let extra_conn = Connect::new(cfg.clone()).unwrap();
464 let total_before = p11.total_connections();
465 p11.release_conn(extra_conn);
466 assert!(p11.total_connections() <= total_before);
467
468 lock_pool().clear();
469 let mut p12 = Pools::new(cfg.clone(), 3).unwrap();
470 let mut bad_conn = p12.get_connect().unwrap();
471 bad_conn._close();
472 lock_pool().push_back(bad_conn);
473 p12._cleanup_idle_connections();
474
475 lock_pool().clear();
476 let bad_cfg = mock_config(1);
477 let mut p13 = Pools::new(bad_cfg.clone(), 5).unwrap();
478 let good_conn = Connect::new(cfg.clone()).unwrap();
479 let mut bad_c = Connect::new(cfg.clone()).unwrap();
480 bad_c._close();
481 lock_pool().clear();
482 lock_pool().push_back(bad_c);
483 *lock_counter(&p13.total_connections) = 1;
484 let result = p13.get_connect();
485 assert!(result.is_err());
486 drop(good_conn);
487
488 lock_pool().clear();
489 let mut p14 = Pools::new(bad_cfg.clone(), 5).unwrap();
490 *lock_counter(&p14.total_connections) = 0;
491 let result = p14.get_connect();
492 assert!(result.is_err());
493
494 lock_pool().clear();
495 let mut p15 = Pools::new(bad_cfg.clone(), 5).unwrap();
496 *lock_counter(&p15.total_connections) = 0;
497 let result = p15.get_connect_for_transaction();
498 assert!(result.is_err());
499
500 lock_pool().clear();
501 let mut p16 = Pools::new(bad_cfg.clone(), 5).unwrap();
502 let mut bad_txn = Connect::new(cfg.clone()).unwrap();
503 bad_txn._close();
504 lock_pool().push_back(bad_txn);
505 *lock_counter(&p16.total_connections) = 1;
506 let result = p16.get_connect_for_transaction();
507 assert!(result.is_err());
508
509 stop.store(true, Ordering::Relaxed);
510 lock_pool().clear();
511 }
512}