use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::encode::SqlParam;
use crate::error::TypedError;
use crate::row::Row;
static SAVEPOINT_COUNTER: AtomicU64 = AtomicU64::new(0);
pub trait Executor: Send + Sync {
fn query<'a>(
&'a self,
sql: &'a str,
params: &'a [&'a dyn SqlParam],
) -> impl Future<Output = Result<Vec<Row>, TypedError>> + Send + 'a;
fn execute<'a>(
&'a self,
sql: &'a str,
params: &'a [&'a dyn SqlParam],
) -> impl Future<Output = Result<u64, TypedError>> + Send + 'a;
fn query_one<'a>(
&'a self,
sql: &'a str,
params: &'a [&'a dyn SqlParam],
) -> impl Future<Output = Result<Row, TypedError>> + Send + 'a {
async move {
let rows = self.query(sql, params).await?;
if rows.len() != 1 {
return Err(TypedError::NotExactlyOne(rows.len()));
}
Ok(rows.into_iter().next().unwrap())
}
}
fn query_opt<'a>(
&'a self,
sql: &'a str,
params: &'a [&'a dyn SqlParam],
) -> impl Future<Output = Result<Option<Row>, TypedError>> + Send + 'a {
async move {
let rows = self.query(sql, params).await?;
match rows.len() {
0 => Ok(None),
1 => Ok(Some(rows.into_iter().next().unwrap())),
n => Err(TypedError::NotExactlyOne(n)),
}
}
}
fn query_named<'a>(
&'a self,
sql: &'a str,
params: &'a [(&'a str, &'a dyn SqlParam)],
) -> impl Future<Output = Result<Vec<Row>, TypedError>> + Send + 'a {
async move {
let (rewritten, names) = crate::named_params::rewrite(sql);
let ordered = resolve_named(&names, params)?;
self.query(&rewritten, &ordered).await
}
}
fn execute_named<'a>(
&'a self,
sql: &'a str,
params: &'a [(&'a str, &'a dyn SqlParam)],
) -> impl Future<Output = Result<u64, TypedError>> + Send + 'a {
async move {
let (rewritten, names) = crate::named_params::rewrite(sql);
let ordered = resolve_named(&names, params)?;
self.execute(&rewritten, &ordered).await
}
}
fn atomic<'a, T: Send + 'a>(
&'a self,
f: impl FnOnce(&'a Self) -> Pin<Box<dyn Future<Output = Result<T, TypedError>> + Send + 'a>>
+ Send
+ 'a,
) -> impl Future<Output = Result<T, TypedError>> + Send + 'a;
fn ping<'a>(&'a self) -> impl Future<Output = Result<(), TypedError>> + Send + 'a {
async move {
self.query("SELECT 1", &[]).await?;
Ok(())
}
}
fn copy_in<'a>(
&'a self,
copy_sql: &'a str,
data: &'a [u8],
) -> impl Future<Output = Result<u64, TypedError>> + Send + 'a;
fn copy_out<'a>(
&'a self,
copy_sql: &'a str,
) -> impl Future<Output = Result<Vec<u8>, TypedError>> + Send + 'a;
}
fn resolve_named<'a>(
names: &[String],
params: &[(&str, &'a dyn SqlParam)],
) -> Result<Vec<&'a dyn SqlParam>, TypedError> {
names
.iter()
.map(|name| {
params
.iter()
.find(|(n, _)| *n == name.as_str())
.map(|(_, p)| *p)
.ok_or_else(|| TypedError::MissingParam(name.to_string()))
})
.collect()
}
#[allow(clippy::manual_async_fn)]
impl Executor for crate::query::Client {
fn query<'a>(
&'a self,
sql: &'a str,
params: &'a [&'a dyn SqlParam],
) -> impl Future<Output = Result<Vec<Row>, TypedError>> + Send + 'a {
crate::query::Client::query(self, sql, params)
}
fn execute<'a>(
&'a self,
sql: &'a str,
params: &'a [&'a dyn SqlParam],
) -> impl Future<Output = Result<u64, TypedError>> + Send + 'a {
crate::query::Client::execute(self, sql, params)
}
fn copy_in<'a>(
&'a self,
copy_sql: &'a str,
data: &'a [u8],
) -> impl Future<Output = Result<u64, TypedError>> + Send + 'a {
crate::query::Client::copy_in(self, copy_sql, data)
}
fn copy_out<'a>(
&'a self,
copy_sql: &'a str,
) -> impl Future<Output = Result<Vec<u8>, TypedError>> + Send + 'a {
crate::query::Client::copy_out(self, copy_sql)
}
fn atomic<'a, T: Send + 'a>(
&'a self,
f: impl FnOnce(&'a Self) -> Pin<Box<dyn Future<Output = Result<T, TypedError>> + Send + 'a>>
+ Send
+ 'a,
) -> impl Future<Output = Result<T, TypedError>> + Send + 'a {
async move {
self.simple_query("BEGIN").await?;
match f(self).await {
Ok(val) => {
self.simple_query("COMMIT").await?;
Ok(val)
}
Err(e) => {
if let Err(rb_err) = self.simple_query("ROLLBACK").await {
tracing::error!(error = %rb_err, "transaction rollback failed");
}
Err(e)
}
}
}
}
}
#[allow(clippy::manual_async_fn)]
impl Executor for crate::query::Transaction<'_> {
fn query<'a>(
&'a self,
sql: &'a str,
params: &'a [&'a dyn SqlParam],
) -> impl Future<Output = Result<Vec<Row>, TypedError>> + Send + 'a {
self.client.query(sql, params)
}
fn execute<'a>(
&'a self,
sql: &'a str,
params: &'a [&'a dyn SqlParam],
) -> impl Future<Output = Result<u64, TypedError>> + Send + 'a {
self.client.execute(sql, params)
}
fn copy_in<'a>(
&'a self,
copy_sql: &'a str,
data: &'a [u8],
) -> impl Future<Output = Result<u64, TypedError>> + Send + 'a {
self.client.copy_in(copy_sql, data)
}
fn copy_out<'a>(
&'a self,
copy_sql: &'a str,
) -> impl Future<Output = Result<Vec<u8>, TypedError>> + Send + 'a {
self.client.copy_out(copy_sql)
}
fn atomic<'a, T: Send + 'a>(
&'a self,
f: impl FnOnce(&'a Self) -> Pin<Box<dyn Future<Output = Result<T, TypedError>> + Send + 'a>>
+ Send
+ 'a,
) -> impl Future<Output = Result<T, TypedError>> + Send + 'a {
async move {
let id = SAVEPOINT_COUNTER.fetch_add(1, Ordering::Relaxed);
let sp = format!("resolute_sp_{id}");
self.client.simple_query(&format!("SAVEPOINT {sp}")).await?;
match f(self).await {
Ok(val) => {
self.client
.simple_query(&format!("RELEASE SAVEPOINT {sp}"))
.await?;
Ok(val)
}
Err(e) => {
if let Err(rb_err) = self
.client
.simple_query(&format!("ROLLBACK TO SAVEPOINT {sp}"))
.await
{
tracing::error!(error = %rb_err, savepoint = %sp, "savepoint rollback failed");
}
Err(e)
}
}
}
}
}
#[allow(clippy::manual_async_fn)]
impl Executor for crate::pooled::PooledClient {
fn query<'a>(
&'a self,
sql: &'a str,
params: &'a [&'a dyn SqlParam],
) -> impl Future<Output = Result<Vec<Row>, TypedError>> + Send + 'a {
crate::pooled::PooledClient::query(self, sql, params)
}
fn execute<'a>(
&'a self,
sql: &'a str,
params: &'a [&'a dyn SqlParam],
) -> impl Future<Output = Result<u64, TypedError>> + Send + 'a {
crate::pooled::PooledClient::execute(self, sql, params)
}
fn copy_in<'a>(
&'a self,
copy_sql: &'a str,
data: &'a [u8],
) -> impl Future<Output = Result<u64, TypedError>> + Send + 'a {
crate::pooled::PooledClient::copy_in(self, copy_sql, data)
}
fn copy_out<'a>(
&'a self,
copy_sql: &'a str,
) -> impl Future<Output = Result<Vec<u8>, TypedError>> + Send + 'a {
crate::pooled::PooledClient::copy_out(self, copy_sql)
}
fn atomic<'a, T: Send + 'a>(
&'a self,
f: impl FnOnce(&'a Self) -> Pin<Box<dyn Future<Output = Result<T, TypedError>> + Send + 'a>>
+ Send
+ 'a,
) -> impl Future<Output = Result<T, TypedError>> + Send + 'a {
async move {
self.simple_query("BEGIN").await?;
match f(self).await {
Ok(val) => {
self.simple_query("COMMIT").await?;
Ok(val)
}
Err(e) => {
if let Err(rb_err) = self.simple_query("ROLLBACK").await {
tracing::error!(error = %rb_err, "transaction rollback failed");
}
Err(e)
}
}
}
}
}
#[allow(clippy::manual_async_fn)]
impl Executor for crate::reconnect::ReconnectingClient {
fn query<'a>(
&'a self,
sql: &'a str,
params: &'a [&'a dyn SqlParam],
) -> impl Future<Output = Result<Vec<Row>, TypedError>> + Send + 'a {
crate::reconnect::ReconnectingClient::query(self, sql, params)
}
fn execute<'a>(
&'a self,
sql: &'a str,
params: &'a [&'a dyn SqlParam],
) -> impl Future<Output = Result<u64, TypedError>> + Send + 'a {
crate::reconnect::ReconnectingClient::execute(self, sql, params)
}
fn copy_in<'a>(
&'a self,
copy_sql: &'a str,
data: &'a [u8],
) -> impl Future<Output = Result<u64, TypedError>> + Send + 'a {
async move { self.client().copy_in(copy_sql, data).await }
}
fn copy_out<'a>(
&'a self,
copy_sql: &'a str,
) -> impl Future<Output = Result<Vec<u8>, TypedError>> + Send + 'a {
async move { self.client().copy_out(copy_sql).await }
}
fn atomic<'a, T: Send + 'a>(
&'a self,
f: impl FnOnce(&'a Self) -> Pin<Box<dyn Future<Output = Result<T, TypedError>> + Send + 'a>>
+ Send
+ 'a,
) -> impl Future<Output = Result<T, TypedError>> + Send + 'a {
async move {
let client = self.client();
client.simple_query("BEGIN").await?;
match f(self).await {
Ok(val) => {
client.simple_query("COMMIT").await?;
Ok(val)
}
Err(e) => {
if let Err(rb_err) = client.simple_query("ROLLBACK").await {
tracing::error!(error = %rb_err, "transaction rollback failed");
}
Err(e)
}
}
}
}
}
#[allow(clippy::manual_async_fn)]
impl Executor for crate::pooled::PooledTransaction<'_> {
fn query<'a>(
&'a self,
sql: &'a str,
params: &'a [&'a dyn SqlParam],
) -> impl Future<Output = Result<Vec<Row>, TypedError>> + Send + 'a {
crate::pooled::PooledTransaction::query(self, sql, params)
}
fn execute<'a>(
&'a self,
sql: &'a str,
params: &'a [&'a dyn SqlParam],
) -> impl Future<Output = Result<u64, TypedError>> + Send + 'a {
crate::pooled::PooledTransaction::execute(self, sql, params)
}
fn copy_in<'a>(
&'a self,
copy_sql: &'a str,
data: &'a [u8],
) -> impl Future<Output = Result<u64, TypedError>> + Send + 'a {
async move { self.client().copy_in(copy_sql, data).await }
}
fn copy_out<'a>(
&'a self,
copy_sql: &'a str,
) -> impl Future<Output = Result<Vec<u8>, TypedError>> + Send + 'a {
async move { self.client().copy_out(copy_sql).await }
}
fn atomic<'a, T: Send + 'a>(
&'a self,
f: impl FnOnce(&'a Self) -> Pin<Box<dyn Future<Output = Result<T, TypedError>> + Send + 'a>>
+ Send
+ 'a,
) -> impl Future<Output = Result<T, TypedError>> + Send + 'a {
async move {
let id = SAVEPOINT_COUNTER.fetch_add(1, Ordering::Relaxed);
let sp = format!("resolute_sp_{id}");
self.client()
.simple_query(&format!("SAVEPOINT {sp}"))
.await?;
match f(self).await {
Ok(val) => {
self.client()
.simple_query(&format!("RELEASE SAVEPOINT {sp}"))
.await?;
Ok(val)
}
Err(e) => {
if let Err(rb_err) = self
.client()
.simple_query(&format!("ROLLBACK TO SAVEPOINT {sp}"))
.await
{
tracing::error!(error = %rb_err, savepoint = %sp, "savepoint rollback failed");
}
Err(e)
}
}
}
}
}
impl crate::query::Client {
pub async fn with_transaction<'a, T: Send + 'a>(
&'a self,
f: impl FnOnce(&'a Self) -> Pin<Box<dyn Future<Output = Result<T, TypedError>> + Send + 'a>>,
) -> Result<T, TypedError> {
self.simple_query("BEGIN").await?;
match f(self).await {
Ok(val) => {
self.simple_query("COMMIT").await?;
Ok(val)
}
Err(e) => {
if let Err(rollback_err) = self.simple_query("ROLLBACK").await {
tracing::warn!(
error = %rollback_err,
"ROLLBACK failed after transaction error; connection may be unhealthy"
);
}
Err(e)
}
}
}
}