1use crate::config::Config;
2use crate::connect::Connect;
3use crate::error::PgsqlError;
4use log::{error, warn};
5use std::collections::VecDeque;
6use std::sync::{Arc, Condvar, Mutex, MutexGuard, PoisonError};
7use std::thread;
8use std::time::Duration;
9
10const MAX_IDLE_SECS: u64 = 300;
12struct PoolInner {
13 idle: VecDeque<Connect>,
14 total: usize,
15 max: usize,
16}
17
18struct SlotGuard<'a> {
21 mutex: &'a Mutex<PoolInner>,
22 condvar: &'a Condvar,
23 active: bool,
24}
25
26impl<'a> SlotGuard<'a> {
27 fn new(mutex: &'a Mutex<PoolInner>, condvar: &'a Condvar) -> Self {
28 Self {
29 mutex,
30 condvar,
31 active: true,
32 }
33 }
34
35 fn disarm(&mut self) {
37 self.active = false;
38 }
39}
40
41impl Drop for SlotGuard<'_> {
42 fn drop(&mut self) {
43 if self.active {
44 let mut pool = lock_inner(self.mutex);
45 pool.total = pool.total.saturating_sub(1);
46 drop(pool);
47 self.condvar.notify_one();
48 }
49 }
50}
51#[derive(Clone)]
52pub struct Pools {
53 pub config: Config,
54 inner: Arc<(Mutex<PoolInner>, Condvar)>,
55}
56fn lock_inner(mutex: &Mutex<PoolInner>) -> MutexGuard<'_, PoolInner> {
57 mutex.lock().unwrap_or_else(PoisonError::into_inner)
58}
59pub struct ConnectionGuard {
60 pool: Pools,
61 conn: Option<Connect>,
62}
63impl ConnectionGuard {
64 pub fn new(pool: Pools) -> Result<Self, PgsqlError> {
65 let conn = pool.get_connect()?;
66 Ok(Self {
67 pool,
68 conn: Some(conn),
69 })
70 }
71 pub fn conn(&mut self) -> &mut Connect {
72 self.conn.as_mut().expect("connection already released")
73 }
74 pub fn discard(&mut self) {
76 if let Some(_conn) = self.conn.take() {
77 let (ref mutex, ref condvar) = *self.pool.inner;
78 let mut pool = lock_inner(mutex);
79 pool.total = pool.total.saturating_sub(1);
80 drop(pool);
81 condvar.notify_one();
82 }
83 }
84}
85impl Drop for ConnectionGuard {
86 fn drop(&mut self) {
87 if let Some(conn) = self.conn.take() {
88 self.pool.release_conn(conn);
89 }
90 }
91}
92impl Pools {
93 pub fn get_guard(&self) -> Result<ConnectionGuard, PgsqlError> {
94 ConnectionGuard::new(self.clone())
95 }
96 pub fn new(config: Config, size: usize) -> Result<Self, PgsqlError> {
97 let init_size = 2.min(size);
98 let mut idle = VecDeque::with_capacity(size);
99 let mut created = 0;
100 for _ in 0..init_size {
101 match Connect::new(config.clone()) {
102 Ok(conn) => {
103 idle.push_back(conn);
104 created += 1;
105 }
106 Err(e) => warn!("初始化连接失败: {e}"),
107 }
108 }
109 let inner = PoolInner {
110 idle,
111 total: created,
112 max: size,
113 };
114
115 Ok(Self {
116 config,
117 inner: Arc::new((Mutex::new(inner), Condvar::new())),
118 })
119 }
120
121 fn acquire_connect(&self, for_transaction: bool) -> Result<Connect, PgsqlError> {
123 let mut attempts = 0;
124 let (ref mutex, ref condvar) = *self.inner;
125 let label = if for_transaction { "事务" } else { "" };
126 #[cfg(not(test))]
127 const BASE_SLEEP_MS: u64 = 1000;
128 #[cfg(test)]
129 const BASE_SLEEP_MS: u64 = 1;
130
131 #[cfg(not(test))]
132 const MAX_SLEEP_MS: u64 = 8000;
133 #[cfg(test)]
134 const MAX_SLEEP_MS: u64 = 5;
135
136 #[cfg(not(test))]
137 const WAIT_TIMEOUT: Duration = Duration::from_secs(3);
138 #[cfg(test)]
139 const WAIT_TIMEOUT: Duration = Duration::from_millis(5);
140
141 let timeout_msg = if for_transaction {
142 "无法获取事务连接,重试超时"
143 } else {
144 "无法连接数据库,重试超时"
145 };
146
147 loop {
148 if attempts >= 20 {
149 return Err(PgsqlError::Pool(timeout_msg.into()));
150 }
151
152 let action = {
153 let mut pool = lock_inner(mutex);
154 if let Some(conn) = pool.idle.pop_front() {
155 Action::GotConn(Box::new(conn))
156 } else if pool.total < pool.max {
157 pool.total += 1; Action::Create
159 } else {
160 Action::Wait
161 }
162 };
163
164 match action {
165 Action::GotConn(mut conn) => {
166 if conn.is_valid() {
168 conn.touch();
169 return Ok(*conn);
170 }
171 {
173 let mut pool = lock_inner(mutex);
174 pool.total = pool.total.saturating_sub(1);
175 }
176 warn!(
177 "{}连接失效,尝试重建,当前总连接数量: {}",
178 label,
179 self.total_connections()
180 );
181 let mut guard = SlotGuard::new(mutex, condvar);
183 {
184 let mut pool = lock_inner(mutex);
185 pool.total += 1;
186 }
187 match Connect::new(self.config.clone()) {
188 Ok(new_conn) => {
189 guard.disarm();
190 return Ok(new_conn);
191 }
192 Err(e) => {
193 drop(guard);
195 let sleep_ms = BASE_SLEEP_MS
196 .saturating_mul(1u64 << attempts.min(3))
197 .min(MAX_SLEEP_MS);
198 attempts += 1;
199 error!("{}重建连接失败({}ms后重试): {}", label, sleep_ms, e);
200 thread::sleep(Duration::from_millis(sleep_ms));
201 }
202 }
203 }
204 Action::Create => {
205 let mut guard = SlotGuard::new(mutex, condvar);
207 match Connect::new(self.config.clone()) {
208 Ok(new_conn) => {
209 guard.disarm();
210 return Ok(new_conn);
211 }
212 Err(e) => {
213 drop(guard);
215 let sleep_ms = BASE_SLEEP_MS
216 .saturating_mul(1u64 << attempts.min(3))
217 .min(MAX_SLEEP_MS);
218 attempts += 1;
219 error!("创建{}连接失败({}ms后重试): {}", label, sleep_ms, e);
220 thread::sleep(Duration::from_millis(sleep_ms));
221 }
222 }
223 }
224 Action::Wait => {
225 let pool = lock_inner(mutex);
226 let (_pool, timeout) = condvar
227 .wait_timeout(pool, WAIT_TIMEOUT)
228 .unwrap_or_else(PoisonError::into_inner);
229 drop(_pool);
230 if timeout.timed_out() {
231 attempts += 1;
232 }
233 }
234 }
235 }
236 }
237 pub fn get_connect(&self) -> Result<Connect, PgsqlError> {
238 self.acquire_connect(false)
239 }
240 pub fn get_connect_for_transaction(&self) -> Result<Connect, PgsqlError> {
242 self.acquire_connect(true)
243 }
244 pub fn release_transaction_conn(&self) {
245 let (ref mutex, ref condvar) = *self.inner;
246 let mut pool = lock_inner(mutex);
247 pool.total = pool.total.saturating_sub(1);
248 drop(pool);
249 condvar.notify_one();
250 }
251 pub fn release_conn(&self, conn: Connect) {
252 let (ref mutex, ref condvar) = *self.inner;
253 if !conn.peer_valid() {
254 let mut pool = lock_inner(mutex);
255 pool.total = pool.total.saturating_sub(1);
256 drop(pool);
257 condvar.notify_one();
258 warn!("释放时检测到坏连接,已丢弃");
259 return;
260 }
261 if conn.idle_elapsed().as_secs() > MAX_IDLE_SECS {
262 let mut pool = lock_inner(mutex);
263 pool.total = pool.total.saturating_sub(1);
264 drop(pool);
265 condvar.notify_one();
266 log::debug!("连接空闲超过{}秒,已丢弃", MAX_IDLE_SECS);
267 return;
268 }
269 let mut pool = lock_inner(mutex);
270 if pool.idle.len() < pool.max {
271 pool.idle.push_back(conn);
272 } else {
273 pool.total = pool.total.saturating_sub(1);
274 warn!("连接池已满,丢弃连接");
275 }
276 drop(pool);
277 condvar.notify_one();
278 }
279 pub fn idle_pool_size(&self) -> usize {
280 let (ref mutex, _) = *self.inner;
281 let pool = lock_inner(mutex);
282 pool.idle.len()
283 }
284 pub fn total_connections(&self) -> usize {
285 let (ref mutex, _) = *self.inner;
286 let pool = lock_inner(mutex);
287 pool.total
288 }
289 pub fn borrowed_connections(&self) -> usize {
290 let (ref mutex, _) = *self.inner;
291 let pool = lock_inner(mutex);
292 pool.total.saturating_sub(pool.idle.len())
293 }
294 #[allow(dead_code)]
295 pub fn _cleanup_idle_connections(&self) {
296 let (ref mutex, _) = *self.inner;
297 let mut pool = lock_inner(mutex);
298 log::debug!("当前连接池中的连接数量(清理前): {}", pool.idle.len());
299 let before = pool.idle.len();
300 pool.idle.retain(|conn| {
301 let peer_ok = conn.peer_valid();
302 let idle_ok = conn.idle_elapsed().as_secs() <= MAX_IDLE_SECS;
303 if !peer_ok {
304 log::debug!("检测到无效连接,已移除");
305 } else if !idle_ok {
306 log::debug!("检测到空闲超时连接,已移除");
307 }
308 peer_ok && idle_ok
309 });
310 let removed = before - pool.idle.len();
311 pool.total = pool.total.saturating_sub(removed);
312 }
313}
314
315enum Action {
317 GotConn(Box<Connect>),
318 Create,
319 Wait,
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325 use std::io::{Read as IoRead, Write as IoWrite};
326 use std::net::TcpListener;
327 use std::sync::atomic::{AtomicBool, Ordering};
328
329 fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
330 let mut m = Vec::with_capacity(5 + payload.len());
331 m.push(tag);
332 m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
333 m.extend_from_slice(payload);
334 m
335 }
336
337 fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
338 let mut body = Vec::new();
339 body.extend(&auth_type.to_be_bytes());
340 body.extend_from_slice(extra);
341 pg_msg(b'R', &body)
342 }
343
344 fn post_auth_ok() -> Vec<u8> {
345 let mut v = Vec::new();
346 v.extend(pg_auth(0, &[]));
347 v.extend(pg_msg(b'S', b"server_version\x0015.0\x00"));
348 let mut k = Vec::new();
349 k.extend(&1u32.to_be_bytes());
350 k.extend(&2u32.to_be_bytes());
351 v.extend(pg_msg(b'K', &k));
352 v.extend(pg_msg(b'Z', b"I"));
353 v
354 }
355
356 fn simple_query_response() -> Vec<u8> {
357 let mut r = Vec::new();
358 r.extend(pg_msg(b'1', &[]));
359 r.extend(pg_msg(b'2', &[]));
360 let mut rd = Vec::new();
361 rd.extend(&1u16.to_be_bytes());
362 rd.extend(b"c\x00");
363 rd.extend(&0u32.to_be_bytes());
364 rd.extend(&1u16.to_be_bytes());
365 rd.extend(&23u32.to_be_bytes());
366 rd.extend(&4i16.to_be_bytes());
367 rd.extend(&(-1i32).to_be_bytes());
368 rd.extend(&0u16.to_be_bytes());
369 r.extend(pg_msg(b'T', &rd));
370 let mut dr = Vec::new();
371 dr.extend(&1u16.to_be_bytes());
372 dr.extend(&1u32.to_be_bytes());
373 dr.push(b'1');
374 r.extend(pg_msg(b'D', &dr));
375 r.extend(pg_msg(b'C', b"SELECT 1\x00"));
376 r.extend(pg_msg(b'Z', b"I"));
377 r
378 }
379
380 fn spawn_multi_server(stop: Arc<AtomicBool>) -> u16 {
381 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
382 let port = listener.local_addr().unwrap().port();
383 thread::spawn(move || {
384 listener.set_nonblocking(true).unwrap();
385 while !stop.load(Ordering::Relaxed) {
386 match listener.accept() {
387 Ok((s, _)) => {
388 s.set_nonblocking(false).ok();
389 let stop2 = stop.clone();
390 thread::spawn(move || {
391 s.set_read_timeout(Some(Duration::from_secs(5))).ok();
392 let mut s = s;
393 let mut buf = [0u8; 4096];
394 if s.read(&mut buf).unwrap_or(0) == 0 {
395 return;
396 }
397 let _ = s.write_all(&pg_auth(3, &[]));
398 if s.read(&mut buf).unwrap_or(0) == 0 {
399 return;
400 }
401 let _ = s.write_all(&post_auth_ok());
402 while !stop2.load(Ordering::Relaxed) {
403 match s.read(&mut buf) {
404 Ok(0) | Err(_) => break,
405 Ok(_) => {
406 let _ = s.write_all(&simple_query_response());
407 }
408 }
409 }
410 });
411 }
412 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
413 thread::sleep(Duration::from_millis(5));
414 }
415 Err(_) => break,
416 }
417 }
418 });
419 thread::sleep(Duration::from_millis(50));
420 port
421 }
422
423 fn mock_config(port: u16) -> Config {
424 Config {
425 debug: false,
426 hostname: "127.0.0.1".into(),
427 hostport: port as i32,
428 username: "u".into(),
429 userpass: "p".into(),
430 database: "d".into(),
431 charset: "utf8".into(),
432 pool_max: 5,
433 }
434 }
435
436 #[test]
437 fn pools_all_paths() {
438 let stop = Arc::new(AtomicBool::new(false));
439 let port = spawn_multi_server(stop.clone());
440 let cfg = mock_config(port);
441
442 let pools = Pools::new(cfg.clone(), 10).unwrap();
444 assert_eq!(pools.total_connections(), 2);
445 assert_eq!(pools.idle_pool_size(), 2);
446 assert_eq!(pools.borrowed_connections(), 0);
447
448 let conn1 = pools.get_connect().unwrap();
450 assert_eq!(pools.idle_pool_size(), 1);
451 assert!(pools.borrowed_connections() > 0);
452
453 let idle_before = pools.idle_pool_size();
455 pools.release_conn(conn1);
456 assert!(pools.idle_pool_size() > idle_before);
457
458 let conn2 = pools.get_connect().unwrap();
460 drop(conn2);
461
462 let mut conn3 = pools.get_connect().unwrap();
464 let total_before = pools.total_connections();
465 conn3._close();
466 pools.release_conn(conn3);
467 assert!(pools.total_connections() <= total_before);
468
469 pools._cleanup_idle_connections();
471
472 {
474 let mut guard = pools.get_guard().unwrap();
475 let qr = guard.conn().query("SELECT 1");
476 assert!(qr.is_ok());
477 }
478 assert!(pools.idle_pool_size() > 0);
479
480 let pools2 = Pools::new(cfg.clone(), 10).unwrap();
482 let txn = pools2.get_connect_for_transaction().unwrap();
483 let total_before = pools2.total_connections();
484 pools2.release_transaction_conn();
485 assert_eq!(pools2.total_connections(), total_before - 1);
486 drop(txn);
487
488 let pools3 = Pools::new(cfg.clone(), 1).unwrap();
490 let held = pools3.get_connect().unwrap();
491 let result = pools3.get_connect();
492 assert!(result.is_err());
493 drop(held);
494
495 let bad_cfg = mock_config(1);
497 let pools4 = Pools::new(bad_cfg.clone(), 5).unwrap();
498 assert_eq!(pools4.total_connections(), 0);
499
500 let pools5 = Pools::new(bad_cfg.clone(), 5).unwrap();
502 let result = pools5.get_connect();
503 assert!(result.is_err());
504
505 let pools6 = Pools::new(bad_cfg.clone(), 5).unwrap();
507 let result = pools6.get_connect_for_transaction();
508 assert!(result.is_err());
509
510 let pools7 = Pools::new(cfg.clone(), 1).unwrap();
512 assert_eq!(pools7.total_connections(), 1);
513
514 stop.store(true, Ordering::Relaxed);
515 }
516}