Skip to main content

forge_runtime/workflow/
executor.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use tokio::sync::RwLock;
7use uuid::Uuid;
8
9use super::registry::WorkflowRegistry;
10use super::state::{WorkflowRecord, WorkflowStepRecord};
11use forge_core::CircuitBreakerClient;
12use forge_core::function::WorkflowDispatch;
13use forge_core::workflow::{CompensationHandler, StepStatus, WorkflowContext, WorkflowStatus};
14
15/// Workflow execution result.
16#[derive(Debug)]
17pub enum WorkflowResult {
18    /// Workflow completed successfully.
19    Completed(serde_json::Value),
20    /// Workflow is waiting for an external event.
21    Waiting { event_type: String },
22    /// Workflow failed.
23    Failed { error: String },
24    /// Workflow was compensated.
25    Compensated,
26}
27
28/// Compensation state for a running workflow.
29struct CompensationState {
30    handlers: HashMap<String, CompensationHandler>,
31    completed_steps: Vec<String>,
32}
33
34/// Executes workflows.
35pub struct WorkflowExecutor {
36    registry: Arc<WorkflowRegistry>,
37    pool: sqlx::PgPool,
38    http_client: CircuitBreakerClient,
39    /// Compensation state for active workflows (run_id -> state).
40    compensation_state: Arc<RwLock<HashMap<Uuid, CompensationState>>>,
41}
42
43impl WorkflowExecutor {
44    /// Create a new workflow executor.
45    pub fn new(
46        registry: Arc<WorkflowRegistry>,
47        pool: sqlx::PgPool,
48        http_client: CircuitBreakerClient,
49    ) -> Self {
50        Self {
51            registry,
52            pool,
53            http_client,
54            compensation_state: Arc::new(RwLock::new(HashMap::new())),
55        }
56    }
57
58    /// Start a new workflow.
59    /// Returns immediately with the run_id; workflow executes in the background.
60    pub async fn start<I: serde::Serialize>(
61        &self,
62        workflow_name: &str,
63        input: I,
64        owner_subject: Option<String>,
65    ) -> forge_core::Result<Uuid> {
66        let entry = self.registry.get(workflow_name).ok_or_else(|| {
67            forge_core::ForgeError::NotFound(format!("Workflow '{}' not found", workflow_name))
68        })?;
69
70        let input_value = serde_json::to_value(input)?;
71
72        let record = WorkflowRecord::new(
73            workflow_name,
74            entry.info.version,
75            input_value.clone(),
76            owner_subject,
77        );
78        let run_id = record.id;
79
80        // Clone entry data for background execution
81        let entry_info = entry.info.clone();
82        let entry_handler = entry.handler.clone();
83
84        // Persist workflow record
85        self.save_workflow(&record).await?;
86
87        // Execute workflow in background
88        let registry = self.registry.clone();
89        let pool = self.pool.clone();
90        let http_client = self.http_client.clone();
91        let compensation_state = self.compensation_state.clone();
92
93        tokio::spawn(async move {
94            let executor = WorkflowExecutor {
95                registry,
96                pool,
97                http_client,
98                compensation_state,
99            };
100            let entry = super::registry::WorkflowEntry {
101                info: entry_info,
102                handler: entry_handler,
103            };
104            if let Err(e) = executor.execute_workflow(run_id, &entry, input_value).await {
105                tracing::error!(
106                    workflow_run_id = %run_id,
107                    error = %e,
108                    "Workflow execution failed"
109                );
110            }
111        });
112
113        Ok(run_id)
114    }
115
116    /// Execute a workflow.
117    async fn execute_workflow(
118        &self,
119        run_id: Uuid,
120        entry: &super::registry::WorkflowEntry,
121        input: serde_json::Value,
122    ) -> forge_core::Result<WorkflowResult> {
123        // Update status to running
124        self.update_workflow_status(run_id, WorkflowStatus::Running)
125            .await?;
126
127        // Create workflow context
128        let mut ctx = WorkflowContext::new(
129            run_id,
130            entry.info.name.to_string(),
131            entry.info.version,
132            self.pool.clone(),
133            self.http_client.clone(),
134        );
135        ctx.set_http_timeout(entry.info.http_timeout);
136
137        // Execute workflow with timeout
138        let handler = entry.handler.clone();
139        let result = tokio::time::timeout(entry.info.timeout, handler(&ctx, input)).await;
140
141        // Capture compensation state after execution
142        let compensation_state = CompensationState {
143            handlers: ctx.compensation_handlers(),
144            completed_steps: ctx.completed_steps_reversed().into_iter().rev().collect(),
145        };
146        self.compensation_state
147            .write()
148            .await
149            .insert(run_id, compensation_state);
150
151        match result {
152            Ok(Ok(output)) => {
153                // Mark as completed, clean up compensation state
154                self.complete_workflow(run_id, output.clone()).await?;
155                self.compensation_state.write().await.remove(&run_id);
156                Ok(WorkflowResult::Completed(output))
157            }
158            Ok(Err(e)) => {
159                // Check if this is a suspension (not a real failure)
160                if matches!(e, forge_core::ForgeError::WorkflowSuspended) {
161                    // Workflow suspended itself (sleep or wait_for_event)
162                    // Status already set to 'waiting' by ctx.sleep() - don't mark as failed
163                    return Ok(WorkflowResult::Waiting {
164                        event_type: "timer".to_string(),
165                    });
166                }
167                // Mark as failed - compensation can be triggered via cancel
168                self.fail_workflow(run_id, &e.to_string()).await?;
169                Ok(WorkflowResult::Failed {
170                    error: e.to_string(),
171                })
172            }
173            Err(_) => {
174                // Timeout
175                self.fail_workflow(run_id, "Workflow timed out").await?;
176                Ok(WorkflowResult::Failed {
177                    error: "Workflow timed out".to_string(),
178                })
179            }
180        }
181    }
182
183    /// Execute a resumed workflow with step states loaded from the database.
184    async fn execute_workflow_resumed(
185        &self,
186        run_id: Uuid,
187        entry: &super::registry::WorkflowEntry,
188        input: serde_json::Value,
189        started_at: chrono::DateTime<chrono::Utc>,
190        from_sleep: bool,
191    ) -> forge_core::Result<WorkflowResult> {
192        // Update status to running
193        self.update_workflow_status(run_id, WorkflowStatus::Running)
194            .await?;
195
196        // Load step states from database
197        let step_records = self.get_workflow_steps(run_id).await?;
198        let mut step_states = std::collections::HashMap::new();
199        for step in step_records {
200            let status = step.status;
201            step_states.insert(
202                step.step_name.clone(),
203                forge_core::workflow::StepState {
204                    name: step.step_name,
205                    status,
206                    result: step.result,
207                    error: step.error,
208                    started_at: step.started_at,
209                    completed_at: step.completed_at,
210                },
211            );
212        }
213
214        // Create resumed workflow context with step states
215        let mut ctx = WorkflowContext::resumed(
216            run_id,
217            entry.info.name.to_string(),
218            entry.info.version,
219            started_at,
220            self.pool.clone(),
221            self.http_client.clone(),
222        )
223        .with_step_states(step_states);
224        ctx.set_http_timeout(entry.info.http_timeout);
225
226        // If resuming from a sleep timer, mark the context so sleep() returns immediately
227        if from_sleep {
228            ctx = ctx.with_resumed_from_sleep();
229        }
230
231        // Execute workflow with timeout
232        let handler = entry.handler.clone();
233        let result = tokio::time::timeout(entry.info.timeout, handler(&ctx, input)).await;
234
235        // Capture compensation state after execution
236        let compensation_state = CompensationState {
237            handlers: ctx.compensation_handlers(),
238            completed_steps: ctx.completed_steps_reversed().into_iter().rev().collect(),
239        };
240        self.compensation_state
241            .write()
242            .await
243            .insert(run_id, compensation_state);
244
245        match result {
246            Ok(Ok(output)) => {
247                // Mark as completed, clean up compensation state
248                self.complete_workflow(run_id, output.clone()).await?;
249                self.compensation_state.write().await.remove(&run_id);
250                Ok(WorkflowResult::Completed(output))
251            }
252            Ok(Err(e)) => {
253                // Check if this is a suspension (not a real failure)
254                if matches!(e, forge_core::ForgeError::WorkflowSuspended) {
255                    // Workflow suspended itself (sleep or wait_for_event)
256                    // Status already set to 'waiting' by ctx.sleep() - don't mark as failed
257                    return Ok(WorkflowResult::Waiting {
258                        event_type: "timer".to_string(),
259                    });
260                }
261                // Mark as failed - compensation can be triggered via cancel
262                self.fail_workflow(run_id, &e.to_string()).await?;
263                Ok(WorkflowResult::Failed {
264                    error: e.to_string(),
265                })
266            }
267            Err(_) => {
268                // Timeout
269                self.fail_workflow(run_id, "Workflow timed out").await?;
270                Ok(WorkflowResult::Failed {
271                    error: "Workflow timed out".to_string(),
272                })
273            }
274        }
275    }
276
277    /// Resume a workflow from where it left off.
278    pub async fn resume(&self, run_id: Uuid) -> forge_core::Result<WorkflowResult> {
279        self.resume_internal(run_id, false).await
280    }
281
282    /// Resume a workflow after a sleep timer expired.
283    pub async fn resume_from_sleep(&self, run_id: Uuid) -> forge_core::Result<WorkflowResult> {
284        self.resume_internal(run_id, true).await
285    }
286
287    /// Internal resume logic.
288    async fn resume_internal(
289        &self,
290        run_id: Uuid,
291        from_sleep: bool,
292    ) -> forge_core::Result<WorkflowResult> {
293        let record = self.get_workflow(run_id).await?;
294
295        let entry = self
296            .registry
297            .get_version(&record.workflow_name, record.version)
298            .ok_or_else(|| {
299                forge_core::ForgeError::NotFound(format!(
300                    "Workflow '{}' version {} not found",
301                    record.workflow_name, record.version
302                ))
303            })?;
304
305        // Check if workflow is resumable
306        match record.status {
307            WorkflowStatus::Running | WorkflowStatus::Waiting => {
308                // Can resume
309            }
310            status if status.is_terminal() => {
311                return Err(forge_core::ForgeError::Validation(format!(
312                    "Cannot resume workflow in {} state",
313                    status.as_str()
314                )));
315            }
316            _ => {}
317        }
318
319        self.execute_workflow_resumed(run_id, entry, record.input, record.started_at, from_sleep)
320            .await
321    }
322
323    /// Get workflow status.
324    pub async fn status(&self, run_id: Uuid) -> forge_core::Result<WorkflowRecord> {
325        self.get_workflow(run_id).await
326    }
327
328    /// Cancel a workflow and run compensation.
329    ///
330    /// Compensation follows the saga pattern: steps are undone in reverse order
331    /// of their completion. This ensures that dependencies are respected. For
332    /// example, if step A created a resource that step B modified, we must
333    /// undo B's modification before deleting A's resource.
334    ///
335    /// Compensation handlers receive the original step result, allowing them
336    /// to know exactly what to undo (e.g., refund the specific payment ID).
337    pub async fn cancel(&self, run_id: Uuid) -> forge_core::Result<()> {
338        self.update_workflow_status(run_id, WorkflowStatus::Compensating)
339            .await?;
340
341        // Get compensation state
342        let state = self.compensation_state.write().await.remove(&run_id);
343
344        if let Some(state) = state {
345            // Get completed steps with results from database
346            let steps = self.get_workflow_steps(run_id).await?;
347
348            // Run compensation in reverse order of completion.
349            // This is critical for maintaining consistency: if step B depends on
350            // step A's output, we must undo B before A. The completed_steps list
351            // preserves insertion order, so reversing gives us the correct undo order.
352            for step_name in state.completed_steps.iter().rev() {
353                if let Some(handler) = state.handlers.get(step_name) {
354                    // Find the step result
355                    let step_result = steps
356                        .iter()
357                        .find(|s| &s.step_name == step_name)
358                        .and_then(|s| s.result.clone())
359                        .unwrap_or(serde_json::Value::Null);
360
361                    // Run compensation handler
362                    match handler(step_result).await {
363                        Ok(()) => {
364                            tracing::info!(
365                                workflow_run_id = %run_id,
366                                step = %step_name,
367                                "Compensation completed"
368                            );
369                            self.update_step_status(run_id, step_name, StepStatus::Compensated)
370                                .await?;
371                        }
372                        Err(e) => {
373                            tracing::error!(
374                                workflow_run_id = %run_id,
375                                step = %step_name,
376                                error = %e,
377                                "Compensation failed"
378                            );
379                            // Continue with other compensations even if one fails
380                        }
381                    }
382                } else {
383                    // No handler, just mark as compensated
384                    self.update_step_status(run_id, step_name, StepStatus::Compensated)
385                        .await?;
386                }
387            }
388        } else {
389            // Fail closed: never report compensated when handlers are unavailable.
390            let msg =
391                "Compensation handlers unavailable (likely restart); refusing to mark compensated";
392            tracing::error!(workflow_run_id = %run_id, "{msg}");
393            self.fail_workflow(run_id, msg).await?;
394            return Err(forge_core::ForgeError::InvalidState(msg.to_string()));
395        }
396
397        self.update_workflow_status(run_id, WorkflowStatus::Compensated)
398            .await?;
399
400        Ok(())
401    }
402
403    /// Get workflow steps from database.
404    async fn get_workflow_steps(
405        &self,
406        workflow_run_id: Uuid,
407    ) -> forge_core::Result<Vec<WorkflowStepRecord>> {
408        let rows = sqlx::query(
409            r#"
410            SELECT id, workflow_run_id, step_name, status, result, error, started_at, completed_at
411            FROM forge_workflow_steps
412            WHERE workflow_run_id = $1
413            ORDER BY started_at ASC
414            "#,
415        )
416        .bind(workflow_run_id)
417        .fetch_all(&self.pool)
418        .await
419        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
420
421        use sqlx::Row;
422        rows.into_iter()
423            .map(|row| {
424                let status_str = row.get::<String, _>("status");
425                let status = status_str.parse().map_err(|e| {
426                    forge_core::ForgeError::Database(format!(
427                        "Invalid step status '{}': {}",
428                        status_str, e
429                    ))
430                })?;
431                Ok(WorkflowStepRecord {
432                    id: row.get("id"),
433                    workflow_run_id: row.get("workflow_run_id"),
434                    step_name: row.get("step_name"),
435                    status,
436                    result: row.get("result"),
437                    error: row.get("error"),
438                    started_at: row.get("started_at"),
439                    completed_at: row.get("completed_at"),
440                })
441            })
442            .collect()
443    }
444
445    /// Update step status.
446    async fn update_step_status(
447        &self,
448        workflow_run_id: Uuid,
449        step_name: &str,
450        status: StepStatus,
451    ) -> forge_core::Result<()> {
452        sqlx::query(
453            r#"
454            UPDATE forge_workflow_steps
455            SET status = $3
456            WHERE workflow_run_id = $1 AND step_name = $2
457            "#,
458        )
459        .bind(workflow_run_id)
460        .bind(step_name)
461        .bind(status.as_str())
462        .execute(&self.pool)
463        .await
464        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
465
466        Ok(())
467    }
468
469    /// Save workflow record to database.
470    async fn save_workflow(&self, record: &WorkflowRecord) -> forge_core::Result<()> {
471        sqlx::query(
472            r#"
473            INSERT INTO forge_workflow_runs (
474                id, workflow_name, version, owner_subject, input, status, current_step,
475                step_results, started_at, trace_id
476            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
477            "#,
478        )
479        .bind(record.id)
480        .bind(&record.workflow_name)
481        .bind(record.version as i32)
482        .bind(&record.owner_subject)
483        .bind(&record.input)
484        .bind(record.status.as_str())
485        .bind(&record.current_step)
486        .bind(&record.step_results)
487        .bind(record.started_at)
488        .bind(&record.trace_id)
489        .execute(&self.pool)
490        .await
491        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
492
493        Ok(())
494    }
495
496    /// Get workflow record from database.
497    async fn get_workflow(&self, run_id: Uuid) -> forge_core::Result<WorkflowRecord> {
498        let row = sqlx::query(
499            r#"
500            SELECT id, workflow_name, version, owner_subject, input, output, status, current_step,
501                   step_results, started_at, completed_at, error, trace_id
502            FROM forge_workflow_runs
503            WHERE id = $1
504            "#,
505        )
506        .bind(run_id)
507        .fetch_optional(&self.pool)
508        .await
509        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
510
511        let row = row.ok_or_else(|| {
512            forge_core::ForgeError::NotFound(format!("Workflow run {} not found", run_id))
513        })?;
514
515        use sqlx::Row;
516        let status_str = row.get::<String, _>("status");
517        let status = status_str.parse().map_err(|e| {
518            forge_core::ForgeError::Database(format!(
519                "Invalid workflow status '{}': {}",
520                status_str, e
521            ))
522        })?;
523        Ok(WorkflowRecord {
524            id: row.get("id"),
525            workflow_name: row.get("workflow_name"),
526            version: row.get::<i32, _>("version") as u32,
527            owner_subject: row.get("owner_subject"),
528            input: row.get("input"),
529            output: row.get("output"),
530            status,
531            current_step: row.get("current_step"),
532            step_results: row.get("step_results"),
533            started_at: row.get("started_at"),
534            completed_at: row.get("completed_at"),
535            error: row.get("error"),
536            trace_id: row.get("trace_id"),
537        })
538    }
539
540    /// Update workflow status.
541    async fn update_workflow_status(
542        &self,
543        run_id: Uuid,
544        status: WorkflowStatus,
545    ) -> forge_core::Result<()> {
546        sqlx::query(
547            r#"
548            UPDATE forge_workflow_runs
549            SET status = $2
550            WHERE id = $1
551            "#,
552        )
553        .bind(run_id)
554        .bind(status.as_str())
555        .execute(&self.pool)
556        .await
557        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
558
559        Ok(())
560    }
561
562    /// Mark workflow as completed.
563    async fn complete_workflow(
564        &self,
565        run_id: Uuid,
566        output: serde_json::Value,
567    ) -> forge_core::Result<()> {
568        sqlx::query(
569            r#"
570            UPDATE forge_workflow_runs
571            SET status = 'completed', output = $2, completed_at = NOW()
572            WHERE id = $1
573            "#,
574        )
575        .bind(run_id)
576        .bind(output)
577        .execute(&self.pool)
578        .await
579        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
580
581        Ok(())
582    }
583
584    /// Mark workflow as failed.
585    async fn fail_workflow(&self, run_id: Uuid, error: &str) -> forge_core::Result<()> {
586        sqlx::query(
587            r#"
588            UPDATE forge_workflow_runs
589            SET status = 'failed', error = $2, completed_at = NOW()
590            WHERE id = $1
591            "#,
592        )
593        .bind(run_id)
594        .bind(error)
595        .execute(&self.pool)
596        .await
597        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
598
599        Ok(())
600    }
601
602    /// Save step record.
603    pub async fn save_step(&self, step: &WorkflowStepRecord) -> forge_core::Result<()> {
604        sqlx::query(
605            r#"
606            INSERT INTO forge_workflow_steps (
607                id, workflow_run_id, step_name, status, result, error, started_at, completed_at
608            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
609            ON CONFLICT (workflow_run_id, step_name) DO UPDATE SET
610                status = EXCLUDED.status,
611                result = EXCLUDED.result,
612                error = EXCLUDED.error,
613                started_at = COALESCE(forge_workflow_steps.started_at, EXCLUDED.started_at),
614                completed_at = EXCLUDED.completed_at
615            "#,
616        )
617        .bind(step.id)
618        .bind(step.workflow_run_id)
619        .bind(&step.step_name)
620        .bind(step.status.as_str())
621        .bind(&step.result)
622        .bind(&step.error)
623        .bind(step.started_at)
624        .bind(step.completed_at)
625        .execute(&self.pool)
626        .await
627        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
628
629        Ok(())
630    }
631
632    /// Start a workflow by its registered name with JSON input.
633    pub async fn start_by_name(
634        &self,
635        workflow_name: &str,
636        input: serde_json::Value,
637        owner_subject: Option<String>,
638    ) -> forge_core::Result<Uuid> {
639        self.start(workflow_name, input, owner_subject).await
640    }
641}
642
643impl WorkflowDispatch for WorkflowExecutor {
644    fn get_info(&self, workflow_name: &str) -> Option<forge_core::workflow::WorkflowInfo> {
645        self.registry.get(workflow_name).map(|e| e.info.clone())
646    }
647
648    fn start_by_name(
649        &self,
650        workflow_name: &str,
651        input: serde_json::Value,
652        owner_subject: Option<String>,
653    ) -> Pin<Box<dyn Future<Output = forge_core::Result<Uuid>> + Send + '_>> {
654        let workflow_name = workflow_name.to_string();
655        Box::pin(async move {
656            self.start_by_name(&workflow_name, input, owner_subject)
657                .await
658        })
659    }
660}
661
662#[cfg(test)]
663#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
664mod tests {
665    use super::*;
666
667    #[test]
668    fn test_workflow_result_types() {
669        let completed = WorkflowResult::Completed(serde_json::json!({}));
670        let _waiting = WorkflowResult::Waiting {
671            event_type: "approval".to_string(),
672        };
673        let _failed = WorkflowResult::Failed {
674            error: "test".to_string(),
675        };
676        let _compensated = WorkflowResult::Compensated;
677
678        // Just ensure they can be created
679        match completed {
680            WorkflowResult::Completed(_) => {}
681            _ => panic!("Expected Completed"),
682        }
683    }
684}