1use crate::config::Config;
2use crate::connect::Connect;
3use crate::error::PgsqlError;
4use log::{error, info, warn};
5use std::collections::VecDeque;
6use std::sync::{Arc, Condvar, Mutex, MutexGuard, PoisonError, Weak};
7use std::thread;
8use std::time::Duration;
9
10const MAX_IDLE_SECS: u64 = 300;
12const MAX_CONN_LIFETIME_SECS: u64 = 1800;
14struct PoolInner {
15 idle: VecDeque<Connect>,
16 total: usize,
17 max: usize,
18 txn_total: usize,
20 txn_max: usize,
22}
23
24struct SlotGuard<'a> {
27 mutex: &'a Mutex<PoolInner>,
28 condvar: &'a Condvar,
29 active: bool,
30 for_transaction: bool,
31}
32
33impl<'a> SlotGuard<'a> {
34 fn new(mutex: &'a Mutex<PoolInner>, condvar: &'a Condvar, for_transaction: bool) -> Self {
35 Self {
36 mutex,
37 condvar,
38 active: true,
39 for_transaction,
40 }
41 }
42
43 fn disarm(&mut self) {
45 self.active = false;
46 }
47}
48
49impl Drop for SlotGuard<'_> {
50 fn drop(&mut self) {
51 if self.active {
52 let mut pool = lock_inner(self.mutex);
53 pool.total = pool.total.saturating_sub(1);
54 if self.for_transaction {
55 pool.txn_total = pool.txn_total.saturating_sub(1);
56 }
57 drop(pool);
58 self.condvar.notify_one();
59 }
60 }
61}
62#[derive(Clone)]
63pub struct Pools {
64 pub config: Config,
65 inner: Arc<(Mutex<PoolInner>, Condvar)>,
66}
67fn lock_inner(mutex: &Mutex<PoolInner>) -> MutexGuard<'_, PoolInner> {
68 mutex.lock().unwrap_or_else(PoisonError::into_inner)
69}
70pub struct ConnectionGuard {
71 pool: Pools,
72 conn: Option<Connect>,
73}
74impl ConnectionGuard {
75 pub fn new(pool: Pools) -> Result<Self, PgsqlError> {
76 let conn = pool.get_connect()?;
77 Ok(Self {
78 pool,
79 conn: Some(conn),
80 })
81 }
82 pub fn conn(&mut self) -> &mut Connect {
83 self.conn.as_mut().expect("connection already released")
84 }
85 pub fn discard(&mut self) {
87 if let Some(_conn) = self.conn.take() {
88 let (ref mutex, ref condvar) = *self.pool.inner;
89 let mut pool = lock_inner(mutex);
90 pool.total = pool.total.saturating_sub(1);
91 drop(pool);
92 condvar.notify_one();
93 }
94 }
95}
96impl Drop for ConnectionGuard {
97 fn drop(&mut self) {
98 if let Some(conn) = self.conn.take() {
99 self.pool.release_conn(conn);
100 }
101 }
102}
103impl Pools {
104 pub fn get_guard(&self) -> Result<ConnectionGuard, PgsqlError> {
105 ConnectionGuard::new(self.clone())
106 }
107 pub fn new(config: Config, size: usize) -> Result<Self, PgsqlError> {
108 let init_size = 2.min(size);
109 let mut idle = VecDeque::with_capacity(size);
110 let mut created = 0;
111 for _ in 0..init_size {
112 match Connect::new(config.clone()) {
113 Ok(conn) => {
114 idle.push_back(conn);
115 created += 1;
116 }
117 Err(e) => warn!("初始化连接失败: {e}"),
118 }
119 }
120 let txn_max = (size / 3).max(1);
121 let inner = PoolInner {
122 idle,
123 total: created,
124 max: size,
125 txn_total: 0,
126 txn_max,
127 };
128
129 let arc = Arc::new((Mutex::new(inner), Condvar::new()));
130
131 let weak = Arc::downgrade(&arc);
133 thread::spawn(move || {
134 Self::reaper_loop(weak);
135 });
136
137 Ok(Self { config, inner: arc })
138 }
139
140 fn acquire_connect(&self, for_transaction: bool) -> Result<Connect, PgsqlError> {
142 let mut attempts = 0;
143 let (ref mutex, ref condvar) = *self.inner;
144 let label = if for_transaction { "事务" } else { "" };
145 #[cfg(not(test))]
146 const BASE_SLEEP_MS: u64 = 200;
147 #[cfg(test)]
148 const BASE_SLEEP_MS: u64 = 1;
149 #[cfg(not(test))]
150 const MAX_SLEEP_MS: u64 = 2000;
151 #[cfg(test)]
152 const MAX_SLEEP_MS: u64 = 5;
153 #[cfg(not(test))]
154 const WAIT_TIMEOUT: Duration = Duration::from_secs(2);
155 #[cfg(test)]
156 const WAIT_TIMEOUT: Duration = Duration::from_millis(5);
157
158 let timeout_msg = if for_transaction {
159 "无法获取事务连接,重试超时"
160 } else {
161 "无法连接数据库,重试超时"
162 };
163
164 loop {
165 if attempts >= 5 {
166 return Err(PgsqlError::Pool(timeout_msg.into()));
167 }
168
169 let action = {
170 let mut pool = lock_inner(mutex);
171 if for_transaction && pool.txn_total >= pool.txn_max && pool.total >= pool.max {
173 Action::Wait
174 } else if let Some(conn) = pool.idle.pop_front() {
175 if for_transaction {
176 pool.txn_total += 1;
177 }
178 Action::GotConn(Box::new(conn))
179 } else if pool.total < pool.max {
180 pool.total += 1; if for_transaction {
182 pool.txn_total += 1;
183 }
184 Action::Create
185 } else {
186 Action::Wait
187 }
188 };
189
190 match action {
191 Action::GotConn(mut conn) => {
192 if conn.age().as_secs() > MAX_CONN_LIFETIME_SECS {
194 {
195 let mut pool = lock_inner(mutex);
196 pool.total = pool.total.saturating_sub(1);
197 if for_transaction {
198 pool.txn_total = pool.txn_total.saturating_sub(1);
199 }
200 }
201 log::debug!("{}连接存活超过{}秒,已丢弃", label, MAX_CONN_LIFETIME_SECS);
202 continue;
203 }
204 if conn.is_valid() {
206 conn.touch();
207 return Ok(*conn);
208 }
209 {
211 let mut pool = lock_inner(mutex);
212 pool.total = pool.total.saturating_sub(1);
213 if for_transaction {
214 pool.txn_total = pool.txn_total.saturating_sub(1);
215 }
216 }
217 warn!(
218 "{}连接失效已丢弃,当前总连接数量: {}",
219 label,
220 self.total_connections()
221 );
222 continue;
224 }
225
226 Action::Create => {
227 let mut guard = SlotGuard::new(mutex, condvar, for_transaction);
229 match Connect::new(self.config.clone()) {
230 Ok(new_conn) => {
231 guard.disarm();
232 return Ok(new_conn);
233 }
234 Err(e) => {
235 drop(guard);
237 let sleep_ms = BASE_SLEEP_MS
238 .saturating_mul(1u64 << attempts.min(3))
239 .min(MAX_SLEEP_MS);
240 attempts += 1;
241 error!("创建{}连接失败({}ms后重试): {}", label, sleep_ms, e);
242 thread::sleep(Duration::from_millis(sleep_ms));
243 }
244 }
245 }
246 Action::Wait => {
247 let pool = lock_inner(mutex);
248 let (_pool, timeout) = condvar
249 .wait_timeout(pool, WAIT_TIMEOUT)
250 .unwrap_or_else(PoisonError::into_inner);
251 drop(_pool);
252 if timeout.timed_out() {
253 attempts += 1;
254 }
255 }
256 }
257 }
258 }
259 pub fn get_connect(&self) -> Result<Connect, PgsqlError> {
260 self.acquire_connect(false)
261 }
262 pub fn get_connect_for_transaction(&self) -> Result<Connect, PgsqlError> {
264 self.acquire_connect(true)
265 }
266 pub fn release_transaction_conn(&self) {
267 let (ref mutex, ref condvar) = *self.inner;
268 let mut pool = lock_inner(mutex);
269 pool.total = pool.total.saturating_sub(1);
270 pool.txn_total = pool.txn_total.saturating_sub(1);
271 drop(pool);
272 condvar.notify_one();
273 }
274 pub fn release_transaction_conn_with_conn(&self, conn: Connect) {
276 let (ref mutex, _) = *self.inner;
277 {
278 let mut pool = lock_inner(mutex);
279 pool.txn_total = pool.txn_total.saturating_sub(1);
280 }
281 self.release_conn(conn);
282 }
283 pub fn release_conn(&self, conn: Connect) {
284 let (ref mutex, ref condvar) = *self.inner;
285 if !conn.peer_valid() {
286 let mut pool = lock_inner(mutex);
287 pool.total = pool.total.saturating_sub(1);
288 drop(pool);
289 condvar.notify_one();
290 warn!("释放时检测到坏连接,已丢弃");
291 return;
292 }
293 if conn.age().as_secs() > MAX_CONN_LIFETIME_SECS {
294 let mut pool = lock_inner(mutex);
295 pool.total = pool.total.saturating_sub(1);
296 drop(pool);
297 condvar.notify_one();
298 log::debug!("释放时连接存活超过{}秒,已丢弃", MAX_CONN_LIFETIME_SECS);
299 return;
300 }
301 if conn.idle_elapsed().as_secs() > MAX_IDLE_SECS {
302 let mut pool = lock_inner(mutex);
303 pool.total = pool.total.saturating_sub(1);
304 drop(pool);
305 condvar.notify_one();
306 log::debug!("连接空闲超过{}秒,已丢弃", MAX_IDLE_SECS);
307 return;
308 }
309 let mut pool = lock_inner(mutex);
310 if pool.idle.len() < pool.max {
311 pool.idle.push_back(conn);
312 } else {
313 pool.total = pool.total.saturating_sub(1);
314 warn!("连接池已满,丢弃连接");
315 }
316 drop(pool);
317 condvar.notify_one();
318 }
319 pub fn idle_pool_size(&self) -> usize {
320 let (ref mutex, _) = *self.inner;
321 let pool = lock_inner(mutex);
322 pool.idle.len()
323 }
324 pub fn total_connections(&self) -> usize {
325 let (ref mutex, _) = *self.inner;
326 let pool = lock_inner(mutex);
327 pool.total
328 }
329 pub fn borrowed_connections(&self) -> usize {
330 let (ref mutex, _) = *self.inner;
331 let pool = lock_inner(mutex);
332 pool.total.saturating_sub(pool.idle.len())
333 }
334 pub fn flush_idle(&self) {
336 let (ref mutex, _) = *self.inner;
337 let mut pool = lock_inner(mutex);
338 let flushed = pool.idle.len();
339 pool.total = pool.total.saturating_sub(flushed);
340 pool.idle.clear();
341 if flushed > 0 {
342 warn!("清空池中 {flushed} 个空闲连接(疑似批量失效)");
343 }
344 }
345 pub fn cleanup_idle_connections(&self) {
346 let (ref mutex, _) = *self.inner;
347 let mut pool = lock_inner(mutex);
348 let before = pool.idle.len();
349 pool.idle.retain(|conn| {
350 let peer_ok = conn.peer_valid();
351 let idle_ok = conn.idle_elapsed().as_secs() <= MAX_IDLE_SECS;
352 let lifetime_ok = conn.age().as_secs() <= MAX_CONN_LIFETIME_SECS;
353 if !peer_ok {
354 log::debug!("检测到无效连接,已移除");
355 } else if !idle_ok {
356 log::debug!("检测到空闲超时连接,已移除");
357 } else if !lifetime_ok {
358 log::debug!("检测到超过最大生命周期连接,已移除");
359 }
360 peer_ok && idle_ok && lifetime_ok
361 });
362 let removed = before - pool.idle.len();
363 pool.total = pool.total.saturating_sub(removed);
364 if removed > 0 {
365 log::debug!(
366 "空闲连接清理完成: 移除 {removed} 个,剩余 {} 个",
367 pool.idle.len()
368 );
369 }
370 }
371 fn reaper_loop(weak: Weak<(Mutex<PoolInner>, Condvar)>) {
373 #[cfg(not(test))]
374 const INTERVAL: Duration = Duration::from_secs(60);
375 #[cfg(test)]
376 const INTERVAL: Duration = Duration::from_millis(50);
377 loop {
378 thread::sleep(INTERVAL);
379 let arc = match weak.upgrade() {
380 Some(a) => a,
381 None => {
382 info!("连接池已释放,回收线程退出");
383 return;
384 }
385 };
386 let (ref mutex, _) = *arc;
387 let mut pool = lock_inner(mutex);
388 let before = pool.idle.len();
389 pool.idle.retain(|conn| {
390 conn.peer_valid()
391 && conn.idle_elapsed().as_secs() <= MAX_IDLE_SECS
392 && conn.age().as_secs() <= MAX_CONN_LIFETIME_SECS
393 });
394 let removed = before - pool.idle.len();
395 pool.total = pool.total.saturating_sub(removed);
396 if removed > 0 {
397 info!(
398 "后台回收: 移除 {removed} 个空闲连接,剩余 {} 个",
399 pool.idle.len()
400 );
401 }
402 }
403 }
404}
405
406enum Action {
408 GotConn(Box<Connect>),
409 Create,
410 Wait,
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416 use std::io::{Read as IoRead, Write as IoWrite};
417 use std::net::TcpListener;
418 use std::sync::atomic::{AtomicBool, Ordering};
419
420 fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
421 let mut m = Vec::with_capacity(5 + payload.len());
422 m.push(tag);
423 m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
424 m.extend_from_slice(payload);
425 m
426 }
427
428 fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
429 let mut body = Vec::new();
430 body.extend(&auth_type.to_be_bytes());
431 body.extend_from_slice(extra);
432 pg_msg(b'R', &body)
433 }
434
435 fn post_auth_ok() -> Vec<u8> {
436 let mut v = Vec::new();
437 v.extend(pg_auth(0, &[]));
438 v.extend(pg_msg(b'S', b"server_version\x0015.0\x00"));
439 let mut k = Vec::new();
440 k.extend(&1u32.to_be_bytes());
441 k.extend(&2u32.to_be_bytes());
442 v.extend(pg_msg(b'K', &k));
443 v.extend(pg_msg(b'Z', b"I"));
444 v
445 }
446
447 fn simple_query_response() -> Vec<u8> {
448 let mut r = Vec::new();
449 r.extend(pg_msg(b'1', &[]));
450 r.extend(pg_msg(b'2', &[]));
451 let mut rd = Vec::new();
452 rd.extend(&1u16.to_be_bytes());
453 rd.extend(b"c\x00");
454 rd.extend(&0u32.to_be_bytes());
455 rd.extend(&1u16.to_be_bytes());
456 rd.extend(&23u32.to_be_bytes());
457 rd.extend(&4i16.to_be_bytes());
458 rd.extend(&(-1i32).to_be_bytes());
459 rd.extend(&0u16.to_be_bytes());
460 r.extend(pg_msg(b'T', &rd));
461 let mut dr = Vec::new();
462 dr.extend(&1u16.to_be_bytes());
463 dr.extend(&1u32.to_be_bytes());
464 dr.push(b'1');
465 r.extend(pg_msg(b'D', &dr));
466 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
467 r.extend(pg_msg(b'Z', b"I"));
468 r
469 }
470
471 fn spawn_multi_server(stop: Arc<AtomicBool>) -> u16 {
472 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
473 let port = listener.local_addr().unwrap().port();
474 thread::spawn(move || {
475 listener.set_nonblocking(true).unwrap();
476 while !stop.load(Ordering::Relaxed) {
477 match listener.accept() {
478 Ok((s, _)) => {
479 s.set_nonblocking(false).ok();
480 let stop2 = stop.clone();
481 thread::spawn(move || {
482 s.set_read_timeout(Some(Duration::from_secs(5))).ok();
483 let mut s = s;
484 let mut buf = [0u8; 4096];
485 if s.read(&mut buf).unwrap_or(0) == 0 {
486 return;
487 }
488 let _ = s.write_all(&pg_auth(3, &[]));
489 if s.read(&mut buf).unwrap_or(0) == 0 {
490 return;
491 }
492 let _ = s.write_all(&post_auth_ok());
493 while !stop2.load(Ordering::Relaxed) {
494 match s.read(&mut buf) {
495 Ok(0) | Err(_) => break,
496 Ok(_) => {
497 let _ = s.write_all(&simple_query_response());
498 }
499 }
500 }
501 });
502 }
503 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
504 thread::sleep(Duration::from_millis(5));
505 }
506 Err(_) => break,
507 }
508 }
509 });
510 thread::sleep(Duration::from_millis(50));
511 port
512 }
513
514 fn mock_config(port: u16) -> Config {
515 Config {
516 debug: false,
517 hostname: "127.0.0.1".into(),
518 hostport: port as i32,
519 username: "u".into(),
520 userpass: "p".into(),
521 database: "d".into(),
522 charset: "utf8".into(),
523 pool_max: 5,
524 sslmode: "disable".into(),
525 }
526 }
527
528 #[test]
529 fn pools_all_paths() {
530 let stop = Arc::new(AtomicBool::new(false));
531 let port = spawn_multi_server(stop.clone());
532 let cfg = mock_config(port);
533
534 let pools = Pools::new(cfg.clone(), 10).unwrap();
536 assert_eq!(pools.total_connections(), 2);
537 assert_eq!(pools.idle_pool_size(), 2);
538 assert_eq!(pools.borrowed_connections(), 0);
539
540 let conn1 = pools.get_connect().unwrap();
542 assert_eq!(pools.idle_pool_size(), 1);
543 assert!(pools.borrowed_connections() > 0);
544
545 let idle_before = pools.idle_pool_size();
547 pools.release_conn(conn1);
548 assert!(pools.idle_pool_size() > idle_before);
549
550 let conn2 = pools.get_connect().unwrap();
552 drop(conn2);
553
554 let mut conn3 = pools.get_connect().unwrap();
556 let total_before = pools.total_connections();
557 conn3._close();
558 pools.release_conn(conn3);
559 assert!(pools.total_connections() <= total_before);
560
561 pools.cleanup_idle_connections();
563
564 {
566 let mut guard = pools.get_guard().unwrap();
567 let qr = guard.conn().query("SELECT 1");
568 assert!(qr.is_ok());
569 }
570 assert!(pools.idle_pool_size() > 0);
571
572 let pools2 = Pools::new(cfg.clone(), 10).unwrap();
574 let txn = pools2.get_connect_for_transaction().unwrap();
575 let total_before = pools2.total_connections();
576 pools2.release_transaction_conn();
577 assert_eq!(pools2.total_connections(), total_before - 1);
578 drop(txn);
579
580 let pools3 = Pools::new(cfg.clone(), 1).unwrap();
582 let held = pools3.get_connect().unwrap();
583 let result = pools3.get_connect();
584 assert!(result.is_err());
585 drop(held);
586
587 let bad_cfg = mock_config(1);
589 let pools4 = Pools::new(bad_cfg.clone(), 5).unwrap();
590 assert_eq!(pools4.total_connections(), 0);
591
592 let pools5 = Pools::new(bad_cfg.clone(), 5).unwrap();
594 let result = pools5.get_connect();
595 assert!(result.is_err());
596
597 let pools6 = Pools::new(bad_cfg.clone(), 5).unwrap();
599 let result = pools6.get_connect_for_transaction();
600 assert!(result.is_err());
601
602 let pools7 = Pools::new(cfg.clone(), 1).unwrap();
604 assert_eq!(pools7.total_connections(), 1);
605
606 stop.store(true, Ordering::Relaxed);
607 }
608}