Skip to main content

forge_core/workflow/
context.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, RwLock};
5use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use serde::de::DeserializeOwned;
9use tokio::sync::mpsc;
10use uuid::Uuid;
11
12use super::parallel::ParallelBuilder;
13use super::step::StepStatus;
14use super::suspend::{SuspendReason, WorkflowEvent};
15use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
16use crate::function::AuthContext;
17use crate::{ForgeError, Result};
18
19/// Type alias for compensation handler function.
20pub type CompensationHandler = Arc<
21    dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync,
22>;
23
24/// Step state stored during execution.
25#[derive(Debug, Clone)]
26pub struct StepState {
27    /// Step name.
28    pub name: String,
29    /// Step status.
30    pub status: StepStatus,
31    /// Serialized result (if completed).
32    pub result: Option<serde_json::Value>,
33    /// Error message (if failed).
34    pub error: Option<String>,
35    /// When the step started.
36    pub started_at: Option<DateTime<Utc>>,
37    /// When the step completed.
38    pub completed_at: Option<DateTime<Utc>>,
39}
40
41impl StepState {
42    /// Create a new pending step state.
43    pub fn new(name: impl Into<String>) -> Self {
44        Self {
45            name: name.into(),
46            status: StepStatus::Pending,
47            result: None,
48            error: None,
49            started_at: None,
50            completed_at: None,
51        }
52    }
53
54    /// Mark step as running.
55    pub fn start(&mut self) {
56        self.status = StepStatus::Running;
57        self.started_at = Some(Utc::now());
58    }
59
60    /// Mark step as completed with result.
61    pub fn complete(&mut self, result: serde_json::Value) {
62        self.status = StepStatus::Completed;
63        self.result = Some(result);
64        self.completed_at = Some(Utc::now());
65    }
66
67    /// Mark step as failed with error.
68    pub fn fail(&mut self, error: impl Into<String>) {
69        self.status = StepStatus::Failed;
70        self.error = Some(error.into());
71        self.completed_at = Some(Utc::now());
72    }
73
74    /// Mark step as compensated.
75    pub fn compensate(&mut self) {
76        self.status = StepStatus::Compensated;
77    }
78}
79
80/// Context available to workflow handlers.
81pub struct WorkflowContext {
82    /// Workflow run ID.
83    pub run_id: Uuid,
84    /// Workflow name.
85    pub workflow_name: String,
86    /// Workflow version.
87    pub version: u32,
88    /// When the workflow started.
89    pub started_at: DateTime<Utc>,
90    /// Deterministic workflow time (consistent across replays).
91    workflow_time: DateTime<Utc>,
92    /// Authentication context.
93    pub auth: AuthContext,
94    /// Database pool.
95    db_pool: sqlx::PgPool,
96    /// HTTP client.
97    http_client: reqwest::Client,
98    /// Step states (for resumption).
99    step_states: Arc<RwLock<HashMap<String, StepState>>>,
100    /// Completed steps in order (for compensation).
101    completed_steps: Arc<RwLock<Vec<String>>>,
102    /// Compensation handlers for completed steps.
103    compensation_handlers: Arc<RwLock<HashMap<String, CompensationHandler>>>,
104    /// Channel for signaling suspension (sent by workflow, received by executor).
105    suspend_tx: Option<mpsc::Sender<SuspendReason>>,
106    /// Whether this is a resumed execution.
107    is_resumed: bool,
108    /// Whether this execution resumed specifically from a sleep (timer expired).
109    resumed_from_sleep: bool,
110    /// Tenant ID for multi-tenancy.
111    tenant_id: Option<Uuid>,
112    /// Environment variable provider.
113    env_provider: Arc<dyn EnvProvider>,
114}
115
116impl WorkflowContext {
117    /// Create a new workflow context.
118    pub fn new(
119        run_id: Uuid,
120        workflow_name: String,
121        version: u32,
122        db_pool: sqlx::PgPool,
123        http_client: reqwest::Client,
124    ) -> Self {
125        let now = Utc::now();
126        Self {
127            run_id,
128            workflow_name,
129            version,
130            started_at: now,
131            workflow_time: now,
132            auth: AuthContext::unauthenticated(),
133            db_pool,
134            http_client,
135            step_states: Arc::new(RwLock::new(HashMap::new())),
136            completed_steps: Arc::new(RwLock::new(Vec::new())),
137            compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
138            suspend_tx: None,
139            is_resumed: false,
140            resumed_from_sleep: false,
141            tenant_id: None,
142            env_provider: Arc::new(RealEnvProvider::new()),
143        }
144    }
145
146    /// Create a resumed workflow context.
147    pub fn resumed(
148        run_id: Uuid,
149        workflow_name: String,
150        version: u32,
151        started_at: DateTime<Utc>,
152        db_pool: sqlx::PgPool,
153        http_client: reqwest::Client,
154    ) -> Self {
155        Self {
156            run_id,
157            workflow_name,
158            version,
159            started_at,
160            workflow_time: started_at,
161            auth: AuthContext::unauthenticated(),
162            db_pool,
163            http_client,
164            step_states: Arc::new(RwLock::new(HashMap::new())),
165            completed_steps: Arc::new(RwLock::new(Vec::new())),
166            compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
167            suspend_tx: None,
168            is_resumed: true,
169            resumed_from_sleep: false,
170            tenant_id: None,
171            env_provider: Arc::new(RealEnvProvider::new()),
172        }
173    }
174
175    /// Set environment provider.
176    pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
177        self.env_provider = provider;
178        self
179    }
180
181    /// Mark that this context resumed from a sleep (timer expired).
182    pub fn with_resumed_from_sleep(mut self) -> Self {
183        self.resumed_from_sleep = true;
184        self
185    }
186
187    /// Set the suspend channel.
188    pub fn with_suspend_channel(mut self, tx: mpsc::Sender<SuspendReason>) -> Self {
189        self.suspend_tx = Some(tx);
190        self
191    }
192
193    /// Set the tenant ID.
194    pub fn with_tenant(mut self, tenant_id: Uuid) -> Self {
195        self.tenant_id = Some(tenant_id);
196        self
197    }
198
199    /// Get the tenant ID.
200    pub fn tenant_id(&self) -> Option<Uuid> {
201        self.tenant_id
202    }
203
204    /// Check if this is a resumed execution.
205    pub fn is_resumed(&self) -> bool {
206        self.is_resumed
207    }
208
209    /// Get the deterministic workflow time.
210    pub fn workflow_time(&self) -> DateTime<Utc> {
211        self.workflow_time
212    }
213
214    /// Get the database pool.
215    pub fn db(&self) -> &sqlx::PgPool {
216        &self.db_pool
217    }
218
219    /// Get the HTTP client.
220    pub fn http(&self) -> &reqwest::Client {
221        &self.http_client
222    }
223
224    /// Set authentication context.
225    pub fn with_auth(mut self, auth: AuthContext) -> Self {
226        self.auth = auth;
227        self
228    }
229
230    /// Restore step states from persisted data.
231    pub fn with_step_states(self, states: HashMap<String, StepState>) -> Self {
232        let completed: Vec<String> = states
233            .iter()
234            .filter(|(_, s)| s.status == StepStatus::Completed)
235            .map(|(name, _)| name.clone())
236            .collect();
237
238        *self.step_states.write().unwrap() = states;
239        *self.completed_steps.write().unwrap() = completed;
240        self
241    }
242
243    /// Get step state by name.
244    pub fn get_step_state(&self, name: &str) -> Option<StepState> {
245        self.step_states.read().unwrap().get(name).cloned()
246    }
247
248    /// Check if a step is already completed.
249    pub fn is_step_completed(&self, name: &str) -> bool {
250        self.step_states
251            .read()
252            .unwrap()
253            .get(name)
254            .map(|s| s.status == StepStatus::Completed)
255            .unwrap_or(false)
256    }
257
258    /// Check if a step has been started (running, completed, or failed).
259    ///
260    /// Use this to guard steps that should only execute once, even across
261    /// workflow suspension and resumption.
262    pub fn is_step_started(&self, name: &str) -> bool {
263        self.step_states
264            .read()
265            .unwrap()
266            .get(name)
267            .map(|s| s.status != StepStatus::Pending)
268            .unwrap_or(false)
269    }
270
271    /// Get the result of a completed step.
272    pub fn get_step_result<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
273        self.step_states
274            .read()
275            .unwrap()
276            .get(name)
277            .and_then(|s| s.result.as_ref())
278            .and_then(|v| serde_json::from_value(v.clone()).ok())
279    }
280
281    /// Record step start.
282    ///
283    /// If the step is already running or beyond (completed/failed), this is a no-op.
284    /// This prevents race conditions when resuming workflows.
285    pub fn record_step_start(&self, name: &str) {
286        let mut states = self.step_states.write().unwrap();
287        let state = states
288            .entry(name.to_string())
289            .or_insert_with(|| StepState::new(name));
290
291        // Only update if step is pending - prevents race condition on resume
292        // where background DB update could overwrite a completed status
293        if state.status != StepStatus::Pending {
294            return;
295        }
296
297        state.start();
298        let state_clone = state.clone();
299        drop(states);
300
301        // Persist to database in background
302        let pool = self.db_pool.clone();
303        let run_id = self.run_id;
304        let step_name = name.to_string();
305        tokio::spawn(async move {
306            let step_id = Uuid::new_v4();
307            if let Err(e) = sqlx::query(
308                r#"
309                INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, started_at)
310                VALUES ($1, $2, $3, $4, $5)
311                ON CONFLICT (workflow_run_id, step_name) DO NOTHING
312                "#,
313            )
314            .bind(step_id)
315            .bind(run_id)
316            .bind(&step_name)
317            .bind(state_clone.status.as_str())
318            .bind(state_clone.started_at)
319            .execute(&pool)
320            .await
321            {
322                tracing::warn!(
323                    workflow_run_id = %run_id,
324                    step = %step_name,
325                    "Failed to persist step start: {}",
326                    e
327                );
328            }
329        });
330    }
331
332    /// Record step completion (fire-and-forget database update).
333    /// Use `record_step_complete_async` if you need to ensure persistence before continuing.
334    pub fn record_step_complete(&self, name: &str, result: serde_json::Value) {
335        let state_clone = self.update_step_state_complete(name, result);
336
337        // Persist to database in background
338        if let Some(state) = state_clone {
339            let pool = self.db_pool.clone();
340            let run_id = self.run_id;
341            let step_name = name.to_string();
342            tokio::spawn(async move {
343                Self::persist_step_complete(&pool, run_id, &step_name, &state).await;
344            });
345        }
346    }
347
348    /// Record step completion and wait for database persistence.
349    pub async fn record_step_complete_async(&self, name: &str, result: serde_json::Value) {
350        let state_clone = self.update_step_state_complete(name, result);
351
352        // Persist to database synchronously
353        if let Some(state) = state_clone {
354            Self::persist_step_complete(&self.db_pool, self.run_id, name, &state).await;
355        }
356    }
357
358    /// Update in-memory step state to completed.
359    fn update_step_state_complete(
360        &self,
361        name: &str,
362        result: serde_json::Value,
363    ) -> Option<StepState> {
364        let mut states = self.step_states.write().unwrap();
365        if let Some(state) = states.get_mut(name) {
366            state.complete(result.clone());
367        }
368        let state_clone = states.get(name).cloned();
369        drop(states);
370
371        let mut completed = self.completed_steps.write().unwrap();
372        if !completed.contains(&name.to_string()) {
373            completed.push(name.to_string());
374        }
375        drop(completed);
376
377        state_clone
378    }
379
380    /// Persist step completion to database.
381    async fn persist_step_complete(
382        pool: &sqlx::PgPool,
383        run_id: Uuid,
384        step_name: &str,
385        state: &StepState,
386    ) {
387        // Use UPSERT to handle race condition where persist_step_start hasn't completed yet
388        if let Err(e) = sqlx::query(
389            r#"
390            INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, result, started_at, completed_at)
391            VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6)
392            ON CONFLICT (workflow_run_id, step_name) DO UPDATE
393            SET status = $3, result = $4, completed_at = $6
394            "#,
395        )
396        .bind(run_id)
397        .bind(step_name)
398        .bind(state.status.as_str())
399        .bind(&state.result)
400        .bind(state.started_at)
401        .bind(state.completed_at)
402        .execute(pool)
403        .await
404        {
405            tracing::warn!(
406                workflow_run_id = %run_id,
407                step = %step_name,
408                "Failed to persist step completion: {}",
409                e
410            );
411        }
412    }
413
414    /// Record step failure.
415    pub fn record_step_failure(&self, name: &str, error: impl Into<String>) {
416        let error_str = error.into();
417        let mut states = self.step_states.write().unwrap();
418        if let Some(state) = states.get_mut(name) {
419            state.fail(error_str.clone());
420        }
421        let state_clone = states.get(name).cloned();
422        drop(states);
423
424        // Persist to database in background
425        if let Some(state) = state_clone {
426            let pool = self.db_pool.clone();
427            let run_id = self.run_id;
428            let step_name = name.to_string();
429            tokio::spawn(async move {
430                if let Err(e) = sqlx::query(
431                    r#"
432                    UPDATE forge_workflow_steps
433                    SET status = $3, error = $4, completed_at = $5
434                    WHERE workflow_run_id = $1 AND step_name = $2
435                    "#,
436                )
437                .bind(run_id)
438                .bind(&step_name)
439                .bind(state.status.as_str())
440                .bind(&state.error)
441                .bind(state.completed_at)
442                .execute(&pool)
443                .await
444                {
445                    tracing::warn!(
446                        workflow_run_id = %run_id,
447                        step = %step_name,
448                        "Failed to persist step failure: {}",
449                        e
450                    );
451                }
452            });
453        }
454    }
455
456    /// Record step compensation.
457    pub fn record_step_compensated(&self, name: &str) {
458        let mut states = self.step_states.write().unwrap();
459        if let Some(state) = states.get_mut(name) {
460            state.compensate();
461        }
462        let state_clone = states.get(name).cloned();
463        drop(states);
464
465        // Persist to database in background
466        if let Some(state) = state_clone {
467            let pool = self.db_pool.clone();
468            let run_id = self.run_id;
469            let step_name = name.to_string();
470            tokio::spawn(async move {
471                if let Err(e) = sqlx::query(
472                    r#"
473                    UPDATE forge_workflow_steps
474                    SET status = $3
475                    WHERE workflow_run_id = $1 AND step_name = $2
476                    "#,
477                )
478                .bind(run_id)
479                .bind(&step_name)
480                .bind(state.status.as_str())
481                .execute(&pool)
482                .await
483                {
484                    tracing::warn!(
485                        workflow_run_id = %run_id,
486                        step = %step_name,
487                        "Failed to persist step compensation: {}",
488                        e
489                    );
490                }
491            });
492        }
493    }
494
495    /// Get completed steps in reverse order (for compensation).
496    pub fn completed_steps_reversed(&self) -> Vec<String> {
497        let completed = self.completed_steps.read().unwrap();
498        completed.iter().rev().cloned().collect()
499    }
500
501    /// Get all step states.
502    pub fn all_step_states(&self) -> HashMap<String, StepState> {
503        self.step_states.read().unwrap().clone()
504    }
505
506    /// Get elapsed time since workflow started.
507    pub fn elapsed(&self) -> chrono::Duration {
508        Utc::now() - self.started_at
509    }
510
511    /// Register a compensation handler for a step.
512    pub fn register_compensation(&self, step_name: &str, handler: CompensationHandler) {
513        let mut handlers = self.compensation_handlers.write().unwrap();
514        handlers.insert(step_name.to_string(), handler);
515    }
516
517    /// Get compensation handler for a step.
518    pub fn get_compensation_handler(&self, step_name: &str) -> Option<CompensationHandler> {
519        self.compensation_handlers
520            .read()
521            .unwrap()
522            .get(step_name)
523            .cloned()
524    }
525
526    /// Check if a step has a compensation handler.
527    pub fn has_compensation(&self, step_name: &str) -> bool {
528        self.compensation_handlers
529            .read()
530            .unwrap()
531            .contains_key(step_name)
532    }
533
534    /// Run compensation for all completed steps in reverse order.
535    /// Returns a list of (step_name, success) tuples.
536    pub async fn run_compensation(&self) -> Vec<(String, bool)> {
537        let steps = self.completed_steps_reversed();
538        let mut results = Vec::new();
539
540        for step_name in steps {
541            let handler = self.get_compensation_handler(&step_name);
542            let result = self
543                .get_step_state(&step_name)
544                .and_then(|s| s.result.clone());
545
546            if let Some(handler) = handler {
547                let step_result = result.unwrap_or(serde_json::Value::Null);
548                match handler(step_result).await {
549                    Ok(()) => {
550                        self.record_step_compensated(&step_name);
551                        results.push((step_name, true));
552                    }
553                    Err(e) => {
554                        tracing::error!(step = %step_name, error = %e, "Compensation failed");
555                        results.push((step_name, false));
556                    }
557                }
558            } else {
559                // No compensation handler, mark as compensated anyway
560                self.record_step_compensated(&step_name);
561                results.push((step_name, true));
562            }
563        }
564
565        results
566    }
567
568    /// Get compensation handlers (for cloning to executor).
569    pub fn compensation_handlers(&self) -> HashMap<String, CompensationHandler> {
570        self.compensation_handlers.read().unwrap().clone()
571    }
572
573    // =========================================================================
574    // DURABLE WORKFLOW API
575    // =========================================================================
576
577    /// Sleep for a duration.
578    ///
579    /// This suspends the workflow and persists the wake time to the database.
580    /// The workflow scheduler will resume the workflow when the time arrives.
581    ///
582    /// # Example
583    /// ```ignore
584    /// // Sleep for 30 days
585    /// ctx.sleep(Duration::from_secs(30 * 24 * 60 * 60)).await?;
586    /// ```
587    pub async fn sleep(&self, duration: Duration) -> Result<()> {
588        // If we resumed from a sleep, the timer already expired - continue immediately
589        if self.resumed_from_sleep {
590            return Ok(());
591        }
592
593        let wake_at = Utc::now() + chrono::Duration::from_std(duration).unwrap_or_default();
594        self.sleep_until(wake_at).await
595    }
596
597    /// Sleep until a specific time.
598    ///
599    /// If the wake time has already passed, returns immediately.
600    ///
601    /// # Example
602    /// ```ignore
603    /// use chrono::{Utc, Duration};
604    /// let renewal_date = Utc::now() + Duration::days(30);
605    /// ctx.sleep_until(renewal_date).await?;
606    /// ```
607    pub async fn sleep_until(&self, wake_at: DateTime<Utc>) -> Result<()> {
608        // If we resumed from a sleep, the timer already expired - continue immediately
609        if self.resumed_from_sleep {
610            return Ok(());
611        }
612
613        // If wake time already passed, return immediately
614        if wake_at <= Utc::now() {
615            return Ok(());
616        }
617
618        // Persist the wake time to database
619        self.set_wake_at(wake_at).await?;
620
621        // Signal suspension to executor
622        self.signal_suspend(SuspendReason::Sleep { wake_at })
623            .await?;
624
625        Ok(())
626    }
627
628    /// Wait for an external event with optional timeout.
629    ///
630    /// The workflow suspends until the event arrives or the timeout expires.
631    /// Events are correlated by the workflow run ID.
632    ///
633    /// # Example
634    /// ```ignore
635    /// let payment: PaymentConfirmation = ctx.wait_for_event(
636    ///     "payment_confirmed",
637    ///     Some(Duration::from_secs(7 * 24 * 60 * 60)), // 7 days
638    /// ).await?;
639    /// ```
640    pub async fn wait_for_event<T: DeserializeOwned>(
641        &self,
642        event_name: &str,
643        timeout: Option<Duration>,
644    ) -> Result<T> {
645        let correlation_id = self.run_id.to_string();
646
647        // Check if event already exists (race condition handling)
648        if let Some(event) = self.try_consume_event(event_name, &correlation_id).await? {
649            return serde_json::from_value(event.payload.unwrap_or_default())
650                .map_err(|e| ForgeError::Deserialization(e.to_string()));
651        }
652
653        // Calculate timeout
654        let timeout_at =
655            timeout.map(|d| Utc::now() + chrono::Duration::from_std(d).unwrap_or_default());
656
657        // Persist waiting state
658        self.set_waiting_for_event(event_name, timeout_at).await?;
659
660        // Signal suspension
661        self.signal_suspend(SuspendReason::WaitingEvent {
662            event_name: event_name.to_string(),
663            timeout: timeout_at,
664        })
665        .await?;
666
667        // After resume, try to consume the event
668        self.try_consume_event(event_name, &correlation_id)
669            .await?
670            .and_then(|e| e.payload)
671            .and_then(|p| serde_json::from_value(p).ok())
672            .ok_or_else(|| ForgeError::Timeout(format!("Event '{}' timed out", event_name)))
673    }
674
675    /// Try to consume an event from the database.
676    #[allow(clippy::type_complexity)]
677    async fn try_consume_event(
678        &self,
679        event_name: &str,
680        correlation_id: &str,
681    ) -> Result<Option<WorkflowEvent>> {
682        let result: Option<(
683            Uuid,
684            String,
685            String,
686            Option<serde_json::Value>,
687            DateTime<Utc>,
688        )> = sqlx::query_as(
689            r#"
690                UPDATE forge_workflow_events
691                SET consumed_at = NOW(), consumed_by = $3
692                WHERE id = (
693                    SELECT id FROM forge_workflow_events
694                    WHERE event_name = $1 AND correlation_id = $2 AND consumed_at IS NULL
695                    ORDER BY created_at ASC LIMIT 1
696                    FOR UPDATE SKIP LOCKED
697                )
698                RETURNING id, event_name, correlation_id, payload, created_at
699                "#,
700        )
701        .bind(event_name)
702        .bind(correlation_id)
703        .bind(self.run_id)
704        .fetch_optional(&self.db_pool)
705        .await
706        .map_err(|e| ForgeError::Database(e.to_string()))?;
707
708        Ok(result.map(
709            |(id, event_name, correlation_id, payload, created_at)| WorkflowEvent {
710                id,
711                event_name,
712                correlation_id,
713                payload,
714                created_at,
715            },
716        ))
717    }
718
719    /// Persist wake time to database.
720    async fn set_wake_at(&self, wake_at: DateTime<Utc>) -> Result<()> {
721        sqlx::query(
722            r#"
723            UPDATE forge_workflow_runs
724            SET status = 'waiting', suspended_at = NOW(), wake_at = $2
725            WHERE id = $1
726            "#,
727        )
728        .bind(self.run_id)
729        .bind(wake_at)
730        .execute(&self.db_pool)
731        .await
732        .map_err(|e| ForgeError::Database(e.to_string()))?;
733        Ok(())
734    }
735
736    /// Persist waiting for event state to database.
737    async fn set_waiting_for_event(
738        &self,
739        event_name: &str,
740        timeout_at: Option<DateTime<Utc>>,
741    ) -> Result<()> {
742        sqlx::query(
743            r#"
744            UPDATE forge_workflow_runs
745            SET status = 'waiting', suspended_at = NOW(), waiting_for_event = $2, event_timeout_at = $3
746            WHERE id = $1
747            "#,
748        )
749        .bind(self.run_id)
750        .bind(event_name)
751        .bind(timeout_at)
752        .execute(&self.db_pool)
753        .await
754        .map_err(|e| ForgeError::Database(e.to_string()))?;
755        Ok(())
756    }
757
758    /// Signal suspension to the executor.
759    async fn signal_suspend(&self, reason: SuspendReason) -> Result<()> {
760        if let Some(ref tx) = self.suspend_tx {
761            tx.send(reason)
762                .await
763                .map_err(|_| ForgeError::Internal("Failed to signal suspension".into()))?;
764        }
765        // Return a special error that the executor catches
766        Err(ForgeError::WorkflowSuspended)
767    }
768
769    // =========================================================================
770    // PARALLEL EXECUTION API
771    // =========================================================================
772
773    /// Create a parallel builder for executing steps concurrently.
774    ///
775    /// # Example
776    /// ```ignore
777    /// let results = ctx.parallel()
778    ///     .step("fetch_user", || async { get_user(id).await })
779    ///     .step("fetch_orders", || async { get_orders(id).await })
780    ///     .step_with_compensate("charge_card",
781    ///         || async { charge_card(amount).await },
782    ///         |charge| async move { refund(charge.id).await })
783    ///     .run().await?;
784    ///
785    /// let user: User = results.get("fetch_user")?;
786    /// let orders: Vec<Order> = results.get("fetch_orders")?;
787    /// ```
788    pub fn parallel(&self) -> ParallelBuilder<'_> {
789        ParallelBuilder::new(self)
790    }
791
792    // =========================================================================
793    // FLUENT STEP API
794    // =========================================================================
795
796    /// Create a step runner for executing a workflow step.
797    ///
798    /// This provides a fluent API for defining steps with retry, compensation,
799    /// timeout, and optional behavior.
800    ///
801    /// # Examples
802    ///
803    /// ```ignore
804    /// use std::time::Duration;
805    ///
806    /// // Simple step
807    /// let data = ctx.step("fetch_data", || async {
808    ///     Ok(fetch_from_api().await?)
809    /// }).run().await?;
810    ///
811    /// // Step with retry (3 attempts, 2 second delay)
812    /// ctx.step("send_email", || async {
813    ///     send_verification_email(&user.email).await
814    /// })
815    /// .retry(3, Duration::from_secs(2))
816    /// .run()
817    /// .await?;
818    ///
819    /// // Step with compensation (rollback on later failure)
820    /// let charge = ctx.step("charge_card", || async {
821    ///     charge_credit_card(&card).await
822    /// })
823    /// .compensate(|charge_result| async move {
824    ///     refund_charge(&charge_result.charge_id).await
825    /// })
826    /// .run()
827    /// .await?;
828    ///
829    /// // Optional step (failure won't trigger compensation)
830    /// ctx.step("notify_slack", || async {
831    ///     post_to_slack("User signed up!").await
832    /// })
833    /// .optional()
834    /// .run()
835    /// .await?;
836    ///
837    /// // Step with timeout
838    /// ctx.step("slow_operation", || async {
839    ///     process_large_file().await
840    /// })
841    /// .timeout(Duration::from_secs(60))
842    /// .run()
843    /// .await?;
844    /// ```
845    pub fn step<T, F, Fut>(&self, name: impl Into<String>, f: F) -> super::StepRunner<'_, T>
846    where
847        T: serde::Serialize + serde::de::DeserializeOwned + Clone + Send + Sync + 'static,
848        F: FnOnce() -> Fut + Send + 'static,
849        Fut: std::future::Future<Output = crate::Result<T>> + Send + 'static,
850    {
851        super::StepRunner::new(self, name, f)
852    }
853}
854
855impl EnvAccess for WorkflowContext {
856    fn env_provider(&self) -> &dyn EnvProvider {
857        self.env_provider.as_ref()
858    }
859}
860
861#[cfg(test)]
862mod tests {
863    use super::*;
864
865    #[tokio::test]
866    async fn test_workflow_context_creation() {
867        let pool = sqlx::postgres::PgPoolOptions::new()
868            .max_connections(1)
869            .connect_lazy("postgres://localhost/nonexistent")
870            .expect("Failed to create mock pool");
871
872        let run_id = Uuid::new_v4();
873        let ctx = WorkflowContext::new(
874            run_id,
875            "test_workflow".to_string(),
876            1,
877            pool,
878            reqwest::Client::new(),
879        );
880
881        assert_eq!(ctx.run_id, run_id);
882        assert_eq!(ctx.workflow_name, "test_workflow");
883        assert_eq!(ctx.version, 1);
884    }
885
886    #[tokio::test]
887    async fn test_step_state_tracking() {
888        let pool = sqlx::postgres::PgPoolOptions::new()
889            .max_connections(1)
890            .connect_lazy("postgres://localhost/nonexistent")
891            .expect("Failed to create mock pool");
892
893        let ctx = WorkflowContext::new(
894            Uuid::new_v4(),
895            "test".to_string(),
896            1,
897            pool,
898            reqwest::Client::new(),
899        );
900
901        ctx.record_step_start("step1");
902        assert!(!ctx.is_step_completed("step1"));
903
904        ctx.record_step_complete("step1", serde_json::json!({"result": "ok"}));
905        assert!(ctx.is_step_completed("step1"));
906
907        let result: Option<serde_json::Value> = ctx.get_step_result("step1");
908        assert!(result.is_some());
909    }
910
911    #[test]
912    fn test_step_state_transitions() {
913        let mut state = StepState::new("test");
914        assert_eq!(state.status, StepStatus::Pending);
915
916        state.start();
917        assert_eq!(state.status, StepStatus::Running);
918        assert!(state.started_at.is_some());
919
920        state.complete(serde_json::json!({}));
921        assert_eq!(state.status, StepStatus::Completed);
922        assert!(state.completed_at.is_some());
923    }
924}