use crate::entity_client::persisted_request_id;
use crate::error::ClusterError;
use crate::message_storage::MessageStorage;
use crate::reply::ExitResult;
use crate::types::{EntityId, EntityType};
use async_trait::async_trait;
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::any::Any;
use std::future::Future;
use std::marker::PhantomData;
use std::sync::Arc;
use std::time::Duration;
pub const INTERRUPT_SIGNAL: &str = "Workflow/InterruptSignal";
tokio::task_local! {
static WORKFLOW_REQUEST_ID: i64;
static WORKFLOW_JOURNAL_KEYS: std::cell::RefCell<Vec<String>>;
}
pub struct WorkflowScope;
impl WorkflowScope {
pub async fn run<F, Fut, T>(request_id: i64, f: F) -> (T, Vec<String>)
where
F: FnOnce() -> Fut,
Fut: Future<Output = T>,
{
WORKFLOW_REQUEST_ID
.scope(
request_id,
WORKFLOW_JOURNAL_KEYS.scope(std::cell::RefCell::new(Vec::new()), async {
let result = f().await;
let keys =
WORKFLOW_JOURNAL_KEYS.with(|keys| keys.borrow_mut().drain(..).collect());
(result, keys)
}),
)
.await
}
pub fn current() -> Option<i64> {
WORKFLOW_REQUEST_ID.try_with(|id| *id).ok()
}
pub fn register_journal_key(key: String) {
let _ = WORKFLOW_JOURNAL_KEYS.try_with(|keys| {
keys.borrow_mut().push(key);
});
}
}
#[async_trait]
pub trait WorkflowStorage: Send + Sync {
async fn load(&self, key: &str) -> Result<Option<Vec<u8>>, ClusterError>;
async fn save(&self, key: &str, value: &[u8]) -> Result<(), ClusterError>;
async fn delete(&self, key: &str) -> Result<(), ClusterError>;
async fn list_keys(&self, prefix: &str) -> Result<Vec<String>, ClusterError>;
async fn mark_completed(&self, key: &str) -> Result<(), ClusterError>;
async fn cleanup(&self, older_than: Duration) -> Result<u64, ClusterError>;
async fn begin_transaction(&self) -> Result<Box<dyn StorageTransaction>, ClusterError> {
Ok(Box::new(NoopTransaction {
storage: self.as_arc(),
}))
}
fn as_arc(&self) -> Arc<dyn WorkflowStorage> {
panic!("WorkflowStorage::as_arc() must be implemented for default begin_transaction()")
}
fn sql_pool(&self) -> Option<&sqlx::PgPool> {
None
}
}
#[async_trait]
pub trait StorageTransaction: Send + Sync {
async fn save(&mut self, key: &str, value: &[u8]) -> Result<(), ClusterError>;
async fn delete(&mut self, key: &str) -> Result<(), ClusterError>;
async fn commit(self: Box<Self>) -> Result<(), ClusterError>;
async fn rollback(self: Box<Self>) -> Result<(), ClusterError>;
fn as_any_mut(&mut self) -> &mut dyn Any;
}
struct NoopTransaction {
storage: Arc<dyn WorkflowStorage>,
}
#[async_trait]
impl StorageTransaction for NoopTransaction {
async fn save(&mut self, key: &str, value: &[u8]) -> Result<(), ClusterError> {
self.storage.save(key, value).await
}
async fn delete(&mut self, key: &str) -> Result<(), ClusterError> {
self.storage.delete(key).await
}
async fn commit(self: Box<Self>) -> Result<(), ClusterError> {
Ok(())
}
async fn rollback(self: Box<Self>) -> Result<(), ClusterError> {
Ok(())
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct DeferredKey<T> {
pub name: &'static str,
_marker: PhantomData<T>,
}
impl<T> DeferredKey<T> {
pub const fn new(name: &'static str) -> Self {
Self {
name,
_marker: PhantomData,
}
}
}
pub trait DeferredKeyLike<T> {
fn name(&self) -> &str;
}
impl<T> DeferredKeyLike<T> for DeferredKey<T> {
fn name(&self) -> &str {
self.name
}
}
impl<T> DeferredKeyLike<T> for &DeferredKey<T> {
fn name(&self) -> &str {
self.name
}
}
impl<T> DeferredKeyLike<T> for &str {
fn name(&self) -> &str {
self
}
}
impl<T> DeferredKeyLike<T> for String {
fn name(&self) -> &str {
self.as_str()
}
}
impl<T> DeferredKeyLike<T> for &String {
fn name(&self) -> &str {
self.as_str()
}
}
pub struct DurableContext {
engine: Arc<dyn WorkflowEngine>,
workflow_name: String,
execution_id: String,
message_storage: Option<Arc<dyn MessageStorage>>,
workflow_storage: Option<Arc<dyn WorkflowStorage>>,
entity_type: EntityType,
entity_id: EntityId,
}
impl DurableContext {
pub fn new(
engine: Arc<dyn WorkflowEngine>,
workflow_name: impl Into<String>,
execution_id: impl Into<String>,
) -> Self {
let workflow_name = workflow_name.into();
let execution_id = execution_id.into();
Self {
engine,
entity_type: EntityType::new(&workflow_name),
entity_id: EntityId::new(&execution_id),
workflow_name,
execution_id,
message_storage: None,
workflow_storage: None,
}
}
pub fn with_journal_storage(
engine: Arc<dyn WorkflowEngine>,
workflow_name: impl Into<String>,
execution_id: impl Into<String>,
message_storage: Arc<dyn MessageStorage>,
workflow_storage: Arc<dyn WorkflowStorage>,
) -> Self {
let workflow_name = workflow_name.into();
let execution_id = execution_id.into();
Self {
engine,
entity_type: EntityType::new(&workflow_name),
entity_id: EntityId::new(&execution_id),
workflow_name,
execution_id,
message_storage: Some(message_storage),
workflow_storage: Some(workflow_storage),
}
}
#[tracing::instrument(level = "debug", skip(self))]
pub async fn sleep(&self, name: &str, duration: Duration) -> Result<(), ClusterError> {
self.engine
.sleep(&self.workflow_name, &self.execution_id, name, duration)
.await
}
#[tracing::instrument(level = "debug", skip(self, key))]
pub async fn await_deferred<T, K>(&self, key: K) -> Result<T, ClusterError>
where
T: Serialize + DeserializeOwned,
K: DeferredKeyLike<T>,
{
let name = key.name().to_string();
let bytes = self
.engine
.await_deferred(&self.workflow_name, &self.execution_id, &name)
.await?;
rmp_serde::from_slice(&bytes).map_err(|e| ClusterError::PersistenceError {
reason: format!("failed to deserialize deferred '{name}': {e}"),
source: Some(Box::new(e)),
})
}
#[tracing::instrument(level = "debug", skip(self, key, value))]
pub async fn resolve_deferred<T, K>(&self, key: K, value: &T) -> Result<(), ClusterError>
where
T: Serialize,
K: DeferredKeyLike<T>,
{
let name = key.name().to_string();
let bytes = rmp_serde::to_vec(value).map_err(|e| ClusterError::PersistenceError {
reason: format!("failed to serialize deferred value: {e}"),
source: Some(Box::new(e)),
})?;
self.engine
.resolve_deferred(&self.workflow_name, &self.execution_id, &name, bytes)
.await
}
#[tracing::instrument(level = "debug", skip(self))]
pub async fn on_interrupt(&self) -> Result<(), ClusterError> {
self.engine
.on_interrupt(&self.workflow_name, &self.execution_id)
.await
}
#[tracing::instrument(level = "debug", skip(self, key_bytes))]
pub async fn check_journal<T: DeserializeOwned>(
&self,
name: &str,
key_bytes: &[u8],
) -> Result<Option<T>, ClusterError> {
let wf_storage = match &self.workflow_storage {
Some(s) => s,
None => return Ok(None),
};
let storage_key =
Self::journal_storage_key(name, key_bytes, &self.entity_type, &self.entity_id);
if let Some(bytes) = wf_storage.load(&storage_key).await? {
WorkflowScope::register_journal_key(storage_key);
let result: T = Self::deserialize_journal_result(&bytes)?;
return Ok(Some(result));
}
Ok(None)
}
pub fn journal_storage_key(
name: &str,
key_bytes: &[u8],
entity_type: &EntityType,
entity_id: &EntityId,
) -> String {
let journal_tag = format!("__journal/{name}");
let request_id = persisted_request_id(entity_type, entity_id, &journal_tag, key_bytes);
format!("__journal/{}", request_id.0)
}
pub fn serialize_journal_result<T: Serialize>(
result: &Result<T, ClusterError>,
) -> Result<Vec<u8>, ClusterError> {
let exit = match result {
Ok(value) => {
let bytes =
rmp_serde::to_vec(value).map_err(|e| ClusterError::PersistenceError {
reason: format!("failed to serialize journal result: {e}"),
source: Some(Box::new(e)),
})?;
ExitResult::Success(bytes)
}
Err(e) => ExitResult::Failure(e.to_string()),
};
rmp_serde::to_vec(&exit).map_err(|e| ClusterError::PersistenceError {
reason: format!("failed to serialize journal exit: {e}"),
source: Some(Box::new(e)),
})
}
pub fn deserialize_journal_result<T: DeserializeOwned>(
bytes: &[u8],
) -> Result<T, ClusterError> {
let exit: ExitResult =
rmp_serde::from_slice(bytes).map_err(|e| ClusterError::PersistenceError {
reason: format!("failed to deserialize journal exit: {e}"),
source: Some(Box::new(e)),
})?;
match exit {
ExitResult::Success(data) => {
rmp_serde::from_slice(&data).map_err(|e| ClusterError::PersistenceError {
reason: format!("failed to deserialize cached journal result: {e}"),
source: Some(Box::new(e)),
})
}
ExitResult::Failure(msg) => Err(ClusterError::PersistenceError {
reason: format!("cached journal result was a failure: {msg}"),
source: None,
}),
}
}
pub fn has_journal(&self) -> bool {
self.message_storage.is_some()
}
pub fn entity_type(&self) -> &EntityType {
&self.entity_type
}
pub fn entity_id(&self) -> &EntityId {
&self.entity_id
}
#[tracing::instrument(skip(self, key_bytes, f))]
pub async fn run<T, F, Fut>(
&self,
name: &str,
key_bytes: &[u8],
f: F,
) -> Result<T, ClusterError>
where
T: Serialize + DeserializeOwned,
F: FnOnce() -> Fut,
Fut: Future<Output = Result<T, ClusterError>>,
{
if let Some(cached) = self.check_journal::<T>(name, key_bytes).await? {
return Ok(cached);
}
if self.message_storage.is_none() {
return f().await;
}
let result = f().await;
let storage_key =
Self::journal_storage_key(name, key_bytes, &self.entity_type, &self.entity_id);
let journal_bytes = Self::serialize_journal_result(&result)?;
WorkflowScope::register_journal_key(storage_key.clone());
if crate::state_guard::ActivityScope::is_active() {
crate::state_guard::ActivityScope::buffer_write(storage_key, journal_bytes);
} else if let Some(wf_storage) = &self.workflow_storage {
wf_storage.save(&storage_key, &journal_bytes).await?;
}
result
}
}
pub fn compute_retry_backoff(attempt: u32, backoff_strategy: &str, base_secs: u64) -> Duration {
match backoff_strategy {
"constant" => Duration::from_secs(base_secs),
_ => {
let power = 1u64.checked_shl(attempt).unwrap_or(u64::MAX);
let delay_secs = base_secs.saturating_mul(power);
Duration::from_secs(delay_secs.min(60))
}
}
}
#[async_trait]
pub trait WorkflowEngine: Send + Sync {
async fn sleep(
&self,
workflow_name: &str,
execution_id: &str,
name: &str,
duration: Duration,
) -> Result<(), ClusterError>;
async fn await_deferred(
&self,
workflow_name: &str,
execution_id: &str,
name: &str,
) -> Result<Vec<u8>, ClusterError>;
async fn resolve_deferred(
&self,
workflow_name: &str,
execution_id: &str,
name: &str,
value: Vec<u8>,
) -> Result<(), ClusterError>;
async fn on_interrupt(
&self,
workflow_name: &str,
execution_id: &str,
) -> Result<(), ClusterError>;
}