forge_runtime/workflow/
executor.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use tokio::sync::RwLock;
7use uuid::Uuid;
8
9use super::registry::WorkflowRegistry;
10use super::state::{WorkflowRecord, WorkflowStepRecord};
11use forge_core::function::WorkflowDispatch;
12use forge_core::workflow::{CompensationHandler, StepStatus, WorkflowContext, WorkflowStatus};
13
14/// Workflow execution result.
15#[derive(Debug)]
16pub enum WorkflowResult {
17    /// Workflow completed successfully.
18    Completed(serde_json::Value),
19    /// Workflow is waiting for an external event.
20    Waiting { event_type: String },
21    /// Workflow failed.
22    Failed { error: String },
23    /// Workflow was compensated.
24    Compensated,
25}
26
27/// Compensation state for a running workflow.
28struct CompensationState {
29    handlers: HashMap<String, CompensationHandler>,
30    completed_steps: Vec<String>,
31}
32
33/// Executes workflows.
34pub struct WorkflowExecutor {
35    registry: Arc<WorkflowRegistry>,
36    pool: sqlx::PgPool,
37    http_client: reqwest::Client,
38    /// Compensation state for active workflows (run_id -> state).
39    compensation_state: Arc<RwLock<HashMap<Uuid, CompensationState>>>,
40}
41
42impl WorkflowExecutor {
43    /// Create a new workflow executor.
44    pub fn new(
45        registry: Arc<WorkflowRegistry>,
46        pool: sqlx::PgPool,
47        http_client: reqwest::Client,
48    ) -> Self {
49        Self {
50            registry,
51            pool,
52            http_client,
53            compensation_state: Arc::new(RwLock::new(HashMap::new())),
54        }
55    }
56
57    /// Start a new workflow.
58    /// Returns immediately with the run_id; workflow executes in the background.
59    pub async fn start<I: serde::Serialize>(
60        &self,
61        workflow_name: &str,
62        input: I,
63    ) -> forge_core::Result<Uuid> {
64        let entry = self.registry.get(workflow_name).ok_or_else(|| {
65            forge_core::ForgeError::NotFound(format!("Workflow '{}' not found", workflow_name))
66        })?;
67
68        let input_value = serde_json::to_value(input)?;
69
70        let record = WorkflowRecord::new(workflow_name, entry.info.version, input_value.clone());
71        let run_id = record.id;
72
73        // Clone entry data for background execution
74        let entry_info = entry.info.clone();
75        let entry_handler = entry.handler.clone();
76
77        // Persist workflow record
78        self.save_workflow(&record).await?;
79
80        // Execute workflow in background
81        let registry = self.registry.clone();
82        let pool = self.pool.clone();
83        let http_client = self.http_client.clone();
84        let compensation_state = self.compensation_state.clone();
85
86        tokio::spawn(async move {
87            let executor = WorkflowExecutor {
88                registry,
89                pool,
90                http_client,
91                compensation_state,
92            };
93            let entry = super::registry::WorkflowEntry {
94                info: entry_info,
95                handler: entry_handler,
96            };
97            if let Err(e) = executor.execute_workflow(run_id, &entry, input_value).await {
98                tracing::error!(
99                    workflow_run_id = %run_id,
100                    error = %e,
101                    "Workflow execution failed"
102                );
103            }
104        });
105
106        Ok(run_id)
107    }
108
109    /// Execute a workflow.
110    async fn execute_workflow(
111        &self,
112        run_id: Uuid,
113        entry: &super::registry::WorkflowEntry,
114        input: serde_json::Value,
115    ) -> forge_core::Result<WorkflowResult> {
116        // Update status to running
117        self.update_workflow_status(run_id, WorkflowStatus::Running)
118            .await?;
119
120        // Create workflow context
121        let ctx = WorkflowContext::new(
122            run_id,
123            entry.info.name.to_string(),
124            entry.info.version,
125            self.pool.clone(),
126            self.http_client.clone(),
127        );
128
129        // Execute workflow with timeout
130        let handler = entry.handler.clone();
131        let result = tokio::time::timeout(entry.info.timeout, handler(&ctx, input)).await;
132
133        // Capture compensation state after execution
134        let compensation_state = CompensationState {
135            handlers: ctx.compensation_handlers(),
136            completed_steps: ctx.completed_steps_reversed().into_iter().rev().collect(),
137        };
138        self.compensation_state
139            .write()
140            .await
141            .insert(run_id, compensation_state);
142
143        match result {
144            Ok(Ok(output)) => {
145                // Mark as completed, clean up compensation state
146                self.complete_workflow(run_id, output.clone()).await?;
147                self.compensation_state.write().await.remove(&run_id);
148                Ok(WorkflowResult::Completed(output))
149            }
150            Ok(Err(e)) => {
151                // Check if this is a suspension (not a real failure)
152                if matches!(e, forge_core::ForgeError::WorkflowSuspended) {
153                    // Workflow suspended itself (sleep or wait_for_event)
154                    // Status already set to 'waiting' by ctx.sleep() - don't mark as failed
155                    return Ok(WorkflowResult::Waiting {
156                        event_type: "timer".to_string(),
157                    });
158                }
159                // Mark as failed - compensation can be triggered via cancel
160                self.fail_workflow(run_id, &e.to_string()).await?;
161                Ok(WorkflowResult::Failed {
162                    error: e.to_string(),
163                })
164            }
165            Err(_) => {
166                // Timeout
167                self.fail_workflow(run_id, "Workflow timed out").await?;
168                Ok(WorkflowResult::Failed {
169                    error: "Workflow timed out".to_string(),
170                })
171            }
172        }
173    }
174
175    /// Execute a resumed workflow with step states loaded from the database.
176    async fn execute_workflow_resumed(
177        &self,
178        run_id: Uuid,
179        entry: &super::registry::WorkflowEntry,
180        input: serde_json::Value,
181        started_at: chrono::DateTime<chrono::Utc>,
182        from_sleep: bool,
183    ) -> forge_core::Result<WorkflowResult> {
184        // Update status to running
185        self.update_workflow_status(run_id, WorkflowStatus::Running)
186            .await?;
187
188        // Load step states from database
189        let step_records = self.get_workflow_steps(run_id).await?;
190        let mut step_states = std::collections::HashMap::new();
191        for step in step_records {
192            let status = step.status;
193            step_states.insert(
194                step.step_name.clone(),
195                forge_core::workflow::StepState {
196                    name: step.step_name,
197                    status,
198                    result: step.result,
199                    error: step.error,
200                    started_at: step.started_at,
201                    completed_at: step.completed_at,
202                },
203            );
204        }
205
206        // Create resumed workflow context with step states
207        let mut ctx = WorkflowContext::resumed(
208            run_id,
209            entry.info.name.to_string(),
210            entry.info.version,
211            started_at,
212            self.pool.clone(),
213            self.http_client.clone(),
214        )
215        .with_step_states(step_states);
216
217        // If resuming from a sleep timer, mark the context so sleep() returns immediately
218        if from_sleep {
219            ctx = ctx.with_resumed_from_sleep();
220        }
221
222        // Execute workflow with timeout
223        let handler = entry.handler.clone();
224        let result = tokio::time::timeout(entry.info.timeout, handler(&ctx, input)).await;
225
226        // Capture compensation state after execution
227        let compensation_state = CompensationState {
228            handlers: ctx.compensation_handlers(),
229            completed_steps: ctx.completed_steps_reversed().into_iter().rev().collect(),
230        };
231        self.compensation_state
232            .write()
233            .await
234            .insert(run_id, compensation_state);
235
236        match result {
237            Ok(Ok(output)) => {
238                // Mark as completed, clean up compensation state
239                self.complete_workflow(run_id, output.clone()).await?;
240                self.compensation_state.write().await.remove(&run_id);
241                Ok(WorkflowResult::Completed(output))
242            }
243            Ok(Err(e)) => {
244                // Check if this is a suspension (not a real failure)
245                if matches!(e, forge_core::ForgeError::WorkflowSuspended) {
246                    // Workflow suspended itself (sleep or wait_for_event)
247                    // Status already set to 'waiting' by ctx.sleep() - don't mark as failed
248                    return Ok(WorkflowResult::Waiting {
249                        event_type: "timer".to_string(),
250                    });
251                }
252                // Mark as failed - compensation can be triggered via cancel
253                self.fail_workflow(run_id, &e.to_string()).await?;
254                Ok(WorkflowResult::Failed {
255                    error: e.to_string(),
256                })
257            }
258            Err(_) => {
259                // Timeout
260                self.fail_workflow(run_id, "Workflow timed out").await?;
261                Ok(WorkflowResult::Failed {
262                    error: "Workflow timed out".to_string(),
263                })
264            }
265        }
266    }
267
268    /// Resume a workflow from where it left off.
269    pub async fn resume(&self, run_id: Uuid) -> forge_core::Result<WorkflowResult> {
270        self.resume_internal(run_id, false).await
271    }
272
273    /// Resume a workflow after a sleep timer expired.
274    pub async fn resume_from_sleep(&self, run_id: Uuid) -> forge_core::Result<WorkflowResult> {
275        self.resume_internal(run_id, true).await
276    }
277
278    /// Internal resume logic.
279    async fn resume_internal(
280        &self,
281        run_id: Uuid,
282        from_sleep: bool,
283    ) -> forge_core::Result<WorkflowResult> {
284        let record = self.get_workflow(run_id).await?;
285
286        let entry = self.registry.get(&record.workflow_name).ok_or_else(|| {
287            forge_core::ForgeError::NotFound(format!(
288                "Workflow '{}' not found",
289                record.workflow_name
290            ))
291        })?;
292
293        // Check if workflow is resumable
294        match record.status {
295            WorkflowStatus::Running | WorkflowStatus::Waiting => {
296                // Can resume
297            }
298            status if status.is_terminal() => {
299                return Err(forge_core::ForgeError::Validation(format!(
300                    "Cannot resume workflow in {} state",
301                    status.as_str()
302                )));
303            }
304            _ => {}
305        }
306
307        self.execute_workflow_resumed(run_id, entry, record.input, record.started_at, from_sleep)
308            .await
309    }
310
311    /// Get workflow status.
312    pub async fn status(&self, run_id: Uuid) -> forge_core::Result<WorkflowRecord> {
313        self.get_workflow(run_id).await
314    }
315
316    /// Cancel a workflow and run compensation.
317    pub async fn cancel(&self, run_id: Uuid) -> forge_core::Result<()> {
318        self.update_workflow_status(run_id, WorkflowStatus::Compensating)
319            .await?;
320
321        // Get compensation state
322        let state = self.compensation_state.write().await.remove(&run_id);
323
324        if let Some(state) = state {
325            // Get completed steps with results from database
326            let steps = self.get_workflow_steps(run_id).await?;
327
328            // Run compensation in reverse order
329            for step_name in state.completed_steps.iter().rev() {
330                if let Some(handler) = state.handlers.get(step_name) {
331                    // Find the step result
332                    let step_result = steps
333                        .iter()
334                        .find(|s| &s.step_name == step_name)
335                        .and_then(|s| s.result.clone())
336                        .unwrap_or(serde_json::Value::Null);
337
338                    // Run compensation handler
339                    match handler(step_result).await {
340                        Ok(()) => {
341                            tracing::info!(
342                                workflow_run_id = %run_id,
343                                step = %step_name,
344                                "Compensation completed"
345                            );
346                            self.update_step_status(run_id, step_name, StepStatus::Compensated)
347                                .await?;
348                        }
349                        Err(e) => {
350                            tracing::error!(
351                                workflow_run_id = %run_id,
352                                step = %step_name,
353                                error = %e,
354                                "Compensation failed"
355                            );
356                            // Continue with other compensations even if one fails
357                        }
358                    }
359                } else {
360                    // No handler, just mark as compensated
361                    self.update_step_status(run_id, step_name, StepStatus::Compensated)
362                        .await?;
363                }
364            }
365        } else {
366            // No in-memory state, try to compensate from DB state
367            // This handles the case where the server restarted
368            tracing::warn!(
369                workflow_run_id = %run_id,
370                "No compensation state found, marking as compensated without handlers"
371            );
372        }
373
374        self.update_workflow_status(run_id, WorkflowStatus::Compensated)
375            .await?;
376
377        Ok(())
378    }
379
380    /// Get workflow steps from database.
381    async fn get_workflow_steps(
382        &self,
383        workflow_run_id: Uuid,
384    ) -> forge_core::Result<Vec<WorkflowStepRecord>> {
385        let rows = sqlx::query(
386            r#"
387            SELECT id, workflow_run_id, step_name, status, result, error, started_at, completed_at
388            FROM forge_workflow_steps
389            WHERE workflow_run_id = $1
390            ORDER BY started_at ASC
391            "#,
392        )
393        .bind(workflow_run_id)
394        .fetch_all(&self.pool)
395        .await
396        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
397
398        use sqlx::Row;
399        Ok(rows
400            .into_iter()
401            .map(|row| WorkflowStepRecord {
402                id: row.get("id"),
403                workflow_run_id: row.get("workflow_run_id"),
404                step_name: row.get("step_name"),
405                status: row.get::<String, _>("status").parse().unwrap(),
406                result: row.get("result"),
407                error: row.get("error"),
408                started_at: row.get("started_at"),
409                completed_at: row.get("completed_at"),
410            })
411            .collect())
412    }
413
414    /// Update step status.
415    async fn update_step_status(
416        &self,
417        workflow_run_id: Uuid,
418        step_name: &str,
419        status: StepStatus,
420    ) -> forge_core::Result<()> {
421        sqlx::query(
422            r#"
423            UPDATE forge_workflow_steps
424            SET status = $3
425            WHERE workflow_run_id = $1 AND step_name = $2
426            "#,
427        )
428        .bind(workflow_run_id)
429        .bind(step_name)
430        .bind(status.as_str())
431        .execute(&self.pool)
432        .await
433        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
434
435        Ok(())
436    }
437
438    /// Save workflow record to database.
439    async fn save_workflow(&self, record: &WorkflowRecord) -> forge_core::Result<()> {
440        sqlx::query(
441            r#"
442            INSERT INTO forge_workflow_runs (
443                id, workflow_name, input, status, current_step,
444                step_results, started_at, trace_id
445            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
446            "#,
447        )
448        .bind(record.id)
449        .bind(&record.workflow_name)
450        .bind(&record.input)
451        .bind(record.status.as_str())
452        .bind(&record.current_step)
453        .bind(&record.step_results)
454        .bind(record.started_at)
455        .bind(&record.trace_id)
456        .execute(&self.pool)
457        .await
458        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
459
460        Ok(())
461    }
462
463    /// Get workflow record from database.
464    async fn get_workflow(&self, run_id: Uuid) -> forge_core::Result<WorkflowRecord> {
465        let row = sqlx::query(
466            r#"
467            SELECT id, workflow_name, input, output, status, current_step,
468                   step_results, started_at, completed_at, error, trace_id
469            FROM forge_workflow_runs
470            WHERE id = $1
471            "#,
472        )
473        .bind(run_id)
474        .fetch_optional(&self.pool)
475        .await
476        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
477
478        let row = row.ok_or_else(|| {
479            forge_core::ForgeError::NotFound(format!("Workflow run {} not found", run_id))
480        })?;
481
482        use sqlx::Row;
483        Ok(WorkflowRecord {
484            id: row.get("id"),
485            workflow_name: row.get("workflow_name"),
486            version: 1, // TODO: Add version column
487            input: row.get("input"),
488            output: row.get("output"),
489            status: row.get::<String, _>("status").parse().unwrap(),
490            current_step: row.get("current_step"),
491            step_results: row.get("step_results"),
492            started_at: row.get("started_at"),
493            completed_at: row.get("completed_at"),
494            error: row.get("error"),
495            trace_id: row.get("trace_id"),
496        })
497    }
498
499    /// Update workflow status.
500    async fn update_workflow_status(
501        &self,
502        run_id: Uuid,
503        status: WorkflowStatus,
504    ) -> forge_core::Result<()> {
505        sqlx::query(
506            r#"
507            UPDATE forge_workflow_runs
508            SET status = $2
509            WHERE id = $1
510            "#,
511        )
512        .bind(run_id)
513        .bind(status.as_str())
514        .execute(&self.pool)
515        .await
516        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
517
518        Ok(())
519    }
520
521    /// Mark workflow as completed.
522    async fn complete_workflow(
523        &self,
524        run_id: Uuid,
525        output: serde_json::Value,
526    ) -> forge_core::Result<()> {
527        sqlx::query(
528            r#"
529            UPDATE forge_workflow_runs
530            SET status = 'completed', output = $2, completed_at = NOW()
531            WHERE id = $1
532            "#,
533        )
534        .bind(run_id)
535        .bind(output)
536        .execute(&self.pool)
537        .await
538        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
539
540        Ok(())
541    }
542
543    /// Mark workflow as failed.
544    async fn fail_workflow(&self, run_id: Uuid, error: &str) -> forge_core::Result<()> {
545        sqlx::query(
546            r#"
547            UPDATE forge_workflow_runs
548            SET status = 'failed', error = $2, completed_at = NOW()
549            WHERE id = $1
550            "#,
551        )
552        .bind(run_id)
553        .bind(error)
554        .execute(&self.pool)
555        .await
556        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
557
558        Ok(())
559    }
560
561    /// Save step record.
562    pub async fn save_step(&self, step: &WorkflowStepRecord) -> forge_core::Result<()> {
563        sqlx::query(
564            r#"
565            INSERT INTO forge_workflow_steps (
566                id, workflow_run_id, step_name, status, result, error, started_at, completed_at
567            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
568            ON CONFLICT (workflow_run_id, step_name) DO UPDATE SET
569                status = EXCLUDED.status,
570                result = EXCLUDED.result,
571                error = EXCLUDED.error,
572                started_at = COALESCE(forge_workflow_steps.started_at, EXCLUDED.started_at),
573                completed_at = EXCLUDED.completed_at
574            "#,
575        )
576        .bind(step.id)
577        .bind(step.workflow_run_id)
578        .bind(&step.step_name)
579        .bind(step.status.as_str())
580        .bind(&step.result)
581        .bind(&step.error)
582        .bind(step.started_at)
583        .bind(step.completed_at)
584        .execute(&self.pool)
585        .await
586        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
587
588        Ok(())
589    }
590
591    /// Start a workflow by its registered name with JSON input.
592    pub async fn start_by_name(
593        &self,
594        workflow_name: &str,
595        input: serde_json::Value,
596    ) -> forge_core::Result<Uuid> {
597        self.start(workflow_name, input).await
598    }
599}
600
601impl WorkflowDispatch for WorkflowExecutor {
602    fn start_by_name(
603        &self,
604        workflow_name: &str,
605        input: serde_json::Value,
606    ) -> Pin<Box<dyn Future<Output = forge_core::Result<Uuid>> + Send + '_>> {
607        let workflow_name = workflow_name.to_string();
608        Box::pin(async move { self.start_by_name(&workflow_name, input).await })
609    }
610}
611
612#[cfg(test)]
613mod tests {
614    use super::*;
615
616    #[test]
617    fn test_workflow_result_types() {
618        let completed = WorkflowResult::Completed(serde_json::json!({}));
619        let _waiting = WorkflowResult::Waiting {
620            event_type: "approval".to_string(),
621        };
622        let _failed = WorkflowResult::Failed {
623            error: "test".to_string(),
624        };
625        let _compensated = WorkflowResult::Compensated;
626
627        // Just ensure they can be created
628        match completed {
629            WorkflowResult::Completed(_) => {}
630            _ => panic!("Expected Completed"),
631        }
632    }
633}