tank_postgres/
connection.rs

1use crate::{
2    PostgresDriver, PostgresPrepared, PostgresTransaction, ValueWrap,
3    util::{
4        stream_postgres_row_to_tank_row, stream_postgres_simple_query_message_to_tank_query_result,
5    },
6};
7use async_stream::try_stream;
8use openssl::ssl::{SslConnector, SslFiletype, SslMethod, SslVerifyMode};
9use postgres_openssl::MakeTlsConnector;
10use std::{borrow::Cow, env, path::PathBuf, pin::pin, str::FromStr, sync::Arc};
11use tank_core::{
12    Connection, Driver, Error, ErrorContext, Executor, Query, QueryResult, Result, Transaction,
13    future::Either,
14    stream::{Stream, StreamExt, TryStreamExt},
15    truncate_long,
16};
17use tokio::{spawn, task::JoinHandle};
18use tokio_postgres::NoTls;
19use url::Url;
20use urlencoding::decode;
21
22#[derive(Debug)]
23pub struct PostgresConnection {
24    pub(crate) client: tokio_postgres::Client,
25    pub(crate) handle: JoinHandle<()>,
26    pub(crate) _transaction: bool,
27}
28
29impl Executor for PostgresConnection {
30    type Driver = PostgresDriver;
31
32    fn driver(&self) -> &Self::Driver {
33        &PostgresDriver {}
34    }
35
36    async fn prepare(&mut self, sql: String) -> Result<Query<Self::Driver>> {
37        let sql = sql.trim_end().trim_end_matches(';');
38        Ok(
39            PostgresPrepared::new(self.client.prepare(&sql).await.map_err(|e| {
40                let e = Error::new(e).context(format!(
41                    "While preparing the query:\n{}",
42                    truncate_long!(sql)
43                ));
44                log::error!("{:#}", e);
45                e
46            })?)
47            .into(),
48        )
49    }
50
51    fn run(
52        &mut self,
53        query: Query<Self::Driver>,
54    ) -> impl Stream<Item = Result<QueryResult>> + Send {
55        let context = Arc::new(format!("While running the query:\n{}", query));
56        match query {
57            Query::Raw(sql) => {
58                Either::Left(stream_postgres_simple_query_message_to_tank_query_result(
59                    async move || self.client.simple_query_raw(&sql).await.map_err(Into::into),
60                ))
61            }
62            Query::Prepared(..) => Either::Right(try_stream! {
63                let mut transaction = self.begin().await?;
64                {
65                    let mut stream = pin!(transaction.run(query));
66                    while let Some(value) = stream.next().await.transpose()? {
67                        yield value;
68                    }
69                }
70                transaction.commit().await?;
71            }),
72        }
73        .map_err(move |e: Error| {
74            let e = e.context(context.clone());
75            log::error!("{:#}", e);
76            e
77        })
78    }
79
80    fn fetch<'s>(
81        &'s mut self,
82        query: Query<Self::Driver>,
83    ) -> impl Stream<Item = Result<tank_core::RowLabeled>> + Send + 's {
84        let context = Arc::new(format!("While fetching the query:\n{}", query));
85        match query {
86            Query::Raw(sql) => Either::Left(stream_postgres_row_to_tank_row(async move || {
87                self.client
88                    .query_raw(&sql, Vec::<ValueWrap>::new())
89                    .await
90                    .map_err(|e| {
91                        let e = Error::new(e).context(context.clone());
92                        log::error!("{:#}", e);
93                        e
94                    })
95            })),
96            Query::Prepared(..) => Either::Right(
97                try_stream! {
98                    let mut transaction = self.begin().await?;
99                    {
100                        let mut stream = pin!(transaction.fetch(query));
101                        while let Some(value) = stream.next().await.transpose()? {
102                            yield value;
103                        }
104                    }
105                    transaction.commit().await?;
106                }
107                .map_err(move |e: Error| {
108                    let e = e.context(context.clone());
109                    log::error!("{:#}", e);
110                    e
111                }),
112            ),
113        }
114    }
115}
116
117impl Connection for PostgresConnection {
118    #[allow(refining_impl_trait)]
119    async fn connect(url: Cow<'static, str>) -> Result<PostgresConnection> {
120        let context = || format!("While trying to connect to `{}`", url);
121        let url = decode(&url).with_context(context)?;
122        let prefix = format!("{}://", <Self::Driver as Driver>::NAME);
123        if !url.starts_with(&prefix) {
124            let error = Error::msg(format!(
125                "Postgres connection url must start with `{}`",
126                &prefix
127            ))
128            .context(context());
129            log::error!("{:#}", error);
130            return Err(error);
131        }
132        let mut url = Url::parse(&url).with_context(context)?;
133        let mut take_url_param = |key: &str, env_var: &str, remove: bool| {
134            let value = url
135                .query_pairs()
136                .find_map(|(k, v)| if k == key { Some(v) } else { None })
137                .map(|v| v.to_string());
138            if remove && let Some(..) = value {
139                let mut result = url.clone();
140                result.set_query(None);
141                result
142                    .query_pairs_mut()
143                    .extend_pairs(url.query_pairs().filter(|(k, _)| k != key));
144                url = result;
145            };
146            value.or_else(|| env::var(env_var).ok().map(Into::into))
147        };
148        let sslmode = take_url_param("sslmode", "PGSSLMODE", false).unwrap_or("disable".into());
149        let (client, handle) = if sslmode == "disable" {
150            let (client, connection) = tokio_postgres::connect(url.as_str(), NoTls).await?;
151            let handle = spawn(async move {
152                if let Err(e) = connection.await
153                    && !e.is_closed()
154                {
155                    log::error!("Postgres connection error: {:#}", e);
156                }
157            });
158            (client, handle)
159        } else {
160            let mut builder = SslConnector::builder(SslMethod::tls())?;
161            let path = PathBuf::from_str(
162                take_url_param("sslrootcert", "PGSSLROOTCERT", true)
163                    .as_deref()
164                    .unwrap_or("~/.postgresql/root.crt"),
165            )
166            .context(context())?;
167            if path.exists() {
168                builder.set_ca_file(path)?;
169            }
170            let path = PathBuf::from_str(
171                take_url_param("sslcert", "PGSSLCERT", true)
172                    .as_deref()
173                    .unwrap_or("~/.postgresql/postgresql.crt"),
174            )
175            .context(context())?;
176            if path.exists() {
177                builder.set_certificate_chain_file(path)?;
178            }
179            let path = PathBuf::from_str(
180                take_url_param("sslkey", "PGSSLKEY", true)
181                    .as_deref()
182                    .unwrap_or("~/.postgresql/postgresql.key"),
183            )
184            .context(context())?;
185            if path.exists() {
186                builder.set_private_key_file(path, SslFiletype::PEM)?;
187            }
188            builder.set_verify(SslVerifyMode::PEER);
189            let connector = MakeTlsConnector::new(builder.build());
190            let (client, connection) = tokio_postgres::connect(url.as_str(), connector).await?;
191            let handle = spawn(async move {
192                if let Err(e) = connection.await
193                    && !e.is_closed()
194                {
195                    log::error!("Postgres connection error: {:#}", e);
196                }
197            });
198            (client, handle)
199        };
200        Ok(Self {
201            client,
202            handle,
203            _transaction: false,
204        })
205    }
206
207    #[allow(refining_impl_trait)]
208    fn begin(&mut self) -> impl Future<Output = Result<PostgresTransaction<'_>>> {
209        PostgresTransaction::new(self)
210    }
211
212    #[allow(refining_impl_trait)]
213    async fn disconnect(self) -> Result<()> {
214        drop(self.client);
215        if let Err(e) = self.handle.await {
216            let e = Error::new(e).context("While disconnecting from Postgres");
217            log::error!("{:#}", e);
218            return Err(e);
219        }
220        Ok(())
221    }
222}