use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use redb::{Database, ReadableDatabase as _, TableDefinition};
use serde::Serialize;
use serde::de::DeserializeOwned;
use tokio::sync::watch;
use tracing::{info, info_span};
use crate::engine::WorkflowState;
use crate::error::{EngineError, StepError};
use crate::retry::RetryPolicy;
pub struct SuspendPoint<T> {
key: &'static str,
_marker: PhantomData<T>,
}
impl<T> SuspendPoint<T> {
#[must_use]
pub const fn new(key: &'static str) -> Self {
let bytes = key.as_bytes();
let mut i = 0;
while i < bytes.len() {
assert!(bytes[i] != b'/', "suspend point key must not contain '/'");
i += 1;
}
assert!(
bytes.is_empty() || bytes[0] != b'_',
"suspend point key must not start with '_'"
);
Self {
key,
_marker: PhantomData,
}
}
#[must_use]
pub const fn key(&self) -> &'static str {
self.key
}
}
pub(crate) const STEPS: TableDefinition<&str, &[u8]> = TableDefinition::new("steps");
pub(crate) const TIMERS: TableDefinition<(u64, u64), &[u8]> = TableDefinition::new("timers");
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
pub(crate) struct TimerEntry {
pub workflow_name: String,
pub instance_id: String,
pub step_key: String,
}
#[derive(Debug, Clone)]
enum RetryOverride {
Disabled,
Custom(RetryPolicy),
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub(crate) enum StepData<T> {
Completed { result: T, status: Option<String> },
Suspended,
Failed { error: String },
}
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum StepState {
Suspended,
Completed,
Failed,
}
#[derive(serde::Serialize, serde::Deserialize)]
pub(crate) struct StepEnvelope {
pub state: StepState,
pub type_tag: Option<String>,
data: Vec<u8>,
}
pub(crate) fn serialize_step<T: Serialize>(
data: &StepData<T>,
key: &str,
) -> Result<Vec<u8>, EngineError> {
let (state, type_tag) = match data {
StepData::Completed { .. } => (
StepState::Completed,
Some(std::any::type_name::<T>().to_string()),
),
StepData::Suspended => (StepState::Suspended, None),
StepData::Failed { .. } => (StepState::Failed, None),
};
let inner = postcard::to_allocvec(data).map_err(|e| EngineError::Serialization {
key: key.to_string(),
source: Box::new(e),
})?;
let envelope = StepEnvelope {
state,
type_tag,
data: inner,
};
postcard::to_allocvec(&envelope).map_err(|e| EngineError::Serialization {
key: key.to_string(),
source: Box::new(e),
})
}
pub(crate) fn deserialize_step<T: DeserializeOwned>(
bytes: &[u8],
key: &str,
) -> Result<StepData<T>, EngineError> {
let envelope: StepEnvelope =
postcard::from_bytes(bytes).map_err(|e| EngineError::Serialization {
key: key.to_string(),
source: Box::new(e),
})?;
if let Some(ref stored) = envelope.type_tag {
let expected = std::any::type_name::<T>();
if stored != expected {
return Err(EngineError::TypeMismatch {
key: key.to_string(),
expected: expected.to_string(),
found: stored.clone(),
});
}
}
postcard::from_bytes(&envelope.data).map_err(|e| EngineError::Serialization {
key: key.to_string(),
source: Box::new(e),
})
}
pub(crate) fn deserialize_envelope(bytes: &[u8], key: &str) -> Result<StepEnvelope, EngineError> {
postcard::from_bytes(bytes).map_err(|e| EngineError::Serialization {
key: key.to_string(),
source: Box::new(e),
})
}
pub struct Context {
workflow_name: String,
instance_id: String,
db: Arc<Database>,
status_tx: watch::Sender<WorkflowState>,
replaying: AtomicBool,
timer_serial: Arc<AtomicU64>,
default_retry: Option<RetryPolicy>,
}
impl Context {
pub(crate) fn new(
workflow_name: String,
instance_id: String,
db: Arc<Database>,
status_tx: watch::Sender<WorkflowState>,
timer_serial: Arc<AtomicU64>,
default_retry: Option<RetryPolicy>,
) -> Self {
Self {
workflow_name,
instance_id,
db,
status_tx,
replaying: AtomicBool::new(true),
timer_serial,
default_retry,
}
}
#[must_use]
pub fn workflow_name(&self) -> &str {
&self.workflow_name
}
#[must_use]
pub fn instance_id(&self) -> &str {
&self.instance_id
}
pub fn input<T: DeserializeOwned>(&self) -> Result<Option<T>, EngineError> {
let composite_key = format!("{}/{}/_input", self.workflow_name, self.instance_id);
let Some(bytes) = self.read_step(&composite_key)? else {
return Ok(None);
};
let data: StepData<T> = deserialize_step(&bytes, "_input")?;
match data {
StepData::Completed { result, .. } => Ok(Some(result)),
StepData::Suspended | StepData::Failed { .. } => Ok(None),
}
}
pub fn set_status(&self, msg: impl fmt::Display) {
let value = WorkflowState::InProgress(msg.to_string());
if self.replaying.load(Ordering::Acquire) {
self.status_tx.send_if_modified(|state| {
*state = value;
false
});
} else {
let _ = self.status_tx.send(value);
}
}
#[must_use]
pub fn step<'a>(&'a self, key: &'a str) -> StepBuilder<'a> {
assert!(!key.contains('/'), "step key must not contain '/': '{key}'");
assert!(
!key.starts_with('_'),
"step keys starting with '_' are reserved: '{key}'"
);
StepBuilder {
ctx: self,
key,
timeout: None,
retry_override: None,
}
}
pub fn suspend<'a, T>(&'a self, point: &'a SuspendPoint<T>) -> SuspendBuilder<'a, T>
where
T: Serialize + DeserializeOwned + Send,
{
SuspendBuilder {
ctx: self,
key: point.key(),
status_msg: None,
_marker: PhantomData,
}
}
pub fn timer(&self, key: &str, duration: Duration) -> Result<(), EngineError> {
if key.contains('/') || key.starts_with('_') {
return Err(EngineError::InvalidKey {
label: "step_key",
value: key.to_string(),
});
}
let composite_key = format!("{}/{}/{key}", self.workflow_name, self.instance_id);
let span = info_span!("timer", key, composite_key = %composite_key);
if let Some(bytes) = self.read_step(&composite_key)? {
let data: StepData<()> = deserialize_step(&bytes, key)?;
match data {
StepData::Completed { .. } => {
span.in_scope(|| info!("timer already fired — resuming"));
return Ok(());
}
StepData::Suspended => {
span.in_scope(|| info!("timer still pending — re-suspending"));
return Err(EngineError::Suspended {
key: key.to_string(),
});
}
StepData::Failed { .. } => {
}
}
}
let now_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock before unix epoch")
.as_millis();
let deadline = u64::try_from(now_ms).expect("system clock overflows u64 millis")
+ u64::try_from(duration.as_millis()).expect("duration overflows u64 millis");
let data = StepData::<()>::Suspended;
let step_bytes = serialize_step(&data, key)?;
let serial = self.timer_serial.fetch_add(1, Ordering::Relaxed);
let entry = TimerEntry {
workflow_name: self.workflow_name.clone(),
instance_id: self.instance_id.clone(),
step_key: key.to_string(),
};
let timer_bytes =
postcard::to_allocvec(&entry).map_err(|e| EngineError::Serialization {
key: key.to_string(),
source: Box::new(e),
})?;
let write_txn = self.db.begin_write()?;
{
let mut steps = write_txn.open_table(STEPS)?;
steps.insert(composite_key.as_str(), step_bytes.as_slice())?;
let mut timers = write_txn.open_table(TIMERS)?;
timers.insert((deadline, serial), timer_bytes.as_slice())?;
}
write_txn.commit()?;
let msg = format!("timer {key} (deadline {deadline})");
span.in_scope(|| info!(deadline, "timer set — suspending"));
self.status_tx.send_if_modified(|state| {
*state = WorkflowState::Suspended {
key: key.to_string(),
status: msg.clone(),
};
false
});
Err(EngineError::Suspended {
key: key.to_string(),
})
}
fn execute_suspend<T>(&self, key: &str, status_msg: Option<&str>) -> Result<T, EngineError>
where
T: Serialize + DeserializeOwned + Send,
{
let composite_key = format!("{}/{}/{key}", self.workflow_name, self.instance_id);
let span = info_span!("suspend", key, composite_key = %composite_key);
if let Some(bytes) = self.read_step(&composite_key)? {
let data: StepData<T> = deserialize_step(&bytes, key)?;
match data {
StepData::Completed { result, status } => {
span.in_scope(|| info!("signal received — resuming"));
self.replaying.store(false, Ordering::Release);
if let Some(status) = status {
self.status_tx.send_if_modified(|state| {
*state = WorkflowState::InProgress(status);
false
});
}
return Ok(result);
}
StepData::Suspended => {
span.in_scope(|| info!("still suspended — awaiting signal"));
return Err(EngineError::Suspended {
key: key.to_string(),
});
}
StepData::Failed { .. } => {
}
}
}
let data = StepData::<T>::Suspended;
let bytes = serialize_step(&data, key)?;
self.write_step(&composite_key, &bytes)?;
let msg = status_msg.unwrap_or(key).to_string();
span.in_scope(|| info!(status = %msg, "suspending"));
self.status_tx.send_if_modified(|state| {
*state = WorkflowState::Suspended {
key: key.to_string(),
status: msg.clone(),
};
false
});
Err(EngineError::Suspended {
key: key.to_string(),
})
}
fn current_status_string(&self) -> Option<String> {
match &*self.status_tx.borrow() {
WorkflowState::InProgress(msg) => Some(msg.clone()),
_ => None,
}
}
fn read_step(&self, composite_key: &str) -> Result<Option<Vec<u8>>, EngineError> {
let read_txn = self.db.begin_read()?;
let table = match read_txn.open_table(STEPS) {
Ok(t) => t,
Err(redb::TableError::TableDoesNotExist(_)) => return Ok(None),
Err(e) => return Err(EngineError::from(e)),
};
match table.get(composite_key)? {
Some(guard) => Ok(Some(guard.value().to_vec())),
None => Ok(None),
}
}
fn write_step(&self, composite_key: &str, value: &[u8]) -> Result<(), EngineError> {
let write_txn = self.db.begin_write()?;
{
let mut table = write_txn.open_table(STEPS)?;
table.insert(composite_key, value)?;
}
write_txn.commit()?;
Ok(())
}
}
pub struct SuspendBuilder<'a, T> {
ctx: &'a Context,
key: &'a str,
status_msg: Option<&'a str>,
_marker: PhantomData<T>,
}
impl<'a, T> SuspendBuilder<'a, T>
where
T: Serialize + DeserializeOwned + Send,
{
#[must_use]
pub fn status(mut self, msg: &'a str) -> Self {
self.status_msg = Some(msg);
self
}
}
impl<'a, T> IntoFuture for SuspendBuilder<'a, T>
where
T: Serialize + DeserializeOwned + Send + 'a,
{
type Output = Result<T, EngineError>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { self.ctx.execute_suspend(self.key, self.status_msg) })
}
}
pub struct StepBuilder<'a> {
ctx: &'a Context,
key: &'a str,
timeout: Option<Duration>,
retry_override: Option<RetryOverride>,
}
impl StepBuilder<'_> {
#[must_use]
pub fn timeout(mut self, duration: Duration) -> Self {
self.timeout = Some(duration);
self
}
#[must_use]
pub fn retry(mut self, policy: RetryPolicy) -> Self {
self.retry_override = Some(RetryOverride::Custom(policy));
self
}
#[must_use]
pub fn no_retry(mut self) -> Self {
self.retry_override = Some(RetryOverride::Disabled);
self
}
fn effective_retry(&self) -> Option<&RetryPolicy> {
match &self.retry_override {
Some(RetryOverride::Disabled) => None,
Some(RetryOverride::Custom(policy)) => Some(policy),
None => self.ctx.default_retry.as_ref(),
}
}
#[expect(clippy::too_many_lines)]
pub async fn run<F, T>(self, mut f: F) -> Result<T, EngineError>
where
F: AsyncFnMut() -> Result<T, StepError> + Send,
T: Serialize + DeserializeOwned + Send,
{
let composite_key = format!(
"{}/{}/{}",
self.ctx.workflow_name, self.ctx.instance_id, self.key
);
let span = info_span!("step", key = self.key, composite_key = %composite_key);
if let Some(bytes) = self.ctx.read_step(&composite_key)? {
span.in_scope(|| info!("cache hit"));
let data: StepData<T> = deserialize_step(&bytes, self.key)?;
match data {
StepData::Completed { result, status } => {
if let Some(status) = status {
self.ctx.status_tx.send_if_modified(|state| {
*state = WorkflowState::InProgress(status);
false
});
}
return Ok(result);
}
StepData::Suspended => {
return Err(EngineError::SuspendedStepConflict {
key: self.key.to_string(),
});
}
StepData::Failed { .. } => {
span.in_scope(|| info!("dead-letter found — re-executing"));
}
}
}
self.ctx.replaying.store(false, Ordering::Release);
let max_retries = self.effective_retry().map_or(0, |p| p.max_retries);
let total_attempts = max_retries + 1;
for attempt in 0..total_attempts {
span.in_scope(|| {
if attempt > 0 {
tracing::warn!(attempt = attempt + 1, max = total_attempts, "retrying step");
} else {
info!("executing");
}
});
let step_result = if let Some(duration) = self.timeout {
tokio::time::timeout(duration, f())
.await
.map_err(|_| EngineError::StepTimeout {
key: self.key.to_string(),
duration,
})?
} else {
f().await
};
match step_result {
Ok(result) => {
let data = StepData::Completed {
result,
status: self.ctx.current_status_string(),
};
let bytes = serialize_step(&data, self.key)?;
let StepData::Completed { result, .. } = data else {
unreachable!()
};
self.ctx.write_step(&composite_key, &bytes)?;
span.in_scope(|| info!("persisted"));
return Ok(result);
}
Err(StepError::Permanent(inner)) => {
let failed = StepData::<T>::Failed {
error: inner.to_string(),
};
let bytes = serialize_step(&failed, self.key)?;
self.ctx.write_step(&composite_key, &bytes)?;
return Err(EngineError::StepFailed {
key: self.key.to_string(),
source: inner,
retryable: false,
});
}
Err(StepError::Retryable(inner)) => {
let is_last = attempt + 1 >= total_attempts;
if is_last {
let failed = StepData::<T>::Failed {
error: inner.to_string(),
};
let bytes = serialize_step(&failed, self.key)?;
self.ctx.write_step(&composite_key, &bytes)?;
if max_retries == 0 {
return Err(EngineError::StepFailed {
key: self.key.to_string(),
source: inner,
retryable: true,
});
}
return Err(EngineError::RetriesExhausted {
key: self.key.to_string(),
attempts: total_attempts,
source: inner,
});
}
if let Some(policy) = self.effective_retry() {
let delay = policy.delay_for(attempt);
span.in_scope(|| {
tracing::warn!(
attempt = attempt + 1,
max = total_attempts,
error = %inner,
delay_ms = u64::try_from(delay.as_millis()).unwrap_or(u64::MAX),
"retryable error — backing off"
);
});
tokio::time::sleep(delay).await;
}
}
}
}
unreachable!("loop should return before exhausting iterations")
}
}