use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use redb::{Database, ReadableDatabase as _, TableDefinition};
use serde::Serialize;
use serde::de::DeserializeOwned;
use tokio::sync::watch;
use tracing::{info, info_span};
use crate::engine::WorkflowState;
use crate::error::{EngineError, StepError};
pub(crate) const STEPS: TableDefinition<&str, &[u8]> = TableDefinition::new("steps");
pub(crate) const TIMERS: TableDefinition<(u64, u64), &[u8]> = TableDefinition::new("timers");
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
pub(crate) struct TimerEntry {
pub workflow_name: String,
pub instance_id: String,
pub step_key: String,
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub(crate) enum StepData<T> {
Completed { result: T, status: Option<String> },
Suspended,
}
#[derive(serde::Serialize, serde::Deserialize)]
pub(crate) struct StepEnvelope {
pub type_tag: Option<String>,
data: Vec<u8>,
}
pub(crate) fn serialize_step<T: Serialize>(
data: &StepData<T>,
key: &str,
) -> Result<Vec<u8>, EngineError> {
let type_tag = match data {
StepData::Completed { .. } => Some(std::any::type_name::<T>().to_string()),
StepData::Suspended => None,
};
let inner = postcard::to_allocvec(data).map_err(|e| EngineError::Serialization {
key: key.to_string(),
source: Box::new(e),
})?;
let envelope = StepEnvelope {
type_tag,
data: inner,
};
postcard::to_allocvec(&envelope).map_err(|e| EngineError::Serialization {
key: key.to_string(),
source: Box::new(e),
})
}
pub(crate) fn deserialize_step<T: DeserializeOwned>(
bytes: &[u8],
key: &str,
) -> Result<StepData<T>, EngineError> {
let envelope: StepEnvelope =
postcard::from_bytes(bytes).map_err(|e| EngineError::Serialization {
key: key.to_string(),
source: Box::new(e),
})?;
if let Some(ref stored) = envelope.type_tag {
let expected = std::any::type_name::<T>();
if stored != expected {
return Err(EngineError::TypeMismatch {
key: key.to_string(),
expected: expected.to_string(),
found: stored.clone(),
});
}
}
postcard::from_bytes(&envelope.data).map_err(|e| EngineError::Serialization {
key: key.to_string(),
source: Box::new(e),
})
}
pub(crate) fn deserialize_envelope(bytes: &[u8], key: &str) -> Result<StepEnvelope, EngineError> {
postcard::from_bytes(bytes).map_err(|e| EngineError::Serialization {
key: key.to_string(),
source: Box::new(e),
})
}
pub struct Context {
workflow_name: String,
instance_id: String,
db: Arc<Database>,
status_tx: watch::Sender<WorkflowState>,
replaying: AtomicBool,
timer_serial: Arc<AtomicU64>,
}
impl Context {
pub(crate) fn new(
workflow_name: String,
instance_id: String,
db: Arc<Database>,
status_tx: watch::Sender<WorkflowState>,
timer_serial: Arc<AtomicU64>,
) -> Self {
Self {
workflow_name,
instance_id,
db,
status_tx,
replaying: AtomicBool::new(true),
timer_serial,
}
}
#[must_use]
pub fn workflow_name(&self) -> &str {
&self.workflow_name
}
#[must_use]
pub fn instance_id(&self) -> &str {
&self.instance_id
}
pub fn set_status(&self, msg: impl fmt::Display) {
let value = WorkflowState::InProgress(msg.to_string());
if self.replaying.load(Ordering::Acquire) {
self.status_tx.send_if_modified(|state| {
*state = value;
false
});
} else {
let _ = self.status_tx.send(value);
}
}
#[must_use]
pub fn step<'a>(&'a self, key: &'a str) -> StepBuilder<'a> {
assert!(!key.contains('/'), "step key must not contain '/': '{key}'");
StepBuilder {
ctx: self,
key,
timeout: None,
}
}
pub fn suspend<'a, T>(&'a self, key: &'a str) -> SuspendBuilder<'a, T>
where
T: Serialize + DeserializeOwned + Send,
{
assert!(
!key.contains('/'),
"suspend key must not contain '/': '{key}'"
);
SuspendBuilder {
ctx: self,
key,
status_msg: None,
_marker: PhantomData,
}
}
pub fn timer(&self, key: &str, duration: Duration) -> Result<(), EngineError> {
if key.contains('/') {
return Err(EngineError::InvalidKey {
label: "step_key",
value: key.to_string(),
});
}
let composite_key = format!("{}/{}/{key}", self.workflow_name, self.instance_id);
let span = info_span!("timer", key, composite_key = %composite_key);
if let Some(bytes) = self.read_step(&composite_key)? {
let data: StepData<()> = deserialize_step(&bytes, key)?;
match data {
StepData::Completed { .. } => {
span.in_scope(|| info!("timer already fired — resuming"));
return Ok(());
}
StepData::Suspended => {
span.in_scope(|| info!("timer still pending — re-suspending"));
return Err(EngineError::Suspended {
key: key.to_string(),
});
}
}
}
let now_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock before unix epoch")
.as_millis();
let deadline = u64::try_from(now_ms).expect("system clock overflows u64 millis")
+ u64::try_from(duration.as_millis()).expect("duration overflows u64 millis");
let data = StepData::<()>::Suspended;
let step_bytes = serialize_step(&data, key)?;
let serial = self.timer_serial.fetch_add(1, Ordering::Relaxed);
let entry = TimerEntry {
workflow_name: self.workflow_name.clone(),
instance_id: self.instance_id.clone(),
step_key: key.to_string(),
};
let timer_bytes =
postcard::to_allocvec(&entry).map_err(|e| EngineError::Serialization {
key: key.to_string(),
source: Box::new(e),
})?;
let write_txn = self.db.begin_write()?;
{
let mut steps = write_txn.open_table(STEPS)?;
steps.insert(composite_key.as_str(), step_bytes.as_slice())?;
let mut timers = write_txn.open_table(TIMERS)?;
timers.insert((deadline, serial), timer_bytes.as_slice())?;
}
write_txn.commit()?;
let msg = format!("timer {key} (deadline {deadline})");
span.in_scope(|| info!(deadline, "timer set — suspending"));
let _ = self.status_tx.send(WorkflowState::Suspended(msg.clone()));
Err(EngineError::Suspended {
key: key.to_string(),
})
}
fn execute_suspend<T>(&self, key: &str, status_msg: Option<&str>) -> Result<T, EngineError>
where
T: Serialize + DeserializeOwned + Send,
{
let composite_key = format!("{}/{}/{key}", self.workflow_name, self.instance_id);
let span = info_span!("suspend", key, composite_key = %composite_key);
if let Some(bytes) = self.read_step(&composite_key)? {
let data: StepData<T> = deserialize_step(&bytes, key)?;
match data {
StepData::Completed { result, status } => {
span.in_scope(|| info!("signal received — resuming"));
self.replaying.store(false, Ordering::Release);
if let Some(status) = status {
self.status_tx.send_if_modified(|state| {
*state = WorkflowState::InProgress(status);
false
});
}
return Ok(result);
}
StepData::Suspended => {
span.in_scope(|| info!("still suspended — awaiting signal"));
return Err(EngineError::Suspended {
key: key.to_string(),
});
}
}
}
let data = StepData::<T>::Suspended;
let bytes = serialize_step(&data, key)?;
self.write_step(&composite_key, &bytes)?;
let msg = status_msg.unwrap_or(key).to_string();
span.in_scope(|| info!(status = %msg, "suspending"));
let _ = self.status_tx.send(WorkflowState::Suspended(msg.clone()));
Err(EngineError::Suspended {
key: key.to_string(),
})
}
fn current_status_string(&self) -> Option<String> {
match &*self.status_tx.borrow() {
WorkflowState::InProgress(msg) => Some(msg.clone()),
_ => None,
}
}
fn read_step(&self, composite_key: &str) -> Result<Option<Vec<u8>>, EngineError> {
let read_txn = self.db.begin_read()?;
let table = match read_txn.open_table(STEPS) {
Ok(t) => t,
Err(redb::TableError::TableDoesNotExist(_)) => return Ok(None),
Err(e) => return Err(EngineError::from(e)),
};
match table.get(composite_key)? {
Some(guard) => Ok(Some(guard.value().to_vec())),
None => Ok(None),
}
}
fn write_step(&self, composite_key: &str, value: &[u8]) -> Result<(), EngineError> {
let write_txn = self.db.begin_write()?;
{
let mut table = write_txn.open_table(STEPS)?;
table.insert(composite_key, value)?;
}
write_txn.commit()?;
Ok(())
}
}
pub struct SuspendBuilder<'a, T> {
ctx: &'a Context,
key: &'a str,
status_msg: Option<&'a str>,
_marker: PhantomData<T>,
}
impl<'a, T> SuspendBuilder<'a, T>
where
T: Serialize + DeserializeOwned + Send,
{
#[must_use]
pub fn status(mut self, msg: &'a str) -> Self {
self.status_msg = Some(msg);
self
}
}
impl<'a, T> IntoFuture for SuspendBuilder<'a, T>
where
T: Serialize + DeserializeOwned + Send + 'a,
{
type Output = Result<T, EngineError>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { self.ctx.execute_suspend(self.key, self.status_msg) })
}
}
pub struct StepBuilder<'a> {
ctx: &'a Context,
key: &'a str,
timeout: Option<Duration>,
}
impl StepBuilder<'_> {
#[must_use]
pub fn timeout(mut self, duration: Duration) -> Self {
self.timeout = Some(duration);
self
}
pub async fn run<F, T>(self, f: F) -> Result<T, EngineError>
where
F: AsyncFnOnce() -> Result<T, StepError> + Send,
T: Serialize + DeserializeOwned + Send,
{
let composite_key = format!(
"{}/{}/{}",
self.ctx.workflow_name, self.ctx.instance_id, self.key
);
let span = info_span!("step", key = self.key, composite_key = %composite_key);
if let Some(bytes) = self.ctx.read_step(&composite_key)? {
span.in_scope(|| info!("cache hit"));
let data: StepData<T> = deserialize_step(&bytes, self.key)?;
match data {
StepData::Completed { result, status } => {
if let Some(status) = status {
self.ctx.status_tx.send_if_modified(|state| {
*state = WorkflowState::InProgress(status);
false
});
}
return Ok(result);
}
StepData::Suspended => {
return Err(EngineError::SuspendedStepConflict {
key: self.key.to_string(),
});
}
}
}
self.ctx.replaying.store(false, Ordering::Release);
span.in_scope(|| info!("cache miss — executing"));
let step_result = if let Some(duration) = self.timeout {
tokio::time::timeout(duration, f())
.await
.map_err(|_| EngineError::StepTimeout {
key: self.key.to_string(),
duration,
})?
} else {
f().await
};
let result = step_result.map_err(|e| match e {
StepError::Retryable(inner) => EngineError::StepFailed {
key: self.key.to_string(),
source: inner,
retryable: true,
},
StepError::Permanent(inner) => EngineError::StepFailed {
key: self.key.to_string(),
source: inner,
retryable: false,
},
})?;
let data = StepData::Completed {
result,
status: self.ctx.current_status_string(),
};
let bytes = serialize_step(&data, self.key)?;
let StepData::Completed { result, .. } = data else {
unreachable!()
};
self.ctx.write_step(&composite_key, &bytes)?;
span.in_scope(|| info!("persisted"));
Ok(result)
}
}