use crate::config::Config;
use crate::connect::Connect;
use crate::error::PgsqlError;
use log::{error, info, warn};
use std::collections::VecDeque;
use std::sync::{Arc, Condvar, Mutex, MutexGuard, PoisonError, Weak};
use std::thread;
use std::time::Duration;
const MAX_IDLE_SECS: u64 = 300;
const MAX_CONN_LIFETIME_SECS: u64 = 1800;
struct PoolInner {
idle: VecDeque<Connect>,
total: usize,
max: usize,
txn_total: usize,
txn_max: usize,
}
struct SlotGuard<'a> {
mutex: &'a Mutex<PoolInner>,
condvar: &'a Condvar,
active: bool,
for_transaction: bool,
}
impl<'a> SlotGuard<'a> {
fn new(mutex: &'a Mutex<PoolInner>, condvar: &'a Condvar, for_transaction: bool) -> Self {
Self {
mutex,
condvar,
active: true,
for_transaction,
}
}
fn disarm(&mut self) {
self.active = false;
}
}
impl Drop for SlotGuard<'_> {
fn drop(&mut self) {
if self.active {
let mut pool = lock_inner(self.mutex);
pool.total = pool.total.saturating_sub(1);
if self.for_transaction {
pool.txn_total = pool.txn_total.saturating_sub(1);
}
drop(pool);
self.condvar.notify_one();
}
}
}
#[derive(Clone)]
pub struct Pools {
pub config: Config,
inner: Arc<(Mutex<PoolInner>, Condvar)>,
}
fn lock_inner(mutex: &Mutex<PoolInner>) -> MutexGuard<'_, PoolInner> {
mutex.lock().unwrap_or_else(PoisonError::into_inner)
}
pub struct ConnectionGuard {
pool: Pools,
conn: Option<Connect>,
}
impl ConnectionGuard {
pub fn new(pool: Pools) -> Result<Self, PgsqlError> {
let conn = pool.get_connect()?;
Ok(Self {
pool,
conn: Some(conn),
})
}
pub fn conn(&mut self) -> &mut Connect {
self.conn.as_mut().expect("connection already released")
}
pub fn discard(&mut self) {
if let Some(_conn) = self.conn.take() {
let (ref mutex, ref condvar) = *self.pool.inner;
let mut pool = lock_inner(mutex);
pool.total = pool.total.saturating_sub(1);
drop(pool);
condvar.notify_one();
}
}
}
impl Drop for ConnectionGuard {
fn drop(&mut self) {
if let Some(conn) = self.conn.take() {
self.pool.release_conn(conn);
}
}
}
impl Pools {
pub fn get_guard(&self) -> Result<ConnectionGuard, PgsqlError> {
ConnectionGuard::new(self.clone())
}
pub fn new(config: Config, size: usize) -> Result<Self, PgsqlError> {
let init_size = 2.min(size);
let mut idle = VecDeque::with_capacity(size);
let mut created = 0;
for _ in 0..init_size {
match Connect::new(config.clone()) {
Ok(conn) => {
idle.push_back(conn);
created += 1;
}
Err(e) => warn!("初始化连接失败: {e}"),
}
}
let txn_max = (size / 3).max(1);
let inner = PoolInner {
idle,
total: created,
max: size,
txn_total: 0,
txn_max,
};
let arc = Arc::new((Mutex::new(inner), Condvar::new()));
let weak = Arc::downgrade(&arc);
thread::spawn(move || {
Self::reaper_loop(weak);
});
Ok(Self { config, inner: arc })
}
fn acquire_connect(&self, for_transaction: bool) -> Result<Connect, PgsqlError> {
let mut attempts = 0;
let (ref mutex, ref condvar) = *self.inner;
let label = if for_transaction { "事务" } else { "" };
#[cfg(not(test))]
const BASE_SLEEP_MS: u64 = 200;
#[cfg(test)]
const BASE_SLEEP_MS: u64 = 1;
#[cfg(not(test))]
const MAX_SLEEP_MS: u64 = 2000;
#[cfg(test)]
const MAX_SLEEP_MS: u64 = 5;
#[cfg(not(test))]
const WAIT_TIMEOUT: Duration = Duration::from_secs(2);
#[cfg(test)]
const WAIT_TIMEOUT: Duration = Duration::from_millis(5);
let timeout_msg = if for_transaction {
"无法获取事务连接,重试超时"
} else {
"无法连接数据库,重试超时"
};
loop {
if attempts >= 5 {
return Err(PgsqlError::Pool(timeout_msg.into()));
}
let action = {
let mut pool = lock_inner(mutex);
if for_transaction && pool.txn_total >= pool.txn_max && pool.total >= pool.max {
Action::Wait
} else if let Some(conn) = pool.idle.pop_front() {
if for_transaction {
pool.txn_total += 1;
}
Action::GotConn(Box::new(conn))
} else if pool.total < pool.max {
pool.total += 1; if for_transaction {
pool.txn_total += 1;
}
Action::Create
} else {
Action::Wait
}
};
match action {
Action::GotConn(mut conn) => {
if conn.age().as_secs() > MAX_CONN_LIFETIME_SECS {
{
let mut pool = lock_inner(mutex);
pool.total = pool.total.saturating_sub(1);
if for_transaction {
pool.txn_total = pool.txn_total.saturating_sub(1);
}
}
log::debug!("{}连接存活超过{}秒,已丢弃", label, MAX_CONN_LIFETIME_SECS);
continue;
}
if conn.is_valid() {
conn.touch();
return Ok(*conn);
}
{
let mut pool = lock_inner(mutex);
pool.total = pool.total.saturating_sub(1);
if for_transaction {
pool.txn_total = pool.txn_total.saturating_sub(1);
}
}
warn!(
"{}连接失效已丢弃,当前总连接数量: {}",
label,
self.total_connections()
);
continue;
}
Action::Create => {
let mut guard = SlotGuard::new(mutex, condvar, for_transaction);
match Connect::new(self.config.clone()) {
Ok(new_conn) => {
guard.disarm();
return Ok(new_conn);
}
Err(e) => {
drop(guard);
let sleep_ms = BASE_SLEEP_MS
.saturating_mul(1u64 << attempts.min(3))
.min(MAX_SLEEP_MS);
attempts += 1;
error!("创建{}连接失败({}ms后重试): {}", label, sleep_ms, e);
thread::sleep(Duration::from_millis(sleep_ms));
}
}
}
Action::Wait => {
let pool = lock_inner(mutex);
let (_pool, timeout) = condvar
.wait_timeout(pool, WAIT_TIMEOUT)
.unwrap_or_else(PoisonError::into_inner);
drop(_pool);
if timeout.timed_out() {
attempts += 1;
}
}
}
}
}
pub fn get_connect(&self) -> Result<Connect, PgsqlError> {
self.acquire_connect(false)
}
pub fn get_connect_for_transaction(&self) -> Result<Connect, PgsqlError> {
self.acquire_connect(true)
}
pub fn release_transaction_conn(&self) {
let (ref mutex, ref condvar) = *self.inner;
let mut pool = lock_inner(mutex);
pool.total = pool.total.saturating_sub(1);
pool.txn_total = pool.txn_total.saturating_sub(1);
drop(pool);
condvar.notify_one();
}
pub fn release_transaction_conn_with_conn(&self, conn: Connect) {
let (ref mutex, _) = *self.inner;
{
let mut pool = lock_inner(mutex);
pool.txn_total = pool.txn_total.saturating_sub(1);
}
self.release_conn(conn);
}
pub fn release_conn(&self, conn: Connect) {
let (ref mutex, ref condvar) = *self.inner;
if !conn.peer_valid() {
let mut pool = lock_inner(mutex);
pool.total = pool.total.saturating_sub(1);
drop(pool);
condvar.notify_one();
warn!("释放时检测到坏连接,已丢弃");
return;
}
if conn.age().as_secs() > MAX_CONN_LIFETIME_SECS {
let mut pool = lock_inner(mutex);
pool.total = pool.total.saturating_sub(1);
drop(pool);
condvar.notify_one();
log::debug!("释放时连接存活超过{}秒,已丢弃", MAX_CONN_LIFETIME_SECS);
return;
}
if conn.idle_elapsed().as_secs() > MAX_IDLE_SECS {
let mut pool = lock_inner(mutex);
pool.total = pool.total.saturating_sub(1);
drop(pool);
condvar.notify_one();
log::debug!("连接空闲超过{}秒,已丢弃", MAX_IDLE_SECS);
return;
}
let mut pool = lock_inner(mutex);
if pool.idle.len() < pool.max {
pool.idle.push_back(conn);
} else {
pool.total = pool.total.saturating_sub(1);
warn!("连接池已满,丢弃连接");
}
drop(pool);
condvar.notify_one();
}
pub fn idle_pool_size(&self) -> usize {
let (ref mutex, _) = *self.inner;
let pool = lock_inner(mutex);
pool.idle.len()
}
pub fn total_connections(&self) -> usize {
let (ref mutex, _) = *self.inner;
let pool = lock_inner(mutex);
pool.total
}
pub fn borrowed_connections(&self) -> usize {
let (ref mutex, _) = *self.inner;
let pool = lock_inner(mutex);
pool.total.saturating_sub(pool.idle.len())
}
pub fn flush_idle(&self) {
let (ref mutex, _) = *self.inner;
let mut pool = lock_inner(mutex);
let flushed = pool.idle.len();
pool.total = pool.total.saturating_sub(flushed);
pool.idle.clear();
if flushed > 0 {
warn!("清空池中 {flushed} 个空闲连接(疑似批量失效)");
}
}
pub fn cleanup_idle_connections(&self) {
let (ref mutex, _) = *self.inner;
let mut pool = lock_inner(mutex);
let before = pool.idle.len();
pool.idle.retain(|conn| {
let peer_ok = conn.peer_valid();
let idle_ok = conn.idle_elapsed().as_secs() <= MAX_IDLE_SECS;
let lifetime_ok = conn.age().as_secs() <= MAX_CONN_LIFETIME_SECS;
if !peer_ok {
log::debug!("检测到无效连接,已移除");
} else if !idle_ok {
log::debug!("检测到空闲超时连接,已移除");
} else if !lifetime_ok {
log::debug!("检测到超过最大生命周期连接,已移除");
}
peer_ok && idle_ok && lifetime_ok
});
let removed = before - pool.idle.len();
pool.total = pool.total.saturating_sub(removed);
if removed > 0 {
log::debug!(
"空闲连接清理完成: 移除 {removed} 个,剩余 {} 个",
pool.idle.len()
);
}
}
fn reaper_loop(weak: Weak<(Mutex<PoolInner>, Condvar)>) {
#[cfg(not(test))]
const INTERVAL: Duration = Duration::from_secs(60);
#[cfg(test)]
const INTERVAL: Duration = Duration::from_millis(50);
loop {
thread::sleep(INTERVAL);
let arc = match weak.upgrade() {
Some(a) => a,
None => {
info!("连接池已释放,回收线程退出");
return;
}
};
let (ref mutex, _) = *arc;
let mut pool = lock_inner(mutex);
let before = pool.idle.len();
pool.idle.retain(|conn| {
conn.peer_valid()
&& conn.idle_elapsed().as_secs() <= MAX_IDLE_SECS
&& conn.age().as_secs() <= MAX_CONN_LIFETIME_SECS
});
let removed = before - pool.idle.len();
pool.total = pool.total.saturating_sub(removed);
if removed > 0 {
info!(
"后台回收: 移除 {removed} 个空闲连接,剩余 {} 个",
pool.idle.len()
);
}
}
}
}
enum Action {
GotConn(Box<Connect>),
Create,
Wait,
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::{Read as IoRead, Write as IoWrite};
use std::net::TcpListener;
use std::sync::atomic::{AtomicBool, Ordering};
fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
let mut m = Vec::with_capacity(5 + payload.len());
m.push(tag);
m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
m.extend_from_slice(payload);
m
}
fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
let mut body = Vec::new();
body.extend(&auth_type.to_be_bytes());
body.extend_from_slice(extra);
pg_msg(b'R', &body)
}
fn post_auth_ok() -> Vec<u8> {
let mut v = Vec::new();
v.extend(pg_auth(0, &[]));
v.extend(pg_msg(b'S', b"server_version\x0015.0\x00"));
let mut k = Vec::new();
k.extend(&1u32.to_be_bytes());
k.extend(&2u32.to_be_bytes());
v.extend(pg_msg(b'K', &k));
v.extend(pg_msg(b'Z', b"I"));
v
}
fn simple_query_response() -> Vec<u8> {
let mut r = Vec::new();
r.extend(pg_msg(b'1', &[]));
r.extend(pg_msg(b'2', &[]));
let mut rd = Vec::new();
rd.extend(&1u16.to_be_bytes());
rd.extend(b"c\x00");
rd.extend(&0u32.to_be_bytes());
rd.extend(&1u16.to_be_bytes());
rd.extend(&23u32.to_be_bytes());
rd.extend(&4i16.to_be_bytes());
rd.extend(&(-1i32).to_be_bytes());
rd.extend(&0u16.to_be_bytes());
r.extend(pg_msg(b'T', &rd));
let mut dr = Vec::new();
dr.extend(&1u16.to_be_bytes());
dr.extend(&1u32.to_be_bytes());
dr.push(b'1');
r.extend(pg_msg(b'D', &dr));
r.extend(pg_msg(b'C', b"SELECT 1\x00"));
r.extend(pg_msg(b'Z', b"I"));
r
}
fn spawn_multi_server(stop: Arc<AtomicBool>) -> u16 {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
listener.set_nonblocking(true).unwrap();
while !stop.load(Ordering::Relaxed) {
match listener.accept() {
Ok((s, _)) => {
s.set_nonblocking(false).ok();
let stop2 = stop.clone();
thread::spawn(move || {
s.set_read_timeout(Some(Duration::from_secs(5))).ok();
let mut s = s;
let mut buf = [0u8; 4096];
if s.read(&mut buf).unwrap_or(0) == 0 {
return;
}
let _ = s.write_all(&pg_auth(3, &[]));
if s.read(&mut buf).unwrap_or(0) == 0 {
return;
}
let _ = s.write_all(&post_auth_ok());
while !stop2.load(Ordering::Relaxed) {
match s.read(&mut buf) {
Ok(0) | Err(_) => break,
Ok(_) => {
let _ = s.write_all(&simple_query_response());
}
}
}
});
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
thread::sleep(Duration::from_millis(5));
}
Err(_) => break,
}
}
});
thread::sleep(Duration::from_millis(50));
port
}
fn mock_config(port: u16) -> Config {
Config {
debug: false,
hostname: "127.0.0.1".into(),
hostport: port as i32,
username: "u".into(),
userpass: "p".into(),
database: "d".into(),
charset: "utf8".into(),
pool_max: 5,
sslmode: "disable".into(),
}
}
#[test]
fn pools_all_paths() {
let stop = Arc::new(AtomicBool::new(false));
let port = spawn_multi_server(stop.clone());
let cfg = mock_config(port);
let pools = Pools::new(cfg.clone(), 10).unwrap();
assert_eq!(pools.total_connections(), 2);
assert_eq!(pools.idle_pool_size(), 2);
assert_eq!(pools.borrowed_connections(), 0);
let conn1 = pools.get_connect().unwrap();
assert_eq!(pools.idle_pool_size(), 1);
assert!(pools.borrowed_connections() > 0);
let idle_before = pools.idle_pool_size();
pools.release_conn(conn1);
assert!(pools.idle_pool_size() > idle_before);
let conn2 = pools.get_connect().unwrap();
drop(conn2);
let mut conn3 = pools.get_connect().unwrap();
let total_before = pools.total_connections();
conn3._close();
pools.release_conn(conn3);
assert!(pools.total_connections() <= total_before);
pools.cleanup_idle_connections();
{
let mut guard = pools.get_guard().unwrap();
let qr = guard.conn().query("SELECT 1");
assert!(qr.is_ok());
}
assert!(pools.idle_pool_size() > 0);
let pools2 = Pools::new(cfg.clone(), 10).unwrap();
let txn = pools2.get_connect_for_transaction().unwrap();
let total_before = pools2.total_connections();
pools2.release_transaction_conn();
assert_eq!(pools2.total_connections(), total_before - 1);
drop(txn);
let pools3 = Pools::new(cfg.clone(), 1).unwrap();
let held = pools3.get_connect().unwrap();
let result = pools3.get_connect();
assert!(result.is_err());
drop(held);
let bad_cfg = mock_config(1);
let pools4 = Pools::new(bad_cfg.clone(), 5).unwrap();
assert_eq!(pools4.total_connections(), 0);
let pools5 = Pools::new(bad_cfg.clone(), 5).unwrap();
let result = pools5.get_connect();
assert!(result.is_err());
let pools6 = Pools::new(bad_cfg.clone(), 5).unwrap();
let result = pools6.get_connect_for_transaction();
assert!(result.is_err());
let pools7 = Pools::new(cfg.clone(), 1).unwrap();
assert_eq!(pools7.total_connections(), 1);
stop.store(true, Ordering::Relaxed);
}
}