1use crate::config::Config;
2use crate::connect::Connect;
3use crate::error::PgsqlError;
4use log::{error, warn};
5use std::collections::VecDeque;
6use std::sync::{Arc, Condvar, Mutex, MutexGuard, PoisonError};
7use std::thread;
8use std::time::Duration;
9
10const MAX_IDLE_SECS: u64 = 300;
12const MAX_CONN_LIFETIME_SECS: u64 = 1800;
14struct PoolInner {
15 idle: VecDeque<Connect>,
16 total: usize,
17 max: usize,
18 txn_total: usize,
20 txn_max: usize,
22}
23
24struct SlotGuard<'a> {
27 mutex: &'a Mutex<PoolInner>,
28 condvar: &'a Condvar,
29 active: bool,
30 for_transaction: bool,
31}
32
33impl<'a> SlotGuard<'a> {
34 fn new(mutex: &'a Mutex<PoolInner>, condvar: &'a Condvar, for_transaction: bool) -> Self {
35 Self {
36 mutex,
37 condvar,
38 active: true,
39 for_transaction,
40 }
41 }
42
43 fn disarm(&mut self) {
45 self.active = false;
46 }
47}
48
49impl Drop for SlotGuard<'_> {
50 fn drop(&mut self) {
51 if self.active {
52 let mut pool = lock_inner(self.mutex);
53 pool.total = pool.total.saturating_sub(1);
54 if self.for_transaction {
55 pool.txn_total = pool.txn_total.saturating_sub(1);
56 }
57 drop(pool);
58 self.condvar.notify_one();
59 }
60 }
61}
62#[derive(Clone)]
63pub struct Pools {
64 pub config: Config,
65 inner: Arc<(Mutex<PoolInner>, Condvar)>,
66}
67fn lock_inner(mutex: &Mutex<PoolInner>) -> MutexGuard<'_, PoolInner> {
68 mutex.lock().unwrap_or_else(PoisonError::into_inner)
69}
70pub struct ConnectionGuard {
71 pool: Pools,
72 conn: Option<Connect>,
73}
74impl ConnectionGuard {
75 pub fn new(pool: Pools) -> Result<Self, PgsqlError> {
76 let conn = pool.get_connect()?;
77 Ok(Self {
78 pool,
79 conn: Some(conn),
80 })
81 }
82 pub fn conn(&mut self) -> &mut Connect {
83 self.conn.as_mut().expect("connection already released")
84 }
85 pub fn discard(&mut self) {
87 if let Some(_conn) = self.conn.take() {
88 let (ref mutex, ref condvar) = *self.pool.inner;
89 let mut pool = lock_inner(mutex);
90 pool.total = pool.total.saturating_sub(1);
91 drop(pool);
92 condvar.notify_one();
93 }
94 }
95}
96impl Drop for ConnectionGuard {
97 fn drop(&mut self) {
98 if let Some(conn) = self.conn.take() {
99 self.pool.release_conn(conn);
100 }
101 }
102}
103impl Pools {
104 pub fn get_guard(&self) -> Result<ConnectionGuard, PgsqlError> {
105 ConnectionGuard::new(self.clone())
106 }
107 pub fn new(config: Config, size: usize) -> Result<Self, PgsqlError> {
108 let init_size = 2.min(size);
109 let mut idle = VecDeque::with_capacity(size);
110 let mut created = 0;
111 for _ in 0..init_size {
112 match Connect::new(config.clone()) {
113 Ok(conn) => {
114 idle.push_back(conn);
115 created += 1;
116 }
117 Err(e) => warn!("初始化连接失败: {e}"),
118 }
119 }
120 let txn_max = (size / 3).max(1);
121 let inner = PoolInner {
122 idle,
123 total: created,
124 max: size,
125 txn_total: 0,
126 txn_max,
127 };
128
129 Ok(Self {
130 config,
131 inner: Arc::new((Mutex::new(inner), Condvar::new())),
132 })
133 }
134
135 fn acquire_connect(&self, for_transaction: bool) -> Result<Connect, PgsqlError> {
137 let mut attempts = 0;
138 let (ref mutex, ref condvar) = *self.inner;
139 let label = if for_transaction { "事务" } else { "" };
140 #[cfg(not(test))]
141 const BASE_SLEEP_MS: u64 = 200;
142 #[cfg(test)]
143 const BASE_SLEEP_MS: u64 = 1;
144 #[cfg(not(test))]
145 const MAX_SLEEP_MS: u64 = 2000;
146 #[cfg(test)]
147 const MAX_SLEEP_MS: u64 = 5;
148 #[cfg(not(test))]
149 const WAIT_TIMEOUT: Duration = Duration::from_secs(2);
150 #[cfg(test)]
151 const WAIT_TIMEOUT: Duration = Duration::from_millis(5);
152
153 let timeout_msg = if for_transaction {
154 "无法获取事务连接,重试超时"
155 } else {
156 "无法连接数据库,重试超时"
157 };
158
159 loop {
160 if attempts >= 5 {
161 return Err(PgsqlError::Pool(timeout_msg.into()));
162 }
163
164 let action = {
165 let mut pool = lock_inner(mutex);
166 if for_transaction && pool.txn_total >= pool.txn_max && pool.total >= pool.max {
168 Action::Wait
169 } else if let Some(conn) = pool.idle.pop_front() {
170 if for_transaction {
171 pool.txn_total += 1;
172 }
173 Action::GotConn(Box::new(conn))
174 } else if pool.total < pool.max {
175 pool.total += 1; if for_transaction {
177 pool.txn_total += 1;
178 }
179 Action::Create
180 } else {
181 Action::Wait
182 }
183 };
184
185 match action {
186 Action::GotConn(mut conn) => {
187 if conn.age().as_secs() > MAX_CONN_LIFETIME_SECS {
189 {
190 let mut pool = lock_inner(mutex);
191 pool.total = pool.total.saturating_sub(1);
192 if for_transaction {
193 pool.txn_total = pool.txn_total.saturating_sub(1);
194 }
195 }
196 log::debug!("{}连接存活超过{}秒,已丢弃", label, MAX_CONN_LIFETIME_SECS);
197 continue;
198 }
199 if conn.is_valid() {
201 conn.touch();
202 return Ok(*conn);
203 }
204 {
206 let mut pool = lock_inner(mutex);
207 pool.total = pool.total.saturating_sub(1);
208 if for_transaction {
209 pool.txn_total = pool.txn_total.saturating_sub(1);
210 }
211 }
212 warn!(
213 "{}连接失效已丢弃,当前总连接数量: {}",
214 label,
215 self.total_connections()
216 );
217 continue;
219 }
220
221 Action::Create => {
222 let mut guard = SlotGuard::new(mutex, condvar, for_transaction);
224 match Connect::new(self.config.clone()) {
225 Ok(new_conn) => {
226 guard.disarm();
227 return Ok(new_conn);
228 }
229 Err(e) => {
230 drop(guard);
232 let sleep_ms = BASE_SLEEP_MS
233 .saturating_mul(1u64 << attempts.min(3))
234 .min(MAX_SLEEP_MS);
235 attempts += 1;
236 error!("创建{}连接失败({}ms后重试): {}", label, sleep_ms, e);
237 thread::sleep(Duration::from_millis(sleep_ms));
238 }
239 }
240 }
241 Action::Wait => {
242 let pool = lock_inner(mutex);
243 let (_pool, timeout) = condvar
244 .wait_timeout(pool, WAIT_TIMEOUT)
245 .unwrap_or_else(PoisonError::into_inner);
246 drop(_pool);
247 if timeout.timed_out() {
248 attempts += 1;
249 }
250 }
251 }
252 }
253 }
254 pub fn get_connect(&self) -> Result<Connect, PgsqlError> {
255 self.acquire_connect(false)
256 }
257 pub fn get_connect_for_transaction(&self) -> Result<Connect, PgsqlError> {
259 self.acquire_connect(true)
260 }
261 pub fn release_transaction_conn(&self) {
262 let (ref mutex, ref condvar) = *self.inner;
263 let mut pool = lock_inner(mutex);
264 pool.total = pool.total.saturating_sub(1);
265 pool.txn_total = pool.txn_total.saturating_sub(1);
266 drop(pool);
267 condvar.notify_one();
268 }
269 pub fn release_conn(&self, conn: Connect) {
270 let (ref mutex, ref condvar) = *self.inner;
271 if !conn.peer_valid() {
272 let mut pool = lock_inner(mutex);
273 pool.total = pool.total.saturating_sub(1);
274 drop(pool);
275 condvar.notify_one();
276 warn!("释放时检测到坏连接,已丢弃");
277 return;
278 }
279 if conn.age().as_secs() > MAX_CONN_LIFETIME_SECS {
280 let mut pool = lock_inner(mutex);
281 pool.total = pool.total.saturating_sub(1);
282 drop(pool);
283 condvar.notify_one();
284 log::debug!("释放时连接存活超过{}秒,已丢弃", MAX_CONN_LIFETIME_SECS);
285 return;
286 }
287 if conn.idle_elapsed().as_secs() > MAX_IDLE_SECS {
288 let mut pool = lock_inner(mutex);
289 pool.total = pool.total.saturating_sub(1);
290 drop(pool);
291 condvar.notify_one();
292 log::debug!("连接空闲超过{}秒,已丢弃", MAX_IDLE_SECS);
293 return;
294 }
295 let mut pool = lock_inner(mutex);
296 if pool.idle.len() < pool.max {
297 pool.idle.push_back(conn);
298 } else {
299 pool.total = pool.total.saturating_sub(1);
300 warn!("连接池已满,丢弃连接");
301 }
302 drop(pool);
303 condvar.notify_one();
304 }
305 pub fn idle_pool_size(&self) -> usize {
306 let (ref mutex, _) = *self.inner;
307 let pool = lock_inner(mutex);
308 pool.idle.len()
309 }
310 pub fn total_connections(&self) -> usize {
311 let (ref mutex, _) = *self.inner;
312 let pool = lock_inner(mutex);
313 pool.total
314 }
315 pub fn borrowed_connections(&self) -> usize {
316 let (ref mutex, _) = *self.inner;
317 let pool = lock_inner(mutex);
318 pool.total.saturating_sub(pool.idle.len())
319 }
320 pub fn flush_idle(&self) {
322 let (ref mutex, _) = *self.inner;
323 let mut pool = lock_inner(mutex);
324 let flushed = pool.idle.len();
325 pool.total = pool.total.saturating_sub(flushed);
326 pool.idle.clear();
327 if flushed > 0 {
328 warn!("清空池中 {flushed} 个空闲连接(疑似批量失效)");
329 }
330 }
331 #[allow(dead_code)]
332 pub fn _cleanup_idle_connections(&self) {
333 let (ref mutex, _) = *self.inner;
334 let mut pool = lock_inner(mutex);
335 log::debug!("当前连接池中的连接数量(清理前): {}", pool.idle.len());
336 let before = pool.idle.len();
337 pool.idle.retain(|conn| {
338 let peer_ok = conn.peer_valid();
339 let idle_ok = conn.idle_elapsed().as_secs() <= MAX_IDLE_SECS;
340 if !peer_ok {
341 log::debug!("检测到无效连接,已移除");
342 } else if !idle_ok {
343 log::debug!("检测到空闲超时连接,已移除");
344 }
345 peer_ok && idle_ok
346 });
347 let removed = before - pool.idle.len();
348 pool.total = pool.total.saturating_sub(removed);
349 }
350}
351
352enum Action {
354 GotConn(Box<Connect>),
355 Create,
356 Wait,
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use std::io::{Read as IoRead, Write as IoWrite};
363 use std::net::TcpListener;
364 use std::sync::atomic::{AtomicBool, Ordering};
365
366 fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
367 let mut m = Vec::with_capacity(5 + payload.len());
368 m.push(tag);
369 m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
370 m.extend_from_slice(payload);
371 m
372 }
373
374 fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
375 let mut body = Vec::new();
376 body.extend(&auth_type.to_be_bytes());
377 body.extend_from_slice(extra);
378 pg_msg(b'R', &body)
379 }
380
381 fn post_auth_ok() -> Vec<u8> {
382 let mut v = Vec::new();
383 v.extend(pg_auth(0, &[]));
384 v.extend(pg_msg(b'S', b"server_version\x0015.0\x00"));
385 let mut k = Vec::new();
386 k.extend(&1u32.to_be_bytes());
387 k.extend(&2u32.to_be_bytes());
388 v.extend(pg_msg(b'K', &k));
389 v.extend(pg_msg(b'Z', b"I"));
390 v
391 }
392
393 fn simple_query_response() -> Vec<u8> {
394 let mut r = Vec::new();
395 r.extend(pg_msg(b'1', &[]));
396 r.extend(pg_msg(b'2', &[]));
397 let mut rd = Vec::new();
398 rd.extend(&1u16.to_be_bytes());
399 rd.extend(b"c\x00");
400 rd.extend(&0u32.to_be_bytes());
401 rd.extend(&1u16.to_be_bytes());
402 rd.extend(&23u32.to_be_bytes());
403 rd.extend(&4i16.to_be_bytes());
404 rd.extend(&(-1i32).to_be_bytes());
405 rd.extend(&0u16.to_be_bytes());
406 r.extend(pg_msg(b'T', &rd));
407 let mut dr = Vec::new();
408 dr.extend(&1u16.to_be_bytes());
409 dr.extend(&1u32.to_be_bytes());
410 dr.push(b'1');
411 r.extend(pg_msg(b'D', &dr));
412 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
413 r.extend(pg_msg(b'Z', b"I"));
414 r
415 }
416
417 fn spawn_multi_server(stop: Arc<AtomicBool>) -> u16 {
418 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
419 let port = listener.local_addr().unwrap().port();
420 thread::spawn(move || {
421 listener.set_nonblocking(true).unwrap();
422 while !stop.load(Ordering::Relaxed) {
423 match listener.accept() {
424 Ok((s, _)) => {
425 s.set_nonblocking(false).ok();
426 let stop2 = stop.clone();
427 thread::spawn(move || {
428 s.set_read_timeout(Some(Duration::from_secs(5))).ok();
429 let mut s = s;
430 let mut buf = [0u8; 4096];
431 if s.read(&mut buf).unwrap_or(0) == 0 {
432 return;
433 }
434 let _ = s.write_all(&pg_auth(3, &[]));
435 if s.read(&mut buf).unwrap_or(0) == 0 {
436 return;
437 }
438 let _ = s.write_all(&post_auth_ok());
439 while !stop2.load(Ordering::Relaxed) {
440 match s.read(&mut buf) {
441 Ok(0) | Err(_) => break,
442 Ok(_) => {
443 let _ = s.write_all(&simple_query_response());
444 }
445 }
446 }
447 });
448 }
449 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
450 thread::sleep(Duration::from_millis(5));
451 }
452 Err(_) => break,
453 }
454 }
455 });
456 thread::sleep(Duration::from_millis(50));
457 port
458 }
459
460 fn mock_config(port: u16) -> Config {
461 Config {
462 debug: false,
463 hostname: "127.0.0.1".into(),
464 hostport: port as i32,
465 username: "u".into(),
466 userpass: "p".into(),
467 database: "d".into(),
468 charset: "utf8".into(),
469 pool_max: 5,
470 }
471 }
472
473 #[test]
474 fn pools_all_paths() {
475 let stop = Arc::new(AtomicBool::new(false));
476 let port = spawn_multi_server(stop.clone());
477 let cfg = mock_config(port);
478
479 let pools = Pools::new(cfg.clone(), 10).unwrap();
481 assert_eq!(pools.total_connections(), 2);
482 assert_eq!(pools.idle_pool_size(), 2);
483 assert_eq!(pools.borrowed_connections(), 0);
484
485 let conn1 = pools.get_connect().unwrap();
487 assert_eq!(pools.idle_pool_size(), 1);
488 assert!(pools.borrowed_connections() > 0);
489
490 let idle_before = pools.idle_pool_size();
492 pools.release_conn(conn1);
493 assert!(pools.idle_pool_size() > idle_before);
494
495 let conn2 = pools.get_connect().unwrap();
497 drop(conn2);
498
499 let mut conn3 = pools.get_connect().unwrap();
501 let total_before = pools.total_connections();
502 conn3._close();
503 pools.release_conn(conn3);
504 assert!(pools.total_connections() <= total_before);
505
506 pools._cleanup_idle_connections();
508
509 {
511 let mut guard = pools.get_guard().unwrap();
512 let qr = guard.conn().query("SELECT 1");
513 assert!(qr.is_ok());
514 }
515 assert!(pools.idle_pool_size() > 0);
516
517 let pools2 = Pools::new(cfg.clone(), 10).unwrap();
519 let txn = pools2.get_connect_for_transaction().unwrap();
520 let total_before = pools2.total_connections();
521 pools2.release_transaction_conn();
522 assert_eq!(pools2.total_connections(), total_before - 1);
523 drop(txn);
524
525 let pools3 = Pools::new(cfg.clone(), 1).unwrap();
527 let held = pools3.get_connect().unwrap();
528 let result = pools3.get_connect();
529 assert!(result.is_err());
530 drop(held);
531
532 let bad_cfg = mock_config(1);
534 let pools4 = Pools::new(bad_cfg.clone(), 5).unwrap();
535 assert_eq!(pools4.total_connections(), 0);
536
537 let pools5 = Pools::new(bad_cfg.clone(), 5).unwrap();
539 let result = pools5.get_connect();
540 assert!(result.is_err());
541
542 let pools6 = Pools::new(bad_cfg.clone(), 5).unwrap();
544 let result = pools6.get_connect_for_transaction();
545 assert!(result.is_err());
546
547 let pools7 = Pools::new(cfg.clone(), 1).unwrap();
549 assert_eq!(pools7.total_connections(), 1);
550
551 stop.store(true, Ordering::Relaxed);
552 }
553}