1use super::{PgConnection, PgError, PgResult};
7use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::{Mutex, Semaphore};
11
12#[derive(Clone)]
13pub struct PoolConfig {
14 pub host: String,
15 pub port: u16,
16 pub user: String,
17 pub database: String,
18 pub password: Option<String>,
19 pub max_connections: usize,
20 pub min_connections: usize,
21 pub idle_timeout: Duration,
22 pub acquire_timeout: Duration,
23 pub connect_timeout: Duration,
24 pub max_lifetime: Option<Duration>,
25 pub test_on_acquire: bool,
26}
27
28impl PoolConfig {
29 pub fn new(host: &str, port: u16, user: &str, database: &str) -> Self {
31 Self {
32 host: host.to_string(),
33 port,
34 user: user.to_string(),
35 database: database.to_string(),
36 password: None,
37 max_connections: 10,
38 min_connections: 1,
39 idle_timeout: Duration::from_secs(600), acquire_timeout: Duration::from_secs(30), connect_timeout: Duration::from_secs(10), max_lifetime: None, test_on_acquire: false, }
45 }
46
47 pub fn password(mut self, password: &str) -> Self {
49 self.password = Some(password.to_string());
50 self
51 }
52
53 pub fn max_connections(mut self, max: usize) -> Self {
54 self.max_connections = max;
55 self
56 }
57
58 pub fn min_connections(mut self, min: usize) -> Self {
60 self.min_connections = min;
61 self
62 }
63
64 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
66 self.idle_timeout = timeout;
67 self
68 }
69
70 pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
72 self.acquire_timeout = timeout;
73 self
74 }
75
76 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
78 self.connect_timeout = timeout;
79 self
80 }
81
82 pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
84 self.max_lifetime = Some(lifetime);
85 self
86 }
87
88 pub fn test_on_acquire(mut self, enabled: bool) -> Self {
90 self.test_on_acquire = enabled;
91 self
92 }
93}
94
95#[derive(Debug, Clone, Default)]
97pub struct PoolStats {
98 pub active: usize,
99 pub idle: usize,
100 pub pending: usize,
101 pub max_size: usize,
103 pub total_created: usize,
104}
105
106struct PooledConn {
108 conn: PgConnection,
109 created_at: Instant,
110 last_used: Instant,
111}
112
113pub struct PooledConnection {
119 conn: Option<PgConnection>,
120 pool: Arc<PgPoolInner>,
121 rls_dirty: bool,
122}
123
124impl PooledConnection {
125 pub fn get_mut(&mut self) -> &mut PgConnection {
127 self.conn
128 .as_mut()
129 .expect("Connection should always be present")
130 }
131
132 pub fn cancel_token(&self) -> crate::driver::CancelToken {
134 let (process_id, secret_key) = self.conn.as_ref().expect("Connection missing").get_cancel_key();
135 crate::driver::CancelToken {
136 host: self.pool.config.host.clone(),
137 port: self.pool.config.port,
138 process_id,
139 secret_key,
140 }
141 }
142
143 pub async fn fetch_all_uncached(&mut self, cmd: &qail_core::ast::Qail) -> PgResult<Vec<super::PgRow>> {
146 use crate::protocol::AstEncoder;
147 use super::ColumnInfo;
148
149 let conn = self.conn.as_mut().expect("Connection should always be present");
150
151 let wire_bytes = AstEncoder::encode_cmd_reuse(
152 cmd,
153 &mut conn.sql_buf,
154 &mut conn.params_buf,
155 );
156
157 conn.send_bytes(&wire_bytes).await?;
158
159 let mut rows: Vec<super::PgRow> = Vec::new();
160 let mut column_info: Option<Arc<ColumnInfo>> = None;
161 let mut error: Option<PgError> = None;
162
163 loop {
164 let msg = conn.recv().await?;
165 match msg {
166 crate::protocol::BackendMessage::ParseComplete
167 | crate::protocol::BackendMessage::BindComplete => {}
168 crate::protocol::BackendMessage::RowDescription(fields) => {
169 column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
170 }
171 crate::protocol::BackendMessage::DataRow(data) => {
172 if error.is_none() {
173 rows.push(super::PgRow {
174 columns: data,
175 column_info: column_info.clone(),
176 });
177 }
178 }
179 crate::protocol::BackendMessage::CommandComplete(_) => {}
180 crate::protocol::BackendMessage::ReadyForQuery(_) => {
181 if let Some(err) = error {
182 return Err(err);
183 }
184 return Ok(rows);
185 }
186 crate::protocol::BackendMessage::ErrorResponse(err) => {
187 if error.is_none() {
188 error = Some(PgError::Query(err.message));
189 }
190 }
191 _ => {}
192 }
193 }
194 }
195}
196
197impl Drop for PooledConnection {
198 fn drop(&mut self) {
199 if let Some(conn) = self.conn.take() {
200 let pool = self.pool.clone();
201 let rls_dirty = self.rls_dirty;
202 tokio::spawn(async move {
203 if rls_dirty {
204 let mut conn = conn;
208 let _ = conn.execute_simple(super::rls::reset_sql()).await;
209 pool.return_connection(conn).await;
210 } else {
211 pool.return_connection(conn).await;
212 }
213 });
214 }
215 }
216}
217
218impl std::ops::Deref for PooledConnection {
219 type Target = PgConnection;
220
221 fn deref(&self) -> &Self::Target {
222 self.conn
223 .as_ref()
224 .expect("Connection should always be present")
225 }
226}
227
228impl std::ops::DerefMut for PooledConnection {
229 fn deref_mut(&mut self) -> &mut Self::Target {
230 self.conn
231 .as_mut()
232 .expect("Connection should always be present")
233 }
234}
235
236struct PgPoolInner {
238 config: PoolConfig,
239 connections: Mutex<Vec<PooledConn>>,
240 semaphore: Semaphore,
241 closed: AtomicBool,
242 active_count: AtomicUsize,
243 total_created: AtomicUsize,
244}
245
246impl PgPoolInner {
247 async fn return_connection(&self, conn: PgConnection) {
248
249 self.active_count.fetch_sub(1, Ordering::Relaxed);
250
251
252 if self.closed.load(Ordering::Relaxed) {
253 return;
254 }
255
256 let mut connections = self.connections.lock().await;
257 if connections.len() < self.config.max_connections {
258 connections.push(PooledConn {
259 conn,
260 created_at: Instant::now(),
261 last_used: Instant::now(),
262 });
263 }
264
265 self.semaphore.add_permits(1);
266 }
267
268 async fn get_healthy_connection(&self) -> Option<PgConnection> {
270 let mut connections = self.connections.lock().await;
271
272 while let Some(pooled) = connections.pop() {
273 if pooled.last_used.elapsed() > self.config.idle_timeout {
274 continue;
276 }
277
278 if let Some(max_life) = self.config.max_lifetime
279 && pooled.created_at.elapsed() > max_life
280 {
281 continue;
283 }
284
285 return Some(pooled.conn);
286 }
287
288 None
289 }
290}
291
292#[derive(Clone)]
303pub struct PgPool {
304 inner: Arc<PgPoolInner>,
305}
306
307impl PgPool {
308 pub async fn connect(config: PoolConfig) -> PgResult<Self> {
310 let semaphore = Semaphore::new(config.max_connections);
312
313 let mut initial_connections = Vec::new();
314 for _ in 0..config.min_connections {
315 let conn = Self::create_connection(&config).await?;
316 initial_connections.push(PooledConn {
317 conn,
318 created_at: Instant::now(),
319 last_used: Instant::now(),
320 });
321 }
322
323 let initial_count = initial_connections.len();
324
325 let inner = Arc::new(PgPoolInner {
326 config,
327 connections: Mutex::new(initial_connections),
328 semaphore,
329 closed: AtomicBool::new(false),
330 active_count: AtomicUsize::new(0),
331 total_created: AtomicUsize::new(initial_count),
332 });
333
334 Ok(Self { inner })
335 }
336
337 pub async fn acquire(&self) -> PgResult<PooledConnection> {
339 if self.inner.closed.load(Ordering::Relaxed) {
340 return Err(PgError::Connection("Pool is closed".to_string()));
341 }
342
343 let acquire_timeout = self.inner.config.acquire_timeout;
345 let permit = tokio::time::timeout(acquire_timeout, self.inner.semaphore.acquire())
346 .await
347 .map_err(|_| {
348 PgError::Connection(format!(
349 "Timed out waiting for connection ({}s)",
350 acquire_timeout.as_secs()
351 ))
352 })?
353 .map_err(|_| PgError::Connection("Pool closed".to_string()))?;
354 permit.forget();
355
356 let conn = if let Some(conn) = self.inner.get_healthy_connection().await {
358 conn
359 } else {
360 let conn = Self::create_connection(&self.inner.config).await?;
361 self.inner.total_created.fetch_add(1, Ordering::Relaxed);
362 conn
363 };
364
365
366 self.inner.active_count.fetch_add(1, Ordering::Relaxed);
367
368 Ok(PooledConnection {
369 conn: Some(conn),
370 pool: self.inner.clone(),
371 rls_dirty: false,
372 })
373 }
374
375 pub async fn acquire_with_rls(
391 &self,
392 ctx: qail_core::rls::RlsContext,
393 ) -> PgResult<PooledConnection> {
394 let mut conn = self.acquire().await?;
395
396 let sql = super::rls::context_to_sql(&ctx);
398 let pg_conn = conn.get_mut();
399 pg_conn.execute_simple(&sql).await?;
400
401 conn.rls_dirty = true;
403
404 Ok(conn)
405 }
406
407 pub async fn idle_count(&self) -> usize {
409 self.inner.connections.lock().await.len()
410 }
411
412 pub fn active_count(&self) -> usize {
414 self.inner.active_count.load(Ordering::Relaxed)
415 }
416
417 pub fn max_connections(&self) -> usize {
419 self.inner.config.max_connections
420 }
421
422 pub async fn stats(&self) -> PoolStats {
424 let idle = self.inner.connections.lock().await.len();
425 PoolStats {
426 active: self.inner.active_count.load(Ordering::Relaxed),
427 idle,
428 pending: self.inner.config.max_connections
429 - self.inner.semaphore.available_permits()
430 - self.active_count(),
431 max_size: self.inner.config.max_connections,
432 total_created: self.inner.total_created.load(Ordering::Relaxed),
433 }
434 }
435
436 pub fn is_closed(&self) -> bool {
438 self.inner.closed.load(Ordering::Relaxed)
439 }
440
441 pub async fn close(&self) {
443 self.inner.closed.store(true, Ordering::Relaxed);
444
445 let mut connections = self.inner.connections.lock().await;
446 connections.clear();
447 }
448
449 async fn create_connection(config: &PoolConfig) -> PgResult<PgConnection> {
451 match &config.password {
452 Some(password) => {
453 PgConnection::connect_with_password(
454 &config.host,
455 config.port,
456 &config.user,
457 &config.database,
458 Some(password),
459 )
460 .await
461 }
462 None => {
463 PgConnection::connect(&config.host, config.port, &config.user, &config.database)
464 .await
465 }
466 }
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473
474 #[test]
475 fn test_pool_config() {
476 let config = PoolConfig::new("localhost", 5432, "user", "testdb")
477 .password("secret123")
478 .max_connections(20)
479 .min_connections(5);
480
481 assert_eq!(config.host, "localhost");
482 assert_eq!(config.port, 5432);
483 assert_eq!(config.user, "user");
484 assert_eq!(config.database, "testdb");
485 assert_eq!(config.password, Some("secret123".to_string()));
486 assert_eq!(config.max_connections, 20);
487 assert_eq!(config.min_connections, 5);
488 }
489}