use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use chrono::{DateTime, Utc};
use serde::de::DeserializeOwned;
use tokio::sync::mpsc;
use uuid::Uuid;
use super::parallel::ParallelBuilder;
use super::step::StepStatus;
use super::suspend::{SuspendReason, WorkflowEvent};
use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
use crate::function::AuthContext;
use crate::http::CircuitBreakerClient;
use crate::{ForgeError, Result};
pub type CompensationHandler = Arc<
dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync,
>;
#[derive(Debug, Clone)]
pub struct StepState {
pub name: String,
pub status: StepStatus,
pub result: Option<serde_json::Value>,
pub error: Option<String>,
pub started_at: Option<DateTime<Utc>>,
pub completed_at: Option<DateTime<Utc>>,
}
impl StepState {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
status: StepStatus::Pending,
result: None,
error: None,
started_at: None,
completed_at: None,
}
}
pub fn start(&mut self) {
self.status = StepStatus::Running;
self.started_at = Some(Utc::now());
}
pub fn complete(&mut self, result: serde_json::Value) {
self.status = StepStatus::Completed;
self.result = Some(result);
self.completed_at = Some(Utc::now());
}
pub fn fail(&mut self, error: impl Into<String>) {
self.status = StepStatus::Failed;
self.error = Some(error.into());
self.completed_at = Some(Utc::now());
}
pub fn compensate(&mut self) {
self.status = StepStatus::Compensated;
}
}
pub struct WorkflowContext {
pub run_id: Uuid,
pub workflow_name: String,
pub started_at: DateTime<Utc>,
workflow_time: DateTime<Utc>,
pub auth: AuthContext,
db_pool: sqlx::PgPool,
http_client: CircuitBreakerClient,
http_timeout: Option<Duration>,
step_states: Arc<RwLock<HashMap<String, StepState>>>,
completed_steps: Arc<RwLock<Vec<String>>>,
compensation_handlers: Arc<RwLock<HashMap<String, CompensationHandler>>>,
suspend_tx: Option<mpsc::Sender<SuspendReason>>,
is_resumed: bool,
resumed_from_sleep: bool,
tenant_id: Option<Uuid>,
env_provider: Arc<dyn EnvProvider>,
}
impl WorkflowContext {
pub fn new(
run_id: Uuid,
workflow_name: String,
db_pool: sqlx::PgPool,
http_client: CircuitBreakerClient,
) -> Self {
let now = Utc::now();
Self {
run_id,
workflow_name,
started_at: now,
workflow_time: now,
auth: AuthContext::unauthenticated(),
db_pool,
http_client,
http_timeout: None,
step_states: Arc::new(RwLock::new(HashMap::new())),
completed_steps: Arc::new(RwLock::new(Vec::new())),
compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
suspend_tx: None,
is_resumed: false,
resumed_from_sleep: false,
tenant_id: None,
env_provider: Arc::new(RealEnvProvider::new()),
}
}
pub fn resumed(
run_id: Uuid,
workflow_name: String,
started_at: DateTime<Utc>,
db_pool: sqlx::PgPool,
http_client: CircuitBreakerClient,
) -> Self {
Self {
run_id,
workflow_name,
started_at,
workflow_time: started_at,
auth: AuthContext::unauthenticated(),
db_pool,
http_client,
http_timeout: None,
step_states: Arc::new(RwLock::new(HashMap::new())),
completed_steps: Arc::new(RwLock::new(Vec::new())),
compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
suspend_tx: None,
is_resumed: true,
resumed_from_sleep: false,
tenant_id: None,
env_provider: Arc::new(RealEnvProvider::new()),
}
}
pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
self.env_provider = provider;
self
}
pub fn with_resumed_from_sleep(mut self) -> Self {
self.resumed_from_sleep = true;
self
}
pub fn with_suspend_channel(mut self, tx: mpsc::Sender<SuspendReason>) -> Self {
self.suspend_tx = Some(tx);
self
}
pub fn with_tenant(mut self, tenant_id: Uuid) -> Self {
self.tenant_id = Some(tenant_id);
self
}
pub fn tenant_id(&self) -> Option<Uuid> {
self.tenant_id
}
pub fn is_resumed(&self) -> bool {
self.is_resumed
}
pub fn workflow_time(&self) -> DateTime<Utc> {
self.workflow_time
}
pub fn db(&self) -> crate::function::ForgeDb {
crate::function::ForgeDb::from_pool(&self.db_pool)
}
pub fn db_conn(&self) -> crate::function::DbConn<'_> {
crate::function::DbConn::Pool(self.db_pool.clone())
}
pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
Ok(crate::function::ForgeConn::Pool(
self.db_pool.acquire().await?,
))
}
pub fn http(&self) -> crate::http::HttpClient {
self.http_client.with_timeout(self.http_timeout)
}
pub fn raw_http(&self) -> &reqwest::Client {
self.http_client.inner()
}
pub fn set_http_timeout(&mut self, timeout: Option<Duration>) {
self.http_timeout = timeout;
}
pub fn with_auth(mut self, auth: AuthContext) -> Self {
self.auth = auth;
self
}
pub fn with_step_states(self, states: HashMap<String, StepState>) -> Self {
let completed: Vec<String> = states
.iter()
.filter(|(_, s)| s.status == StepStatus::Completed)
.map(|(name, _)| name.clone())
.collect();
*self.step_states.write().expect("workflow lock poisoned") = states;
*self
.completed_steps
.write()
.expect("workflow lock poisoned") = completed;
self
}
pub fn get_step_state(&self, name: &str) -> Option<StepState> {
self.step_states
.read()
.expect("workflow lock poisoned")
.get(name)
.cloned()
}
pub fn is_step_completed(&self, name: &str) -> bool {
self.step_states
.read()
.expect("workflow lock poisoned")
.get(name)
.map(|s| s.status == StepStatus::Completed)
.unwrap_or(false)
}
pub fn is_step_started(&self, name: &str) -> bool {
self.step_states
.read()
.expect("workflow lock poisoned")
.get(name)
.map(|s| s.status != StepStatus::Pending)
.unwrap_or(false)
}
pub fn get_step_result<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
self.step_states
.read()
.expect("workflow lock poisoned")
.get(name)
.and_then(|s| s.result.as_ref())
.and_then(|v| serde_json::from_value(v.clone()).ok())
}
pub fn record_step_start(&self, name: &str) {
let mut states = self.step_states.write().expect("workflow lock poisoned");
let state = states
.entry(name.to_string())
.or_insert_with(|| StepState::new(name));
if state.status != StepStatus::Pending {
return;
}
state.start();
let state_clone = state.clone();
drop(states);
let pool = self.db_pool.clone();
let run_id = self.run_id;
let step_name = name.to_string();
tokio::spawn(async move {
let step_id = Uuid::new_v4();
if let Err(e) = sqlx::query!(
r#"
INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, started_at)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (workflow_run_id, step_name) DO NOTHING
"#,
step_id,
run_id,
step_name,
state_clone.status.as_str(),
state_clone.started_at,
)
.execute(&pool)
.await
{
tracing::warn!(
workflow_run_id = %run_id,
step = %step_name,
"Failed to persist step start: {}",
e
);
}
});
}
pub fn record_step_complete(&self, name: &str, result: serde_json::Value) {
let state_clone = self.update_step_state_complete(name, result);
if let Some(state) = state_clone {
let pool = self.db_pool.clone();
let run_id = self.run_id;
let step_name = name.to_string();
tokio::spawn(async move {
Self::persist_step_complete(&pool, run_id, &step_name, &state).await;
});
}
}
pub async fn record_step_complete_async(&self, name: &str, result: serde_json::Value) {
let state_clone = self.update_step_state_complete(name, result);
if let Some(state) = state_clone {
Self::persist_step_complete(&self.db_pool, self.run_id, name, &state).await;
}
}
fn update_step_state_complete(
&self,
name: &str,
result: serde_json::Value,
) -> Option<StepState> {
let mut states = self.step_states.write().expect("workflow lock poisoned");
if let Some(state) = states.get_mut(name) {
state.complete(result.clone());
}
let state_clone = states.get(name).cloned();
drop(states);
let mut completed = self
.completed_steps
.write()
.expect("workflow lock poisoned");
if !completed.contains(&name.to_string()) {
completed.push(name.to_string());
}
drop(completed);
state_clone
}
async fn persist_step_complete(
pool: &sqlx::PgPool,
run_id: Uuid,
step_name: &str,
state: &StepState,
) {
if let Err(e) = sqlx::query!(
r#"
INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, result, started_at, completed_at)
VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6)
ON CONFLICT (workflow_run_id, step_name) DO UPDATE
SET status = $3, result = $4, completed_at = $6
"#,
run_id,
step_name,
state.status.as_str(),
state.result as _,
state.started_at,
state.completed_at,
)
.execute(pool)
.await
{
tracing::warn!(
workflow_run_id = %run_id,
step = %step_name,
"Failed to persist step completion: {}",
e
);
}
}
pub fn record_step_failure(&self, name: &str, error: impl Into<String>) {
let error_str = error.into();
let mut states = self.step_states.write().expect("workflow lock poisoned");
if let Some(state) = states.get_mut(name) {
state.fail(error_str.clone());
}
let state_clone = states.get(name).cloned();
drop(states);
if let Some(state) = state_clone {
let pool = self.db_pool.clone();
let run_id = self.run_id;
let step_name = name.to_string();
tokio::spawn(async move {
if let Err(e) = sqlx::query!(
r#"
UPDATE forge_workflow_steps
SET status = $3, error = $4, completed_at = $5
WHERE workflow_run_id = $1 AND step_name = $2
"#,
run_id,
step_name,
state.status.as_str(),
state.error as _,
state.completed_at,
)
.execute(&pool)
.await
{
tracing::warn!(
workflow_run_id = %run_id,
step = %step_name,
"Failed to persist step failure: {}",
e
);
}
});
}
}
pub fn record_step_compensated(&self, name: &str) {
let mut states = self.step_states.write().expect("workflow lock poisoned");
if let Some(state) = states.get_mut(name) {
state.compensate();
}
let state_clone = states.get(name).cloned();
drop(states);
if let Some(state) = state_clone {
let pool = self.db_pool.clone();
let run_id = self.run_id;
let step_name = name.to_string();
tokio::spawn(async move {
if let Err(e) = sqlx::query!(
r#"
UPDATE forge_workflow_steps
SET status = $3
WHERE workflow_run_id = $1 AND step_name = $2
"#,
run_id,
step_name,
state.status.as_str(),
)
.execute(&pool)
.await
{
tracing::warn!(
workflow_run_id = %run_id,
step = %step_name,
"Failed to persist step compensation: {}",
e
);
}
});
}
}
pub fn completed_steps_reversed(&self) -> Vec<String> {
let completed = self.completed_steps.read().expect("workflow lock poisoned");
completed.iter().rev().cloned().collect()
}
pub fn all_step_states(&self) -> HashMap<String, StepState> {
self.step_states
.read()
.expect("workflow lock poisoned")
.clone()
}
pub fn elapsed(&self) -> chrono::Duration {
Utc::now() - self.started_at
}
pub fn register_compensation(&self, step_name: &str, handler: CompensationHandler) {
let mut handlers = self
.compensation_handlers
.write()
.expect("workflow lock poisoned");
handlers.insert(step_name.to_string(), handler);
}
pub fn get_compensation_handler(&self, step_name: &str) -> Option<CompensationHandler> {
self.compensation_handlers
.read()
.expect("workflow lock poisoned")
.get(step_name)
.cloned()
}
pub fn has_compensation(&self, step_name: &str) -> bool {
self.compensation_handlers
.read()
.expect("workflow lock poisoned")
.contains_key(step_name)
}
pub async fn run_compensation(&self) -> Vec<(String, bool)> {
let steps = self.completed_steps_reversed();
let mut results = Vec::new();
for step_name in steps {
let handler = self.get_compensation_handler(&step_name);
let result = self
.get_step_state(&step_name)
.and_then(|s| s.result.clone());
if let Some(handler) = handler {
let step_result = result.unwrap_or(serde_json::Value::Null);
match handler(step_result).await {
Ok(()) => {
self.record_step_compensated(&step_name);
results.push((step_name, true));
}
Err(e) => {
tracing::error!(step = %step_name, error = %e, "Compensation failed");
results.push((step_name, false));
}
}
} else {
self.record_step_compensated(&step_name);
results.push((step_name, true));
}
}
results
}
pub fn compensation_handlers(&self) -> HashMap<String, CompensationHandler> {
self.compensation_handlers
.read()
.expect("workflow lock poisoned")
.clone()
}
pub async fn sleep(&self, duration: Duration) -> Result<()> {
if self.resumed_from_sleep {
return Ok(());
}
let wake_at = Utc::now() + chrono::Duration::from_std(duration).unwrap_or_default();
self.sleep_until(wake_at).await
}
pub async fn sleep_until(&self, wake_at: DateTime<Utc>) -> Result<()> {
if self.resumed_from_sleep {
return Ok(());
}
if wake_at <= Utc::now() {
return Ok(());
}
self.set_wake_at(wake_at).await?;
self.signal_suspend(SuspendReason::Sleep { wake_at })
.await?;
Ok(())
}
pub async fn wait_for_event<T: DeserializeOwned>(
&self,
event_name: &str,
timeout: Option<Duration>,
) -> Result<T> {
let correlation_id = self.run_id.to_string();
if self.is_resumed
&& let Some(event) = self
.find_consumed_event(event_name, &correlation_id)
.await?
{
return serde_json::from_value(event.payload.unwrap_or_default())
.map_err(|e| ForgeError::Deserialization(e.to_string()));
}
if let Some(event) = self.try_consume_event(event_name, &correlation_id).await? {
return serde_json::from_value(event.payload.unwrap_or_default())
.map_err(|e| ForgeError::Deserialization(e.to_string()));
}
let timeout_at =
timeout.map(|d| Utc::now() + chrono::Duration::from_std(d).unwrap_or_default());
self.set_waiting_for_event(event_name, timeout_at).await?;
self.signal_suspend(SuspendReason::WaitingEvent {
event_name: event_name.to_string(),
timeout: timeout_at,
})
.await?;
self.try_consume_event(event_name, &correlation_id)
.await?
.and_then(|e| e.payload)
.and_then(|p| serde_json::from_value(p).ok())
.ok_or_else(|| ForgeError::Timeout(format!("Event '{}' timed out", event_name)))
}
#[allow(clippy::type_complexity)]
async fn try_consume_event(
&self,
event_name: &str,
correlation_id: &str,
) -> Result<Option<WorkflowEvent>> {
let result = sqlx::query!(
r#"
UPDATE forge_workflow_events
SET consumed_at = NOW(), consumed_by = $3
WHERE id = (
SELECT id FROM forge_workflow_events
WHERE event_name = $1 AND correlation_id = $2 AND consumed_at IS NULL
ORDER BY created_at ASC LIMIT 1
FOR UPDATE SKIP LOCKED
)
RETURNING id, event_name, correlation_id, payload, created_at
"#,
event_name,
correlation_id,
self.run_id
)
.fetch_optional(&self.db_pool)
.await
.map_err(|e| ForgeError::Database(e.to_string()))?;
Ok(result.map(|row| WorkflowEvent {
id: row.id,
event_name: row.event_name,
correlation_id: row.correlation_id,
payload: row.payload,
created_at: row.created_at,
}))
}
async fn find_consumed_event(
&self,
event_name: &str,
correlation_id: &str,
) -> Result<Option<WorkflowEvent>> {
let result = sqlx::query!(
r#"
SELECT id, event_name, correlation_id, payload, created_at
FROM forge_workflow_events
WHERE event_name = $1 AND correlation_id = $2 AND consumed_by = $3
ORDER BY created_at DESC LIMIT 1
"#,
event_name,
correlation_id,
self.run_id
)
.fetch_optional(&self.db_pool)
.await
.map_err(|e| ForgeError::Database(e.to_string()))?;
Ok(result.map(|row| WorkflowEvent {
id: row.id,
event_name: row.event_name,
correlation_id: row.correlation_id,
payload: row.payload,
created_at: row.created_at,
}))
}
async fn set_wake_at(&self, wake_at: DateTime<Utc>) -> Result<()> {
sqlx::query!(
r#"
UPDATE forge_workflow_runs
SET status = 'waiting', suspended_at = NOW(), wake_at = $2
WHERE id = $1
"#,
self.run_id,
wake_at,
)
.execute(&self.db_pool)
.await
.map_err(|e| ForgeError::Database(e.to_string()))?;
Ok(())
}
async fn set_waiting_for_event(
&self,
event_name: &str,
timeout_at: Option<DateTime<Utc>>,
) -> Result<()> {
sqlx::query!(
r#"
UPDATE forge_workflow_runs
SET status = 'waiting', suspended_at = NOW(), waiting_for_event = $2, event_timeout_at = $3
WHERE id = $1
"#,
self.run_id,
event_name,
timeout_at,
)
.execute(&self.db_pool)
.await
.map_err(|e| ForgeError::Database(e.to_string()))?;
Ok(())
}
async fn signal_suspend(&self, reason: SuspendReason) -> Result<()> {
if let Some(ref tx) = self.suspend_tx {
tx.send(reason)
.await
.map_err(|_| ForgeError::Internal("Failed to signal suspension".into()))?;
}
Err(ForgeError::WorkflowSuspended)
}
pub fn parallel(&self) -> ParallelBuilder<'_> {
ParallelBuilder::new(self)
}
pub fn step<T, F, Fut>(&self, name: impl Into<String>, f: F) -> super::StepRunner<'_, T>
where
T: serde::Serialize + serde::de::DeserializeOwned + Clone + Send + Sync + 'static,
F: Fn() -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = crate::Result<T>> + Send + 'static,
{
super::StepRunner::new(self, name, f)
}
}
impl EnvAccess for WorkflowContext {
fn env_provider(&self) -> &dyn EnvProvider {
self.env_provider.as_ref()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
mod tests {
use super::*;
#[tokio::test]
async fn test_workflow_context_creation() {
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect_lazy("postgres://localhost/nonexistent")
.expect("Failed to create mock pool");
let run_id = Uuid::new_v4();
let ctx = WorkflowContext::new(
run_id,
"test_workflow".to_string(),
pool,
CircuitBreakerClient::with_defaults(reqwest::Client::new()),
);
assert_eq!(ctx.run_id, run_id);
assert_eq!(ctx.workflow_name, "test_workflow");
}
#[tokio::test]
async fn test_step_state_tracking() {
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect_lazy("postgres://localhost/nonexistent")
.expect("Failed to create mock pool");
let ctx = WorkflowContext::new(
Uuid::new_v4(),
"test".to_string(),
pool,
CircuitBreakerClient::with_defaults(reqwest::Client::new()),
);
ctx.record_step_start("step1");
assert!(!ctx.is_step_completed("step1"));
ctx.record_step_complete("step1", serde_json::json!({"result": "ok"}));
assert!(ctx.is_step_completed("step1"));
let result: Option<serde_json::Value> = ctx.get_step_result("step1");
assert!(result.is_some());
}
#[test]
fn test_step_state_transitions() {
let mut state = StepState::new("test");
assert_eq!(state.status, StepStatus::Pending);
state.start();
assert_eq!(state.status, StepStatus::Running);
assert!(state.started_at.is_some());
state.complete(serde_json::json!({}));
assert_eq!(state.status, StepStatus::Completed);
assert!(state.completed_at.is_some());
}
}