1use std::borrow::Borrow;
4use std::fmt::Write;
5
6use deadpool_postgres::Metrics;
7use deadpool_postgres::{ClientWrapper, Object};
8
9use futures_util::StreamExt;
10use futures_util::TryStreamExt;
11use futures_util::pin_mut;
12use postgres_types::{BorrowToSql, ToSql, Type};
13use tokio_postgres::Error as PgError;
14use tokio_postgres::error::SqlState;
15
16pub use deadpool::managed::TimeoutType;
17pub use deadpool_postgres::{Config, ConfigError};
18use tokio_postgres::Statement;
19use tokio_postgres::ToStatement;
20use tracing::error;
21
22use crate::Row;
23use crate::filter::Filter;
24use crate::filter::Limit;
25use crate::filter::WhereFilter;
26use crate::row::NamedColumns;
27use crate::row::RowStream;
28use crate::row::ToRowStatic;
29use crate::row::{FromRowOwned, ToRow};
30use crate::try2;
31
32#[derive(Debug, thiserror::Error)]
33#[non_exhaustive]
34pub enum Error {
35 #[error("Unique violation {constraint:?}")]
36 #[non_exhaustive]
37 UniqueViolation { constraint: Option<String> },
38
39 #[error("Expected one row")]
40 ExpectedOneRow,
41
42 #[error("Other Postgres error {0}")]
43 Other(PgError),
44
45 #[error("Deserialization error {0}")]
46 Deserialize(Box<dyn std::error::Error + Send + Sync>),
47
48 #[error("Unknown error {0}")]
49 Unknown(Box<dyn std::error::Error + Send + Sync>),
50}
51
52impl Error {
53 pub fn unique_violation(constraint: Option<String>) -> Self {
54 Self::UniqueViolation { constraint }
55 }
56}
57
58impl From<PgError> for Error {
59 fn from(e: PgError) -> Self {
60 let Some(db_error) = e.as_db_error() else {
61 return Self::Other(e);
62 };
63
64 match db_error.code() {
65 &SqlState::UNIQUE_VIOLATION => Self::UniqueViolation {
66 constraint: db_error.constraint().map(Into::into),
67 },
68 state => {
69 error!("db error with state {:?}", state);
70 Self::Other(e)
71 }
72 }
73 }
74}
75
76#[derive(Debug)]
77pub struct ConnectionOwned(pub(crate) Object);
78
79impl ConnectionOwned {
80 pub fn connection(&self) -> Connection<'_> {
81 Connection {
82 inner: ConnectionInner::Client(&self.0),
83 }
84 }
85
86 pub async fn transaction<'a>(
87 &'a mut self,
88 ) -> Result<Transaction<'a>, Error> {
89 Ok(Transaction {
90 inner: self.0.transaction().await.map_err(Error::from)?,
91 })
92 }
93
94 pub fn metrics(&self) -> &Metrics {
95 Object::metrics(&self.0)
96 }
97}
98
99#[cfg(feature = "chuchi")]
100mod impl_chuchi {
101 use chuchi::{
102 extractor::Extractor, extractor_extract, extractor_prepare,
103 extractor_validate,
104 };
105
106 use crate::{Database, database::DatabaseError};
107
108 use super::*;
109
110 impl<'a, R> Extractor<'a, R> for ConnectionOwned {
111 type Error = DatabaseError;
112 type Prepared = Self;
113
114 extractor_validate!(|validate| {
115 assert!(
116 validate.resources.exists::<Database>(),
117 "Db resource not found"
118 );
119 });
120
121 extractor_prepare!(|prepare| {
122 let db = prepare.resources.get::<Database>().unwrap();
123 db.get().await
124 });
125
126 extractor_extract!(|extract| { Ok(extract.prepared) });
127 }
128}
129
130#[derive(Debug)]
131pub struct Transaction<'a> {
132 inner: deadpool_postgres::Transaction<'a>,
133}
134
135impl<'a> Transaction<'a> {
136 pub fn connection(&self) -> Connection<'_> {
138 Connection {
139 inner: ConnectionInner::Transaction(&self.inner),
140 }
141 }
142
143 pub async fn commit(self) -> Result<(), Error> {
145 self.inner.commit().await.map_err(Error::from)
146 }
147
148 pub async fn rollback(self) -> Result<(), Error> {
150 self.inner.rollback().await.map_err(Error::from)
151 }
152}
153
154#[derive(Debug, Clone, Copy)]
155pub struct Connection<'a> {
156 inner: ConnectionInner<'a>,
157}
158
159#[derive(Debug, Clone, Copy)]
160enum ConnectionInner<'a> {
161 Client(&'a ClientWrapper),
162 Transaction(&'a deadpool_postgres::Transaction<'a>),
163}
164
165impl Connection<'_> {
166 pub async fn select<R>(
173 &self,
174 table: &str,
175 filter: impl Borrow<Filter<'_>>,
176 ) -> Result<Vec<R>, Error>
177 where
178 R: FromRowOwned + NamedColumns,
179 {
180 let sql = format!(
181 "SELECT {} FROM \"{}\"{}",
182 R::select_columns(),
183 table,
184 filter.borrow()
185 );
186 let stmt = self.prepare_cached(&sql).await?;
187
188 self.query_raw(&stmt, filter.borrow().params.iter_to_sql())
189 .await?
190 .map(|row| {
191 row.and_then(|row| {
192 R::from_row_owned(row).map_err(Error::Deserialize)
193 })
194 })
195 .try_collect()
196 .await
197 }
198
199 pub async fn select_one<R>(
201 &self,
202 table: &str,
203 filter: impl Borrow<Filter<'_>>,
204 ) -> Result<R, Error>
205 where
206 R: FromRowOwned + NamedColumns,
207 {
208 let mut formatter = filter.borrow().to_formatter();
209
210 if matches!(formatter.limit, Limit::All) {
211 formatter.limit = &Limit::Fixed(1);
212 }
213
214 let sql = format!(
215 "SELECT {} FROM \"{}\"{}",
216 R::select_columns(),
217 table,
218 formatter
219 );
220 let stmt = self.prepare_cached(&sql).await?;
221
222 let row = self
223 .query_raw_opt(&stmt, filter.borrow().params.iter_to_sql())
224 .await
225 .and_then(|opt| opt.ok_or(Error::ExpectedOneRow))?;
226
227 R::from_row_owned(row).map_err(Error::Deserialize)
228 }
229
230 pub async fn select_opt<R>(
232 &self,
233 table: &str,
234 filter: impl Borrow<Filter<'_>>,
235 ) -> Result<Option<R>, Error>
236 where
237 R: FromRowOwned + NamedColumns,
238 {
239 let mut formatter = filter.borrow().to_formatter();
240
241 if matches!(formatter.limit, Limit::All) {
242 formatter.limit = &Limit::Fixed(1);
243 }
244
245 let sql = format!(
246 "SELECT {} FROM \"{}\"{}",
247 R::select_columns(),
248 table,
249 formatter
250 );
251 let stmt = self.prepare_cached(&sql).await?;
252
253 self.query_raw_opt(&stmt, filter.borrow().params.iter_to_sql())
254 .await
255 }
256
257 pub async fn count(
262 &self,
263 table: &str,
264 column: &str,
265 filter: impl Borrow<Filter<'_>>,
266 ) -> Result<u64, Error> {
267 let sql = format!(
268 "SELECT COUNT(\"{column}\") FROM \"{table}\"{}",
269 filter.borrow()
270 );
271 let stmt = self.prepare_cached(&sql).await?;
272
273 let row: Row = self
274 .query_raw_opt(&stmt, filter.borrow().params.iter_to_sql())
275 .await
276 .and_then(|opt| opt.ok_or(Error::ExpectedOneRow))?;
277
278 Ok(row.get::<_, i64>(0) as u64)
279 }
280
281 pub async fn insert<U>(&self, table: &str, item: &U) -> Result<(), Error>
283 where
284 U: ToRow,
285 {
286 let mut sql = format!("INSERT INTO \"{table}\" (");
287 item.insert_columns(&mut sql);
288 sql.push_str(") VALUES (");
289 item.insert_values(&mut sql);
290 sql.push(')');
291
292 let stmt = self.prepare_cached(&sql).await?;
293
294 self.execute_raw(&stmt, item.params()).await.map(|_| ())
295 }
296
297 pub async fn insert_many<U, I>(
299 &self,
300 table: &str,
301 items: I,
302 ) -> Result<(), Error>
303 where
304 U: ToRowStatic,
305 I: IntoIterator,
306 I::Item: Borrow<U>,
307 {
308 let sql = format!(
309 "INSERT INTO \"{}\" ({}) VALUES ({})",
310 table,
311 U::insert_columns(),
312 U::insert_values()
313 );
314 let stmt = self.prepare_cached(&sql).await?;
315
316 for item in items {
317 self.execute_raw(&stmt, item.borrow().params()).await?;
318 }
319
320 Ok(())
321 }
322
323 pub async fn update<U>(
325 &self,
326 table: &str,
327 item: &U,
328 filter: impl Borrow<WhereFilter<'_>>,
329 ) -> Result<(), Error>
330 where
331 U: ToRow,
332 {
333 let filter = filter.borrow();
334 let mut formatter = filter.whr.to_formatter();
335 formatter.param_start = item.params_len();
336
337 let mut sql = format!("UPDATE \"{table}\" SET ");
338 item.update_columns(&mut sql);
339 write!(&mut sql, "{}", formatter).unwrap();
340
341 let stmt = self.prepare_cached(&sql).await?;
342
343 self.execute_raw(
346 &stmt,
347 TwoExactSize(item.params(), filter.params.iter_to_sql()),
348 )
349 .await
350 .map(|_| ())
351 }
352
353 pub async fn delete(
355 &self,
356 table: &str,
357 filter: impl Borrow<WhereFilter<'_>>,
358 ) -> Result<(), Error> {
359 let sql = format!("DELETE FROM \"{}\"{}", table, filter.borrow());
360 let stmt = self.prepare_cached(&sql).await?;
361
362 self.execute_raw(&stmt, filter.borrow().params.iter_to_sql())
363 .await
364 .map(|_| ())
365 }
366
367 pub async fn prepare_cached(
370 &self,
371 query: &str,
372 ) -> Result<Statement, Error> {
373 match &self.inner {
374 ConnectionInner::Client(client) => {
375 client.prepare_cached(query).await.map_err(Error::from)
376 }
377 ConnectionInner::Transaction(tr) => {
378 tr.prepare_cached(query).await.map_err(Error::from)
379 }
380 }
381 }
382
383 pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
385 match &self.inner {
386 ConnectionInner::Client(client) => {
387 client.prepare(query).await.map_err(Error::from)
388 }
389 ConnectionInner::Transaction(tr) => {
390 tr.prepare(query).await.map_err(Error::from)
391 }
392 }
393 }
394
395 pub async fn prepare_typed_cached(
398 &self,
399 query: &str,
400 types: &[Type],
401 ) -> Result<Statement, Error> {
402 match &self.inner {
403 ConnectionInner::Client(client) => client
404 .prepare_typed_cached(query, types)
405 .await
406 .map_err(Error::from),
407 ConnectionInner::Transaction(tr) => tr
408 .prepare_typed_cached(query, types)
409 .await
410 .map_err(Error::from),
411 }
412 }
413
414 pub async fn prepare_typed(
416 &self,
417 query: &str,
418 parameter_types: &[Type],
419 ) -> Result<Statement, Error> {
420 match &self.inner {
421 ConnectionInner::Client(client) => client
422 .prepare_typed(query, parameter_types)
423 .await
424 .map_err(Error::from),
425 ConnectionInner::Transaction(tr) => tr
426 .prepare_typed(query, parameter_types)
427 .await
428 .map_err(Error::from),
429 }
430 }
431
432 pub async fn query<R, T>(
434 &self,
435 statement: &T,
436 params: &[&(dyn ToSql + Sync)],
437 ) -> Result<Vec<R>, Error>
438 where
439 R: FromRowOwned,
440 T: ?Sized + ToStatement,
441 {
442 self.query_raw(statement, slice_iter(params))
443 .await?
444 .map(|row| {
445 row.and_then(|row| {
446 R::from_row_owned(row).map_err(Error::Deserialize)
447 })
448 })
449 .try_collect()
450 .await
451 }
452
453 pub async fn query_one<R, T>(
455 &self,
456 statement: &T,
457 params: &[&(dyn ToSql + Sync)],
458 ) -> Result<R, Error>
459 where
460 R: FromRowOwned,
461 T: ?Sized + ToStatement,
462 {
463 let row = match &self.inner {
464 ConnectionInner::Client(client) => {
465 client.query_one(statement, params).await?
466 }
467 ConnectionInner::Transaction(tr) => {
468 tr.query_one(statement, params).await?
469 }
470 };
471
472 R::from_row_owned(row.into()).map_err(Error::Deserialize)
473 }
474
475 pub async fn query_opt<R, T>(
477 &self,
478 statement: &T,
479 params: &[&(dyn ToSql + Sync)],
480 ) -> Result<Option<R>, Error>
481 where
482 R: FromRowOwned,
483 T: ?Sized + ToStatement,
484 {
485 let row = match &self.inner {
486 ConnectionInner::Client(client) => {
487 client.query_opt(statement, params).await?
488 }
489 ConnectionInner::Transaction(tr) => {
490 tr.query_opt(statement, params).await?
491 }
492 };
493
494 R::from_row_owned(try2!(row).into())
495 .map(Some)
496 .map_err(Error::Deserialize)
497 }
498
499 pub async fn query_raw_opt<R, T, P, I>(
501 &self,
502 statement: &T,
503 params: I,
504 ) -> Result<Option<R>, Error>
505 where
506 R: FromRowOwned,
507 T: ?Sized + ToStatement,
508 P: BorrowToSql,
509 I: IntoIterator<Item = P>,
510 I::IntoIter: ExactSizeIterator,
511 {
512 let stream = self.query_raw(statement, params).await?;
513 pin_mut!(stream);
514
515 let row = match stream.try_next().await? {
516 Some(row) => row,
517 None => return Ok(None),
518 };
519
520 if stream.try_next().await?.is_some() {
521 return Err(Error::ExpectedOneRow);
522 }
523
524 R::from_row_owned(row).map(Some).map_err(Error::Deserialize)
525 }
526
527 pub async fn query_raw<T, P, I>(
529 &self,
530 statement: &T,
531 params: I,
532 ) -> Result<RowStream, Error>
533 where
534 T: ?Sized + ToStatement,
535 P: BorrowToSql,
536 I: IntoIterator<Item = P>,
537 I::IntoIter: ExactSizeIterator,
538 {
539 let row_stream = match &self.inner {
540 ConnectionInner::Client(client) => {
541 client.query_raw(statement, params).await?
542 }
543 ConnectionInner::Transaction(tr) => {
544 tr.query_raw(statement, params).await?
545 }
546 };
547
548 Ok(row_stream.into())
549 }
550
551 pub async fn execute<T>(
553 &self,
554 statement: &T,
555 params: &[&(dyn ToSql + Sync)],
556 ) -> Result<u64, Error>
557 where
558 T: ?Sized + ToStatement,
559 {
560 match &self.inner {
561 ConnectionInner::Client(client) => {
562 client.execute(statement, params).await.map_err(Error::from)
563 }
564 ConnectionInner::Transaction(tr) => {
565 tr.execute(statement, params).await.map_err(Error::from)
566 }
567 }
568 }
569
570 pub async fn execute_raw<T, P, I>(
572 &self,
573 statement: &T,
574 params: I,
575 ) -> Result<u64, Error>
576 where
577 T: ?Sized + ToStatement,
578 P: BorrowToSql,
579 I: IntoIterator<Item = P>,
580 I::IntoIter: ExactSizeIterator,
581 {
582 match &self.inner {
583 ConnectionInner::Client(client) => client
584 .execute_raw(statement, params)
585 .await
586 .map_err(Error::from),
587 ConnectionInner::Transaction(tr) => {
588 tr.execute_raw(statement, params).await.map_err(Error::from)
589 }
590 }
591 }
592
593 pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
595 match &self.inner {
596 ConnectionInner::Client(client) => {
597 client.batch_execute(query).await.map_err(Error::from)
598 }
599 ConnectionInner::Transaction(tr) => {
600 tr.batch_execute(query).await.map_err(Error::from)
601 }
602 }
603 }
604}
605
606fn slice_iter<'a>(
607 s: &'a [&'a (dyn ToSql + Sync)],
608) -> impl ExactSizeIterator<Item = &'a dyn ToSql> + 'a {
609 s.iter().map(|s| *s as _)
610}
611
612struct TwoExactSize<I, J>(I, J);
613
614impl<I, J, T> Iterator for TwoExactSize<I, J>
615where
616 I: ExactSizeIterator<Item = T>,
617 J: ExactSizeIterator<Item = T>,
618{
619 type Item = T;
620
621 fn next(&mut self) -> Option<Self::Item> {
622 self.0.next().or_else(|| self.1.next())
623 }
624
625 fn size_hint(&self) -> (usize, Option<usize>) {
626 let (a, b) = (self.0.size_hint(), self.1.size_hint());
627 (a.0 + b.0, a.1.and_then(|a| b.1.map(|b| a + b)))
628 }
629}
630
631impl<I, J, T> ExactSizeIterator for TwoExactSize<I, J>
632where
633 I: ExactSizeIterator<Item = T>,
634 J: ExactSizeIterator<Item = T>,
635{
636}