Skip to main content

forge_core/workflow/
context.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, RwLock};
5use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use serde::de::DeserializeOwned;
9use tokio::sync::mpsc;
10use uuid::Uuid;
11
12use super::parallel::ParallelBuilder;
13use super::step::StepStatus;
14use super::suspend::{SuspendReason, WorkflowEvent};
15use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
16use crate::function::AuthContext;
17use crate::http::CircuitBreakerClient;
18use crate::{ForgeError, Result};
19
20/// Type alias for compensation handler function.
21pub type CompensationHandler = Arc<
22    dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync,
23>;
24
25/// Step state stored during execution.
26#[derive(Debug, Clone)]
27pub struct StepState {
28    /// Step name.
29    pub name: String,
30    /// Step status.
31    pub status: StepStatus,
32    /// Serialized result (if completed).
33    pub result: Option<serde_json::Value>,
34    /// Error message (if failed).
35    pub error: Option<String>,
36    /// When the step started.
37    pub started_at: Option<DateTime<Utc>>,
38    /// When the step completed.
39    pub completed_at: Option<DateTime<Utc>>,
40}
41
42impl StepState {
43    /// Create a new pending step state.
44    pub fn new(name: impl Into<String>) -> Self {
45        Self {
46            name: name.into(),
47            status: StepStatus::Pending,
48            result: None,
49            error: None,
50            started_at: None,
51            completed_at: None,
52        }
53    }
54
55    /// Mark step as running.
56    pub fn start(&mut self) {
57        self.status = StepStatus::Running;
58        self.started_at = Some(Utc::now());
59    }
60
61    /// Mark step as completed with result.
62    pub fn complete(&mut self, result: serde_json::Value) {
63        self.status = StepStatus::Completed;
64        self.result = Some(result);
65        self.completed_at = Some(Utc::now());
66    }
67
68    /// Mark step as failed with error.
69    pub fn fail(&mut self, error: impl Into<String>) {
70        self.status = StepStatus::Failed;
71        self.error = Some(error.into());
72        self.completed_at = Some(Utc::now());
73    }
74
75    /// Mark step as compensated.
76    pub fn compensate(&mut self) {
77        self.status = StepStatus::Compensated;
78    }
79}
80
81/// Context available to workflow handlers.
82pub struct WorkflowContext {
83    /// Workflow run ID.
84    pub run_id: Uuid,
85    /// Workflow name.
86    pub workflow_name: String,
87    /// When the workflow started.
88    pub started_at: DateTime<Utc>,
89    /// Deterministic workflow time (consistent across replays).
90    workflow_time: DateTime<Utc>,
91    /// Authentication context.
92    pub auth: AuthContext,
93    /// Database pool.
94    db_pool: sqlx::PgPool,
95    /// HTTP client.
96    http_client: CircuitBreakerClient,
97    /// Default timeout for outbound HTTP requests made through the
98    /// circuit-breaker client. `None` means unlimited.
99    http_timeout: Option<Duration>,
100    /// Step states (for resumption).
101    step_states: Arc<RwLock<HashMap<String, StepState>>>,
102    /// Completed steps in order (for compensation).
103    completed_steps: Arc<RwLock<Vec<String>>>,
104    /// Compensation handlers for completed steps.
105    compensation_handlers: Arc<RwLock<HashMap<String, CompensationHandler>>>,
106    /// Channel for signaling suspension (sent by workflow, received by executor).
107    suspend_tx: Option<mpsc::Sender<SuspendReason>>,
108    /// Whether this is a resumed execution.
109    is_resumed: bool,
110    /// Whether this execution resumed specifically from a sleep (timer expired).
111    resumed_from_sleep: bool,
112    /// Tenant ID for multi-tenancy.
113    tenant_id: Option<Uuid>,
114    /// Environment variable provider.
115    env_provider: Arc<dyn EnvProvider>,
116}
117
118impl WorkflowContext {
119    /// Create a new workflow context.
120    pub fn new(
121        run_id: Uuid,
122        workflow_name: String,
123        db_pool: sqlx::PgPool,
124        http_client: CircuitBreakerClient,
125    ) -> Self {
126        let now = Utc::now();
127        Self {
128            run_id,
129            workflow_name,
130            started_at: now,
131            workflow_time: now,
132            auth: AuthContext::unauthenticated(),
133            db_pool,
134            http_client,
135            http_timeout: None,
136            step_states: Arc::new(RwLock::new(HashMap::new())),
137            completed_steps: Arc::new(RwLock::new(Vec::new())),
138            compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
139            suspend_tx: None,
140            is_resumed: false,
141            resumed_from_sleep: false,
142            tenant_id: None,
143            env_provider: Arc::new(RealEnvProvider::new()),
144        }
145    }
146
147    /// Create a resumed workflow context.
148    pub fn resumed(
149        run_id: Uuid,
150        workflow_name: String,
151        started_at: DateTime<Utc>,
152        db_pool: sqlx::PgPool,
153        http_client: CircuitBreakerClient,
154    ) -> Self {
155        Self {
156            run_id,
157            workflow_name,
158            started_at,
159            workflow_time: started_at,
160            auth: AuthContext::unauthenticated(),
161            db_pool,
162            http_client,
163            http_timeout: None,
164            step_states: Arc::new(RwLock::new(HashMap::new())),
165            completed_steps: Arc::new(RwLock::new(Vec::new())),
166            compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
167            suspend_tx: None,
168            is_resumed: true,
169            resumed_from_sleep: false,
170            tenant_id: None,
171            env_provider: Arc::new(RealEnvProvider::new()),
172        }
173    }
174
175    /// Set environment provider.
176    pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
177        self.env_provider = provider;
178        self
179    }
180
181    /// Mark that this context resumed from a sleep (timer expired).
182    pub fn with_resumed_from_sleep(mut self) -> Self {
183        self.resumed_from_sleep = true;
184        self
185    }
186
187    /// Set the suspend channel.
188    pub fn with_suspend_channel(mut self, tx: mpsc::Sender<SuspendReason>) -> Self {
189        self.suspend_tx = Some(tx);
190        self
191    }
192
193    /// Set the tenant ID.
194    pub fn with_tenant(mut self, tenant_id: Uuid) -> Self {
195        self.tenant_id = Some(tenant_id);
196        self
197    }
198
199    pub fn tenant_id(&self) -> Option<Uuid> {
200        self.tenant_id
201    }
202
203    pub fn is_resumed(&self) -> bool {
204        self.is_resumed
205    }
206
207    pub fn workflow_time(&self) -> DateTime<Utc> {
208        self.workflow_time
209    }
210
211    pub fn db(&self) -> crate::function::ForgeDb {
212        crate::function::ForgeDb::from_pool(&self.db_pool)
213    }
214
215    /// Get a `DbConn` for use in shared helper functions.
216    pub fn db_conn(&self) -> crate::function::DbConn<'_> {
217        crate::function::DbConn::Pool(self.db_pool.clone())
218    }
219
220    /// Acquire a connection compatible with sqlx compile-time checked macros.
221    pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
222        Ok(crate::function::ForgeConn::Pool(
223            self.db_pool.acquire().await?,
224        ))
225    }
226
227    pub fn http(&self) -> crate::http::HttpClient {
228        self.http_client.with_timeout(self.http_timeout)
229    }
230
231    pub fn raw_http(&self) -> &reqwest::Client {
232        self.http_client.inner()
233    }
234
235    pub fn set_http_timeout(&mut self, timeout: Option<Duration>) {
236        self.http_timeout = timeout;
237    }
238
239    /// Set authentication context.
240    pub fn with_auth(mut self, auth: AuthContext) -> Self {
241        self.auth = auth;
242        self
243    }
244
245    /// Restore step states from persisted data.
246    pub fn with_step_states(self, states: HashMap<String, StepState>) -> Self {
247        let completed: Vec<String> = states
248            .iter()
249            .filter(|(_, s)| s.status == StepStatus::Completed)
250            .map(|(name, _)| name.clone())
251            .collect();
252
253        *self.step_states.write().expect("workflow lock poisoned") = states;
254        *self
255            .completed_steps
256            .write()
257            .expect("workflow lock poisoned") = completed;
258        self
259    }
260
261    pub fn get_step_state(&self, name: &str) -> Option<StepState> {
262        self.step_states
263            .read()
264            .expect("workflow lock poisoned")
265            .get(name)
266            .cloned()
267    }
268
269    pub fn is_step_completed(&self, name: &str) -> bool {
270        self.step_states
271            .read()
272            .expect("workflow lock poisoned")
273            .get(name)
274            .map(|s| s.status == StepStatus::Completed)
275            .unwrap_or(false)
276    }
277
278    /// Check if a step has been started (running, completed, or failed).
279    ///
280    /// Use this to guard steps that should only execute once, even across
281    /// workflow suspension and resumption.
282    pub fn is_step_started(&self, name: &str) -> bool {
283        self.step_states
284            .read()
285            .expect("workflow lock poisoned")
286            .get(name)
287            .map(|s| s.status != StepStatus::Pending)
288            .unwrap_or(false)
289    }
290
291    pub fn get_step_result<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
292        self.step_states
293            .read()
294            .expect("workflow lock poisoned")
295            .get(name)
296            .and_then(|s| s.result.as_ref())
297            .and_then(|v| serde_json::from_value(v.clone()).ok())
298    }
299
300    /// Record step start.
301    ///
302    /// If the step is already running or beyond (completed/failed), this is a no-op.
303    /// This prevents race conditions when resuming workflows.
304    pub fn record_step_start(&self, name: &str) {
305        let mut states = self.step_states.write().expect("workflow lock poisoned");
306        let state = states
307            .entry(name.to_string())
308            .or_insert_with(|| StepState::new(name));
309
310        // Only update if step is pending - prevents race condition on resume
311        // where background DB update could overwrite a completed status
312        if state.status != StepStatus::Pending {
313            return;
314        }
315
316        state.start();
317        let state_clone = state.clone();
318        drop(states);
319
320        // Persist to database in background
321        let pool = self.db_pool.clone();
322        let run_id = self.run_id;
323        let step_name = name.to_string();
324        tokio::spawn(async move {
325            let step_id = Uuid::new_v4();
326            if let Err(e) = sqlx::query!(
327                r#"
328                INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, started_at)
329                VALUES ($1, $2, $3, $4, $5)
330                ON CONFLICT (workflow_run_id, step_name) DO NOTHING
331                "#,
332                step_id,
333                run_id,
334                step_name,
335                state_clone.status.as_str(),
336                state_clone.started_at,
337            )
338            .execute(&pool)
339            .await
340            {
341                tracing::warn!(
342                    workflow_run_id = %run_id,
343                    step = %step_name,
344                    "Failed to persist step start: {}",
345                    e
346                );
347            }
348        });
349    }
350
351    /// Record step completion (fire-and-forget database update).
352    /// Use `record_step_complete_async` if you need to ensure persistence before continuing.
353    pub fn record_step_complete(&self, name: &str, result: serde_json::Value) {
354        let state_clone = self.update_step_state_complete(name, result);
355
356        // Persist to database in background
357        if let Some(state) = state_clone {
358            let pool = self.db_pool.clone();
359            let run_id = self.run_id;
360            let step_name = name.to_string();
361            tokio::spawn(async move {
362                Self::persist_step_complete(&pool, run_id, &step_name, &state).await;
363            });
364        }
365    }
366
367    /// Record step completion and wait for database persistence.
368    pub async fn record_step_complete_async(&self, name: &str, result: serde_json::Value) {
369        let state_clone = self.update_step_state_complete(name, result);
370
371        // Persist to database synchronously
372        if let Some(state) = state_clone {
373            Self::persist_step_complete(&self.db_pool, self.run_id, name, &state).await;
374        }
375    }
376
377    /// Update in-memory step state to completed.
378    fn update_step_state_complete(
379        &self,
380        name: &str,
381        result: serde_json::Value,
382    ) -> Option<StepState> {
383        let mut states = self.step_states.write().expect("workflow lock poisoned");
384        if let Some(state) = states.get_mut(name) {
385            state.complete(result.clone());
386        }
387        let state_clone = states.get(name).cloned();
388        drop(states);
389
390        let mut completed = self
391            .completed_steps
392            .write()
393            .expect("workflow lock poisoned");
394        if !completed.contains(&name.to_string()) {
395            completed.push(name.to_string());
396        }
397        drop(completed);
398
399        state_clone
400    }
401
402    /// Persist step completion to database.
403    async fn persist_step_complete(
404        pool: &sqlx::PgPool,
405        run_id: Uuid,
406        step_name: &str,
407        state: &StepState,
408    ) {
409        // Use UPSERT to handle race condition where persist_step_start hasn't completed yet
410        if let Err(e) = sqlx::query!(
411            r#"
412            INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, result, started_at, completed_at)
413            VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6)
414            ON CONFLICT (workflow_run_id, step_name) DO UPDATE
415            SET status = $3, result = $4, completed_at = $6
416            "#,
417            run_id,
418            step_name,
419            state.status.as_str(),
420            state.result as _,
421            state.started_at,
422            state.completed_at,
423        )
424        .execute(pool)
425        .await
426        {
427            tracing::warn!(
428                workflow_run_id = %run_id,
429                step = %step_name,
430                "Failed to persist step completion: {}",
431                e
432            );
433        }
434    }
435
436    /// Record step failure.
437    pub fn record_step_failure(&self, name: &str, error: impl Into<String>) {
438        let error_str = error.into();
439        let mut states = self.step_states.write().expect("workflow lock poisoned");
440        if let Some(state) = states.get_mut(name) {
441            state.fail(error_str.clone());
442        }
443        let state_clone = states.get(name).cloned();
444        drop(states);
445
446        // Persist to database in background
447        if let Some(state) = state_clone {
448            let pool = self.db_pool.clone();
449            let run_id = self.run_id;
450            let step_name = name.to_string();
451            tokio::spawn(async move {
452                if let Err(e) = sqlx::query!(
453                    r#"
454                    UPDATE forge_workflow_steps
455                    SET status = $3, error = $4, completed_at = $5
456                    WHERE workflow_run_id = $1 AND step_name = $2
457                    "#,
458                    run_id,
459                    step_name,
460                    state.status.as_str(),
461                    state.error as _,
462                    state.completed_at,
463                )
464                .execute(&pool)
465                .await
466                {
467                    tracing::warn!(
468                        workflow_run_id = %run_id,
469                        step = %step_name,
470                        "Failed to persist step failure: {}",
471                        e
472                    );
473                }
474            });
475        }
476    }
477
478    /// Record step compensation.
479    pub fn record_step_compensated(&self, name: &str) {
480        let mut states = self.step_states.write().expect("workflow lock poisoned");
481        if let Some(state) = states.get_mut(name) {
482            state.compensate();
483        }
484        let state_clone = states.get(name).cloned();
485        drop(states);
486
487        // Persist to database in background
488        if let Some(state) = state_clone {
489            let pool = self.db_pool.clone();
490            let run_id = self.run_id;
491            let step_name = name.to_string();
492            tokio::spawn(async move {
493                if let Err(e) = sqlx::query!(
494                    r#"
495                    UPDATE forge_workflow_steps
496                    SET status = $3
497                    WHERE workflow_run_id = $1 AND step_name = $2
498                    "#,
499                    run_id,
500                    step_name,
501                    state.status.as_str(),
502                )
503                .execute(&pool)
504                .await
505                {
506                    tracing::warn!(
507                        workflow_run_id = %run_id,
508                        step = %step_name,
509                        "Failed to persist step compensation: {}",
510                        e
511                    );
512                }
513            });
514        }
515    }
516
517    pub fn completed_steps_reversed(&self) -> Vec<String> {
518        let completed = self.completed_steps.read().expect("workflow lock poisoned");
519        completed.iter().rev().cloned().collect()
520    }
521
522    pub fn all_step_states(&self) -> HashMap<String, StepState> {
523        self.step_states
524            .read()
525            .expect("workflow lock poisoned")
526            .clone()
527    }
528
529    pub fn elapsed(&self) -> chrono::Duration {
530        Utc::now() - self.started_at
531    }
532
533    /// Register a compensation handler for a step.
534    pub fn register_compensation(&self, step_name: &str, handler: CompensationHandler) {
535        let mut handlers = self
536            .compensation_handlers
537            .write()
538            .expect("workflow lock poisoned");
539        handlers.insert(step_name.to_string(), handler);
540    }
541
542    pub fn get_compensation_handler(&self, step_name: &str) -> Option<CompensationHandler> {
543        self.compensation_handlers
544            .read()
545            .expect("workflow lock poisoned")
546            .get(step_name)
547            .cloned()
548    }
549
550    pub fn has_compensation(&self, step_name: &str) -> bool {
551        self.compensation_handlers
552            .read()
553            .expect("workflow lock poisoned")
554            .contains_key(step_name)
555    }
556
557    /// Run compensation for all completed steps in reverse order.
558    /// Returns a list of (step_name, success) tuples.
559    pub async fn run_compensation(&self) -> Vec<(String, bool)> {
560        let steps = self.completed_steps_reversed();
561        let mut results = Vec::new();
562
563        for step_name in steps {
564            let handler = self.get_compensation_handler(&step_name);
565            let result = self
566                .get_step_state(&step_name)
567                .and_then(|s| s.result.clone());
568
569            if let Some(handler) = handler {
570                let step_result = result.unwrap_or(serde_json::Value::Null);
571                match handler(step_result).await {
572                    Ok(()) => {
573                        self.record_step_compensated(&step_name);
574                        results.push((step_name, true));
575                    }
576                    Err(e) => {
577                        tracing::error!(step = %step_name, error = %e, "Compensation failed");
578                        results.push((step_name, false));
579                    }
580                }
581            } else {
582                // No compensation handler, mark as compensated anyway
583                self.record_step_compensated(&step_name);
584                results.push((step_name, true));
585            }
586        }
587
588        results
589    }
590
591    pub fn compensation_handlers(&self) -> HashMap<String, CompensationHandler> {
592        self.compensation_handlers
593            .read()
594            .expect("workflow lock poisoned")
595            .clone()
596    }
597
598    /// Sleep for a duration.
599    ///
600    /// This suspends the workflow and persists the wake time to the database.
601    /// The workflow scheduler will resume the workflow when the time arrives.
602    ///
603    /// # Example
604    /// ```ignore
605    /// // Sleep for 30 days
606    /// ctx.sleep(Duration::from_secs(30 * 24 * 60 * 60)).await?;
607    /// ```
608    pub async fn sleep(&self, duration: Duration) -> Result<()> {
609        // If we resumed from a sleep, the timer already expired - continue immediately
610        if self.resumed_from_sleep {
611            return Ok(());
612        }
613
614        let wake_at = Utc::now() + chrono::Duration::from_std(duration).unwrap_or_default();
615        self.sleep_until(wake_at).await
616    }
617
618    /// Sleep until a specific time.
619    ///
620    /// If the wake time has already passed, returns immediately.
621    ///
622    /// # Example
623    /// ```ignore
624    /// use chrono::{Utc, Duration};
625    /// let renewal_date = Utc::now() + Duration::days(30);
626    /// ctx.sleep_until(renewal_date).await?;
627    /// ```
628    pub async fn sleep_until(&self, wake_at: DateTime<Utc>) -> Result<()> {
629        // If we resumed from a sleep, the timer already expired - continue immediately
630        if self.resumed_from_sleep {
631            return Ok(());
632        }
633
634        // If wake time already passed, return immediately
635        if wake_at <= Utc::now() {
636            return Ok(());
637        }
638
639        // Persist the wake time to database
640        self.set_wake_at(wake_at).await?;
641
642        // Signal suspension to executor
643        self.signal_suspend(SuspendReason::Sleep { wake_at })
644            .await?;
645
646        Ok(())
647    }
648
649    /// Wait for an external event with optional timeout.
650    ///
651    /// The workflow suspends until the event arrives or the timeout expires.
652    /// Events are correlated by the workflow run ID.
653    ///
654    /// # Example
655    /// ```ignore
656    /// let payment: PaymentConfirmation = ctx.wait_for_event(
657    ///     "payment_confirmed",
658    ///     Some(Duration::from_secs(7 * 24 * 60 * 60)), // 7 days
659    /// ).await?;
660    /// ```
661    pub async fn wait_for_event<T: DeserializeOwned>(
662        &self,
663        event_name: &str,
664        timeout: Option<Duration>,
665    ) -> Result<T> {
666        let correlation_id = self.run_id.to_string();
667
668        // On resume: check if the scheduler already consumed an event for this run.
669        // This happens when the scheduler wakes the workflow after an event arrives.
670        if self.is_resumed
671            && let Some(event) = self
672                .find_consumed_event(event_name, &correlation_id)
673                .await?
674        {
675            return serde_json::from_value(event.payload.unwrap_or_default())
676                .map_err(|e| ForgeError::Deserialization(e.to_string()));
677        }
678
679        // Check if event already exists but not yet consumed (race condition handling)
680        if let Some(event) = self.try_consume_event(event_name, &correlation_id).await? {
681            return serde_json::from_value(event.payload.unwrap_or_default())
682                .map_err(|e| ForgeError::Deserialization(e.to_string()));
683        }
684
685        // Calculate timeout
686        let timeout_at =
687            timeout.map(|d| Utc::now() + chrono::Duration::from_std(d).unwrap_or_default());
688
689        // Persist waiting state
690        self.set_waiting_for_event(event_name, timeout_at).await?;
691
692        // Signal suspension
693        self.signal_suspend(SuspendReason::WaitingEvent {
694            event_name: event_name.to_string(),
695            timeout: timeout_at,
696        })
697        .await?;
698
699        // After resume, try to consume the event
700        self.try_consume_event(event_name, &correlation_id)
701            .await?
702            .and_then(|e| e.payload)
703            .and_then(|p| serde_json::from_value(p).ok())
704            .ok_or_else(|| ForgeError::Timeout(format!("Event '{}' timed out", event_name)))
705    }
706
707    /// Try to consume an event from the database.
708    #[allow(clippy::type_complexity)]
709    async fn try_consume_event(
710        &self,
711        event_name: &str,
712        correlation_id: &str,
713    ) -> Result<Option<WorkflowEvent>> {
714        let result = sqlx::query!(
715            r#"
716                UPDATE forge_workflow_events
717                SET consumed_at = NOW(), consumed_by = $3
718                WHERE id = (
719                    SELECT id FROM forge_workflow_events
720                    WHERE event_name = $1 AND correlation_id = $2 AND consumed_at IS NULL
721                    ORDER BY created_at ASC LIMIT 1
722                    FOR UPDATE SKIP LOCKED
723                )
724                RETURNING id, event_name, correlation_id, payload, created_at
725                "#,
726            event_name,
727            correlation_id,
728            self.run_id
729        )
730        .fetch_optional(&self.db_pool)
731        .await
732        .map_err(|e| ForgeError::Database(e.to_string()))?;
733
734        Ok(result.map(|row| WorkflowEvent {
735            id: row.id,
736            event_name: row.event_name,
737            correlation_id: row.correlation_id,
738            payload: row.payload,
739            created_at: row.created_at,
740        }))
741    }
742
743    /// Find an event that was already consumed by the scheduler for this run.
744    /// Used on resume to retrieve the event payload without re-consuming.
745    async fn find_consumed_event(
746        &self,
747        event_name: &str,
748        correlation_id: &str,
749    ) -> Result<Option<WorkflowEvent>> {
750        let result = sqlx::query!(
751            r#"
752                SELECT id, event_name, correlation_id, payload, created_at
753                FROM forge_workflow_events
754                WHERE event_name = $1 AND correlation_id = $2 AND consumed_by = $3
755                ORDER BY created_at DESC LIMIT 1
756                "#,
757            event_name,
758            correlation_id,
759            self.run_id
760        )
761        .fetch_optional(&self.db_pool)
762        .await
763        .map_err(|e| ForgeError::Database(e.to_string()))?;
764
765        Ok(result.map(|row| WorkflowEvent {
766            id: row.id,
767            event_name: row.event_name,
768            correlation_id: row.correlation_id,
769            payload: row.payload,
770            created_at: row.created_at,
771        }))
772    }
773
774    /// Persist wake time to database.
775    async fn set_wake_at(&self, wake_at: DateTime<Utc>) -> Result<()> {
776        sqlx::query!(
777            r#"
778            UPDATE forge_workflow_runs
779            SET status = 'waiting', suspended_at = NOW(), wake_at = $2
780            WHERE id = $1
781            "#,
782            self.run_id,
783            wake_at,
784        )
785        .execute(&self.db_pool)
786        .await
787        .map_err(|e| ForgeError::Database(e.to_string()))?;
788        Ok(())
789    }
790
791    /// Persist waiting for event state to database.
792    async fn set_waiting_for_event(
793        &self,
794        event_name: &str,
795        timeout_at: Option<DateTime<Utc>>,
796    ) -> Result<()> {
797        sqlx::query!(
798            r#"
799            UPDATE forge_workflow_runs
800            SET status = 'waiting', suspended_at = NOW(), waiting_for_event = $2, event_timeout_at = $3
801            WHERE id = $1
802            "#,
803            self.run_id,
804            event_name,
805            timeout_at,
806        )
807        .execute(&self.db_pool)
808        .await
809        .map_err(|e| ForgeError::Database(e.to_string()))?;
810        Ok(())
811    }
812
813    /// Signal suspension to the executor.
814    async fn signal_suspend(&self, reason: SuspendReason) -> Result<()> {
815        if let Some(ref tx) = self.suspend_tx {
816            tx.send(reason)
817                .await
818                .map_err(|_| ForgeError::Internal("Failed to signal suspension".into()))?;
819        }
820        // Return a special error that the executor catches
821        Err(ForgeError::WorkflowSuspended)
822    }
823
824    /// Create a parallel builder for executing steps concurrently.
825    ///
826    /// # Example
827    /// ```ignore
828    /// let results = ctx.parallel()
829    ///     .step("fetch_user", || async { get_user(id).await })
830    ///     .step("fetch_orders", || async { get_orders(id).await })
831    ///     .step_with_compensate("charge_card",
832    ///         || async { charge_card(amount).await },
833    ///         |charge| async move { refund(charge.id).await })
834    ///     .run().await?;
835    ///
836    /// let user: User = results.get("fetch_user")?;
837    /// let orders: Vec<Order> = results.get("fetch_orders")?;
838    /// ```
839    pub fn parallel(&self) -> ParallelBuilder<'_> {
840        ParallelBuilder::new(self)
841    }
842
843    /// Create a step runner for executing a workflow step.
844    ///
845    /// This provides a fluent API for defining steps with retry, compensation,
846    /// timeout, and optional behavior.
847    ///
848    /// # Examples
849    ///
850    /// ```ignore
851    /// use std::time::Duration;
852    ///
853    /// // Simple step
854    /// let data = ctx.step("fetch_data", || async {
855    ///     Ok(fetch_from_api().await?)
856    /// }).run().await?;
857    ///
858    /// // Step with retry (3 attempts, 2 second delay)
859    /// ctx.step("send_email", || async {
860    ///     send_verification_email(&user.email).await
861    /// })
862    /// .retry(3, Duration::from_secs(2))
863    /// .run()
864    /// .await?;
865    ///
866    /// // Step with compensation (rollback on later failure)
867    /// let charge = ctx.step("charge_card", || async {
868    ///     charge_credit_card(&card).await
869    /// })
870    /// .compensate(|charge_result| async move {
871    ///     refund_charge(&charge_result.charge_id).await
872    /// })
873    /// .run()
874    /// .await?;
875    ///
876    /// // Optional step (failure won't trigger compensation)
877    /// ctx.step("notify_slack", || async {
878    ///     post_to_slack("User signed up!").await
879    /// })
880    /// .optional()
881    /// .run()
882    /// .await?;
883    ///
884    /// // Step with timeout
885    /// ctx.step("slow_operation", || async {
886    ///     process_large_file().await
887    /// })
888    /// .timeout(Duration::from_secs(60))
889    /// .run()
890    /// .await?;
891    /// ```
892    pub fn step<T, F, Fut>(&self, name: impl Into<String>, f: F) -> super::StepRunner<'_, T>
893    where
894        T: serde::Serialize + serde::de::DeserializeOwned + Clone + Send + Sync + 'static,
895        F: Fn() -> Fut + Send + Sync + 'static,
896        Fut: std::future::Future<Output = crate::Result<T>> + Send + 'static,
897    {
898        super::StepRunner::new(self, name, f)
899    }
900}
901
902impl EnvAccess for WorkflowContext {
903    fn env_provider(&self) -> &dyn EnvProvider {
904        self.env_provider.as_ref()
905    }
906}
907
908#[cfg(test)]
909#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
910mod tests {
911    use super::*;
912
913    #[tokio::test]
914    async fn test_workflow_context_creation() {
915        let pool = sqlx::postgres::PgPoolOptions::new()
916            .max_connections(1)
917            .connect_lazy("postgres://localhost/nonexistent")
918            .expect("Failed to create mock pool");
919
920        let run_id = Uuid::new_v4();
921        let ctx = WorkflowContext::new(
922            run_id,
923            "test_workflow".to_string(),
924            pool,
925            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
926        );
927
928        assert_eq!(ctx.run_id, run_id);
929        assert_eq!(ctx.workflow_name, "test_workflow");
930    }
931
932    #[tokio::test]
933    async fn test_step_state_tracking() {
934        let pool = sqlx::postgres::PgPoolOptions::new()
935            .max_connections(1)
936            .connect_lazy("postgres://localhost/nonexistent")
937            .expect("Failed to create mock pool");
938
939        let ctx = WorkflowContext::new(
940            Uuid::new_v4(),
941            "test".to_string(),
942            pool,
943            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
944        );
945
946        ctx.record_step_start("step1");
947        assert!(!ctx.is_step_completed("step1"));
948
949        ctx.record_step_complete("step1", serde_json::json!({"result": "ok"}));
950        assert!(ctx.is_step_completed("step1"));
951
952        let result: Option<serde_json::Value> = ctx.get_step_result("step1");
953        assert!(result.is_some());
954    }
955
956    #[test]
957    fn test_step_state_transitions() {
958        let mut state = StepState::new("test");
959        assert_eq!(state.status, StepStatus::Pending);
960
961        state.start();
962        assert_eq!(state.status, StepStatus::Running);
963        assert!(state.started_at.is_some());
964
965        state.complete(serde_json::json!({}));
966        assert_eq!(state.status, StepStatus::Completed);
967        assert!(state.completed_at.is_some());
968    }
969}