l337_postgres/
lib.rs

1//! Postgres adapater for l3-37 pool
2// #![deny(missing_docs, missing_debug_implementations)]
3
4use async_trait::async_trait;
5use futures::{channel::oneshot, prelude::*};
6use std::{
7    convert::{AsMut, AsRef},
8    ops::{Deref, DerefMut},
9};
10use tokio::spawn;
11use tokio_postgres::error::Error;
12use tokio_postgres::{
13    tls::{MakeTlsConnect, TlsConnect},
14    Client, Socket,
15};
16use tracing::{debug, debug_span, info, warn, Instrument};
17
18use std::fmt;
19
20pub struct AsyncConnection {
21    pub client: Client,
22    broken: bool,
23    done_rx: oneshot::Receiver<()>,
24    drop_tx: Option<oneshot::Sender<()>>,
25}
26
27// Connections can be dropped when they report an error from is_valid, or return
28// true from has_broken. The channel is used here to ensure that the async
29// driver task spawned in PostgresConnectionManager::connect is ended.
30impl Drop for AsyncConnection {
31    fn drop(&mut self) {
32        // If the receiver is gone here, it means the task is already finished,
33        // and it's no problem.
34        if let Some(drop_tx) = self.drop_tx.take() {
35            let _ = drop_tx.send(());
36        }
37    }
38}
39
40impl Deref for AsyncConnection {
41    type Target = Client;
42
43    fn deref(&self) -> &Self::Target {
44        &self.client
45    }
46}
47
48impl DerefMut for AsyncConnection {
49    fn deref_mut(&mut self) -> &mut Self::Target {
50        &mut self.client
51    }
52}
53
54impl AsMut<Client> for AsyncConnection {
55    fn as_mut(&mut self) -> &mut Client {
56        &mut self.client
57    }
58}
59
60impl AsRef<Client> for AsyncConnection {
61    fn as_ref(&self) -> &Client {
62        &self.client
63    }
64}
65
66/// A `ManageConnection` for `tokio_postgres::Connection`s.
67pub struct PostgresConnectionManager<T>
68where
69    T: 'static + MakeTlsConnect<Socket> + Clone + Send + Sync,
70{
71    config: tokio_postgres::Config,
72    make_tls_connect: T,
73}
74
75impl<T> PostgresConnectionManager<T>
76where
77    T: 'static + MakeTlsConnect<Socket> + Clone + Send + Sync,
78{
79    /// Create a new `PostgresConnectionManager`.
80    pub fn new(config: tokio_postgres::Config, make_tls_connect: T) -> Self {
81        Self {
82            config,
83            make_tls_connect,
84        }
85    }
86}
87
88#[async_trait]
89impl<T> l337::ManageConnection for PostgresConnectionManager<T>
90where
91    T: 'static + MakeTlsConnect<Socket> + Clone + Send + Sync,
92    T::Stream: Send + Sync,
93    T::TlsConnect: Send,
94    <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
95{
96    type Connection = AsyncConnection;
97    type Error = Error;
98
99    async fn connect(&self) -> Result<Self::Connection, l337::Error<Self::Error>> {
100        let (client, connection) = self
101            .config
102            .connect(self.make_tls_connect.clone())
103            .instrument(debug_span!("connect: open new postgres connection"))
104            .await
105            .map_err(|e| l337::Error::External(e))?;
106
107        let (done_tx, done_rx) = oneshot::channel();
108        let (drop_tx, drop_rx) = oneshot::channel();
109        spawn(async move {
110            debug!("connect: start connection future");
111            let connection = connection.fuse();
112            let drop_rx = drop_rx.fuse();
113
114            futures::pin_mut!(connection, drop_rx);
115
116            futures::select! {
117                result = connection => {
118                    if let Err(e) = result {
119                        warn!("future backing postgres future ended with an error: {}", e);
120                    }
121                }
122                _ = drop_rx => { }
123            }
124
125            // If this fails to send, the connection object was already dropped and does not need to be notified
126            let _ = done_tx.send(());
127
128            info!("connect: connection future ended");
129        });
130
131        debug!("connect: postgres connection established");
132        Ok(AsyncConnection {
133            broken: false,
134            client,
135            done_rx,
136            drop_tx: Some(drop_tx),
137        })
138    }
139
140    async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), l337::Error<Self::Error>> {
141        // If we can execute this without erroring, we're definitely still connected to the database
142        conn.simple_query("")
143            .await
144            .map_err(|e| l337::Error::External(e))?;
145
146        Ok(())
147    }
148
149    fn has_broken(&self, conn: &mut Self::Connection) -> bool {
150        if conn.broken {
151            return true;
152        }
153
154        if conn.client.is_closed() {
155            return true;
156        }
157
158        // Use try_recv() as `has_broken` can be called via Drop and not have a
159        // future Context to poll on.
160        // https://docs.rs/futures/0.3.1/futures/channel/oneshot/struct.Receiver.html#method.try_recv
161        match conn.done_rx.try_recv() {
162            // If we get any message, the connection task stopped, which means this connection is
163            // now dead
164            Ok(Some(_)) => {
165                conn.broken = true;
166                true
167            }
168            // If the future isn't ready, then we haven't sent a value which means the future is
169            // still successfully running
170            Ok(None) => false,
171            // This can happen if the future that the connection was
172            // spawned in panicked or was dropped.
173            Err(error) => {
174                warn!(%error, "cannot receive from connection future");
175                conn.broken = true;
176                true
177            }
178        }
179    }
180
181    fn timed_out(&self) -> l337::Error<Self::Error> {
182        unimplemented!()
183        // Error::io(io::ErrorKind::TimedOut.into())
184    }
185}
186
187impl<T> fmt::Debug for PostgresConnectionManager<T>
188where
189    T: 'static + MakeTlsConnect<Socket> + Clone + Send + Sync,
190{
191    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
192        f.debug_struct("PostgresConnectionManager")
193            .field("config", &self.config)
194            .finish()
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use l337::{Config, Pool};
202    use std::time::Duration;
203    use tokio::time::sleep;
204
205    #[tokio::test]
206    async fn it_works() {
207        let mngr = PostgresConnectionManager::new(
208            "postgres://pass_user:password@localhost:5433/postgres"
209                .parse()
210                .unwrap(),
211            tokio_postgres::NoTls,
212        );
213
214        let config: Config = Default::default();
215        let pool = Pool::new(mngr, config).await.unwrap();
216        let conn = pool.connection().await.unwrap();
217        let select = conn.prepare("SELECT 1::INT4").await.unwrap();
218
219        let rows = conn.query(&select, &[]).await.unwrap();
220
221        for row in rows {
222            assert_eq!(1, row.get(0));
223        }
224    }
225
226    #[tokio::test]
227    async fn it_allows_multiple_queries_at_the_same_time() {
228        let mngr = PostgresConnectionManager::new(
229            "postgres://pass_user:password@localhost:5433/postgres"
230                .parse()
231                .unwrap(),
232            tokio_postgres::NoTls,
233        );
234
235        let config: Config = Default::default();
236        let pool = Pool::new(mngr, config).await.unwrap();
237
238        let q1 = async {
239            let conn = pool.connection().await.unwrap();
240            let select = conn.prepare("SELECT 1::INT4").await.unwrap();
241            let rows = conn.query(&select, &[]).await.unwrap();
242
243            for row in rows {
244                assert_eq!(1, row.get(0));
245            }
246
247            sleep(Duration::from_secs(5)).await;
248
249            conn
250        };
251
252        let q2 = async {
253            let conn = pool.connection().await.unwrap();
254            let select = conn.prepare("SELECT 2::INT4").await.unwrap();
255            let rows = conn.query(&select, &[]).await.unwrap();
256
257            for row in rows {
258                assert_eq!(2, row.get(0));
259            }
260
261            sleep(Duration::from_secs(5)).await;
262
263            conn
264        };
265
266        futures::join!(q1, q2);
267    }
268
269    #[tokio::test]
270    async fn it_reuses_connections() {
271        let mngr = PostgresConnectionManager::new(
272            "postgres://pass_user:password@localhost:5433/postgres"
273                .parse()
274                .unwrap(),
275            tokio_postgres::NoTls,
276        );
277
278        let config: Config = Default::default();
279        let pool = Pool::new(mngr, config).await.unwrap();
280        let q1 = async {
281            let conn = pool.connection().await.unwrap();
282            let select = conn.prepare("SELECT 1::INT4").await.unwrap();
283            let rows = conn.query(&select, &[]).await.unwrap();
284
285            for row in rows {
286                assert_eq!(1, row.get(0));
287            }
288        };
289
290        q1.await;
291
292        // This delay is required to ensure that the connection is returned to
293        // the pool after Drop runs. Because Drop spawns a future that returns
294        // the connection to the pool.
295        sleep(Duration::from_millis(500)).await;
296
297        let q2 = async {
298            let conn = pool.connection().await.unwrap();
299            let select = conn.prepare("SELECT 2::INT4").await.unwrap();
300            let rows = conn.query(&select, &[]).await.unwrap();
301
302            for row in rows {
303                assert_eq!(2, row.get(0));
304            }
305        };
306
307        let q3 = async {
308            let conn = pool.connection().await.unwrap();
309            let select = conn.prepare("SELECT 3::INT4").await.unwrap();
310            let rows = conn.query(&select, &[]).await.unwrap();
311
312            for row in rows {
313                assert_eq!(3, row.get(0));
314            }
315        };
316
317        futures::join!(q2, q3);
318
319        assert_eq!(pool.total_conns(), 2);
320    }
321}