1use async_trait::async_trait;
5use futures::{channel::oneshot, prelude::*};
6use std::{
7 convert::{AsMut, AsRef},
8 ops::{Deref, DerefMut},
9};
10use tokio::spawn;
11use tokio_postgres::error::Error;
12use tokio_postgres::{
13 tls::{MakeTlsConnect, TlsConnect},
14 Client, Socket,
15};
16use tracing::{debug, debug_span, info, warn, Instrument};
17
18use std::fmt;
19
20pub struct AsyncConnection {
21 pub client: Client,
22 broken: bool,
23 done_rx: oneshot::Receiver<()>,
24 drop_tx: Option<oneshot::Sender<()>>,
25}
26
27impl Drop for AsyncConnection {
31 fn drop(&mut self) {
32 if let Some(drop_tx) = self.drop_tx.take() {
35 let _ = drop_tx.send(());
36 }
37 }
38}
39
40impl Deref for AsyncConnection {
41 type Target = Client;
42
43 fn deref(&self) -> &Self::Target {
44 &self.client
45 }
46}
47
48impl DerefMut for AsyncConnection {
49 fn deref_mut(&mut self) -> &mut Self::Target {
50 &mut self.client
51 }
52}
53
54impl AsMut<Client> for AsyncConnection {
55 fn as_mut(&mut self) -> &mut Client {
56 &mut self.client
57 }
58}
59
60impl AsRef<Client> for AsyncConnection {
61 fn as_ref(&self) -> &Client {
62 &self.client
63 }
64}
65
66pub struct PostgresConnectionManager<T>
68where
69 T: 'static + MakeTlsConnect<Socket> + Clone + Send + Sync,
70{
71 config: tokio_postgres::Config,
72 make_tls_connect: T,
73}
74
75impl<T> PostgresConnectionManager<T>
76where
77 T: 'static + MakeTlsConnect<Socket> + Clone + Send + Sync,
78{
79 pub fn new(config: tokio_postgres::Config, make_tls_connect: T) -> Self {
81 Self {
82 config,
83 make_tls_connect,
84 }
85 }
86}
87
88#[async_trait]
89impl<T> l337::ManageConnection for PostgresConnectionManager<T>
90where
91 T: 'static + MakeTlsConnect<Socket> + Clone + Send + Sync,
92 T::Stream: Send + Sync,
93 T::TlsConnect: Send,
94 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
95{
96 type Connection = AsyncConnection;
97 type Error = Error;
98
99 async fn connect(&self) -> Result<Self::Connection, l337::Error<Self::Error>> {
100 let (client, connection) = self
101 .config
102 .connect(self.make_tls_connect.clone())
103 .instrument(debug_span!("connect: open new postgres connection"))
104 .await
105 .map_err(|e| l337::Error::External(e))?;
106
107 let (done_tx, done_rx) = oneshot::channel();
108 let (drop_tx, drop_rx) = oneshot::channel();
109 spawn(async move {
110 debug!("connect: start connection future");
111 let connection = connection.fuse();
112 let drop_rx = drop_rx.fuse();
113
114 futures::pin_mut!(connection, drop_rx);
115
116 futures::select! {
117 result = connection => {
118 if let Err(e) = result {
119 warn!("future backing postgres future ended with an error: {}", e);
120 }
121 }
122 _ = drop_rx => { }
123 }
124
125 let _ = done_tx.send(());
127
128 info!("connect: connection future ended");
129 });
130
131 debug!("connect: postgres connection established");
132 Ok(AsyncConnection {
133 broken: false,
134 client,
135 done_rx,
136 drop_tx: Some(drop_tx),
137 })
138 }
139
140 async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), l337::Error<Self::Error>> {
141 conn.simple_query("")
143 .await
144 .map_err(|e| l337::Error::External(e))?;
145
146 Ok(())
147 }
148
149 fn has_broken(&self, conn: &mut Self::Connection) -> bool {
150 if conn.broken {
151 return true;
152 }
153
154 if conn.client.is_closed() {
155 return true;
156 }
157
158 match conn.done_rx.try_recv() {
162 Ok(Some(_)) => {
165 conn.broken = true;
166 true
167 }
168 Ok(None) => false,
171 Err(error) => {
174 warn!(%error, "cannot receive from connection future");
175 conn.broken = true;
176 true
177 }
178 }
179 }
180
181 fn timed_out(&self) -> l337::Error<Self::Error> {
182 unimplemented!()
183 }
185}
186
187impl<T> fmt::Debug for PostgresConnectionManager<T>
188where
189 T: 'static + MakeTlsConnect<Socket> + Clone + Send + Sync,
190{
191 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
192 f.debug_struct("PostgresConnectionManager")
193 .field("config", &self.config)
194 .finish()
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use l337::{Config, Pool};
202 use std::time::Duration;
203 use tokio::time::sleep;
204
205 #[tokio::test]
206 async fn it_works() {
207 let mngr = PostgresConnectionManager::new(
208 "postgres://pass_user:password@localhost:5433/postgres"
209 .parse()
210 .unwrap(),
211 tokio_postgres::NoTls,
212 );
213
214 let config: Config = Default::default();
215 let pool = Pool::new(mngr, config).await.unwrap();
216 let conn = pool.connection().await.unwrap();
217 let select = conn.prepare("SELECT 1::INT4").await.unwrap();
218
219 let rows = conn.query(&select, &[]).await.unwrap();
220
221 for row in rows {
222 assert_eq!(1, row.get(0));
223 }
224 }
225
226 #[tokio::test]
227 async fn it_allows_multiple_queries_at_the_same_time() {
228 let mngr = PostgresConnectionManager::new(
229 "postgres://pass_user:password@localhost:5433/postgres"
230 .parse()
231 .unwrap(),
232 tokio_postgres::NoTls,
233 );
234
235 let config: Config = Default::default();
236 let pool = Pool::new(mngr, config).await.unwrap();
237
238 let q1 = async {
239 let conn = pool.connection().await.unwrap();
240 let select = conn.prepare("SELECT 1::INT4").await.unwrap();
241 let rows = conn.query(&select, &[]).await.unwrap();
242
243 for row in rows {
244 assert_eq!(1, row.get(0));
245 }
246
247 sleep(Duration::from_secs(5)).await;
248
249 conn
250 };
251
252 let q2 = async {
253 let conn = pool.connection().await.unwrap();
254 let select = conn.prepare("SELECT 2::INT4").await.unwrap();
255 let rows = conn.query(&select, &[]).await.unwrap();
256
257 for row in rows {
258 assert_eq!(2, row.get(0));
259 }
260
261 sleep(Duration::from_secs(5)).await;
262
263 conn
264 };
265
266 futures::join!(q1, q2);
267 }
268
269 #[tokio::test]
270 async fn it_reuses_connections() {
271 let mngr = PostgresConnectionManager::new(
272 "postgres://pass_user:password@localhost:5433/postgres"
273 .parse()
274 .unwrap(),
275 tokio_postgres::NoTls,
276 );
277
278 let config: Config = Default::default();
279 let pool = Pool::new(mngr, config).await.unwrap();
280 let q1 = async {
281 let conn = pool.connection().await.unwrap();
282 let select = conn.prepare("SELECT 1::INT4").await.unwrap();
283 let rows = conn.query(&select, &[]).await.unwrap();
284
285 for row in rows {
286 assert_eq!(1, row.get(0));
287 }
288 };
289
290 q1.await;
291
292 sleep(Duration::from_millis(500)).await;
296
297 let q2 = async {
298 let conn = pool.connection().await.unwrap();
299 let select = conn.prepare("SELECT 2::INT4").await.unwrap();
300 let rows = conn.query(&select, &[]).await.unwrap();
301
302 for row in rows {
303 assert_eq!(2, row.get(0));
304 }
305 };
306
307 let q3 = async {
308 let conn = pool.connection().await.unwrap();
309 let select = conn.prepare("SELECT 3::INT4").await.unwrap();
310 let rows = conn.query(&select, &[]).await.unwrap();
311
312 for row in rows {
313 assert_eq!(3, row.get(0));
314 }
315 };
316
317 futures::join!(q2, q3);
318
319 assert_eq!(pool.total_conns(), 2);
320 }
321}