1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use tokio::sync::RwLock;
7use uuid::Uuid;
8
9use super::registry::WorkflowRegistry;
10use super::state::{WorkflowRecord, WorkflowStepRecord};
11use forge_core::function::WorkflowDispatch;
12use forge_core::workflow::{CompensationHandler, StepStatus, WorkflowContext, WorkflowStatus};
13
14#[derive(Debug)]
16pub enum WorkflowResult {
17 Completed(serde_json::Value),
19 Waiting { event_type: String },
21 Failed { error: String },
23 Compensated,
25}
26
27struct CompensationState {
29 handlers: HashMap<String, CompensationHandler>,
30 completed_steps: Vec<String>,
31}
32
33pub struct WorkflowExecutor {
35 registry: Arc<WorkflowRegistry>,
36 pool: sqlx::PgPool,
37 http_client: reqwest::Client,
38 compensation_state: Arc<RwLock<HashMap<Uuid, CompensationState>>>,
40}
41
42impl WorkflowExecutor {
43 pub fn new(
45 registry: Arc<WorkflowRegistry>,
46 pool: sqlx::PgPool,
47 http_client: reqwest::Client,
48 ) -> Self {
49 Self {
50 registry,
51 pool,
52 http_client,
53 compensation_state: Arc::new(RwLock::new(HashMap::new())),
54 }
55 }
56
57 pub async fn start<I: serde::Serialize>(
60 &self,
61 workflow_name: &str,
62 input: I,
63 ) -> forge_core::Result<Uuid> {
64 let entry = self.registry.get(workflow_name).ok_or_else(|| {
65 forge_core::ForgeError::NotFound(format!("Workflow '{}' not found", workflow_name))
66 })?;
67
68 let input_value = serde_json::to_value(input)?;
69
70 let record = WorkflowRecord::new(workflow_name, entry.info.version, input_value.clone());
71 let run_id = record.id;
72
73 let entry_info = entry.info.clone();
75 let entry_handler = entry.handler.clone();
76
77 self.save_workflow(&record).await?;
79
80 let registry = self.registry.clone();
82 let pool = self.pool.clone();
83 let http_client = self.http_client.clone();
84 let compensation_state = self.compensation_state.clone();
85
86 tokio::spawn(async move {
87 let executor = WorkflowExecutor {
88 registry,
89 pool,
90 http_client,
91 compensation_state,
92 };
93 let entry = super::registry::WorkflowEntry {
94 info: entry_info,
95 handler: entry_handler,
96 };
97 if let Err(e) = executor.execute_workflow(run_id, &entry, input_value).await {
98 tracing::error!(
99 workflow_run_id = %run_id,
100 error = %e,
101 "Workflow execution failed"
102 );
103 }
104 });
105
106 Ok(run_id)
107 }
108
109 async fn execute_workflow(
111 &self,
112 run_id: Uuid,
113 entry: &super::registry::WorkflowEntry,
114 input: serde_json::Value,
115 ) -> forge_core::Result<WorkflowResult> {
116 self.update_workflow_status(run_id, WorkflowStatus::Running)
118 .await?;
119
120 let ctx = WorkflowContext::new(
122 run_id,
123 entry.info.name.to_string(),
124 entry.info.version,
125 self.pool.clone(),
126 self.http_client.clone(),
127 );
128
129 let handler = entry.handler.clone();
131 let result = tokio::time::timeout(entry.info.timeout, handler(&ctx, input)).await;
132
133 let compensation_state = CompensationState {
135 handlers: ctx.compensation_handlers(),
136 completed_steps: ctx.completed_steps_reversed().into_iter().rev().collect(),
137 };
138 self.compensation_state
139 .write()
140 .await
141 .insert(run_id, compensation_state);
142
143 match result {
144 Ok(Ok(output)) => {
145 self.complete_workflow(run_id, output.clone()).await?;
147 self.compensation_state.write().await.remove(&run_id);
148 Ok(WorkflowResult::Completed(output))
149 }
150 Ok(Err(e)) => {
151 if matches!(e, forge_core::ForgeError::WorkflowSuspended) {
153 return Ok(WorkflowResult::Waiting {
156 event_type: "timer".to_string(),
157 });
158 }
159 self.fail_workflow(run_id, &e.to_string()).await?;
161 Ok(WorkflowResult::Failed {
162 error: e.to_string(),
163 })
164 }
165 Err(_) => {
166 self.fail_workflow(run_id, "Workflow timed out").await?;
168 Ok(WorkflowResult::Failed {
169 error: "Workflow timed out".to_string(),
170 })
171 }
172 }
173 }
174
175 async fn execute_workflow_resumed(
177 &self,
178 run_id: Uuid,
179 entry: &super::registry::WorkflowEntry,
180 input: serde_json::Value,
181 started_at: chrono::DateTime<chrono::Utc>,
182 from_sleep: bool,
183 ) -> forge_core::Result<WorkflowResult> {
184 self.update_workflow_status(run_id, WorkflowStatus::Running)
186 .await?;
187
188 let step_records = self.get_workflow_steps(run_id).await?;
190 let mut step_states = std::collections::HashMap::new();
191 for step in step_records {
192 let status = step.status;
193 step_states.insert(
194 step.step_name.clone(),
195 forge_core::workflow::StepState {
196 name: step.step_name,
197 status,
198 result: step.result,
199 error: step.error,
200 started_at: step.started_at,
201 completed_at: step.completed_at,
202 },
203 );
204 }
205
206 let mut ctx = WorkflowContext::resumed(
208 run_id,
209 entry.info.name.to_string(),
210 entry.info.version,
211 started_at,
212 self.pool.clone(),
213 self.http_client.clone(),
214 )
215 .with_step_states(step_states);
216
217 if from_sleep {
219 ctx = ctx.with_resumed_from_sleep();
220 }
221
222 let handler = entry.handler.clone();
224 let result = tokio::time::timeout(entry.info.timeout, handler(&ctx, input)).await;
225
226 let compensation_state = CompensationState {
228 handlers: ctx.compensation_handlers(),
229 completed_steps: ctx.completed_steps_reversed().into_iter().rev().collect(),
230 };
231 self.compensation_state
232 .write()
233 .await
234 .insert(run_id, compensation_state);
235
236 match result {
237 Ok(Ok(output)) => {
238 self.complete_workflow(run_id, output.clone()).await?;
240 self.compensation_state.write().await.remove(&run_id);
241 Ok(WorkflowResult::Completed(output))
242 }
243 Ok(Err(e)) => {
244 if matches!(e, forge_core::ForgeError::WorkflowSuspended) {
246 return Ok(WorkflowResult::Waiting {
249 event_type: "timer".to_string(),
250 });
251 }
252 self.fail_workflow(run_id, &e.to_string()).await?;
254 Ok(WorkflowResult::Failed {
255 error: e.to_string(),
256 })
257 }
258 Err(_) => {
259 self.fail_workflow(run_id, "Workflow timed out").await?;
261 Ok(WorkflowResult::Failed {
262 error: "Workflow timed out".to_string(),
263 })
264 }
265 }
266 }
267
268 pub async fn resume(&self, run_id: Uuid) -> forge_core::Result<WorkflowResult> {
270 self.resume_internal(run_id, false).await
271 }
272
273 pub async fn resume_from_sleep(&self, run_id: Uuid) -> forge_core::Result<WorkflowResult> {
275 self.resume_internal(run_id, true).await
276 }
277
278 async fn resume_internal(
280 &self,
281 run_id: Uuid,
282 from_sleep: bool,
283 ) -> forge_core::Result<WorkflowResult> {
284 let record = self.get_workflow(run_id).await?;
285
286 let entry = self.registry.get(&record.workflow_name).ok_or_else(|| {
287 forge_core::ForgeError::NotFound(format!(
288 "Workflow '{}' not found",
289 record.workflow_name
290 ))
291 })?;
292
293 match record.status {
295 WorkflowStatus::Running | WorkflowStatus::Waiting => {
296 }
298 status if status.is_terminal() => {
299 return Err(forge_core::ForgeError::Validation(format!(
300 "Cannot resume workflow in {} state",
301 status.as_str()
302 )));
303 }
304 _ => {}
305 }
306
307 self.execute_workflow_resumed(run_id, entry, record.input, record.started_at, from_sleep)
308 .await
309 }
310
311 pub async fn status(&self, run_id: Uuid) -> forge_core::Result<WorkflowRecord> {
313 self.get_workflow(run_id).await
314 }
315
316 pub async fn cancel(&self, run_id: Uuid) -> forge_core::Result<()> {
318 self.update_workflow_status(run_id, WorkflowStatus::Compensating)
319 .await?;
320
321 let state = self.compensation_state.write().await.remove(&run_id);
323
324 if let Some(state) = state {
325 let steps = self.get_workflow_steps(run_id).await?;
327
328 for step_name in state.completed_steps.iter().rev() {
330 if let Some(handler) = state.handlers.get(step_name) {
331 let step_result = steps
333 .iter()
334 .find(|s| &s.step_name == step_name)
335 .and_then(|s| s.result.clone())
336 .unwrap_or(serde_json::Value::Null);
337
338 match handler(step_result).await {
340 Ok(()) => {
341 tracing::info!(
342 workflow_run_id = %run_id,
343 step = %step_name,
344 "Compensation completed"
345 );
346 self.update_step_status(run_id, step_name, StepStatus::Compensated)
347 .await?;
348 }
349 Err(e) => {
350 tracing::error!(
351 workflow_run_id = %run_id,
352 step = %step_name,
353 error = %e,
354 "Compensation failed"
355 );
356 }
358 }
359 } else {
360 self.update_step_status(run_id, step_name, StepStatus::Compensated)
362 .await?;
363 }
364 }
365 } else {
366 tracing::warn!(
369 workflow_run_id = %run_id,
370 "No compensation state found, marking as compensated without handlers"
371 );
372 }
373
374 self.update_workflow_status(run_id, WorkflowStatus::Compensated)
375 .await?;
376
377 Ok(())
378 }
379
380 async fn get_workflow_steps(
382 &self,
383 workflow_run_id: Uuid,
384 ) -> forge_core::Result<Vec<WorkflowStepRecord>> {
385 let rows = sqlx::query(
386 r#"
387 SELECT id, workflow_run_id, step_name, status, result, error, started_at, completed_at
388 FROM forge_workflow_steps
389 WHERE workflow_run_id = $1
390 ORDER BY started_at ASC
391 "#,
392 )
393 .bind(workflow_run_id)
394 .fetch_all(&self.pool)
395 .await
396 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
397
398 use sqlx::Row;
399 Ok(rows
400 .into_iter()
401 .map(|row| WorkflowStepRecord {
402 id: row.get("id"),
403 workflow_run_id: row.get("workflow_run_id"),
404 step_name: row.get("step_name"),
405 status: row.get::<String, _>("status").parse().unwrap(),
406 result: row.get("result"),
407 error: row.get("error"),
408 started_at: row.get("started_at"),
409 completed_at: row.get("completed_at"),
410 })
411 .collect())
412 }
413
414 async fn update_step_status(
416 &self,
417 workflow_run_id: Uuid,
418 step_name: &str,
419 status: StepStatus,
420 ) -> forge_core::Result<()> {
421 sqlx::query(
422 r#"
423 UPDATE forge_workflow_steps
424 SET status = $3
425 WHERE workflow_run_id = $1 AND step_name = $2
426 "#,
427 )
428 .bind(workflow_run_id)
429 .bind(step_name)
430 .bind(status.as_str())
431 .execute(&self.pool)
432 .await
433 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
434
435 Ok(())
436 }
437
438 async fn save_workflow(&self, record: &WorkflowRecord) -> forge_core::Result<()> {
440 sqlx::query(
441 r#"
442 INSERT INTO forge_workflow_runs (
443 id, workflow_name, input, status, current_step,
444 step_results, started_at, trace_id
445 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
446 "#,
447 )
448 .bind(record.id)
449 .bind(&record.workflow_name)
450 .bind(&record.input)
451 .bind(record.status.as_str())
452 .bind(&record.current_step)
453 .bind(&record.step_results)
454 .bind(record.started_at)
455 .bind(&record.trace_id)
456 .execute(&self.pool)
457 .await
458 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
459
460 Ok(())
461 }
462
463 async fn get_workflow(&self, run_id: Uuid) -> forge_core::Result<WorkflowRecord> {
465 let row = sqlx::query(
466 r#"
467 SELECT id, workflow_name, input, output, status, current_step,
468 step_results, started_at, completed_at, error, trace_id
469 FROM forge_workflow_runs
470 WHERE id = $1
471 "#,
472 )
473 .bind(run_id)
474 .fetch_optional(&self.pool)
475 .await
476 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
477
478 let row = row.ok_or_else(|| {
479 forge_core::ForgeError::NotFound(format!("Workflow run {} not found", run_id))
480 })?;
481
482 use sqlx::Row;
483 Ok(WorkflowRecord {
484 id: row.get("id"),
485 workflow_name: row.get("workflow_name"),
486 version: 1, input: row.get("input"),
488 output: row.get("output"),
489 status: row.get::<String, _>("status").parse().unwrap(),
490 current_step: row.get("current_step"),
491 step_results: row.get("step_results"),
492 started_at: row.get("started_at"),
493 completed_at: row.get("completed_at"),
494 error: row.get("error"),
495 trace_id: row.get("trace_id"),
496 })
497 }
498
499 async fn update_workflow_status(
501 &self,
502 run_id: Uuid,
503 status: WorkflowStatus,
504 ) -> forge_core::Result<()> {
505 sqlx::query(
506 r#"
507 UPDATE forge_workflow_runs
508 SET status = $2
509 WHERE id = $1
510 "#,
511 )
512 .bind(run_id)
513 .bind(status.as_str())
514 .execute(&self.pool)
515 .await
516 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
517
518 Ok(())
519 }
520
521 async fn complete_workflow(
523 &self,
524 run_id: Uuid,
525 output: serde_json::Value,
526 ) -> forge_core::Result<()> {
527 sqlx::query(
528 r#"
529 UPDATE forge_workflow_runs
530 SET status = 'completed', output = $2, completed_at = NOW()
531 WHERE id = $1
532 "#,
533 )
534 .bind(run_id)
535 .bind(output)
536 .execute(&self.pool)
537 .await
538 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
539
540 Ok(())
541 }
542
543 async fn fail_workflow(&self, run_id: Uuid, error: &str) -> forge_core::Result<()> {
545 sqlx::query(
546 r#"
547 UPDATE forge_workflow_runs
548 SET status = 'failed', error = $2, completed_at = NOW()
549 WHERE id = $1
550 "#,
551 )
552 .bind(run_id)
553 .bind(error)
554 .execute(&self.pool)
555 .await
556 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
557
558 Ok(())
559 }
560
561 pub async fn save_step(&self, step: &WorkflowStepRecord) -> forge_core::Result<()> {
563 sqlx::query(
564 r#"
565 INSERT INTO forge_workflow_steps (
566 id, workflow_run_id, step_name, status, result, error, started_at, completed_at
567 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
568 ON CONFLICT (workflow_run_id, step_name) DO UPDATE SET
569 status = EXCLUDED.status,
570 result = EXCLUDED.result,
571 error = EXCLUDED.error,
572 started_at = COALESCE(forge_workflow_steps.started_at, EXCLUDED.started_at),
573 completed_at = EXCLUDED.completed_at
574 "#,
575 )
576 .bind(step.id)
577 .bind(step.workflow_run_id)
578 .bind(&step.step_name)
579 .bind(step.status.as_str())
580 .bind(&step.result)
581 .bind(&step.error)
582 .bind(step.started_at)
583 .bind(step.completed_at)
584 .execute(&self.pool)
585 .await
586 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
587
588 Ok(())
589 }
590
591 pub async fn start_by_name(
593 &self,
594 workflow_name: &str,
595 input: serde_json::Value,
596 ) -> forge_core::Result<Uuid> {
597 self.start(workflow_name, input).await
598 }
599}
600
601impl WorkflowDispatch for WorkflowExecutor {
602 fn start_by_name(
603 &self,
604 workflow_name: &str,
605 input: serde_json::Value,
606 ) -> Pin<Box<dyn Future<Output = forge_core::Result<Uuid>> + Send + '_>> {
607 let workflow_name = workflow_name.to_string();
608 Box::pin(async move { self.start_by_name(&workflow_name, input).await })
609 }
610}
611
612#[cfg(test)]
613mod tests {
614 use super::*;
615
616 #[test]
617 fn test_workflow_result_types() {
618 let completed = WorkflowResult::Completed(serde_json::json!({}));
619 let _waiting = WorkflowResult::Waiting {
620 event_type: "approval".to_string(),
621 };
622 let _failed = WorkflowResult::Failed {
623 error: "test".to_string(),
624 };
625 let _compensated = WorkflowResult::Compensated;
626
627 match completed {
629 WorkflowResult::Completed(_) => {}
630 _ => panic!("Expected Completed"),
631 }
632 }
633}