1#![warn(missing_docs)]
2#![deny(warnings, clippy::pedantic, clippy::nursery, unused_crate_dependencies)]
3#![allow(clippy::future_not_send)]
4
5#[cfg(feature = "async")]
11use {
12 deadpool_postgres::Pool as Deadpool,
13 std::{future::Future, pin::Pin},
14 tokio_postgres::error::SqlState,
15};
16
17#[cfg(feature = "sync")]
18use {
19 r2d2_postgres::PostgresConnectionManager,
20 tokio::task::block_in_place,
21 tokio_postgres::{
22 tls::{MakeTlsConnect, TlsConnect},
23 Socket,
24 },
25};
26
27pub const COCKROACH_SAVEPOINT: &str = "cockroach_restart";
29
30pub enum Error<E> {
32 #[cfg(feature = "async")]
34 Pool(deadpool::managed::PoolError<tokio_postgres::Error>),
35 #[cfg(feature = "async")]
37 Postgres(tokio_postgres::Error),
38 #[cfg(feature = "sync")]
40 R2d2(r2d2::Error),
41 Other(E),
43}
44
45#[cfg(feature = "async")]
46impl<E> From<deadpool::managed::PoolError<tokio_postgres::Error>> for Error<E> {
47 #[inline]
48 fn from(e: deadpool::managed::PoolError<tokio_postgres::Error>) -> Self {
49 Self::Pool(e)
50 }
51}
52
53#[cfg(feature = "async")]
54impl<E> From<tokio_postgres::Error> for Error<E> {
55 #[inline]
56 fn from(e: tokio_postgres::Error) -> Self {
57 Self::Postgres(e)
58 }
59}
60
61#[cfg(feature = "sync")]
62impl<E> From<r2d2::Error> for Error<E> {
63 #[inline]
64 fn from(e: r2d2::Error) -> Self {
65 Self::R2d2(e)
66 }
67}
68
69#[cfg(feature = "sync")]
71pub type Pool<T> = r2d2::Pool<PostgresConnectionManager<T>>;
72
73#[cfg(feature = "async")]
75pub type AsyncResult<'a, T, I> = Pin<Box<dyn Future<Output = Result<T, I>> + Send + 'a>>;
76
77#[cfg(feature = "async")]
84#[inline]
85pub async fn tx<T, E, I, S, F>(pool: &Deadpool, savepoint: S, f: F) -> Result<T, Error<E>>
86where
87 I: Into<Error<E>>,
88 S: AsRef<str>,
89 for<'a> F: Fn(&'a tokio_postgres::Transaction<'a>) -> AsyncResult<'a, T, I>,
90{
91 let mut client = pool.get().await?;
92 let mut tx = client.transaction().await?;
93 let savepoint = savepoint.as_ref();
94 let v = loop {
95 match execute_fn(&mut tx, savepoint, &f).await {
96 Err(Error::Postgres(e)) if e.code() == Some(&SqlState::T_R_SERIALIZATION_FAILURE) => {}
97 r => break r,
98 }
99 }?;
100
101 tx.commit().await?;
102
103 Ok(v)
104}
105
106#[cfg(feature = "async")]
107#[inline]
108async fn execute_fn<T, E, I, F>(
109 tx: &mut tokio_postgres::Transaction<'_>,
110 savepoint: &str,
111 f: &F,
112) -> Result<T, Error<E>>
113where
114 I: Into<Error<E>>,
115 for<'a> F: Fn(&'a tokio_postgres::Transaction<'a>) -> AsyncResult<'a, T, I>,
116{
117 let mut sp = tx.savepoint(savepoint).await?;
118 let v = f(&mut sp).await.map_err(Into::into)?;
119
120 sp.commit().await?;
121
122 Ok(v)
123}
124
125#[cfg(feature = "sync")]
132#[inline]
133pub fn tx_sync<T, M, E, I, S, F>(pool: &Pool<M>, savepoint: S, mut f: F) -> Result<T, Error<E>>
134where
135 M: MakeTlsConnect<Socket> + Clone + 'static + Sync + Send,
136 M::TlsConnect: Send,
137 M::Stream: Send,
138 <M::TlsConnect as TlsConnect<Socket>>::Future: Send,
139 I: Into<Error<E>>,
140 S: AsRef<str>,
141 F: FnMut(&mut postgres::Transaction<'_>) -> Result<T, I>,
142{
143 block_in_place(|| {
144 let mut con = pool.get()?;
145 let mut tx = con.transaction()?;
146 let savepoint = savepoint.as_ref();
147
148 loop {
149 let mut sp = tx.savepoint(savepoint)?;
150
151 match f(&mut sp)
152 .map_err(Into::into)
153 .and_then(|t| sp.commit().map(|_| t).map_err(Error::from))
154 {
155 Err(Error::Postgres(e))
156 if e.code() == Some(&SqlState::T_R_SERIALIZATION_FAILURE) => {}
157 r => break r,
158 }
159 }
160 .and_then(|t| tx.commit().map(|_| t).map_err(Error::from))
161 })
162}