tank_postgres/
connection.rs

1use crate::{
2    PostgresDriver, PostgresPrepared, PostgresTransaction, ValueWrap, postgres_type_to_value,
3    util::{
4        stream_postgres_row_to_tank_row, stream_postgres_simple_query_message_to_tank_query_result,
5    },
6};
7use async_stream::try_stream;
8use openssl::ssl::{SslConnector, SslFiletype, SslMethod, SslVerifyMode};
9use postgres_openssl::MakeTlsConnector;
10use std::{borrow::Cow, env, mem, path::PathBuf, pin::pin, str::FromStr, sync::Arc};
11use tank_core::{
12    AsQuery, Connection, Driver, Error, ErrorContext, Executor, Query, QueryResult, Result,
13    Transaction,
14    future::Either,
15    stream::{Stream, StreamExt, TryStreamExt},
16    truncate_long,
17};
18use tokio::{spawn, task::JoinHandle};
19use tokio_postgres::NoTls;
20use url::Url;
21use urlencoding::decode;
22
23#[derive(Debug)]
24pub struct PostgresConnection {
25    pub(crate) client: tokio_postgres::Client,
26    pub(crate) handle: JoinHandle<()>,
27    pub(crate) _transaction: bool,
28}
29
30impl Executor for PostgresConnection {
31    type Driver = PostgresDriver;
32
33    fn driver(&self) -> &Self::Driver {
34        &PostgresDriver {}
35    }
36
37    async fn prepare(&mut self, sql: String) -> Result<Query<Self::Driver>> {
38        let sql = sql.trim_end().trim_end_matches(';');
39        Ok(
40            PostgresPrepared::new(self.client.prepare(&sql).await.map_err(|e| {
41                let error = Error::new(e).context(format!(
42                    "While preparing the query:\n{}",
43                    truncate_long!(sql)
44                ));
45                log::error!("{:#}", error);
46                error
47            })?)
48            .into(),
49        )
50    }
51
52    fn run<'s>(
53        &'s mut self,
54        query: impl AsQuery<Self::Driver> + 's,
55    ) -> impl Stream<Item = Result<QueryResult>> + Send {
56        let mut query = query.as_query();
57        let context = Arc::new(format!("While running the query:\n{}", query.as_mut()));
58        let mut owned = mem::take(query.as_mut());
59        match owned {
60            Query::Raw(sql) => {
61                Either::Left(stream_postgres_simple_query_message_to_tank_query_result(
62                    async move || self.client.simple_query_raw(&sql).await.map_err(Into::into),
63                ))
64            }
65            Query::Prepared(..) => Either::Right(try_stream! {
66                let mut transaction = self.begin().await?;
67                {
68                    let mut stream = pin!(transaction.run(&mut owned));
69                    while let Some(value) = stream.next().await.transpose()? {
70                        yield value;
71                    }
72                }
73                transaction.commit().await?;
74                *query.as_mut() = mem::take(&mut owned);
75            }),
76        }
77        .map_err(move |e: Error| {
78            let error = e.context(context.clone());
79            log::error!("{:#}", error);
80            error
81        })
82    }
83
84    fn fetch<'s>(
85        &'s mut self,
86        query: impl AsQuery<Self::Driver> + 's,
87    ) -> impl Stream<Item = Result<tank_core::RowLabeled>> + Send + 's {
88        let mut query = query.as_query();
89        let context = Arc::new(format!("While fetching the query:\n{}", query.as_mut()));
90        let owned = mem::take(query.as_mut());
91        stream_postgres_row_to_tank_row(async move || {
92            let row_stream = match owned {
93                Query::Raw(mut sql) => {
94                    let stream = self
95                        .client
96                        .query_raw(&sql, Vec::<ValueWrap>::new())
97                        .await
98                        .map_err(|e| Error::new(e).context(context.clone()))?;
99                    *query.as_mut() = Query::Raw(mem::take(&mut sql));
100                    stream
101                }
102                Query::Prepared(mut prepared) => {
103                    let mut params = mem::take(&mut prepared.params);
104                    let types = prepared.statement.params();
105
106                    for (i, param) in params.iter_mut().enumerate() {
107                        *param = ValueWrap(
108                            mem::take(&mut param.0).try_as(&postgres_type_to_value(&types[i]))?,
109                        );
110                    }
111                    let stream = self
112                        .client
113                        .query_raw(&prepared.statement, params)
114                        .await
115                        .map_err(|e| Error::new(e).context(context.clone()))?;
116                    *query.as_mut() = Query::Prepared(prepared);
117                    stream
118                }
119            };
120            Ok(row_stream).map_err(|e| {
121                log::error!("{:#}", e);
122                e
123            })
124        })
125    }
126}
127
128impl Connection for PostgresConnection {
129    #[allow(refining_impl_trait)]
130    async fn connect(url: Cow<'static, str>) -> Result<PostgresConnection> {
131        let context = || format!("While trying to connect to `{}`", truncate_long!(url));
132        let url = decode(&url).with_context(context)?;
133        let prefix = format!("{}://", <Self::Driver as Driver>::NAME);
134        if !url.starts_with(&prefix) {
135            let error = Error::msg(format!(
136                "Postgres connection url must start with `{}`",
137                &prefix
138            ))
139            .context(context());
140            log::error!("{:#}", error);
141            return Err(error);
142        }
143        let mut url = Url::parse(&url).with_context(context)?;
144        let mut take_url_param = |key: &str, env_var: &str, remove: bool| {
145            let value = url
146                .query_pairs()
147                .find_map(|(k, v)| if k == key { Some(v) } else { None })
148                .map(|v| v.to_string());
149            if remove && let Some(..) = value {
150                let mut result = url.clone();
151                result.set_query(None);
152                result
153                    .query_pairs_mut()
154                    .extend_pairs(url.query_pairs().filter(|(k, _)| k != key));
155                url = result;
156            };
157            value.or_else(|| env::var(env_var).ok().map(Into::into))
158        };
159        let sslmode = take_url_param("sslmode", "PGSSLMODE", false).unwrap_or("disable".into());
160        let (client, handle) = if sslmode == "disable" {
161            let (client, connection) = tokio_postgres::connect(url.as_str(), NoTls).await?;
162            let handle = spawn(async move {
163                if let Err(error) = connection.await
164                    && !error.is_closed()
165                {
166                    log::error!("Postgres connection error: {:#?}", error);
167                }
168            });
169            (client, handle)
170        } else {
171            let mut builder = SslConnector::builder(SslMethod::tls())?;
172            let path = PathBuf::from_str(
173                take_url_param("sslrootcert", "PGSSLROOTCERT", true)
174                    .as_deref()
175                    .unwrap_or("~/.postgresql/root.crt"),
176            )
177            .context(context())?;
178            if path.exists() {
179                builder.set_ca_file(path)?;
180            }
181            let path = PathBuf::from_str(
182                take_url_param("sslcert", "PGSSLCERT", true)
183                    .as_deref()
184                    .unwrap_or("~/.postgresql/postgresql.crt"),
185            )
186            .context(context())?;
187            if path.exists() {
188                builder.set_certificate_chain_file(path)?;
189            }
190            let path = PathBuf::from_str(
191                take_url_param("sslkey", "PGSSLKEY", true)
192                    .as_deref()
193                    .unwrap_or("~/.postgresql/postgresql.key"),
194            )
195            .context(context())?;
196            if path.exists() {
197                builder.set_private_key_file(path, SslFiletype::PEM)?;
198            }
199            builder.set_verify(SslVerifyMode::PEER);
200            let connector = MakeTlsConnector::new(builder.build());
201            let (client, connection) = tokio_postgres::connect(url.as_str(), connector).await?;
202            let handle = spawn(async move {
203                if let Err(error) = connection.await
204                    && !error.is_closed()
205                {
206                    log::error!("Postgres connection error: {:#?}", error);
207                }
208            });
209            (client, handle)
210        };
211        Ok(Self {
212            client,
213            handle,
214            _transaction: false,
215        })
216    }
217
218    #[allow(refining_impl_trait)]
219    fn begin(&mut self) -> impl Future<Output = Result<PostgresTransaction<'_>>> {
220        PostgresTransaction::new(self)
221    }
222
223    #[allow(refining_impl_trait)]
224    async fn disconnect(self) -> Result<()> {
225        drop(self.client);
226        if let Err(e) = self.handle.await {
227            let error = Error::new(e).context("While disconnecting from Postgres");
228            log::error!("{:#}", error);
229            return Err(error);
230        }
231        Ok(())
232    }
233}