1use crate::config::Config;
2use crate::connect::Connect;
3use crate::error::PgsqlError;
4use log::{error, warn};
5use std::collections::VecDeque;
6use std::sync::{Arc, Condvar, Mutex, MutexGuard, PoisonError};
7use std::thread;
8use std::time::Duration;
9
10const MAX_IDLE_SECS: u64 = 300;
12const MAX_CONN_LIFETIME_SECS: u64 = 1800;
14struct PoolInner {
15 idle: VecDeque<Connect>,
16 total: usize,
17 max: usize,
18 txn_total: usize,
20 txn_max: usize,
22}
23
24struct SlotGuard<'a> {
27 mutex: &'a Mutex<PoolInner>,
28 condvar: &'a Condvar,
29 active: bool,
30 for_transaction: bool,
31}
32
33impl<'a> SlotGuard<'a> {
34 fn new(mutex: &'a Mutex<PoolInner>, condvar: &'a Condvar, for_transaction: bool) -> Self {
35 Self {
36 mutex,
37 condvar,
38 active: true,
39 for_transaction,
40 }
41 }
42
43 fn disarm(&mut self) {
45 self.active = false;
46 }
47}
48
49impl Drop for SlotGuard<'_> {
50 fn drop(&mut self) {
51 if self.active {
52 let mut pool = lock_inner(self.mutex);
53 pool.total = pool.total.saturating_sub(1);
54 if self.for_transaction {
55 pool.txn_total = pool.txn_total.saturating_sub(1);
56 }
57 drop(pool);
58 self.condvar.notify_one();
59 }
60 }
61}
62#[derive(Clone)]
63pub struct Pools {
64 pub config: Config,
65 inner: Arc<(Mutex<PoolInner>, Condvar)>,
66}
67fn lock_inner(mutex: &Mutex<PoolInner>) -> MutexGuard<'_, PoolInner> {
68 mutex.lock().unwrap_or_else(PoisonError::into_inner)
69}
70pub struct ConnectionGuard {
71 pool: Pools,
72 conn: Option<Connect>,
73}
74impl ConnectionGuard {
75 pub fn new(pool: Pools) -> Result<Self, PgsqlError> {
76 let conn = pool.get_connect()?;
77 Ok(Self {
78 pool,
79 conn: Some(conn),
80 })
81 }
82 pub fn conn(&mut self) -> &mut Connect {
83 self.conn.as_mut().expect("connection already released")
84 }
85 pub fn discard(&mut self) {
87 if let Some(_conn) = self.conn.take() {
88 let (ref mutex, ref condvar) = *self.pool.inner;
89 let mut pool = lock_inner(mutex);
90 pool.total = pool.total.saturating_sub(1);
91 drop(pool);
92 condvar.notify_one();
93 }
94 }
95}
96impl Drop for ConnectionGuard {
97 fn drop(&mut self) {
98 if let Some(conn) = self.conn.take() {
99 self.pool.release_conn(conn);
100 }
101 }
102}
103impl Pools {
104 pub fn get_guard(&self) -> Result<ConnectionGuard, PgsqlError> {
105 ConnectionGuard::new(self.clone())
106 }
107 pub fn new(config: Config, size: usize) -> Result<Self, PgsqlError> {
108 let init_size = 2.min(size);
109 let mut idle = VecDeque::with_capacity(size);
110 let mut created = 0;
111 for _ in 0..init_size {
112 match Connect::new(config.clone()) {
113 Ok(conn) => {
114 idle.push_back(conn);
115 created += 1;
116 }
117 Err(e) => warn!("初始化连接失败: {e}"),
118 }
119 }
120 let txn_max = (size / 3).max(1);
121 let inner = PoolInner {
122 idle,
123 total: created,
124 max: size,
125 txn_total: 0,
126 txn_max,
127 };
128
129 Ok(Self {
130 config,
131 inner: Arc::new((Mutex::new(inner), Condvar::new())),
132 })
133 }
134
135 fn acquire_connect(&self, for_transaction: bool) -> Result<Connect, PgsqlError> {
137 let mut attempts = 0;
138 let (ref mutex, ref condvar) = *self.inner;
139 let label = if for_transaction { "事务" } else { "" };
140 #[cfg(not(test))]
141 const BASE_SLEEP_MS: u64 = 200;
142 #[cfg(test)]
143 const BASE_SLEEP_MS: u64 = 1;
144 #[cfg(not(test))]
145 const MAX_SLEEP_MS: u64 = 2000;
146 #[cfg(test)]
147 const MAX_SLEEP_MS: u64 = 5;
148 #[cfg(not(test))]
149 const WAIT_TIMEOUT: Duration = Duration::from_secs(2);
150 #[cfg(test)]
151 const WAIT_TIMEOUT: Duration = Duration::from_millis(5);
152
153 let timeout_msg = if for_transaction {
154 "无法获取事务连接,重试超时"
155 } else {
156 "无法连接数据库,重试超时"
157 };
158
159 loop {
160 if attempts >= 5 {
161 return Err(PgsqlError::Pool(timeout_msg.into()));
162 }
163
164 let action = {
165 let mut pool = lock_inner(mutex);
166 if for_transaction && pool.txn_total >= pool.txn_max && pool.total >= pool.max {
168 Action::Wait
169 } else if let Some(conn) = pool.idle.pop_front() {
170 if for_transaction {
171 pool.txn_total += 1;
172 }
173 Action::GotConn(Box::new(conn))
174 } else if pool.total < pool.max {
175 pool.total += 1; if for_transaction {
177 pool.txn_total += 1;
178 }
179 Action::Create
180 } else {
181 Action::Wait
182 }
183 };
184
185 match action {
186 Action::GotConn(mut conn) => {
187 if conn.age().as_secs() > MAX_CONN_LIFETIME_SECS {
189 {
190 let mut pool = lock_inner(mutex);
191 pool.total = pool.total.saturating_sub(1);
192 if for_transaction {
193 pool.txn_total = pool.txn_total.saturating_sub(1);
194 }
195 }
196 log::debug!("{}连接存活超过{}秒,已丢弃", label, MAX_CONN_LIFETIME_SECS);
197 continue;
198 }
199 if conn.is_valid() {
201 conn.touch();
202 return Ok(*conn);
203 }
204 {
206 let mut pool = lock_inner(mutex);
207 pool.total = pool.total.saturating_sub(1);
208 if for_transaction {
209 pool.txn_total = pool.txn_total.saturating_sub(1);
210 }
211 }
212 warn!(
213 "{}连接失效已丢弃,当前总连接数量: {}",
214 label,
215 self.total_connections()
216 );
217 continue;
219 }
220
221 Action::Create => {
222 let mut guard = SlotGuard::new(mutex, condvar, for_transaction);
224 match Connect::new(self.config.clone()) {
225 Ok(new_conn) => {
226 guard.disarm();
227 return Ok(new_conn);
228 }
229 Err(e) => {
230 drop(guard);
232 let sleep_ms = BASE_SLEEP_MS
233 .saturating_mul(1u64 << attempts.min(3))
234 .min(MAX_SLEEP_MS);
235 attempts += 1;
236 error!("创建{}连接失败({}ms后重试): {}", label, sleep_ms, e);
237 thread::sleep(Duration::from_millis(sleep_ms));
238 }
239 }
240 }
241 Action::Wait => {
242 let pool = lock_inner(mutex);
243 let (_pool, timeout) = condvar
244 .wait_timeout(pool, WAIT_TIMEOUT)
245 .unwrap_or_else(PoisonError::into_inner);
246 drop(_pool);
247 if timeout.timed_out() {
248 attempts += 1;
249 }
250 }
251 }
252 }
253 }
254 pub fn get_connect(&self) -> Result<Connect, PgsqlError> {
255 self.acquire_connect(false)
256 }
257 pub fn get_connect_for_transaction(&self) -> Result<Connect, PgsqlError> {
259 self.acquire_connect(true)
260 }
261 pub fn release_transaction_conn(&self) {
262 let (ref mutex, ref condvar) = *self.inner;
263 let mut pool = lock_inner(mutex);
264 pool.total = pool.total.saturating_sub(1);
265 pool.txn_total = pool.txn_total.saturating_sub(1);
266 drop(pool);
267 condvar.notify_one();
268 }
269 pub fn release_transaction_conn_with_conn(&self, conn: Connect) {
271 let (ref mutex, _) = *self.inner;
272 {
273 let mut pool = lock_inner(mutex);
274 pool.txn_total = pool.txn_total.saturating_sub(1);
275 }
276 self.release_conn(conn);
277 }
278 pub fn release_conn(&self, conn: Connect) {
279 let (ref mutex, ref condvar) = *self.inner;
280 if !conn.peer_valid() {
281 let mut pool = lock_inner(mutex);
282 pool.total = pool.total.saturating_sub(1);
283 drop(pool);
284 condvar.notify_one();
285 warn!("释放时检测到坏连接,已丢弃");
286 return;
287 }
288 if conn.age().as_secs() > MAX_CONN_LIFETIME_SECS {
289 let mut pool = lock_inner(mutex);
290 pool.total = pool.total.saturating_sub(1);
291 drop(pool);
292 condvar.notify_one();
293 log::debug!("释放时连接存活超过{}秒,已丢弃", MAX_CONN_LIFETIME_SECS);
294 return;
295 }
296 if conn.idle_elapsed().as_secs() > MAX_IDLE_SECS {
297 let mut pool = lock_inner(mutex);
298 pool.total = pool.total.saturating_sub(1);
299 drop(pool);
300 condvar.notify_one();
301 log::debug!("连接空闲超过{}秒,已丢弃", MAX_IDLE_SECS);
302 return;
303 }
304 let mut pool = lock_inner(mutex);
305 if pool.idle.len() < pool.max {
306 pool.idle.push_back(conn);
307 } else {
308 pool.total = pool.total.saturating_sub(1);
309 warn!("连接池已满,丢弃连接");
310 }
311 drop(pool);
312 condvar.notify_one();
313 }
314 pub fn idle_pool_size(&self) -> usize {
315 let (ref mutex, _) = *self.inner;
316 let pool = lock_inner(mutex);
317 pool.idle.len()
318 }
319 pub fn total_connections(&self) -> usize {
320 let (ref mutex, _) = *self.inner;
321 let pool = lock_inner(mutex);
322 pool.total
323 }
324 pub fn borrowed_connections(&self) -> usize {
325 let (ref mutex, _) = *self.inner;
326 let pool = lock_inner(mutex);
327 pool.total.saturating_sub(pool.idle.len())
328 }
329 pub fn flush_idle(&self) {
331 let (ref mutex, _) = *self.inner;
332 let mut pool = lock_inner(mutex);
333 let flushed = pool.idle.len();
334 pool.total = pool.total.saturating_sub(flushed);
335 pool.idle.clear();
336 if flushed > 0 {
337 warn!("清空池中 {flushed} 个空闲连接(疑似批量失效)");
338 }
339 }
340 #[allow(dead_code)]
341 pub fn _cleanup_idle_connections(&self) {
342 let (ref mutex, _) = *self.inner;
343 let mut pool = lock_inner(mutex);
344 log::debug!("当前连接池中的连接数量(清理前): {}", pool.idle.len());
345 let before = pool.idle.len();
346 pool.idle.retain(|conn| {
347 let peer_ok = conn.peer_valid();
348 let idle_ok = conn.idle_elapsed().as_secs() <= MAX_IDLE_SECS;
349 if !peer_ok {
350 log::debug!("检测到无效连接,已移除");
351 } else if !idle_ok {
352 log::debug!("检测到空闲超时连接,已移除");
353 }
354 peer_ok && idle_ok
355 });
356 let removed = before - pool.idle.len();
357 pool.total = pool.total.saturating_sub(removed);
358 }
359}
360
361enum Action {
363 GotConn(Box<Connect>),
364 Create,
365 Wait,
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use std::io::{Read as IoRead, Write as IoWrite};
372 use std::net::TcpListener;
373 use std::sync::atomic::{AtomicBool, Ordering};
374
375 fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
376 let mut m = Vec::with_capacity(5 + payload.len());
377 m.push(tag);
378 m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
379 m.extend_from_slice(payload);
380 m
381 }
382
383 fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
384 let mut body = Vec::new();
385 body.extend(&auth_type.to_be_bytes());
386 body.extend_from_slice(extra);
387 pg_msg(b'R', &body)
388 }
389
390 fn post_auth_ok() -> Vec<u8> {
391 let mut v = Vec::new();
392 v.extend(pg_auth(0, &[]));
393 v.extend(pg_msg(b'S', b"server_version\x0015.0\x00"));
394 let mut k = Vec::new();
395 k.extend(&1u32.to_be_bytes());
396 k.extend(&2u32.to_be_bytes());
397 v.extend(pg_msg(b'K', &k));
398 v.extend(pg_msg(b'Z', b"I"));
399 v
400 }
401
402 fn simple_query_response() -> Vec<u8> {
403 let mut r = Vec::new();
404 r.extend(pg_msg(b'1', &[]));
405 r.extend(pg_msg(b'2', &[]));
406 let mut rd = Vec::new();
407 rd.extend(&1u16.to_be_bytes());
408 rd.extend(b"c\x00");
409 rd.extend(&0u32.to_be_bytes());
410 rd.extend(&1u16.to_be_bytes());
411 rd.extend(&23u32.to_be_bytes());
412 rd.extend(&4i16.to_be_bytes());
413 rd.extend(&(-1i32).to_be_bytes());
414 rd.extend(&0u16.to_be_bytes());
415 r.extend(pg_msg(b'T', &rd));
416 let mut dr = Vec::new();
417 dr.extend(&1u16.to_be_bytes());
418 dr.extend(&1u32.to_be_bytes());
419 dr.push(b'1');
420 r.extend(pg_msg(b'D', &dr));
421 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
422 r.extend(pg_msg(b'Z', b"I"));
423 r
424 }
425
426 fn spawn_multi_server(stop: Arc<AtomicBool>) -> u16 {
427 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
428 let port = listener.local_addr().unwrap().port();
429 thread::spawn(move || {
430 listener.set_nonblocking(true).unwrap();
431 while !stop.load(Ordering::Relaxed) {
432 match listener.accept() {
433 Ok((s, _)) => {
434 s.set_nonblocking(false).ok();
435 let stop2 = stop.clone();
436 thread::spawn(move || {
437 s.set_read_timeout(Some(Duration::from_secs(5))).ok();
438 let mut s = s;
439 let mut buf = [0u8; 4096];
440 if s.read(&mut buf).unwrap_or(0) == 0 {
441 return;
442 }
443 let _ = s.write_all(&pg_auth(3, &[]));
444 if s.read(&mut buf).unwrap_or(0) == 0 {
445 return;
446 }
447 let _ = s.write_all(&post_auth_ok());
448 while !stop2.load(Ordering::Relaxed) {
449 match s.read(&mut buf) {
450 Ok(0) | Err(_) => break,
451 Ok(_) => {
452 let _ = s.write_all(&simple_query_response());
453 }
454 }
455 }
456 });
457 }
458 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
459 thread::sleep(Duration::from_millis(5));
460 }
461 Err(_) => break,
462 }
463 }
464 });
465 thread::sleep(Duration::from_millis(50));
466 port
467 }
468
469 fn mock_config(port: u16) -> Config {
470 Config {
471 debug: false,
472 hostname: "127.0.0.1".into(),
473 hostport: port as i32,
474 username: "u".into(),
475 userpass: "p".into(),
476 database: "d".into(),
477 charset: "utf8".into(),
478 pool_max: 5,
479 }
480 }
481
482 #[test]
483 fn pools_all_paths() {
484 let stop = Arc::new(AtomicBool::new(false));
485 let port = spawn_multi_server(stop.clone());
486 let cfg = mock_config(port);
487
488 let pools = Pools::new(cfg.clone(), 10).unwrap();
490 assert_eq!(pools.total_connections(), 2);
491 assert_eq!(pools.idle_pool_size(), 2);
492 assert_eq!(pools.borrowed_connections(), 0);
493
494 let conn1 = pools.get_connect().unwrap();
496 assert_eq!(pools.idle_pool_size(), 1);
497 assert!(pools.borrowed_connections() > 0);
498
499 let idle_before = pools.idle_pool_size();
501 pools.release_conn(conn1);
502 assert!(pools.idle_pool_size() > idle_before);
503
504 let conn2 = pools.get_connect().unwrap();
506 drop(conn2);
507
508 let mut conn3 = pools.get_connect().unwrap();
510 let total_before = pools.total_connections();
511 conn3._close();
512 pools.release_conn(conn3);
513 assert!(pools.total_connections() <= total_before);
514
515 pools._cleanup_idle_connections();
517
518 {
520 let mut guard = pools.get_guard().unwrap();
521 let qr = guard.conn().query("SELECT 1");
522 assert!(qr.is_ok());
523 }
524 assert!(pools.idle_pool_size() > 0);
525
526 let pools2 = Pools::new(cfg.clone(), 10).unwrap();
528 let txn = pools2.get_connect_for_transaction().unwrap();
529 let total_before = pools2.total_connections();
530 pools2.release_transaction_conn();
531 assert_eq!(pools2.total_connections(), total_before - 1);
532 drop(txn);
533
534 let pools3 = Pools::new(cfg.clone(), 1).unwrap();
536 let held = pools3.get_connect().unwrap();
537 let result = pools3.get_connect();
538 assert!(result.is_err());
539 drop(held);
540
541 let bad_cfg = mock_config(1);
543 let pools4 = Pools::new(bad_cfg.clone(), 5).unwrap();
544 assert_eq!(pools4.total_connections(), 0);
545
546 let pools5 = Pools::new(bad_cfg.clone(), 5).unwrap();
548 let result = pools5.get_connect();
549 assert!(result.is_err());
550
551 let pools6 = Pools::new(bad_cfg.clone(), 5).unwrap();
553 let result = pools6.get_connect_for_transaction();
554 assert!(result.is_err());
555
556 let pools7 = Pools::new(cfg.clone(), 1).unwrap();
558 assert_eq!(pools7.total_connections(), 1);
559
560 stop.store(true, Ordering::Relaxed);
561 }
562}