chuchi_postgres/
connection.rs

1// use crate::table::{Table, TableTemplate};
2
3use std::borrow::Borrow;
4use std::fmt::Write;
5
6use deadpool_postgres::Metrics;
7use deadpool_postgres::{ClientWrapper, Object};
8
9use futures_util::StreamExt;
10use futures_util::TryStreamExt;
11use futures_util::pin_mut;
12use postgres_types::{BorrowToSql, ToSql, Type};
13use tokio_postgres::Error as PgError;
14use tokio_postgres::error::SqlState;
15
16pub use deadpool::managed::TimeoutType;
17pub use deadpool_postgres::{Config, ConfigError};
18use tokio_postgres::Statement;
19use tokio_postgres::ToStatement;
20use tracing::error;
21
22use crate::Row;
23use crate::filter::Filter;
24use crate::filter::Limit;
25use crate::filter::WhereFilter;
26use crate::row::NamedColumns;
27use crate::row::RowStream;
28use crate::row::ToRowStatic;
29use crate::row::{FromRowOwned, ToRow};
30use crate::try2;
31
32#[derive(Debug, thiserror::Error)]
33#[non_exhaustive]
34pub enum Error {
35	#[error("Unique violation {constraint:?}")]
36	#[non_exhaustive]
37	UniqueViolation { constraint: Option<String> },
38
39	#[error("Expected one row")]
40	ExpectedOneRow,
41
42	#[error("Other Postgres error {0}")]
43	Other(PgError),
44
45	#[error("Deserialization error {0}")]
46	Deserialize(Box<dyn std::error::Error + Send + Sync>),
47
48	#[error("Unknown error {0}")]
49	Unknown(Box<dyn std::error::Error + Send + Sync>),
50}
51
52impl From<PgError> for Error {
53	fn from(e: PgError) -> Self {
54		let Some(db_error) = e.as_db_error() else {
55			return Self::Other(e);
56		};
57
58		match db_error.code() {
59			&SqlState::UNIQUE_VIOLATION => Self::UniqueViolation {
60				constraint: db_error.constraint().map(Into::into),
61			},
62			state => {
63				error!("db error with state {:?}", state);
64				Self::Other(e)
65			}
66		}
67	}
68}
69
70#[derive(Debug)]
71pub struct ConnectionOwned(pub(crate) Object);
72
73impl ConnectionOwned {
74	pub fn connection(&self) -> Connection<'_> {
75		Connection {
76			inner: ConnectionInner::Client(&self.0),
77		}
78	}
79
80	pub async fn transaction<'a>(
81		&'a mut self,
82	) -> Result<Transaction<'a>, Error> {
83		Ok(Transaction {
84			inner: self.0.transaction().await.map_err(Error::from)?,
85		})
86	}
87
88	pub fn metrics(&self) -> &Metrics {
89		Object::metrics(&self.0)
90	}
91}
92
93#[cfg(feature = "chuchi")]
94mod impl_chuchi {
95	use chuchi::{
96		extractor::Extractor, extractor_extract, extractor_prepare,
97		extractor_validate,
98	};
99
100	use crate::{Database, database::DatabaseError};
101
102	use super::*;
103
104	impl<'a, R> Extractor<'a, R> for ConnectionOwned {
105		type Error = DatabaseError;
106		type Prepared = Self;
107
108		extractor_validate!(|validate| {
109			assert!(
110				validate.resources.exists::<Database>(),
111				"Db resource not found"
112			);
113		});
114
115		extractor_prepare!(|prepare| {
116			let db = prepare.resources.get::<Database>().unwrap();
117			db.get().await
118		});
119
120		extractor_extract!(|extract| { Ok(extract.prepared) });
121	}
122}
123
124#[derive(Debug)]
125pub struct Transaction<'a> {
126	inner: deadpool_postgres::Transaction<'a>,
127}
128
129impl<'a> Transaction<'a> {
130	/// Returns a connection to the database
131	pub fn connection(&self) -> Connection<'_> {
132		Connection {
133			inner: ConnectionInner::Transaction(&self.inner),
134		}
135	}
136
137	/// See [`tokio_postgres::Transaction::commit()`]
138	pub async fn commit(self) -> Result<(), Error> {
139		self.inner.commit().await.map_err(Error::from)
140	}
141
142	/// See [`tokio_postgres::Transaction::rollback()`]
143	pub async fn rollback(self) -> Result<(), Error> {
144		self.inner.rollback().await.map_err(Error::from)
145	}
146}
147
148#[derive(Debug, Clone, Copy)]
149pub struct Connection<'a> {
150	inner: ConnectionInner<'a>,
151}
152
153#[derive(Debug, Clone, Copy)]
154enum ConnectionInner<'a> {
155	Client(&'a ClientWrapper),
156	Transaction(&'a deadpool_postgres::Transaction<'a>),
157}
158
159impl Connection<'_> {
160	// select
161
162	// how about the columns are a separat parameter, which contains
163	// an exact size iterator, and implementors can call
164	// select("table", R::select_columns(), filter)
165	// or select("table", &["column1", "column2"], filter)
166	pub async fn select<R>(
167		&self,
168		table: &str,
169		filter: impl Borrow<Filter<'_>>,
170	) -> Result<Vec<R>, Error>
171	where
172		R: FromRowOwned + NamedColumns,
173	{
174		let sql = format!(
175			"SELECT {} FROM \"{}\"{}",
176			R::select_columns(),
177			table,
178			filter.borrow()
179		);
180		let stmt = self.prepare_cached(&sql).await?;
181
182		self.query_raw(&stmt, filter.borrow().params.iter_to_sql())
183			.await?
184			.map(|row| {
185				row.and_then(|row| {
186					R::from_row_owned(row).map_err(Error::Deserialize)
187				})
188			})
189			.try_collect()
190			.await
191	}
192
193	// select_one
194	pub async fn select_one<R>(
195		&self,
196		table: &str,
197		filter: impl Borrow<Filter<'_>>,
198	) -> Result<R, Error>
199	where
200		R: FromRowOwned + NamedColumns,
201	{
202		let mut formatter = filter.borrow().to_formatter();
203
204		if matches!(formatter.limit, Limit::All) {
205			formatter.limit = &Limit::Fixed(1);
206		}
207
208		let sql = format!(
209			"SELECT {} FROM \"{}\"{}",
210			R::select_columns(),
211			table,
212			formatter
213		);
214		let stmt = self.prepare_cached(&sql).await?;
215
216		let row = self
217			.query_raw_opt(&stmt, filter.borrow().params.iter_to_sql())
218			.await
219			.and_then(|opt| opt.ok_or(Error::ExpectedOneRow))?;
220
221		R::from_row_owned(row).map_err(Error::Deserialize)
222	}
223
224	// select_opt
225	pub async fn select_opt<R>(
226		&self,
227		table: &str,
228		filter: impl Borrow<Filter<'_>>,
229	) -> Result<Option<R>, Error>
230	where
231		R: FromRowOwned + NamedColumns,
232	{
233		let mut formatter = filter.borrow().to_formatter();
234
235		if matches!(formatter.limit, Limit::All) {
236			formatter.limit = &Limit::Fixed(1);
237		}
238
239		let sql = format!(
240			"SELECT {} FROM \"{}\"{}",
241			R::select_columns(),
242			table,
243			formatter
244		);
245		let stmt = self.prepare_cached(&sql).await?;
246
247		self.query_raw_opt(&stmt, filter.borrow().params.iter_to_sql())
248			.await
249	}
250
251	/// count
252	///
253	/// A column is required because you should select a column which has some
254	/// indexes on it, this makes the call a lot cheaper
255	pub async fn count(
256		&self,
257		table: &str,
258		column: &str,
259		filter: impl Borrow<Filter<'_>>,
260	) -> Result<u64, Error> {
261		let sql = format!(
262			"SELECT COUNT(\"{column}\") FROM \"{table}\"{}",
263			filter.borrow()
264		);
265		let stmt = self.prepare_cached(&sql).await?;
266
267		let row: Row = self
268			.query_raw_opt(&stmt, filter.borrow().params.iter_to_sql())
269			.await
270			.and_then(|opt| opt.ok_or(Error::ExpectedOneRow))?;
271
272		Ok(row.get::<_, i64>(0) as u64)
273	}
274
275	// insert one
276	pub async fn insert<U>(&self, table: &str, item: &U) -> Result<(), Error>
277	where
278		U: ToRow,
279	{
280		let mut sql = format!("INSERT INTO \"{table}\" (");
281		item.insert_columns(&mut sql);
282		sql.push_str(") VALUES (");
283		item.insert_values(&mut sql);
284		sql.push(')');
285
286		let stmt = self.prepare_cached(&sql).await?;
287
288		self.execute_raw(&stmt, item.params()).await.map(|_| ())
289	}
290
291	// insert_many
292	pub async fn insert_many<U, I>(
293		&self,
294		table: &str,
295		items: I,
296	) -> Result<(), Error>
297	where
298		U: ToRowStatic,
299		I: IntoIterator,
300		I::Item: Borrow<U>,
301	{
302		let sql = format!(
303			"INSERT INTO \"{}\" ({}) VALUES ({})",
304			table,
305			U::insert_columns(),
306			U::insert_values()
307		);
308		let stmt = self.prepare_cached(&sql).await?;
309
310		for item in items {
311			self.execute_raw(&stmt, item.borrow().params()).await?;
312		}
313
314		Ok(())
315	}
316
317	// update
318	pub async fn update<U>(
319		&self,
320		table: &str,
321		item: &U,
322		filter: impl Borrow<WhereFilter<'_>>,
323	) -> Result<(), Error>
324	where
325		U: ToRow,
326	{
327		let filter = filter.borrow();
328		let mut formatter = filter.whr.to_formatter();
329		formatter.param_start = item.params_len();
330
331		let mut sql = format!("UPDATE \"{table}\" SET ");
332		item.update_columns(&mut sql);
333		write!(&mut sql, "{}", formatter).unwrap();
334
335		let stmt = self.prepare_cached(&sql).await?;
336
337		// we need to merge both params
338
339		self.execute_raw(
340			&stmt,
341			TwoExactSize(item.params(), filter.params.iter_to_sql()),
342		)
343		.await
344		.map(|_| ())
345	}
346
347	// delete
348	pub async fn delete(
349		&self,
350		table: &str,
351		filter: impl Borrow<WhereFilter<'_>>,
352	) -> Result<(), Error> {
353		let sql = format!("DELETE FROM \"{}\"{}", table, filter.borrow());
354		let stmt = self.prepare_cached(&sql).await?;
355
356		self.execute_raw(&stmt, filter.borrow().params.iter_to_sql())
357			.await
358			.map(|_| ())
359	}
360
361	/// Like [`tokio_postgres::Client::prepare_typed()`] but uses a cached
362	/// statement if one exists.
363	pub async fn prepare_cached(
364		&self,
365		query: &str,
366	) -> Result<Statement, Error> {
367		match &self.inner {
368			ConnectionInner::Client(client) => {
369				client.prepare_cached(query).await.map_err(Error::from)
370			}
371			ConnectionInner::Transaction(tr) => {
372				tr.prepare_cached(query).await.map_err(Error::from)
373			}
374		}
375	}
376
377	/// See [`tokio_postgres::Client::prepare()`]
378	pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
379		match &self.inner {
380			ConnectionInner::Client(client) => {
381				client.prepare(query).await.map_err(Error::from)
382			}
383			ConnectionInner::Transaction(tr) => {
384				tr.prepare(query).await.map_err(Error::from)
385			}
386		}
387	}
388
389	/// Like [`tokio_postgres::Client::prepare_typed()`] but uses a cached
390	/// statement if one exists.
391	pub async fn prepare_typed_cached(
392		&self,
393		query: &str,
394		types: &[Type],
395	) -> Result<Statement, Error> {
396		match &self.inner {
397			ConnectionInner::Client(client) => client
398				.prepare_typed_cached(query, types)
399				.await
400				.map_err(Error::from),
401			ConnectionInner::Transaction(tr) => tr
402				.prepare_typed_cached(query, types)
403				.await
404				.map_err(Error::from),
405		}
406	}
407
408	/// See [`tokio_postgres::Client::prepare_typed()`]
409	pub async fn prepare_typed(
410		&self,
411		query: &str,
412		parameter_types: &[Type],
413	) -> Result<Statement, Error> {
414		match &self.inner {
415			ConnectionInner::Client(client) => client
416				.prepare_typed(query, parameter_types)
417				.await
418				.map_err(Error::from),
419			ConnectionInner::Transaction(tr) => tr
420				.prepare_typed(query, parameter_types)
421				.await
422				.map_err(Error::from),
423		}
424	}
425
426	/// See [`tokio_postgres::Client::query()`]
427	pub async fn query<R, T>(
428		&self,
429		statement: &T,
430		params: &[&(dyn ToSql + Sync)],
431	) -> Result<Vec<R>, Error>
432	where
433		R: FromRowOwned,
434		T: ?Sized + ToStatement,
435	{
436		self.query_raw(statement, slice_iter(params))
437			.await?
438			.map(|row| {
439				row.and_then(|row| {
440					R::from_row_owned(row).map_err(Error::Deserialize)
441				})
442			})
443			.try_collect()
444			.await
445	}
446
447	/// See [`tokio_postgres::Client::query_one()`]
448	pub async fn query_one<R, T>(
449		&self,
450		statement: &T,
451		params: &[&(dyn ToSql + Sync)],
452	) -> Result<R, Error>
453	where
454		R: FromRowOwned,
455		T: ?Sized + ToStatement,
456	{
457		let row = match &self.inner {
458			ConnectionInner::Client(client) => {
459				client.query_one(statement, params).await?
460			}
461			ConnectionInner::Transaction(tr) => {
462				tr.query_one(statement, params).await?
463			}
464		};
465
466		R::from_row_owned(row.into()).map_err(Error::Deserialize)
467	}
468
469	/// See [`tokio_postgres::Client::query_opt()`]
470	pub async fn query_opt<R, T>(
471		&self,
472		statement: &T,
473		params: &[&(dyn ToSql + Sync)],
474	) -> Result<Option<R>, Error>
475	where
476		R: FromRowOwned,
477		T: ?Sized + ToStatement,
478	{
479		let row = match &self.inner {
480			ConnectionInner::Client(client) => {
481				client.query_opt(statement, params).await?
482			}
483			ConnectionInner::Transaction(tr) => {
484				tr.query_opt(statement, params).await?
485			}
486		};
487
488		R::from_row_owned(try2!(row).into())
489			.map(Some)
490			.map_err(Error::Deserialize)
491	}
492
493	/// See [`tokio_postgres::Client::query_opt()`] and [`tokio_postgres::Client::query_raw()`]
494	pub async fn query_raw_opt<R, T, P, I>(
495		&self,
496		statement: &T,
497		params: I,
498	) -> Result<Option<R>, Error>
499	where
500		R: FromRowOwned,
501		T: ?Sized + ToStatement,
502		P: BorrowToSql,
503		I: IntoIterator<Item = P>,
504		I::IntoIter: ExactSizeIterator,
505	{
506		let stream = self.query_raw(statement, params).await?;
507		pin_mut!(stream);
508
509		let row = match stream.try_next().await? {
510			Some(row) => row,
511			None => return Ok(None),
512		};
513
514		if stream.try_next().await?.is_some() {
515			return Err(Error::ExpectedOneRow);
516		}
517
518		R::from_row_owned(row).map(Some).map_err(Error::Deserialize)
519	}
520
521	/// See [`tokio_postgres::Client::query_raw()`]
522	pub async fn query_raw<T, P, I>(
523		&self,
524		statement: &T,
525		params: I,
526	) -> Result<RowStream, Error>
527	where
528		T: ?Sized + ToStatement,
529		P: BorrowToSql,
530		I: IntoIterator<Item = P>,
531		I::IntoIter: ExactSizeIterator,
532	{
533		let row_stream = match &self.inner {
534			ConnectionInner::Client(client) => {
535				client.query_raw(statement, params).await?
536			}
537			ConnectionInner::Transaction(tr) => {
538				tr.query_raw(statement, params).await?
539			}
540		};
541
542		Ok(row_stream.into())
543	}
544
545	/// See [`tokio_postgres::Client::execute()`]
546	pub async fn execute<T>(
547		&self,
548		statement: &T,
549		params: &[&(dyn ToSql + Sync)],
550	) -> Result<u64, Error>
551	where
552		T: ?Sized + ToStatement,
553	{
554		match &self.inner {
555			ConnectionInner::Client(client) => {
556				client.execute(statement, params).await.map_err(Error::from)
557			}
558			ConnectionInner::Transaction(tr) => {
559				tr.execute(statement, params).await.map_err(Error::from)
560			}
561		}
562	}
563
564	/// See [`tokio_postgres::Client::execute_raw()`]
565	pub async fn execute_raw<T, P, I>(
566		&self,
567		statement: &T,
568		params: I,
569	) -> Result<u64, Error>
570	where
571		T: ?Sized + ToStatement,
572		P: BorrowToSql,
573		I: IntoIterator<Item = P>,
574		I::IntoIter: ExactSizeIterator,
575	{
576		match &self.inner {
577			ConnectionInner::Client(client) => client
578				.execute_raw(statement, params)
579				.await
580				.map_err(Error::from),
581			ConnectionInner::Transaction(tr) => {
582				tr.execute_raw(statement, params).await.map_err(Error::from)
583			}
584		}
585	}
586
587	/// See [`tokio_postgres::Client::batch_execute()`]
588	pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
589		match &self.inner {
590			ConnectionInner::Client(client) => {
591				client.batch_execute(query).await.map_err(Error::from)
592			}
593			ConnectionInner::Transaction(tr) => {
594				tr.batch_execute(query).await.map_err(Error::from)
595			}
596		}
597	}
598}
599
600fn slice_iter<'a>(
601	s: &'a [&'a (dyn ToSql + Sync)],
602) -> impl ExactSizeIterator<Item = &'a dyn ToSql> + 'a {
603	s.iter().map(|s| *s as _)
604}
605
606struct TwoExactSize<I, J>(I, J);
607
608impl<I, J, T> Iterator for TwoExactSize<I, J>
609where
610	I: ExactSizeIterator<Item = T>,
611	J: ExactSizeIterator<Item = T>,
612{
613	type Item = T;
614
615	fn next(&mut self) -> Option<Self::Item> {
616		self.0.next().or_else(|| self.1.next())
617	}
618
619	fn size_hint(&self) -> (usize, Option<usize>) {
620		let (a, b) = (self.0.size_hint(), self.1.size_hint());
621		(a.0 + b.0, a.1.and_then(|a| b.1.map(|b| a + b)))
622	}
623}
624
625impl<I, J, T> ExactSizeIterator for TwoExactSize<I, J>
626where
627	I: ExactSizeIterator<Item = T>,
628	J: ExactSizeIterator<Item = T>,
629{
630}