1use crate::{
2 PostgresDriver, PostgresPrepared, PostgresTransaction, ValueWrap, postgres_type_to_value,
3 util::{
4 stream_postgres_row_to_tank_row, stream_postgres_simple_query_message_to_tank_query_result,
5 },
6};
7use async_stream::try_stream;
8use openssl::ssl::{SslConnector, SslFiletype, SslMethod, SslVerifyMode};
9use postgres_openssl::MakeTlsConnector;
10use std::{borrow::Cow, env, mem, path::PathBuf, pin::pin, str::FromStr, sync::Arc};
11use tank_core::{
12 AsQuery, Connection, Driver, Error, ErrorContext, Executor, Query, QueryResult, Result,
13 Transaction,
14 future::Either,
15 stream::{Stream, StreamExt, TryStreamExt},
16 truncate_long,
17};
18use tokio::{spawn, task::JoinHandle};
19use tokio_postgres::NoTls;
20use url::Url;
21use urlencoding::decode;
22
23#[derive(Debug)]
24pub struct PostgresConnection {
25 pub(crate) client: tokio_postgres::Client,
26 pub(crate) handle: JoinHandle<()>,
27 pub(crate) _transaction: bool,
28}
29
30impl Executor for PostgresConnection {
31 type Driver = PostgresDriver;
32
33 fn driver(&self) -> &Self::Driver {
34 &PostgresDriver {}
35 }
36
37 async fn prepare(&mut self, sql: String) -> Result<Query<Self::Driver>> {
38 let sql = sql.trim_end().trim_end_matches(';');
39 Ok(
40 PostgresPrepared::new(self.client.prepare(&sql).await.map_err(|e| {
41 let error = Error::new(e).context(format!(
42 "While preparing the query:\n{}",
43 truncate_long!(sql)
44 ));
45 log::error!("{:#}", error);
46 error
47 })?)
48 .into(),
49 )
50 }
51
52 fn run<'s>(
53 &'s mut self,
54 query: impl AsQuery<Self::Driver> + 's,
55 ) -> impl Stream<Item = Result<QueryResult>> + Send {
56 let mut query = query.as_query();
57 let context = Arc::new(format!("While running the query:\n{}", query.as_mut()));
58 let mut owned = mem::take(query.as_mut());
59 match owned {
60 Query::Raw(sql) => {
61 Either::Left(stream_postgres_simple_query_message_to_tank_query_result(
62 async move || self.client.simple_query_raw(&sql).await.map_err(Into::into),
63 ))
64 }
65 Query::Prepared(..) => Either::Right(try_stream! {
66 let mut transaction = self.begin().await?;
67 {
68 let mut stream = pin!(transaction.run(&mut owned));
69 while let Some(value) = stream.next().await.transpose()? {
70 yield value;
71 }
72 }
73 transaction.commit().await?;
74 *query.as_mut() = mem::take(&mut owned);
75 }),
76 }
77 .map_err(move |e: Error| {
78 let error = e.context(context.clone());
79 log::error!("{:#}", error);
80 error
81 })
82 }
83
84 fn fetch<'s>(
85 &'s mut self,
86 query: impl AsQuery<Self::Driver> + 's,
87 ) -> impl Stream<Item = Result<tank_core::RowLabeled>> + Send + 's {
88 let mut query = query.as_query();
89 let context = Arc::new(format!("While fetching the query:\n{}", query.as_mut()));
90 let owned = mem::take(query.as_mut());
91 stream_postgres_row_to_tank_row(async move || {
92 let row_stream = match owned {
93 Query::Raw(mut sql) => {
94 let stream = self
95 .client
96 .query_raw(&sql, Vec::<ValueWrap>::new())
97 .await
98 .map_err(|e| Error::new(e).context(context.clone()))?;
99 *query.as_mut() = Query::Raw(mem::take(&mut sql));
100 stream
101 }
102 Query::Prepared(mut prepared) => {
103 let mut params = mem::take(&mut prepared.params);
104 let types = prepared.statement.params();
105
106 for (i, param) in params.iter_mut().enumerate() {
107 *param = ValueWrap(
108 mem::take(&mut param.0).try_as(&postgres_type_to_value(&types[i]))?,
109 );
110 }
111 let stream = self
112 .client
113 .query_raw(&prepared.statement, params)
114 .await
115 .map_err(|e| Error::new(e).context(context.clone()))?;
116 *query.as_mut() = Query::Prepared(prepared);
117 stream
118 }
119 };
120 Ok(row_stream).map_err(|e| {
121 log::error!("{:#}", e);
122 e
123 })
124 })
125 }
126}
127
128impl Connection for PostgresConnection {
129 #[allow(refining_impl_trait)]
130 async fn connect(url: Cow<'static, str>) -> Result<PostgresConnection> {
131 let context = || format!("While trying to connect to `{}`", truncate_long!(url));
132 let url = decode(&url).with_context(context)?;
133 let prefix = format!("{}://", <Self::Driver as Driver>::NAME);
134 if !url.starts_with(&prefix) {
135 let error = Error::msg(format!(
136 "Postgres connection url must start with `{}`",
137 &prefix
138 ))
139 .context(context());
140 log::error!("{:#}", error);
141 return Err(error);
142 }
143 let mut url = Url::parse(&url).with_context(context)?;
144 let mut take_url_param = |key: &str, env_var: &str, remove: bool| {
145 let value = url
146 .query_pairs()
147 .find_map(|(k, v)| if k == key { Some(v) } else { None })
148 .map(|v| v.to_string());
149 if remove && let Some(..) = value {
150 let mut result = url.clone();
151 result.set_query(None);
152 result
153 .query_pairs_mut()
154 .extend_pairs(url.query_pairs().filter(|(k, _)| k != key));
155 url = result;
156 };
157 value.or_else(|| env::var(env_var).ok().map(Into::into))
158 };
159 let sslmode = take_url_param("sslmode", "PGSSLMODE", false).unwrap_or("disable".into());
160 let (client, handle) = if sslmode == "disable" {
161 let (client, connection) = tokio_postgres::connect(url.as_str(), NoTls).await?;
162 let handle = spawn(async move {
163 if let Err(error) = connection.await
164 && !error.is_closed()
165 {
166 log::error!("Postgres connection error: {:#?}", error);
167 }
168 });
169 (client, handle)
170 } else {
171 let mut builder = SslConnector::builder(SslMethod::tls())?;
172 let path = PathBuf::from_str(
173 take_url_param("sslrootcert", "PGSSLROOTCERT", true)
174 .as_deref()
175 .unwrap_or("~/.postgresql/root.crt"),
176 )
177 .context(context())?;
178 if path.exists() {
179 builder.set_ca_file(path)?;
180 }
181 let path = PathBuf::from_str(
182 take_url_param("sslcert", "PGSSLCERT", true)
183 .as_deref()
184 .unwrap_or("~/.postgresql/postgresql.crt"),
185 )
186 .context(context())?;
187 if path.exists() {
188 builder.set_certificate_chain_file(path)?;
189 }
190 let path = PathBuf::from_str(
191 take_url_param("sslkey", "PGSSLKEY", true)
192 .as_deref()
193 .unwrap_or("~/.postgresql/postgresql.key"),
194 )
195 .context(context())?;
196 if path.exists() {
197 builder.set_private_key_file(path, SslFiletype::PEM)?;
198 }
199 builder.set_verify(SslVerifyMode::PEER);
200 let connector = MakeTlsConnector::new(builder.build());
201 let (client, connection) = tokio_postgres::connect(url.as_str(), connector).await?;
202 let handle = spawn(async move {
203 if let Err(error) = connection.await
204 && !error.is_closed()
205 {
206 log::error!("Postgres connection error: {:#?}", error);
207 }
208 });
209 (client, handle)
210 };
211 Ok(Self {
212 client,
213 handle,
214 _transaction: false,
215 })
216 }
217
218 #[allow(refining_impl_trait)]
219 fn begin(&mut self) -> impl Future<Output = Result<PostgresTransaction<'_>>> {
220 PostgresTransaction::new(self)
221 }
222
223 #[allow(refining_impl_trait)]
224 async fn disconnect(self) -> Result<()> {
225 drop(self.client);
226 if let Err(e) = self.handle.await {
227 let error = Error::new(e).context("While disconnecting from Postgres");
228 log::error!("{:#}", error);
229 return Err(error);
230 }
231 Ok(())
232 }
233}