use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use chrono::{DateTime, Utc};
use serde::de::DeserializeOwned;
use uuid::Uuid;
use super::step::StepStatus;
use super::suspend::{SuspendReason, WorkflowEvent};
use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
use crate::function::{AuthContext, KvHandle};
use crate::http::CircuitBreakerClient;
use crate::{ForgeError, Result};
const LOCK_POISONED: &str = "workflow lock poisoned";
pub type CompensationHandler = Arc<
dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync,
>;
#[derive(Debug, Clone)]
pub struct StepState {
pub name: String,
pub status: StepStatus,
pub result: Option<serde_json::Value>,
pub error: Option<String>,
pub started_at: Option<DateTime<Utc>>,
pub completed_at: Option<DateTime<Utc>>,
}
impl StepState {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
status: StepStatus::Pending,
result: None,
error: None,
started_at: None,
completed_at: None,
}
}
pub fn start(&mut self) {
self.status = StepStatus::Running;
self.started_at = Some(Utc::now());
}
pub fn complete(&mut self, result: serde_json::Value) {
self.status = StepStatus::Completed;
self.result = Some(result);
self.completed_at = Some(Utc::now());
}
pub fn fail(&mut self, error: impl Into<String>) {
self.status = StepStatus::Failed;
self.error = Some(error.into());
self.completed_at = Some(Utc::now());
}
pub fn compensate(&mut self) {
self.status = StepStatus::Compensated;
}
}
#[non_exhaustive]
pub struct WorkflowContext {
pub run_id: Uuid,
pub workflow_name: String,
pub started_at: DateTime<Utc>,
workflow_time: DateTime<Utc>,
pub auth: AuthContext,
db_pool: sqlx::PgPool,
http_client: CircuitBreakerClient,
http_timeout: Option<Duration>,
step_states: Arc<RwLock<HashMap<String, StepState>>>,
completed_steps: Arc<RwLock<Vec<String>>>,
compensation_handlers: Arc<RwLock<HashMap<String, CompensationHandler>>>,
is_resumed: bool,
resumed_from_sleep: bool,
tenant_id: Option<Uuid>,
env_provider: Arc<dyn EnvProvider>,
saved_state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
kv: Option<Arc<dyn KvHandle>>,
persist_step_start: bool,
suspend_reason: Arc<std::sync::Mutex<Option<SuspendReason>>>,
}
impl WorkflowContext {
pub fn new(
run_id: Uuid,
workflow_name: String,
db_pool: sqlx::PgPool,
http_client: CircuitBreakerClient,
) -> Self {
let now = Utc::now();
Self {
run_id,
workflow_name,
started_at: now,
workflow_time: now,
auth: AuthContext::unauthenticated(),
db_pool,
http_client,
http_timeout: None,
step_states: Arc::new(RwLock::new(HashMap::new())),
completed_steps: Arc::new(RwLock::new(Vec::new())),
compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
is_resumed: false,
resumed_from_sleep: false,
tenant_id: None,
env_provider: Arc::new(RealEnvProvider::new()),
saved_state: Arc::new(RwLock::new(HashMap::new())),
kv: None,
persist_step_start: false,
suspend_reason: Arc::new(std::sync::Mutex::new(None)),
}
}
pub fn with_persist_step_start(mut self, persist: bool) -> Self {
self.persist_step_start = persist;
self
}
pub fn resumed(
run_id: Uuid,
workflow_name: String,
started_at: DateTime<Utc>,
db_pool: sqlx::PgPool,
http_client: CircuitBreakerClient,
) -> Self {
Self {
run_id,
workflow_name,
started_at,
workflow_time: started_at,
auth: AuthContext::unauthenticated(),
db_pool,
http_client,
http_timeout: None,
step_states: Arc::new(RwLock::new(HashMap::new())),
completed_steps: Arc::new(RwLock::new(Vec::new())),
compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
is_resumed: true,
resumed_from_sleep: false,
tenant_id: None,
env_provider: Arc::new(RealEnvProvider::new()),
saved_state: Arc::new(RwLock::new(HashMap::new())),
kv: None,
persist_step_start: false,
suspend_reason: Arc::new(std::sync::Mutex::new(None)),
}
}
pub fn with_kv(mut self, kv: Arc<dyn KvHandle>) -> Self {
self.kv = Some(kv);
self
}
pub fn kv(&self) -> crate::error::Result<&dyn KvHandle> {
self.kv
.as_deref()
.ok_or_else(|| crate::error::ForgeError::internal("KV store not available"))
}
pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
self.env_provider = provider;
self
}
pub fn with_resumed_from_sleep(mut self) -> Self {
self.resumed_from_sleep = true;
self
}
pub fn with_tenant(mut self, tenant_id: Uuid) -> Self {
self.tenant_id = Some(tenant_id);
self
}
pub fn tenant_id(&self) -> Option<Uuid> {
self.tenant_id
}
pub fn is_resumed(&self) -> bool {
self.is_resumed
}
pub fn workflow_time(&self) -> DateTime<Utc> {
self.workflow_time
}
pub fn db(&self) -> crate::function::ForgeDb {
crate::function::ForgeDb::from_pool(&self.db_pool)
}
pub fn db_conn(&self) -> crate::function::DbConn<'_> {
crate::function::DbConn::Pool(self.db_pool.clone())
}
pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
Ok(crate::function::ForgeConn::Pool(
self.db_pool.acquire().await?,
))
}
pub fn http(&self) -> crate::http::HttpClient {
self.http_client.with_timeout(self.http_timeout)
}
pub fn raw_http(&self) -> &reqwest::Client {
self.http_client.inner()
}
pub fn set_http_timeout(&mut self, timeout: Option<Duration>) {
self.http_timeout = timeout;
}
pub fn with_auth(mut self, auth: AuthContext) -> Self {
self.auth = auth;
self
}
pub fn with_saved_state(self, state: HashMap<String, serde_json::Value>) -> Self {
*self.saved_state.write().expect(LOCK_POISONED) = state;
self
}
pub fn save_state(&self, key: &str, value: impl serde::Serialize) -> crate::Result<()> {
let json = serde_json::to_value(value)
.map_err(|e| crate::ForgeError::Serialization(e.to_string()))?;
self.saved_state
.write()
.expect(LOCK_POISONED)
.insert(key.to_string(), json);
Ok(())
}
pub fn load_state<T: serde::de::DeserializeOwned>(
&self,
key: &str,
) -> crate::Result<Option<T>> {
let guard = self.saved_state.read().expect(LOCK_POISONED);
match guard.get(key) {
Some(value) => {
let result = serde_json::from_value(value.clone())
.map_err(|e| crate::ForgeError::Deserialization(e.to_string()))?;
Ok(Some(result))
}
None => Ok(None),
}
}
pub fn take_saved_state(&self) -> HashMap<String, serde_json::Value> {
self.saved_state.read().expect(LOCK_POISONED).clone()
}
pub fn with_step_states(self, states: HashMap<String, StepState>) -> Self {
let completed: Vec<String> = states
.iter()
.filter(|(_, s)| s.status == StepStatus::Completed)
.map(|(name, _)| name.clone())
.collect();
*self.step_states.write().expect(LOCK_POISONED) = states;
*self.completed_steps.write().expect(LOCK_POISONED) = completed;
self
}
pub fn get_step_state(&self, name: &str) -> Option<StepState> {
self.step_states
.read()
.expect(LOCK_POISONED)
.get(name)
.cloned()
}
pub fn is_step_completed(&self, name: &str) -> bool {
self.step_states
.read()
.expect(LOCK_POISONED)
.get(name)
.map(|s| s.status == StepStatus::Completed)
.unwrap_or(false)
}
pub fn is_step_started(&self, name: &str) -> bool {
self.step_states
.read()
.expect(LOCK_POISONED)
.get(name)
.map(|s| s.status != StepStatus::Pending)
.unwrap_or(false)
}
pub fn get_step_result<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
self.step_states
.read()
.expect(LOCK_POISONED)
.get(name)
.and_then(|s| s.result.as_ref())
.and_then(|v| serde_json::from_value(v.clone()).ok())
}
pub async fn record_step_start(&self, name: &str) -> crate::Result<()> {
let state_clone = {
let mut states = self.step_states.write().expect(LOCK_POISONED);
let state = states
.entry(name.to_string())
.or_insert_with(|| StepState::new(name));
if state.status != StepStatus::Pending {
return Ok(());
}
state.start();
state.clone()
};
if !self.persist_step_start {
return Ok(());
}
let step_id = Uuid::new_v4();
let step_name = name.to_string();
sqlx::query!(
r#"
INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, started_at)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (workflow_run_id, step_name) DO NOTHING
"#,
step_id,
self.run_id,
step_name,
state_clone.status.as_str(),
state_clone.started_at,
)
.execute(&self.db_pool)
.await
.map_err(crate::ForgeError::Database)?;
Ok(())
}
pub async fn record_step_complete(
&self,
name: &str,
result: serde_json::Value,
) -> crate::Result<()> {
let state_clone = self.update_step_state_complete(name, result);
if let Some(state) = state_clone {
Self::persist_step_complete(&self.db_pool, self.run_id, name, &state).await?;
}
Ok(())
}
fn update_step_state_complete(
&self,
name: &str,
result: serde_json::Value,
) -> Option<StepState> {
let mut states = self.step_states.write().expect(LOCK_POISONED);
if let Some(state) = states.get_mut(name) {
state.complete(result.clone());
}
let state_clone = states.get(name).cloned();
drop(states);
let mut completed = self.completed_steps.write().expect(LOCK_POISONED);
if !completed.contains(&name.to_string()) {
completed.push(name.to_string());
}
drop(completed);
state_clone
}
async fn persist_step_complete(
pool: &sqlx::PgPool,
run_id: Uuid,
step_name: &str,
state: &StepState,
) -> crate::Result<()> {
sqlx::query!(
r#"
INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, result, started_at, completed_at)
VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6)
ON CONFLICT (workflow_run_id, step_name) DO UPDATE
SET status = $3, result = $4, completed_at = $6
"#,
run_id,
step_name,
state.status.as_str(),
state.result as _,
state.started_at,
state.completed_at,
)
.execute(pool)
.await
.map_err(crate::ForgeError::Database)?;
Ok(())
}
pub async fn record_step_failure(
&self,
name: &str,
error: impl Into<String>,
) -> crate::Result<()> {
let error_str = error.into();
let state_clone = {
let mut states = self.step_states.write().expect(LOCK_POISONED);
if let Some(state) = states.get_mut(name) {
state.fail(error_str.clone());
}
states.get(name).cloned()
};
if let Some(state) = state_clone {
let step_name = name.to_string();
sqlx::query!(
r#"
UPDATE forge_workflow_steps
SET status = $3, error = $4, completed_at = $5
WHERE workflow_run_id = $1 AND step_name = $2
"#,
self.run_id,
step_name,
state.status.as_str(),
state.error as _,
state.completed_at,
)
.execute(&self.db_pool)
.await
.map_err(crate::ForgeError::Database)?;
}
Ok(())
}
pub async fn record_step_compensated(&self, name: &str) -> crate::Result<()> {
let state_clone = {
let mut states = self.step_states.write().expect(LOCK_POISONED);
if let Some(state) = states.get_mut(name) {
state.compensate();
}
states.get(name).cloned()
};
if let Some(state) = state_clone {
let step_name = name.to_string();
sqlx::query!(
r#"
UPDATE forge_workflow_steps
SET status = $3
WHERE workflow_run_id = $1 AND step_name = $2
"#,
self.run_id,
step_name,
state.status.as_str(),
)
.execute(&self.db_pool)
.await
.map_err(crate::ForgeError::Database)?;
}
Ok(())
}
pub fn completed_steps_reversed(&self) -> Vec<String> {
let completed = self.completed_steps.read().expect(LOCK_POISONED);
completed.iter().rev().cloned().collect()
}
pub fn all_step_states(&self) -> HashMap<String, StepState> {
self.step_states.read().expect(LOCK_POISONED).clone()
}
pub fn elapsed(&self) -> chrono::Duration {
Utc::now() - self.started_at
}
pub fn register_compensation(&self, step_name: &str, handler: CompensationHandler) {
let mut handlers = self.compensation_handlers.write().expect(LOCK_POISONED);
handlers.insert(step_name.to_string(), handler);
}
pub fn get_compensation_handler(&self, step_name: &str) -> Option<CompensationHandler> {
self.compensation_handlers
.read()
.expect(LOCK_POISONED)
.get(step_name)
.cloned()
}
pub fn has_compensation(&self, step_name: &str) -> bool {
self.compensation_handlers
.read()
.expect(LOCK_POISONED)
.contains_key(step_name)
}
pub async fn run_compensation(&self) -> Vec<(String, bool)> {
let steps = self.completed_steps_reversed();
let mut results = Vec::new();
for step_name in steps {
let handler = self.get_compensation_handler(&step_name);
let result = self
.get_step_state(&step_name)
.and_then(|s| s.result.clone());
if let Some(handler) = handler {
let step_result = result.unwrap_or(serde_json::Value::Null);
match handler(step_result).await {
Ok(()) => match self.record_step_compensated(&step_name).await {
Ok(()) => results.push((step_name, true)),
Err(e) => {
tracing::error!(
step = %step_name,
error = %e,
"Failed to persist step compensation; marking compensation as failed",
);
results.push((step_name, false));
}
},
Err(e) => {
tracing::error!(step = %step_name, error = %e, "Compensation failed");
results.push((step_name, false));
}
}
} else {
match self.record_step_compensated(&step_name).await {
Ok(()) => results.push((step_name, true)),
Err(e) => {
tracing::error!(
step = %step_name,
error = %e,
"Failed to persist step compensation",
);
results.push((step_name, false));
}
}
}
}
results
}
pub fn compensation_handlers(&self) -> HashMap<String, CompensationHandler> {
self.compensation_handlers
.read()
.expect(LOCK_POISONED)
.clone()
}
pub async fn sleep(&self, duration: Duration) -> Result<()> {
if self.resumed_from_sleep {
return Ok(());
}
let wake_at = Utc::now() + chrono::Duration::from_std(duration).unwrap_or_default();
self.sleep_until(wake_at).await
}
pub async fn sleep_until(&self, wake_at: DateTime<Utc>) -> Result<()> {
if self.resumed_from_sleep {
return Ok(());
}
if wake_at <= Utc::now() {
return Ok(());
}
self.set_wake_at(wake_at).await?;
self.signal_suspend(SuspendReason::Sleep { wake_at })
.await?;
Ok(())
}
pub async fn wait_for_event<T: DeserializeOwned>(
&self,
event_name: &str,
timeout: Option<Duration>,
) -> Result<T> {
let correlation_id = self.run_id.to_string();
if self.is_resumed
&& let Some(event) = self
.find_consumed_event(event_name, &correlation_id)
.await?
{
return serde_json::from_value(event.payload.unwrap_or_default())
.map_err(|e| ForgeError::Deserialization(e.to_string()));
}
if let Some(event) = self.try_consume_event(event_name, &correlation_id).await? {
return serde_json::from_value(event.payload.unwrap_or_default())
.map_err(|e| ForgeError::Deserialization(e.to_string()));
}
let timeout_at =
timeout.map(|d| Utc::now() + chrono::Duration::from_std(d).unwrap_or_default());
self.set_waiting_for_event(event_name, timeout_at).await?;
self.signal_suspend(SuspendReason::WaitingEvent {
event_name: event_name.to_string(),
timeout: timeout_at,
})
.await?;
self.try_consume_event(event_name, &correlation_id)
.await?
.and_then(|e| e.payload)
.and_then(|p| serde_json::from_value(p).ok())
.ok_or_else(|| ForgeError::Timeout(format!("Event '{}' timed out", event_name)))
}
#[allow(clippy::type_complexity)]
async fn try_consume_event(
&self,
event_name: &str,
correlation_id: &str,
) -> Result<Option<WorkflowEvent>> {
let result = sqlx::query!(
r#"
UPDATE forge_workflow_events
SET consumed_at = NOW(), consumed_by = $3
WHERE id = (
SELECT id FROM forge_workflow_events
WHERE event_name = $1 AND correlation_id = $2 AND consumed_at IS NULL
ORDER BY created_at ASC LIMIT 1
FOR UPDATE SKIP LOCKED
)
RETURNING id, event_name, correlation_id, payload, created_at
"#,
event_name,
correlation_id,
self.run_id
)
.fetch_optional(&self.db_pool)
.await
.map_err(ForgeError::Database)?;
Ok(result.map(|row| WorkflowEvent {
id: row.id,
event_name: row.event_name,
correlation_id: row.correlation_id,
payload: row.payload,
created_at: row.created_at,
}))
}
async fn find_consumed_event(
&self,
event_name: &str,
correlation_id: &str,
) -> Result<Option<WorkflowEvent>> {
let result = sqlx::query!(
r#"
SELECT id, event_name, correlation_id, payload, created_at
FROM forge_workflow_events
WHERE event_name = $1 AND correlation_id = $2 AND consumed_by = $3
ORDER BY created_at DESC LIMIT 1
"#,
event_name,
correlation_id,
self.run_id
)
.fetch_optional(&self.db_pool)
.await
.map_err(ForgeError::Database)?;
Ok(result.map(|row| WorkflowEvent {
id: row.id,
event_name: row.event_name,
correlation_id: row.correlation_id,
payload: row.payload,
created_at: row.created_at,
}))
}
async fn set_wake_at(&self, wake_at: DateTime<Utc>) -> Result<()> {
sqlx::query!(
r#"
UPDATE forge_workflow_runs
SET status = 'sleeping', suspended_at = NOW(), wake_at = $2
WHERE id = $1
"#,
self.run_id,
wake_at,
)
.execute(&self.db_pool)
.await
.map_err(ForgeError::Database)?;
#[allow(clippy::disallowed_methods)]
if let Err(e) = sqlx::query("SELECT pg_notify('forge_workflow_wakeup', $1::text)")
.bind(self.run_id.to_string())
.execute(&self.db_pool)
.await
{
tracing::debug!(
workflow_run_id = %self.run_id,
error = %e,
"Failed to send workflow wakeup notify (scheduler will poll)",
);
}
Ok(())
}
async fn set_waiting_for_event(
&self,
event_name: &str,
timeout_at: Option<DateTime<Utc>>,
) -> Result<()> {
sqlx::query!(
r#"
UPDATE forge_workflow_runs
SET status = 'waiting', suspended_at = NOW(), waiting_for_event = $2, event_timeout_at = $3
WHERE id = $1
"#,
self.run_id,
event_name,
timeout_at,
)
.execute(&self.db_pool)
.await
.map_err(ForgeError::Database)?;
Ok(())
}
async fn signal_suspend(&self, reason: SuspendReason) -> Result<()> {
*self.suspend_reason.lock().expect(LOCK_POISONED) = Some(reason.clone());
Err(ForgeError::WorkflowSuspended(reason))
}
pub fn take_suspend_reason(&self) -> Option<SuspendReason> {
self.suspend_reason.lock().expect(LOCK_POISONED).take()
}
}
impl EnvAccess for WorkflowContext {
fn env_provider(&self) -> &dyn EnvProvider {
self.env_provider.as_ref()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
mod tests {
use super::*;
#[tokio::test]
async fn test_workflow_context_creation() {
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.acquire_timeout(std::time::Duration::from_millis(1))
.connect_lazy("postgres://localhost/nonexistent")
.expect("Failed to create mock pool");
let run_id = Uuid::new_v4();
let ctx = WorkflowContext::new(
run_id,
"test_workflow".to_string(),
pool,
CircuitBreakerClient::with_defaults(reqwest::Client::new()),
);
assert_eq!(ctx.run_id, run_id);
assert_eq!(ctx.workflow_name, "test_workflow");
}
#[tokio::test]
async fn test_step_state_tracking() {
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.acquire_timeout(std::time::Duration::from_millis(1))
.connect_lazy("postgres://localhost/nonexistent")
.expect("Failed to create mock pool");
let ctx = WorkflowContext::new(
Uuid::new_v4(),
"test".to_string(),
pool,
CircuitBreakerClient::with_defaults(reqwest::Client::new()),
);
ctx.record_step_start("step1")
.await
.expect("record_step_start should not touch db when persist disabled");
assert!(!ctx.is_step_completed("step1"));
let complete_err = ctx
.record_step_complete("step1", serde_json::json!({"result": "ok"}))
.await
.expect_err("record_step_complete should propagate db errors");
assert!(
matches!(complete_err, crate::ForgeError::Database(_)),
"expected Database error, got {complete_err:?}",
);
assert!(ctx.is_step_completed("step1"));
let result: Option<serde_json::Value> = ctx.get_step_result("step1");
assert!(result.is_some());
}
#[test]
fn test_step_state_transitions() {
let mut state = StepState::new("test");
assert_eq!(state.status, StepStatus::Pending);
state.start();
assert_eq!(state.status, StepStatus::Running);
assert!(state.started_at.is_some());
state.complete(serde_json::json!({}));
assert_eq!(state.status, StepStatus::Completed);
assert!(state.completed_at.is_some());
}
fn lazy_ctx() -> WorkflowContext {
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.acquire_timeout(std::time::Duration::from_millis(1))
.connect_lazy("postgres://localhost/nonexistent")
.expect("Failed to create mock pool");
WorkflowContext::new(
Uuid::new_v4(),
"test".to_string(),
pool,
CircuitBreakerClient::with_defaults(reqwest::Client::new()),
)
}
#[test]
fn step_state_fail_records_error_and_completion() {
let mut state = StepState::new("step");
state.start();
state.fail("boom");
assert_eq!(state.status, StepStatus::Failed);
assert_eq!(state.error.as_deref(), Some("boom"));
assert!(state.completed_at.is_some());
}
#[test]
fn step_state_compensate_only_flips_status() {
let mut state = StepState::new("step");
state.complete(serde_json::json!({"ok": true}));
let completed_at = state.completed_at;
state.compensate();
assert_eq!(state.status, StepStatus::Compensated);
assert_eq!(state.completed_at, completed_at);
}
#[tokio::test]
async fn save_state_and_load_state_round_trip() {
let ctx = lazy_ctx();
ctx.save_state("count", 42_u32).unwrap();
let v: Option<u32> = ctx.load_state("count").unwrap();
assert_eq!(v, Some(42));
}
#[tokio::test]
async fn load_state_returns_none_for_unknown_key() {
let ctx = lazy_ctx();
let v: Option<String> = ctx.load_state("missing").unwrap();
assert!(v.is_none());
}
#[tokio::test]
async fn load_state_returns_deserialization_error_on_type_mismatch() {
let ctx = lazy_ctx();
ctx.save_state("k", "a string").unwrap();
let err = ctx.load_state::<u32>("k").unwrap_err();
assert!(matches!(err, ForgeError::Deserialization(_)));
}
#[tokio::test]
async fn take_saved_state_returns_snapshot_of_all_entries() {
let ctx = lazy_ctx();
ctx.save_state("a", 1_u32).unwrap();
ctx.save_state("b", "two").unwrap();
let snap = ctx.take_saved_state();
assert_eq!(snap.len(), 2);
assert_eq!(snap.get("a"), Some(&serde_json::json!(1)));
assert_eq!(snap.get("b"), Some(&serde_json::json!("two")));
}
#[tokio::test]
async fn tenant_id_defaults_to_none_and_with_tenant_sets_it() {
let ctx = lazy_ctx();
assert!(ctx.tenant_id().is_none());
let tenant = Uuid::new_v4();
let ctx = ctx.with_tenant(tenant);
assert_eq!(ctx.tenant_id(), Some(tenant));
}
#[tokio::test]
async fn is_resumed_defaults_to_false() {
let ctx = lazy_ctx();
assert!(!ctx.is_resumed());
}
#[tokio::test]
async fn is_step_completed_and_started_return_false_for_unknown_steps() {
let ctx = lazy_ctx();
assert!(!ctx.is_step_completed("nope"));
assert!(!ctx.is_step_started("nope"));
}
#[tokio::test]
async fn get_step_result_returns_none_for_unknown_step() {
let ctx = lazy_ctx();
let v: Option<serde_json::Value> = ctx.get_step_result("nope");
assert!(v.is_none());
}
#[tokio::test]
async fn with_step_states_rebuilds_completed_steps_from_status() {
let ctx = lazy_ctx();
let mut s = HashMap::new();
let mut completed = StepState::new("done");
completed.complete(serde_json::json!({"v": 1}));
s.insert("done".to_string(), completed);
let pending = StepState::new("pending");
s.insert("pending".to_string(), pending);
let ctx = ctx.with_step_states(s);
assert!(ctx.is_step_completed("done"));
assert!(!ctx.is_step_completed("pending"));
let reversed = ctx.completed_steps_reversed();
assert_eq!(reversed, vec!["done".to_string()]);
}
#[tokio::test]
async fn completed_steps_reversed_is_empty_initially() {
let ctx = lazy_ctx();
assert!(ctx.completed_steps_reversed().is_empty());
}
#[tokio::test]
async fn elapsed_is_non_negative() {
let ctx = lazy_ctx();
let e = ctx.elapsed();
assert!(e.num_milliseconds() >= 0);
}
#[tokio::test]
async fn register_and_has_compensation_round_trip() {
let ctx = lazy_ctx();
assert!(!ctx.has_compensation("step1"));
let handler: CompensationHandler =
Arc::new(|_v| Box::pin(async { Ok::<(), ForgeError>(()) }));
ctx.register_compensation("step1", handler);
assert!(ctx.has_compensation("step1"));
assert!(ctx.get_compensation_handler("step1").is_some());
assert!(ctx.get_compensation_handler("step2").is_none());
}
#[tokio::test]
async fn all_step_states_returns_independent_clone() {
let ctx = lazy_ctx();
let mut s = HashMap::new();
s.insert("a".to_string(), StepState::new("a"));
let ctx = ctx.with_step_states(s);
let snap = ctx.all_step_states();
assert_eq!(snap.len(), 1);
assert!(snap.contains_key("a"));
}
}