1use crate::config::Config;
2use crate::connect::Connect;
3use crate::error::PgsqlError;
4use log::{error, warn};
5use std::collections::VecDeque;
6use std::sync::{Arc, Condvar, Mutex, MutexGuard, PoisonError};
7use std::thread;
8use std::time::Duration;
9
10const MAX_IDLE_SECS: u64 = 300;
12struct PoolInner {
13 idle: VecDeque<Connect>,
14 total: usize,
15 max: usize,
16 txn_total: usize,
18 txn_max: usize,
20}
21
22struct SlotGuard<'a> {
25 mutex: &'a Mutex<PoolInner>,
26 condvar: &'a Condvar,
27 active: bool,
28 for_transaction: bool,
29}
30
31impl<'a> SlotGuard<'a> {
32 fn new(mutex: &'a Mutex<PoolInner>, condvar: &'a Condvar, for_transaction: bool) -> Self {
33 Self {
34 mutex,
35 condvar,
36 active: true,
37 for_transaction,
38 }
39 }
40
41 fn disarm(&mut self) {
43 self.active = false;
44 }
45}
46
47impl Drop for SlotGuard<'_> {
48 fn drop(&mut self) {
49 if self.active {
50 let mut pool = lock_inner(self.mutex);
51 pool.total = pool.total.saturating_sub(1);
52 if self.for_transaction {
53 pool.txn_total = pool.txn_total.saturating_sub(1);
54 }
55 drop(pool);
56 self.condvar.notify_one();
57 }
58 }
59}
60#[derive(Clone)]
61pub struct Pools {
62 pub config: Config,
63 inner: Arc<(Mutex<PoolInner>, Condvar)>,
64}
65fn lock_inner(mutex: &Mutex<PoolInner>) -> MutexGuard<'_, PoolInner> {
66 mutex.lock().unwrap_or_else(PoisonError::into_inner)
67}
68pub struct ConnectionGuard {
69 pool: Pools,
70 conn: Option<Connect>,
71}
72impl ConnectionGuard {
73 pub fn new(pool: Pools) -> Result<Self, PgsqlError> {
74 let conn = pool.get_connect()?;
75 Ok(Self {
76 pool,
77 conn: Some(conn),
78 })
79 }
80 pub fn conn(&mut self) -> &mut Connect {
81 self.conn.as_mut().expect("connection already released")
82 }
83 pub fn discard(&mut self) {
85 if let Some(_conn) = self.conn.take() {
86 let (ref mutex, ref condvar) = *self.pool.inner;
87 let mut pool = lock_inner(mutex);
88 pool.total = pool.total.saturating_sub(1);
89 drop(pool);
90 condvar.notify_one();
91 }
92 }
93}
94impl Drop for ConnectionGuard {
95 fn drop(&mut self) {
96 if let Some(conn) = self.conn.take() {
97 self.pool.release_conn(conn);
98 }
99 }
100}
101impl Pools {
102 pub fn get_guard(&self) -> Result<ConnectionGuard, PgsqlError> {
103 ConnectionGuard::new(self.clone())
104 }
105 pub fn new(config: Config, size: usize) -> Result<Self, PgsqlError> {
106 let init_size = 2.min(size);
107 let mut idle = VecDeque::with_capacity(size);
108 let mut created = 0;
109 for _ in 0..init_size {
110 match Connect::new(config.clone()) {
111 Ok(conn) => {
112 idle.push_back(conn);
113 created += 1;
114 }
115 Err(e) => warn!("初始化连接失败: {e}"),
116 }
117 }
118 let txn_max = (size / 3).max(1);
119 let inner = PoolInner {
120 idle,
121 total: created,
122 max: size,
123 txn_total: 0,
124 txn_max,
125 };
126
127 Ok(Self {
128 config,
129 inner: Arc::new((Mutex::new(inner), Condvar::new())),
130 })
131 }
132
133 fn acquire_connect(&self, for_transaction: bool) -> Result<Connect, PgsqlError> {
135 let mut attempts = 0;
136 let (ref mutex, ref condvar) = *self.inner;
137 let label = if for_transaction { "事务" } else { "" };
138 #[cfg(not(test))]
139 const BASE_SLEEP_MS: u64 = 200;
140 #[cfg(test)]
141 const BASE_SLEEP_MS: u64 = 1;
142 #[cfg(not(test))]
143 const MAX_SLEEP_MS: u64 = 2000;
144 #[cfg(test)]
145 const MAX_SLEEP_MS: u64 = 5;
146 #[cfg(not(test))]
147 const WAIT_TIMEOUT: Duration = Duration::from_secs(2);
148 #[cfg(test)]
149 const WAIT_TIMEOUT: Duration = Duration::from_millis(5);
150
151 let timeout_msg = if for_transaction {
152 "无法获取事务连接,重试超时"
153 } else {
154 "无法连接数据库,重试超时"
155 };
156
157 loop {
158 if attempts >= 5 {
159 return Err(PgsqlError::Pool(timeout_msg.into()));
160 }
161
162 let action = {
163 let mut pool = lock_inner(mutex);
164 if for_transaction && pool.txn_total >= pool.txn_max && pool.total >= pool.max {
166 Action::Wait
167 } else if let Some(conn) = pool.idle.pop_front() {
168 if for_transaction {
169 pool.txn_total += 1;
170 }
171 Action::GotConn(Box::new(conn))
172 } else if pool.total < pool.max {
173 pool.total += 1; if for_transaction {
175 pool.txn_total += 1;
176 }
177 Action::Create
178 } else {
179 Action::Wait
180 }
181 };
182
183 match action {
184 Action::GotConn(mut conn) => {
185 if conn.is_valid() {
187 conn.touch();
188 return Ok(*conn);
189 }
190 {
192 let mut pool = lock_inner(mutex);
193 pool.total = pool.total.saturating_sub(1);
194 }
195 warn!(
196 "{}连接失效,尝试重建,当前总连接数量: {}",
197 label,
198 self.total_connections()
199 );
200 let mut guard = SlotGuard::new(mutex, condvar, for_transaction);
202 {
203 let mut pool = lock_inner(mutex);
204 pool.total += 1;
205 }
206 match Connect::new(self.config.clone()) {
207 Ok(new_conn) => {
208 guard.disarm();
209 return Ok(new_conn);
210 }
211 Err(e) => {
212 drop(guard);
214 let sleep_ms = BASE_SLEEP_MS
215 .saturating_mul(1u64 << attempts.min(3))
216 .min(MAX_SLEEP_MS);
217 attempts += 1;
218 error!("{}重建连接失败({}ms后重试): {}", label, sleep_ms, e);
219 thread::sleep(Duration::from_millis(sleep_ms));
220 }
221 }
222 }
223 Action::Create => {
224 let mut guard = SlotGuard::new(mutex, condvar, for_transaction);
226 match Connect::new(self.config.clone()) {
227 Ok(new_conn) => {
228 guard.disarm();
229 return Ok(new_conn);
230 }
231 Err(e) => {
232 drop(guard);
234 let sleep_ms = BASE_SLEEP_MS
235 .saturating_mul(1u64 << attempts.min(3))
236 .min(MAX_SLEEP_MS);
237 attempts += 1;
238 error!("创建{}连接失败({}ms后重试): {}", label, sleep_ms, e);
239 thread::sleep(Duration::from_millis(sleep_ms));
240 }
241 }
242 }
243 Action::Wait => {
244 let pool = lock_inner(mutex);
245 let (_pool, timeout) = condvar
246 .wait_timeout(pool, WAIT_TIMEOUT)
247 .unwrap_or_else(PoisonError::into_inner);
248 drop(_pool);
249 if timeout.timed_out() {
250 attempts += 1;
251 }
252 }
253 }
254 }
255 }
256 pub fn get_connect(&self) -> Result<Connect, PgsqlError> {
257 self.acquire_connect(false)
258 }
259 pub fn get_connect_for_transaction(&self) -> Result<Connect, PgsqlError> {
261 self.acquire_connect(true)
262 }
263 pub fn release_transaction_conn(&self) {
264 let (ref mutex, ref condvar) = *self.inner;
265 let mut pool = lock_inner(mutex);
266 pool.total = pool.total.saturating_sub(1);
267 pool.txn_total = pool.txn_total.saturating_sub(1);
268 drop(pool);
269 condvar.notify_one();
270 }
271 pub fn release_conn(&self, conn: Connect) {
272 let (ref mutex, ref condvar) = *self.inner;
273 if !conn.peer_valid() {
274 let mut pool = lock_inner(mutex);
275 pool.total = pool.total.saturating_sub(1);
276 drop(pool);
277 condvar.notify_one();
278 warn!("释放时检测到坏连接,已丢弃");
279 return;
280 }
281 if conn.idle_elapsed().as_secs() > MAX_IDLE_SECS {
282 let mut pool = lock_inner(mutex);
283 pool.total = pool.total.saturating_sub(1);
284 drop(pool);
285 condvar.notify_one();
286 log::debug!("连接空闲超过{}秒,已丢弃", MAX_IDLE_SECS);
287 return;
288 }
289 let mut pool = lock_inner(mutex);
290 if pool.idle.len() < pool.max {
291 pool.idle.push_back(conn);
292 } else {
293 pool.total = pool.total.saturating_sub(1);
294 warn!("连接池已满,丢弃连接");
295 }
296 drop(pool);
297 condvar.notify_one();
298 }
299 pub fn idle_pool_size(&self) -> usize {
300 let (ref mutex, _) = *self.inner;
301 let pool = lock_inner(mutex);
302 pool.idle.len()
303 }
304 pub fn total_connections(&self) -> usize {
305 let (ref mutex, _) = *self.inner;
306 let pool = lock_inner(mutex);
307 pool.total
308 }
309 pub fn borrowed_connections(&self) -> usize {
310 let (ref mutex, _) = *self.inner;
311 let pool = lock_inner(mutex);
312 pool.total.saturating_sub(pool.idle.len())
313 }
314 #[allow(dead_code)]
315 pub fn _cleanup_idle_connections(&self) {
316 let (ref mutex, _) = *self.inner;
317 let mut pool = lock_inner(mutex);
318 log::debug!("当前连接池中的连接数量(清理前): {}", pool.idle.len());
319 let before = pool.idle.len();
320 pool.idle.retain(|conn| {
321 let peer_ok = conn.peer_valid();
322 let idle_ok = conn.idle_elapsed().as_secs() <= MAX_IDLE_SECS;
323 if !peer_ok {
324 log::debug!("检测到无效连接,已移除");
325 } else if !idle_ok {
326 log::debug!("检测到空闲超时连接,已移除");
327 }
328 peer_ok && idle_ok
329 });
330 let removed = before - pool.idle.len();
331 pool.total = pool.total.saturating_sub(removed);
332 }
333}
334
335enum Action {
337 GotConn(Box<Connect>),
338 Create,
339 Wait,
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345 use std::io::{Read as IoRead, Write as IoWrite};
346 use std::net::TcpListener;
347 use std::sync::atomic::{AtomicBool, Ordering};
348
349 fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
350 let mut m = Vec::with_capacity(5 + payload.len());
351 m.push(tag);
352 m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
353 m.extend_from_slice(payload);
354 m
355 }
356
357 fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
358 let mut body = Vec::new();
359 body.extend(&auth_type.to_be_bytes());
360 body.extend_from_slice(extra);
361 pg_msg(b'R', &body)
362 }
363
364 fn post_auth_ok() -> Vec<u8> {
365 let mut v = Vec::new();
366 v.extend(pg_auth(0, &[]));
367 v.extend(pg_msg(b'S', b"server_version\x0015.0\x00"));
368 let mut k = Vec::new();
369 k.extend(&1u32.to_be_bytes());
370 k.extend(&2u32.to_be_bytes());
371 v.extend(pg_msg(b'K', &k));
372 v.extend(pg_msg(b'Z', b"I"));
373 v
374 }
375
376 fn simple_query_response() -> Vec<u8> {
377 let mut r = Vec::new();
378 r.extend(pg_msg(b'1', &[]));
379 r.extend(pg_msg(b'2', &[]));
380 let mut rd = Vec::new();
381 rd.extend(&1u16.to_be_bytes());
382 rd.extend(b"c\x00");
383 rd.extend(&0u32.to_be_bytes());
384 rd.extend(&1u16.to_be_bytes());
385 rd.extend(&23u32.to_be_bytes());
386 rd.extend(&4i16.to_be_bytes());
387 rd.extend(&(-1i32).to_be_bytes());
388 rd.extend(&0u16.to_be_bytes());
389 r.extend(pg_msg(b'T', &rd));
390 let mut dr = Vec::new();
391 dr.extend(&1u16.to_be_bytes());
392 dr.extend(&1u32.to_be_bytes());
393 dr.push(b'1');
394 r.extend(pg_msg(b'D', &dr));
395 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
396 r.extend(pg_msg(b'Z', b"I"));
397 r
398 }
399
400 fn spawn_multi_server(stop: Arc<AtomicBool>) -> u16 {
401 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
402 let port = listener.local_addr().unwrap().port();
403 thread::spawn(move || {
404 listener.set_nonblocking(true).unwrap();
405 while !stop.load(Ordering::Relaxed) {
406 match listener.accept() {
407 Ok((s, _)) => {
408 s.set_nonblocking(false).ok();
409 let stop2 = stop.clone();
410 thread::spawn(move || {
411 s.set_read_timeout(Some(Duration::from_secs(5))).ok();
412 let mut s = s;
413 let mut buf = [0u8; 4096];
414 if s.read(&mut buf).unwrap_or(0) == 0 {
415 return;
416 }
417 let _ = s.write_all(&pg_auth(3, &[]));
418 if s.read(&mut buf).unwrap_or(0) == 0 {
419 return;
420 }
421 let _ = s.write_all(&post_auth_ok());
422 while !stop2.load(Ordering::Relaxed) {
423 match s.read(&mut buf) {
424 Ok(0) | Err(_) => break,
425 Ok(_) => {
426 let _ = s.write_all(&simple_query_response());
427 }
428 }
429 }
430 });
431 }
432 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
433 thread::sleep(Duration::from_millis(5));
434 }
435 Err(_) => break,
436 }
437 }
438 });
439 thread::sleep(Duration::from_millis(50));
440 port
441 }
442
443 fn mock_config(port: u16) -> Config {
444 Config {
445 debug: false,
446 hostname: "127.0.0.1".into(),
447 hostport: port as i32,
448 username: "u".into(),
449 userpass: "p".into(),
450 database: "d".into(),
451 charset: "utf8".into(),
452 pool_max: 5,
453 }
454 }
455
456 #[test]
457 fn pools_all_paths() {
458 let stop = Arc::new(AtomicBool::new(false));
459 let port = spawn_multi_server(stop.clone());
460 let cfg = mock_config(port);
461
462 let pools = Pools::new(cfg.clone(), 10).unwrap();
464 assert_eq!(pools.total_connections(), 2);
465 assert_eq!(pools.idle_pool_size(), 2);
466 assert_eq!(pools.borrowed_connections(), 0);
467
468 let conn1 = pools.get_connect().unwrap();
470 assert_eq!(pools.idle_pool_size(), 1);
471 assert!(pools.borrowed_connections() > 0);
472
473 let idle_before = pools.idle_pool_size();
475 pools.release_conn(conn1);
476 assert!(pools.idle_pool_size() > idle_before);
477
478 let conn2 = pools.get_connect().unwrap();
480 drop(conn2);
481
482 let mut conn3 = pools.get_connect().unwrap();
484 let total_before = pools.total_connections();
485 conn3._close();
486 pools.release_conn(conn3);
487 assert!(pools.total_connections() <= total_before);
488
489 pools._cleanup_idle_connections();
491
492 {
494 let mut guard = pools.get_guard().unwrap();
495 let qr = guard.conn().query("SELECT 1");
496 assert!(qr.is_ok());
497 }
498 assert!(pools.idle_pool_size() > 0);
499
500 let pools2 = Pools::new(cfg.clone(), 10).unwrap();
502 let txn = pools2.get_connect_for_transaction().unwrap();
503 let total_before = pools2.total_connections();
504 pools2.release_transaction_conn();
505 assert_eq!(pools2.total_connections(), total_before - 1);
506 drop(txn);
507
508 let pools3 = Pools::new(cfg.clone(), 1).unwrap();
510 let held = pools3.get_connect().unwrap();
511 let result = pools3.get_connect();
512 assert!(result.is_err());
513 drop(held);
514
515 let bad_cfg = mock_config(1);
517 let pools4 = Pools::new(bad_cfg.clone(), 5).unwrap();
518 assert_eq!(pools4.total_connections(), 0);
519
520 let pools5 = Pools::new(bad_cfg.clone(), 5).unwrap();
522 let result = pools5.get_connect();
523 assert!(result.is_err());
524
525 let pools6 = Pools::new(bad_cfg.clone(), 5).unwrap();
527 let result = pools6.get_connect_for_transaction();
528 assert!(result.is_err());
529
530 let pools7 = Pools::new(cfg.clone(), 1).unwrap();
532 assert_eq!(pools7.total_connections(), 1);
533
534 stop.store(true, Ordering::Relaxed);
535 }
536}