1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use tokio::sync::RwLock;
7use uuid::Uuid;
8
9use super::registry::WorkflowRegistry;
10use super::state::{WorkflowRecord, WorkflowStepRecord};
11use forge_core::CircuitBreakerClient;
12use forge_core::function::WorkflowDispatch;
13use forge_core::workflow::{CompensationHandler, StepStatus, WorkflowContext, WorkflowStatus};
14
15#[derive(Debug)]
17pub enum WorkflowResult {
18 Completed(serde_json::Value),
20 Waiting { event_type: String },
22 Failed { error: String },
24 Compensated,
26}
27
28struct CompensationState {
30 handlers: HashMap<String, CompensationHandler>,
31 completed_steps: Vec<String>,
32}
33
34pub struct WorkflowExecutor {
36 registry: Arc<WorkflowRegistry>,
37 pool: sqlx::PgPool,
38 http_client: CircuitBreakerClient,
39 compensation_state: Arc<RwLock<HashMap<Uuid, CompensationState>>>,
41}
42
43impl WorkflowExecutor {
44 pub fn new(
46 registry: Arc<WorkflowRegistry>,
47 pool: sqlx::PgPool,
48 http_client: CircuitBreakerClient,
49 ) -> Self {
50 Self {
51 registry,
52 pool,
53 http_client,
54 compensation_state: Arc::new(RwLock::new(HashMap::new())),
55 }
56 }
57
58 pub async fn start<I: serde::Serialize>(
61 &self,
62 workflow_name: &str,
63 input: I,
64 owner_subject: Option<String>,
65 ) -> forge_core::Result<Uuid> {
66 let entry = self.registry.get(workflow_name).ok_or_else(|| {
67 forge_core::ForgeError::NotFound(format!("Workflow '{}' not found", workflow_name))
68 })?;
69
70 let input_value = serde_json::to_value(input)?;
71
72 let record = WorkflowRecord::new(
73 workflow_name,
74 entry.info.version,
75 input_value.clone(),
76 owner_subject,
77 );
78 let run_id = record.id;
79
80 let entry_info = entry.info.clone();
82 let entry_handler = entry.handler.clone();
83
84 self.save_workflow(&record).await?;
86
87 let registry = self.registry.clone();
89 let pool = self.pool.clone();
90 let http_client = self.http_client.clone();
91 let compensation_state = self.compensation_state.clone();
92
93 tokio::spawn(async move {
94 let executor = WorkflowExecutor {
95 registry,
96 pool,
97 http_client,
98 compensation_state,
99 };
100 let entry = super::registry::WorkflowEntry {
101 info: entry_info,
102 handler: entry_handler,
103 };
104 if let Err(e) = executor.execute_workflow(run_id, &entry, input_value).await {
105 tracing::error!(
106 workflow_run_id = %run_id,
107 error = %e,
108 "Workflow execution failed"
109 );
110 }
111 });
112
113 Ok(run_id)
114 }
115
116 async fn execute_workflow(
118 &self,
119 run_id: Uuid,
120 entry: &super::registry::WorkflowEntry,
121 input: serde_json::Value,
122 ) -> forge_core::Result<WorkflowResult> {
123 self.update_workflow_status(run_id, WorkflowStatus::Running)
125 .await?;
126
127 let ctx = WorkflowContext::new(
129 run_id,
130 entry.info.name.to_string(),
131 entry.info.version,
132 self.pool.clone(),
133 self.http_client.inner().clone(),
134 );
135
136 let handler = entry.handler.clone();
138 let result = tokio::time::timeout(entry.info.timeout, handler(&ctx, input)).await;
139
140 let compensation_state = CompensationState {
142 handlers: ctx.compensation_handlers(),
143 completed_steps: ctx.completed_steps_reversed().into_iter().rev().collect(),
144 };
145 self.compensation_state
146 .write()
147 .await
148 .insert(run_id, compensation_state);
149
150 match result {
151 Ok(Ok(output)) => {
152 self.complete_workflow(run_id, output.clone()).await?;
154 self.compensation_state.write().await.remove(&run_id);
155 Ok(WorkflowResult::Completed(output))
156 }
157 Ok(Err(e)) => {
158 if matches!(e, forge_core::ForgeError::WorkflowSuspended) {
160 return Ok(WorkflowResult::Waiting {
163 event_type: "timer".to_string(),
164 });
165 }
166 self.fail_workflow(run_id, &e.to_string()).await?;
168 Ok(WorkflowResult::Failed {
169 error: e.to_string(),
170 })
171 }
172 Err(_) => {
173 self.fail_workflow(run_id, "Workflow timed out").await?;
175 Ok(WorkflowResult::Failed {
176 error: "Workflow timed out".to_string(),
177 })
178 }
179 }
180 }
181
182 async fn execute_workflow_resumed(
184 &self,
185 run_id: Uuid,
186 entry: &super::registry::WorkflowEntry,
187 input: serde_json::Value,
188 started_at: chrono::DateTime<chrono::Utc>,
189 from_sleep: bool,
190 ) -> forge_core::Result<WorkflowResult> {
191 self.update_workflow_status(run_id, WorkflowStatus::Running)
193 .await?;
194
195 let step_records = self.get_workflow_steps(run_id).await?;
197 let mut step_states = std::collections::HashMap::new();
198 for step in step_records {
199 let status = step.status;
200 step_states.insert(
201 step.step_name.clone(),
202 forge_core::workflow::StepState {
203 name: step.step_name,
204 status,
205 result: step.result,
206 error: step.error,
207 started_at: step.started_at,
208 completed_at: step.completed_at,
209 },
210 );
211 }
212
213 let mut ctx = WorkflowContext::resumed(
215 run_id,
216 entry.info.name.to_string(),
217 entry.info.version,
218 started_at,
219 self.pool.clone(),
220 self.http_client.inner().clone(),
221 )
222 .with_step_states(step_states);
223
224 if from_sleep {
226 ctx = ctx.with_resumed_from_sleep();
227 }
228
229 let handler = entry.handler.clone();
231 let result = tokio::time::timeout(entry.info.timeout, handler(&ctx, input)).await;
232
233 let compensation_state = CompensationState {
235 handlers: ctx.compensation_handlers(),
236 completed_steps: ctx.completed_steps_reversed().into_iter().rev().collect(),
237 };
238 self.compensation_state
239 .write()
240 .await
241 .insert(run_id, compensation_state);
242
243 match result {
244 Ok(Ok(output)) => {
245 self.complete_workflow(run_id, output.clone()).await?;
247 self.compensation_state.write().await.remove(&run_id);
248 Ok(WorkflowResult::Completed(output))
249 }
250 Ok(Err(e)) => {
251 if matches!(e, forge_core::ForgeError::WorkflowSuspended) {
253 return Ok(WorkflowResult::Waiting {
256 event_type: "timer".to_string(),
257 });
258 }
259 self.fail_workflow(run_id, &e.to_string()).await?;
261 Ok(WorkflowResult::Failed {
262 error: e.to_string(),
263 })
264 }
265 Err(_) => {
266 self.fail_workflow(run_id, "Workflow timed out").await?;
268 Ok(WorkflowResult::Failed {
269 error: "Workflow timed out".to_string(),
270 })
271 }
272 }
273 }
274
275 pub async fn resume(&self, run_id: Uuid) -> forge_core::Result<WorkflowResult> {
277 self.resume_internal(run_id, false).await
278 }
279
280 pub async fn resume_from_sleep(&self, run_id: Uuid) -> forge_core::Result<WorkflowResult> {
282 self.resume_internal(run_id, true).await
283 }
284
285 async fn resume_internal(
287 &self,
288 run_id: Uuid,
289 from_sleep: bool,
290 ) -> forge_core::Result<WorkflowResult> {
291 let record = self.get_workflow(run_id).await?;
292
293 let entry = self
294 .registry
295 .get_version(&record.workflow_name, record.version)
296 .ok_or_else(|| {
297 forge_core::ForgeError::NotFound(format!(
298 "Workflow '{}' version {} not found",
299 record.workflow_name, record.version
300 ))
301 })?;
302
303 match record.status {
305 WorkflowStatus::Running | WorkflowStatus::Waiting => {
306 }
308 status if status.is_terminal() => {
309 return Err(forge_core::ForgeError::Validation(format!(
310 "Cannot resume workflow in {} state",
311 status.as_str()
312 )));
313 }
314 _ => {}
315 }
316
317 self.execute_workflow_resumed(run_id, entry, record.input, record.started_at, from_sleep)
318 .await
319 }
320
321 pub async fn status(&self, run_id: Uuid) -> forge_core::Result<WorkflowRecord> {
323 self.get_workflow(run_id).await
324 }
325
326 pub async fn cancel(&self, run_id: Uuid) -> forge_core::Result<()> {
336 self.update_workflow_status(run_id, WorkflowStatus::Compensating)
337 .await?;
338
339 let state = self.compensation_state.write().await.remove(&run_id);
341
342 if let Some(state) = state {
343 let steps = self.get_workflow_steps(run_id).await?;
345
346 for step_name in state.completed_steps.iter().rev() {
351 if let Some(handler) = state.handlers.get(step_name) {
352 let step_result = steps
354 .iter()
355 .find(|s| &s.step_name == step_name)
356 .and_then(|s| s.result.clone())
357 .unwrap_or(serde_json::Value::Null);
358
359 match handler(step_result).await {
361 Ok(()) => {
362 tracing::info!(
363 workflow_run_id = %run_id,
364 step = %step_name,
365 "Compensation completed"
366 );
367 self.update_step_status(run_id, step_name, StepStatus::Compensated)
368 .await?;
369 }
370 Err(e) => {
371 tracing::error!(
372 workflow_run_id = %run_id,
373 step = %step_name,
374 error = %e,
375 "Compensation failed"
376 );
377 }
379 }
380 } else {
381 self.update_step_status(run_id, step_name, StepStatus::Compensated)
383 .await?;
384 }
385 }
386 } else {
387 let msg =
389 "Compensation handlers unavailable (likely restart); refusing to mark compensated";
390 tracing::error!(workflow_run_id = %run_id, "{msg}");
391 self.fail_workflow(run_id, msg).await?;
392 return Err(forge_core::ForgeError::InvalidState(msg.to_string()));
393 }
394
395 self.update_workflow_status(run_id, WorkflowStatus::Compensated)
396 .await?;
397
398 Ok(())
399 }
400
401 async fn get_workflow_steps(
403 &self,
404 workflow_run_id: Uuid,
405 ) -> forge_core::Result<Vec<WorkflowStepRecord>> {
406 let rows = sqlx::query(
407 r#"
408 SELECT id, workflow_run_id, step_name, status, result, error, started_at, completed_at
409 FROM forge_workflow_steps
410 WHERE workflow_run_id = $1
411 ORDER BY started_at ASC
412 "#,
413 )
414 .bind(workflow_run_id)
415 .fetch_all(&self.pool)
416 .await
417 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
418
419 use sqlx::Row;
420 rows.into_iter()
421 .map(|row| {
422 let status_str = row.get::<String, _>("status");
423 let status = status_str.parse().map_err(|e| {
424 forge_core::ForgeError::Database(format!(
425 "Invalid step status '{}': {}",
426 status_str, e
427 ))
428 })?;
429 Ok(WorkflowStepRecord {
430 id: row.get("id"),
431 workflow_run_id: row.get("workflow_run_id"),
432 step_name: row.get("step_name"),
433 status,
434 result: row.get("result"),
435 error: row.get("error"),
436 started_at: row.get("started_at"),
437 completed_at: row.get("completed_at"),
438 })
439 })
440 .collect()
441 }
442
443 async fn update_step_status(
445 &self,
446 workflow_run_id: Uuid,
447 step_name: &str,
448 status: StepStatus,
449 ) -> forge_core::Result<()> {
450 sqlx::query(
451 r#"
452 UPDATE forge_workflow_steps
453 SET status = $3
454 WHERE workflow_run_id = $1 AND step_name = $2
455 "#,
456 )
457 .bind(workflow_run_id)
458 .bind(step_name)
459 .bind(status.as_str())
460 .execute(&self.pool)
461 .await
462 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
463
464 Ok(())
465 }
466
467 async fn save_workflow(&self, record: &WorkflowRecord) -> forge_core::Result<()> {
469 sqlx::query(
470 r#"
471 INSERT INTO forge_workflow_runs (
472 id, workflow_name, version, owner_subject, input, status, current_step,
473 step_results, started_at, trace_id
474 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
475 "#,
476 )
477 .bind(record.id)
478 .bind(&record.workflow_name)
479 .bind(record.version as i32)
480 .bind(&record.owner_subject)
481 .bind(&record.input)
482 .bind(record.status.as_str())
483 .bind(&record.current_step)
484 .bind(&record.step_results)
485 .bind(record.started_at)
486 .bind(&record.trace_id)
487 .execute(&self.pool)
488 .await
489 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
490
491 Ok(())
492 }
493
494 async fn get_workflow(&self, run_id: Uuid) -> forge_core::Result<WorkflowRecord> {
496 let row = sqlx::query(
497 r#"
498 SELECT id, workflow_name, version, owner_subject, input, output, status, current_step,
499 step_results, started_at, completed_at, error, trace_id
500 FROM forge_workflow_runs
501 WHERE id = $1
502 "#,
503 )
504 .bind(run_id)
505 .fetch_optional(&self.pool)
506 .await
507 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
508
509 let row = row.ok_or_else(|| {
510 forge_core::ForgeError::NotFound(format!("Workflow run {} not found", run_id))
511 })?;
512
513 use sqlx::Row;
514 let status_str = row.get::<String, _>("status");
515 let status = status_str.parse().map_err(|e| {
516 forge_core::ForgeError::Database(format!(
517 "Invalid workflow status '{}': {}",
518 status_str, e
519 ))
520 })?;
521 Ok(WorkflowRecord {
522 id: row.get("id"),
523 workflow_name: row.get("workflow_name"),
524 version: row.get::<i32, _>("version") as u32,
525 owner_subject: row.get("owner_subject"),
526 input: row.get("input"),
527 output: row.get("output"),
528 status,
529 current_step: row.get("current_step"),
530 step_results: row.get("step_results"),
531 started_at: row.get("started_at"),
532 completed_at: row.get("completed_at"),
533 error: row.get("error"),
534 trace_id: row.get("trace_id"),
535 })
536 }
537
538 async fn update_workflow_status(
540 &self,
541 run_id: Uuid,
542 status: WorkflowStatus,
543 ) -> forge_core::Result<()> {
544 sqlx::query(
545 r#"
546 UPDATE forge_workflow_runs
547 SET status = $2
548 WHERE id = $1
549 "#,
550 )
551 .bind(run_id)
552 .bind(status.as_str())
553 .execute(&self.pool)
554 .await
555 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
556
557 Ok(())
558 }
559
560 async fn complete_workflow(
562 &self,
563 run_id: Uuid,
564 output: serde_json::Value,
565 ) -> forge_core::Result<()> {
566 sqlx::query(
567 r#"
568 UPDATE forge_workflow_runs
569 SET status = 'completed', output = $2, completed_at = NOW()
570 WHERE id = $1
571 "#,
572 )
573 .bind(run_id)
574 .bind(output)
575 .execute(&self.pool)
576 .await
577 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
578
579 Ok(())
580 }
581
582 async fn fail_workflow(&self, run_id: Uuid, error: &str) -> forge_core::Result<()> {
584 sqlx::query(
585 r#"
586 UPDATE forge_workflow_runs
587 SET status = 'failed', error = $2, completed_at = NOW()
588 WHERE id = $1
589 "#,
590 )
591 .bind(run_id)
592 .bind(error)
593 .execute(&self.pool)
594 .await
595 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
596
597 Ok(())
598 }
599
600 pub async fn save_step(&self, step: &WorkflowStepRecord) -> forge_core::Result<()> {
602 sqlx::query(
603 r#"
604 INSERT INTO forge_workflow_steps (
605 id, workflow_run_id, step_name, status, result, error, started_at, completed_at
606 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
607 ON CONFLICT (workflow_run_id, step_name) DO UPDATE SET
608 status = EXCLUDED.status,
609 result = EXCLUDED.result,
610 error = EXCLUDED.error,
611 started_at = COALESCE(forge_workflow_steps.started_at, EXCLUDED.started_at),
612 completed_at = EXCLUDED.completed_at
613 "#,
614 )
615 .bind(step.id)
616 .bind(step.workflow_run_id)
617 .bind(&step.step_name)
618 .bind(step.status.as_str())
619 .bind(&step.result)
620 .bind(&step.error)
621 .bind(step.started_at)
622 .bind(step.completed_at)
623 .execute(&self.pool)
624 .await
625 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
626
627 Ok(())
628 }
629
630 pub async fn start_by_name(
632 &self,
633 workflow_name: &str,
634 input: serde_json::Value,
635 owner_subject: Option<String>,
636 ) -> forge_core::Result<Uuid> {
637 self.start(workflow_name, input, owner_subject).await
638 }
639}
640
641impl WorkflowDispatch for WorkflowExecutor {
642 fn get_info(&self, workflow_name: &str) -> Option<forge_core::workflow::WorkflowInfo> {
643 self.registry.get(workflow_name).map(|e| e.info.clone())
644 }
645
646 fn start_by_name(
647 &self,
648 workflow_name: &str,
649 input: serde_json::Value,
650 owner_subject: Option<String>,
651 ) -> Pin<Box<dyn Future<Output = forge_core::Result<Uuid>> + Send + '_>> {
652 let workflow_name = workflow_name.to_string();
653 Box::pin(async move {
654 self.start_by_name(&workflow_name, input, owner_subject)
655 .await
656 })
657 }
658}
659
660#[cfg(test)]
661#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
662mod tests {
663 use super::*;
664
665 #[test]
666 fn test_workflow_result_types() {
667 let completed = WorkflowResult::Completed(serde_json::json!({}));
668 let _waiting = WorkflowResult::Waiting {
669 event_type: "approval".to_string(),
670 };
671 let _failed = WorkflowResult::Failed {
672 error: "test".to_string(),
673 };
674 let _compensated = WorkflowResult::Compensated;
675
676 match completed {
678 WorkflowResult::Completed(_) => {}
679 _ => panic!("Expected Completed"),
680 }
681 }
682}