Skip to main content

forge_core/workflow/
context.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, RwLock};
5use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use serde::de::DeserializeOwned;
9use tokio::sync::mpsc;
10use uuid::Uuid;
11
12use super::parallel::ParallelBuilder;
13use super::step::StepStatus;
14use super::suspend::{SuspendReason, WorkflowEvent};
15use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
16use crate::function::AuthContext;
17use crate::{ForgeError, Result};
18
19/// Type alias for compensation handler function.
20pub type CompensationHandler = Arc<
21    dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync,
22>;
23
24/// Step state stored during execution.
25#[derive(Debug, Clone)]
26pub struct StepState {
27    /// Step name.
28    pub name: String,
29    /// Step status.
30    pub status: StepStatus,
31    /// Serialized result (if completed).
32    pub result: Option<serde_json::Value>,
33    /// Error message (if failed).
34    pub error: Option<String>,
35    /// When the step started.
36    pub started_at: Option<DateTime<Utc>>,
37    /// When the step completed.
38    pub completed_at: Option<DateTime<Utc>>,
39}
40
41impl StepState {
42    /// Create a new pending step state.
43    pub fn new(name: impl Into<String>) -> Self {
44        Self {
45            name: name.into(),
46            status: StepStatus::Pending,
47            result: None,
48            error: None,
49            started_at: None,
50            completed_at: None,
51        }
52    }
53
54    /// Mark step as running.
55    pub fn start(&mut self) {
56        self.status = StepStatus::Running;
57        self.started_at = Some(Utc::now());
58    }
59
60    /// Mark step as completed with result.
61    pub fn complete(&mut self, result: serde_json::Value) {
62        self.status = StepStatus::Completed;
63        self.result = Some(result);
64        self.completed_at = Some(Utc::now());
65    }
66
67    /// Mark step as failed with error.
68    pub fn fail(&mut self, error: impl Into<String>) {
69        self.status = StepStatus::Failed;
70        self.error = Some(error.into());
71        self.completed_at = Some(Utc::now());
72    }
73
74    /// Mark step as compensated.
75    pub fn compensate(&mut self) {
76        self.status = StepStatus::Compensated;
77    }
78}
79
80/// Context available to workflow handlers.
81pub struct WorkflowContext {
82    /// Workflow run ID.
83    pub run_id: Uuid,
84    /// Workflow name.
85    pub workflow_name: String,
86    /// Workflow version.
87    pub version: u32,
88    /// When the workflow started.
89    pub started_at: DateTime<Utc>,
90    /// Deterministic workflow time (consistent across replays).
91    workflow_time: DateTime<Utc>,
92    /// Authentication context.
93    pub auth: AuthContext,
94    /// Database pool.
95    db_pool: sqlx::PgPool,
96    /// HTTP client.
97    http_client: reqwest::Client,
98    /// Step states (for resumption).
99    step_states: Arc<RwLock<HashMap<String, StepState>>>,
100    /// Completed steps in order (for compensation).
101    completed_steps: Arc<RwLock<Vec<String>>>,
102    /// Compensation handlers for completed steps.
103    compensation_handlers: Arc<RwLock<HashMap<String, CompensationHandler>>>,
104    /// Channel for signaling suspension (sent by workflow, received by executor).
105    suspend_tx: Option<mpsc::Sender<SuspendReason>>,
106    /// Whether this is a resumed execution.
107    is_resumed: bool,
108    /// Whether this execution resumed specifically from a sleep (timer expired).
109    resumed_from_sleep: bool,
110    /// Tenant ID for multi-tenancy.
111    tenant_id: Option<Uuid>,
112    /// Environment variable provider.
113    env_provider: Arc<dyn EnvProvider>,
114}
115
116impl WorkflowContext {
117    /// Create a new workflow context.
118    pub fn new(
119        run_id: Uuid,
120        workflow_name: String,
121        version: u32,
122        db_pool: sqlx::PgPool,
123        http_client: reqwest::Client,
124    ) -> Self {
125        let now = Utc::now();
126        Self {
127            run_id,
128            workflow_name,
129            version,
130            started_at: now,
131            workflow_time: now,
132            auth: AuthContext::unauthenticated(),
133            db_pool,
134            http_client,
135            step_states: Arc::new(RwLock::new(HashMap::new())),
136            completed_steps: Arc::new(RwLock::new(Vec::new())),
137            compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
138            suspend_tx: None,
139            is_resumed: false,
140            resumed_from_sleep: false,
141            tenant_id: None,
142            env_provider: Arc::new(RealEnvProvider::new()),
143        }
144    }
145
146    /// Create a resumed workflow context.
147    pub fn resumed(
148        run_id: Uuid,
149        workflow_name: String,
150        version: u32,
151        started_at: DateTime<Utc>,
152        db_pool: sqlx::PgPool,
153        http_client: reqwest::Client,
154    ) -> Self {
155        Self {
156            run_id,
157            workflow_name,
158            version,
159            started_at,
160            workflow_time: started_at,
161            auth: AuthContext::unauthenticated(),
162            db_pool,
163            http_client,
164            step_states: Arc::new(RwLock::new(HashMap::new())),
165            completed_steps: Arc::new(RwLock::new(Vec::new())),
166            compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
167            suspend_tx: None,
168            is_resumed: true,
169            resumed_from_sleep: false,
170            tenant_id: None,
171            env_provider: Arc::new(RealEnvProvider::new()),
172        }
173    }
174
175    /// Set environment provider.
176    pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
177        self.env_provider = provider;
178        self
179    }
180
181    /// Mark that this context resumed from a sleep (timer expired).
182    pub fn with_resumed_from_sleep(mut self) -> Self {
183        self.resumed_from_sleep = true;
184        self
185    }
186
187    /// Set the suspend channel.
188    pub fn with_suspend_channel(mut self, tx: mpsc::Sender<SuspendReason>) -> Self {
189        self.suspend_tx = Some(tx);
190        self
191    }
192
193    /// Set the tenant ID.
194    pub fn with_tenant(mut self, tenant_id: Uuid) -> Self {
195        self.tenant_id = Some(tenant_id);
196        self
197    }
198
199    pub fn tenant_id(&self) -> Option<Uuid> {
200        self.tenant_id
201    }
202
203    pub fn is_resumed(&self) -> bool {
204        self.is_resumed
205    }
206
207    pub fn workflow_time(&self) -> DateTime<Utc> {
208        self.workflow_time
209    }
210
211    pub fn db(&self) -> &sqlx::PgPool {
212        &self.db_pool
213    }
214
215    /// Returns a `DbConn` wrapping the pool for shared helper functions.
216    pub fn db_conn(&self) -> crate::function::DbConn<'_> {
217        crate::function::DbConn::Pool(&self.db_pool)
218    }
219
220    /// Acquire a connection compatible with sqlx compile-time checked macros.
221    pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
222        Ok(crate::function::ForgeConn::Pool(
223            self.db_pool.acquire().await?,
224        ))
225    }
226
227    pub fn http(&self) -> &reqwest::Client {
228        &self.http_client
229    }
230
231    /// Set authentication context.
232    pub fn with_auth(mut self, auth: AuthContext) -> Self {
233        self.auth = auth;
234        self
235    }
236
237    /// Restore step states from persisted data.
238    pub fn with_step_states(self, states: HashMap<String, StepState>) -> Self {
239        let completed: Vec<String> = states
240            .iter()
241            .filter(|(_, s)| s.status == StepStatus::Completed)
242            .map(|(name, _)| name.clone())
243            .collect();
244
245        *self.step_states.write().expect("workflow lock poisoned") = states;
246        *self
247            .completed_steps
248            .write()
249            .expect("workflow lock poisoned") = completed;
250        self
251    }
252
253    pub fn get_step_state(&self, name: &str) -> Option<StepState> {
254        self.step_states
255            .read()
256            .expect("workflow lock poisoned")
257            .get(name)
258            .cloned()
259    }
260
261    pub fn is_step_completed(&self, name: &str) -> bool {
262        self.step_states
263            .read()
264            .expect("workflow lock poisoned")
265            .get(name)
266            .map(|s| s.status == StepStatus::Completed)
267            .unwrap_or(false)
268    }
269
270    /// Check if a step has been started (running, completed, or failed).
271    ///
272    /// Use this to guard steps that should only execute once, even across
273    /// workflow suspension and resumption.
274    pub fn is_step_started(&self, name: &str) -> bool {
275        self.step_states
276            .read()
277            .expect("workflow lock poisoned")
278            .get(name)
279            .map(|s| s.status != StepStatus::Pending)
280            .unwrap_or(false)
281    }
282
283    pub fn get_step_result<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
284        self.step_states
285            .read()
286            .expect("workflow lock poisoned")
287            .get(name)
288            .and_then(|s| s.result.as_ref())
289            .and_then(|v| serde_json::from_value(v.clone()).ok())
290    }
291
292    /// Record step start.
293    ///
294    /// If the step is already running or beyond (completed/failed), this is a no-op.
295    /// This prevents race conditions when resuming workflows.
296    pub fn record_step_start(&self, name: &str) {
297        let mut states = self.step_states.write().expect("workflow lock poisoned");
298        let state = states
299            .entry(name.to_string())
300            .or_insert_with(|| StepState::new(name));
301
302        // Only update if step is pending - prevents race condition on resume
303        // where background DB update could overwrite a completed status
304        if state.status != StepStatus::Pending {
305            return;
306        }
307
308        state.start();
309        let state_clone = state.clone();
310        drop(states);
311
312        // Persist to database in background
313        let pool = self.db_pool.clone();
314        let run_id = self.run_id;
315        let step_name = name.to_string();
316        tokio::spawn(async move {
317            let step_id = Uuid::new_v4();
318            if let Err(e) = sqlx::query(
319                r#"
320                INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, started_at)
321                VALUES ($1, $2, $3, $4, $5)
322                ON CONFLICT (workflow_run_id, step_name) DO NOTHING
323                "#,
324            )
325            .bind(step_id)
326            .bind(run_id)
327            .bind(&step_name)
328            .bind(state_clone.status.as_str())
329            .bind(state_clone.started_at)
330            .execute(&pool)
331            .await
332            {
333                tracing::warn!(
334                    workflow_run_id = %run_id,
335                    step = %step_name,
336                    "Failed to persist step start: {}",
337                    e
338                );
339            }
340        });
341    }
342
343    /// Record step completion (fire-and-forget database update).
344    /// Use `record_step_complete_async` if you need to ensure persistence before continuing.
345    pub fn record_step_complete(&self, name: &str, result: serde_json::Value) {
346        let state_clone = self.update_step_state_complete(name, result);
347
348        // Persist to database in background
349        if let Some(state) = state_clone {
350            let pool = self.db_pool.clone();
351            let run_id = self.run_id;
352            let step_name = name.to_string();
353            tokio::spawn(async move {
354                Self::persist_step_complete(&pool, run_id, &step_name, &state).await;
355            });
356        }
357    }
358
359    /// Record step completion and wait for database persistence.
360    pub async fn record_step_complete_async(&self, name: &str, result: serde_json::Value) {
361        let state_clone = self.update_step_state_complete(name, result);
362
363        // Persist to database synchronously
364        if let Some(state) = state_clone {
365            Self::persist_step_complete(&self.db_pool, self.run_id, name, &state).await;
366        }
367    }
368
369    /// Update in-memory step state to completed.
370    fn update_step_state_complete(
371        &self,
372        name: &str,
373        result: serde_json::Value,
374    ) -> Option<StepState> {
375        let mut states = self.step_states.write().expect("workflow lock poisoned");
376        if let Some(state) = states.get_mut(name) {
377            state.complete(result.clone());
378        }
379        let state_clone = states.get(name).cloned();
380        drop(states);
381
382        let mut completed = self
383            .completed_steps
384            .write()
385            .expect("workflow lock poisoned");
386        if !completed.contains(&name.to_string()) {
387            completed.push(name.to_string());
388        }
389        drop(completed);
390
391        state_clone
392    }
393
394    /// Persist step completion to database.
395    async fn persist_step_complete(
396        pool: &sqlx::PgPool,
397        run_id: Uuid,
398        step_name: &str,
399        state: &StepState,
400    ) {
401        // Use UPSERT to handle race condition where persist_step_start hasn't completed yet
402        if let Err(e) = sqlx::query(
403            r#"
404            INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, result, started_at, completed_at)
405            VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6)
406            ON CONFLICT (workflow_run_id, step_name) DO UPDATE
407            SET status = $3, result = $4, completed_at = $6
408            "#,
409        )
410        .bind(run_id)
411        .bind(step_name)
412        .bind(state.status.as_str())
413        .bind(&state.result)
414        .bind(state.started_at)
415        .bind(state.completed_at)
416        .execute(pool)
417        .await
418        {
419            tracing::warn!(
420                workflow_run_id = %run_id,
421                step = %step_name,
422                "Failed to persist step completion: {}",
423                e
424            );
425        }
426    }
427
428    /// Record step failure.
429    pub fn record_step_failure(&self, name: &str, error: impl Into<String>) {
430        let error_str = error.into();
431        let mut states = self.step_states.write().expect("workflow lock poisoned");
432        if let Some(state) = states.get_mut(name) {
433            state.fail(error_str.clone());
434        }
435        let state_clone = states.get(name).cloned();
436        drop(states);
437
438        // Persist to database in background
439        if let Some(state) = state_clone {
440            let pool = self.db_pool.clone();
441            let run_id = self.run_id;
442            let step_name = name.to_string();
443            tokio::spawn(async move {
444                if let Err(e) = sqlx::query(
445                    r#"
446                    UPDATE forge_workflow_steps
447                    SET status = $3, error = $4, completed_at = $5
448                    WHERE workflow_run_id = $1 AND step_name = $2
449                    "#,
450                )
451                .bind(run_id)
452                .bind(&step_name)
453                .bind(state.status.as_str())
454                .bind(&state.error)
455                .bind(state.completed_at)
456                .execute(&pool)
457                .await
458                {
459                    tracing::warn!(
460                        workflow_run_id = %run_id,
461                        step = %step_name,
462                        "Failed to persist step failure: {}",
463                        e
464                    );
465                }
466            });
467        }
468    }
469
470    /// Record step compensation.
471    pub fn record_step_compensated(&self, name: &str) {
472        let mut states = self.step_states.write().expect("workflow lock poisoned");
473        if let Some(state) = states.get_mut(name) {
474            state.compensate();
475        }
476        let state_clone = states.get(name).cloned();
477        drop(states);
478
479        // Persist to database in background
480        if let Some(state) = state_clone {
481            let pool = self.db_pool.clone();
482            let run_id = self.run_id;
483            let step_name = name.to_string();
484            tokio::spawn(async move {
485                if let Err(e) = sqlx::query(
486                    r#"
487                    UPDATE forge_workflow_steps
488                    SET status = $3
489                    WHERE workflow_run_id = $1 AND step_name = $2
490                    "#,
491                )
492                .bind(run_id)
493                .bind(&step_name)
494                .bind(state.status.as_str())
495                .execute(&pool)
496                .await
497                {
498                    tracing::warn!(
499                        workflow_run_id = %run_id,
500                        step = %step_name,
501                        "Failed to persist step compensation: {}",
502                        e
503                    );
504                }
505            });
506        }
507    }
508
509    pub fn completed_steps_reversed(&self) -> Vec<String> {
510        let completed = self.completed_steps.read().expect("workflow lock poisoned");
511        completed.iter().rev().cloned().collect()
512    }
513
514    pub fn all_step_states(&self) -> HashMap<String, StepState> {
515        self.step_states
516            .read()
517            .expect("workflow lock poisoned")
518            .clone()
519    }
520
521    pub fn elapsed(&self) -> chrono::Duration {
522        Utc::now() - self.started_at
523    }
524
525    /// Register a compensation handler for a step.
526    pub fn register_compensation(&self, step_name: &str, handler: CompensationHandler) {
527        let mut handlers = self
528            .compensation_handlers
529            .write()
530            .expect("workflow lock poisoned");
531        handlers.insert(step_name.to_string(), handler);
532    }
533
534    pub fn get_compensation_handler(&self, step_name: &str) -> Option<CompensationHandler> {
535        self.compensation_handlers
536            .read()
537            .expect("workflow lock poisoned")
538            .get(step_name)
539            .cloned()
540    }
541
542    pub fn has_compensation(&self, step_name: &str) -> bool {
543        self.compensation_handlers
544            .read()
545            .expect("workflow lock poisoned")
546            .contains_key(step_name)
547    }
548
549    /// Run compensation for all completed steps in reverse order.
550    /// Returns a list of (step_name, success) tuples.
551    pub async fn run_compensation(&self) -> Vec<(String, bool)> {
552        let steps = self.completed_steps_reversed();
553        let mut results = Vec::new();
554
555        for step_name in steps {
556            let handler = self.get_compensation_handler(&step_name);
557            let result = self
558                .get_step_state(&step_name)
559                .and_then(|s| s.result.clone());
560
561            if let Some(handler) = handler {
562                let step_result = result.unwrap_or(serde_json::Value::Null);
563                match handler(step_result).await {
564                    Ok(()) => {
565                        self.record_step_compensated(&step_name);
566                        results.push((step_name, true));
567                    }
568                    Err(e) => {
569                        tracing::error!(step = %step_name, error = %e, "Compensation failed");
570                        results.push((step_name, false));
571                    }
572                }
573            } else {
574                // No compensation handler, mark as compensated anyway
575                self.record_step_compensated(&step_name);
576                results.push((step_name, true));
577            }
578        }
579
580        results
581    }
582
583    pub fn compensation_handlers(&self) -> HashMap<String, CompensationHandler> {
584        self.compensation_handlers
585            .read()
586            .expect("workflow lock poisoned")
587            .clone()
588    }
589
590    /// Sleep for a duration.
591    ///
592    /// This suspends the workflow and persists the wake time to the database.
593    /// The workflow scheduler will resume the workflow when the time arrives.
594    ///
595    /// # Example
596    /// ```ignore
597    /// // Sleep for 30 days
598    /// ctx.sleep(Duration::from_secs(30 * 24 * 60 * 60)).await?;
599    /// ```
600    pub async fn sleep(&self, duration: Duration) -> Result<()> {
601        // If we resumed from a sleep, the timer already expired - continue immediately
602        if self.resumed_from_sleep {
603            return Ok(());
604        }
605
606        let wake_at = Utc::now() + chrono::Duration::from_std(duration).unwrap_or_default();
607        self.sleep_until(wake_at).await
608    }
609
610    /// Sleep until a specific time.
611    ///
612    /// If the wake time has already passed, returns immediately.
613    ///
614    /// # Example
615    /// ```ignore
616    /// use chrono::{Utc, Duration};
617    /// let renewal_date = Utc::now() + Duration::days(30);
618    /// ctx.sleep_until(renewal_date).await?;
619    /// ```
620    pub async fn sleep_until(&self, wake_at: DateTime<Utc>) -> Result<()> {
621        // If we resumed from a sleep, the timer already expired - continue immediately
622        if self.resumed_from_sleep {
623            return Ok(());
624        }
625
626        // If wake time already passed, return immediately
627        if wake_at <= Utc::now() {
628            return Ok(());
629        }
630
631        // Persist the wake time to database
632        self.set_wake_at(wake_at).await?;
633
634        // Signal suspension to executor
635        self.signal_suspend(SuspendReason::Sleep { wake_at })
636            .await?;
637
638        Ok(())
639    }
640
641    /// Wait for an external event with optional timeout.
642    ///
643    /// The workflow suspends until the event arrives or the timeout expires.
644    /// Events are correlated by the workflow run ID.
645    ///
646    /// # Example
647    /// ```ignore
648    /// let payment: PaymentConfirmation = ctx.wait_for_event(
649    ///     "payment_confirmed",
650    ///     Some(Duration::from_secs(7 * 24 * 60 * 60)), // 7 days
651    /// ).await?;
652    /// ```
653    pub async fn wait_for_event<T: DeserializeOwned>(
654        &self,
655        event_name: &str,
656        timeout: Option<Duration>,
657    ) -> Result<T> {
658        let correlation_id = self.run_id.to_string();
659
660        // Check if event already exists (race condition handling)
661        if let Some(event) = self.try_consume_event(event_name, &correlation_id).await? {
662            return serde_json::from_value(event.payload.unwrap_or_default())
663                .map_err(|e| ForgeError::Deserialization(e.to_string()));
664        }
665
666        // Calculate timeout
667        let timeout_at =
668            timeout.map(|d| Utc::now() + chrono::Duration::from_std(d).unwrap_or_default());
669
670        // Persist waiting state
671        self.set_waiting_for_event(event_name, timeout_at).await?;
672
673        // Signal suspension
674        self.signal_suspend(SuspendReason::WaitingEvent {
675            event_name: event_name.to_string(),
676            timeout: timeout_at,
677        })
678        .await?;
679
680        // After resume, try to consume the event
681        self.try_consume_event(event_name, &correlation_id)
682            .await?
683            .and_then(|e| e.payload)
684            .and_then(|p| serde_json::from_value(p).ok())
685            .ok_or_else(|| ForgeError::Timeout(format!("Event '{}' timed out", event_name)))
686    }
687
688    /// Try to consume an event from the database.
689    #[allow(clippy::type_complexity)]
690    async fn try_consume_event(
691        &self,
692        event_name: &str,
693        correlation_id: &str,
694    ) -> Result<Option<WorkflowEvent>> {
695        let result: Option<(
696            Uuid,
697            String,
698            String,
699            Option<serde_json::Value>,
700            DateTime<Utc>,
701        )> = sqlx::query_as(
702            r#"
703                UPDATE forge_workflow_events
704                SET consumed_at = NOW(), consumed_by = $3
705                WHERE id = (
706                    SELECT id FROM forge_workflow_events
707                    WHERE event_name = $1 AND correlation_id = $2 AND consumed_at IS NULL
708                    ORDER BY created_at ASC LIMIT 1
709                    FOR UPDATE SKIP LOCKED
710                )
711                RETURNING id, event_name, correlation_id, payload, created_at
712                "#,
713        )
714        .bind(event_name)
715        .bind(correlation_id)
716        .bind(self.run_id)
717        .fetch_optional(&self.db_pool)
718        .await
719        .map_err(|e| ForgeError::Database(e.to_string()))?;
720
721        Ok(result.map(
722            |(id, event_name, correlation_id, payload, created_at)| WorkflowEvent {
723                id,
724                event_name,
725                correlation_id,
726                payload,
727                created_at,
728            },
729        ))
730    }
731
732    /// Persist wake time to database.
733    async fn set_wake_at(&self, wake_at: DateTime<Utc>) -> Result<()> {
734        sqlx::query(
735            r#"
736            UPDATE forge_workflow_runs
737            SET status = 'waiting', suspended_at = NOW(), wake_at = $2
738            WHERE id = $1
739            "#,
740        )
741        .bind(self.run_id)
742        .bind(wake_at)
743        .execute(&self.db_pool)
744        .await
745        .map_err(|e| ForgeError::Database(e.to_string()))?;
746        Ok(())
747    }
748
749    /// Persist waiting for event state to database.
750    async fn set_waiting_for_event(
751        &self,
752        event_name: &str,
753        timeout_at: Option<DateTime<Utc>>,
754    ) -> Result<()> {
755        sqlx::query(
756            r#"
757            UPDATE forge_workflow_runs
758            SET status = 'waiting', suspended_at = NOW(), waiting_for_event = $2, event_timeout_at = $3
759            WHERE id = $1
760            "#,
761        )
762        .bind(self.run_id)
763        .bind(event_name)
764        .bind(timeout_at)
765        .execute(&self.db_pool)
766        .await
767        .map_err(|e| ForgeError::Database(e.to_string()))?;
768        Ok(())
769    }
770
771    /// Signal suspension to the executor.
772    async fn signal_suspend(&self, reason: SuspendReason) -> Result<()> {
773        if let Some(ref tx) = self.suspend_tx {
774            tx.send(reason)
775                .await
776                .map_err(|_| ForgeError::Internal("Failed to signal suspension".into()))?;
777        }
778        // Return a special error that the executor catches
779        Err(ForgeError::WorkflowSuspended)
780    }
781
782    /// Create a parallel builder for executing steps concurrently.
783    ///
784    /// # Example
785    /// ```ignore
786    /// let results = ctx.parallel()
787    ///     .step("fetch_user", || async { get_user(id).await })
788    ///     .step("fetch_orders", || async { get_orders(id).await })
789    ///     .step_with_compensate("charge_card",
790    ///         || async { charge_card(amount).await },
791    ///         |charge| async move { refund(charge.id).await })
792    ///     .run().await?;
793    ///
794    /// let user: User = results.get("fetch_user")?;
795    /// let orders: Vec<Order> = results.get("fetch_orders")?;
796    /// ```
797    pub fn parallel(&self) -> ParallelBuilder<'_> {
798        ParallelBuilder::new(self)
799    }
800
801    /// Create a step runner for executing a workflow step.
802    ///
803    /// This provides a fluent API for defining steps with retry, compensation,
804    /// timeout, and optional behavior.
805    ///
806    /// # Examples
807    ///
808    /// ```ignore
809    /// use std::time::Duration;
810    ///
811    /// // Simple step
812    /// let data = ctx.step("fetch_data", || async {
813    ///     Ok(fetch_from_api().await?)
814    /// }).run().await?;
815    ///
816    /// // Step with retry (3 attempts, 2 second delay)
817    /// ctx.step("send_email", || async {
818    ///     send_verification_email(&user.email).await
819    /// })
820    /// .retry(3, Duration::from_secs(2))
821    /// .run()
822    /// .await?;
823    ///
824    /// // Step with compensation (rollback on later failure)
825    /// let charge = ctx.step("charge_card", || async {
826    ///     charge_credit_card(&card).await
827    /// })
828    /// .compensate(|charge_result| async move {
829    ///     refund_charge(&charge_result.charge_id).await
830    /// })
831    /// .run()
832    /// .await?;
833    ///
834    /// // Optional step (failure won't trigger compensation)
835    /// ctx.step("notify_slack", || async {
836    ///     post_to_slack("User signed up!").await
837    /// })
838    /// .optional()
839    /// .run()
840    /// .await?;
841    ///
842    /// // Step with timeout
843    /// ctx.step("slow_operation", || async {
844    ///     process_large_file().await
845    /// })
846    /// .timeout(Duration::from_secs(60))
847    /// .run()
848    /// .await?;
849    /// ```
850    pub fn step<T, F, Fut>(&self, name: impl Into<String>, f: F) -> super::StepRunner<'_, T>
851    where
852        T: serde::Serialize + serde::de::DeserializeOwned + Clone + Send + Sync + 'static,
853        F: Fn() -> Fut + Send + Sync + 'static,
854        Fut: std::future::Future<Output = crate::Result<T>> + Send + 'static,
855    {
856        super::StepRunner::new(self, name, f)
857    }
858}
859
860impl EnvAccess for WorkflowContext {
861    fn env_provider(&self) -> &dyn EnvProvider {
862        self.env_provider.as_ref()
863    }
864}
865
866#[cfg(test)]
867#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
868mod tests {
869    use super::*;
870
871    #[tokio::test]
872    async fn test_workflow_context_creation() {
873        let pool = sqlx::postgres::PgPoolOptions::new()
874            .max_connections(1)
875            .connect_lazy("postgres://localhost/nonexistent")
876            .expect("Failed to create mock pool");
877
878        let run_id = Uuid::new_v4();
879        let ctx = WorkflowContext::new(
880            run_id,
881            "test_workflow".to_string(),
882            1,
883            pool,
884            reqwest::Client::new(),
885        );
886
887        assert_eq!(ctx.run_id, run_id);
888        assert_eq!(ctx.workflow_name, "test_workflow");
889        assert_eq!(ctx.version, 1);
890    }
891
892    #[tokio::test]
893    async fn test_step_state_tracking() {
894        let pool = sqlx::postgres::PgPoolOptions::new()
895            .max_connections(1)
896            .connect_lazy("postgres://localhost/nonexistent")
897            .expect("Failed to create mock pool");
898
899        let ctx = WorkflowContext::new(
900            Uuid::new_v4(),
901            "test".to_string(),
902            1,
903            pool,
904            reqwest::Client::new(),
905        );
906
907        ctx.record_step_start("step1");
908        assert!(!ctx.is_step_completed("step1"));
909
910        ctx.record_step_complete("step1", serde_json::json!({"result": "ok"}));
911        assert!(ctx.is_step_completed("step1"));
912
913        let result: Option<serde_json::Value> = ctx.get_step_result("step1");
914        assert!(result.is_some());
915    }
916
917    #[test]
918    fn test_step_state_transitions() {
919        let mut state = StepState::new("test");
920        assert_eq!(state.status, StepStatus::Pending);
921
922        state.start();
923        assert_eq!(state.status, StepStatus::Running);
924        assert!(state.started_at.is_some());
925
926        state.complete(serde_json::json!({}));
927        assert_eq!(state.status, StepStatus::Completed);
928        assert!(state.completed_at.is_some());
929    }
930}