1use crate::{
2 PostgresDriver, PostgresPrepared, PostgresTransaction, ValueWrap,
3 util::{
4 postgres_type_to_value, stream_postgres_row_to_tank_row,
5 stream_postgres_simple_query_message_to_tank_query_result, value_to_postgres_type,
6 },
7};
8use async_stream::try_stream;
9use openssl::ssl::{SslConnector, SslFiletype, SslMethod, SslVerifyMode};
10use postgres_openssl::MakeTlsConnector;
11use postgres_types::ToSql;
12use std::{
13 borrow::Cow,
14 env, mem,
15 path::PathBuf,
16 pin::{Pin, pin},
17 str::FromStr,
18};
19use tank_core::{
20 AsQuery, Connection, Driver, DynQuery, Entity, Error, ErrorContext, Executor, Query,
21 QueryResult, RawQuery, Result, RowsAffected, SqlWriter, Transaction,
22 future::Either,
23 stream::{Stream, StreamExt, TryStreamExt},
24 truncate_long,
25};
26use tokio::{spawn, task::JoinHandle};
27use tokio_postgres::{NoTls, binary_copy::BinaryCopyInWriter};
28
29#[derive(Debug)]
31pub struct PostgresConnection {
32 pub(crate) client: tokio_postgres::Client,
33 pub(crate) handle: JoinHandle<()>,
34 pub(crate) _transaction: bool,
35}
36
37impl Executor for PostgresConnection {
38 type Driver = PostgresDriver;
39
40 async fn do_prepare(&mut self, sql: String) -> Result<Query<PostgresDriver>> {
41 let sql = sql.as_str().trim_end().trim_end_matches(';');
42 Ok(
43 PostgresPrepared::new(self.client.prepare(&sql).await.map_err(|e| {
44 let error = Error::new(e).context(format!(
45 "While preparing the query:\n{}",
46 truncate_long!(sql)
47 ));
48 log::error!("{:#}", error);
49 error
50 })?)
51 .into(),
52 )
53 }
54
55 fn run<'s>(
56 &'s mut self,
57 query: impl AsQuery<PostgresDriver> + 's,
58 ) -> impl Stream<Item = Result<QueryResult>> + Send {
59 let mut query = query.as_query();
60 let context = format!("While running the query:\n{}", query.as_mut());
61 let mut owned = mem::take(query.as_mut());
62 match owned {
63 Query::Raw(raw) => Either::Left(try_stream! {
64 let sql = &raw.0;
65 {
66 let stream = stream_postgres_simple_query_message_to_tank_query_result(
67 async move || self.client.simple_query_raw(sql).await.map_err(Into::into),
68 );
69 let mut stream = pin!(stream);
70 while let Some(value) = stream.next().await.transpose()? {
71 yield value;
72 }
73 }
74 *query.as_mut() = Query::Raw(raw);
75 }),
76 Query::Prepared(..) => Either::Right(try_stream! {
77 let mut transaction = self.begin().await?;
78 {
79 let mut stream = pin!(transaction.run(&mut owned));
80 while let Some(value) = stream.next().await.transpose()? {
81 yield value;
82 }
83 }
84 transaction.commit().await?;
85 *query.as_mut() = mem::take(&mut owned);
86 }),
87 }
88 .map_err(move |e: Error| {
89 let error = e.context(context.clone());
90 log::error!("{:#}", error);
91 error
92 })
93 }
94
95 fn fetch<'s>(
96 &'s mut self,
97 query: impl AsQuery<PostgresDriver> + 's,
98 ) -> impl Stream<Item = Result<tank_core::RowLabeled>> + Send {
99 let mut query = query.as_query();
100 let context = format!("While fetching the query:\n{}", query.as_mut());
101 let owned = mem::take(query.as_mut());
102 stream_postgres_row_to_tank_row(async move || {
103 let row_stream = match owned {
104 Query::Raw(RawQuery(sql)) => {
105 let stream = self
106 .client
107 .query_raw(&sql, Vec::<ValueWrap>::new())
108 .await
109 .map_err(|e| Error::new(e).context(context.clone()))?;
110 *query.as_mut() = Query::raw(sql);
111 stream
112 }
113 Query::Prepared(mut prepared) => {
114 let mut params = prepared.take_params();
115 let types = prepared.statement.params();
116
117 for (i, param) in params.iter_mut().enumerate() {
118 *param = ValueWrap(Cow::Owned(
119 mem::take(param)
120 .0
121 .into_owned()
122 .try_as(&postgres_type_to_value(&types[i]))?,
123 ));
124 }
125 let stream = self
126 .client
127 .query_raw(&prepared.statement, params)
128 .await
129 .map_err(|e| Error::new(e).context(context.clone()))?;
130 *query.as_mut() = Query::Prepared(prepared);
131 stream
132 }
133 };
134 Ok(row_stream).map_err(|e| {
135 log::error!("{:#}", e);
136 e
137 })
138 })
139 }
140
141 async fn append<'a, E, It>(&mut self, entities: It) -> Result<RowsAffected>
142 where
143 E: Entity + 'a,
144 It: IntoIterator<Item = &'a E> + Send,
145 <It as IntoIterator>::IntoIter: Send,
146 {
147 let writer = self.driver().sql_writer();
148 let context = || {
149 format!(
150 "While appending to the table `{}`",
151 E::table().full_name(writer.separator())
152 )
153 };
154 let mut query = DynQuery::default();
155 writer.write_copy::<E>(&mut query);
156 let sink = match self
157 .client
158 .copy_in(&query.as_str() as &str)
159 .await
160 .with_context(context)
161 {
162 Ok(v) => v,
163 Err(e) => {
164 log::error!("{e:#}");
165 return Err(e);
166 }
167 };
168 let types: Vec<_> = E::columns()
169 .into_iter()
170 .map(|c| value_to_postgres_type(&c.value))
171 .collect();
172 let writer = BinaryCopyInWriter::new(sink, &types);
173 let mut writer = pin!(writer);
174 let columns_len = E::columns().len();
175 let mut values = Vec::<ValueWrap>::with_capacity(columns_len);
176 let mut refs = Vec::<&(dyn ToSql + Sync)>::with_capacity(columns_len);
177 for entity in entities.into_iter() {
178 values.extend(
179 entity
180 .row_full()
181 .into_iter()
182 .map(|v| ValueWrap(Cow::Owned(v))),
183 );
184 refs.extend(
185 values
186 .iter()
187 .map(|v| unsafe { &*(v as &(dyn ToSql + Sync) as *const _) }),
188 );
189 match Pin::as_mut(&mut writer)
190 .write(&refs)
191 .await
192 .with_context(context)
193 {
194 Ok(_) => {}
195 Err(e) => {
196 log::error!("{e:#}");
197 return Err(e);
198 }
199 };
200 refs.clear();
201 values.clear();
202 }
203 match writer.finish().await.with_context(context) {
204 Ok(v) => Ok(RowsAffected {
205 rows_affected: Some(v),
206 last_affected_id: None,
207 }),
208 Err(e) => {
209 log::error!("{e:#}");
210 return Err(e);
211 }
212 }
213 }
214}
215
216impl Connection for PostgresConnection {
217 async fn connect(url: Cow<'static, str>) -> Result<Self> {
218 let context = format!("While trying to connect to `{}`", truncate_long!(url));
219 let mut url = Self::sanitize_url(url)?;
220 let mut take_url_param = |key: &str, env_var: &str, remove: bool| {
221 let value = url
222 .query_pairs()
223 .find_map(|(k, v)| if k == key { Some(v) } else { None })
224 .map(|v| v.to_string());
225 if remove && let Some(..) = value {
226 let mut result = url.clone();
227 result.set_query(None);
228 result
229 .query_pairs_mut()
230 .extend_pairs(url.query_pairs().filter(|(k, _)| k != key));
231 url = result;
232 };
233 value.or_else(|| env::var(env_var).ok().map(Into::into))
234 };
235 let sslmode = take_url_param("sslmode", "PGSSLMODE", false).unwrap_or("disable".into());
236 let (client, handle) = if sslmode == "disable" {
237 let (client, connection) = tokio_postgres::connect(url.as_str(), NoTls).await?;
238 let handle = spawn(async move {
239 if let Err(error) = connection.await
240 && !error.is_closed()
241 {
242 log::error!("Postgres connection error: {:#?}", error);
243 }
244 });
245 (client, handle)
246 } else {
247 let mut builder = SslConnector::builder(SslMethod::tls())?;
248 let path = PathBuf::from_str(
249 take_url_param("sslrootcert", "PGSSLROOTCERT", true)
250 .as_deref()
251 .unwrap_or("~/.postgresql/root.crt"),
252 )
253 .with_context(|| context.clone())?;
254 if path.exists() {
255 builder.set_ca_file(path)?;
256 }
257 let path = PathBuf::from_str(
258 take_url_param("sslcert", "PGSSLCERT", true)
259 .as_deref()
260 .unwrap_or("~/.postgresql/postgresql.crt"),
261 )
262 .with_context(|| context.clone())?;
263 if path.exists() {
264 builder.set_certificate_chain_file(path)?;
265 }
266 let path = PathBuf::from_str(
267 take_url_param("sslkey", "PGSSLKEY", true)
268 .as_deref()
269 .unwrap_or("~/.postgresql/postgresql.key"),
270 )
271 .with_context(|| context.clone())?;
272 if path.exists() {
273 builder.set_private_key_file(path, SslFiletype::PEM)?;
274 }
275 builder.set_verify(SslVerifyMode::PEER);
276 let connector = MakeTlsConnector::new(builder.build());
277 let (client, connection) = tokio_postgres::connect(url.as_str(), connector).await?;
278 let handle = spawn(async move {
279 if let Err(error) = connection.await
280 && !error.is_closed()
281 {
282 log::error!("Postgres connection error: {:#?}", error);
283 }
284 });
285 (client, handle)
286 };
287 Ok(Self {
288 client,
289 handle,
290 _transaction: false,
291 })
292 }
293
294 fn begin(&mut self) -> impl Future<Output = Result<PostgresTransaction<'_>>> {
295 PostgresTransaction::new(self)
296 }
297
298 async fn disconnect(self) -> Result<()> {
299 drop(self.client);
300 if let Err(e) = self.handle.await {
301 let error = Error::new(e).context("While disconnecting from Postgres");
302 log::error!("{:#}", error);
303 return Err(error);
304 }
305 Ok(())
306 }
307}