use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use redb::{Database, ReadableDatabase as _, TableDefinition};
use serde::Serialize;
use serde::de::DeserializeOwned;
use tokio::sync::watch;
use tracing::{info, info_span};
use crate::engine::WorkflowState;
use crate::error::{EngineError, StepError};
pub(crate) const STEPS: TableDefinition<&str, &[u8]> = TableDefinition::new("steps");
pub(crate) const TIMERS: TableDefinition<(u64, u64), &[u8]> = TableDefinition::new("timers");
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
pub(crate) struct TimerEntry {
pub workflow_name: String,
pub instance_id: String,
pub step_key: String,
}
#[derive(serde::Serialize, serde::Deserialize)]
pub(crate) enum StepData<T> {
Completed { result: T, status: Option<String> },
Suspended,
}
pub struct Context {
workflow_name: String,
instance_id: String,
db: Arc<Database>,
status_tx: watch::Sender<WorkflowState>,
replaying: AtomicBool,
timer_serial: Arc<AtomicU64>,
}
impl Context {
pub(crate) fn new(
workflow_name: String,
instance_id: String,
db: Arc<Database>,
status_tx: watch::Sender<WorkflowState>,
timer_serial: Arc<AtomicU64>,
) -> Self {
Self {
workflow_name,
instance_id,
db,
status_tx,
replaying: AtomicBool::new(true),
timer_serial,
}
}
#[must_use]
pub fn workflow_name(&self) -> &str {
&self.workflow_name
}
#[must_use]
pub fn instance_id(&self) -> &str {
&self.instance_id
}
pub fn set_status(&self, msg: impl fmt::Display) {
let value = WorkflowState::InProgress(msg.to_string());
if self.replaying.load(Ordering::Acquire) {
self.status_tx.send_if_modified(|state| {
*state = value;
false
});
} else {
let _ = self.status_tx.send(value);
}
}
#[must_use]
pub fn step<'a>(&'a self, key: &'a str) -> StepBuilder<'a> {
StepBuilder {
ctx: self,
key,
timeout: None,
}
}
pub fn suspend<'a, T>(&'a self, key: &'a str) -> SuspendBuilder<'a, T>
where
T: Serialize + DeserializeOwned + Send,
{
SuspendBuilder {
ctx: self,
key,
status_msg: None,
_marker: PhantomData,
}
}
pub fn timer(&self, key: &str, duration: Duration) -> Result<(), EngineError> {
let composite_key = format!("{}/{}/{key}", self.workflow_name, self.instance_id);
let span = info_span!("timer", key, composite_key = %composite_key);
if let Some(bytes) = self.read_step(&composite_key)? {
let data: StepData<()> =
postcard::from_bytes(&bytes).map_err(|e| EngineError::Serialization {
key: key.to_string(),
source: Box::new(e),
})?;
match data {
StepData::Completed { .. } => {
span.in_scope(|| info!("timer already fired — resuming"));
return Ok(());
}
StepData::Suspended => {
span.in_scope(|| info!("timer still pending — re-suspending"));
return Err(EngineError::Suspended {
key: key.to_string(),
});
}
}
}
let deadline = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock before unix epoch")
.as_secs()
+ duration.as_secs();
let data = StepData::<()>::Suspended;
let bytes = postcard::to_allocvec(&data).map_err(|e| EngineError::Serialization {
key: key.to_string(),
source: Box::new(e),
})?;
self.write_step(&composite_key, &bytes)?;
let serial = self.timer_serial.fetch_add(1, Ordering::Relaxed);
let entry = TimerEntry {
workflow_name: self.workflow_name.clone(),
instance_id: self.instance_id.clone(),
step_key: key.to_string(),
};
let entry_bytes =
postcard::to_allocvec(&entry).map_err(|e| EngineError::Serialization {
key: key.to_string(),
source: Box::new(e),
})?;
let write_txn = self.db.begin_write()?;
{
let mut table = write_txn.open_table(TIMERS)?;
table.insert((deadline, serial), entry_bytes.as_slice())?;
}
write_txn.commit()?;
let msg = format!("timer {key} (deadline {deadline})");
span.in_scope(|| info!(deadline, "timer set — suspending"));
let _ = self.status_tx.send(WorkflowState::Suspended(msg.clone()));
Err(EngineError::Suspended {
key: key.to_string(),
})
}
fn execute_suspend<T>(&self, key: &str, status_msg: Option<&str>) -> Result<T, EngineError>
where
T: Serialize + DeserializeOwned + Send,
{
let composite_key = format!("{}/{}/{key}", self.workflow_name, self.instance_id);
let span = info_span!("suspend", key, composite_key = %composite_key);
if let Some(bytes) = self.read_step(&composite_key)? {
let data: StepData<T> =
postcard::from_bytes(&bytes).map_err(|e| EngineError::Serialization {
key: key.to_string(),
source: Box::new(e),
})?;
match data {
StepData::Completed { result, status } => {
span.in_scope(|| info!("signal received — resuming"));
self.replaying.store(false, Ordering::Release);
if let Some(status) = status {
self.status_tx.send_if_modified(|state| {
*state = WorkflowState::InProgress(status);
false
});
}
return Ok(result);
}
StepData::Suspended => {
span.in_scope(|| info!("still suspended — awaiting signal"));
return Err(EngineError::Suspended {
key: key.to_string(),
});
}
}
}
let data = StepData::<T>::Suspended;
let bytes = postcard::to_allocvec(&data).map_err(|e| EngineError::Serialization {
key: key.to_string(),
source: Box::new(e),
})?;
self.write_step(&composite_key, &bytes)?;
let msg = status_msg.unwrap_or(key).to_string();
span.in_scope(|| info!(status = %msg, "suspending"));
let _ = self.status_tx.send(WorkflowState::Suspended(msg.clone()));
Err(EngineError::Suspended {
key: key.to_string(),
})
}
fn current_status_string(&self) -> Option<String> {
match &*self.status_tx.borrow() {
WorkflowState::InProgress(msg) => Some(msg.clone()),
_ => None,
}
}
fn read_step(&self, composite_key: &str) -> Result<Option<Vec<u8>>, EngineError> {
let read_txn = self.db.begin_read()?;
let table = match read_txn.open_table(STEPS) {
Ok(t) => t,
Err(redb::TableError::TableDoesNotExist(_)) => return Ok(None),
Err(e) => return Err(EngineError::from(e)),
};
match table.get(composite_key)? {
Some(guard) => Ok(Some(guard.value().to_vec())),
None => Ok(None),
}
}
fn write_step(&self, composite_key: &str, value: &[u8]) -> Result<(), EngineError> {
let write_txn = self.db.begin_write()?;
{
let mut table = write_txn.open_table(STEPS)?;
table.insert(composite_key, value)?;
}
write_txn.commit()?;
Ok(())
}
}
pub struct SuspendBuilder<'a, T> {
ctx: &'a Context,
key: &'a str,
status_msg: Option<&'a str>,
_marker: PhantomData<T>,
}
impl<'a, T> SuspendBuilder<'a, T>
where
T: Serialize + DeserializeOwned + Send,
{
#[must_use]
pub fn status(mut self, msg: &'a str) -> Self {
self.status_msg = Some(msg);
self
}
}
impl<'a, T> IntoFuture for SuspendBuilder<'a, T>
where
T: Serialize + DeserializeOwned + Send + 'a,
{
type Output = Result<T, EngineError>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { self.ctx.execute_suspend(self.key, self.status_msg) })
}
}
pub struct StepBuilder<'a> {
ctx: &'a Context,
key: &'a str,
timeout: Option<Duration>,
}
impl StepBuilder<'_> {
#[must_use]
pub fn timeout(mut self, duration: Duration) -> Self {
self.timeout = Some(duration);
self
}
pub async fn run<F, T>(self, f: F) -> Result<T, EngineError>
where
F: AsyncFnOnce() -> Result<T, StepError> + Send,
T: Serialize + DeserializeOwned + Send,
{
let composite_key = format!(
"{}/{}/{}",
self.ctx.workflow_name, self.ctx.instance_id, self.key
);
let span = info_span!("step", key = self.key, composite_key = %composite_key);
if let Some(bytes) = self.ctx.read_step(&composite_key)? {
span.in_scope(|| info!("cache hit"));
let data: StepData<T> =
postcard::from_bytes(&bytes).map_err(|e| EngineError::Serialization {
key: self.key.to_string(),
source: Box::new(e),
})?;
match data {
StepData::Completed { result, status } => {
if let Some(status) = status {
self.ctx.status_tx.send_if_modified(|state| {
*state = WorkflowState::InProgress(status);
false
});
}
return Ok(result);
}
StepData::Suspended => {
span.in_scope(|| {
info!("found suspended entry in step table — unexpected");
});
}
}
}
self.ctx.replaying.store(false, Ordering::Release);
span.in_scope(|| info!("cache miss — executing"));
let step_result = if let Some(duration) = self.timeout {
tokio::time::timeout(duration, f())
.await
.map_err(|_| EngineError::StepTimeout {
key: self.key.to_string(),
duration,
})?
} else {
f().await
};
let result = step_result.map_err(|e| EngineError::StepFailed {
key: self.key.to_string(),
source: match e {
StepError::Retryable(inner) | StepError::Permanent(inner) => inner,
},
})?;
let data = StepData::Completed {
result,
status: self.ctx.current_status_string(),
};
let bytes = postcard::to_allocvec(&data).map_err(|e| EngineError::Serialization {
key: self.key.to_string(),
source: Box::new(e),
})?;
let StepData::Completed { result, .. } = data else {
unreachable!()
};
self.ctx.write_step(&composite_key, &bytes)?;
span.in_scope(|| info!("persisted"));
Ok(result)
}
}