tank_postgres/
connection.rs

1use crate::{
2    PostgresDriver, PostgresPrepared, PostgresTransaction, ValueWrap,
3    util::{row_to_tank_row, stream_postgres_simple_query_message_to_tank_query_result},
4};
5use async_stream::try_stream;
6use openssl::ssl::{SslConnector, SslFiletype, SslMethod, SslVerifyMode};
7use postgres_openssl::MakeTlsConnector;
8use std::{borrow::Cow, env, mem, path::PathBuf, pin::pin, str::FromStr, sync::Arc};
9use tank_core::{
10    AsQuery, Connection, Driver, Error, ErrorContext, Executor, Query, QueryResult, Result,
11    Transaction,
12    future::Either,
13    stream::{Stream, StreamExt, TryStreamExt},
14    truncate_long,
15};
16use tokio::{spawn, task::JoinHandle};
17use tokio_postgres::NoTls;
18use url::Url;
19use urlencoding::decode;
20
21#[derive(Debug)]
22pub struct PostgresConnection {
23    pub(crate) client: tokio_postgres::Client,
24    pub(crate) handle: JoinHandle<()>,
25    pub(crate) _transaction: bool,
26}
27
28impl Executor for PostgresConnection {
29    type Driver = PostgresDriver;
30
31    fn driver(&self) -> &Self::Driver {
32        &PostgresDriver {}
33    }
34
35    async fn prepare(&mut self, sql: String) -> Result<Query<Self::Driver>> {
36        let sql = sql.trim_end().trim_end_matches(';');
37        Ok(
38            PostgresPrepared::new(self.client.prepare(&sql).await.map_err(|e| {
39                let error = Error::new(e).context(format!(
40                    "While preparing the query:\n{}",
41                    truncate_long!(sql)
42                ));
43                log::error!("{:#}", error);
44                error
45            })?)
46            .into(),
47        )
48    }
49
50    fn run<'s>(
51        &'s mut self,
52        query: impl AsQuery<Self::Driver> + 's,
53    ) -> impl Stream<Item = Result<QueryResult>> + Send {
54        let mut query = query.as_query();
55        let context = Arc::new(format!("While running the query:\n{}", query.as_mut()));
56        let mut owned = mem::take(query.as_mut());
57        match owned {
58            Query::Raw(sql) => {
59                Either::Left(stream_postgres_simple_query_message_to_tank_query_result(
60                    async move || self.client.simple_query_raw(&sql).await.map_err(Into::into),
61                ))
62            }
63            Query::Prepared(..) => Either::Right(try_stream! {
64                let mut transaction = self.begin().await?;
65                {
66                    let mut stream = pin!(transaction.run(&mut owned));
67                    while let Some(value) = stream.next().await.transpose()? {
68                        yield value;
69                    }
70                }
71                transaction.commit().await?;
72                *query.as_mut() = mem::take(&mut owned);
73            }),
74        }
75        .map_err(move |e: Error| {
76            let error = e.context(context.clone());
77            log::error!("{:#}", error);
78            error
79        })
80    }
81
82    fn fetch<'s>(
83        &'s mut self,
84        query: impl AsQuery<Self::Driver> + 's,
85    ) -> impl Stream<Item = Result<tank_core::RowLabeled>> + Send + 's {
86        let mut query = query.as_query();
87        let context = Arc::new(format!("While fetching the query:\n{}", query.as_mut()));
88        let owned = mem::take(query.as_mut());
89        match owned {
90            Query::Raw(sql) => {
91                let stream = async move || {
92                    self.client
93                        .query_raw(&sql, Vec::<ValueWrap>::new())
94                        .await
95                        .map_err(|e| {
96                            let error = Error::new(e).context(context.clone());
97                            log::error!("{:#}", error);
98                            error
99                        })
100                };
101                let stream = try_stream! {
102                    let stream = stream().await?;
103                    let mut stream = pin!(stream);
104                    let mut labels: Option<tank_core::RowNames> = None;
105                    while let Some(row) = stream.next().await.transpose()? {
106                        let labels = labels.get_or_insert_with(|| {
107                            row.columns().iter().map(|c| c.name().to_string()).collect()
108                        });
109                        yield tank_core::RowLabeled {
110                            labels: labels.clone(),
111                            values: row_to_tank_row(row)?.into(),
112                        };
113                    }
114                };
115                Either::Left(stream)
116            }
117            Query::Prepared(prepared) => Either::Right(
118                try_stream! {
119                    let mut transaction = self.begin().await?;
120                    {
121                        let mut query = Query::Prepared(prepared);
122                        let mut stream = pin!(transaction.fetch(&mut query));
123                        while let Some(value) = stream.next().await.transpose()? {
124                            yield value;
125                        }
126                    }
127                    transaction.commit().await?;
128                }
129                .map_err(move |e: Error| {
130                    let error = e.context(context.clone());
131                    log::error!("{:#}", error);
132                    error
133                }),
134            ),
135        }
136    }
137}
138
139impl Connection for PostgresConnection {
140    #[allow(refining_impl_trait)]
141    async fn connect(url: Cow<'static, str>) -> Result<PostgresConnection> {
142        let context = || format!("While trying to connect to `{}`", truncate_long!(url));
143        let url = decode(&url).with_context(context)?;
144        let prefix = format!("{}://", <Self::Driver as Driver>::NAME);
145        if !url.starts_with(&prefix) {
146            let error = Error::msg(format!(
147                "Postgres connection url must start with `{}`",
148                &prefix
149            ))
150            .context(context());
151            log::error!("{:#}", error);
152            return Err(error);
153        }
154        let mut url = Url::parse(&url).with_context(context)?;
155        let mut take_url_param = |key: &str, env_var: &str, remove: bool| {
156            let value = url
157                .query_pairs()
158                .find_map(|(k, v)| if k == key { Some(v) } else { None })
159                .map(|v| v.to_string());
160            if remove && let Some(..) = value {
161                let mut result = url.clone();
162                result.set_query(None);
163                result
164                    .query_pairs_mut()
165                    .extend_pairs(url.query_pairs().filter(|(k, _)| k != key));
166                url = result;
167            };
168            value.or_else(|| env::var(env_var).ok().map(Into::into))
169        };
170        let sslmode = take_url_param("sslmode", "PGSSLMODE", false).unwrap_or("disable".into());
171        let (client, handle) = if sslmode == "disable" {
172            let (client, connection) = tokio_postgres::connect(url.as_str(), NoTls).await?;
173            let handle = spawn(async move {
174                if let Err(error) = connection.await
175                    && !error.is_closed()
176                {
177                    log::error!("Postgres connection error: {:#?}", error);
178                }
179            });
180            (client, handle)
181        } else {
182            let mut builder = SslConnector::builder(SslMethod::tls())?;
183            let path = PathBuf::from_str(
184                take_url_param("sslrootcert", "PGSSLROOTCERT", true)
185                    .as_deref()
186                    .unwrap_or("~/.postgresql/root.crt"),
187            )
188            .context(context())?;
189            if path.exists() {
190                builder.set_ca_file(path)?;
191            }
192            let path = PathBuf::from_str(
193                take_url_param("sslcert", "PGSSLCERT", true)
194                    .as_deref()
195                    .unwrap_or("~/.postgresql/postgresql.crt"),
196            )
197            .context(context())?;
198            if path.exists() {
199                builder.set_certificate_chain_file(path)?;
200            }
201            let path = PathBuf::from_str(
202                take_url_param("sslkey", "PGSSLKEY", true)
203                    .as_deref()
204                    .unwrap_or("~/.postgresql/postgresql.key"),
205            )
206            .context(context())?;
207            if path.exists() {
208                builder.set_private_key_file(path, SslFiletype::PEM)?;
209            }
210            builder.set_verify(SslVerifyMode::PEER);
211            let connector = MakeTlsConnector::new(builder.build());
212            let (client, connection) = tokio_postgres::connect(url.as_str(), connector).await?;
213            let handle = spawn(async move {
214                if let Err(error) = connection.await
215                    && !error.is_closed()
216                {
217                    log::error!("Postgres connection error: {:#?}", error);
218                }
219            });
220            (client, handle)
221        };
222        Ok(Self {
223            client,
224            handle,
225            _transaction: false,
226        })
227    }
228
229    #[allow(refining_impl_trait)]
230    fn begin(&mut self) -> impl Future<Output = Result<PostgresTransaction<'_>>> {
231        PostgresTransaction::new(self)
232    }
233
234    #[allow(refining_impl_trait)]
235    async fn disconnect(self) -> Result<()> {
236        drop(self.client);
237        if let Err(e) = self.handle.await {
238            let error = Error::new(e).context("While disconnecting from Postgres");
239            log::error!("{:#}", error);
240            return Err(error);
241        }
242        Ok(())
243    }
244}