Skip to main content

pg_wired/
async_pool.rs

1//! Pool of AsyncConns for spreading load across multiple PostgreSQL backends.
2//!
3//! Each AsyncConn maintains its own TCP connection, writer task, and reader task.
4//! The pool dispatches requests round-robin across connections using an atomic counter.
5//! Dead connections are detected and replaced transparently.
6
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::Arc;
9
10use tokio::sync::RwLock;
11
12use crate::async_conn::AsyncConn;
13use crate::connection::WireConn;
14use crate::error::PgWireError;
15use crate::protocol::types::RawRow;
16use crate::tls::TlsMode;
17
18/// Connection configuration for reconnection.
19///
20/// Marked `#[non_exhaustive]` so additional connection options (compression,
21/// statement_timeout, etc.) can be added without breaking downstream construction.
22#[derive(Clone)]
23#[non_exhaustive]
24pub struct ConnConfig {
25    /// Server address as `host:port` (e.g., `"127.0.0.1:5432"`).
26    pub addr: String,
27    /// PostgreSQL role to authenticate as.
28    pub user: String,
29    /// Password for the role; ignored when the server requests trust auth.
30    pub password: String,
31    /// Database name to attach to after authentication.
32    pub database: String,
33    /// TLS preference (plain, prefer, require, verify-ca, verify-full).
34    pub tls_mode: TlsMode,
35}
36
37impl std::fmt::Debug for ConnConfig {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        f.debug_struct("ConnConfig")
40            .field("addr", &self.addr)
41            .field("user", &self.user)
42            .field("password", &"<redacted>")
43            .field("database", &self.database)
44            .field("tls_mode", &self.tls_mode)
45            .finish()
46    }
47}
48
49/// A pool of N AsyncConns for parallel PostgreSQL backend utilization.
50/// Detects dead connections and replaces them automatically.
51pub struct AsyncPool {
52    conns: Vec<RwLock<Arc<AsyncConn>>>,
53    config: ConnConfig,
54    counter: AtomicUsize,
55}
56
57impl std::fmt::Debug for AsyncPool {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        f.debug_struct("AsyncPool")
60            .field("size", &self.conns.len())
61            .field("config", &self.config)
62            .finish()
63    }
64}
65
66impl AsyncPool {
67    /// Create a pool of `size` AsyncConns, each with its own TCP connection.
68    pub async fn connect(
69        addr: &str,
70        user: &str,
71        password: &str,
72        database: &str,
73        size: usize,
74    ) -> Result<Arc<Self>, PgWireError> {
75        Self::connect_with_tls(addr, user, password, database, size, TlsMode::default()).await
76    }
77
78    /// Create a pool with an explicit TLS mode.
79    pub async fn connect_with_tls(
80        addr: &str,
81        user: &str,
82        password: &str,
83        database: &str,
84        size: usize,
85        tls_mode: TlsMode,
86    ) -> Result<Arc<Self>, PgWireError> {
87        if size == 0 {
88            return Err(PgWireError::Protocol("pool size must be >= 1".into()));
89        }
90        let config = ConnConfig {
91            addr: addr.to_string(),
92            user: user.to_string(),
93            password: password.to_string(),
94            database: database.to_string(),
95            tls_mode,
96        };
97
98        let mut conns = Vec::with_capacity(size);
99        for _ in 0..size {
100            let wire =
101                WireConn::connect_with_options(addr, user, password, database, &[], tls_mode)
102                    .await?;
103            conns.push(RwLock::new(Arc::new(AsyncConn::new(wire))));
104        }
105
106        let pool = Arc::new(Self {
107            conns,
108            config,
109            counter: AtomicUsize::new(0),
110        });
111
112        // Spawn background health monitor. Uses a Weak reference so
113        // the monitor stops when the pool is dropped.
114        {
115            let pool_weak = Arc::downgrade(&pool);
116            tokio::spawn(async move {
117                health_monitor(pool_weak).await;
118            });
119        }
120
121        Ok(pool)
122    }
123
124    /// Get the next alive AsyncConn via round-robin.
125    pub async fn get_async(&self) -> Arc<AsyncConn> {
126        let len = self.conns.len();
127        let start = self.counter.fetch_add(1, Ordering::Relaxed) % len;
128
129        for i in 0..len {
130            let idx = (start + i) % len;
131            let conn = self.conns[idx].read().await;
132            if conn.is_alive() {
133                return Arc::clone(&conn);
134            }
135        }
136
137        // All dead — return first anyway, request will fail and trigger reconnect.
138        let conn = self.conns[start % len].read().await;
139        Arc::clone(&conn)
140    }
141
142    /// Replace a dead connection at the given index.
143    async fn reconnect(&self, idx: usize) -> Result<(), PgWireError> {
144        let wire = WireConn::connect_with_options(
145            &self.config.addr,
146            &self.config.user,
147            &self.config.password,
148            &self.config.database,
149            &[],
150            self.config.tls_mode,
151        )
152        .await?;
153        let new_conn = Arc::new(AsyncConn::new(wire));
154
155        let mut slot = self.conns[idx].write().await;
156        *slot = new_conn;
157        tracing::info!("pg-wired: reconnected slot {idx}");
158        Ok(())
159    }
160
161    /// Number of connections in the pool.
162    pub fn size(&self) -> usize {
163        self.conns.len()
164    }
165
166    /// Number of alive connections.
167    pub async fn alive_count(&self) -> usize {
168        let mut count = 0;
169        for slot in &self.conns {
170            let conn = slot.read().await;
171            if conn.is_alive() {
172                count += 1;
173            }
174        }
175        count
176    }
177
178    /// Close all connections in the pool by sending Terminate on each and
179    /// waiting for the writer/reader tasks to exit. Idempotent: dead slots
180    /// are skipped.
181    pub async fn close(&self) -> Result<(), PgWireError> {
182        for slot in &self.conns {
183            let conn = slot.read().await;
184            let _ = conn.close().await;
185        }
186        Ok(())
187    }
188
189    /// Execute a pipelined transaction on the next available connection.
190    pub async fn exec_transaction(
191        &self,
192        setup_sql: &str,
193        query_sql: &str,
194        params: &[Option<&[u8]>],
195        param_oids: &[u32],
196    ) -> Result<Vec<RawRow>, PgWireError> {
197        self.get_async()
198            .await
199            .exec_transaction(setup_sql, query_sql, params, param_oids)
200            .await
201    }
202
203    /// Execute a parameterized query on the next available connection.
204    pub async fn exec_query(
205        &self,
206        sql: &str,
207        params: &[Option<&[u8]>],
208        param_oids: &[u32],
209    ) -> Result<Vec<RawRow>, PgWireError> {
210        self.get_async()
211            .await
212            .exec_query(sql, params, param_oids)
213            .await
214    }
215
216    /// Execute a parameterized query with explicit per-param and per-result
217    /// format codes on the next available connection.
218    pub async fn exec_query_with_formats(
219        &self,
220        sql: &str,
221        params: &[Option<&[u8]>],
222        param_oids: &[u32],
223        param_formats: &[crate::protocol::types::FormatCode],
224        result_formats: &[crate::protocol::types::FormatCode],
225    ) -> Result<Vec<RawRow>, PgWireError> {
226        self.get_async()
227            .await
228            .exec_query_with_formats(sql, params, param_oids, param_formats, result_formats)
229            .await
230    }
231}
232
233/// Background task that checks connection health and reconnects dead ones.
234/// Stops automatically when the pool is dropped (Weak becomes invalid).
235async fn health_monitor(pool_weak: std::sync::Weak<AsyncPool>) {
236    let mut interval = tokio::time::interval(std::time::Duration::from_secs(5));
237    loop {
238        interval.tick().await;
239
240        let pool = match pool_weak.upgrade() {
241            Some(p) => p,
242            None => {
243                tracing::debug!("pg-wired: health monitor stopping (pool dropped)");
244                return;
245            }
246        };
247
248        for idx in 0..pool.conns.len() {
249            let is_dead = {
250                let conn = pool.conns[idx].read().await;
251                !conn.is_alive()
252            };
253
254            if is_dead {
255                tracing::warn!("pg-wired: slot {idx} is dead, reconnecting...");
256                match pool.reconnect(idx).await {
257                    Ok(()) => {}
258                    Err(e) => {
259                        tracing::error!("pg-wired: reconnect slot {idx} failed: {e}");
260                    }
261                }
262            }
263        }
264    }
265}