Skip to main content

forge_core/workflow/
context.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, RwLock};
5use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use serde::de::DeserializeOwned;
9use tokio::sync::mpsc;
10use uuid::Uuid;
11
12use super::parallel::ParallelBuilder;
13use super::step::StepStatus;
14use super::suspend::{SuspendReason, WorkflowEvent};
15use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
16use crate::function::AuthContext;
17use crate::http::CircuitBreakerClient;
18use crate::{ForgeError, Result};
19
20/// Type alias for compensation handler function.
21pub type CompensationHandler = Arc<
22    dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync,
23>;
24
25/// Step state stored during execution.
26#[derive(Debug, Clone)]
27pub struct StepState {
28    /// Step name.
29    pub name: String,
30    /// Step status.
31    pub status: StepStatus,
32    /// Serialized result (if completed).
33    pub result: Option<serde_json::Value>,
34    /// Error message (if failed).
35    pub error: Option<String>,
36    /// When the step started.
37    pub started_at: Option<DateTime<Utc>>,
38    /// When the step completed.
39    pub completed_at: Option<DateTime<Utc>>,
40}
41
42impl StepState {
43    /// Create a new pending step state.
44    pub fn new(name: impl Into<String>) -> Self {
45        Self {
46            name: name.into(),
47            status: StepStatus::Pending,
48            result: None,
49            error: None,
50            started_at: None,
51            completed_at: None,
52        }
53    }
54
55    /// Mark step as running.
56    pub fn start(&mut self) {
57        self.status = StepStatus::Running;
58        self.started_at = Some(Utc::now());
59    }
60
61    /// Mark step as completed with result.
62    pub fn complete(&mut self, result: serde_json::Value) {
63        self.status = StepStatus::Completed;
64        self.result = Some(result);
65        self.completed_at = Some(Utc::now());
66    }
67
68    /// Mark step as failed with error.
69    pub fn fail(&mut self, error: impl Into<String>) {
70        self.status = StepStatus::Failed;
71        self.error = Some(error.into());
72        self.completed_at = Some(Utc::now());
73    }
74
75    /// Mark step as compensated.
76    pub fn compensate(&mut self) {
77        self.status = StepStatus::Compensated;
78    }
79}
80
81/// Context available to workflow handlers.
82pub struct WorkflowContext {
83    /// Workflow run ID.
84    pub run_id: Uuid,
85    /// Workflow name.
86    pub workflow_name: String,
87    /// Workflow version.
88    pub version: u32,
89    /// When the workflow started.
90    pub started_at: DateTime<Utc>,
91    /// Deterministic workflow time (consistent across replays).
92    workflow_time: DateTime<Utc>,
93    /// Authentication context.
94    pub auth: AuthContext,
95    /// Database pool.
96    db_pool: sqlx::PgPool,
97    /// HTTP client.
98    http_client: CircuitBreakerClient,
99    /// Default timeout for outbound HTTP requests made through the
100    /// circuit-breaker client. `None` means unlimited.
101    http_timeout: Option<Duration>,
102    /// Step states (for resumption).
103    step_states: Arc<RwLock<HashMap<String, StepState>>>,
104    /// Completed steps in order (for compensation).
105    completed_steps: Arc<RwLock<Vec<String>>>,
106    /// Compensation handlers for completed steps.
107    compensation_handlers: Arc<RwLock<HashMap<String, CompensationHandler>>>,
108    /// Channel for signaling suspension (sent by workflow, received by executor).
109    suspend_tx: Option<mpsc::Sender<SuspendReason>>,
110    /// Whether this is a resumed execution.
111    is_resumed: bool,
112    /// Whether this execution resumed specifically from a sleep (timer expired).
113    resumed_from_sleep: bool,
114    /// Tenant ID for multi-tenancy.
115    tenant_id: Option<Uuid>,
116    /// Environment variable provider.
117    env_provider: Arc<dyn EnvProvider>,
118}
119
120impl WorkflowContext {
121    /// Create a new workflow context.
122    pub fn new(
123        run_id: Uuid,
124        workflow_name: String,
125        version: u32,
126        db_pool: sqlx::PgPool,
127        http_client: CircuitBreakerClient,
128    ) -> Self {
129        let now = Utc::now();
130        Self {
131            run_id,
132            workflow_name,
133            version,
134            started_at: now,
135            workflow_time: now,
136            auth: AuthContext::unauthenticated(),
137            db_pool,
138            http_client,
139            http_timeout: None,
140            step_states: Arc::new(RwLock::new(HashMap::new())),
141            completed_steps: Arc::new(RwLock::new(Vec::new())),
142            compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
143            suspend_tx: None,
144            is_resumed: false,
145            resumed_from_sleep: false,
146            tenant_id: None,
147            env_provider: Arc::new(RealEnvProvider::new()),
148        }
149    }
150
151    /// Create a resumed workflow context.
152    pub fn resumed(
153        run_id: Uuid,
154        workflow_name: String,
155        version: u32,
156        started_at: DateTime<Utc>,
157        db_pool: sqlx::PgPool,
158        http_client: CircuitBreakerClient,
159    ) -> Self {
160        Self {
161            run_id,
162            workflow_name,
163            version,
164            started_at,
165            workflow_time: started_at,
166            auth: AuthContext::unauthenticated(),
167            db_pool,
168            http_client,
169            http_timeout: None,
170            step_states: Arc::new(RwLock::new(HashMap::new())),
171            completed_steps: Arc::new(RwLock::new(Vec::new())),
172            compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
173            suspend_tx: None,
174            is_resumed: true,
175            resumed_from_sleep: false,
176            tenant_id: None,
177            env_provider: Arc::new(RealEnvProvider::new()),
178        }
179    }
180
181    /// Set environment provider.
182    pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
183        self.env_provider = provider;
184        self
185    }
186
187    /// Mark that this context resumed from a sleep (timer expired).
188    pub fn with_resumed_from_sleep(mut self) -> Self {
189        self.resumed_from_sleep = true;
190        self
191    }
192
193    /// Set the suspend channel.
194    pub fn with_suspend_channel(mut self, tx: mpsc::Sender<SuspendReason>) -> Self {
195        self.suspend_tx = Some(tx);
196        self
197    }
198
199    /// Set the tenant ID.
200    pub fn with_tenant(mut self, tenant_id: Uuid) -> Self {
201        self.tenant_id = Some(tenant_id);
202        self
203    }
204
205    pub fn tenant_id(&self) -> Option<Uuid> {
206        self.tenant_id
207    }
208
209    pub fn is_resumed(&self) -> bool {
210        self.is_resumed
211    }
212
213    pub fn workflow_time(&self) -> DateTime<Utc> {
214        self.workflow_time
215    }
216
217    pub fn db(&self) -> crate::function::ForgeDb {
218        crate::function::ForgeDb::from_pool(&self.db_pool)
219    }
220
221    /// Acquire a connection compatible with sqlx compile-time checked macros.
222    pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
223        Ok(crate::function::ForgeConn::Pool(
224            self.db_pool.acquire().await?,
225        ))
226    }
227
228    pub fn http(&self) -> crate::http::HttpClient {
229        self.http_client.with_timeout(self.http_timeout)
230    }
231
232    pub fn raw_http(&self) -> &reqwest::Client {
233        self.http_client.inner()
234    }
235
236    pub fn http_with_circuit_breaker(&self) -> crate::http::HttpClient {
237        self.http()
238    }
239
240    pub fn set_http_timeout(&mut self, timeout: Option<Duration>) {
241        self.http_timeout = timeout;
242    }
243
244    /// Set authentication context.
245    pub fn with_auth(mut self, auth: AuthContext) -> Self {
246        self.auth = auth;
247        self
248    }
249
250    /// Restore step states from persisted data.
251    pub fn with_step_states(self, states: HashMap<String, StepState>) -> Self {
252        let completed: Vec<String> = states
253            .iter()
254            .filter(|(_, s)| s.status == StepStatus::Completed)
255            .map(|(name, _)| name.clone())
256            .collect();
257
258        *self.step_states.write().expect("workflow lock poisoned") = states;
259        *self
260            .completed_steps
261            .write()
262            .expect("workflow lock poisoned") = completed;
263        self
264    }
265
266    pub fn get_step_state(&self, name: &str) -> Option<StepState> {
267        self.step_states
268            .read()
269            .expect("workflow lock poisoned")
270            .get(name)
271            .cloned()
272    }
273
274    pub fn is_step_completed(&self, name: &str) -> bool {
275        self.step_states
276            .read()
277            .expect("workflow lock poisoned")
278            .get(name)
279            .map(|s| s.status == StepStatus::Completed)
280            .unwrap_or(false)
281    }
282
283    /// Check if a step has been started (running, completed, or failed).
284    ///
285    /// Use this to guard steps that should only execute once, even across
286    /// workflow suspension and resumption.
287    pub fn is_step_started(&self, name: &str) -> bool {
288        self.step_states
289            .read()
290            .expect("workflow lock poisoned")
291            .get(name)
292            .map(|s| s.status != StepStatus::Pending)
293            .unwrap_or(false)
294    }
295
296    pub fn get_step_result<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
297        self.step_states
298            .read()
299            .expect("workflow lock poisoned")
300            .get(name)
301            .and_then(|s| s.result.as_ref())
302            .and_then(|v| serde_json::from_value(v.clone()).ok())
303    }
304
305    /// Record step start.
306    ///
307    /// If the step is already running or beyond (completed/failed), this is a no-op.
308    /// This prevents race conditions when resuming workflows.
309    pub fn record_step_start(&self, name: &str) {
310        let mut states = self.step_states.write().expect("workflow lock poisoned");
311        let state = states
312            .entry(name.to_string())
313            .or_insert_with(|| StepState::new(name));
314
315        // Only update if step is pending - prevents race condition on resume
316        // where background DB update could overwrite a completed status
317        if state.status != StepStatus::Pending {
318            return;
319        }
320
321        state.start();
322        let state_clone = state.clone();
323        drop(states);
324
325        // Persist to database in background
326        let pool = self.db_pool.clone();
327        let run_id = self.run_id;
328        let step_name = name.to_string();
329        tokio::spawn(async move {
330            let step_id = Uuid::new_v4();
331            if let Err(e) = sqlx::query(
332                r#"
333                INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, started_at)
334                VALUES ($1, $2, $3, $4, $5)
335                ON CONFLICT (workflow_run_id, step_name) DO NOTHING
336                "#,
337            )
338            .bind(step_id)
339            .bind(run_id)
340            .bind(&step_name)
341            .bind(state_clone.status.as_str())
342            .bind(state_clone.started_at)
343            .execute(&pool)
344            .await
345            {
346                tracing::warn!(
347                    workflow_run_id = %run_id,
348                    step = %step_name,
349                    "Failed to persist step start: {}",
350                    e
351                );
352            }
353        });
354    }
355
356    /// Record step completion (fire-and-forget database update).
357    /// Use `record_step_complete_async` if you need to ensure persistence before continuing.
358    pub fn record_step_complete(&self, name: &str, result: serde_json::Value) {
359        let state_clone = self.update_step_state_complete(name, result);
360
361        // Persist to database in background
362        if let Some(state) = state_clone {
363            let pool = self.db_pool.clone();
364            let run_id = self.run_id;
365            let step_name = name.to_string();
366            tokio::spawn(async move {
367                Self::persist_step_complete(&pool, run_id, &step_name, &state).await;
368            });
369        }
370    }
371
372    /// Record step completion and wait for database persistence.
373    pub async fn record_step_complete_async(&self, name: &str, result: serde_json::Value) {
374        let state_clone = self.update_step_state_complete(name, result);
375
376        // Persist to database synchronously
377        if let Some(state) = state_clone {
378            Self::persist_step_complete(&self.db_pool, self.run_id, name, &state).await;
379        }
380    }
381
382    /// Update in-memory step state to completed.
383    fn update_step_state_complete(
384        &self,
385        name: &str,
386        result: serde_json::Value,
387    ) -> Option<StepState> {
388        let mut states = self.step_states.write().expect("workflow lock poisoned");
389        if let Some(state) = states.get_mut(name) {
390            state.complete(result.clone());
391        }
392        let state_clone = states.get(name).cloned();
393        drop(states);
394
395        let mut completed = self
396            .completed_steps
397            .write()
398            .expect("workflow lock poisoned");
399        if !completed.contains(&name.to_string()) {
400            completed.push(name.to_string());
401        }
402        drop(completed);
403
404        state_clone
405    }
406
407    /// Persist step completion to database.
408    async fn persist_step_complete(
409        pool: &sqlx::PgPool,
410        run_id: Uuid,
411        step_name: &str,
412        state: &StepState,
413    ) {
414        // Use UPSERT to handle race condition where persist_step_start hasn't completed yet
415        if let Err(e) = sqlx::query(
416            r#"
417            INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, result, started_at, completed_at)
418            VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6)
419            ON CONFLICT (workflow_run_id, step_name) DO UPDATE
420            SET status = $3, result = $4, completed_at = $6
421            "#,
422        )
423        .bind(run_id)
424        .bind(step_name)
425        .bind(state.status.as_str())
426        .bind(&state.result)
427        .bind(state.started_at)
428        .bind(state.completed_at)
429        .execute(pool)
430        .await
431        {
432            tracing::warn!(
433                workflow_run_id = %run_id,
434                step = %step_name,
435                "Failed to persist step completion: {}",
436                e
437            );
438        }
439    }
440
441    /// Record step failure.
442    pub fn record_step_failure(&self, name: &str, error: impl Into<String>) {
443        let error_str = error.into();
444        let mut states = self.step_states.write().expect("workflow lock poisoned");
445        if let Some(state) = states.get_mut(name) {
446            state.fail(error_str.clone());
447        }
448        let state_clone = states.get(name).cloned();
449        drop(states);
450
451        // Persist to database in background
452        if let Some(state) = state_clone {
453            let pool = self.db_pool.clone();
454            let run_id = self.run_id;
455            let step_name = name.to_string();
456            tokio::spawn(async move {
457                if let Err(e) = sqlx::query(
458                    r#"
459                    UPDATE forge_workflow_steps
460                    SET status = $3, error = $4, completed_at = $5
461                    WHERE workflow_run_id = $1 AND step_name = $2
462                    "#,
463                )
464                .bind(run_id)
465                .bind(&step_name)
466                .bind(state.status.as_str())
467                .bind(&state.error)
468                .bind(state.completed_at)
469                .execute(&pool)
470                .await
471                {
472                    tracing::warn!(
473                        workflow_run_id = %run_id,
474                        step = %step_name,
475                        "Failed to persist step failure: {}",
476                        e
477                    );
478                }
479            });
480        }
481    }
482
483    /// Record step compensation.
484    pub fn record_step_compensated(&self, name: &str) {
485        let mut states = self.step_states.write().expect("workflow lock poisoned");
486        if let Some(state) = states.get_mut(name) {
487            state.compensate();
488        }
489        let state_clone = states.get(name).cloned();
490        drop(states);
491
492        // Persist to database in background
493        if let Some(state) = state_clone {
494            let pool = self.db_pool.clone();
495            let run_id = self.run_id;
496            let step_name = name.to_string();
497            tokio::spawn(async move {
498                if let Err(e) = sqlx::query(
499                    r#"
500                    UPDATE forge_workflow_steps
501                    SET status = $3
502                    WHERE workflow_run_id = $1 AND step_name = $2
503                    "#,
504                )
505                .bind(run_id)
506                .bind(&step_name)
507                .bind(state.status.as_str())
508                .execute(&pool)
509                .await
510                {
511                    tracing::warn!(
512                        workflow_run_id = %run_id,
513                        step = %step_name,
514                        "Failed to persist step compensation: {}",
515                        e
516                    );
517                }
518            });
519        }
520    }
521
522    pub fn completed_steps_reversed(&self) -> Vec<String> {
523        let completed = self.completed_steps.read().expect("workflow lock poisoned");
524        completed.iter().rev().cloned().collect()
525    }
526
527    pub fn all_step_states(&self) -> HashMap<String, StepState> {
528        self.step_states
529            .read()
530            .expect("workflow lock poisoned")
531            .clone()
532    }
533
534    pub fn elapsed(&self) -> chrono::Duration {
535        Utc::now() - self.started_at
536    }
537
538    /// Register a compensation handler for a step.
539    pub fn register_compensation(&self, step_name: &str, handler: CompensationHandler) {
540        let mut handlers = self
541            .compensation_handlers
542            .write()
543            .expect("workflow lock poisoned");
544        handlers.insert(step_name.to_string(), handler);
545    }
546
547    pub fn get_compensation_handler(&self, step_name: &str) -> Option<CompensationHandler> {
548        self.compensation_handlers
549            .read()
550            .expect("workflow lock poisoned")
551            .get(step_name)
552            .cloned()
553    }
554
555    pub fn has_compensation(&self, step_name: &str) -> bool {
556        self.compensation_handlers
557            .read()
558            .expect("workflow lock poisoned")
559            .contains_key(step_name)
560    }
561
562    /// Run compensation for all completed steps in reverse order.
563    /// Returns a list of (step_name, success) tuples.
564    pub async fn run_compensation(&self) -> Vec<(String, bool)> {
565        let steps = self.completed_steps_reversed();
566        let mut results = Vec::new();
567
568        for step_name in steps {
569            let handler = self.get_compensation_handler(&step_name);
570            let result = self
571                .get_step_state(&step_name)
572                .and_then(|s| s.result.clone());
573
574            if let Some(handler) = handler {
575                let step_result = result.unwrap_or(serde_json::Value::Null);
576                match handler(step_result).await {
577                    Ok(()) => {
578                        self.record_step_compensated(&step_name);
579                        results.push((step_name, true));
580                    }
581                    Err(e) => {
582                        tracing::error!(step = %step_name, error = %e, "Compensation failed");
583                        results.push((step_name, false));
584                    }
585                }
586            } else {
587                // No compensation handler, mark as compensated anyway
588                self.record_step_compensated(&step_name);
589                results.push((step_name, true));
590            }
591        }
592
593        results
594    }
595
596    pub fn compensation_handlers(&self) -> HashMap<String, CompensationHandler> {
597        self.compensation_handlers
598            .read()
599            .expect("workflow lock poisoned")
600            .clone()
601    }
602
603    /// Sleep for a duration.
604    ///
605    /// This suspends the workflow and persists the wake time to the database.
606    /// The workflow scheduler will resume the workflow when the time arrives.
607    ///
608    /// # Example
609    /// ```ignore
610    /// // Sleep for 30 days
611    /// ctx.sleep(Duration::from_secs(30 * 24 * 60 * 60)).await?;
612    /// ```
613    pub async fn sleep(&self, duration: Duration) -> Result<()> {
614        // If we resumed from a sleep, the timer already expired - continue immediately
615        if self.resumed_from_sleep {
616            return Ok(());
617        }
618
619        let wake_at = Utc::now() + chrono::Duration::from_std(duration).unwrap_or_default();
620        self.sleep_until(wake_at).await
621    }
622
623    /// Sleep until a specific time.
624    ///
625    /// If the wake time has already passed, returns immediately.
626    ///
627    /// # Example
628    /// ```ignore
629    /// use chrono::{Utc, Duration};
630    /// let renewal_date = Utc::now() + Duration::days(30);
631    /// ctx.sleep_until(renewal_date).await?;
632    /// ```
633    pub async fn sleep_until(&self, wake_at: DateTime<Utc>) -> Result<()> {
634        // If we resumed from a sleep, the timer already expired - continue immediately
635        if self.resumed_from_sleep {
636            return Ok(());
637        }
638
639        // If wake time already passed, return immediately
640        if wake_at <= Utc::now() {
641            return Ok(());
642        }
643
644        // Persist the wake time to database
645        self.set_wake_at(wake_at).await?;
646
647        // Signal suspension to executor
648        self.signal_suspend(SuspendReason::Sleep { wake_at })
649            .await?;
650
651        Ok(())
652    }
653
654    /// Wait for an external event with optional timeout.
655    ///
656    /// The workflow suspends until the event arrives or the timeout expires.
657    /// Events are correlated by the workflow run ID.
658    ///
659    /// # Example
660    /// ```ignore
661    /// let payment: PaymentConfirmation = ctx.wait_for_event(
662    ///     "payment_confirmed",
663    ///     Some(Duration::from_secs(7 * 24 * 60 * 60)), // 7 days
664    /// ).await?;
665    /// ```
666    pub async fn wait_for_event<T: DeserializeOwned>(
667        &self,
668        event_name: &str,
669        timeout: Option<Duration>,
670    ) -> Result<T> {
671        let correlation_id = self.run_id.to_string();
672
673        // Check if event already exists (race condition handling)
674        if let Some(event) = self.try_consume_event(event_name, &correlation_id).await? {
675            return serde_json::from_value(event.payload.unwrap_or_default())
676                .map_err(|e| ForgeError::Deserialization(e.to_string()));
677        }
678
679        // Calculate timeout
680        let timeout_at =
681            timeout.map(|d| Utc::now() + chrono::Duration::from_std(d).unwrap_or_default());
682
683        // Persist waiting state
684        self.set_waiting_for_event(event_name, timeout_at).await?;
685
686        // Signal suspension
687        self.signal_suspend(SuspendReason::WaitingEvent {
688            event_name: event_name.to_string(),
689            timeout: timeout_at,
690        })
691        .await?;
692
693        // After resume, try to consume the event
694        self.try_consume_event(event_name, &correlation_id)
695            .await?
696            .and_then(|e| e.payload)
697            .and_then(|p| serde_json::from_value(p).ok())
698            .ok_or_else(|| ForgeError::Timeout(format!("Event '{}' timed out", event_name)))
699    }
700
701    /// Try to consume an event from the database.
702    #[allow(clippy::type_complexity)]
703    async fn try_consume_event(
704        &self,
705        event_name: &str,
706        correlation_id: &str,
707    ) -> Result<Option<WorkflowEvent>> {
708        let result = sqlx::query!(
709            r#"
710                UPDATE forge_workflow_events
711                SET consumed_at = NOW(), consumed_by = $3
712                WHERE id = (
713                    SELECT id FROM forge_workflow_events
714                    WHERE event_name = $1 AND correlation_id = $2 AND consumed_at IS NULL
715                    ORDER BY created_at ASC LIMIT 1
716                    FOR UPDATE SKIP LOCKED
717                )
718                RETURNING id, event_name, correlation_id, payload, created_at
719                "#,
720            event_name,
721            correlation_id,
722            self.run_id
723        )
724        .fetch_optional(&self.db_pool)
725        .await
726        .map_err(|e| ForgeError::Database(e.to_string()))?;
727
728        Ok(result.map(|row| WorkflowEvent {
729            id: row.id,
730            event_name: row.event_name,
731            correlation_id: row.correlation_id,
732            payload: row.payload,
733            created_at: row.created_at,
734        }))
735    }
736
737    /// Persist wake time to database.
738    async fn set_wake_at(&self, wake_at: DateTime<Utc>) -> Result<()> {
739        sqlx::query(
740            r#"
741            UPDATE forge_workflow_runs
742            SET status = 'waiting', suspended_at = NOW(), wake_at = $2
743            WHERE id = $1
744            "#,
745        )
746        .bind(self.run_id)
747        .bind(wake_at)
748        .execute(&self.db_pool)
749        .await
750        .map_err(|e| ForgeError::Database(e.to_string()))?;
751        Ok(())
752    }
753
754    /// Persist waiting for event state to database.
755    async fn set_waiting_for_event(
756        &self,
757        event_name: &str,
758        timeout_at: Option<DateTime<Utc>>,
759    ) -> Result<()> {
760        sqlx::query(
761            r#"
762            UPDATE forge_workflow_runs
763            SET status = 'waiting', suspended_at = NOW(), waiting_for_event = $2, event_timeout_at = $3
764            WHERE id = $1
765            "#,
766        )
767        .bind(self.run_id)
768        .bind(event_name)
769        .bind(timeout_at)
770        .execute(&self.db_pool)
771        .await
772        .map_err(|e| ForgeError::Database(e.to_string()))?;
773        Ok(())
774    }
775
776    /// Signal suspension to the executor.
777    async fn signal_suspend(&self, reason: SuspendReason) -> Result<()> {
778        if let Some(ref tx) = self.suspend_tx {
779            tx.send(reason)
780                .await
781                .map_err(|_| ForgeError::Internal("Failed to signal suspension".into()))?;
782        }
783        // Return a special error that the executor catches
784        Err(ForgeError::WorkflowSuspended)
785    }
786
787    /// Create a parallel builder for executing steps concurrently.
788    ///
789    /// # Example
790    /// ```ignore
791    /// let results = ctx.parallel()
792    ///     .step("fetch_user", || async { get_user(id).await })
793    ///     .step("fetch_orders", || async { get_orders(id).await })
794    ///     .step_with_compensate("charge_card",
795    ///         || async { charge_card(amount).await },
796    ///         |charge| async move { refund(charge.id).await })
797    ///     .run().await?;
798    ///
799    /// let user: User = results.get("fetch_user")?;
800    /// let orders: Vec<Order> = results.get("fetch_orders")?;
801    /// ```
802    pub fn parallel(&self) -> ParallelBuilder<'_> {
803        ParallelBuilder::new(self)
804    }
805
806    /// Create a step runner for executing a workflow step.
807    ///
808    /// This provides a fluent API for defining steps with retry, compensation,
809    /// timeout, and optional behavior.
810    ///
811    /// # Examples
812    ///
813    /// ```ignore
814    /// use std::time::Duration;
815    ///
816    /// // Simple step
817    /// let data = ctx.step("fetch_data", || async {
818    ///     Ok(fetch_from_api().await?)
819    /// }).run().await?;
820    ///
821    /// // Step with retry (3 attempts, 2 second delay)
822    /// ctx.step("send_email", || async {
823    ///     send_verification_email(&user.email).await
824    /// })
825    /// .retry(3, Duration::from_secs(2))
826    /// .run()
827    /// .await?;
828    ///
829    /// // Step with compensation (rollback on later failure)
830    /// let charge = ctx.step("charge_card", || async {
831    ///     charge_credit_card(&card).await
832    /// })
833    /// .compensate(|charge_result| async move {
834    ///     refund_charge(&charge_result.charge_id).await
835    /// })
836    /// .run()
837    /// .await?;
838    ///
839    /// // Optional step (failure won't trigger compensation)
840    /// ctx.step("notify_slack", || async {
841    ///     post_to_slack("User signed up!").await
842    /// })
843    /// .optional()
844    /// .run()
845    /// .await?;
846    ///
847    /// // Step with timeout
848    /// ctx.step("slow_operation", || async {
849    ///     process_large_file().await
850    /// })
851    /// .timeout(Duration::from_secs(60))
852    /// .run()
853    /// .await?;
854    /// ```
855    pub fn step<T, F, Fut>(&self, name: impl Into<String>, f: F) -> super::StepRunner<'_, T>
856    where
857        T: serde::Serialize + serde::de::DeserializeOwned + Clone + Send + Sync + 'static,
858        F: Fn() -> Fut + Send + Sync + 'static,
859        Fut: std::future::Future<Output = crate::Result<T>> + Send + 'static,
860    {
861        super::StepRunner::new(self, name, f)
862    }
863}
864
865impl EnvAccess for WorkflowContext {
866    fn env_provider(&self) -> &dyn EnvProvider {
867        self.env_provider.as_ref()
868    }
869}
870
871#[cfg(test)]
872#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
873mod tests {
874    use super::*;
875
876    #[tokio::test]
877    async fn test_workflow_context_creation() {
878        let pool = sqlx::postgres::PgPoolOptions::new()
879            .max_connections(1)
880            .connect_lazy("postgres://localhost/nonexistent")
881            .expect("Failed to create mock pool");
882
883        let run_id = Uuid::new_v4();
884        let ctx = WorkflowContext::new(
885            run_id,
886            "test_workflow".to_string(),
887            1,
888            pool,
889            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
890        );
891
892        assert_eq!(ctx.run_id, run_id);
893        assert_eq!(ctx.workflow_name, "test_workflow");
894        assert_eq!(ctx.version, 1);
895    }
896
897    #[tokio::test]
898    async fn test_step_state_tracking() {
899        let pool = sqlx::postgres::PgPoolOptions::new()
900            .max_connections(1)
901            .connect_lazy("postgres://localhost/nonexistent")
902            .expect("Failed to create mock pool");
903
904        let ctx = WorkflowContext::new(
905            Uuid::new_v4(),
906            "test".to_string(),
907            1,
908            pool,
909            CircuitBreakerClient::with_defaults(reqwest::Client::new()),
910        );
911
912        ctx.record_step_start("step1");
913        assert!(!ctx.is_step_completed("step1"));
914
915        ctx.record_step_complete("step1", serde_json::json!({"result": "ok"}));
916        assert!(ctx.is_step_completed("step1"));
917
918        let result: Option<serde_json::Value> = ctx.get_step_result("step1");
919        assert!(result.is_some());
920    }
921
922    #[test]
923    fn test_step_state_transitions() {
924        let mut state = StepState::new("test");
925        assert_eq!(state.status, StepStatus::Pending);
926
927        state.start();
928        assert_eq!(state.status, StepStatus::Running);
929        assert!(state.started_at.is_some());
930
931        state.complete(serde_json::json!({}));
932        assert_eq!(state.status, StepStatus::Completed);
933        assert!(state.completed_at.is_some());
934    }
935}