1use deadpool_postgres::{Config, ManagerConfig, RecyclingMethod, Runtime};
10use tokio_postgres::NoTls;
11use tokio_postgres::types::ToSql;
12
13use crate::error::{BsqlError, BsqlResult, ConnectError};
14use crate::stream::QueryStream;
15use crate::transaction::Transaction;
16
17pub struct Pool {
22 inner: deadpool_postgres::Pool,
23 pgbouncer: PgBouncerInfo,
24}
25
26#[derive(Debug, Clone, Copy)]
28pub(crate) struct PgBouncerInfo {
29 detected: bool,
31 supports_named_stmts: bool,
34}
35
36impl PgBouncerInfo {
37 const DIRECT: Self = Self {
38 detected: false,
39 supports_named_stmts: true,
40 };
41}
42
43pub struct PoolBuilder {
45 host: Option<String>,
46 port: Option<u16>,
47 dbname: Option<String>,
48 user: Option<String>,
49 password: Option<String>,
50 max_size: usize,
51 connect_timeout_secs: u64,
52}
53
54impl PoolBuilder {
55 pub fn host(mut self, host: &str) -> Self {
56 self.host = Some(host.into());
57 self
58 }
59
60 pub fn port(mut self, port: u16) -> Self {
61 self.port = Some(port);
62 self
63 }
64
65 pub fn dbname(mut self, dbname: &str) -> Self {
66 self.dbname = Some(dbname.into());
67 self
68 }
69
70 pub fn user(mut self, user: &str) -> Self {
71 self.user = Some(user.into());
72 self
73 }
74
75 pub fn password(mut self, password: &str) -> Self {
76 self.password = Some(password.into());
77 self
78 }
79
80 pub fn max_size(mut self, size: usize) -> Self {
81 self.max_size = size;
82 self
83 }
84
85 pub fn connect_timeout(mut self, secs: u64) -> Self {
88 self.connect_timeout_secs = secs;
89 self
90 }
91
92 pub async fn build(self) -> BsqlResult<Pool> {
93 let mut cfg = Config::new();
94 cfg.host = self.host;
95 cfg.port = self.port;
96 cfg.dbname = self.dbname;
97 cfg.user = self.user;
98 cfg.password = self.password;
99 cfg.connect_timeout = Some(std::time::Duration::from_secs(self.connect_timeout_secs));
100 cfg.manager = Some(ManagerConfig {
101 recycling_method: RecyclingMethod::Fast,
102 });
103 cfg.pool = Some(deadpool_postgres::PoolConfig {
105 max_size: self.max_size,
106 timeouts: deadpool_postgres::Timeouts {
107 wait: Some(std::time::Duration::ZERO),
108 create: None,
109 recycle: None,
110 },
111 ..Default::default()
112 });
113
114 let pool = cfg
115 .create_pool(Some(Runtime::Tokio1), NoTls)
116 .map_err(|e| ConnectError::create(e.to_string()))?;
117
118 let pgbouncer = detect_pgbouncer(&pool).await?;
120
121 Ok(Pool {
122 inner: pool,
123 pgbouncer,
124 })
125 }
126}
127
128impl Pool {
129 pub async fn connect(url: &str) -> BsqlResult<Self> {
133 let config: tokio_postgres::Config = url
134 .parse()
135 .map_err(|e: tokio_postgres::Error| ConnectError::create(e.to_string()))?;
136
137 let mut cfg = Config::new();
138 cfg.host = config.get_hosts().first().map(|h| match h {
139 tokio_postgres::config::Host::Tcp(s) => s.clone(),
140 #[cfg(unix)]
141 tokio_postgres::config::Host::Unix(p) => p.to_string_lossy().into_owned(),
142 });
143 cfg.port = config.get_ports().first().copied();
144 cfg.dbname = config.get_dbname().map(String::from);
145 cfg.user = config.get_user().map(String::from);
146 cfg.password = config
147 .get_password()
148 .map(|p| String::from_utf8_lossy(p).into_owned());
149 cfg.connect_timeout = Some(std::time::Duration::from_secs(5));
150 cfg.manager = Some(ManagerConfig {
151 recycling_method: RecyclingMethod::Fast,
152 });
153 cfg.pool = Some(deadpool_postgres::PoolConfig {
155 max_size: 16,
156 timeouts: deadpool_postgres::Timeouts {
157 wait: Some(std::time::Duration::ZERO),
158 create: None,
159 recycle: None,
160 },
161 ..Default::default()
162 });
163
164 let pool = cfg
165 .create_pool(Some(Runtime::Tokio1), NoTls)
166 .map_err(|e| ConnectError::create(e.to_string()))?;
167
168 let pgbouncer = detect_pgbouncer(&pool).await?;
170
171 Ok(Pool {
172 inner: pool,
173 pgbouncer,
174 })
175 }
176
177 pub fn builder() -> PoolBuilder {
179 PoolBuilder {
180 host: None,
181 port: None,
182 dbname: None,
183 user: None,
184 password: None,
185 max_size: 16,
186 connect_timeout_secs: 5,
187 }
188 }
189
190 pub async fn acquire(&self) -> BsqlResult<PoolConnection> {
195 let conn = self.inner.get().await.map_err(BsqlError::from)?;
196
197 Ok(PoolConnection {
198 inner: conn,
199 pgbouncer: self.pgbouncer,
200 })
201 }
202
203 pub fn is_pgbouncer(&self) -> bool {
205 self.pgbouncer.detected
206 }
207
208 pub fn supports_named_statements(&self) -> bool {
212 self.pgbouncer.supports_named_stmts
213 }
214
215 pub async fn begin(&self) -> BsqlResult<Transaction> {
223 let conn = self.acquire().await?;
224 conn.inner
225 .batch_execute("BEGIN")
226 .await
227 .map_err(BsqlError::from)?;
228 Ok(Transaction::new(conn))
229 }
230
231 pub async fn query_stream(
245 &self,
246 sql: &str,
247 params: &[&(dyn ToSql + Sync)],
248 ) -> BsqlResult<QueryStream> {
249 let conn = self.acquire().await?;
250 let stmt = conn
251 .inner
252 .prepare_cached(sql)
253 .await
254 .map_err(BsqlError::from)?;
255
256 let row_stream = conn
257 .inner
258 .query_raw(&stmt, params.iter().copied())
259 .await
260 .map_err(BsqlError::from)?;
261
262 Ok(QueryStream::new(conn, row_stream))
263 }
264
265 pub fn status(&self) -> PoolStatus {
267 let status = self.inner.status();
268 PoolStatus {
269 available: status.available,
270 size: status.size,
271 max_size: status.max_size,
272 }
273 }
274}
275
276pub struct PoolConnection {
280 pub(crate) inner: deadpool_postgres::Object,
281 pub(crate) pgbouncer: PgBouncerInfo,
282}
283
284impl PoolConnection {
285 pub fn supports_named_statements(&self) -> bool {
287 self.pgbouncer.supports_named_stmts
288 }
289}
290
291#[derive(Debug, Clone, Copy)]
293pub struct PoolStatus {
294 pub available: usize,
295 pub size: usize,
296 pub max_size: usize,
297}
298
299async fn detect_pgbouncer(pool: &deadpool_postgres::Pool) -> BsqlResult<PgBouncerInfo> {
307 let conn = pool.get().await.map_err(|e| {
308 ConnectError::with_source(format!("failed to establish initial connection: {e}"), e)
309 })?;
310
311 let is_pgbouncer = conn.simple_query("SHOW POOLS").await.is_ok();
313
314 if !is_pgbouncer {
315 return Ok(PgBouncerInfo::DIRECT);
316 }
317
318 let supports_named = match conn.simple_query("SHOW CONFIG").await {
320 Ok(messages) => messages.iter().any(|msg| {
321 if let tokio_postgres::SimpleQueryMessage::Row(row) = msg {
322 row.get(0) == Some("prepared_statements") && row.get(1) == Some("yes")
323 } else {
324 false
325 }
326 }),
327 Err(_) => false,
328 };
329
330 Ok(PgBouncerInfo {
331 detected: true,
332 supports_named_stmts: supports_named,
333 })
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339
340 #[test]
341 fn builder_defaults() {
342 let b = Pool::builder();
343 assert_eq!(b.max_size, 16);
344 assert_eq!(b.connect_timeout_secs, 5);
345 }
346
347 #[test]
348 fn builder_config() {
349 let b = Pool::builder()
350 .host("localhost")
351 .port(5432)
352 .dbname("test")
353 .user("app")
354 .password("secret")
355 .max_size(8)
356 .connect_timeout(10);
357
358 assert_eq!(b.host.as_deref(), Some("localhost"));
359 assert_eq!(b.port, Some(5432));
360 assert_eq!(b.dbname.as_deref(), Some("test"));
361 assert_eq!(b.user.as_deref(), Some("app"));
362 assert_eq!(b.password.as_deref(), Some("secret"));
363 assert_eq!(b.max_size, 8);
364 assert_eq!(b.connect_timeout_secs, 10);
365 }
366
367 #[test]
368 fn pgbouncer_direct_defaults() {
369 let info = PgBouncerInfo::DIRECT;
370 assert!(!info.detected);
371 assert!(info.supports_named_stmts);
372 }
373
374 #[test]
375 fn pool_status_type_is_copy() {
376 fn assert_copy<T: Copy>() {}
377 assert_copy::<PoolStatus>();
378 }
379}