Skip to main content

forge_runtime/workflow/
executor.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use tokio::sync::RwLock;
7use uuid::Uuid;
8
9use super::registry::WorkflowRegistry;
10use super::state::{WorkflowRecord, WorkflowStepRecord};
11use forge_core::CircuitBreakerClient;
12use forge_core::function::WorkflowDispatch;
13use forge_core::workflow::{CompensationHandler, StepStatus, WorkflowContext, WorkflowStatus};
14
15/// Workflow execution result.
16#[derive(Debug)]
17pub enum WorkflowResult {
18    /// Workflow completed successfully.
19    Completed(serde_json::Value),
20    /// Workflow is waiting for an external event.
21    Waiting { event_type: String },
22    /// Workflow failed.
23    Failed { error: String },
24    /// Workflow was compensated.
25    Compensated,
26}
27
28/// Compensation state for a running workflow.
29struct CompensationState {
30    handlers: HashMap<String, CompensationHandler>,
31    completed_steps: Vec<String>,
32}
33
34/// Executes workflows.
35pub struct WorkflowExecutor {
36    registry: Arc<WorkflowRegistry>,
37    pool: sqlx::PgPool,
38    http_client: CircuitBreakerClient,
39    /// Compensation state for active workflows (run_id -> state).
40    compensation_state: Arc<RwLock<HashMap<Uuid, CompensationState>>>,
41}
42
43impl WorkflowExecutor {
44    /// Create a new workflow executor.
45    pub fn new(
46        registry: Arc<WorkflowRegistry>,
47        pool: sqlx::PgPool,
48        http_client: CircuitBreakerClient,
49    ) -> Self {
50        Self {
51            registry,
52            pool,
53            http_client,
54            compensation_state: Arc::new(RwLock::new(HashMap::new())),
55        }
56    }
57
58    /// Start a new workflow.
59    /// Returns immediately with the run_id; workflow executes in the background.
60    pub async fn start<I: serde::Serialize>(
61        &self,
62        workflow_name: &str,
63        input: I,
64        owner_subject: Option<String>,
65    ) -> forge_core::Result<Uuid> {
66        let entry = self.registry.get(workflow_name).ok_or_else(|| {
67            forge_core::ForgeError::NotFound(format!("Workflow '{}' not found", workflow_name))
68        })?;
69
70        let input_value = serde_json::to_value(input)?;
71
72        let record = WorkflowRecord::new(
73            workflow_name,
74            entry.info.version,
75            input_value.clone(),
76            owner_subject,
77        );
78        let run_id = record.id;
79
80        // Clone entry data for background execution
81        let entry_info = entry.info.clone();
82        let entry_handler = entry.handler.clone();
83
84        // Persist workflow record
85        self.save_workflow(&record).await?;
86
87        // Execute workflow in background
88        let registry = self.registry.clone();
89        let pool = self.pool.clone();
90        let http_client = self.http_client.clone();
91        let compensation_state = self.compensation_state.clone();
92
93        tokio::spawn(async move {
94            let executor = WorkflowExecutor {
95                registry,
96                pool,
97                http_client,
98                compensation_state,
99            };
100            let entry = super::registry::WorkflowEntry {
101                info: entry_info,
102                handler: entry_handler,
103            };
104            if let Err(e) = executor.execute_workflow(run_id, &entry, input_value).await {
105                tracing::error!(
106                    workflow_run_id = %run_id,
107                    error = %e,
108                    "Workflow execution failed"
109                );
110            }
111        });
112
113        Ok(run_id)
114    }
115
116    /// Execute a workflow.
117    async fn execute_workflow(
118        &self,
119        run_id: Uuid,
120        entry: &super::registry::WorkflowEntry,
121        input: serde_json::Value,
122    ) -> forge_core::Result<WorkflowResult> {
123        // Update status to running
124        self.update_workflow_status(run_id, WorkflowStatus::Running)
125            .await?;
126
127        // Create workflow context
128        let ctx = WorkflowContext::new(
129            run_id,
130            entry.info.name.to_string(),
131            entry.info.version,
132            self.pool.clone(),
133            self.http_client.inner().clone(),
134        );
135
136        // Execute workflow with timeout
137        let handler = entry.handler.clone();
138        let result = tokio::time::timeout(entry.info.timeout, handler(&ctx, input)).await;
139
140        // Capture compensation state after execution
141        let compensation_state = CompensationState {
142            handlers: ctx.compensation_handlers(),
143            completed_steps: ctx.completed_steps_reversed().into_iter().rev().collect(),
144        };
145        self.compensation_state
146            .write()
147            .await
148            .insert(run_id, compensation_state);
149
150        match result {
151            Ok(Ok(output)) => {
152                // Mark as completed, clean up compensation state
153                self.complete_workflow(run_id, output.clone()).await?;
154                self.compensation_state.write().await.remove(&run_id);
155                Ok(WorkflowResult::Completed(output))
156            }
157            Ok(Err(e)) => {
158                // Check if this is a suspension (not a real failure)
159                if matches!(e, forge_core::ForgeError::WorkflowSuspended) {
160                    // Workflow suspended itself (sleep or wait_for_event)
161                    // Status already set to 'waiting' by ctx.sleep() - don't mark as failed
162                    return Ok(WorkflowResult::Waiting {
163                        event_type: "timer".to_string(),
164                    });
165                }
166                // Mark as failed - compensation can be triggered via cancel
167                self.fail_workflow(run_id, &e.to_string()).await?;
168                Ok(WorkflowResult::Failed {
169                    error: e.to_string(),
170                })
171            }
172            Err(_) => {
173                // Timeout
174                self.fail_workflow(run_id, "Workflow timed out").await?;
175                Ok(WorkflowResult::Failed {
176                    error: "Workflow timed out".to_string(),
177                })
178            }
179        }
180    }
181
182    /// Execute a resumed workflow with step states loaded from the database.
183    async fn execute_workflow_resumed(
184        &self,
185        run_id: Uuid,
186        entry: &super::registry::WorkflowEntry,
187        input: serde_json::Value,
188        started_at: chrono::DateTime<chrono::Utc>,
189        from_sleep: bool,
190    ) -> forge_core::Result<WorkflowResult> {
191        // Update status to running
192        self.update_workflow_status(run_id, WorkflowStatus::Running)
193            .await?;
194
195        // Load step states from database
196        let step_records = self.get_workflow_steps(run_id).await?;
197        let mut step_states = std::collections::HashMap::new();
198        for step in step_records {
199            let status = step.status;
200            step_states.insert(
201                step.step_name.clone(),
202                forge_core::workflow::StepState {
203                    name: step.step_name,
204                    status,
205                    result: step.result,
206                    error: step.error,
207                    started_at: step.started_at,
208                    completed_at: step.completed_at,
209                },
210            );
211        }
212
213        // Create resumed workflow context with step states
214        let mut ctx = WorkflowContext::resumed(
215            run_id,
216            entry.info.name.to_string(),
217            entry.info.version,
218            started_at,
219            self.pool.clone(),
220            self.http_client.inner().clone(),
221        )
222        .with_step_states(step_states);
223
224        // If resuming from a sleep timer, mark the context so sleep() returns immediately
225        if from_sleep {
226            ctx = ctx.with_resumed_from_sleep();
227        }
228
229        // Execute workflow with timeout
230        let handler = entry.handler.clone();
231        let result = tokio::time::timeout(entry.info.timeout, handler(&ctx, input)).await;
232
233        // Capture compensation state after execution
234        let compensation_state = CompensationState {
235            handlers: ctx.compensation_handlers(),
236            completed_steps: ctx.completed_steps_reversed().into_iter().rev().collect(),
237        };
238        self.compensation_state
239            .write()
240            .await
241            .insert(run_id, compensation_state);
242
243        match result {
244            Ok(Ok(output)) => {
245                // Mark as completed, clean up compensation state
246                self.complete_workflow(run_id, output.clone()).await?;
247                self.compensation_state.write().await.remove(&run_id);
248                Ok(WorkflowResult::Completed(output))
249            }
250            Ok(Err(e)) => {
251                // Check if this is a suspension (not a real failure)
252                if matches!(e, forge_core::ForgeError::WorkflowSuspended) {
253                    // Workflow suspended itself (sleep or wait_for_event)
254                    // Status already set to 'waiting' by ctx.sleep() - don't mark as failed
255                    return Ok(WorkflowResult::Waiting {
256                        event_type: "timer".to_string(),
257                    });
258                }
259                // Mark as failed - compensation can be triggered via cancel
260                self.fail_workflow(run_id, &e.to_string()).await?;
261                Ok(WorkflowResult::Failed {
262                    error: e.to_string(),
263                })
264            }
265            Err(_) => {
266                // Timeout
267                self.fail_workflow(run_id, "Workflow timed out").await?;
268                Ok(WorkflowResult::Failed {
269                    error: "Workflow timed out".to_string(),
270                })
271            }
272        }
273    }
274
275    /// Resume a workflow from where it left off.
276    pub async fn resume(&self, run_id: Uuid) -> forge_core::Result<WorkflowResult> {
277        self.resume_internal(run_id, false).await
278    }
279
280    /// Resume a workflow after a sleep timer expired.
281    pub async fn resume_from_sleep(&self, run_id: Uuid) -> forge_core::Result<WorkflowResult> {
282        self.resume_internal(run_id, true).await
283    }
284
285    /// Internal resume logic.
286    async fn resume_internal(
287        &self,
288        run_id: Uuid,
289        from_sleep: bool,
290    ) -> forge_core::Result<WorkflowResult> {
291        let record = self.get_workflow(run_id).await?;
292
293        let entry = self
294            .registry
295            .get_version(&record.workflow_name, record.version)
296            .ok_or_else(|| {
297                forge_core::ForgeError::NotFound(format!(
298                    "Workflow '{}' version {} not found",
299                    record.workflow_name, record.version
300                ))
301            })?;
302
303        // Check if workflow is resumable
304        match record.status {
305            WorkflowStatus::Running | WorkflowStatus::Waiting => {
306                // Can resume
307            }
308            status if status.is_terminal() => {
309                return Err(forge_core::ForgeError::Validation(format!(
310                    "Cannot resume workflow in {} state",
311                    status.as_str()
312                )));
313            }
314            _ => {}
315        }
316
317        self.execute_workflow_resumed(run_id, entry, record.input, record.started_at, from_sleep)
318            .await
319    }
320
321    /// Get workflow status.
322    pub async fn status(&self, run_id: Uuid) -> forge_core::Result<WorkflowRecord> {
323        self.get_workflow(run_id).await
324    }
325
326    /// Cancel a workflow and run compensation.
327    ///
328    /// Compensation follows the saga pattern: steps are undone in reverse order
329    /// of their completion. This ensures that dependencies are respected. For
330    /// example, if step A created a resource that step B modified, we must
331    /// undo B's modification before deleting A's resource.
332    ///
333    /// Compensation handlers receive the original step result, allowing them
334    /// to know exactly what to undo (e.g., refund the specific payment ID).
335    pub async fn cancel(&self, run_id: Uuid) -> forge_core::Result<()> {
336        self.update_workflow_status(run_id, WorkflowStatus::Compensating)
337            .await?;
338
339        // Get compensation state
340        let state = self.compensation_state.write().await.remove(&run_id);
341
342        if let Some(state) = state {
343            // Get completed steps with results from database
344            let steps = self.get_workflow_steps(run_id).await?;
345
346            // Run compensation in reverse order of completion.
347            // This is critical for maintaining consistency: if step B depends on
348            // step A's output, we must undo B before A. The completed_steps list
349            // preserves insertion order, so reversing gives us the correct undo order.
350            for step_name in state.completed_steps.iter().rev() {
351                if let Some(handler) = state.handlers.get(step_name) {
352                    // Find the step result
353                    let step_result = steps
354                        .iter()
355                        .find(|s| &s.step_name == step_name)
356                        .and_then(|s| s.result.clone())
357                        .unwrap_or(serde_json::Value::Null);
358
359                    // Run compensation handler
360                    match handler(step_result).await {
361                        Ok(()) => {
362                            tracing::info!(
363                                workflow_run_id = %run_id,
364                                step = %step_name,
365                                "Compensation completed"
366                            );
367                            self.update_step_status(run_id, step_name, StepStatus::Compensated)
368                                .await?;
369                        }
370                        Err(e) => {
371                            tracing::error!(
372                                workflow_run_id = %run_id,
373                                step = %step_name,
374                                error = %e,
375                                "Compensation failed"
376                            );
377                            // Continue with other compensations even if one fails
378                        }
379                    }
380                } else {
381                    // No handler, just mark as compensated
382                    self.update_step_status(run_id, step_name, StepStatus::Compensated)
383                        .await?;
384                }
385            }
386        } else {
387            // Fail closed: never report compensated when handlers are unavailable.
388            let msg =
389                "Compensation handlers unavailable (likely restart); refusing to mark compensated";
390            tracing::error!(workflow_run_id = %run_id, "{msg}");
391            self.fail_workflow(run_id, msg).await?;
392            return Err(forge_core::ForgeError::InvalidState(msg.to_string()));
393        }
394
395        self.update_workflow_status(run_id, WorkflowStatus::Compensated)
396            .await?;
397
398        Ok(())
399    }
400
401    /// Get workflow steps from database.
402    async fn get_workflow_steps(
403        &self,
404        workflow_run_id: Uuid,
405    ) -> forge_core::Result<Vec<WorkflowStepRecord>> {
406        let rows = sqlx::query(
407            r#"
408            SELECT id, workflow_run_id, step_name, status, result, error, started_at, completed_at
409            FROM forge_workflow_steps
410            WHERE workflow_run_id = $1
411            ORDER BY started_at ASC
412            "#,
413        )
414        .bind(workflow_run_id)
415        .fetch_all(&self.pool)
416        .await
417        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
418
419        use sqlx::Row;
420        rows.into_iter()
421            .map(|row| {
422                let status_str = row.get::<String, _>("status");
423                let status = status_str.parse().map_err(|e| {
424                    forge_core::ForgeError::Database(format!(
425                        "Invalid step status '{}': {}",
426                        status_str, e
427                    ))
428                })?;
429                Ok(WorkflowStepRecord {
430                    id: row.get("id"),
431                    workflow_run_id: row.get("workflow_run_id"),
432                    step_name: row.get("step_name"),
433                    status,
434                    result: row.get("result"),
435                    error: row.get("error"),
436                    started_at: row.get("started_at"),
437                    completed_at: row.get("completed_at"),
438                })
439            })
440            .collect()
441    }
442
443    /// Update step status.
444    async fn update_step_status(
445        &self,
446        workflow_run_id: Uuid,
447        step_name: &str,
448        status: StepStatus,
449    ) -> forge_core::Result<()> {
450        sqlx::query(
451            r#"
452            UPDATE forge_workflow_steps
453            SET status = $3
454            WHERE workflow_run_id = $1 AND step_name = $2
455            "#,
456        )
457        .bind(workflow_run_id)
458        .bind(step_name)
459        .bind(status.as_str())
460        .execute(&self.pool)
461        .await
462        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
463
464        Ok(())
465    }
466
467    /// Save workflow record to database.
468    async fn save_workflow(&self, record: &WorkflowRecord) -> forge_core::Result<()> {
469        sqlx::query(
470            r#"
471            INSERT INTO forge_workflow_runs (
472                id, workflow_name, version, owner_subject, input, status, current_step,
473                step_results, started_at, trace_id
474            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
475            "#,
476        )
477        .bind(record.id)
478        .bind(&record.workflow_name)
479        .bind(record.version as i32)
480        .bind(&record.owner_subject)
481        .bind(&record.input)
482        .bind(record.status.as_str())
483        .bind(&record.current_step)
484        .bind(&record.step_results)
485        .bind(record.started_at)
486        .bind(&record.trace_id)
487        .execute(&self.pool)
488        .await
489        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
490
491        Ok(())
492    }
493
494    /// Get workflow record from database.
495    async fn get_workflow(&self, run_id: Uuid) -> forge_core::Result<WorkflowRecord> {
496        let row = sqlx::query(
497            r#"
498            SELECT id, workflow_name, version, owner_subject, input, output, status, current_step,
499                   step_results, started_at, completed_at, error, trace_id
500            FROM forge_workflow_runs
501            WHERE id = $1
502            "#,
503        )
504        .bind(run_id)
505        .fetch_optional(&self.pool)
506        .await
507        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
508
509        let row = row.ok_or_else(|| {
510            forge_core::ForgeError::NotFound(format!("Workflow run {} not found", run_id))
511        })?;
512
513        use sqlx::Row;
514        let status_str = row.get::<String, _>("status");
515        let status = status_str.parse().map_err(|e| {
516            forge_core::ForgeError::Database(format!(
517                "Invalid workflow status '{}': {}",
518                status_str, e
519            ))
520        })?;
521        Ok(WorkflowRecord {
522            id: row.get("id"),
523            workflow_name: row.get("workflow_name"),
524            version: row.get::<i32, _>("version") as u32,
525            owner_subject: row.get("owner_subject"),
526            input: row.get("input"),
527            output: row.get("output"),
528            status,
529            current_step: row.get("current_step"),
530            step_results: row.get("step_results"),
531            started_at: row.get("started_at"),
532            completed_at: row.get("completed_at"),
533            error: row.get("error"),
534            trace_id: row.get("trace_id"),
535        })
536    }
537
538    /// Update workflow status.
539    async fn update_workflow_status(
540        &self,
541        run_id: Uuid,
542        status: WorkflowStatus,
543    ) -> forge_core::Result<()> {
544        sqlx::query(
545            r#"
546            UPDATE forge_workflow_runs
547            SET status = $2
548            WHERE id = $1
549            "#,
550        )
551        .bind(run_id)
552        .bind(status.as_str())
553        .execute(&self.pool)
554        .await
555        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
556
557        Ok(())
558    }
559
560    /// Mark workflow as completed.
561    async fn complete_workflow(
562        &self,
563        run_id: Uuid,
564        output: serde_json::Value,
565    ) -> forge_core::Result<()> {
566        sqlx::query(
567            r#"
568            UPDATE forge_workflow_runs
569            SET status = 'completed', output = $2, completed_at = NOW()
570            WHERE id = $1
571            "#,
572        )
573        .bind(run_id)
574        .bind(output)
575        .execute(&self.pool)
576        .await
577        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
578
579        Ok(())
580    }
581
582    /// Mark workflow as failed.
583    async fn fail_workflow(&self, run_id: Uuid, error: &str) -> forge_core::Result<()> {
584        sqlx::query(
585            r#"
586            UPDATE forge_workflow_runs
587            SET status = 'failed', error = $2, completed_at = NOW()
588            WHERE id = $1
589            "#,
590        )
591        .bind(run_id)
592        .bind(error)
593        .execute(&self.pool)
594        .await
595        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
596
597        Ok(())
598    }
599
600    /// Save step record.
601    pub async fn save_step(&self, step: &WorkflowStepRecord) -> forge_core::Result<()> {
602        sqlx::query(
603            r#"
604            INSERT INTO forge_workflow_steps (
605                id, workflow_run_id, step_name, status, result, error, started_at, completed_at
606            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
607            ON CONFLICT (workflow_run_id, step_name) DO UPDATE SET
608                status = EXCLUDED.status,
609                result = EXCLUDED.result,
610                error = EXCLUDED.error,
611                started_at = COALESCE(forge_workflow_steps.started_at, EXCLUDED.started_at),
612                completed_at = EXCLUDED.completed_at
613            "#,
614        )
615        .bind(step.id)
616        .bind(step.workflow_run_id)
617        .bind(&step.step_name)
618        .bind(step.status.as_str())
619        .bind(&step.result)
620        .bind(&step.error)
621        .bind(step.started_at)
622        .bind(step.completed_at)
623        .execute(&self.pool)
624        .await
625        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
626
627        Ok(())
628    }
629
630    /// Start a workflow by its registered name with JSON input.
631    pub async fn start_by_name(
632        &self,
633        workflow_name: &str,
634        input: serde_json::Value,
635        owner_subject: Option<String>,
636    ) -> forge_core::Result<Uuid> {
637        self.start(workflow_name, input, owner_subject).await
638    }
639}
640
641impl WorkflowDispatch for WorkflowExecutor {
642    fn get_info(&self, workflow_name: &str) -> Option<forge_core::workflow::WorkflowInfo> {
643        self.registry.get(workflow_name).map(|e| e.info.clone())
644    }
645
646    fn start_by_name(
647        &self,
648        workflow_name: &str,
649        input: serde_json::Value,
650        owner_subject: Option<String>,
651    ) -> Pin<Box<dyn Future<Output = forge_core::Result<Uuid>> + Send + '_>> {
652        let workflow_name = workflow_name.to_string();
653        Box::pin(async move {
654            self.start_by_name(&workflow_name, input, owner_subject)
655                .await
656        })
657    }
658}
659
660#[cfg(test)]
661#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
662mod tests {
663    use super::*;
664
665    #[test]
666    fn test_workflow_result_types() {
667        let completed = WorkflowResult::Completed(serde_json::json!({}));
668        let _waiting = WorkflowResult::Waiting {
669            event_type: "approval".to_string(),
670        };
671        let _failed = WorkflowResult::Failed {
672            error: "test".to_string(),
673        };
674        let _compensated = WorkflowResult::Compensated;
675
676        // Just ensure they can be created
677        match completed {
678            WorkflowResult::Completed(_) => {}
679            _ => panic!("Expected Completed"),
680        }
681    }
682}