forge_core/workflow/
context.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, RwLock};
5use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use serde::de::DeserializeOwned;
9use tokio::sync::mpsc;
10use uuid::Uuid;
11
12use super::parallel::ParallelBuilder;
13use super::step::StepStatus;
14use super::suspend::{SuspendReason, WorkflowEvent};
15use crate::function::AuthContext;
16use crate::{ForgeError, Result};
17
18/// Type alias for compensation handler function.
19pub type CompensationHandler = Arc<
20    dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync,
21>;
22
23/// Step state stored during execution.
24#[derive(Debug, Clone)]
25pub struct StepState {
26    /// Step name.
27    pub name: String,
28    /// Step status.
29    pub status: StepStatus,
30    /// Serialized result (if completed).
31    pub result: Option<serde_json::Value>,
32    /// Error message (if failed).
33    pub error: Option<String>,
34    /// When the step started.
35    pub started_at: Option<DateTime<Utc>>,
36    /// When the step completed.
37    pub completed_at: Option<DateTime<Utc>>,
38}
39
40impl StepState {
41    /// Create a new pending step state.
42    pub fn new(name: impl Into<String>) -> Self {
43        Self {
44            name: name.into(),
45            status: StepStatus::Pending,
46            result: None,
47            error: None,
48            started_at: None,
49            completed_at: None,
50        }
51    }
52
53    /// Mark step as running.
54    pub fn start(&mut self) {
55        self.status = StepStatus::Running;
56        self.started_at = Some(Utc::now());
57    }
58
59    /// Mark step as completed with result.
60    pub fn complete(&mut self, result: serde_json::Value) {
61        self.status = StepStatus::Completed;
62        self.result = Some(result);
63        self.completed_at = Some(Utc::now());
64    }
65
66    /// Mark step as failed with error.
67    pub fn fail(&mut self, error: impl Into<String>) {
68        self.status = StepStatus::Failed;
69        self.error = Some(error.into());
70        self.completed_at = Some(Utc::now());
71    }
72
73    /// Mark step as compensated.
74    pub fn compensate(&mut self) {
75        self.status = StepStatus::Compensated;
76    }
77}
78
79/// Context available to workflow handlers.
80pub struct WorkflowContext {
81    /// Workflow run ID.
82    pub run_id: Uuid,
83    /// Workflow name.
84    pub workflow_name: String,
85    /// Workflow version.
86    pub version: u32,
87    /// When the workflow started.
88    pub started_at: DateTime<Utc>,
89    /// Deterministic workflow time (consistent across replays).
90    workflow_time: DateTime<Utc>,
91    /// Authentication context.
92    pub auth: AuthContext,
93    /// Database pool.
94    db_pool: sqlx::PgPool,
95    /// HTTP client.
96    http_client: reqwest::Client,
97    /// Step states (for resumption).
98    step_states: Arc<RwLock<HashMap<String, StepState>>>,
99    /// Completed steps in order (for compensation).
100    completed_steps: Arc<RwLock<Vec<String>>>,
101    /// Compensation handlers for completed steps.
102    compensation_handlers: Arc<RwLock<HashMap<String, CompensationHandler>>>,
103    /// Channel for signaling suspension (sent by workflow, received by executor).
104    suspend_tx: Option<mpsc::Sender<SuspendReason>>,
105    /// Whether this is a resumed execution.
106    is_resumed: bool,
107    /// Whether this execution resumed specifically from a sleep (timer expired).
108    resumed_from_sleep: bool,
109    /// Tenant ID for multi-tenancy.
110    tenant_id: Option<Uuid>,
111}
112
113impl WorkflowContext {
114    /// Create a new workflow context.
115    pub fn new(
116        run_id: Uuid,
117        workflow_name: String,
118        version: u32,
119        db_pool: sqlx::PgPool,
120        http_client: reqwest::Client,
121    ) -> Self {
122        let now = Utc::now();
123        Self {
124            run_id,
125            workflow_name,
126            version,
127            started_at: now,
128            workflow_time: now,
129            auth: AuthContext::unauthenticated(),
130            db_pool,
131            http_client,
132            step_states: Arc::new(RwLock::new(HashMap::new())),
133            completed_steps: Arc::new(RwLock::new(Vec::new())),
134            compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
135            suspend_tx: None,
136            is_resumed: false,
137            resumed_from_sleep: false,
138            tenant_id: None,
139        }
140    }
141
142    /// Create a resumed workflow context.
143    pub fn resumed(
144        run_id: Uuid,
145        workflow_name: String,
146        version: u32,
147        started_at: DateTime<Utc>,
148        db_pool: sqlx::PgPool,
149        http_client: reqwest::Client,
150    ) -> Self {
151        Self {
152            run_id,
153            workflow_name,
154            version,
155            started_at,
156            workflow_time: started_at,
157            auth: AuthContext::unauthenticated(),
158            db_pool,
159            http_client,
160            step_states: Arc::new(RwLock::new(HashMap::new())),
161            completed_steps: Arc::new(RwLock::new(Vec::new())),
162            compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
163            suspend_tx: None,
164            is_resumed: true,
165            resumed_from_sleep: false,
166            tenant_id: None,
167        }
168    }
169
170    /// Mark that this context resumed from a sleep (timer expired).
171    pub fn with_resumed_from_sleep(mut self) -> Self {
172        self.resumed_from_sleep = true;
173        self
174    }
175
176    /// Set the suspend channel.
177    pub fn with_suspend_channel(mut self, tx: mpsc::Sender<SuspendReason>) -> Self {
178        self.suspend_tx = Some(tx);
179        self
180    }
181
182    /// Set the tenant ID.
183    pub fn with_tenant(mut self, tenant_id: Uuid) -> Self {
184        self.tenant_id = Some(tenant_id);
185        self
186    }
187
188    /// Get the tenant ID.
189    pub fn tenant_id(&self) -> Option<Uuid> {
190        self.tenant_id
191    }
192
193    /// Check if this is a resumed execution.
194    pub fn is_resumed(&self) -> bool {
195        self.is_resumed
196    }
197
198    /// Get the deterministic workflow time.
199    pub fn workflow_time(&self) -> DateTime<Utc> {
200        self.workflow_time
201    }
202
203    /// Get the database pool.
204    pub fn db(&self) -> &sqlx::PgPool {
205        &self.db_pool
206    }
207
208    /// Get the HTTP client.
209    pub fn http(&self) -> &reqwest::Client {
210        &self.http_client
211    }
212
213    /// Set authentication context.
214    pub fn with_auth(mut self, auth: AuthContext) -> Self {
215        self.auth = auth;
216        self
217    }
218
219    /// Restore step states from persisted data.
220    pub fn with_step_states(self, states: HashMap<String, StepState>) -> Self {
221        let completed: Vec<String> = states
222            .iter()
223            .filter(|(_, s)| s.status == StepStatus::Completed)
224            .map(|(name, _)| name.clone())
225            .collect();
226
227        *self.step_states.write().unwrap() = states;
228        *self.completed_steps.write().unwrap() = completed;
229        self
230    }
231
232    /// Get step state by name.
233    pub fn get_step_state(&self, name: &str) -> Option<StepState> {
234        self.step_states.read().unwrap().get(name).cloned()
235    }
236
237    /// Check if a step is already completed.
238    pub fn is_step_completed(&self, name: &str) -> bool {
239        self.step_states
240            .read()
241            .unwrap()
242            .get(name)
243            .map(|s| s.status == StepStatus::Completed)
244            .unwrap_or(false)
245    }
246
247    /// Check if a step has been started (running, completed, or failed).
248    ///
249    /// Use this to guard steps that should only execute once, even across
250    /// workflow suspension and resumption.
251    pub fn is_step_started(&self, name: &str) -> bool {
252        self.step_states
253            .read()
254            .unwrap()
255            .get(name)
256            .map(|s| s.status != StepStatus::Pending)
257            .unwrap_or(false)
258    }
259
260    /// Get the result of a completed step.
261    pub fn get_step_result<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
262        self.step_states
263            .read()
264            .unwrap()
265            .get(name)
266            .and_then(|s| s.result.as_ref())
267            .and_then(|v| serde_json::from_value(v.clone()).ok())
268    }
269
270    /// Record step start.
271    ///
272    /// If the step is already running or beyond (completed/failed), this is a no-op.
273    /// This prevents race conditions when resuming workflows.
274    pub fn record_step_start(&self, name: &str) {
275        let mut states = self.step_states.write().unwrap();
276        let state = states
277            .entry(name.to_string())
278            .or_insert_with(|| StepState::new(name));
279
280        // Only update if step is pending - prevents race condition on resume
281        // where background DB update could overwrite a completed status
282        if state.status != StepStatus::Pending {
283            return;
284        }
285
286        state.start();
287        let state_clone = state.clone();
288        drop(states);
289
290        // Persist to database in background
291        let pool = self.db_pool.clone();
292        let run_id = self.run_id;
293        let step_name = name.to_string();
294        tokio::spawn(async move {
295            let step_id = Uuid::new_v4();
296            if let Err(e) = sqlx::query(
297                r#"
298                INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, started_at)
299                VALUES ($1, $2, $3, $4, $5)
300                ON CONFLICT (workflow_run_id, step_name) DO NOTHING
301                "#,
302            )
303            .bind(step_id)
304            .bind(run_id)
305            .bind(&step_name)
306            .bind(state_clone.status.as_str())
307            .bind(state_clone.started_at)
308            .execute(&pool)
309            .await
310            {
311                tracing::warn!(
312                    workflow_run_id = %run_id,
313                    step = %step_name,
314                    "Failed to persist step start: {}",
315                    e
316                );
317            }
318        });
319    }
320
321    /// Record step completion (fire-and-forget database update).
322    /// Use `record_step_complete_async` if you need to ensure persistence before continuing.
323    pub fn record_step_complete(&self, name: &str, result: serde_json::Value) {
324        let state_clone = self.update_step_state_complete(name, result);
325
326        // Persist to database in background
327        if let Some(state) = state_clone {
328            let pool = self.db_pool.clone();
329            let run_id = self.run_id;
330            let step_name = name.to_string();
331            tokio::spawn(async move {
332                Self::persist_step_complete(&pool, run_id, &step_name, &state).await;
333            });
334        }
335    }
336
337    /// Record step completion and wait for database persistence.
338    pub async fn record_step_complete_async(&self, name: &str, result: serde_json::Value) {
339        let state_clone = self.update_step_state_complete(name, result);
340
341        // Persist to database synchronously
342        if let Some(state) = state_clone {
343            Self::persist_step_complete(&self.db_pool, self.run_id, name, &state).await;
344        }
345    }
346
347    /// Update in-memory step state to completed.
348    fn update_step_state_complete(
349        &self,
350        name: &str,
351        result: serde_json::Value,
352    ) -> Option<StepState> {
353        let mut states = self.step_states.write().unwrap();
354        if let Some(state) = states.get_mut(name) {
355            state.complete(result.clone());
356        }
357        let state_clone = states.get(name).cloned();
358        drop(states);
359
360        let mut completed = self.completed_steps.write().unwrap();
361        if !completed.contains(&name.to_string()) {
362            completed.push(name.to_string());
363        }
364        drop(completed);
365
366        state_clone
367    }
368
369    /// Persist step completion to database.
370    async fn persist_step_complete(
371        pool: &sqlx::PgPool,
372        run_id: Uuid,
373        step_name: &str,
374        state: &StepState,
375    ) {
376        // Use UPSERT to handle race condition where persist_step_start hasn't completed yet
377        if let Err(e) = sqlx::query(
378            r#"
379            INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, result, started_at, completed_at)
380            VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6)
381            ON CONFLICT (workflow_run_id, step_name) DO UPDATE
382            SET status = $3, result = $4, completed_at = $6
383            "#,
384        )
385        .bind(run_id)
386        .bind(step_name)
387        .bind(state.status.as_str())
388        .bind(&state.result)
389        .bind(state.started_at)
390        .bind(state.completed_at)
391        .execute(pool)
392        .await
393        {
394            tracing::warn!(
395                workflow_run_id = %run_id,
396                step = %step_name,
397                "Failed to persist step completion: {}",
398                e
399            );
400        }
401    }
402
403    /// Record step failure.
404    pub fn record_step_failure(&self, name: &str, error: impl Into<String>) {
405        let error_str = error.into();
406        let mut states = self.step_states.write().unwrap();
407        if let Some(state) = states.get_mut(name) {
408            state.fail(error_str.clone());
409        }
410        let state_clone = states.get(name).cloned();
411        drop(states);
412
413        // Persist to database in background
414        if let Some(state) = state_clone {
415            let pool = self.db_pool.clone();
416            let run_id = self.run_id;
417            let step_name = name.to_string();
418            tokio::spawn(async move {
419                if let Err(e) = sqlx::query(
420                    r#"
421                    UPDATE forge_workflow_steps
422                    SET status = $3, error = $4, completed_at = $5
423                    WHERE workflow_run_id = $1 AND step_name = $2
424                    "#,
425                )
426                .bind(run_id)
427                .bind(&step_name)
428                .bind(state.status.as_str())
429                .bind(&state.error)
430                .bind(state.completed_at)
431                .execute(&pool)
432                .await
433                {
434                    tracing::warn!(
435                        workflow_run_id = %run_id,
436                        step = %step_name,
437                        "Failed to persist step failure: {}",
438                        e
439                    );
440                }
441            });
442        }
443    }
444
445    /// Record step compensation.
446    pub fn record_step_compensated(&self, name: &str) {
447        let mut states = self.step_states.write().unwrap();
448        if let Some(state) = states.get_mut(name) {
449            state.compensate();
450        }
451        let state_clone = states.get(name).cloned();
452        drop(states);
453
454        // Persist to database in background
455        if let Some(state) = state_clone {
456            let pool = self.db_pool.clone();
457            let run_id = self.run_id;
458            let step_name = name.to_string();
459            tokio::spawn(async move {
460                if let Err(e) = sqlx::query(
461                    r#"
462                    UPDATE forge_workflow_steps
463                    SET status = $3
464                    WHERE workflow_run_id = $1 AND step_name = $2
465                    "#,
466                )
467                .bind(run_id)
468                .bind(&step_name)
469                .bind(state.status.as_str())
470                .execute(&pool)
471                .await
472                {
473                    tracing::warn!(
474                        workflow_run_id = %run_id,
475                        step = %step_name,
476                        "Failed to persist step compensation: {}",
477                        e
478                    );
479                }
480            });
481        }
482    }
483
484    /// Get completed steps in reverse order (for compensation).
485    pub fn completed_steps_reversed(&self) -> Vec<String> {
486        let completed = self.completed_steps.read().unwrap();
487        completed.iter().rev().cloned().collect()
488    }
489
490    /// Get all step states.
491    pub fn all_step_states(&self) -> HashMap<String, StepState> {
492        self.step_states.read().unwrap().clone()
493    }
494
495    /// Get elapsed time since workflow started.
496    pub fn elapsed(&self) -> chrono::Duration {
497        Utc::now() - self.started_at
498    }
499
500    /// Register a compensation handler for a step.
501    pub fn register_compensation(&self, step_name: &str, handler: CompensationHandler) {
502        let mut handlers = self.compensation_handlers.write().unwrap();
503        handlers.insert(step_name.to_string(), handler);
504    }
505
506    /// Get compensation handler for a step.
507    pub fn get_compensation_handler(&self, step_name: &str) -> Option<CompensationHandler> {
508        self.compensation_handlers
509            .read()
510            .unwrap()
511            .get(step_name)
512            .cloned()
513    }
514
515    /// Check if a step has a compensation handler.
516    pub fn has_compensation(&self, step_name: &str) -> bool {
517        self.compensation_handlers
518            .read()
519            .unwrap()
520            .contains_key(step_name)
521    }
522
523    /// Run compensation for all completed steps in reverse order.
524    /// Returns a list of (step_name, success) tuples.
525    pub async fn run_compensation(&self) -> Vec<(String, bool)> {
526        let steps = self.completed_steps_reversed();
527        let mut results = Vec::new();
528
529        for step_name in steps {
530            let handler = self.get_compensation_handler(&step_name);
531            let result = self
532                .get_step_state(&step_name)
533                .and_then(|s| s.result.clone());
534
535            if let Some(handler) = handler {
536                let step_result = result.unwrap_or(serde_json::Value::Null);
537                match handler(step_result).await {
538                    Ok(()) => {
539                        self.record_step_compensated(&step_name);
540                        results.push((step_name, true));
541                    }
542                    Err(e) => {
543                        tracing::error!(step = %step_name, error = %e, "Compensation failed");
544                        results.push((step_name, false));
545                    }
546                }
547            } else {
548                // No compensation handler, mark as compensated anyway
549                self.record_step_compensated(&step_name);
550                results.push((step_name, true));
551            }
552        }
553
554        results
555    }
556
557    /// Get compensation handlers (for cloning to executor).
558    pub fn compensation_handlers(&self) -> HashMap<String, CompensationHandler> {
559        self.compensation_handlers.read().unwrap().clone()
560    }
561
562    // =========================================================================
563    // DURABLE WORKFLOW API
564    // =========================================================================
565
566    /// Sleep for a duration.
567    ///
568    /// This suspends the workflow and persists the wake time to the database.
569    /// The workflow scheduler will resume the workflow when the time arrives.
570    ///
571    /// # Example
572    /// ```ignore
573    /// // Sleep for 30 days
574    /// ctx.sleep(Duration::from_secs(30 * 24 * 60 * 60)).await?;
575    /// ```
576    pub async fn sleep(&self, duration: Duration) -> Result<()> {
577        // If we resumed from a sleep, the timer already expired - continue immediately
578        if self.resumed_from_sleep {
579            return Ok(());
580        }
581
582        let wake_at = Utc::now() + chrono::Duration::from_std(duration).unwrap_or_default();
583        self.sleep_until(wake_at).await
584    }
585
586    /// Sleep until a specific time.
587    ///
588    /// If the wake time has already passed, returns immediately.
589    ///
590    /// # Example
591    /// ```ignore
592    /// use chrono::{Utc, Duration};
593    /// let renewal_date = Utc::now() + Duration::days(30);
594    /// ctx.sleep_until(renewal_date).await?;
595    /// ```
596    pub async fn sleep_until(&self, wake_at: DateTime<Utc>) -> Result<()> {
597        // If we resumed from a sleep, the timer already expired - continue immediately
598        if self.resumed_from_sleep {
599            return Ok(());
600        }
601
602        // If wake time already passed, return immediately
603        if wake_at <= Utc::now() {
604            return Ok(());
605        }
606
607        // Persist the wake time to database
608        self.set_wake_at(wake_at).await?;
609
610        // Signal suspension to executor
611        self.signal_suspend(SuspendReason::Sleep { wake_at })
612            .await?;
613
614        Ok(())
615    }
616
617    /// Wait for an external event with optional timeout.
618    ///
619    /// The workflow suspends until the event arrives or the timeout expires.
620    /// Events are correlated by the workflow run ID.
621    ///
622    /// # Example
623    /// ```ignore
624    /// let payment: PaymentConfirmation = ctx.wait_for_event(
625    ///     "payment_confirmed",
626    ///     Some(Duration::from_secs(7 * 24 * 60 * 60)), // 7 days
627    /// ).await?;
628    /// ```
629    pub async fn wait_for_event<T: DeserializeOwned>(
630        &self,
631        event_name: &str,
632        timeout: Option<Duration>,
633    ) -> Result<T> {
634        let correlation_id = self.run_id.to_string();
635
636        // Check if event already exists (race condition handling)
637        if let Some(event) = self.try_consume_event(event_name, &correlation_id).await? {
638            return serde_json::from_value(event.payload.unwrap_or_default())
639                .map_err(|e| ForgeError::Deserialization(e.to_string()));
640        }
641
642        // Calculate timeout
643        let timeout_at =
644            timeout.map(|d| Utc::now() + chrono::Duration::from_std(d).unwrap_or_default());
645
646        // Persist waiting state
647        self.set_waiting_for_event(event_name, timeout_at).await?;
648
649        // Signal suspension
650        self.signal_suspend(SuspendReason::WaitingEvent {
651            event_name: event_name.to_string(),
652            timeout: timeout_at,
653        })
654        .await?;
655
656        // After resume, try to consume the event
657        self.try_consume_event(event_name, &correlation_id)
658            .await?
659            .and_then(|e| e.payload)
660            .and_then(|p| serde_json::from_value(p).ok())
661            .ok_or_else(|| ForgeError::Timeout(format!("Event '{}' timed out", event_name)))
662    }
663
664    /// Try to consume an event from the database.
665    #[allow(clippy::type_complexity)]
666    async fn try_consume_event(
667        &self,
668        event_name: &str,
669        correlation_id: &str,
670    ) -> Result<Option<WorkflowEvent>> {
671        let result: Option<(
672            Uuid,
673            String,
674            String,
675            Option<serde_json::Value>,
676            DateTime<Utc>,
677        )> = sqlx::query_as(
678            r#"
679                UPDATE forge_workflow_events
680                SET consumed_at = NOW(), consumed_by = $3
681                WHERE id = (
682                    SELECT id FROM forge_workflow_events
683                    WHERE event_name = $1 AND correlation_id = $2 AND consumed_at IS NULL
684                    ORDER BY created_at ASC LIMIT 1
685                    FOR UPDATE SKIP LOCKED
686                )
687                RETURNING id, event_name, correlation_id, payload, created_at
688                "#,
689        )
690        .bind(event_name)
691        .bind(correlation_id)
692        .bind(self.run_id)
693        .fetch_optional(&self.db_pool)
694        .await
695        .map_err(|e| ForgeError::Database(e.to_string()))?;
696
697        Ok(result.map(
698            |(id, event_name, correlation_id, payload, created_at)| WorkflowEvent {
699                id,
700                event_name,
701                correlation_id,
702                payload,
703                created_at,
704            },
705        ))
706    }
707
708    /// Persist wake time to database.
709    async fn set_wake_at(&self, wake_at: DateTime<Utc>) -> Result<()> {
710        sqlx::query(
711            r#"
712            UPDATE forge_workflow_runs
713            SET status = 'waiting', suspended_at = NOW(), wake_at = $2
714            WHERE id = $1
715            "#,
716        )
717        .bind(self.run_id)
718        .bind(wake_at)
719        .execute(&self.db_pool)
720        .await
721        .map_err(|e| ForgeError::Database(e.to_string()))?;
722        Ok(())
723    }
724
725    /// Persist waiting for event state to database.
726    async fn set_waiting_for_event(
727        &self,
728        event_name: &str,
729        timeout_at: Option<DateTime<Utc>>,
730    ) -> Result<()> {
731        sqlx::query(
732            r#"
733            UPDATE forge_workflow_runs
734            SET status = 'waiting', suspended_at = NOW(), waiting_for_event = $2, event_timeout_at = $3
735            WHERE id = $1
736            "#,
737        )
738        .bind(self.run_id)
739        .bind(event_name)
740        .bind(timeout_at)
741        .execute(&self.db_pool)
742        .await
743        .map_err(|e| ForgeError::Database(e.to_string()))?;
744        Ok(())
745    }
746
747    /// Signal suspension to the executor.
748    async fn signal_suspend(&self, reason: SuspendReason) -> Result<()> {
749        if let Some(ref tx) = self.suspend_tx {
750            tx.send(reason)
751                .await
752                .map_err(|_| ForgeError::Internal("Failed to signal suspension".into()))?;
753        }
754        // Return a special error that the executor catches
755        Err(ForgeError::WorkflowSuspended)
756    }
757
758    // =========================================================================
759    // PARALLEL EXECUTION API
760    // =========================================================================
761
762    /// Create a parallel builder for executing steps concurrently.
763    ///
764    /// # Example
765    /// ```ignore
766    /// let results = ctx.parallel()
767    ///     .step("fetch_user", || async { get_user(id).await })
768    ///     .step("fetch_orders", || async { get_orders(id).await })
769    ///     .step_with_compensate("charge_card",
770    ///         || async { charge_card(amount).await },
771    ///         |charge| async move { refund(charge.id).await })
772    ///     .run().await?;
773    ///
774    /// let user: User = results.get("fetch_user")?;
775    /// let orders: Vec<Order> = results.get("fetch_orders")?;
776    /// ```
777    pub fn parallel(&self) -> ParallelBuilder<'_> {
778        ParallelBuilder::new(self)
779    }
780
781    // =========================================================================
782    // FLUENT STEP API
783    // =========================================================================
784
785    /// Create a step runner for executing a workflow step.
786    ///
787    /// This provides a fluent API for defining steps with retry, compensation,
788    /// timeout, and optional behavior.
789    ///
790    /// # Examples
791    ///
792    /// ```ignore
793    /// use std::time::Duration;
794    ///
795    /// // Simple step
796    /// let data = ctx.step("fetch_data", || async {
797    ///     Ok(fetch_from_api().await?)
798    /// }).run().await?;
799    ///
800    /// // Step with retry (3 attempts, 2 second delay)
801    /// ctx.step("send_email", || async {
802    ///     send_verification_email(&user.email).await
803    /// })
804    /// .retry(3, Duration::from_secs(2))
805    /// .run()
806    /// .await?;
807    ///
808    /// // Step with compensation (rollback on later failure)
809    /// let charge = ctx.step("charge_card", || async {
810    ///     charge_credit_card(&card).await
811    /// })
812    /// .compensate(|charge_result| async move {
813    ///     refund_charge(&charge_result.charge_id).await
814    /// })
815    /// .run()
816    /// .await?;
817    ///
818    /// // Optional step (failure won't trigger compensation)
819    /// ctx.step("notify_slack", || async {
820    ///     post_to_slack("User signed up!").await
821    /// })
822    /// .optional()
823    /// .run()
824    /// .await?;
825    ///
826    /// // Step with timeout
827    /// ctx.step("slow_operation", || async {
828    ///     process_large_file().await
829    /// })
830    /// .timeout(Duration::from_secs(60))
831    /// .run()
832    /// .await?;
833    /// ```
834    pub fn step<T, F, Fut>(&self, name: impl Into<String>, f: F) -> super::StepRunner<'_, T>
835    where
836        T: serde::Serialize + serde::de::DeserializeOwned + Clone + Send + Sync + 'static,
837        F: FnOnce() -> Fut + Send + 'static,
838        Fut: std::future::Future<Output = crate::Result<T>> + Send + 'static,
839    {
840        super::StepRunner::new(self, name, f)
841    }
842}
843
844#[cfg(test)]
845mod tests {
846    use super::*;
847
848    #[tokio::test]
849    async fn test_workflow_context_creation() {
850        let pool = sqlx::postgres::PgPoolOptions::new()
851            .max_connections(1)
852            .connect_lazy("postgres://localhost/nonexistent")
853            .expect("Failed to create mock pool");
854
855        let run_id = Uuid::new_v4();
856        let ctx = WorkflowContext::new(
857            run_id,
858            "test_workflow".to_string(),
859            1,
860            pool,
861            reqwest::Client::new(),
862        );
863
864        assert_eq!(ctx.run_id, run_id);
865        assert_eq!(ctx.workflow_name, "test_workflow");
866        assert_eq!(ctx.version, 1);
867    }
868
869    #[tokio::test]
870    async fn test_step_state_tracking() {
871        let pool = sqlx::postgres::PgPoolOptions::new()
872            .max_connections(1)
873            .connect_lazy("postgres://localhost/nonexistent")
874            .expect("Failed to create mock pool");
875
876        let ctx = WorkflowContext::new(
877            Uuid::new_v4(),
878            "test".to_string(),
879            1,
880            pool,
881            reqwest::Client::new(),
882        );
883
884        ctx.record_step_start("step1");
885        assert!(!ctx.is_step_completed("step1"));
886
887        ctx.record_step_complete("step1", serde_json::json!({"result": "ok"}));
888        assert!(ctx.is_step_completed("step1"));
889
890        let result: Option<serde_json::Value> = ctx.get_step_result("step1");
891        assert!(result.is_some());
892    }
893
894    #[test]
895    fn test_step_state_transitions() {
896        let mut state = StepState::new("test");
897        assert_eq!(state.status, StepStatus::Pending);
898
899        state.start();
900        assert_eq!(state.status, StepStatus::Running);
901        assert!(state.started_at.is_some());
902
903        state.complete(serde_json::json!({}));
904        assert_eq!(state.status, StepStatus::Completed);
905        assert!(state.completed_at.is_some());
906    }
907}