chuchi_postgres/
connection.rs

1// use crate::table::{Table, TableTemplate};
2
3use std::borrow::Borrow;
4use std::fmt::Write;
5
6use deadpool_postgres::Metrics;
7use deadpool_postgres::{ClientWrapper, Object};
8
9use futures_util::StreamExt;
10use futures_util::TryStreamExt;
11use futures_util::pin_mut;
12use postgres_types::{BorrowToSql, ToSql, Type};
13use tokio_postgres::Error as PgError;
14use tokio_postgres::error::SqlState;
15
16pub use deadpool::managed::TimeoutType;
17pub use deadpool_postgres::{Config, ConfigError};
18use tokio_postgres::Statement;
19use tokio_postgres::ToStatement;
20use tracing::error;
21
22use crate::Row;
23use crate::filter::Filter;
24use crate::filter::Limit;
25use crate::filter::WhereFilter;
26use crate::row::NamedColumns;
27use crate::row::RowStream;
28use crate::row::ToRowStatic;
29use crate::row::{FromRowOwned, ToRow};
30use crate::try2;
31
32#[derive(Debug, thiserror::Error)]
33#[non_exhaustive]
34pub enum Error {
35	#[error("Unique violation {0}")]
36	UniqueViolation(PgError),
37
38	#[error("Expected one row")]
39	ExpectedOneRow,
40
41	#[error("Other Postgres error {0}")]
42	Other(PgError),
43
44	#[error("Deserialization error {0}")]
45	Deserialize(Box<dyn std::error::Error + Send + Sync>),
46
47	#[error("Unknown error {0}")]
48	Unknown(Box<dyn std::error::Error + Send + Sync>),
49}
50
51impl From<PgError> for Error {
52	fn from(e: PgError) -> Self {
53		let Some(state) = e.code() else {
54			return Self::Other(e);
55		};
56
57		match state {
58			&SqlState::UNIQUE_VIOLATION => Self::UniqueViolation(e),
59			state => {
60				error!("db error with state {:?}", state);
61				Self::Other(e)
62			}
63		}
64	}
65}
66
67#[derive(Debug)]
68pub struct ConnectionOwned(pub(crate) Object);
69
70impl ConnectionOwned {
71	pub fn connection(&self) -> Connection {
72		Connection {
73			inner: ConnectionInner::Client(&self.0),
74		}
75	}
76
77	pub async fn transaction<'a>(
78		&'a mut self,
79	) -> Result<Transaction<'a>, Error> {
80		Ok(Transaction {
81			inner: self.0.transaction().await.map_err(Error::from)?,
82		})
83	}
84
85	pub fn metrics(&self) -> &Metrics {
86		Object::metrics(&self.0)
87	}
88}
89
90#[cfg(feature = "chuchi")]
91mod impl_chuchi {
92	use chuchi::{
93		extractor::Extractor, extractor_extract, extractor_prepare,
94		extractor_validate,
95	};
96
97	use crate::{Database, database::DatabaseError};
98
99	use super::*;
100
101	impl<'a, R> Extractor<'a, R> for ConnectionOwned {
102		type Error = DatabaseError;
103		type Prepared = Self;
104
105		extractor_validate!(|validate| {
106			assert!(
107				validate.resources.exists::<Database>(),
108				"Db resource not found"
109			);
110		});
111
112		extractor_prepare!(|prepare| {
113			let db = prepare.resources.get::<Database>().unwrap();
114			db.get().await
115		});
116
117		extractor_extract!(|extract| { Ok(extract.prepared) });
118	}
119}
120
121#[derive(Debug)]
122pub struct Transaction<'a> {
123	inner: deadpool_postgres::Transaction<'a>,
124}
125
126impl<'a> Transaction<'a> {
127	/// Returns a connection to the database
128	pub fn connection(&self) -> Connection<'_> {
129		Connection {
130			inner: ConnectionInner::Transaction(&self.inner),
131		}
132	}
133
134	/// See [`tokio_postgres::Transaction::commit()`]
135	pub async fn commit(self) -> Result<(), Error> {
136		self.inner.commit().await.map_err(Error::from)
137	}
138
139	/// See [`tokio_postgres::Transaction::rollback()`]
140	pub async fn rollback(self) -> Result<(), Error> {
141		self.inner.rollback().await.map_err(Error::from)
142	}
143}
144
145#[derive(Debug, Clone, Copy)]
146pub struct Connection<'a> {
147	inner: ConnectionInner<'a>,
148}
149
150#[derive(Debug, Clone, Copy)]
151enum ConnectionInner<'a> {
152	Client(&'a ClientWrapper),
153	Transaction(&'a deadpool_postgres::Transaction<'a>),
154}
155
156impl Connection<'_> {
157	// select
158
159	// how about the columns are a separat parameter, which contains
160	// an exact size iterator, and implementors can call
161	// select("table", R::select_columns(), filter)
162	// or select("table", &["column1", "column2"], filter)
163	pub async fn select<R>(
164		&self,
165		table: &str,
166		filter: impl Borrow<Filter<'_>>,
167	) -> Result<Vec<R>, Error>
168	where
169		R: FromRowOwned + NamedColumns,
170	{
171		let sql = format!(
172			"SELECT {} FROM \"{}\"{}",
173			R::select_columns(),
174			table,
175			filter.borrow()
176		);
177		let stmt = self.prepare_cached(&sql).await?;
178
179		self.query_raw(&stmt, filter.borrow().params.iter_to_sql())
180			.await?
181			.map(|row| {
182				row.and_then(|row| {
183					R::from_row_owned(row).map_err(Error::Deserialize)
184				})
185			})
186			.try_collect()
187			.await
188	}
189
190	// select_one
191	pub async fn select_one<R>(
192		&self,
193		table: &str,
194		filter: impl Borrow<Filter<'_>>,
195	) -> Result<R, Error>
196	where
197		R: FromRowOwned + NamedColumns,
198	{
199		let mut formatter = filter.borrow().to_formatter();
200
201		if matches!(formatter.limit, Limit::All) {
202			formatter.limit = &Limit::Fixed(1);
203		}
204
205		let sql = format!(
206			"SELECT {} FROM \"{}\"{}",
207			R::select_columns(),
208			table,
209			formatter
210		);
211		let stmt = self.prepare_cached(&sql).await?;
212
213		let row = self
214			.query_raw_opt(&stmt, filter.borrow().params.iter_to_sql())
215			.await
216			.and_then(|opt| opt.ok_or(Error::ExpectedOneRow))?;
217
218		R::from_row_owned(row).map_err(Error::Deserialize)
219	}
220
221	// select_opt
222	pub async fn select_opt<R>(
223		&self,
224		table: &str,
225		filter: impl Borrow<Filter<'_>>,
226	) -> Result<Option<R>, Error>
227	where
228		R: FromRowOwned + NamedColumns,
229	{
230		let mut formatter = filter.borrow().to_formatter();
231
232		if matches!(formatter.limit, Limit::All) {
233			formatter.limit = &Limit::Fixed(1);
234		}
235
236		let sql = format!(
237			"SELECT {} FROM \"{}\"{}",
238			R::select_columns(),
239			table,
240			formatter
241		);
242		let stmt = self.prepare_cached(&sql).await?;
243
244		self.query_raw_opt(&stmt, filter.borrow().params.iter_to_sql())
245			.await
246	}
247
248	/// count
249	///
250	/// A column is required because you should select a column which has some
251	/// indexes on it, this makes the call a lot cheaper
252	pub async fn count(
253		&self,
254		table: &str,
255		column: &str,
256		filter: impl Borrow<Filter<'_>>,
257	) -> Result<u64, Error> {
258		let sql = format!(
259			"SELECT COUNT(\"{column}\") FROM \"{table}\"{}",
260			filter.borrow()
261		);
262		let stmt = self.prepare_cached(&sql).await?;
263
264		let row: Row = self
265			.query_raw_opt(&stmt, filter.borrow().params.iter_to_sql())
266			.await
267			.and_then(|opt| opt.ok_or(Error::ExpectedOneRow))?;
268
269		Ok(row.get::<_, i64>(0) as u64)
270	}
271
272	// insert one
273	pub async fn insert<U>(&self, table: &str, item: &U) -> Result<(), Error>
274	where
275		U: ToRow,
276	{
277		let mut sql = format!("INSERT INTO \"{table}\" (");
278		item.insert_columns(&mut sql);
279		sql.push_str(") VALUES (");
280		item.insert_values(&mut sql);
281		sql.push(')');
282
283		let stmt = self.prepare_cached(&sql).await?;
284
285		self.execute_raw(&stmt, item.params()).await.map(|_| ())
286	}
287
288	// insert_many
289	pub async fn insert_many<U, I>(
290		&self,
291		table: &str,
292		items: I,
293	) -> Result<(), Error>
294	where
295		U: ToRowStatic,
296		I: IntoIterator,
297		I::Item: Borrow<U>,
298	{
299		let sql = format!(
300			"INSERT INTO \"{}\" ({}) VALUES ({})",
301			table,
302			U::insert_columns(),
303			U::insert_values()
304		);
305		let stmt = self.prepare_cached(&sql).await?;
306
307		for item in items {
308			self.execute_raw(&stmt, item.borrow().params()).await?;
309		}
310
311		Ok(())
312	}
313
314	// update
315	pub async fn update<U>(
316		&self,
317		table: &str,
318		item: &U,
319		filter: impl Borrow<WhereFilter<'_>>,
320	) -> Result<(), Error>
321	where
322		U: ToRow,
323	{
324		let filter = filter.borrow();
325		let mut formatter = filter.whr.to_formatter();
326		formatter.param_start = item.params_len();
327
328		let mut sql = format!("UPDATE \"{table}\" SET ");
329		item.update_columns(&mut sql);
330		write!(&mut sql, "{}", formatter).unwrap();
331
332		let stmt = self.prepare_cached(&sql).await?;
333
334		// we need to merge both params
335
336		self.execute_raw(
337			&stmt,
338			TwoExactSize(item.params(), filter.params.iter_to_sql()),
339		)
340		.await
341		.map(|_| ())
342	}
343
344	// delete
345	pub async fn delete(
346		&self,
347		table: &str,
348		filter: impl Borrow<WhereFilter<'_>>,
349	) -> Result<(), Error> {
350		let sql = format!("DELETE FROM \"{}\"{}", table, filter.borrow());
351		let stmt = self.prepare_cached(&sql).await?;
352
353		self.execute_raw(&stmt, filter.borrow().params.iter_to_sql())
354			.await
355			.map(|_| ())
356	}
357
358	/// Like [`tokio_postgres::Client::prepare_typed()`] but uses a cached
359	/// statement if one exists.
360	pub async fn prepare_cached(
361		&self,
362		query: &str,
363	) -> Result<Statement, Error> {
364		match &self.inner {
365			ConnectionInner::Client(client) => {
366				client.prepare_cached(query).await.map_err(Error::from)
367			}
368			ConnectionInner::Transaction(tr) => {
369				tr.prepare_cached(query).await.map_err(Error::from)
370			}
371		}
372	}
373
374	/// See [`tokio_postgres::Client::prepare()`]
375	pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
376		match &self.inner {
377			ConnectionInner::Client(client) => {
378				client.prepare(query).await.map_err(Error::from)
379			}
380			ConnectionInner::Transaction(tr) => {
381				tr.prepare(query).await.map_err(Error::from)
382			}
383		}
384	}
385
386	/// Like [`tokio_postgres::Client::prepare_typed()`] but uses a cached
387	/// statement if one exists.
388	pub async fn prepare_typed_cached(
389		&self,
390		query: &str,
391		types: &[Type],
392	) -> Result<Statement, Error> {
393		match &self.inner {
394			ConnectionInner::Client(client) => client
395				.prepare_typed_cached(query, types)
396				.await
397				.map_err(Error::from),
398			ConnectionInner::Transaction(tr) => tr
399				.prepare_typed_cached(query, types)
400				.await
401				.map_err(Error::from),
402		}
403	}
404
405	/// See [`tokio_postgres::Client::prepare_typed()`]
406	pub async fn prepare_typed(
407		&self,
408		query: &str,
409		parameter_types: &[Type],
410	) -> Result<Statement, Error> {
411		match &self.inner {
412			ConnectionInner::Client(client) => client
413				.prepare_typed(query, parameter_types)
414				.await
415				.map_err(Error::from),
416			ConnectionInner::Transaction(tr) => tr
417				.prepare_typed(query, parameter_types)
418				.await
419				.map_err(Error::from),
420		}
421	}
422
423	/// See [`tokio_postgres::Client::query()`]
424	pub async fn query<R, T>(
425		&self,
426		statement: &T,
427		params: &[&(dyn ToSql + Sync)],
428	) -> Result<Vec<R>, Error>
429	where
430		R: FromRowOwned,
431		T: ?Sized + ToStatement,
432	{
433		self.query_raw(statement, slice_iter(params))
434			.await?
435			.map(|row| {
436				row.and_then(|row| {
437					R::from_row_owned(row).map_err(Error::Deserialize)
438				})
439			})
440			.try_collect()
441			.await
442	}
443
444	/// See [`tokio_postgres::Client::query_one()`]
445	pub async fn query_one<R, T>(
446		&self,
447		statement: &T,
448		params: &[&(dyn ToSql + Sync)],
449	) -> Result<R, Error>
450	where
451		R: FromRowOwned,
452		T: ?Sized + ToStatement,
453	{
454		let row = match &self.inner {
455			ConnectionInner::Client(client) => {
456				client.query_one(statement, params).await?
457			}
458			ConnectionInner::Transaction(tr) => {
459				tr.query_one(statement, params).await?
460			}
461		};
462
463		R::from_row_owned(row.into()).map_err(Error::Deserialize)
464	}
465
466	/// See [`tokio_postgres::Client::query_opt()`]
467	pub async fn query_opt<R, T>(
468		&self,
469		statement: &T,
470		params: &[&(dyn ToSql + Sync)],
471	) -> Result<Option<R>, Error>
472	where
473		R: FromRowOwned,
474		T: ?Sized + ToStatement,
475	{
476		let row = match &self.inner {
477			ConnectionInner::Client(client) => {
478				client.query_opt(statement, params).await?
479			}
480			ConnectionInner::Transaction(tr) => {
481				tr.query_opt(statement, params).await?
482			}
483		};
484
485		R::from_row_owned(try2!(row).into())
486			.map(Some)
487			.map_err(Error::Deserialize)
488	}
489
490	/// See [`tokio_postgres::Client::query_opt()`] and [`tokio_postgres::Client::query_raw()`]
491	pub async fn query_raw_opt<R, T, P, I>(
492		&self,
493		statement: &T,
494		params: I,
495	) -> Result<Option<R>, Error>
496	where
497		R: FromRowOwned,
498		T: ?Sized + ToStatement,
499		P: BorrowToSql,
500		I: IntoIterator<Item = P>,
501		I::IntoIter: ExactSizeIterator,
502	{
503		let stream = self.query_raw(statement, params).await?;
504		pin_mut!(stream);
505
506		let row = match stream.try_next().await? {
507			Some(row) => row,
508			None => return Ok(None),
509		};
510
511		if stream.try_next().await?.is_some() {
512			return Err(Error::ExpectedOneRow);
513		}
514
515		R::from_row_owned(row).map(Some).map_err(Error::Deserialize)
516	}
517
518	/// See [`tokio_postgres::Client::query_raw()`]
519	pub async fn query_raw<T, P, I>(
520		&self,
521		statement: &T,
522		params: I,
523	) -> Result<RowStream, Error>
524	where
525		T: ?Sized + ToStatement,
526		P: BorrowToSql,
527		I: IntoIterator<Item = P>,
528		I::IntoIter: ExactSizeIterator,
529	{
530		let row_stream = match &self.inner {
531			ConnectionInner::Client(client) => {
532				client.query_raw(statement, params).await?
533			}
534			ConnectionInner::Transaction(tr) => {
535				tr.query_raw(statement, params).await?
536			}
537		};
538
539		Ok(row_stream.into())
540	}
541
542	/// See [`tokio_postgres::Client::execute()`]
543	pub async fn execute<T>(
544		&self,
545		statement: &T,
546		params: &[&(dyn ToSql + Sync)],
547	) -> Result<u64, Error>
548	where
549		T: ?Sized + ToStatement,
550	{
551		match &self.inner {
552			ConnectionInner::Client(client) => {
553				client.execute(statement, params).await.map_err(Error::from)
554			}
555			ConnectionInner::Transaction(tr) => {
556				tr.execute(statement, params).await.map_err(Error::from)
557			}
558		}
559	}
560
561	/// See [`tokio_postgres::Client::execute_raw()`]
562	pub async fn execute_raw<T, P, I>(
563		&self,
564		statement: &T,
565		params: I,
566	) -> Result<u64, Error>
567	where
568		T: ?Sized + ToStatement,
569		P: BorrowToSql,
570		I: IntoIterator<Item = P>,
571		I::IntoIter: ExactSizeIterator,
572	{
573		match &self.inner {
574			ConnectionInner::Client(client) => client
575				.execute_raw(statement, params)
576				.await
577				.map_err(Error::from),
578			ConnectionInner::Transaction(tr) => {
579				tr.execute_raw(statement, params).await.map_err(Error::from)
580			}
581		}
582	}
583
584	/// See [`tokio_postgres::Client::batch_execute()`]
585	pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
586		match &self.inner {
587			ConnectionInner::Client(client) => {
588				client.batch_execute(query).await.map_err(Error::from)
589			}
590			ConnectionInner::Transaction(tr) => {
591				tr.batch_execute(query).await.map_err(Error::from)
592			}
593		}
594	}
595}
596
597fn slice_iter<'a>(
598	s: &'a [&'a (dyn ToSql + Sync)],
599) -> impl ExactSizeIterator<Item = &'a dyn ToSql> + 'a {
600	s.iter().map(|s| *s as _)
601}
602
603struct TwoExactSize<I, J>(I, J);
604
605impl<I, J, T> Iterator for TwoExactSize<I, J>
606where
607	I: ExactSizeIterator<Item = T>,
608	J: ExactSizeIterator<Item = T>,
609{
610	type Item = T;
611
612	fn next(&mut self) -> Option<Self::Item> {
613		self.0.next().or_else(|| self.1.next())
614	}
615
616	fn size_hint(&self) -> (usize, Option<usize>) {
617		let (a, b) = (self.0.size_hint(), self.1.size_hint());
618		(a.0 + b.0, a.1.and_then(|a| b.1.map(|b| a + b)))
619	}
620}
621
622impl<I, J, T> ExactSizeIterator for TwoExactSize<I, J>
623where
624	I: ExactSizeIterator<Item = T>,
625	J: ExactSizeIterator<Item = T>,
626{
627}