1use std::sync::Arc;
15
16use deadpool_postgres::{Config, ManagerConfig, RecyclingMethod, Runtime};
17use tokio_postgres::NoTls;
18use tokio_postgres::types::ToSql;
19
20use crate::error::{BsqlError, BsqlResult, ConnectError};
21use crate::singleflight::{FlightStatus, Singleflight, sql_key};
22use crate::stream::QueryStream;
23use crate::transaction::Transaction;
24
25pub struct Pool {
30 primary: deadpool_postgres::Pool,
31 replicas: Vec<deadpool_postgres::Pool>,
34 replica_idx: std::sync::atomic::AtomicUsize,
36 pgbouncer: PgBouncerInfo,
37 singleflight: Singleflight,
38}
39
40#[derive(Debug, Clone, Copy)]
42pub(crate) struct PgBouncerInfo {
43 detected: bool,
45 supports_named_stmts: bool,
48}
49
50impl PgBouncerInfo {
51 const DIRECT: Self = Self {
52 detected: false,
53 supports_named_stmts: true,
54 };
55}
56
57pub struct PoolBuilder {
59 host: Option<String>,
60 port: Option<u16>,
61 dbname: Option<String>,
62 user: Option<String>,
63 password: Option<String>,
64 max_size: usize,
65 connect_timeout_secs: u64,
66 replica_urls: Vec<String>,
67}
68
69impl PoolBuilder {
70 pub fn host(mut self, host: &str) -> Self {
71 self.host = Some(host.into());
72 self
73 }
74
75 pub fn port(mut self, port: u16) -> Self {
76 self.port = Some(port);
77 self
78 }
79
80 pub fn dbname(mut self, dbname: &str) -> Self {
81 self.dbname = Some(dbname.into());
82 self
83 }
84
85 pub fn user(mut self, user: &str) -> Self {
86 self.user = Some(user.into());
87 self
88 }
89
90 pub fn password(mut self, password: &str) -> Self {
91 self.password = Some(password.into());
92 self
93 }
94
95 pub fn max_size(mut self, size: usize) -> Self {
96 self.max_size = size;
97 self
98 }
99
100 pub fn connect_timeout(mut self, secs: u64) -> Self {
103 self.connect_timeout_secs = secs;
104 self
105 }
106
107 pub fn replica(mut self, url: &str) -> Self {
115 self.replica_urls.push(url.into());
116 self
117 }
118
119 pub async fn build(self) -> BsqlResult<Pool> {
120 let mut cfg = Config::new();
121 cfg.host = self.host;
122 cfg.port = self.port;
123 cfg.dbname = self.dbname;
124 cfg.user = self.user;
125 cfg.password = self.password;
126 cfg.connect_timeout = Some(std::time::Duration::from_secs(self.connect_timeout_secs));
127 cfg.manager = Some(ManagerConfig {
128 recycling_method: RecyclingMethod::Fast,
129 });
130 cfg.pool = Some(deadpool_postgres::PoolConfig {
132 max_size: self.max_size,
133 timeouts: deadpool_postgres::Timeouts {
134 wait: Some(std::time::Duration::ZERO),
135 create: None,
136 recycle: None,
137 },
138 ..Default::default()
139 });
140
141 let pool = cfg
142 .create_pool(Some(Runtime::Tokio1), NoTls)
143 .map_err(|e| ConnectError::create(e.to_string()))?;
144
145 let pgbouncer = detect_pgbouncer(&pool).await?;
147
148 let mut replicas = Vec::with_capacity(self.replica_urls.len());
150 for url in &self.replica_urls {
151 let replica_pool = create_pool_from_url(url, self.max_size).await?;
152 replicas.push(replica_pool);
153 }
154
155 Ok(Pool {
156 primary: pool,
157 replicas,
158 replica_idx: std::sync::atomic::AtomicUsize::new(0),
159 pgbouncer,
160 singleflight: Singleflight::new(),
161 })
162 }
163}
164
165impl Pool {
166 pub async fn connect(url: &str) -> BsqlResult<Self> {
170 let config: tokio_postgres::Config = url
171 .parse()
172 .map_err(|e: tokio_postgres::Error| ConnectError::create(e.to_string()))?;
173
174 let mut cfg = Config::new();
175 cfg.host = config.get_hosts().first().map(|h| match h {
176 tokio_postgres::config::Host::Tcp(s) => s.clone(),
177 #[cfg(unix)]
178 tokio_postgres::config::Host::Unix(p) => p.to_string_lossy().into_owned(),
179 });
180 cfg.port = config.get_ports().first().copied();
181 cfg.dbname = config.get_dbname().map(String::from);
182 cfg.user = config.get_user().map(String::from);
183 cfg.password =
184 match config.get_password() {
185 Some(p) => Some(String::from_utf8(p.to_vec()).map_err(|_| {
186 ConnectError::create("database password contains invalid UTF-8")
187 })?),
188 None => None,
189 };
190 cfg.connect_timeout = Some(std::time::Duration::from_secs(5));
191 cfg.manager = Some(ManagerConfig {
192 recycling_method: RecyclingMethod::Fast,
193 });
194 cfg.pool = Some(deadpool_postgres::PoolConfig {
196 max_size: 16,
197 timeouts: deadpool_postgres::Timeouts {
198 wait: Some(std::time::Duration::ZERO),
199 create: None,
200 recycle: None,
201 },
202 ..Default::default()
203 });
204
205 let pool = cfg
206 .create_pool(Some(Runtime::Tokio1), NoTls)
207 .map_err(|e| ConnectError::create(e.to_string()))?;
208
209 let pgbouncer = detect_pgbouncer(&pool).await?;
211
212 Ok(Pool {
213 primary: pool,
214 replicas: Vec::new(),
215 replica_idx: std::sync::atomic::AtomicUsize::new(0),
216 pgbouncer,
217 singleflight: Singleflight::new(),
218 })
219 }
220
221 pub fn builder() -> PoolBuilder {
223 PoolBuilder {
224 host: None,
225 port: None,
226 dbname: None,
227 user: None,
228 password: None,
229 max_size: 16,
230 connect_timeout_secs: 5,
231 replica_urls: Vec::new(),
232 }
233 }
234
235 pub async fn acquire(&self) -> BsqlResult<PoolConnection> {
240 let conn = self.primary.get().await.map_err(BsqlError::from)?;
241
242 Ok(PoolConnection {
243 inner: conn,
244 pgbouncer: self.pgbouncer,
245 })
246 }
247
248 pub fn is_pgbouncer(&self) -> bool {
250 self.pgbouncer.detected
251 }
252
253 pub fn supports_named_statements(&self) -> bool {
257 self.pgbouncer.supports_named_stmts
258 }
259
260 pub fn has_replicas(&self) -> bool {
262 !self.replicas.is_empty()
263 }
264
265 pub async fn begin(&self) -> BsqlResult<Transaction> {
273 let conn = self.acquire().await?;
274 conn.inner
275 .batch_execute("BEGIN")
276 .await
277 .map_err(BsqlError::from)?;
278 Ok(Transaction::new(conn))
279 }
280
281 pub async fn query_stream(
295 &self,
296 sql: &str,
297 params: &[&(dyn ToSql + Sync)],
298 ) -> BsqlResult<QueryStream> {
299 let conn = self.acquire().await?;
300 let stmt = conn
301 .inner
302 .prepare_cached(sql)
303 .await
304 .map_err(BsqlError::from)?;
305
306 let row_stream = conn
307 .inner
308 .query_raw(&stmt, params.iter().copied())
309 .await
310 .map_err(BsqlError::from)?;
311
312 Ok(QueryStream::new(conn, row_stream))
313 }
314
315 pub fn status(&self) -> PoolStatus {
317 let status = self.primary.status();
318 PoolStatus {
319 available: status.available,
320 size: status.size,
321 max_size: status.max_size,
322 }
323 }
324
325 pub(crate) async fn query_raw_primary(
329 &self,
330 sql: &str,
331 params: &[&(dyn ToSql + Sync)],
332 ) -> BsqlResult<Arc<Vec<tokio_postgres::Row>>> {
333 if params.is_empty() {
337 let key = sql_key(sql);
338 self.query_with_singleflight(key, sql, params, false).await
339 } else {
340 self.execute_on_pool(sql, params, false).await
341 }
342 }
343
344 pub(crate) async fn query_raw_read(
347 &self,
348 sql: &str,
349 params: &[&(dyn ToSql + Sync)],
350 ) -> BsqlResult<Arc<Vec<tokio_postgres::Row>>> {
351 if self.replicas.is_empty() {
352 return self.query_raw_primary(sql, params).await;
353 }
354
355 if params.is_empty() {
356 let key = sql_key(sql);
357 match self.query_with_singleflight(key, sql, params, true).await {
359 Ok(rows) => Ok(rows),
360 Err(_) => self.query_with_singleflight(key, sql, params, false).await,
361 }
362 } else {
363 match self.execute_on_pool(sql, params, true).await {
365 Ok(rows) => Ok(rows),
366 Err(_) => self.execute_on_pool(sql, params, false).await,
367 }
368 }
369 }
370
371 async fn query_with_singleflight(
373 &self,
374 key: u64,
375 sql: &str,
376 params: &[&(dyn ToSql + Sync)],
377 use_replica: bool,
378 ) -> BsqlResult<Arc<Vec<tokio_postgres::Row>>> {
379 match self.singleflight.try_join(key) {
380 FlightStatus::Follower(mut rx) => {
381 match rx.recv().await {
383 Ok(rows) => Ok(rows),
384 Err(_) => {
385 self.execute_on_pool(sql, params, use_replica).await
387 }
388 }
389 }
390 FlightStatus::Leader => match self.execute_on_pool(sql, params, use_replica).await {
391 Ok(rows) => {
392 self.singleflight.complete(key, Arc::clone(&rows));
393 Ok(rows)
394 }
395 Err(e) => {
396 self.singleflight.abandon(key);
397 Err(e)
398 }
399 },
400 }
401 }
402
403 async fn execute_on_pool(
405 &self,
406 sql: &str,
407 params: &[&(dyn ToSql + Sync)],
408 use_replica: bool,
409 ) -> BsqlResult<Arc<Vec<tokio_postgres::Row>>> {
410 let raw_conn = if use_replica && !self.replicas.is_empty() {
411 let idx = self
412 .replica_idx
413 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
414 % self.replicas.len();
415 self.replicas[idx].get().await.map_err(BsqlError::from)?
416 } else {
417 self.primary.get().await.map_err(BsqlError::from)?
418 };
419
420 let stmt = raw_conn
421 .prepare_cached(sql)
422 .await
423 .map_err(BsqlError::from)?;
424
425 let rows = raw_conn
426 .query(&stmt, params)
427 .await
428 .map_err(BsqlError::from)?;
429
430 Ok(Arc::new(rows))
431 }
432}
433
434pub struct PoolConnection {
438 pub(crate) inner: deadpool_postgres::Object,
439 pub(crate) pgbouncer: PgBouncerInfo,
440}
441
442impl PoolConnection {
443 pub fn supports_named_statements(&self) -> bool {
445 self.pgbouncer.supports_named_stmts
446 }
447}
448
449#[derive(Debug, Clone, Copy)]
451pub struct PoolStatus {
452 pub available: usize,
453 pub size: usize,
454 pub max_size: usize,
455}
456
457async fn create_pool_from_url(url: &str, max_size: usize) -> BsqlResult<deadpool_postgres::Pool> {
461 let config: tokio_postgres::Config = url
462 .parse()
463 .map_err(|e: tokio_postgres::Error| ConnectError::create(e.to_string()))?;
464
465 let mut cfg = Config::new();
466 cfg.host = config.get_hosts().first().map(|h| match h {
467 tokio_postgres::config::Host::Tcp(s) => s.clone(),
468 #[cfg(unix)]
469 tokio_postgres::config::Host::Unix(p) => p.to_string_lossy().into_owned(),
470 });
471 cfg.port = config.get_ports().first().copied();
472 cfg.dbname = config.get_dbname().map(String::from);
473 cfg.user = config.get_user().map(String::from);
474 cfg.password = match config.get_password() {
475 Some(p) => Some(
476 String::from_utf8(p.to_vec())
477 .map_err(|_| ConnectError::create("database password contains invalid UTF-8"))?,
478 ),
479 None => None,
480 };
481 cfg.connect_timeout = Some(std::time::Duration::from_secs(5));
482 cfg.manager = Some(ManagerConfig {
483 recycling_method: RecyclingMethod::Fast,
484 });
485 cfg.pool = Some(deadpool_postgres::PoolConfig {
486 max_size,
487 timeouts: deadpool_postgres::Timeouts {
488 wait: Some(std::time::Duration::ZERO),
489 create: None,
490 recycle: None,
491 },
492 ..Default::default()
493 });
494
495 let pool = cfg
496 .create_pool(Some(Runtime::Tokio1), NoTls)
497 .map_err(|e| ConnectError::create(e.to_string()))?;
498
499 let _conn = pool
501 .get()
502 .await
503 .map_err(|e| ConnectError::with_source(format!("failed to connect to replica: {e}"), e))?;
504
505 Ok(pool)
506}
507
508async fn detect_pgbouncer(pool: &deadpool_postgres::Pool) -> BsqlResult<PgBouncerInfo> {
516 let conn = pool.get().await.map_err(|e| {
517 ConnectError::with_source(format!("failed to establish initial connection: {e}"), e)
518 })?;
519
520 let is_pgbouncer = conn.simple_query("SHOW POOLS").await.is_ok();
522
523 if !is_pgbouncer {
524 return Ok(PgBouncerInfo::DIRECT);
525 }
526
527 let supports_named = match conn.simple_query("SHOW CONFIG").await {
529 Ok(messages) => messages.iter().any(|msg| {
530 if let tokio_postgres::SimpleQueryMessage::Row(row) = msg {
531 row.get(0) == Some("prepared_statements") && row.get(1) == Some("yes")
532 } else {
533 false
534 }
535 }),
536 Err(_) => false,
537 };
538
539 Ok(PgBouncerInfo {
540 detected: true,
541 supports_named_stmts: supports_named,
542 })
543}
544
545#[cfg(test)]
546mod tests {
547 use super::*;
548
549 #[test]
550 fn builder_defaults() {
551 let b = Pool::builder();
552 assert_eq!(b.max_size, 16);
553 assert_eq!(b.connect_timeout_secs, 5);
554 assert!(b.replica_urls.is_empty());
555 }
556
557 #[test]
558 fn builder_config() {
559 let b = Pool::builder()
560 .host("localhost")
561 .port(5432)
562 .dbname("test")
563 .user("app")
564 .password("secret")
565 .max_size(8)
566 .connect_timeout(10);
567
568 assert_eq!(b.host.as_deref(), Some("localhost"));
569 assert_eq!(b.port, Some(5432));
570 assert_eq!(b.dbname.as_deref(), Some("test"));
571 assert_eq!(b.user.as_deref(), Some("app"));
572 assert_eq!(b.password.as_deref(), Some("secret"));
573 assert_eq!(b.max_size, 8);
574 assert_eq!(b.connect_timeout_secs, 10);
575 }
576
577 #[test]
578 fn builder_replicas() {
579 let b = Pool::builder()
580 .replica("postgres://replica1:5432/db")
581 .replica("postgres://replica2:5432/db");
582 assert_eq!(b.replica_urls.len(), 2);
583 }
584
585 #[test]
586 fn pgbouncer_direct_defaults() {
587 let info = PgBouncerInfo::DIRECT;
588 assert!(!info.detected);
589 assert!(info.supports_named_stmts);
590 }
591
592 #[test]
593 fn pool_status_type_is_copy() {
594 fn assert_copy<T: Copy>() {}
595 assert_copy::<PoolStatus>();
596 }
597}