Skip to main content

forge_runtime/workflow/
executor.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use tokio::sync::RwLock;
7use uuid::Uuid;
8
9use super::registry::WorkflowRegistry;
10use super::state::{WorkflowRecord, WorkflowStepRecord};
11use forge_core::CircuitBreakerClient;
12use forge_core::function::WorkflowDispatch;
13use forge_core::workflow::{CompensationHandler, StepStatus, WorkflowContext, WorkflowStatus};
14
15/// Workflow execution result.
16#[derive(Debug)]
17pub enum WorkflowResult {
18    /// Workflow completed successfully.
19    Completed(serde_json::Value),
20    /// Workflow is waiting for an external event.
21    Waiting { event_type: String },
22    /// Workflow failed.
23    Failed { error: String },
24    /// Workflow was compensated.
25    Compensated,
26}
27
28/// Compensation state for a running workflow.
29struct CompensationState {
30    handlers: HashMap<String, CompensationHandler>,
31    completed_steps: Vec<String>,
32}
33
34/// Executes workflows.
35pub struct WorkflowExecutor {
36    registry: Arc<WorkflowRegistry>,
37    pool: sqlx::PgPool,
38    http_client: CircuitBreakerClient,
39    /// Compensation state for active workflows (run_id -> state).
40    compensation_state: Arc<RwLock<HashMap<Uuid, CompensationState>>>,
41}
42
43impl WorkflowExecutor {
44    /// Create a new workflow executor.
45    pub fn new(
46        registry: Arc<WorkflowRegistry>,
47        pool: sqlx::PgPool,
48        http_client: CircuitBreakerClient,
49    ) -> Self {
50        Self {
51            registry,
52            pool,
53            http_client,
54            compensation_state: Arc::new(RwLock::new(HashMap::new())),
55        }
56    }
57
58    /// Start a new workflow.
59    /// Returns immediately with the run_id; workflow executes in the background.
60    pub async fn start<I: serde::Serialize>(
61        &self,
62        workflow_name: &str,
63        input: I,
64    ) -> forge_core::Result<Uuid> {
65        let entry = self.registry.get(workflow_name).ok_or_else(|| {
66            forge_core::ForgeError::NotFound(format!("Workflow '{}' not found", workflow_name))
67        })?;
68
69        let input_value = serde_json::to_value(input)?;
70
71        let record = WorkflowRecord::new(workflow_name, entry.info.version, input_value.clone());
72        let run_id = record.id;
73
74        // Clone entry data for background execution
75        let entry_info = entry.info.clone();
76        let entry_handler = entry.handler.clone();
77
78        // Persist workflow record
79        self.save_workflow(&record).await?;
80
81        // Execute workflow in background
82        let registry = self.registry.clone();
83        let pool = self.pool.clone();
84        let http_client = self.http_client.clone();
85        let compensation_state = self.compensation_state.clone();
86
87        tokio::spawn(async move {
88            let executor = WorkflowExecutor {
89                registry,
90                pool,
91                http_client,
92                compensation_state,
93            };
94            let entry = super::registry::WorkflowEntry {
95                info: entry_info,
96                handler: entry_handler,
97            };
98            if let Err(e) = executor.execute_workflow(run_id, &entry, input_value).await {
99                tracing::error!(
100                    workflow_run_id = %run_id,
101                    error = %e,
102                    "Workflow execution failed"
103                );
104            }
105        });
106
107        Ok(run_id)
108    }
109
110    /// Execute a workflow.
111    async fn execute_workflow(
112        &self,
113        run_id: Uuid,
114        entry: &super::registry::WorkflowEntry,
115        input: serde_json::Value,
116    ) -> forge_core::Result<WorkflowResult> {
117        // Update status to running
118        self.update_workflow_status(run_id, WorkflowStatus::Running)
119            .await?;
120
121        // Create workflow context
122        let ctx = WorkflowContext::new(
123            run_id,
124            entry.info.name.to_string(),
125            entry.info.version,
126            self.pool.clone(),
127            self.http_client.inner().clone(),
128        );
129
130        // Execute workflow with timeout
131        let handler = entry.handler.clone();
132        let result = tokio::time::timeout(entry.info.timeout, handler(&ctx, input)).await;
133
134        // Capture compensation state after execution
135        let compensation_state = CompensationState {
136            handlers: ctx.compensation_handlers(),
137            completed_steps: ctx.completed_steps_reversed().into_iter().rev().collect(),
138        };
139        self.compensation_state
140            .write()
141            .await
142            .insert(run_id, compensation_state);
143
144        match result {
145            Ok(Ok(output)) => {
146                // Mark as completed, clean up compensation state
147                self.complete_workflow(run_id, output.clone()).await?;
148                self.compensation_state.write().await.remove(&run_id);
149                Ok(WorkflowResult::Completed(output))
150            }
151            Ok(Err(e)) => {
152                // Check if this is a suspension (not a real failure)
153                if matches!(e, forge_core::ForgeError::WorkflowSuspended) {
154                    // Workflow suspended itself (sleep or wait_for_event)
155                    // Status already set to 'waiting' by ctx.sleep() - don't mark as failed
156                    return Ok(WorkflowResult::Waiting {
157                        event_type: "timer".to_string(),
158                    });
159                }
160                // Mark as failed - compensation can be triggered via cancel
161                self.fail_workflow(run_id, &e.to_string()).await?;
162                Ok(WorkflowResult::Failed {
163                    error: e.to_string(),
164                })
165            }
166            Err(_) => {
167                // Timeout
168                self.fail_workflow(run_id, "Workflow timed out").await?;
169                Ok(WorkflowResult::Failed {
170                    error: "Workflow timed out".to_string(),
171                })
172            }
173        }
174    }
175
176    /// Execute a resumed workflow with step states loaded from the database.
177    async fn execute_workflow_resumed(
178        &self,
179        run_id: Uuid,
180        entry: &super::registry::WorkflowEntry,
181        input: serde_json::Value,
182        started_at: chrono::DateTime<chrono::Utc>,
183        from_sleep: bool,
184    ) -> forge_core::Result<WorkflowResult> {
185        // Update status to running
186        self.update_workflow_status(run_id, WorkflowStatus::Running)
187            .await?;
188
189        // Load step states from database
190        let step_records = self.get_workflow_steps(run_id).await?;
191        let mut step_states = std::collections::HashMap::new();
192        for step in step_records {
193            let status = step.status;
194            step_states.insert(
195                step.step_name.clone(),
196                forge_core::workflow::StepState {
197                    name: step.step_name,
198                    status,
199                    result: step.result,
200                    error: step.error,
201                    started_at: step.started_at,
202                    completed_at: step.completed_at,
203                },
204            );
205        }
206
207        // Create resumed workflow context with step states
208        let mut ctx = WorkflowContext::resumed(
209            run_id,
210            entry.info.name.to_string(),
211            entry.info.version,
212            started_at,
213            self.pool.clone(),
214            self.http_client.inner().clone(),
215        )
216        .with_step_states(step_states);
217
218        // If resuming from a sleep timer, mark the context so sleep() returns immediately
219        if from_sleep {
220            ctx = ctx.with_resumed_from_sleep();
221        }
222
223        // Execute workflow with timeout
224        let handler = entry.handler.clone();
225        let result = tokio::time::timeout(entry.info.timeout, handler(&ctx, input)).await;
226
227        // Capture compensation state after execution
228        let compensation_state = CompensationState {
229            handlers: ctx.compensation_handlers(),
230            completed_steps: ctx.completed_steps_reversed().into_iter().rev().collect(),
231        };
232        self.compensation_state
233            .write()
234            .await
235            .insert(run_id, compensation_state);
236
237        match result {
238            Ok(Ok(output)) => {
239                // Mark as completed, clean up compensation state
240                self.complete_workflow(run_id, output.clone()).await?;
241                self.compensation_state.write().await.remove(&run_id);
242                Ok(WorkflowResult::Completed(output))
243            }
244            Ok(Err(e)) => {
245                // Check if this is a suspension (not a real failure)
246                if matches!(e, forge_core::ForgeError::WorkflowSuspended) {
247                    // Workflow suspended itself (sleep or wait_for_event)
248                    // Status already set to 'waiting' by ctx.sleep() - don't mark as failed
249                    return Ok(WorkflowResult::Waiting {
250                        event_type: "timer".to_string(),
251                    });
252                }
253                // Mark as failed - compensation can be triggered via cancel
254                self.fail_workflow(run_id, &e.to_string()).await?;
255                Ok(WorkflowResult::Failed {
256                    error: e.to_string(),
257                })
258            }
259            Err(_) => {
260                // Timeout
261                self.fail_workflow(run_id, "Workflow timed out").await?;
262                Ok(WorkflowResult::Failed {
263                    error: "Workflow timed out".to_string(),
264                })
265            }
266        }
267    }
268
269    /// Resume a workflow from where it left off.
270    pub async fn resume(&self, run_id: Uuid) -> forge_core::Result<WorkflowResult> {
271        self.resume_internal(run_id, false).await
272    }
273
274    /// Resume a workflow after a sleep timer expired.
275    pub async fn resume_from_sleep(&self, run_id: Uuid) -> forge_core::Result<WorkflowResult> {
276        self.resume_internal(run_id, true).await
277    }
278
279    /// Internal resume logic.
280    async fn resume_internal(
281        &self,
282        run_id: Uuid,
283        from_sleep: bool,
284    ) -> forge_core::Result<WorkflowResult> {
285        let record = self.get_workflow(run_id).await?;
286
287        let entry = self.registry.get(&record.workflow_name).ok_or_else(|| {
288            forge_core::ForgeError::NotFound(format!(
289                "Workflow '{}' not found",
290                record.workflow_name
291            ))
292        })?;
293
294        // Check if workflow is resumable
295        match record.status {
296            WorkflowStatus::Running | WorkflowStatus::Waiting => {
297                // Can resume
298            }
299            status if status.is_terminal() => {
300                return Err(forge_core::ForgeError::Validation(format!(
301                    "Cannot resume workflow in {} state",
302                    status.as_str()
303                )));
304            }
305            _ => {}
306        }
307
308        self.execute_workflow_resumed(run_id, entry, record.input, record.started_at, from_sleep)
309            .await
310    }
311
312    /// Get workflow status.
313    pub async fn status(&self, run_id: Uuid) -> forge_core::Result<WorkflowRecord> {
314        self.get_workflow(run_id).await
315    }
316
317    /// Cancel a workflow and run compensation.
318    pub async fn cancel(&self, run_id: Uuid) -> forge_core::Result<()> {
319        self.update_workflow_status(run_id, WorkflowStatus::Compensating)
320            .await?;
321
322        // Get compensation state
323        let state = self.compensation_state.write().await.remove(&run_id);
324
325        if let Some(state) = state {
326            // Get completed steps with results from database
327            let steps = self.get_workflow_steps(run_id).await?;
328
329            // Run compensation in reverse order
330            for step_name in state.completed_steps.iter().rev() {
331                if let Some(handler) = state.handlers.get(step_name) {
332                    // Find the step result
333                    let step_result = steps
334                        .iter()
335                        .find(|s| &s.step_name == step_name)
336                        .and_then(|s| s.result.clone())
337                        .unwrap_or(serde_json::Value::Null);
338
339                    // Run compensation handler
340                    match handler(step_result).await {
341                        Ok(()) => {
342                            tracing::info!(
343                                workflow_run_id = %run_id,
344                                step = %step_name,
345                                "Compensation completed"
346                            );
347                            self.update_step_status(run_id, step_name, StepStatus::Compensated)
348                                .await?;
349                        }
350                        Err(e) => {
351                            tracing::error!(
352                                workflow_run_id = %run_id,
353                                step = %step_name,
354                                error = %e,
355                                "Compensation failed"
356                            );
357                            // Continue with other compensations even if one fails
358                        }
359                    }
360                } else {
361                    // No handler, just mark as compensated
362                    self.update_step_status(run_id, step_name, StepStatus::Compensated)
363                        .await?;
364                }
365            }
366        } else {
367            // No in-memory state, try to compensate from DB state
368            // This handles the case where the server restarted
369            tracing::warn!(
370                workflow_run_id = %run_id,
371                "No compensation state found, marking as compensated without handlers"
372            );
373        }
374
375        self.update_workflow_status(run_id, WorkflowStatus::Compensated)
376            .await?;
377
378        Ok(())
379    }
380
381    /// Get workflow steps from database.
382    async fn get_workflow_steps(
383        &self,
384        workflow_run_id: Uuid,
385    ) -> forge_core::Result<Vec<WorkflowStepRecord>> {
386        let rows = sqlx::query(
387            r#"
388            SELECT id, workflow_run_id, step_name, status, result, error, started_at, completed_at
389            FROM forge_workflow_steps
390            WHERE workflow_run_id = $1
391            ORDER BY started_at ASC
392            "#,
393        )
394        .bind(workflow_run_id)
395        .fetch_all(&self.pool)
396        .await
397        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
398
399        use sqlx::Row;
400        Ok(rows
401            .into_iter()
402            .map(|row| WorkflowStepRecord {
403                id: row.get("id"),
404                workflow_run_id: row.get("workflow_run_id"),
405                step_name: row.get("step_name"),
406                status: row.get::<String, _>("status").parse().unwrap(),
407                result: row.get("result"),
408                error: row.get("error"),
409                started_at: row.get("started_at"),
410                completed_at: row.get("completed_at"),
411            })
412            .collect())
413    }
414
415    /// Update step status.
416    async fn update_step_status(
417        &self,
418        workflow_run_id: Uuid,
419        step_name: &str,
420        status: StepStatus,
421    ) -> forge_core::Result<()> {
422        sqlx::query(
423            r#"
424            UPDATE forge_workflow_steps
425            SET status = $3
426            WHERE workflow_run_id = $1 AND step_name = $2
427            "#,
428        )
429        .bind(workflow_run_id)
430        .bind(step_name)
431        .bind(status.as_str())
432        .execute(&self.pool)
433        .await
434        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
435
436        Ok(())
437    }
438
439    /// Save workflow record to database.
440    async fn save_workflow(&self, record: &WorkflowRecord) -> forge_core::Result<()> {
441        sqlx::query(
442            r#"
443            INSERT INTO forge_workflow_runs (
444                id, workflow_name, input, status, current_step,
445                step_results, started_at, trace_id
446            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
447            "#,
448        )
449        .bind(record.id)
450        .bind(&record.workflow_name)
451        .bind(&record.input)
452        .bind(record.status.as_str())
453        .bind(&record.current_step)
454        .bind(&record.step_results)
455        .bind(record.started_at)
456        .bind(&record.trace_id)
457        .execute(&self.pool)
458        .await
459        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
460
461        Ok(())
462    }
463
464    /// Get workflow record from database.
465    async fn get_workflow(&self, run_id: Uuid) -> forge_core::Result<WorkflowRecord> {
466        let row = sqlx::query(
467            r#"
468            SELECT id, workflow_name, input, output, status, current_step,
469                   step_results, started_at, completed_at, error, trace_id
470            FROM forge_workflow_runs
471            WHERE id = $1
472            "#,
473        )
474        .bind(run_id)
475        .fetch_optional(&self.pool)
476        .await
477        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
478
479        let row = row.ok_or_else(|| {
480            forge_core::ForgeError::NotFound(format!("Workflow run {} not found", run_id))
481        })?;
482
483        use sqlx::Row;
484        Ok(WorkflowRecord {
485            id: row.get("id"),
486            workflow_name: row.get("workflow_name"),
487            version: 1, // TODO: Add version column
488            input: row.get("input"),
489            output: row.get("output"),
490            status: row.get::<String, _>("status").parse().unwrap(),
491            current_step: row.get("current_step"),
492            step_results: row.get("step_results"),
493            started_at: row.get("started_at"),
494            completed_at: row.get("completed_at"),
495            error: row.get("error"),
496            trace_id: row.get("trace_id"),
497        })
498    }
499
500    /// Update workflow status.
501    async fn update_workflow_status(
502        &self,
503        run_id: Uuid,
504        status: WorkflowStatus,
505    ) -> forge_core::Result<()> {
506        sqlx::query(
507            r#"
508            UPDATE forge_workflow_runs
509            SET status = $2
510            WHERE id = $1
511            "#,
512        )
513        .bind(run_id)
514        .bind(status.as_str())
515        .execute(&self.pool)
516        .await
517        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
518
519        Ok(())
520    }
521
522    /// Mark workflow as completed.
523    async fn complete_workflow(
524        &self,
525        run_id: Uuid,
526        output: serde_json::Value,
527    ) -> forge_core::Result<()> {
528        sqlx::query(
529            r#"
530            UPDATE forge_workflow_runs
531            SET status = 'completed', output = $2, completed_at = NOW()
532            WHERE id = $1
533            "#,
534        )
535        .bind(run_id)
536        .bind(output)
537        .execute(&self.pool)
538        .await
539        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
540
541        Ok(())
542    }
543
544    /// Mark workflow as failed.
545    async fn fail_workflow(&self, run_id: Uuid, error: &str) -> forge_core::Result<()> {
546        sqlx::query(
547            r#"
548            UPDATE forge_workflow_runs
549            SET status = 'failed', error = $2, completed_at = NOW()
550            WHERE id = $1
551            "#,
552        )
553        .bind(run_id)
554        .bind(error)
555        .execute(&self.pool)
556        .await
557        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
558
559        Ok(())
560    }
561
562    /// Save step record.
563    pub async fn save_step(&self, step: &WorkflowStepRecord) -> forge_core::Result<()> {
564        sqlx::query(
565            r#"
566            INSERT INTO forge_workflow_steps (
567                id, workflow_run_id, step_name, status, result, error, started_at, completed_at
568            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
569            ON CONFLICT (workflow_run_id, step_name) DO UPDATE SET
570                status = EXCLUDED.status,
571                result = EXCLUDED.result,
572                error = EXCLUDED.error,
573                started_at = COALESCE(forge_workflow_steps.started_at, EXCLUDED.started_at),
574                completed_at = EXCLUDED.completed_at
575            "#,
576        )
577        .bind(step.id)
578        .bind(step.workflow_run_id)
579        .bind(&step.step_name)
580        .bind(step.status.as_str())
581        .bind(&step.result)
582        .bind(&step.error)
583        .bind(step.started_at)
584        .bind(step.completed_at)
585        .execute(&self.pool)
586        .await
587        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
588
589        Ok(())
590    }
591
592    /// Start a workflow by its registered name with JSON input.
593    pub async fn start_by_name(
594        &self,
595        workflow_name: &str,
596        input: serde_json::Value,
597    ) -> forge_core::Result<Uuid> {
598        self.start(workflow_name, input).await
599    }
600}
601
602impl WorkflowDispatch for WorkflowExecutor {
603    fn get_info(&self, workflow_name: &str) -> Option<forge_core::workflow::WorkflowInfo> {
604        self.registry.get(workflow_name).map(|e| e.info.clone())
605    }
606
607    fn start_by_name(
608        &self,
609        workflow_name: &str,
610        input: serde_json::Value,
611    ) -> Pin<Box<dyn Future<Output = forge_core::Result<Uuid>> + Send + '_>> {
612        let workflow_name = workflow_name.to_string();
613        Box::pin(async move { self.start_by_name(&workflow_name, input).await })
614    }
615}
616
617#[cfg(test)]
618mod tests {
619    use super::*;
620
621    #[test]
622    fn test_workflow_result_types() {
623        let completed = WorkflowResult::Completed(serde_json::json!({}));
624        let _waiting = WorkflowResult::Waiting {
625            event_type: "approval".to_string(),
626        };
627        let _failed = WorkflowResult::Failed {
628            error: "test".to_string(),
629        };
630        let _compensated = WorkflowResult::Compensated;
631
632        // Just ensure they can be created
633        match completed {
634            WorkflowResult::Completed(_) => {}
635            _ => panic!("Expected Completed"),
636        }
637    }
638}