1use std::borrow::Borrow;
4use std::fmt::Write;
5
6use deadpool_postgres::Metrics;
7use deadpool_postgres::{ClientWrapper, Object};
8
9use futures_util::StreamExt;
10use futures_util::TryStreamExt;
11use futures_util::pin_mut;
12use postgres_types::{BorrowToSql, ToSql, Type};
13use tokio_postgres::Error as PgError;
14use tokio_postgres::error::SqlState;
15
16pub use deadpool::managed::TimeoutType;
17pub use deadpool_postgres::{Config, ConfigError};
18use tokio_postgres::Statement;
19use tokio_postgres::ToStatement;
20use tracing::error;
21
22use crate::Row;
23use crate::filter::Filter;
24use crate::filter::Limit;
25use crate::filter::WhereFilter;
26use crate::row::NamedColumns;
27use crate::row::RowStream;
28use crate::row::ToRowStatic;
29use crate::row::{FromRowOwned, ToRow};
30use crate::try2;
31
32#[derive(Debug, thiserror::Error)]
33#[non_exhaustive]
34pub enum Error {
35 #[error("Unique violation {constraint:?}")]
36 #[non_exhaustive]
37 UniqueViolation { constraint: Option<String> },
38
39 #[error("Expected one row")]
40 ExpectedOneRow,
41
42 #[error("Other Postgres error {0}")]
43 Other(PgError),
44
45 #[error("Deserialization error {0}")]
46 Deserialize(Box<dyn std::error::Error + Send + Sync>),
47
48 #[error("Unknown error {0}")]
49 Unknown(Box<dyn std::error::Error + Send + Sync>),
50}
51
52impl From<PgError> for Error {
53 fn from(e: PgError) -> Self {
54 let Some(db_error) = e.as_db_error() else {
55 return Self::Other(e);
56 };
57
58 match db_error.code() {
59 &SqlState::UNIQUE_VIOLATION => Self::UniqueViolation {
60 constraint: db_error.constraint().map(Into::into),
61 },
62 state => {
63 error!("db error with state {:?}", state);
64 Self::Other(e)
65 }
66 }
67 }
68}
69
70#[derive(Debug)]
71pub struct ConnectionOwned(pub(crate) Object);
72
73impl ConnectionOwned {
74 pub fn connection(&self) -> Connection<'_> {
75 Connection {
76 inner: ConnectionInner::Client(&self.0),
77 }
78 }
79
80 pub async fn transaction<'a>(
81 &'a mut self,
82 ) -> Result<Transaction<'a>, Error> {
83 Ok(Transaction {
84 inner: self.0.transaction().await.map_err(Error::from)?,
85 })
86 }
87
88 pub fn metrics(&self) -> &Metrics {
89 Object::metrics(&self.0)
90 }
91}
92
93#[cfg(feature = "chuchi")]
94mod impl_chuchi {
95 use chuchi::{
96 extractor::Extractor, extractor_extract, extractor_prepare,
97 extractor_validate,
98 };
99
100 use crate::{Database, database::DatabaseError};
101
102 use super::*;
103
104 impl<'a, R> Extractor<'a, R> for ConnectionOwned {
105 type Error = DatabaseError;
106 type Prepared = Self;
107
108 extractor_validate!(|validate| {
109 assert!(
110 validate.resources.exists::<Database>(),
111 "Db resource not found"
112 );
113 });
114
115 extractor_prepare!(|prepare| {
116 let db = prepare.resources.get::<Database>().unwrap();
117 db.get().await
118 });
119
120 extractor_extract!(|extract| { Ok(extract.prepared) });
121 }
122}
123
124#[derive(Debug)]
125pub struct Transaction<'a> {
126 inner: deadpool_postgres::Transaction<'a>,
127}
128
129impl<'a> Transaction<'a> {
130 pub fn connection(&self) -> Connection<'_> {
132 Connection {
133 inner: ConnectionInner::Transaction(&self.inner),
134 }
135 }
136
137 pub async fn commit(self) -> Result<(), Error> {
139 self.inner.commit().await.map_err(Error::from)
140 }
141
142 pub async fn rollback(self) -> Result<(), Error> {
144 self.inner.rollback().await.map_err(Error::from)
145 }
146}
147
148#[derive(Debug, Clone, Copy)]
149pub struct Connection<'a> {
150 inner: ConnectionInner<'a>,
151}
152
153#[derive(Debug, Clone, Copy)]
154enum ConnectionInner<'a> {
155 Client(&'a ClientWrapper),
156 Transaction(&'a deadpool_postgres::Transaction<'a>),
157}
158
159impl Connection<'_> {
160 pub async fn select<R>(
167 &self,
168 table: &str,
169 filter: impl Borrow<Filter<'_>>,
170 ) -> Result<Vec<R>, Error>
171 where
172 R: FromRowOwned + NamedColumns,
173 {
174 let sql = format!(
175 "SELECT {} FROM \"{}\"{}",
176 R::select_columns(),
177 table,
178 filter.borrow()
179 );
180 let stmt = self.prepare_cached(&sql).await?;
181
182 self.query_raw(&stmt, filter.borrow().params.iter_to_sql())
183 .await?
184 .map(|row| {
185 row.and_then(|row| {
186 R::from_row_owned(row).map_err(Error::Deserialize)
187 })
188 })
189 .try_collect()
190 .await
191 }
192
193 pub async fn select_one<R>(
195 &self,
196 table: &str,
197 filter: impl Borrow<Filter<'_>>,
198 ) -> Result<R, Error>
199 where
200 R: FromRowOwned + NamedColumns,
201 {
202 let mut formatter = filter.borrow().to_formatter();
203
204 if matches!(formatter.limit, Limit::All) {
205 formatter.limit = &Limit::Fixed(1);
206 }
207
208 let sql = format!(
209 "SELECT {} FROM \"{}\"{}",
210 R::select_columns(),
211 table,
212 formatter
213 );
214 let stmt = self.prepare_cached(&sql).await?;
215
216 let row = self
217 .query_raw_opt(&stmt, filter.borrow().params.iter_to_sql())
218 .await
219 .and_then(|opt| opt.ok_or(Error::ExpectedOneRow))?;
220
221 R::from_row_owned(row).map_err(Error::Deserialize)
222 }
223
224 pub async fn select_opt<R>(
226 &self,
227 table: &str,
228 filter: impl Borrow<Filter<'_>>,
229 ) -> Result<Option<R>, Error>
230 where
231 R: FromRowOwned + NamedColumns,
232 {
233 let mut formatter = filter.borrow().to_formatter();
234
235 if matches!(formatter.limit, Limit::All) {
236 formatter.limit = &Limit::Fixed(1);
237 }
238
239 let sql = format!(
240 "SELECT {} FROM \"{}\"{}",
241 R::select_columns(),
242 table,
243 formatter
244 );
245 let stmt = self.prepare_cached(&sql).await?;
246
247 self.query_raw_opt(&stmt, filter.borrow().params.iter_to_sql())
248 .await
249 }
250
251 pub async fn count(
256 &self,
257 table: &str,
258 column: &str,
259 filter: impl Borrow<Filter<'_>>,
260 ) -> Result<u64, Error> {
261 let sql = format!(
262 "SELECT COUNT(\"{column}\") FROM \"{table}\"{}",
263 filter.borrow()
264 );
265 let stmt = self.prepare_cached(&sql).await?;
266
267 let row: Row = self
268 .query_raw_opt(&stmt, filter.borrow().params.iter_to_sql())
269 .await
270 .and_then(|opt| opt.ok_or(Error::ExpectedOneRow))?;
271
272 Ok(row.get::<_, i64>(0) as u64)
273 }
274
275 pub async fn insert<U>(&self, table: &str, item: &U) -> Result<(), Error>
277 where
278 U: ToRow,
279 {
280 let mut sql = format!("INSERT INTO \"{table}\" (");
281 item.insert_columns(&mut sql);
282 sql.push_str(") VALUES (");
283 item.insert_values(&mut sql);
284 sql.push(')');
285
286 let stmt = self.prepare_cached(&sql).await?;
287
288 self.execute_raw(&stmt, item.params()).await.map(|_| ())
289 }
290
291 pub async fn insert_many<U, I>(
293 &self,
294 table: &str,
295 items: I,
296 ) -> Result<(), Error>
297 where
298 U: ToRowStatic,
299 I: IntoIterator,
300 I::Item: Borrow<U>,
301 {
302 let sql = format!(
303 "INSERT INTO \"{}\" ({}) VALUES ({})",
304 table,
305 U::insert_columns(),
306 U::insert_values()
307 );
308 let stmt = self.prepare_cached(&sql).await?;
309
310 for item in items {
311 self.execute_raw(&stmt, item.borrow().params()).await?;
312 }
313
314 Ok(())
315 }
316
317 pub async fn update<U>(
319 &self,
320 table: &str,
321 item: &U,
322 filter: impl Borrow<WhereFilter<'_>>,
323 ) -> Result<(), Error>
324 where
325 U: ToRow,
326 {
327 let filter = filter.borrow();
328 let mut formatter = filter.whr.to_formatter();
329 formatter.param_start = item.params_len();
330
331 let mut sql = format!("UPDATE \"{table}\" SET ");
332 item.update_columns(&mut sql);
333 write!(&mut sql, "{}", formatter).unwrap();
334
335 let stmt = self.prepare_cached(&sql).await?;
336
337 self.execute_raw(
340 &stmt,
341 TwoExactSize(item.params(), filter.params.iter_to_sql()),
342 )
343 .await
344 .map(|_| ())
345 }
346
347 pub async fn delete(
349 &self,
350 table: &str,
351 filter: impl Borrow<WhereFilter<'_>>,
352 ) -> Result<(), Error> {
353 let sql = format!("DELETE FROM \"{}\"{}", table, filter.borrow());
354 let stmt = self.prepare_cached(&sql).await?;
355
356 self.execute_raw(&stmt, filter.borrow().params.iter_to_sql())
357 .await
358 .map(|_| ())
359 }
360
361 pub async fn prepare_cached(
364 &self,
365 query: &str,
366 ) -> Result<Statement, Error> {
367 match &self.inner {
368 ConnectionInner::Client(client) => {
369 client.prepare_cached(query).await.map_err(Error::from)
370 }
371 ConnectionInner::Transaction(tr) => {
372 tr.prepare_cached(query).await.map_err(Error::from)
373 }
374 }
375 }
376
377 pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
379 match &self.inner {
380 ConnectionInner::Client(client) => {
381 client.prepare(query).await.map_err(Error::from)
382 }
383 ConnectionInner::Transaction(tr) => {
384 tr.prepare(query).await.map_err(Error::from)
385 }
386 }
387 }
388
389 pub async fn prepare_typed_cached(
392 &self,
393 query: &str,
394 types: &[Type],
395 ) -> Result<Statement, Error> {
396 match &self.inner {
397 ConnectionInner::Client(client) => client
398 .prepare_typed_cached(query, types)
399 .await
400 .map_err(Error::from),
401 ConnectionInner::Transaction(tr) => tr
402 .prepare_typed_cached(query, types)
403 .await
404 .map_err(Error::from),
405 }
406 }
407
408 pub async fn prepare_typed(
410 &self,
411 query: &str,
412 parameter_types: &[Type],
413 ) -> Result<Statement, Error> {
414 match &self.inner {
415 ConnectionInner::Client(client) => client
416 .prepare_typed(query, parameter_types)
417 .await
418 .map_err(Error::from),
419 ConnectionInner::Transaction(tr) => tr
420 .prepare_typed(query, parameter_types)
421 .await
422 .map_err(Error::from),
423 }
424 }
425
426 pub async fn query<R, T>(
428 &self,
429 statement: &T,
430 params: &[&(dyn ToSql + Sync)],
431 ) -> Result<Vec<R>, Error>
432 where
433 R: FromRowOwned,
434 T: ?Sized + ToStatement,
435 {
436 self.query_raw(statement, slice_iter(params))
437 .await?
438 .map(|row| {
439 row.and_then(|row| {
440 R::from_row_owned(row).map_err(Error::Deserialize)
441 })
442 })
443 .try_collect()
444 .await
445 }
446
447 pub async fn query_one<R, T>(
449 &self,
450 statement: &T,
451 params: &[&(dyn ToSql + Sync)],
452 ) -> Result<R, Error>
453 where
454 R: FromRowOwned,
455 T: ?Sized + ToStatement,
456 {
457 let row = match &self.inner {
458 ConnectionInner::Client(client) => {
459 client.query_one(statement, params).await?
460 }
461 ConnectionInner::Transaction(tr) => {
462 tr.query_one(statement, params).await?
463 }
464 };
465
466 R::from_row_owned(row.into()).map_err(Error::Deserialize)
467 }
468
469 pub async fn query_opt<R, T>(
471 &self,
472 statement: &T,
473 params: &[&(dyn ToSql + Sync)],
474 ) -> Result<Option<R>, Error>
475 where
476 R: FromRowOwned,
477 T: ?Sized + ToStatement,
478 {
479 let row = match &self.inner {
480 ConnectionInner::Client(client) => {
481 client.query_opt(statement, params).await?
482 }
483 ConnectionInner::Transaction(tr) => {
484 tr.query_opt(statement, params).await?
485 }
486 };
487
488 R::from_row_owned(try2!(row).into())
489 .map(Some)
490 .map_err(Error::Deserialize)
491 }
492
493 pub async fn query_raw_opt<R, T, P, I>(
495 &self,
496 statement: &T,
497 params: I,
498 ) -> Result<Option<R>, Error>
499 where
500 R: FromRowOwned,
501 T: ?Sized + ToStatement,
502 P: BorrowToSql,
503 I: IntoIterator<Item = P>,
504 I::IntoIter: ExactSizeIterator,
505 {
506 let stream = self.query_raw(statement, params).await?;
507 pin_mut!(stream);
508
509 let row = match stream.try_next().await? {
510 Some(row) => row,
511 None => return Ok(None),
512 };
513
514 if stream.try_next().await?.is_some() {
515 return Err(Error::ExpectedOneRow);
516 }
517
518 R::from_row_owned(row).map(Some).map_err(Error::Deserialize)
519 }
520
521 pub async fn query_raw<T, P, I>(
523 &self,
524 statement: &T,
525 params: I,
526 ) -> Result<RowStream, Error>
527 where
528 T: ?Sized + ToStatement,
529 P: BorrowToSql,
530 I: IntoIterator<Item = P>,
531 I::IntoIter: ExactSizeIterator,
532 {
533 let row_stream = match &self.inner {
534 ConnectionInner::Client(client) => {
535 client.query_raw(statement, params).await?
536 }
537 ConnectionInner::Transaction(tr) => {
538 tr.query_raw(statement, params).await?
539 }
540 };
541
542 Ok(row_stream.into())
543 }
544
545 pub async fn execute<T>(
547 &self,
548 statement: &T,
549 params: &[&(dyn ToSql + Sync)],
550 ) -> Result<u64, Error>
551 where
552 T: ?Sized + ToStatement,
553 {
554 match &self.inner {
555 ConnectionInner::Client(client) => {
556 client.execute(statement, params).await.map_err(Error::from)
557 }
558 ConnectionInner::Transaction(tr) => {
559 tr.execute(statement, params).await.map_err(Error::from)
560 }
561 }
562 }
563
564 pub async fn execute_raw<T, P, I>(
566 &self,
567 statement: &T,
568 params: I,
569 ) -> Result<u64, Error>
570 where
571 T: ?Sized + ToStatement,
572 P: BorrowToSql,
573 I: IntoIterator<Item = P>,
574 I::IntoIter: ExactSizeIterator,
575 {
576 match &self.inner {
577 ConnectionInner::Client(client) => client
578 .execute_raw(statement, params)
579 .await
580 .map_err(Error::from),
581 ConnectionInner::Transaction(tr) => {
582 tr.execute_raw(statement, params).await.map_err(Error::from)
583 }
584 }
585 }
586
587 pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
589 match &self.inner {
590 ConnectionInner::Client(client) => {
591 client.batch_execute(query).await.map_err(Error::from)
592 }
593 ConnectionInner::Transaction(tr) => {
594 tr.batch_execute(query).await.map_err(Error::from)
595 }
596 }
597 }
598}
599
600fn slice_iter<'a>(
601 s: &'a [&'a (dyn ToSql + Sync)],
602) -> impl ExactSizeIterator<Item = &'a dyn ToSql> + 'a {
603 s.iter().map(|s| *s as _)
604}
605
606struct TwoExactSize<I, J>(I, J);
607
608impl<I, J, T> Iterator for TwoExactSize<I, J>
609where
610 I: ExactSizeIterator<Item = T>,
611 J: ExactSizeIterator<Item = T>,
612{
613 type Item = T;
614
615 fn next(&mut self) -> Option<Self::Item> {
616 self.0.next().or_else(|| self.1.next())
617 }
618
619 fn size_hint(&self) -> (usize, Option<usize>) {
620 let (a, b) = (self.0.size_hint(), self.1.size_hint());
621 (a.0 + b.0, a.1.and_then(|a| b.1.map(|b| a + b)))
622 }
623}
624
625impl<I, J, T> ExactSizeIterator for TwoExactSize<I, J>
626where
627 I: ExactSizeIterator<Item = T>,
628 J: ExactSizeIterator<Item = T>,
629{
630}