Skip to main content

forge_core/workflow/
context.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, RwLock};
5use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use serde::de::DeserializeOwned;
9use uuid::Uuid;
10
11use super::step::StepStatus;
12use super::suspend::{SuspendReason, WorkflowEvent};
13use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
14use crate::function::{AuthContext, KvHandle};
15use crate::http::CircuitBreakerClient;
16use crate::{ForgeError, Result};
17
18const LOCK_POISONED: &str = "workflow lock poisoned";
19
20/// Type alias for compensation handler function.
21pub type CompensationHandler = Arc<
22    dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync,
23>;
24
25/// Step state stored during execution.
26#[derive(Debug, Clone)]
27pub struct StepState {
28    /// Step name.
29    pub name: String,
30    /// Step status.
31    pub status: StepStatus,
32    /// Serialized result (if completed).
33    pub result: Option<serde_json::Value>,
34    /// Error message (if failed).
35    pub error: Option<String>,
36    /// When the step started.
37    pub started_at: Option<DateTime<Utc>>,
38    /// When the step completed.
39    pub completed_at: Option<DateTime<Utc>>,
40}
41
42impl StepState {
43    /// Create a new pending step state.
44    pub fn new(name: impl Into<String>) -> Self {
45        Self {
46            name: name.into(),
47            status: StepStatus::Pending,
48            result: None,
49            error: None,
50            started_at: None,
51            completed_at: None,
52        }
53    }
54
55    /// Mark step as running.
56    pub fn start(&mut self) {
57        self.status = StepStatus::Running;
58        self.started_at = Some(Utc::now());
59    }
60
61    /// Mark step as completed with result.
62    pub fn complete(&mut self, result: serde_json::Value) {
63        self.status = StepStatus::Completed;
64        self.result = Some(result);
65        self.completed_at = Some(Utc::now());
66    }
67
68    /// Mark step as failed with error.
69    pub fn fail(&mut self, error: impl Into<String>) {
70        self.status = StepStatus::Failed;
71        self.error = Some(error.into());
72        self.completed_at = Some(Utc::now());
73    }
74
75    /// Mark step as compensated.
76    pub fn compensate(&mut self) {
77        self.status = StepStatus::Compensated;
78    }
79}
80
81/// Context available to workflow handlers.
82#[non_exhaustive]
83pub struct WorkflowContext {
84    /// Workflow run ID.
85    pub run_id: Uuid,
86    /// Workflow name.
87    pub workflow_name: String,
88    /// When the workflow started.
89    pub started_at: DateTime<Utc>,
90    /// Deterministic workflow time (consistent across replays).
91    workflow_time: DateTime<Utc>,
92    /// Authentication context.
93    pub auth: AuthContext,
94    /// Database pool.
95    db_pool: sqlx::PgPool,
96    /// HTTP client.
97    http_client: CircuitBreakerClient,
98    /// Default timeout for outbound HTTP requests made through the
99    /// circuit-breaker client. `None` means unlimited.
100    http_timeout: Option<Duration>,
101    step_states: Arc<RwLock<HashMap<String, StepState>>>,
102    /// Ordered list of completed step names, used to drive compensation in reverse.
103    completed_steps: Arc<RwLock<Vec<String>>>,
104    compensation_handlers: Arc<RwLock<HashMap<String, CompensationHandler>>>,
105    is_resumed: bool,
106    resumed_from_sleep: bool,
107    tenant_id: Option<Uuid>,
108    env_provider: Arc<dyn EnvProvider>,
109    /// Persists across suspension points.
110    saved_state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
111    kv: Option<Arc<dyn KvHandle>>,
112    /// When `false` (default), `record_step_start` skips the DB write;
113    /// `persist_step_complete`'s upsert covers the missing start row, saving
114    /// one roundtrip per step. Set `true` for long-running steps where
115    /// in-progress observability matters.
116    persist_step_start: bool,
117    /// Set by `signal_suspend()` so the executor can read the reason without
118    /// a `ForgeError` variant as a side-channel.
119    suspend_reason: Arc<std::sync::Mutex<Option<SuspendReason>>>,
120}
121
122impl WorkflowContext {
123    /// Create a new workflow context.
124    pub fn new(
125        run_id: Uuid,
126        workflow_name: String,
127        db_pool: sqlx::PgPool,
128        http_client: CircuitBreakerClient,
129    ) -> Self {
130        let now = Utc::now();
131        Self {
132            run_id,
133            workflow_name,
134            started_at: now,
135            workflow_time: now,
136            auth: AuthContext::unauthenticated(),
137            db_pool,
138            http_client,
139            http_timeout: None,
140            step_states: Arc::new(RwLock::new(HashMap::new())),
141            completed_steps: Arc::new(RwLock::new(Vec::new())),
142            compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
143
144            is_resumed: false,
145            resumed_from_sleep: false,
146            tenant_id: None,
147            env_provider: Arc::new(RealEnvProvider::new()),
148            saved_state: Arc::new(RwLock::new(HashMap::new())),
149            kv: None,
150            persist_step_start: false,
151            suspend_reason: Arc::new(std::sync::Mutex::new(None)),
152        }
153    }
154
155    /// Enable DB writes for `record_step_start`. By default, only
156    /// `record_step_complete` writes to the database (its upsert handles
157    /// the missing start row). Enable this for long-running steps where
158    /// observing in-progress state is important.
159    pub fn with_persist_step_start(mut self, persist: bool) -> Self {
160        self.persist_step_start = persist;
161        self
162    }
163
164    /// Create a resumed workflow context.
165    pub fn resumed(
166        run_id: Uuid,
167        workflow_name: String,
168        started_at: DateTime<Utc>,
169        db_pool: sqlx::PgPool,
170        http_client: CircuitBreakerClient,
171    ) -> Self {
172        Self {
173            run_id,
174            workflow_name,
175            started_at,
176            workflow_time: started_at,
177            auth: AuthContext::unauthenticated(),
178            db_pool,
179            http_client,
180            http_timeout: None,
181            step_states: Arc::new(RwLock::new(HashMap::new())),
182            completed_steps: Arc::new(RwLock::new(Vec::new())),
183            compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
184
185            is_resumed: true,
186            resumed_from_sleep: false,
187            tenant_id: None,
188            env_provider: Arc::new(RealEnvProvider::new()),
189            saved_state: Arc::new(RwLock::new(HashMap::new())),
190            kv: None,
191            persist_step_start: false,
192            suspend_reason: Arc::new(std::sync::Mutex::new(None)),
193        }
194    }
195
196    /// Attach a KV store handle. Called by the runtime before handing the
197    /// context to the handler.
198    pub fn with_kv(mut self, kv: Arc<dyn KvHandle>) -> Self {
199        self.kv = Some(kv);
200        self
201    }
202
203    /// Access the KV store.
204    pub fn kv(&self) -> crate::error::Result<&dyn KvHandle> {
205        self.kv
206            .as_deref()
207            .ok_or_else(|| crate::error::ForgeError::internal("KV store not available"))
208    }
209
210    /// Set environment provider.
211    pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
212        self.env_provider = provider;
213        self
214    }
215
216    /// Mark that this context resumed from a sleep (timer expired).
217    pub fn with_resumed_from_sleep(mut self) -> Self {
218        self.resumed_from_sleep = true;
219        self
220    }
221
222    /// Set the tenant ID.
223    pub fn with_tenant(mut self, tenant_id: Uuid) -> Self {
224        self.tenant_id = Some(tenant_id);
225        self
226    }
227
228    pub fn tenant_id(&self) -> Option<Uuid> {
229        self.tenant_id
230    }
231
232    pub fn is_resumed(&self) -> bool {
233        self.is_resumed
234    }
235
236    pub fn workflow_time(&self) -> DateTime<Utc> {
237        self.workflow_time
238    }
239
240    pub fn db(&self) -> crate::function::ForgeDb {
241        crate::function::ForgeDb::from_pool(&self.db_pool)
242    }
243
244    /// Get a `DbConn` for use in shared helper functions.
245    pub fn db_conn(&self) -> crate::function::DbConn<'_> {
246        crate::function::DbConn::Pool(self.db_pool.clone())
247    }
248
249    /// Acquire a connection compatible with sqlx compile-time checked macros.
250    pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
251        Ok(crate::function::ForgeConn::Pool(
252            self.db_pool.acquire().await?,
253        ))
254    }
255
256    pub fn http(&self) -> crate::http::HttpClient {
257        self.http_client.with_timeout(self.http_timeout)
258    }
259
260    pub fn raw_http(&self) -> &reqwest::Client {
261        self.http_client.inner()
262    }
263
264    pub fn set_http_timeout(&mut self, timeout: Option<Duration>) {
265        self.http_timeout = timeout;
266    }
267
268    /// Set authentication context.
269    pub fn with_auth(mut self, auth: AuthContext) -> Self {
270        self.auth = auth;
271        self
272    }
273
274    /// Restore saved state from persisted data (used on resume).
275    pub fn with_saved_state(self, state: HashMap<String, serde_json::Value>) -> Self {
276        *self.saved_state.write().expect(LOCK_POISONED) = state;
277        self
278    }
279
280    /// Save arbitrary state that persists across suspension points.
281    pub fn save_state(&self, key: &str, value: impl serde::Serialize) -> crate::Result<()> {
282        let json = serde_json::to_value(value)
283            .map_err(|e| crate::ForgeError::Serialization(e.to_string()))?;
284        self.saved_state
285            .write()
286            .expect(LOCK_POISONED)
287            .insert(key.to_string(), json);
288        Ok(())
289    }
290
291    /// Load previously saved state. Returns `None` if the key doesn't exist.
292    pub fn load_state<T: serde::de::DeserializeOwned>(
293        &self,
294        key: &str,
295    ) -> crate::Result<Option<T>> {
296        let guard = self.saved_state.read().expect(LOCK_POISONED);
297        match guard.get(key) {
298            Some(value) => {
299                let result = serde_json::from_value(value.clone())
300                    .map_err(|e| crate::ForgeError::Deserialization(e.to_string()))?;
301                Ok(Some(result))
302            }
303            None => Ok(None),
304        }
305    }
306
307    /// Get a snapshot of all saved state for persistence.
308    pub fn take_saved_state(&self) -> HashMap<String, serde_json::Value> {
309        self.saved_state.read().expect(LOCK_POISONED).clone()
310    }
311
312    /// Restore step states from persisted data.
313    pub fn with_step_states(self, states: HashMap<String, StepState>) -> Self {
314        let completed: Vec<String> = states
315            .iter()
316            .filter(|(_, s)| s.status == StepStatus::Completed)
317            .map(|(name, _)| name.clone())
318            .collect();
319
320        *self.step_states.write().expect(LOCK_POISONED) = states;
321        *self.completed_steps.write().expect(LOCK_POISONED) = completed;
322        self
323    }
324
325    pub fn get_step_state(&self, name: &str) -> Option<StepState> {
326        self.step_states
327            .read()
328            .expect(LOCK_POISONED)
329            .get(name)
330            .cloned()
331    }
332
333    pub fn is_step_completed(&self, name: &str) -> bool {
334        self.step_states
335            .read()
336            .expect(LOCK_POISONED)
337            .get(name)
338            .map(|s| s.status == StepStatus::Completed)
339            .unwrap_or(false)
340    }
341
342    /// Check if a step has been started (running, completed, or failed).
343    ///
344    /// Use this to guard steps that should only execute once, even across
345    /// workflow suspension and resumption.
346    pub fn is_step_started(&self, name: &str) -> bool {
347        self.step_states
348            .read()
349            .expect(LOCK_POISONED)
350            .get(name)
351            .map(|s| s.status != StepStatus::Pending)
352            .unwrap_or(false)
353    }
354
355    pub fn get_step_result<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
356        self.step_states
357            .read()
358            .expect(LOCK_POISONED)
359            .get(name)
360            .and_then(|s| s.result.as_ref())
361            .and_then(|v| serde_json::from_value(v.clone()).ok())
362    }
363
364    /// Record step start and persist to the database before returning.
365    ///
366    /// If the step is already running or beyond (completed/failed), this is a no-op.
367    ///
368    /// **`name` is part of the workflow's persisted contract.** The `#[workflow]` macro
369    /// hashes every step name (along with wait keys, timeout, and type names) into a
370    /// signature stored at run creation. If you rename a step, the next deploy produces
371    /// a different signature, and any in-flight run that tries to resume will be blocked
372    /// with `WorkflowStatus::BlockedSignatureMismatch`. Treat step names as stable
373    /// public identifiers — change them only under a new workflow version.
374    ///
375    /// Persistence errors are propagated. A swallowed error here would let
376    /// the workflow continue running while its on-disk state diverged from
377    /// memory, producing a "completed" run with no recorded step rows.
378    pub async fn record_step_start(&self, name: &str) -> crate::Result<()> {
379        let state_clone = {
380            let mut states = self.step_states.write().expect(LOCK_POISONED);
381            let state = states
382                .entry(name.to_string())
383                .or_insert_with(|| StepState::new(name));
384
385            if state.status != StepStatus::Pending {
386                return Ok(());
387            }
388
389            state.start();
390            state.clone()
391        };
392
393        if !self.persist_step_start {
394            return Ok(());
395        }
396
397        let step_id = Uuid::new_v4();
398        let step_name = name.to_string();
399        sqlx::query!(
400            r#"
401                INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, started_at)
402                VALUES ($1, $2, $3, $4, $5)
403                ON CONFLICT (workflow_run_id, step_name) DO NOTHING
404                "#,
405            step_id,
406            self.run_id,
407            step_name,
408            state_clone.status.as_str(),
409            state_clone.started_at,
410        )
411        .execute(&self.db_pool)
412        .await
413        .map_err(crate::ForgeError::Database)?;
414        Ok(())
415    }
416
417    /// Record step completion and persist to the database before returning.
418    ///
419    /// Errors from the persist call are propagated so the workflow can react
420    /// rather than continuing past a step the database never observed.
421    pub async fn record_step_complete(
422        &self,
423        name: &str,
424        result: serde_json::Value,
425    ) -> crate::Result<()> {
426        let state_clone = self.update_step_state_complete(name, result);
427
428        if let Some(state) = state_clone {
429            Self::persist_step_complete(&self.db_pool, self.run_id, name, &state).await?;
430        }
431        Ok(())
432    }
433
434    /// Update in-memory step state to completed.
435    fn update_step_state_complete(
436        &self,
437        name: &str,
438        result: serde_json::Value,
439    ) -> Option<StepState> {
440        let mut states = self.step_states.write().expect(LOCK_POISONED);
441        if let Some(state) = states.get_mut(name) {
442            state.complete(result.clone());
443        }
444        let state_clone = states.get(name).cloned();
445        drop(states);
446
447        let mut completed = self.completed_steps.write().expect(LOCK_POISONED);
448        if !completed.contains(&name.to_string()) {
449            completed.push(name.to_string());
450        }
451        drop(completed);
452
453        state_clone
454    }
455
456    /// Persist step completion to database.
457    async fn persist_step_complete(
458        pool: &sqlx::PgPool,
459        run_id: Uuid,
460        step_name: &str,
461        state: &StepState,
462    ) -> crate::Result<()> {
463        // Use UPSERT to handle race condition where persist_step_start hasn't completed yet
464        sqlx::query!(
465            r#"
466            INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, result, started_at, completed_at)
467            VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6)
468            ON CONFLICT (workflow_run_id, step_name) DO UPDATE
469            SET status = $3, result = $4, completed_at = $6
470            "#,
471            run_id,
472            step_name,
473            state.status.as_str(),
474            state.result as _,
475            state.started_at,
476            state.completed_at,
477        )
478        .execute(pool)
479        .await
480        .map_err(crate::ForgeError::Database)?;
481        Ok(())
482    }
483
484    /// Record step failure and persist to the database before returning.
485    ///
486    /// Errors from the persist call are propagated so the workflow doesn't
487    /// declare a step "failed" only in memory while the row still claims it
488    /// is running.
489    pub async fn record_step_failure(
490        &self,
491        name: &str,
492        error: impl Into<String>,
493    ) -> crate::Result<()> {
494        let error_str = error.into();
495        let state_clone = {
496            let mut states = self.step_states.write().expect(LOCK_POISONED);
497            if let Some(state) = states.get_mut(name) {
498                state.fail(error_str.clone());
499            }
500            states.get(name).cloned()
501        };
502
503        if let Some(state) = state_clone {
504            let step_name = name.to_string();
505            sqlx::query!(
506                r#"
507                    UPDATE forge_workflow_steps
508                    SET status = $3, error = $4, completed_at = $5
509                    WHERE workflow_run_id = $1 AND step_name = $2
510                    "#,
511                self.run_id,
512                step_name,
513                state.status.as_str(),
514                state.error as _,
515                state.completed_at,
516            )
517            .execute(&self.db_pool)
518            .await
519            .map_err(crate::ForgeError::Database)?;
520        }
521        Ok(())
522    }
523
524    /// Record step compensation and persist to the database before returning.
525    ///
526    /// Persistence is inline (not `tokio::spawn`'d): if the process crashes
527    /// after the in-memory state changes but before the row update lands, a
528    /// later resume would see the step as still completed and re-run its
529    /// compensation handler. Inline await ties durability to the caller and
530    /// surfaces failures through `Result`.
531    pub async fn record_step_compensated(&self, name: &str) -> crate::Result<()> {
532        let state_clone = {
533            let mut states = self.step_states.write().expect(LOCK_POISONED);
534            if let Some(state) = states.get_mut(name) {
535                state.compensate();
536            }
537            states.get(name).cloned()
538        };
539
540        if let Some(state) = state_clone {
541            let step_name = name.to_string();
542            sqlx::query!(
543                r#"
544                    UPDATE forge_workflow_steps
545                    SET status = $3
546                    WHERE workflow_run_id = $1 AND step_name = $2
547                    "#,
548                self.run_id,
549                step_name,
550                state.status.as_str(),
551            )
552            .execute(&self.db_pool)
553            .await
554            .map_err(crate::ForgeError::Database)?;
555        }
556        Ok(())
557    }
558
559    pub fn completed_steps_reversed(&self) -> Vec<String> {
560        let completed = self.completed_steps.read().expect(LOCK_POISONED);
561        completed.iter().rev().cloned().collect()
562    }
563
564    pub fn all_step_states(&self) -> HashMap<String, StepState> {
565        self.step_states.read().expect(LOCK_POISONED).clone()
566    }
567
568    pub fn elapsed(&self) -> chrono::Duration {
569        Utc::now() - self.started_at
570    }
571
572    /// Register a compensation handler for a step.
573    ///
574    /// Limitation: compensation handlers are in-memory closures and cannot
575    /// survive a process restart. If the process crashes between step completion
576    /// and workflow termination, compensation handlers for completed steps are
577    /// lost. The `WorkflowExecutor::cancel` method detects this and fails the
578    /// workflow with a clear message indicating manual remediation is required.
579    /// This is an inherent constraint of closure-based compensation; a durable
580    /// alternative would require serializable compensation descriptors (e.g.,
581    /// naming a registered handler + captured args as JSON).
582    pub fn register_compensation(&self, step_name: &str, handler: CompensationHandler) {
583        let mut handlers = self.compensation_handlers.write().expect(LOCK_POISONED);
584        handlers.insert(step_name.to_string(), handler);
585    }
586
587    pub fn get_compensation_handler(&self, step_name: &str) -> Option<CompensationHandler> {
588        self.compensation_handlers
589            .read()
590            .expect(LOCK_POISONED)
591            .get(step_name)
592            .cloned()
593    }
594
595    pub fn has_compensation(&self, step_name: &str) -> bool {
596        self.compensation_handlers
597            .read()
598            .expect(LOCK_POISONED)
599            .contains_key(step_name)
600    }
601
602    /// Run compensation for all completed steps in reverse order.
603    /// Returns a list of (step_name, success) tuples.
604    pub async fn run_compensation(&self) -> Vec<(String, bool)> {
605        let steps = self.completed_steps_reversed();
606        let mut results = Vec::new();
607
608        for step_name in steps {
609            let handler = self.get_compensation_handler(&step_name);
610            let result = self
611                .get_step_state(&step_name)
612                .and_then(|s| s.result.clone());
613
614            if let Some(handler) = handler {
615                let step_result = result.unwrap_or(serde_json::Value::Null);
616                match handler(step_result).await {
617                    Ok(()) => match self.record_step_compensated(&step_name).await {
618                        Ok(()) => results.push((step_name, true)),
619                        Err(e) => {
620                            tracing::error!(
621                                step = %step_name,
622                                error = %e,
623                                "Failed to persist step compensation; marking compensation as failed",
624                            );
625                            results.push((step_name, false));
626                        }
627                    },
628                    Err(e) => {
629                        tracing::error!(step = %step_name, error = %e, "Compensation failed");
630                        results.push((step_name, false));
631                    }
632                }
633            } else {
634                // No compensation handler, mark as compensated anyway.
635                match self.record_step_compensated(&step_name).await {
636                    Ok(()) => results.push((step_name, true)),
637                    Err(e) => {
638                        tracing::error!(
639                            step = %step_name,
640                            error = %e,
641                            "Failed to persist step compensation",
642                        );
643                        results.push((step_name, false));
644                    }
645                }
646            }
647        }
648
649        results
650    }
651
652    pub fn compensation_handlers(&self) -> HashMap<String, CompensationHandler> {
653        self.compensation_handlers
654            .read()
655            .expect(LOCK_POISONED)
656            .clone()
657    }
658
659    /// Sleep for a duration.
660    ///
661    /// This suspends the workflow and persists the wake time to the database.
662    /// The workflow scheduler will resume the workflow when the time arrives.
663    ///
664    /// # Example
665    /// ```ignore
666    /// // Sleep for 30 days
667    /// ctx.sleep(Duration::from_secs(30 * 24 * 60 * 60)).await?;
668    /// ```
669    pub async fn sleep(&self, duration: Duration) -> Result<()> {
670        // If we resumed from a sleep, the timer already expired - continue immediately
671        if self.resumed_from_sleep {
672            return Ok(());
673        }
674
675        let wake_at = Utc::now() + chrono::Duration::from_std(duration).unwrap_or_default();
676        self.sleep_until(wake_at).await
677    }
678
679    /// Sleep until a specific time.
680    ///
681    /// If the wake time has already passed, returns immediately.
682    ///
683    /// # Example
684    /// ```ignore
685    /// use chrono::{Utc, Duration};
686    /// let renewal_date = Utc::now() + Duration::days(30);
687    /// ctx.sleep_until(renewal_date).await?;
688    /// ```
689    pub async fn sleep_until(&self, wake_at: DateTime<Utc>) -> Result<()> {
690        // If we resumed from a sleep, the timer already expired - continue immediately
691        if self.resumed_from_sleep {
692            return Ok(());
693        }
694
695        // If wake time already passed, return immediately
696        if wake_at <= Utc::now() {
697            return Ok(());
698        }
699
700        // Persist the wake time to database
701        self.set_wake_at(wake_at).await?;
702
703        // Signal suspension to executor
704        self.signal_suspend(SuspendReason::Sleep { wake_at })
705            .await?;
706
707        Ok(())
708    }
709
710    /// Wait for an external event with optional timeout.
711    ///
712    /// The workflow suspends until the event arrives or the timeout expires.
713    /// Events are correlated by the workflow run ID.
714    ///
715    /// # Example
716    /// ```ignore
717    /// let payment: PaymentConfirmation = ctx.wait_for_event(
718    ///     "payment_confirmed",
719    ///     Some(Duration::from_secs(7 * 24 * 60 * 60)), // 7 days
720    /// ).await?;
721    /// ```
722    pub async fn wait_for_event<T: DeserializeOwned>(
723        &self,
724        event_name: &str,
725        timeout: Option<Duration>,
726    ) -> Result<T> {
727        let correlation_id = self.run_id.to_string();
728
729        // On resume: check if the scheduler already consumed an event for this run.
730        // This happens when the scheduler wakes the workflow after an event arrives.
731        if self.is_resumed
732            && let Some(event) = self
733                .find_consumed_event(event_name, &correlation_id)
734                .await?
735        {
736            return serde_json::from_value(event.payload.unwrap_or_default())
737                .map_err(|e| ForgeError::Deserialization(e.to_string()));
738        }
739
740        // Check if event already exists but not yet consumed (race condition handling)
741        if let Some(event) = self.try_consume_event(event_name, &correlation_id).await? {
742            return serde_json::from_value(event.payload.unwrap_or_default())
743                .map_err(|e| ForgeError::Deserialization(e.to_string()));
744        }
745
746        // Calculate timeout
747        let timeout_at =
748            timeout.map(|d| Utc::now() + chrono::Duration::from_std(d).unwrap_or_default());
749
750        // Persist waiting state
751        self.set_waiting_for_event(event_name, timeout_at).await?;
752
753        // Signal suspension
754        self.signal_suspend(SuspendReason::WaitingEvent {
755            event_name: event_name.to_string(),
756            timeout: timeout_at,
757        })
758        .await?;
759
760        // After resume, try to consume the event
761        self.try_consume_event(event_name, &correlation_id)
762            .await?
763            .and_then(|e| e.payload)
764            .and_then(|p| serde_json::from_value(p).ok())
765            .ok_or_else(|| ForgeError::Timeout(format!("Event '{}' timed out", event_name)))
766    }
767
768    /// Try to consume an event from the database.
769    #[allow(clippy::type_complexity)]
770    async fn try_consume_event(
771        &self,
772        event_name: &str,
773        correlation_id: &str,
774    ) -> Result<Option<WorkflowEvent>> {
775        let result = sqlx::query!(
776            r#"
777                UPDATE forge_workflow_events
778                SET consumed_at = NOW(), consumed_by = $3
779                WHERE id = (
780                    SELECT id FROM forge_workflow_events
781                    WHERE event_name = $1 AND correlation_id = $2 AND consumed_at IS NULL
782                    ORDER BY created_at ASC LIMIT 1
783                    FOR UPDATE SKIP LOCKED
784                )
785                RETURNING id, event_name, correlation_id, payload, created_at
786                "#,
787            event_name,
788            correlation_id,
789            self.run_id
790        )
791        .fetch_optional(&self.db_pool)
792        .await
793        .map_err(ForgeError::Database)?;
794
795        Ok(result.map(|row| WorkflowEvent {
796            id: row.id,
797            event_name: row.event_name,
798            correlation_id: row.correlation_id,
799            payload: row.payload,
800            created_at: row.created_at,
801        }))
802    }
803
804    /// Find an event that was already consumed by the scheduler for this run.
805    /// Used on resume to retrieve the event payload without re-consuming.
806    async fn find_consumed_event(
807        &self,
808        event_name: &str,
809        correlation_id: &str,
810    ) -> Result<Option<WorkflowEvent>> {
811        let result = sqlx::query!(
812            r#"
813                SELECT id, event_name, correlation_id, payload, created_at
814                FROM forge_workflow_events
815                WHERE event_name = $1 AND correlation_id = $2 AND consumed_by = $3
816                ORDER BY created_at DESC LIMIT 1
817                "#,
818            event_name,
819            correlation_id,
820            self.run_id
821        )
822        .fetch_optional(&self.db_pool)
823        .await
824        .map_err(ForgeError::Database)?;
825
826        Ok(result.map(|row| WorkflowEvent {
827            id: row.id,
828            event_name: row.event_name,
829            correlation_id: row.correlation_id,
830            payload: row.payload,
831            created_at: row.created_at,
832        }))
833    }
834
835    /// Persist wake time to database and notify the scheduler.
836    async fn set_wake_at(&self, wake_at: DateTime<Utc>) -> Result<()> {
837        sqlx::query!(
838            r#"
839            UPDATE forge_workflow_runs
840            SET status = 'sleeping', suspended_at = NOW(), wake_at = $2
841            WHERE id = $1
842            "#,
843            self.run_id,
844            wake_at,
845        )
846        .execute(&self.db_pool)
847        .await
848        .map_err(ForgeError::Database)?;
849
850        // Notify the scheduler so it can pick up this workflow without
851        // waiting for the next poll interval.
852        #[allow(clippy::disallowed_methods)]
853        if let Err(e) = sqlx::query("SELECT pg_notify('forge_workflow_wakeup', $1::text)")
854            .bind(self.run_id.to_string())
855            .execute(&self.db_pool)
856            .await
857        {
858            tracing::debug!(
859                workflow_run_id = %self.run_id,
860                error = %e,
861                "Failed to send workflow wakeup notify (scheduler will poll)",
862            );
863        }
864
865        Ok(())
866    }
867
868    /// Persist waiting for event state to database.
869    async fn set_waiting_for_event(
870        &self,
871        event_name: &str,
872        timeout_at: Option<DateTime<Utc>>,
873    ) -> Result<()> {
874        sqlx::query!(
875            r#"
876            UPDATE forge_workflow_runs
877            SET status = 'waiting', suspended_at = NOW(), waiting_for_event = $2, event_timeout_at = $3
878            WHERE id = $1
879            "#,
880            self.run_id,
881            event_name,
882            timeout_at,
883        )
884        .execute(&self.db_pool)
885        .await
886        .map_err(ForgeError::Database)?;
887        Ok(())
888    }
889
890    /// Signal suspension to the executor.
891    ///
892    /// Stores the reason in the context so the executor can retrieve it via
893    /// `take_suspend_reason()` and returns a typed
894    /// [`ForgeError::WorkflowSuspended`] so the handler short-circuits via `?`.
895    /// The executor matches on the variant — no string parsing, no risk of a
896    /// real internal error being misclassified as a suspension.
897    async fn signal_suspend(&self, reason: SuspendReason) -> Result<()> {
898        *self.suspend_reason.lock().expect(LOCK_POISONED) = Some(reason.clone());
899        Err(ForgeError::WorkflowSuspended(reason))
900    }
901
902    /// Take the stored suspension reason, if any.
903    ///
904    /// Called by the executor after the handler returns an error to determine
905    /// whether the error represents a suspension or a real failure.
906    pub fn take_suspend_reason(&self) -> Option<SuspendReason> {
907        self.suspend_reason.lock().expect(LOCK_POISONED).take()
908    }
909}
910
911impl EnvAccess for WorkflowContext {
912    fn env_provider(&self) -> &dyn EnvProvider {
913        self.env_provider.as_ref()
914    }
915}
916
917#[cfg(test)]
918#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
919mod tests {
920    use super::*;
921
922    #[tokio::test]
923    async fn test_workflow_context_creation() {
924        let pool = sqlx::postgres::PgPoolOptions::new()
925            .max_connections(1)
926            .acquire_timeout(std::time::Duration::from_millis(1))
927            .connect_lazy("postgres://localhost/nonexistent")
928            .expect("Failed to create mock pool");
929
930        let run_id = Uuid::new_v4();
931        let ctx = WorkflowContext::new(
932            run_id,
933            "test_workflow".to_string(),
934            pool,
935            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
936        );
937
938        assert_eq!(ctx.run_id, run_id);
939        assert_eq!(ctx.workflow_name, "test_workflow");
940    }
941
942    #[tokio::test]
943    async fn test_step_state_tracking() {
944        let pool = sqlx::postgres::PgPoolOptions::new()
945            .max_connections(1)
946            .acquire_timeout(std::time::Duration::from_millis(1))
947            .connect_lazy("postgres://localhost/nonexistent")
948            .expect("Failed to create mock pool");
949
950        let ctx = WorkflowContext::new(
951            Uuid::new_v4(),
952            "test".to_string(),
953            pool,
954            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
955        );
956
957        // `persist_step_start` defaults to false, so record_step_start does
958        // not touch the database and is safe with a non-routable pool.
959        ctx.record_step_start("step1")
960            .await
961            .expect("record_step_start should not touch db when persist disabled");
962        assert!(!ctx.is_step_completed("step1"));
963
964        // record_step_complete must always persist (it's the only path that
965        // writes the row), so without a live database it must surface an
966        // error rather than silently succeeding.
967        let complete_err = ctx
968            .record_step_complete("step1", serde_json::json!({"result": "ok"}))
969            .await
970            .expect_err("record_step_complete should propagate db errors");
971        assert!(
972            matches!(complete_err, crate::ForgeError::Database(_)),
973            "expected Database error, got {complete_err:?}",
974        );
975        // In-memory state still moved forward — the DB error doesn't roll it back.
976        assert!(ctx.is_step_completed("step1"));
977
978        let result: Option<serde_json::Value> = ctx.get_step_result("step1");
979        assert!(result.is_some());
980    }
981
982    #[test]
983    fn test_step_state_transitions() {
984        let mut state = StepState::new("test");
985        assert_eq!(state.status, StepStatus::Pending);
986
987        state.start();
988        assert_eq!(state.status, StepStatus::Running);
989        assert!(state.started_at.is_some());
990
991        state.complete(serde_json::json!({}));
992        assert_eq!(state.status, StepStatus::Completed);
993        assert!(state.completed_at.is_some());
994    }
995
996    fn lazy_ctx() -> WorkflowContext {
997        let pool = sqlx::postgres::PgPoolOptions::new()
998            .max_connections(1)
999            .acquire_timeout(std::time::Duration::from_millis(1))
1000            .connect_lazy("postgres://localhost/nonexistent")
1001            .expect("Failed to create mock pool");
1002        WorkflowContext::new(
1003            Uuid::new_v4(),
1004            "test".to_string(),
1005            pool,
1006            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
1007        )
1008    }
1009
1010    #[test]
1011    fn step_state_fail_records_error_and_completion() {
1012        let mut state = StepState::new("step");
1013        state.start();
1014        state.fail("boom");
1015        assert_eq!(state.status, StepStatus::Failed);
1016        assert_eq!(state.error.as_deref(), Some("boom"));
1017        assert!(state.completed_at.is_some());
1018    }
1019
1020    #[test]
1021    fn step_state_compensate_only_flips_status() {
1022        let mut state = StepState::new("step");
1023        state.complete(serde_json::json!({"ok": true}));
1024        let completed_at = state.completed_at;
1025        state.compensate();
1026        assert_eq!(state.status, StepStatus::Compensated);
1027        // compensate() must not erase or overwrite the completion timestamp.
1028        assert_eq!(state.completed_at, completed_at);
1029    }
1030
1031    #[tokio::test]
1032    async fn save_state_and_load_state_round_trip() {
1033        let ctx = lazy_ctx();
1034        ctx.save_state("count", 42_u32).unwrap();
1035        let v: Option<u32> = ctx.load_state("count").unwrap();
1036        assert_eq!(v, Some(42));
1037    }
1038
1039    #[tokio::test]
1040    async fn load_state_returns_none_for_unknown_key() {
1041        let ctx = lazy_ctx();
1042        let v: Option<String> = ctx.load_state("missing").unwrap();
1043        assert!(v.is_none());
1044    }
1045
1046    #[tokio::test]
1047    async fn load_state_returns_deserialization_error_on_type_mismatch() {
1048        let ctx = lazy_ctx();
1049        ctx.save_state("k", "a string").unwrap();
1050        let err = ctx.load_state::<u32>("k").unwrap_err();
1051        assert!(matches!(err, ForgeError::Deserialization(_)));
1052    }
1053
1054    #[tokio::test]
1055    async fn take_saved_state_returns_snapshot_of_all_entries() {
1056        let ctx = lazy_ctx();
1057        ctx.save_state("a", 1_u32).unwrap();
1058        ctx.save_state("b", "two").unwrap();
1059        let snap = ctx.take_saved_state();
1060        assert_eq!(snap.len(), 2);
1061        assert_eq!(snap.get("a"), Some(&serde_json::json!(1)));
1062        assert_eq!(snap.get("b"), Some(&serde_json::json!("two")));
1063    }
1064
1065    #[tokio::test]
1066    async fn tenant_id_defaults_to_none_and_with_tenant_sets_it() {
1067        let ctx = lazy_ctx();
1068        assert!(ctx.tenant_id().is_none());
1069        let tenant = Uuid::new_v4();
1070        let ctx = ctx.with_tenant(tenant);
1071        assert_eq!(ctx.tenant_id(), Some(tenant));
1072    }
1073
1074    #[tokio::test]
1075    async fn is_resumed_defaults_to_false() {
1076        let ctx = lazy_ctx();
1077        assert!(!ctx.is_resumed());
1078    }
1079
1080    #[tokio::test]
1081    async fn is_step_completed_and_started_return_false_for_unknown_steps() {
1082        let ctx = lazy_ctx();
1083        assert!(!ctx.is_step_completed("nope"));
1084        assert!(!ctx.is_step_started("nope"));
1085    }
1086
1087    #[tokio::test]
1088    async fn get_step_result_returns_none_for_unknown_step() {
1089        let ctx = lazy_ctx();
1090        let v: Option<serde_json::Value> = ctx.get_step_result("nope");
1091        assert!(v.is_none());
1092    }
1093
1094    #[tokio::test]
1095    async fn with_step_states_rebuilds_completed_steps_from_status() {
1096        let ctx = lazy_ctx();
1097        let mut s = HashMap::new();
1098        let mut completed = StepState::new("done");
1099        completed.complete(serde_json::json!({"v": 1}));
1100        s.insert("done".to_string(), completed);
1101        let pending = StepState::new("pending");
1102        s.insert("pending".to_string(), pending);
1103
1104        let ctx = ctx.with_step_states(s);
1105        assert!(ctx.is_step_completed("done"));
1106        assert!(!ctx.is_step_completed("pending"));
1107
1108        let reversed = ctx.completed_steps_reversed();
1109        assert_eq!(reversed, vec!["done".to_string()]);
1110    }
1111
1112    #[tokio::test]
1113    async fn completed_steps_reversed_is_empty_initially() {
1114        let ctx = lazy_ctx();
1115        assert!(ctx.completed_steps_reversed().is_empty());
1116    }
1117
1118    #[tokio::test]
1119    async fn elapsed_is_non_negative() {
1120        let ctx = lazy_ctx();
1121        let e = ctx.elapsed();
1122        // started_at was set at construction, so elapsed must be >= 0.
1123        assert!(e.num_milliseconds() >= 0);
1124    }
1125
1126    #[tokio::test]
1127    async fn register_and_has_compensation_round_trip() {
1128        let ctx = lazy_ctx();
1129        assert!(!ctx.has_compensation("step1"));
1130        let handler: CompensationHandler =
1131            Arc::new(|_v| Box::pin(async { Ok::<(), ForgeError>(()) }));
1132        ctx.register_compensation("step1", handler);
1133        assert!(ctx.has_compensation("step1"));
1134        assert!(ctx.get_compensation_handler("step1").is_some());
1135        assert!(ctx.get_compensation_handler("step2").is_none());
1136    }
1137
1138    #[tokio::test]
1139    async fn all_step_states_returns_independent_clone() {
1140        let ctx = lazy_ctx();
1141        let mut s = HashMap::new();
1142        s.insert("a".to_string(), StepState::new("a"));
1143        let ctx = ctx.with_step_states(s);
1144
1145        let snap = ctx.all_step_states();
1146        assert_eq!(snap.len(), 1);
1147        assert!(snap.contains_key("a"));
1148    }
1149}