use std::cell::RefCell;
use std::sync::Arc;
use tokio::sync::Mutex as TokioMutex;
use crate::durable::StorageTransaction;
use crate::durable::WorkflowStorage;
use crate::error::ClusterError;
pub struct ActivityTx(TokioMutex<sqlx::Transaction<'static, sqlx::Postgres>>);
impl ActivityTx {
pub fn new(tx: sqlx::Transaction<'static, sqlx::Postgres>) -> Self {
Self(TokioMutex::new(tx))
}
pub async fn into_inner(self) -> sqlx::Transaction<'static, sqlx::Postgres> {
self.0.into_inner()
}
}
impl std::fmt::Debug for ActivityTx {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ActivityTx").finish()
}
}
impl<'c> sqlx::Executor<'c> for &'c ActivityTx {
type Database = sqlx::Postgres;
fn fetch_many<'e, 'q: 'e, E>(
self,
query: E,
) -> futures::stream::BoxStream<
'e,
Result<sqlx::Either<sqlx::postgres::PgQueryResult, sqlx::postgres::PgRow>, sqlx::Error>,
>
where
'c: 'e,
E: sqlx::Execute<'q, sqlx::Postgres> + 'q,
{
use futures::{FutureExt, StreamExt};
async move {
let mut guard = self.0.lock().await;
let results: Vec<_> = (&mut **guard).fetch_many(query).collect().await;
futures::stream::iter(results)
}
.into_stream()
.flatten()
.boxed()
}
fn fetch_optional<'e, 'q: 'e, E>(
self,
query: E,
) -> futures::future::BoxFuture<'e, Result<Option<sqlx::postgres::PgRow>, sqlx::Error>>
where
'c: 'e,
E: sqlx::Execute<'q, sqlx::Postgres> + 'q,
{
Box::pin(async move {
let mut guard = self.0.lock().await;
(&mut **guard).fetch_optional(query).await
})
}
fn prepare_with<'e>(
self,
sql: sqlx::SqlStr,
parameters: &'e [<sqlx::Postgres as sqlx::Database>::TypeInfo],
) -> futures::future::BoxFuture<
'e,
Result<<sqlx::Postgres as sqlx::Database>::Statement, sqlx::Error>,
>
where
'c: 'e,
{
Box::pin(async move {
let mut guard = self.0.lock().await;
(&mut **guard).prepare_with(sql, parameters).await
})
}
fn describe<'e>(
self,
sql: sqlx::SqlStr,
) -> futures::future::BoxFuture<'e, Result<sqlx::Describe<sqlx::Postgres>, sqlx::Error>>
where
'c: 'e,
{
Box::pin(async move {
let mut guard = self.0.lock().await;
(&mut **guard).describe(sql).await
})
}
}
type PendingWrites = Arc<parking_lot::Mutex<Vec<(String, Vec<u8>)>>>;
type SharedTransaction = Arc<TokioMutex<Box<dyn StorageTransaction>>>;
tokio::task_local! {
static ACTIVE_TRANSACTION: RefCell<Option<ActiveTransaction>>;
}
struct ActiveTransaction {
pending_writes: PendingWrites,
transaction: SharedTransaction,
}
pub struct ActivityScope;
impl ActivityScope {
#[tracing::instrument(skip(storage, f))]
pub async fn run<F, Fut, T>(storage: &Arc<dyn WorkflowStorage>, f: F) -> Result<T, ClusterError>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<T, ClusterError>>,
{
let tx = storage.begin_transaction().await?;
let transaction = Arc::new(TokioMutex::new(tx));
let pending_writes = Arc::new(parking_lot::Mutex::new(Vec::new()));
let active = ActiveTransaction {
pending_writes: pending_writes.clone(),
transaction: transaction.clone(),
};
let result = ACTIVE_TRANSACTION
.scope(RefCell::new(Some(active)), async { f().await })
.await;
match result {
Ok(value) => {
let writes: Vec<_> = {
let mut guard = pending_writes.lock();
std::mem::take(&mut *guard)
};
let mut tx = Arc::try_unwrap(transaction)
.map_err(|_| ClusterError::PersistenceError {
reason: "transaction still in use after activity completed".to_string(),
source: None,
})?
.into_inner();
for (key, bytes) in writes.iter() {
tx.save(key, bytes).await?;
}
tx.commit().await?;
Ok(value)
}
Err(e) => {
if let Ok(tx_arc) = Arc::try_unwrap(transaction) {
let tx = tx_arc.into_inner();
let _ = tx.rollback().await; }
Err(e)
}
}
}
pub fn is_active() -> bool {
ACTIVE_TRANSACTION
.try_with(|cell| cell.borrow().is_some())
.unwrap_or(false)
}
pub fn buffer_write(key: String, value: Vec<u8>) {
let _ = ACTIVE_TRANSACTION.try_with(|cell| {
if let Some(active) = cell.borrow().as_ref() {
let mut writes = active.pending_writes.lock();
if let Some(pos) = writes.iter().position(|(k, _)| k == &key) {
writes[pos].1 = value;
} else {
writes.push((key, value));
}
}
});
}
pub async fn db() -> SqlTransactionHandle {
Self::sql_transaction()
.await
.expect("db() requires an active SQL transaction; are you inside an #[activity] with SQL storage?")
}
pub async fn sql_transaction() -> Option<SqlTransactionHandle> {
let transaction = ACTIVE_TRANSACTION
.try_with(|cell| cell.borrow().as_ref().map(|a| a.transaction.clone()))
.ok()
.flatten()?;
let mut guard = transaction.lock().await;
let is_sql = guard
.as_any_mut()
.downcast_ref::<crate::storage::sql_workflow_journal::SqlJournalTransaction>()
.is_some();
drop(guard);
if is_sql {
Some(SqlTransactionHandle { transaction })
} else {
None
}
}
}
pub struct SqlTransactionHandle {
transaction: SharedTransaction,
}
impl SqlTransactionHandle {
pub async fn execute(
&self,
query: sqlx::query::Query<'_, sqlx::Postgres, sqlx::postgres::PgArguments>,
) -> Result<sqlx::postgres::PgQueryResult, ClusterError> {
let mut guard = self.transaction.lock().await;
let tx = guard
.as_any_mut()
.downcast_mut::<crate::storage::sql_workflow_journal::SqlJournalTransaction>()
.ok_or_else(|| ClusterError::PersistenceError {
reason: "transaction is not a SQL transaction".to_string(),
source: None,
})?;
tx.execute(query).await
}
pub async fn fetch_one<'q, O>(
&self,
query: sqlx::query::QueryAs<'q, sqlx::Postgres, O, sqlx::postgres::PgArguments>,
) -> Result<O, ClusterError>
where
O: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>,
{
let mut guard = self.transaction.lock().await;
let tx = guard
.as_any_mut()
.downcast_mut::<crate::storage::sql_workflow_journal::SqlJournalTransaction>()
.ok_or_else(|| ClusterError::PersistenceError {
reason: "transaction is not a SQL transaction".to_string(),
source: None,
})?;
tx.fetch_one(query).await
}
pub async fn fetch_optional<'q, O>(
&self,
query: sqlx::query::QueryAs<'q, sqlx::Postgres, O, sqlx::postgres::PgArguments>,
) -> Result<Option<O>, ClusterError>
where
O: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>,
{
let mut guard = self.transaction.lock().await;
let tx = guard
.as_any_mut()
.downcast_mut::<crate::storage::sql_workflow_journal::SqlJournalTransaction>()
.ok_or_else(|| ClusterError::PersistenceError {
reason: "transaction is not a SQL transaction".to_string(),
source: None,
})?;
tx.fetch_optional(query).await
}
pub async fn fetch_all<'q, O>(
&self,
query: sqlx::query::QueryAs<'q, sqlx::Postgres, O, sqlx::postgres::PgArguments>,
) -> Result<Vec<O>, ClusterError>
where
O: Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>,
{
let mut guard = self.transaction.lock().await;
let tx = guard
.as_any_mut()
.downcast_mut::<crate::storage::sql_workflow_journal::SqlJournalTransaction>()
.ok_or_else(|| ClusterError::PersistenceError {
reason: "transaction is not a SQL transaction".to_string(),
source: None,
})?;
tx.fetch_all(query).await
}
}