use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use serde::{Serialize, de::DeserializeOwned};
use super::context::WorkflowContext;
use crate::Result;
type StepFn<T> = Arc<dyn Fn() -> Pin<Box<dyn Future<Output = Result<T>> + Send>> + Send + Sync>;
type CompensateFn<T> =
Arc<dyn Fn(T) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync>;
pub struct StepRunner<'a, T>
where
T: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
{
ctx: &'a WorkflowContext,
name: String,
step_fn: Option<StepFn<T>>,
compensate_fn: Option<CompensateFn<T>>,
timeout: Option<Duration>,
retry_count: u32,
retry_delay: Duration,
optional: bool,
}
impl<'a, T> StepRunner<'a, T>
where
T: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
{
pub(crate) fn new<F, Fut>(ctx: &'a WorkflowContext, name: impl Into<String>, f: F) -> Self
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<T>> + Send + 'static,
{
let name = name.into();
let step_fn: StepFn<T> = Arc::new(move || Box::pin(f()));
Self {
ctx,
name,
step_fn: Some(step_fn),
compensate_fn: None,
timeout: None,
retry_count: 0,
retry_delay: Duration::from_secs(0),
optional: false,
}
}
pub fn compensate<F, Fut>(mut self, f: F) -> Self
where
F: Fn(T) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<()>> + Send + 'static,
{
self.compensate_fn = Some(Arc::new(move |result| Box::pin(f(result))));
self
}
pub fn timeout(mut self, duration: Duration) -> Self {
self.timeout = Some(duration);
self
}
pub fn retry(mut self, count: u32, delay: Duration) -> Self {
self.retry_count = count;
self.retry_delay = delay;
self
}
pub fn optional(mut self) -> Self {
self.optional = true;
self
}
pub async fn run(mut self) -> Result<Option<T>> {
let step_fn = self
.step_fn
.take()
.expect("StepRunner::run called without step function");
if self.ctx.is_step_completed(&self.name)
&& let Some(result) = self.ctx.get_step_result::<T>(&self.name)
{
tracing::debug!(step = %self.name, "Step already completed, returning cached result");
return Ok(Some(result));
}
self.ctx.record_step_start(&self.name);
let total_attempts = self.retry_count + 1; let mut last_error = None;
for attempt in 1..=total_attempts {
let result = self.execute_with_timeout(&step_fn).await;
match result {
Ok(value) => {
let json_value =
serde_json::to_value(&value).unwrap_or(serde_json::Value::Null);
self.ctx.record_step_complete(&self.name, json_value);
if let Some(compensate_fn) = self.compensate_fn.take() {
let value_clone = value.clone();
self.ctx.register_compensation(
&self.name,
Arc::new(move |_| compensate_fn(value_clone.clone())),
);
}
return Ok(Some(value));
}
Err(e) => {
last_error = Some(e);
if attempt < total_attempts {
if let Some(ref err) = last_error {
tracing::warn!(
step = %self.name,
attempt = attempt,
max_attempts = total_attempts,
delay_ms = self.retry_delay.as_millis() as u64,
error = %err,
"Step failed, retrying"
);
}
tokio::time::sleep(self.retry_delay).await;
}
}
}
}
let error = last_error.expect("loop ran at least one attempt");
let error_msg = error.to_string();
self.ctx.record_step_failure(&self.name, &error_msg);
if self.optional {
tracing::warn!(
step = %self.name,
error = %error_msg,
attempts = total_attempts,
"Optional step failed after all retries, continuing workflow"
);
return Ok(None);
}
Err(error)
}
async fn execute_with_timeout(&self, step_fn: &StepFn<T>) -> Result<T> {
let fut = step_fn();
if let Some(timeout_duration) = self.timeout {
match tokio::time::timeout(timeout_duration, fut).await {
Ok(result) => result,
Err(_) => Err(crate::ForgeError::Timeout(format!(
"Step '{}' timed out after {:?}",
self.name, timeout_duration
))),
}
} else {
fut.await
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
mod tests {
#[allow(unused_imports)]
use super::*;
#[test]
fn test_step_runner_builder_pattern() {
}
}