1use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::Arc;
9
10use tokio::sync::RwLock;
11
12use crate::async_conn::AsyncConn;
13use crate::connection::WireConn;
14use crate::error::PgWireError;
15use crate::protocol::types::RawRow;
16use crate::tls::TlsMode;
17
18#[derive(Clone)]
23#[non_exhaustive]
24pub struct ConnConfig {
25 pub addr: String,
27 pub user: String,
29 pub password: String,
31 pub database: String,
33 pub tls_mode: TlsMode,
35}
36
37impl std::fmt::Debug for ConnConfig {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 f.debug_struct("ConnConfig")
40 .field("addr", &self.addr)
41 .field("user", &self.user)
42 .field("password", &"<redacted>")
43 .field("database", &self.database)
44 .field("tls_mode", &self.tls_mode)
45 .finish()
46 }
47}
48
49pub struct AsyncPool {
52 conns: Vec<RwLock<Arc<AsyncConn>>>,
53 config: ConnConfig,
54 counter: AtomicUsize,
55}
56
57impl std::fmt::Debug for AsyncPool {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("AsyncPool")
60 .field("size", &self.conns.len())
61 .field("config", &self.config)
62 .finish()
63 }
64}
65
66impl AsyncPool {
67 pub async fn connect(
69 addr: &str,
70 user: &str,
71 password: &str,
72 database: &str,
73 size: usize,
74 ) -> Result<Arc<Self>, PgWireError> {
75 Self::connect_with_tls(addr, user, password, database, size, TlsMode::default()).await
76 }
77
78 pub async fn connect_with_tls(
80 addr: &str,
81 user: &str,
82 password: &str,
83 database: &str,
84 size: usize,
85 tls_mode: TlsMode,
86 ) -> Result<Arc<Self>, PgWireError> {
87 if size == 0 {
88 return Err(PgWireError::Protocol("pool size must be >= 1".into()));
89 }
90 let config = ConnConfig {
91 addr: addr.to_string(),
92 user: user.to_string(),
93 password: password.to_string(),
94 database: database.to_string(),
95 tls_mode,
96 };
97
98 let mut conns = Vec::with_capacity(size);
99 for _ in 0..size {
100 let wire =
101 WireConn::connect_with_options(addr, user, password, database, &[], tls_mode)
102 .await?;
103 conns.push(RwLock::new(Arc::new(AsyncConn::new(wire))));
104 }
105
106 let pool = Arc::new(Self {
107 conns,
108 config,
109 counter: AtomicUsize::new(0),
110 });
111
112 {
115 let pool_weak = Arc::downgrade(&pool);
116 tokio::spawn(async move {
117 health_monitor(pool_weak).await;
118 });
119 }
120
121 Ok(pool)
122 }
123
124 pub async fn get_async(&self) -> Arc<AsyncConn> {
126 let len = self.conns.len();
127 let start = self.counter.fetch_add(1, Ordering::Relaxed) % len;
128
129 for i in 0..len {
130 let idx = (start + i) % len;
131 let conn = self.conns[idx].read().await;
132 if conn.is_alive() {
133 return Arc::clone(&conn);
134 }
135 }
136
137 let conn = self.conns[start % len].read().await;
139 Arc::clone(&conn)
140 }
141
142 async fn reconnect(&self, idx: usize) -> Result<(), PgWireError> {
144 let wire = WireConn::connect_with_options(
145 &self.config.addr,
146 &self.config.user,
147 &self.config.password,
148 &self.config.database,
149 &[],
150 self.config.tls_mode,
151 )
152 .await?;
153 let new_conn = Arc::new(AsyncConn::new(wire));
154
155 let mut slot = self.conns[idx].write().await;
156 *slot = new_conn;
157 tracing::info!("pg-wired: reconnected slot {idx}");
158 Ok(())
159 }
160
161 pub fn size(&self) -> usize {
163 self.conns.len()
164 }
165
166 pub async fn alive_count(&self) -> usize {
168 let mut count = 0;
169 for slot in &self.conns {
170 let conn = slot.read().await;
171 if conn.is_alive() {
172 count += 1;
173 }
174 }
175 count
176 }
177
178 pub async fn close(&self) -> Result<(), PgWireError> {
182 for slot in &self.conns {
183 let conn = slot.read().await;
184 let _ = conn.close().await;
185 }
186 Ok(())
187 }
188
189 pub async fn exec_transaction(
191 &self,
192 setup_sql: &str,
193 query_sql: &str,
194 params: &[Option<&[u8]>],
195 param_oids: &[u32],
196 ) -> Result<Vec<RawRow>, PgWireError> {
197 self.get_async()
198 .await
199 .exec_transaction(setup_sql, query_sql, params, param_oids)
200 .await
201 }
202
203 pub async fn exec_query(
205 &self,
206 sql: &str,
207 params: &[Option<&[u8]>],
208 param_oids: &[u32],
209 ) -> Result<Vec<RawRow>, PgWireError> {
210 self.get_async()
211 .await
212 .exec_query(sql, params, param_oids)
213 .await
214 }
215
216 pub async fn exec_query_with_formats(
219 &self,
220 sql: &str,
221 params: &[Option<&[u8]>],
222 param_oids: &[u32],
223 param_formats: &[crate::protocol::types::FormatCode],
224 result_formats: &[crate::protocol::types::FormatCode],
225 ) -> Result<Vec<RawRow>, PgWireError> {
226 self.get_async()
227 .await
228 .exec_query_with_formats(sql, params, param_oids, param_formats, result_formats)
229 .await
230 }
231}
232
233async fn health_monitor(pool_weak: std::sync::Weak<AsyncPool>) {
236 let mut interval = tokio::time::interval(std::time::Duration::from_secs(5));
237 loop {
238 interval.tick().await;
239
240 let pool = match pool_weak.upgrade() {
241 Some(p) => p,
242 None => {
243 tracing::debug!("pg-wired: health monitor stopping (pool dropped)");
244 return;
245 }
246 };
247
248 for idx in 0..pool.conns.len() {
249 let is_dead = {
250 let conn = pool.conns[idx].read().await;
251 !conn.is_alive()
252 };
253
254 if is_dead {
255 tracing::warn!("pg-wired: slot {idx} is dead, reconnecting...");
256 match pool.reconnect(idx).await {
257 Ok(()) => {}
258 Err(e) => {
259 tracing::error!("pg-wired: reconnect slot {idx} failed: {e}");
260 }
261 }
262 }
263 }
264 }
265}