chuchi_postgres/
connection.rs

1// use crate::table::{Table, TableTemplate};
2
3use std::borrow::Borrow;
4use std::fmt::Write;
5
6use deadpool_postgres::Metrics;
7use deadpool_postgres::{ClientWrapper, Object};
8
9use futures_util::StreamExt;
10use futures_util::TryStreamExt;
11use futures_util::pin_mut;
12use postgres_types::{BorrowToSql, ToSql, Type};
13use tokio_postgres::Error as PgError;
14use tokio_postgres::error::SqlState;
15
16pub use deadpool::managed::TimeoutType;
17pub use deadpool_postgres::{Config, ConfigError};
18use tokio_postgres::Statement;
19use tokio_postgres::ToStatement;
20use tracing::error;
21
22use crate::Row;
23use crate::filter::Filter;
24use crate::filter::Limit;
25use crate::filter::WhereFilter;
26use crate::row::NamedColumns;
27use crate::row::RowStream;
28use crate::row::ToRowStatic;
29use crate::row::{FromRowOwned, ToRow};
30use crate::try2;
31
32#[derive(Debug, thiserror::Error)]
33#[non_exhaustive]
34pub enum Error {
35	#[error("Unique violation {constraint:?}")]
36	#[non_exhaustive]
37	UniqueViolation { constraint: Option<String> },
38
39	#[error("Expected one row")]
40	ExpectedOneRow,
41
42	#[error("Other Postgres error {0}")]
43	Other(PgError),
44
45	#[error("Deserialization error {0}")]
46	Deserialize(Box<dyn std::error::Error + Send + Sync>),
47
48	#[error("Unknown error {0}")]
49	Unknown(Box<dyn std::error::Error + Send + Sync>),
50}
51
52impl Error {
53	pub fn unique_violation(constraint: Option<String>) -> Self {
54		Self::UniqueViolation { constraint }
55	}
56}
57
58impl From<PgError> for Error {
59	fn from(e: PgError) -> Self {
60		let Some(db_error) = e.as_db_error() else {
61			return Self::Other(e);
62		};
63
64		match db_error.code() {
65			&SqlState::UNIQUE_VIOLATION => Self::UniqueViolation {
66				constraint: db_error.constraint().map(Into::into),
67			},
68			state => {
69				error!("db error with state {:?}", state);
70				Self::Other(e)
71			}
72		}
73	}
74}
75
76#[derive(Debug)]
77pub struct ConnectionOwned(pub(crate) Object);
78
79impl ConnectionOwned {
80	pub fn connection(&self) -> Connection<'_> {
81		Connection {
82			inner: ConnectionInner::Client(&self.0),
83		}
84	}
85
86	pub async fn transaction<'a>(
87		&'a mut self,
88	) -> Result<Transaction<'a>, Error> {
89		Ok(Transaction {
90			inner: self.0.transaction().await.map_err(Error::from)?,
91		})
92	}
93
94	pub fn metrics(&self) -> &Metrics {
95		Object::metrics(&self.0)
96	}
97}
98
99#[cfg(feature = "chuchi")]
100mod impl_chuchi {
101	use chuchi::{
102		extractor::Extractor, extractor_extract, extractor_prepare,
103		extractor_validate,
104	};
105
106	use crate::{Database, database::DatabaseError};
107
108	use super::*;
109
110	impl<'a, R> Extractor<'a, R> for ConnectionOwned {
111		type Error = DatabaseError;
112		type Prepared = Self;
113
114		extractor_validate!(|validate| {
115			assert!(
116				validate.resources.exists::<Database>(),
117				"Db resource not found"
118			);
119		});
120
121		extractor_prepare!(|prepare| {
122			let db = prepare.resources.get::<Database>().unwrap();
123			db.get().await
124		});
125
126		extractor_extract!(|extract| { Ok(extract.prepared) });
127	}
128}
129
130#[derive(Debug)]
131pub struct Transaction<'a> {
132	inner: deadpool_postgres::Transaction<'a>,
133}
134
135impl<'a> Transaction<'a> {
136	/// Returns a connection to the database
137	pub fn connection(&self) -> Connection<'_> {
138		Connection {
139			inner: ConnectionInner::Transaction(&self.inner),
140		}
141	}
142
143	/// See [`tokio_postgres::Transaction::commit()`]
144	pub async fn commit(self) -> Result<(), Error> {
145		self.inner.commit().await.map_err(Error::from)
146	}
147
148	/// See [`tokio_postgres::Transaction::rollback()`]
149	pub async fn rollback(self) -> Result<(), Error> {
150		self.inner.rollback().await.map_err(Error::from)
151	}
152}
153
154#[derive(Debug, Clone, Copy)]
155pub struct Connection<'a> {
156	inner: ConnectionInner<'a>,
157}
158
159#[derive(Debug, Clone, Copy)]
160enum ConnectionInner<'a> {
161	Client(&'a ClientWrapper),
162	Transaction(&'a deadpool_postgres::Transaction<'a>),
163}
164
165impl Connection<'_> {
166	// select
167
168	// how about the columns are a separat parameter, which contains
169	// an exact size iterator, and implementors can call
170	// select("table", R::select_columns(), filter)
171	// or select("table", &["column1", "column2"], filter)
172	pub async fn select<R>(
173		&self,
174		table: &str,
175		filter: impl Borrow<Filter<'_>>,
176	) -> Result<Vec<R>, Error>
177	where
178		R: FromRowOwned + NamedColumns,
179	{
180		let sql = format!(
181			"SELECT {} FROM \"{}\"{}",
182			R::select_columns(),
183			table,
184			filter.borrow()
185		);
186		let stmt = self.prepare_cached(&sql).await?;
187
188		self.query_raw(&stmt, filter.borrow().params.iter_to_sql())
189			.await?
190			.map(|row| {
191				row.and_then(|row| {
192					R::from_row_owned(row).map_err(Error::Deserialize)
193				})
194			})
195			.try_collect()
196			.await
197	}
198
199	// select_one
200	pub async fn select_one<R>(
201		&self,
202		table: &str,
203		filter: impl Borrow<Filter<'_>>,
204	) -> Result<R, Error>
205	where
206		R: FromRowOwned + NamedColumns,
207	{
208		let mut formatter = filter.borrow().to_formatter();
209
210		if matches!(formatter.limit, Limit::All) {
211			formatter.limit = &Limit::Fixed(1);
212		}
213
214		let sql = format!(
215			"SELECT {} FROM \"{}\"{}",
216			R::select_columns(),
217			table,
218			formatter
219		);
220		let stmt = self.prepare_cached(&sql).await?;
221
222		let row = self
223			.query_raw_opt(&stmt, filter.borrow().params.iter_to_sql())
224			.await
225			.and_then(|opt| opt.ok_or(Error::ExpectedOneRow))?;
226
227		R::from_row_owned(row).map_err(Error::Deserialize)
228	}
229
230	// select_opt
231	pub async fn select_opt<R>(
232		&self,
233		table: &str,
234		filter: impl Borrow<Filter<'_>>,
235	) -> Result<Option<R>, Error>
236	where
237		R: FromRowOwned + NamedColumns,
238	{
239		let mut formatter = filter.borrow().to_formatter();
240
241		if matches!(formatter.limit, Limit::All) {
242			formatter.limit = &Limit::Fixed(1);
243		}
244
245		let sql = format!(
246			"SELECT {} FROM \"{}\"{}",
247			R::select_columns(),
248			table,
249			formatter
250		);
251		let stmt = self.prepare_cached(&sql).await?;
252
253		self.query_raw_opt(&stmt, filter.borrow().params.iter_to_sql())
254			.await
255	}
256
257	/// count
258	///
259	/// A column is required because you should select a column which has some
260	/// indexes on it, this makes the call a lot cheaper
261	pub async fn count(
262		&self,
263		table: &str,
264		column: &str,
265		filter: impl Borrow<Filter<'_>>,
266	) -> Result<u64, Error> {
267		let sql = format!(
268			"SELECT COUNT(\"{column}\") FROM \"{table}\"{}",
269			filter.borrow()
270		);
271		let stmt = self.prepare_cached(&sql).await?;
272
273		let row: Row = self
274			.query_raw_opt(&stmt, filter.borrow().params.iter_to_sql())
275			.await
276			.and_then(|opt| opt.ok_or(Error::ExpectedOneRow))?;
277
278		Ok(row.get::<_, i64>(0) as u64)
279	}
280
281	// insert one
282	pub async fn insert<U>(&self, table: &str, item: &U) -> Result<(), Error>
283	where
284		U: ToRow,
285	{
286		let mut sql = format!("INSERT INTO \"{table}\" (");
287		item.insert_columns(&mut sql);
288		sql.push_str(") VALUES (");
289		item.insert_values(&mut sql);
290		sql.push(')');
291
292		let stmt = self.prepare_cached(&sql).await?;
293
294		self.execute_raw(&stmt, item.params()).await.map(|_| ())
295	}
296
297	// insert_many
298	pub async fn insert_many<U, I>(
299		&self,
300		table: &str,
301		items: I,
302	) -> Result<(), Error>
303	where
304		U: ToRowStatic,
305		I: IntoIterator,
306		I::Item: Borrow<U>,
307	{
308		let sql = format!(
309			"INSERT INTO \"{}\" ({}) VALUES ({})",
310			table,
311			U::insert_columns(),
312			U::insert_values()
313		);
314		let stmt = self.prepare_cached(&sql).await?;
315
316		for item in items {
317			self.execute_raw(&stmt, item.borrow().params()).await?;
318		}
319
320		Ok(())
321	}
322
323	// update
324	pub async fn update<U>(
325		&self,
326		table: &str,
327		item: &U,
328		filter: impl Borrow<WhereFilter<'_>>,
329	) -> Result<(), Error>
330	where
331		U: ToRow,
332	{
333		let filter = filter.borrow();
334		let mut formatter = filter.whr.to_formatter();
335		formatter.param_start = item.params_len();
336
337		let mut sql = format!("UPDATE \"{table}\" SET ");
338		item.update_columns(&mut sql);
339		write!(&mut sql, "{}", formatter).unwrap();
340
341		let stmt = self.prepare_cached(&sql).await?;
342
343		// we need to merge both params
344
345		self.execute_raw(
346			&stmt,
347			TwoExactSize(item.params(), filter.params.iter_to_sql()),
348		)
349		.await
350		.map(|_| ())
351	}
352
353	// delete
354	pub async fn delete(
355		&self,
356		table: &str,
357		filter: impl Borrow<WhereFilter<'_>>,
358	) -> Result<(), Error> {
359		let sql = format!("DELETE FROM \"{}\"{}", table, filter.borrow());
360		let stmt = self.prepare_cached(&sql).await?;
361
362		self.execute_raw(&stmt, filter.borrow().params.iter_to_sql())
363			.await
364			.map(|_| ())
365	}
366
367	/// Like [`tokio_postgres::Client::prepare_typed()`] but uses a cached
368	/// statement if one exists.
369	pub async fn prepare_cached(
370		&self,
371		query: &str,
372	) -> Result<Statement, Error> {
373		match &self.inner {
374			ConnectionInner::Client(client) => {
375				client.prepare_cached(query).await.map_err(Error::from)
376			}
377			ConnectionInner::Transaction(tr) => {
378				tr.prepare_cached(query).await.map_err(Error::from)
379			}
380		}
381	}
382
383	/// See [`tokio_postgres::Client::prepare()`]
384	pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
385		match &self.inner {
386			ConnectionInner::Client(client) => {
387				client.prepare(query).await.map_err(Error::from)
388			}
389			ConnectionInner::Transaction(tr) => {
390				tr.prepare(query).await.map_err(Error::from)
391			}
392		}
393	}
394
395	/// Like [`tokio_postgres::Client::prepare_typed()`] but uses a cached
396	/// statement if one exists.
397	pub async fn prepare_typed_cached(
398		&self,
399		query: &str,
400		types: &[Type],
401	) -> Result<Statement, Error> {
402		match &self.inner {
403			ConnectionInner::Client(client) => client
404				.prepare_typed_cached(query, types)
405				.await
406				.map_err(Error::from),
407			ConnectionInner::Transaction(tr) => tr
408				.prepare_typed_cached(query, types)
409				.await
410				.map_err(Error::from),
411		}
412	}
413
414	/// See [`tokio_postgres::Client::prepare_typed()`]
415	pub async fn prepare_typed(
416		&self,
417		query: &str,
418		parameter_types: &[Type],
419	) -> Result<Statement, Error> {
420		match &self.inner {
421			ConnectionInner::Client(client) => client
422				.prepare_typed(query, parameter_types)
423				.await
424				.map_err(Error::from),
425			ConnectionInner::Transaction(tr) => tr
426				.prepare_typed(query, parameter_types)
427				.await
428				.map_err(Error::from),
429		}
430	}
431
432	/// See [`tokio_postgres::Client::query()`]
433	pub async fn query<R, T>(
434		&self,
435		statement: &T,
436		params: &[&(dyn ToSql + Sync)],
437	) -> Result<Vec<R>, Error>
438	where
439		R: FromRowOwned,
440		T: ?Sized + ToStatement,
441	{
442		self.query_raw(statement, slice_iter(params))
443			.await?
444			.map(|row| {
445				row.and_then(|row| {
446					R::from_row_owned(row).map_err(Error::Deserialize)
447				})
448			})
449			.try_collect()
450			.await
451	}
452
453	/// See [`tokio_postgres::Client::query_one()`]
454	pub async fn query_one<R, T>(
455		&self,
456		statement: &T,
457		params: &[&(dyn ToSql + Sync)],
458	) -> Result<R, Error>
459	where
460		R: FromRowOwned,
461		T: ?Sized + ToStatement,
462	{
463		let row = match &self.inner {
464			ConnectionInner::Client(client) => {
465				client.query_one(statement, params).await?
466			}
467			ConnectionInner::Transaction(tr) => {
468				tr.query_one(statement, params).await?
469			}
470		};
471
472		R::from_row_owned(row.into()).map_err(Error::Deserialize)
473	}
474
475	/// See [`tokio_postgres::Client::query_opt()`]
476	pub async fn query_opt<R, T>(
477		&self,
478		statement: &T,
479		params: &[&(dyn ToSql + Sync)],
480	) -> Result<Option<R>, Error>
481	where
482		R: FromRowOwned,
483		T: ?Sized + ToStatement,
484	{
485		let row = match &self.inner {
486			ConnectionInner::Client(client) => {
487				client.query_opt(statement, params).await?
488			}
489			ConnectionInner::Transaction(tr) => {
490				tr.query_opt(statement, params).await?
491			}
492		};
493
494		R::from_row_owned(try2!(row).into())
495			.map(Some)
496			.map_err(Error::Deserialize)
497	}
498
499	/// See [`tokio_postgres::Client::query_opt()`] and [`tokio_postgres::Client::query_raw()`]
500	pub async fn query_raw_opt<R, T, P, I>(
501		&self,
502		statement: &T,
503		params: I,
504	) -> Result<Option<R>, Error>
505	where
506		R: FromRowOwned,
507		T: ?Sized + ToStatement,
508		P: BorrowToSql,
509		I: IntoIterator<Item = P>,
510		I::IntoIter: ExactSizeIterator,
511	{
512		let stream = self.query_raw(statement, params).await?;
513		pin_mut!(stream);
514
515		let row = match stream.try_next().await? {
516			Some(row) => row,
517			None => return Ok(None),
518		};
519
520		if stream.try_next().await?.is_some() {
521			return Err(Error::ExpectedOneRow);
522		}
523
524		R::from_row_owned(row).map(Some).map_err(Error::Deserialize)
525	}
526
527	/// See [`tokio_postgres::Client::query_raw()`]
528	pub async fn query_raw<T, P, I>(
529		&self,
530		statement: &T,
531		params: I,
532	) -> Result<RowStream, Error>
533	where
534		T: ?Sized + ToStatement,
535		P: BorrowToSql,
536		I: IntoIterator<Item = P>,
537		I::IntoIter: ExactSizeIterator,
538	{
539		let row_stream = match &self.inner {
540			ConnectionInner::Client(client) => {
541				client.query_raw(statement, params).await?
542			}
543			ConnectionInner::Transaction(tr) => {
544				tr.query_raw(statement, params).await?
545			}
546		};
547
548		Ok(row_stream.into())
549	}
550
551	/// See [`tokio_postgres::Client::execute()`]
552	pub async fn execute<T>(
553		&self,
554		statement: &T,
555		params: &[&(dyn ToSql + Sync)],
556	) -> Result<u64, Error>
557	where
558		T: ?Sized + ToStatement,
559	{
560		match &self.inner {
561			ConnectionInner::Client(client) => {
562				client.execute(statement, params).await.map_err(Error::from)
563			}
564			ConnectionInner::Transaction(tr) => {
565				tr.execute(statement, params).await.map_err(Error::from)
566			}
567		}
568	}
569
570	/// See [`tokio_postgres::Client::execute_raw()`]
571	pub async fn execute_raw<T, P, I>(
572		&self,
573		statement: &T,
574		params: I,
575	) -> Result<u64, Error>
576	where
577		T: ?Sized + ToStatement,
578		P: BorrowToSql,
579		I: IntoIterator<Item = P>,
580		I::IntoIter: ExactSizeIterator,
581	{
582		match &self.inner {
583			ConnectionInner::Client(client) => client
584				.execute_raw(statement, params)
585				.await
586				.map_err(Error::from),
587			ConnectionInner::Transaction(tr) => {
588				tr.execute_raw(statement, params).await.map_err(Error::from)
589			}
590		}
591	}
592
593	/// See [`tokio_postgres::Client::batch_execute()`]
594	pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
595		match &self.inner {
596			ConnectionInner::Client(client) => {
597				client.batch_execute(query).await.map_err(Error::from)
598			}
599			ConnectionInner::Transaction(tr) => {
600				tr.batch_execute(query).await.map_err(Error::from)
601			}
602		}
603	}
604}
605
606fn slice_iter<'a>(
607	s: &'a [&'a (dyn ToSql + Sync)],
608) -> impl ExactSizeIterator<Item = &'a dyn ToSql> + 'a {
609	s.iter().map(|s| *s as _)
610}
611
612struct TwoExactSize<I, J>(I, J);
613
614impl<I, J, T> Iterator for TwoExactSize<I, J>
615where
616	I: ExactSizeIterator<Item = T>,
617	J: ExactSizeIterator<Item = T>,
618{
619	type Item = T;
620
621	fn next(&mut self) -> Option<Self::Item> {
622		self.0.next().or_else(|| self.1.next())
623	}
624
625	fn size_hint(&self) -> (usize, Option<usize>) {
626		let (a, b) = (self.0.size_hint(), self.1.size_hint());
627		(a.0 + b.0, a.1.and_then(|a| b.1.map(|b| a + b)))
628	}
629}
630
631impl<I, J, T> ExactSizeIterator for TwoExactSize<I, J>
632where
633	I: ExactSizeIterator<Item = T>,
634	J: ExactSizeIterator<Item = T>,
635{
636}