use std::time::Duration;
use sqlx::{postgres::PgRow, PgPool};
use super::executor::{IsRetryable, RetryConfig};
use crate::core::condition::SqlValue;
use crate::core::model::Model;
use crate::core::query::QueryBuilder;
use crate::core::sqlx::pg as sqlx_pg;
pub struct Tx<'c> {
inner: sqlx::Transaction<'c, sqlx::Postgres>,
}
impl<'c> Tx<'c> {
pub async fn begin(pool: &'c PgPool) -> Result<Self, sqlx::Error> {
Ok(Self {
inner: pool.begin().await?,
})
}
pub async fn commit(self) -> Result<(), sqlx::Error> {
self.inner.commit().await
}
pub async fn commit_with_retry(self, _config: &RetryConfig) -> Result<(), sqlx::Error> {
self.inner.commit().await
}
pub async fn run_with_retry<F, Fut, T>(
pool: &'c PgPool,
config: &RetryConfig,
mut f: F,
) -> Result<T, sqlx::Error>
where
F: FnMut(&mut Tx<'c>) -> Fut,
Fut: std::future::Future<Output = Result<T, sqlx::Error>>,
{
let mut attempt = 0u32;
loop {
attempt += 1;
let mut tx = Tx::begin(pool).await?;
let result = f(&mut tx).await;
match result {
Ok(val) => return tx.commit().await.map(|_| val),
Err(e) if e.is_retryable() && attempt < config.max_attempts => {
let delay = config
.base_delay_ms
.saturating_mul(2u64.pow(attempt.saturating_sub(1)))
.min(config.max_delay_ms);
tokio::time::sleep(Duration::from_millis(delay)).await;
}
Err(e) => return Err(e),
}
}
}
pub async fn rollback(self) -> Result<(), sqlx::Error> {
self.inner.rollback().await
}
pub async fn fetch_all<T>(&mut self, builder: QueryBuilder<T>) -> Result<Vec<T>, sqlx::Error>
where
T: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
let (sql, params) = builder.to_sql();
sqlx_pg::build_query_as::<T>(&sql, params)
.fetch_all(&mut *self.inner)
.await
}
pub async fn fetch_optional<T>(
&mut self,
builder: QueryBuilder<T>,
) -> Result<Option<T>, sqlx::Error>
where
T: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
let (sql, params) = builder.to_sql();
sqlx_pg::build_query_as::<T>(&sql, params)
.fetch_optional(&mut *self.inner)
.await
}
pub async fn count<T>(&mut self, builder: QueryBuilder<T>) -> Result<i64, sqlx::Error> {
let (sql, params) = builder.to_count_sql();
let row = sqlx_pg::build_query(&sql, params)
.fetch_one(&mut *self.inner)
.await?;
use sqlx::Row;
row.try_get::<i64, _>(0)
}
pub async fn execute_raw(
&mut self,
sql: &str,
params: Vec<SqlValue>,
) -> Result<u64, sqlx::Error> {
let result = sqlx_pg::build_query(sql, params)
.execute(&mut *self.inner)
.await?;
Ok(result.rows_affected())
}
pub async fn insert<T>(
&mut self,
table: &str,
data: &[(&str, SqlValue)],
) -> Result<u64, sqlx::Error> {
let (sql, params) = QueryBuilder::<T>::insert_sql(table, data);
self.execute_raw(&sql, params).await
}
pub async fn insert_returning<T>(
&mut self,
table: &str,
data: &[(&str, SqlValue)],
) -> Result<T, sqlx::Error>
where
T: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
let (base_sql, params) = QueryBuilder::<T>::insert_sql(table, data);
let sql = format!("{base_sql} RETURNING *");
sqlx_pg::build_query_as::<T>(&sql, params)
.fetch_optional(&mut *self.inner)
.await?
.ok_or(sqlx::Error::RowNotFound)
}
pub async fn bulk_insert<T>(
&mut self,
table: &str,
rows: &[Vec<(&str, SqlValue)>],
) -> Result<u64, sqlx::Error> {
if rows.is_empty() {
return Ok(0);
}
let (sql, params) = QueryBuilder::<T>::bulk_insert_sql(table, rows);
self.execute_raw(&sql, params).await
}
pub async fn update<T>(
&mut self,
builder: QueryBuilder<T>,
data: &[(&str, SqlValue)],
) -> Result<u64, sqlx::Error> {
let (sql, params) = builder.to_update_sql(data);
self.execute_raw(&sql, params).await
}
pub async fn delete<T>(&mut self, builder: QueryBuilder<T>) -> Result<u64, sqlx::Error> {
let (sql, params) = builder.to_delete_sql();
self.execute_raw(&sql, params).await
}
}