Skip to main content

tank_postgres/
connection.rs

1use crate::{
2    PostgresDriver, PostgresPrepared, PostgresTransaction, ValueWrap,
3    util::{
4        postgres_type_to_value, stream_postgres_row_to_tank_row,
5        stream_postgres_simple_query_message_to_tank_query_result, value_to_postgres_type,
6    },
7};
8use async_stream::try_stream;
9use openssl::ssl::{SslConnector, SslFiletype, SslMethod, SslVerifyMode};
10use postgres_openssl::MakeTlsConnector;
11use postgres_types::ToSql;
12use std::{
13    borrow::Cow,
14    env, mem,
15    path::PathBuf,
16    pin::{Pin, pin},
17    str::FromStr,
18};
19use tank_core::{
20    AsQuery, Connection, Driver, DynQuery, Entity, Error, ErrorContext, Executor, Query,
21    QueryResult, RawQuery, Result, RowsAffected, SqlWriter, Transaction,
22    future::Either,
23    stream::{Stream, StreamExt, TryStreamExt},
24    truncate_long,
25};
26use tokio::{spawn, task::JoinHandle};
27use tokio_postgres::{NoTls, binary_copy::BinaryCopyInWriter};
28
29/// PostgreSQL connection.
30#[derive(Debug)]
31pub struct PostgresConnection {
32    pub(crate) client: tokio_postgres::Client,
33    pub(crate) handle: JoinHandle<()>,
34    pub(crate) _transaction: bool,
35}
36
37impl Executor for PostgresConnection {
38    type Driver = PostgresDriver;
39
40    async fn do_prepare(&mut self, sql: String) -> Result<Query<PostgresDriver>> {
41        let sql = sql.as_str().trim_end().trim_end_matches(';');
42        Ok(
43            PostgresPrepared::new(self.client.prepare(&sql).await.map_err(|e| {
44                let error = Error::new(e).context(format!(
45                    "While preparing the query:\n{}",
46                    truncate_long!(sql)
47                ));
48                log::error!("{:#}", error);
49                error
50            })?)
51            .into(),
52        )
53    }
54
55    fn run<'s>(
56        &'s mut self,
57        query: impl AsQuery<PostgresDriver> + 's,
58    ) -> impl Stream<Item = Result<QueryResult>> + Send {
59        let mut query = query.as_query();
60        let context = format!("While running the query:\n{}", query.as_mut());
61        let mut owned = mem::take(query.as_mut());
62        match owned {
63            Query::Raw(raw) => Either::Left(try_stream! {
64                let sql = &raw.0;
65                {
66                    let stream = stream_postgres_simple_query_message_to_tank_query_result(
67                        async move || self.client.simple_query_raw(sql).await.map_err(Into::into),
68                    );
69                    let mut stream = pin!(stream);
70                    while let Some(value) = stream.next().await.transpose()? {
71                        yield value;
72                    }
73                }
74                *query.as_mut() = Query::Raw(raw);
75            }),
76            Query::Prepared(..) => Either::Right(try_stream! {
77                let mut transaction = self.begin().await?;
78                {
79                    let mut stream = pin!(transaction.run(&mut owned));
80                    while let Some(value) = stream.next().await.transpose()? {
81                        yield value;
82                    }
83                }
84                transaction.commit().await?;
85                *query.as_mut() = mem::take(&mut owned);
86            }),
87        }
88        .map_err(move |e: Error| {
89            let error = e.context(context.clone());
90            log::error!("{:#}", error);
91            error
92        })
93    }
94
95    fn fetch<'s>(
96        &'s mut self,
97        query: impl AsQuery<PostgresDriver> + 's,
98    ) -> impl Stream<Item = Result<tank_core::RowLabeled>> + Send {
99        let mut query = query.as_query();
100        let context = format!("While fetching the query:\n{}", query.as_mut());
101        let owned = mem::take(query.as_mut());
102        stream_postgres_row_to_tank_row(async move || {
103            let row_stream = match owned {
104                Query::Raw(RawQuery(sql)) => {
105                    let stream = self
106                        .client
107                        .query_raw(&sql, Vec::<ValueWrap>::new())
108                        .await
109                        .map_err(|e| Error::new(e).context(context.clone()))?;
110                    *query.as_mut() = Query::raw(sql);
111                    stream
112                }
113                Query::Prepared(mut prepared) => {
114                    let mut params = prepared.take_params();
115                    let types = prepared.statement.params();
116
117                    for (i, param) in params.iter_mut().enumerate() {
118                        *param = ValueWrap(Cow::Owned(
119                            mem::take(param)
120                                .0
121                                .into_owned()
122                                .try_as(&postgres_type_to_value(&types[i]))?,
123                        ));
124                    }
125                    let stream = self
126                        .client
127                        .query_raw(&prepared.statement, params)
128                        .await
129                        .map_err(|e| Error::new(e).context(context.clone()))?;
130                    *query.as_mut() = Query::Prepared(prepared);
131                    stream
132                }
133            };
134            Ok(row_stream).map_err(|e| {
135                log::error!("{:#}", e);
136                e
137            })
138        })
139    }
140
141    async fn append<'a, E, It>(&mut self, entities: It) -> Result<RowsAffected>
142    where
143        E: Entity + 'a,
144        It: IntoIterator<Item = &'a E> + Send,
145        <It as IntoIterator>::IntoIter: Send,
146    {
147        let writer = self.driver().sql_writer();
148        let context = || {
149            format!(
150                "While appending to the table `{}`",
151                E::table().full_name(writer.separator())
152            )
153        };
154        let mut query = DynQuery::default();
155        writer.write_copy::<E>(&mut query);
156        let sink = match self
157            .client
158            .copy_in(&query.as_str() as &str)
159            .await
160            .with_context(context)
161        {
162            Ok(v) => v,
163            Err(e) => {
164                log::error!("{e:#}");
165                return Err(e);
166            }
167        };
168        let types: Vec<_> = E::columns()
169            .into_iter()
170            .map(|c| value_to_postgres_type(&c.value))
171            .collect();
172        let writer = BinaryCopyInWriter::new(sink, &types);
173        let mut writer = pin!(writer);
174        let columns_len = E::columns().len();
175        let mut values = Vec::<ValueWrap>::with_capacity(columns_len);
176        let mut refs = Vec::<&(dyn ToSql + Sync)>::with_capacity(columns_len);
177        for entity in entities.into_iter() {
178            values.extend(
179                entity
180                    .row_full()
181                    .into_iter()
182                    .map(|v| ValueWrap(Cow::Owned(v))),
183            );
184            refs.extend(
185                values
186                    .iter()
187                    .map(|v| unsafe { &*(v as &(dyn ToSql + Sync) as *const _) }),
188            );
189            match Pin::as_mut(&mut writer)
190                .write(&refs)
191                .await
192                .with_context(context)
193            {
194                Ok(_) => {}
195                Err(e) => {
196                    log::error!("{e:#}");
197                    return Err(e);
198                }
199            };
200            refs.clear();
201            values.clear();
202        }
203        match writer.finish().await.with_context(context) {
204            Ok(v) => Ok(RowsAffected {
205                rows_affected: Some(v),
206                last_affected_id: None,
207            }),
208            Err(e) => {
209                log::error!("{e:#}");
210                return Err(e);
211            }
212        }
213    }
214}
215
216impl Connection for PostgresConnection {
217    async fn connect(url: Cow<'static, str>) -> Result<Self> {
218        let context = format!("While trying to connect to `{}`", truncate_long!(url));
219        let mut url = Self::sanitize_url(url)?;
220        let mut take_url_param = |key: &str, env_var: &str, remove: bool| {
221            let value = url
222                .query_pairs()
223                .find_map(|(k, v)| if k == key { Some(v) } else { None })
224                .map(|v| v.to_string());
225            if remove && let Some(..) = value {
226                let mut result = url.clone();
227                result.set_query(None);
228                result
229                    .query_pairs_mut()
230                    .extend_pairs(url.query_pairs().filter(|(k, _)| k != key));
231                url = result;
232            };
233            value.or_else(|| env::var(env_var).ok().map(Into::into))
234        };
235        let sslmode = take_url_param("sslmode", "PGSSLMODE", false).unwrap_or("disable".into());
236        let (client, handle) = if sslmode == "disable" {
237            let (client, connection) = tokio_postgres::connect(url.as_str(), NoTls).await?;
238            let handle = spawn(async move {
239                if let Err(error) = connection.await
240                    && !error.is_closed()
241                {
242                    log::error!("Postgres connection error: {:#?}", error);
243                }
244            });
245            (client, handle)
246        } else {
247            let mut builder = SslConnector::builder(SslMethod::tls())?;
248            let path = PathBuf::from_str(
249                take_url_param("sslrootcert", "PGSSLROOTCERT", true)
250                    .as_deref()
251                    .unwrap_or("~/.postgresql/root.crt"),
252            )
253            .with_context(|| context.clone())?;
254            if path.exists() {
255                builder.set_ca_file(path)?;
256            }
257            let path = PathBuf::from_str(
258                take_url_param("sslcert", "PGSSLCERT", true)
259                    .as_deref()
260                    .unwrap_or("~/.postgresql/postgresql.crt"),
261            )
262            .with_context(|| context.clone())?;
263            if path.exists() {
264                builder.set_certificate_chain_file(path)?;
265            }
266            let path = PathBuf::from_str(
267                take_url_param("sslkey", "PGSSLKEY", true)
268                    .as_deref()
269                    .unwrap_or("~/.postgresql/postgresql.key"),
270            )
271            .with_context(|| context.clone())?;
272            if path.exists() {
273                builder.set_private_key_file(path, SslFiletype::PEM)?;
274            }
275            builder.set_verify(SslVerifyMode::PEER);
276            let connector = MakeTlsConnector::new(builder.build());
277            let (client, connection) = tokio_postgres::connect(url.as_str(), connector).await?;
278            let handle = spawn(async move {
279                if let Err(error) = connection.await
280                    && !error.is_closed()
281                {
282                    log::error!("Postgres connection error: {:#?}", error);
283                }
284            });
285            (client, handle)
286        };
287        Ok(Self {
288            client,
289            handle,
290            _transaction: false,
291        })
292    }
293
294    fn begin(&mut self) -> impl Future<Output = Result<PostgresTransaction<'_>>> {
295        PostgresTransaction::new(self)
296    }
297
298    async fn disconnect(self) -> Result<()> {
299        drop(self.client);
300        if let Err(e) = self.handle.await {
301            let error = Error::new(e).context("While disconnecting from Postgres");
302            log::error!("{:#}", error);
303            return Err(error);
304        }
305        Ok(())
306    }
307}