1use std::borrow::Borrow;
4use std::fmt::Write;
5
6use deadpool_postgres::Metrics;
7use deadpool_postgres::{ClientWrapper, Object};
8
9use futures_util::StreamExt;
10use futures_util::TryStreamExt;
11use futures_util::pin_mut;
12use postgres_types::{BorrowToSql, ToSql, Type};
13use tokio_postgres::Error as PgError;
14use tokio_postgres::error::SqlState;
15
16pub use deadpool::managed::TimeoutType;
17pub use deadpool_postgres::{Config, ConfigError};
18use tokio_postgres::Statement;
19use tokio_postgres::ToStatement;
20use tracing::error;
21
22use crate::Row;
23use crate::filter::Filter;
24use crate::filter::Limit;
25use crate::filter::WhereFilter;
26use crate::row::NamedColumns;
27use crate::row::RowStream;
28use crate::row::ToRowStatic;
29use crate::row::{FromRowOwned, ToRow};
30use crate::try2;
31
32#[derive(Debug, thiserror::Error)]
33#[non_exhaustive]
34pub enum Error {
35 #[error("Unique violation {0}")]
36 UniqueViolation(PgError),
37
38 #[error("Expected one row")]
39 ExpectedOneRow,
40
41 #[error("Other Postgres error {0}")]
42 Other(PgError),
43
44 #[error("Deserialization error {0}")]
45 Deserialize(Box<dyn std::error::Error + Send + Sync>),
46
47 #[error("Unknown error {0}")]
48 Unknown(Box<dyn std::error::Error + Send + Sync>),
49}
50
51impl From<PgError> for Error {
52 fn from(e: PgError) -> Self {
53 let Some(state) = e.code() else {
54 return Self::Other(e);
55 };
56
57 match state {
58 &SqlState::UNIQUE_VIOLATION => Self::UniqueViolation(e),
59 state => {
60 error!("db error with state {:?}", state);
61 Self::Other(e)
62 }
63 }
64 }
65}
66
67#[derive(Debug)]
68pub struct ConnectionOwned(pub(crate) Object);
69
70impl ConnectionOwned {
71 pub fn connection(&self) -> Connection {
72 Connection {
73 inner: ConnectionInner::Client(&self.0),
74 }
75 }
76
77 pub async fn transaction<'a>(
78 &'a mut self,
79 ) -> Result<Transaction<'a>, Error> {
80 Ok(Transaction {
81 inner: self.0.transaction().await.map_err(Error::from)?,
82 })
83 }
84
85 pub fn metrics(&self) -> &Metrics {
86 Object::metrics(&self.0)
87 }
88}
89
90#[cfg(feature = "chuchi")]
91mod impl_chuchi {
92 use chuchi::{
93 extractor::Extractor, extractor_extract, extractor_prepare,
94 extractor_validate,
95 };
96
97 use crate::{Database, database::DatabaseError};
98
99 use super::*;
100
101 impl<'a, R> Extractor<'a, R> for ConnectionOwned {
102 type Error = DatabaseError;
103 type Prepared = Self;
104
105 extractor_validate!(|validate| {
106 assert!(
107 validate.resources.exists::<Database>(),
108 "Db resource not found"
109 );
110 });
111
112 extractor_prepare!(|prepare| {
113 let db = prepare.resources.get::<Database>().unwrap();
114 db.get().await
115 });
116
117 extractor_extract!(|extract| { Ok(extract.prepared) });
118 }
119}
120
121#[derive(Debug)]
122pub struct Transaction<'a> {
123 inner: deadpool_postgres::Transaction<'a>,
124}
125
126impl<'a> Transaction<'a> {
127 pub fn connection(&self) -> Connection<'_> {
129 Connection {
130 inner: ConnectionInner::Transaction(&self.inner),
131 }
132 }
133
134 pub async fn commit(self) -> Result<(), Error> {
136 self.inner.commit().await.map_err(Error::from)
137 }
138
139 pub async fn rollback(self) -> Result<(), Error> {
141 self.inner.rollback().await.map_err(Error::from)
142 }
143}
144
145#[derive(Debug, Clone, Copy)]
146pub struct Connection<'a> {
147 inner: ConnectionInner<'a>,
148}
149
150#[derive(Debug, Clone, Copy)]
151enum ConnectionInner<'a> {
152 Client(&'a ClientWrapper),
153 Transaction(&'a deadpool_postgres::Transaction<'a>),
154}
155
156impl Connection<'_> {
157 pub async fn select<R>(
164 &self,
165 table: &str,
166 filter: impl Borrow<Filter<'_>>,
167 ) -> Result<Vec<R>, Error>
168 where
169 R: FromRowOwned + NamedColumns,
170 {
171 let sql = format!(
172 "SELECT {} FROM \"{}\"{}",
173 R::select_columns(),
174 table,
175 filter.borrow()
176 );
177 let stmt = self.prepare_cached(&sql).await?;
178
179 self.query_raw(&stmt, filter.borrow().params.iter_to_sql())
180 .await?
181 .map(|row| {
182 row.and_then(|row| {
183 R::from_row_owned(row).map_err(Error::Deserialize)
184 })
185 })
186 .try_collect()
187 .await
188 }
189
190 pub async fn select_one<R>(
192 &self,
193 table: &str,
194 filter: impl Borrow<Filter<'_>>,
195 ) -> Result<R, Error>
196 where
197 R: FromRowOwned + NamedColumns,
198 {
199 let mut formatter = filter.borrow().to_formatter();
200
201 if matches!(formatter.limit, Limit::All) {
202 formatter.limit = &Limit::Fixed(1);
203 }
204
205 let sql = format!(
206 "SELECT {} FROM \"{}\"{}",
207 R::select_columns(),
208 table,
209 formatter
210 );
211 let stmt = self.prepare_cached(&sql).await?;
212
213 let row = self
214 .query_raw_opt(&stmt, filter.borrow().params.iter_to_sql())
215 .await
216 .and_then(|opt| opt.ok_or(Error::ExpectedOneRow))?;
217
218 R::from_row_owned(row).map_err(Error::Deserialize)
219 }
220
221 pub async fn select_opt<R>(
223 &self,
224 table: &str,
225 filter: impl Borrow<Filter<'_>>,
226 ) -> Result<Option<R>, Error>
227 where
228 R: FromRowOwned + NamedColumns,
229 {
230 let mut formatter = filter.borrow().to_formatter();
231
232 if matches!(formatter.limit, Limit::All) {
233 formatter.limit = &Limit::Fixed(1);
234 }
235
236 let sql = format!(
237 "SELECT {} FROM \"{}\"{}",
238 R::select_columns(),
239 table,
240 formatter
241 );
242 let stmt = self.prepare_cached(&sql).await?;
243
244 self.query_raw_opt(&stmt, filter.borrow().params.iter_to_sql())
245 .await
246 }
247
248 pub async fn count(
253 &self,
254 table: &str,
255 column: &str,
256 filter: impl Borrow<Filter<'_>>,
257 ) -> Result<u64, Error> {
258 let sql = format!(
259 "SELECT COUNT(\"{column}\") FROM \"{table}\"{}",
260 filter.borrow()
261 );
262 let stmt = self.prepare_cached(&sql).await?;
263
264 let row: Row = self
265 .query_raw_opt(&stmt, filter.borrow().params.iter_to_sql())
266 .await
267 .and_then(|opt| opt.ok_or(Error::ExpectedOneRow))?;
268
269 Ok(row.get::<_, i64>(0) as u64)
270 }
271
272 pub async fn insert<U>(&self, table: &str, item: &U) -> Result<(), Error>
274 where
275 U: ToRow,
276 {
277 let mut sql = format!("INSERT INTO \"{table}\" (");
278 item.insert_columns(&mut sql);
279 sql.push_str(") VALUES (");
280 item.insert_values(&mut sql);
281 sql.push(')');
282
283 let stmt = self.prepare_cached(&sql).await?;
284
285 self.execute_raw(&stmt, item.params()).await.map(|_| ())
286 }
287
288 pub async fn insert_many<U, I>(
290 &self,
291 table: &str,
292 items: I,
293 ) -> Result<(), Error>
294 where
295 U: ToRowStatic,
296 I: IntoIterator,
297 I::Item: Borrow<U>,
298 {
299 let sql = format!(
300 "INSERT INTO \"{}\" ({}) VALUES ({})",
301 table,
302 U::insert_columns(),
303 U::insert_values()
304 );
305 let stmt = self.prepare_cached(&sql).await?;
306
307 for item in items {
308 self.execute_raw(&stmt, item.borrow().params()).await?;
309 }
310
311 Ok(())
312 }
313
314 pub async fn update<U>(
316 &self,
317 table: &str,
318 item: &U,
319 filter: impl Borrow<WhereFilter<'_>>,
320 ) -> Result<(), Error>
321 where
322 U: ToRow,
323 {
324 let filter = filter.borrow();
325 let mut formatter = filter.whr.to_formatter();
326 formatter.param_start = item.params_len();
327
328 let mut sql = format!("UPDATE \"{table}\" SET ");
329 item.update_columns(&mut sql);
330 write!(&mut sql, "{}", formatter).unwrap();
331
332 let stmt = self.prepare_cached(&sql).await?;
333
334 self.execute_raw(
337 &stmt,
338 TwoExactSize(item.params(), filter.params.iter_to_sql()),
339 )
340 .await
341 .map(|_| ())
342 }
343
344 pub async fn delete(
346 &self,
347 table: &str,
348 filter: impl Borrow<WhereFilter<'_>>,
349 ) -> Result<(), Error> {
350 let sql = format!("DELETE FROM \"{}\"{}", table, filter.borrow());
351 let stmt = self.prepare_cached(&sql).await?;
352
353 self.execute_raw(&stmt, filter.borrow().params.iter_to_sql())
354 .await
355 .map(|_| ())
356 }
357
358 pub async fn prepare_cached(
361 &self,
362 query: &str,
363 ) -> Result<Statement, Error> {
364 match &self.inner {
365 ConnectionInner::Client(client) => {
366 client.prepare_cached(query).await.map_err(Error::from)
367 }
368 ConnectionInner::Transaction(tr) => {
369 tr.prepare_cached(query).await.map_err(Error::from)
370 }
371 }
372 }
373
374 pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
376 match &self.inner {
377 ConnectionInner::Client(client) => {
378 client.prepare(query).await.map_err(Error::from)
379 }
380 ConnectionInner::Transaction(tr) => {
381 tr.prepare(query).await.map_err(Error::from)
382 }
383 }
384 }
385
386 pub async fn prepare_typed_cached(
389 &self,
390 query: &str,
391 types: &[Type],
392 ) -> Result<Statement, Error> {
393 match &self.inner {
394 ConnectionInner::Client(client) => client
395 .prepare_typed_cached(query, types)
396 .await
397 .map_err(Error::from),
398 ConnectionInner::Transaction(tr) => tr
399 .prepare_typed_cached(query, types)
400 .await
401 .map_err(Error::from),
402 }
403 }
404
405 pub async fn prepare_typed(
407 &self,
408 query: &str,
409 parameter_types: &[Type],
410 ) -> Result<Statement, Error> {
411 match &self.inner {
412 ConnectionInner::Client(client) => client
413 .prepare_typed(query, parameter_types)
414 .await
415 .map_err(Error::from),
416 ConnectionInner::Transaction(tr) => tr
417 .prepare_typed(query, parameter_types)
418 .await
419 .map_err(Error::from),
420 }
421 }
422
423 pub async fn query<R, T>(
425 &self,
426 statement: &T,
427 params: &[&(dyn ToSql + Sync)],
428 ) -> Result<Vec<R>, Error>
429 where
430 R: FromRowOwned,
431 T: ?Sized + ToStatement,
432 {
433 self.query_raw(statement, slice_iter(params))
434 .await?
435 .map(|row| {
436 row.and_then(|row| {
437 R::from_row_owned(row).map_err(Error::Deserialize)
438 })
439 })
440 .try_collect()
441 .await
442 }
443
444 pub async fn query_one<R, T>(
446 &self,
447 statement: &T,
448 params: &[&(dyn ToSql + Sync)],
449 ) -> Result<R, Error>
450 where
451 R: FromRowOwned,
452 T: ?Sized + ToStatement,
453 {
454 let row = match &self.inner {
455 ConnectionInner::Client(client) => {
456 client.query_one(statement, params).await?
457 }
458 ConnectionInner::Transaction(tr) => {
459 tr.query_one(statement, params).await?
460 }
461 };
462
463 R::from_row_owned(row.into()).map_err(Error::Deserialize)
464 }
465
466 pub async fn query_opt<R, T>(
468 &self,
469 statement: &T,
470 params: &[&(dyn ToSql + Sync)],
471 ) -> Result<Option<R>, Error>
472 where
473 R: FromRowOwned,
474 T: ?Sized + ToStatement,
475 {
476 let row = match &self.inner {
477 ConnectionInner::Client(client) => {
478 client.query_opt(statement, params).await?
479 }
480 ConnectionInner::Transaction(tr) => {
481 tr.query_opt(statement, params).await?
482 }
483 };
484
485 R::from_row_owned(try2!(row).into())
486 .map(Some)
487 .map_err(Error::Deserialize)
488 }
489
490 pub async fn query_raw_opt<R, T, P, I>(
492 &self,
493 statement: &T,
494 params: I,
495 ) -> Result<Option<R>, Error>
496 where
497 R: FromRowOwned,
498 T: ?Sized + ToStatement,
499 P: BorrowToSql,
500 I: IntoIterator<Item = P>,
501 I::IntoIter: ExactSizeIterator,
502 {
503 let stream = self.query_raw(statement, params).await?;
504 pin_mut!(stream);
505
506 let row = match stream.try_next().await? {
507 Some(row) => row,
508 None => return Ok(None),
509 };
510
511 if stream.try_next().await?.is_some() {
512 return Err(Error::ExpectedOneRow);
513 }
514
515 R::from_row_owned(row).map(Some).map_err(Error::Deserialize)
516 }
517
518 pub async fn query_raw<T, P, I>(
520 &self,
521 statement: &T,
522 params: I,
523 ) -> Result<RowStream, Error>
524 where
525 T: ?Sized + ToStatement,
526 P: BorrowToSql,
527 I: IntoIterator<Item = P>,
528 I::IntoIter: ExactSizeIterator,
529 {
530 let row_stream = match &self.inner {
531 ConnectionInner::Client(client) => {
532 client.query_raw(statement, params).await?
533 }
534 ConnectionInner::Transaction(tr) => {
535 tr.query_raw(statement, params).await?
536 }
537 };
538
539 Ok(row_stream.into())
540 }
541
542 pub async fn execute<T>(
544 &self,
545 statement: &T,
546 params: &[&(dyn ToSql + Sync)],
547 ) -> Result<u64, Error>
548 where
549 T: ?Sized + ToStatement,
550 {
551 match &self.inner {
552 ConnectionInner::Client(client) => {
553 client.execute(statement, params).await.map_err(Error::from)
554 }
555 ConnectionInner::Transaction(tr) => {
556 tr.execute(statement, params).await.map_err(Error::from)
557 }
558 }
559 }
560
561 pub async fn execute_raw<T, P, I>(
563 &self,
564 statement: &T,
565 params: I,
566 ) -> Result<u64, Error>
567 where
568 T: ?Sized + ToStatement,
569 P: BorrowToSql,
570 I: IntoIterator<Item = P>,
571 I::IntoIter: ExactSizeIterator,
572 {
573 match &self.inner {
574 ConnectionInner::Client(client) => client
575 .execute_raw(statement, params)
576 .await
577 .map_err(Error::from),
578 ConnectionInner::Transaction(tr) => {
579 tr.execute_raw(statement, params).await.map_err(Error::from)
580 }
581 }
582 }
583
584 pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
586 match &self.inner {
587 ConnectionInner::Client(client) => {
588 client.batch_execute(query).await.map_err(Error::from)
589 }
590 ConnectionInner::Transaction(tr) => {
591 tr.batch_execute(query).await.map_err(Error::from)
592 }
593 }
594 }
595}
596
597fn slice_iter<'a>(
598 s: &'a [&'a (dyn ToSql + Sync)],
599) -> impl ExactSizeIterator<Item = &'a dyn ToSql> + 'a {
600 s.iter().map(|s| *s as _)
601}
602
603struct TwoExactSize<I, J>(I, J);
604
605impl<I, J, T> Iterator for TwoExactSize<I, J>
606where
607 I: ExactSizeIterator<Item = T>,
608 J: ExactSizeIterator<Item = T>,
609{
610 type Item = T;
611
612 fn next(&mut self) -> Option<Self::Item> {
613 self.0.next().or_else(|| self.1.next())
614 }
615
616 fn size_hint(&self) -> (usize, Option<usize>) {
617 let (a, b) = (self.0.size_hint(), self.1.size_hint());
618 (a.0 + b.0, a.1.and_then(|a| b.1.map(|b| a + b)))
619 }
620}
621
622impl<I, J, T> ExactSizeIterator for TwoExactSize<I, J>
623where
624 I: ExactSizeIterator<Item = T>,
625 J: ExactSizeIterator<Item = T>,
626{
627}