1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use tokio::sync::RwLock;
7use uuid::Uuid;
8
9use super::registry::WorkflowRegistry;
10use super::state::{WorkflowRecord, WorkflowStepRecord};
11use forge_core::CircuitBreakerClient;
12use forge_core::function::WorkflowDispatch;
13use forge_core::workflow::{CompensationHandler, StepStatus, WorkflowContext, WorkflowStatus};
14
15#[derive(Debug)]
17pub enum WorkflowResult {
18 Completed(serde_json::Value),
20 Waiting { event_type: String },
22 Failed { error: String },
24 Compensated,
26}
27
28struct CompensationState {
30 handlers: HashMap<String, CompensationHandler>,
31 completed_steps: Vec<String>,
32}
33
34pub struct WorkflowExecutor {
36 registry: Arc<WorkflowRegistry>,
37 pool: sqlx::PgPool,
38 http_client: CircuitBreakerClient,
39 compensation_state: Arc<RwLock<HashMap<Uuid, CompensationState>>>,
41}
42
43impl WorkflowExecutor {
44 pub fn new(
46 registry: Arc<WorkflowRegistry>,
47 pool: sqlx::PgPool,
48 http_client: CircuitBreakerClient,
49 ) -> Self {
50 Self {
51 registry,
52 pool,
53 http_client,
54 compensation_state: Arc::new(RwLock::new(HashMap::new())),
55 }
56 }
57
58 pub async fn start<I: serde::Serialize>(
61 &self,
62 workflow_name: &str,
63 input: I,
64 owner_subject: Option<String>,
65 ) -> forge_core::Result<Uuid> {
66 let entry = self.registry.get_active(workflow_name).ok_or_else(|| {
67 forge_core::ForgeError::NotFound(format!(
68 "No active version of workflow '{}'",
69 workflow_name
70 ))
71 })?;
72
73 let input_value = serde_json::to_value(input)?;
74
75 let record = WorkflowRecord::new(
76 workflow_name,
77 entry.info.version,
78 entry.info.signature,
79 input_value.clone(),
80 owner_subject,
81 );
82 let run_id = record.id;
83
84 let entry_info = entry.info.clone();
86 let entry_handler = entry.handler.clone();
87
88 self.save_workflow(&record).await?;
90
91 let registry = self.registry.clone();
93 let pool = self.pool.clone();
94 let http_client = self.http_client.clone();
95 let compensation_state = self.compensation_state.clone();
96
97 tokio::spawn(async move {
98 let executor = WorkflowExecutor {
99 registry,
100 pool,
101 http_client,
102 compensation_state,
103 };
104 let entry = super::registry::WorkflowEntry {
105 info: entry_info,
106 handler: entry_handler,
107 };
108 if let Err(e) = executor.execute_workflow(run_id, &entry, input_value).await {
109 tracing::error!(
110 workflow_run_id = %run_id,
111 error = %e,
112 "Workflow execution failed"
113 );
114 }
115 });
116
117 Ok(run_id)
118 }
119
120 async fn execute_workflow(
122 &self,
123 run_id: Uuid,
124 entry: &super::registry::WorkflowEntry,
125 input: serde_json::Value,
126 ) -> forge_core::Result<WorkflowResult> {
127 self.update_workflow_status(run_id, WorkflowStatus::Running)
129 .await?;
130
131 let mut ctx = WorkflowContext::new(
133 run_id,
134 entry.info.name.to_string(),
135 self.pool.clone(),
136 self.http_client.clone(),
137 );
138 ctx.set_http_timeout(entry.info.http_timeout);
139
140 let handler = entry.handler.clone();
142 let result = tokio::time::timeout(entry.info.timeout, handler(&ctx, input)).await;
143
144 let compensation_state = CompensationState {
146 handlers: ctx.compensation_handlers(),
147 completed_steps: ctx.completed_steps_reversed().into_iter().rev().collect(),
148 };
149 self.compensation_state
150 .write()
151 .await
152 .insert(run_id, compensation_state);
153
154 match result {
155 Ok(Ok(output)) => {
156 self.complete_workflow(run_id, output.clone()).await?;
158 self.compensation_state.write().await.remove(&run_id);
159 Ok(WorkflowResult::Completed(output))
160 }
161 Ok(Err(e)) => {
162 if matches!(e, forge_core::ForgeError::WorkflowSuspended) {
164 return Ok(WorkflowResult::Waiting {
167 event_type: "timer".to_string(),
168 });
169 }
170 self.fail_workflow(run_id, &e.to_string()).await?;
172 Ok(WorkflowResult::Failed {
173 error: e.to_string(),
174 })
175 }
176 Err(_) => {
177 self.fail_workflow(run_id, "Workflow timed out").await?;
179 Ok(WorkflowResult::Failed {
180 error: "Workflow timed out".to_string(),
181 })
182 }
183 }
184 }
185
186 async fn execute_workflow_resumed(
188 &self,
189 run_id: Uuid,
190 entry: &super::registry::WorkflowEntry,
191 input: serde_json::Value,
192 started_at: chrono::DateTime<chrono::Utc>,
193 from_sleep: bool,
194 ) -> forge_core::Result<WorkflowResult> {
195 self.update_workflow_status(run_id, WorkflowStatus::Running)
197 .await?;
198
199 let step_records = self.get_workflow_steps(run_id).await?;
201 let mut step_states = std::collections::HashMap::new();
202 for step in step_records {
203 let status = step.status;
204 step_states.insert(
205 step.step_name.clone(),
206 forge_core::workflow::StepState {
207 name: step.step_name,
208 status,
209 result: step.result,
210 error: step.error,
211 started_at: step.started_at,
212 completed_at: step.completed_at,
213 },
214 );
215 }
216
217 let mut ctx = WorkflowContext::resumed(
219 run_id,
220 entry.info.name.to_string(),
221 started_at,
222 self.pool.clone(),
223 self.http_client.clone(),
224 )
225 .with_step_states(step_states);
226 ctx.set_http_timeout(entry.info.http_timeout);
227
228 if from_sleep {
230 ctx = ctx.with_resumed_from_sleep();
231 }
232
233 let handler = entry.handler.clone();
235 let result = tokio::time::timeout(entry.info.timeout, handler(&ctx, input)).await;
236
237 let compensation_state = CompensationState {
239 handlers: ctx.compensation_handlers(),
240 completed_steps: ctx.completed_steps_reversed().into_iter().rev().collect(),
241 };
242 self.compensation_state
243 .write()
244 .await
245 .insert(run_id, compensation_state);
246
247 match result {
248 Ok(Ok(output)) => {
249 self.complete_workflow(run_id, output.clone()).await?;
251 self.compensation_state.write().await.remove(&run_id);
252 Ok(WorkflowResult::Completed(output))
253 }
254 Ok(Err(e)) => {
255 if matches!(e, forge_core::ForgeError::WorkflowSuspended) {
257 return Ok(WorkflowResult::Waiting {
260 event_type: "timer".to_string(),
261 });
262 }
263 self.fail_workflow(run_id, &e.to_string()).await?;
265 Ok(WorkflowResult::Failed {
266 error: e.to_string(),
267 })
268 }
269 Err(_) => {
270 self.fail_workflow(run_id, "Workflow timed out").await?;
272 Ok(WorkflowResult::Failed {
273 error: "Workflow timed out".to_string(),
274 })
275 }
276 }
277 }
278
279 pub async fn resume(&self, run_id: Uuid) -> forge_core::Result<WorkflowResult> {
281 self.resume_internal(run_id, false).await
282 }
283
284 pub async fn resume_from_sleep(&self, run_id: Uuid) -> forge_core::Result<WorkflowResult> {
286 self.resume_internal(run_id, true).await
287 }
288
289 async fn resume_internal(
291 &self,
292 run_id: Uuid,
293 from_sleep: bool,
294 ) -> forge_core::Result<WorkflowResult> {
295 let record = self.get_workflow(run_id).await?;
296
297 match record.status {
299 WorkflowStatus::Running | WorkflowStatus::Waiting => {
300 }
302 status if status.is_terminal() || status.is_blocked() => {
303 return Err(forge_core::ForgeError::Validation(format!(
304 "Cannot resume workflow in {} state",
305 status.as_str()
306 )));
307 }
308 _ => {}
309 }
310
311 match self.registry.validate_resume(
313 &record.workflow_name,
314 &record.workflow_version,
315 &record.workflow_signature,
316 ) {
317 Ok(entry) => {
318 self.execute_workflow_resumed(
319 run_id,
320 entry,
321 record.input,
322 record.started_at,
323 from_sleep,
324 )
325 .await
326 }
327 Err(reason) => {
328 let status = reason.to_status();
330 let description = reason.description();
331 self.block_workflow(run_id, status, &description).await?;
332 tracing::warn!(
333 workflow_run_id = %run_id,
334 workflow_name = %record.workflow_name,
335 workflow_version = %record.workflow_version,
336 reason = %description,
337 "Workflow run blocked"
338 );
339 Ok(WorkflowResult::Failed { error: description })
340 }
341 }
342 }
343
344 pub async fn status(&self, run_id: Uuid) -> forge_core::Result<WorkflowRecord> {
346 self.get_workflow(run_id).await
347 }
348
349 pub async fn cancel(&self, run_id: Uuid) -> forge_core::Result<()> {
359 self.update_workflow_status(run_id, WorkflowStatus::Compensating)
360 .await?;
361
362 let state = self.compensation_state.write().await.remove(&run_id);
364
365 if let Some(state) = state {
366 let steps = self.get_workflow_steps(run_id).await?;
368
369 for step_name in state.completed_steps.iter().rev() {
374 if let Some(handler) = state.handlers.get(step_name) {
375 let step_result = steps
377 .iter()
378 .find(|s| &s.step_name == step_name)
379 .and_then(|s| s.result.clone())
380 .unwrap_or(serde_json::Value::Null);
381
382 match handler(step_result).await {
384 Ok(()) => {
385 tracing::info!(
386 workflow_run_id = %run_id,
387 step = %step_name,
388 "Compensation completed"
389 );
390 self.update_step_status(run_id, step_name, StepStatus::Compensated)
391 .await?;
392 }
393 Err(e) => {
394 tracing::error!(
395 workflow_run_id = %run_id,
396 step = %step_name,
397 error = %e,
398 "Compensation failed"
399 );
400 }
402 }
403 } else {
404 self.update_step_status(run_id, step_name, StepStatus::Compensated)
406 .await?;
407 }
408 }
409 } else {
410 let msg =
412 "Compensation handlers unavailable (likely restart); refusing to mark compensated";
413 tracing::error!(workflow_run_id = %run_id, "{msg}");
414 self.fail_workflow(run_id, msg).await?;
415 return Err(forge_core::ForgeError::InvalidState(msg.to_string()));
416 }
417
418 self.update_workflow_status(run_id, WorkflowStatus::Compensated)
419 .await?;
420
421 Ok(())
422 }
423
424 async fn get_workflow_steps(
426 &self,
427 workflow_run_id: Uuid,
428 ) -> forge_core::Result<Vec<WorkflowStepRecord>> {
429 let rows = sqlx::query!(
430 r#"
431 SELECT id, workflow_run_id, step_name, status, result, error, started_at, completed_at
432 FROM forge_workflow_steps
433 WHERE workflow_run_id = $1
434 ORDER BY started_at ASC
435 "#,
436 workflow_run_id,
437 )
438 .fetch_all(&self.pool)
439 .await
440 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
441
442 rows.into_iter()
443 .map(|row| {
444 let status = row.status.parse().map_err(|e| {
445 forge_core::ForgeError::Database(format!(
446 "Invalid step status '{}': {}",
447 row.status, e
448 ))
449 })?;
450 Ok(WorkflowStepRecord {
451 id: row.id,
452 workflow_run_id: row.workflow_run_id,
453 step_name: row.step_name,
454 status,
455 result: row.result,
456 error: row.error,
457 started_at: row.started_at,
458 completed_at: row.completed_at,
459 })
460 })
461 .collect()
462 }
463
464 async fn update_step_status(
466 &self,
467 workflow_run_id: Uuid,
468 step_name: &str,
469 status: StepStatus,
470 ) -> forge_core::Result<()> {
471 sqlx::query!(
472 r#"
473 UPDATE forge_workflow_steps
474 SET status = $3
475 WHERE workflow_run_id = $1 AND step_name = $2
476 "#,
477 workflow_run_id,
478 step_name,
479 status.as_str(),
480 )
481 .execute(&self.pool)
482 .await
483 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
484
485 Ok(())
486 }
487
488 async fn save_workflow(&self, record: &WorkflowRecord) -> forge_core::Result<()> {
490 sqlx::query!(
491 r#"
492 INSERT INTO forge_workflow_runs (
493 id, workflow_name, workflow_version, workflow_signature,
494 owner_subject, input, status, current_step,
495 step_results, started_at, trace_id
496 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
497 "#,
498 record.id,
499 &record.workflow_name,
500 &record.workflow_version,
501 &record.workflow_signature,
502 record.owner_subject as _,
503 record.input as _,
504 record.status.as_str(),
505 record.current_step as _,
506 record.step_results as _,
507 record.started_at,
508 record.trace_id.as_deref(),
509 )
510 .execute(&self.pool)
511 .await
512 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
513
514 Ok(())
515 }
516
517 async fn get_workflow(&self, run_id: Uuid) -> forge_core::Result<WorkflowRecord> {
519 let row = sqlx::query!(
520 r#"
521 SELECT id, workflow_name, workflow_version, workflow_signature,
522 owner_subject, input, output, status, blocking_reason,
523 resolution_reason, current_step, step_results, started_at,
524 completed_at, error, trace_id
525 FROM forge_workflow_runs
526 WHERE id = $1
527 "#,
528 run_id,
529 )
530 .fetch_optional(&self.pool)
531 .await
532 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
533
534 let row = row.ok_or_else(|| {
535 forge_core::ForgeError::NotFound(format!("Workflow run {} not found", run_id))
536 })?;
537
538 let status = row.status.parse().map_err(|e| {
539 forge_core::ForgeError::Database(format!(
540 "Invalid workflow status '{}': {}",
541 row.status, e
542 ))
543 })?;
544 Ok(WorkflowRecord {
545 id: row.id,
546 workflow_name: row.workflow_name,
547 workflow_version: row.workflow_version,
548 workflow_signature: row.workflow_signature,
549 owner_subject: row.owner_subject,
550 input: row.input,
551 output: row.output,
552 status,
553 blocking_reason: row.blocking_reason,
554 resolution_reason: row.resolution_reason,
555 current_step: row.current_step,
556 step_results: row.step_results.unwrap_or_default(),
557 started_at: row.started_at,
558 completed_at: row.completed_at,
559 error: row.error,
560 trace_id: row.trace_id,
561 })
562 }
563
564 fn valid_source_states(target: &WorkflowStatus) -> &'static [&'static str] {
567 match target {
568 WorkflowStatus::Running => &["created", "waiting", "running"],
569 WorkflowStatus::Waiting => &["running"],
570 WorkflowStatus::Completed => &["running"],
571 WorkflowStatus::Compensating => &["running", "waiting", "failed"],
572 WorkflowStatus::Compensated => &["compensating"],
573 WorkflowStatus::Failed => &["running", "waiting", "compensating"],
574 WorkflowStatus::BlockedMissingVersion
575 | WorkflowStatus::BlockedSignatureMismatch
576 | WorkflowStatus::BlockedMissingHandler => &["waiting", "running", "created"],
577 WorkflowStatus::RetiredUnresumable | WorkflowStatus::CancelledByOperator => &[
578 "created",
579 "running",
580 "waiting",
581 "failed",
582 "blocked_missing_version",
583 "blocked_signature_mismatch",
584 "blocked_missing_handler",
585 ],
586 WorkflowStatus::Created => &[], }
588 }
589
590 async fn update_workflow_status(
595 &self,
596 run_id: Uuid,
597 status: WorkflowStatus,
598 ) -> forge_core::Result<()> {
599 let valid_from = Self::valid_source_states(&status);
600
601 if !valid_from.is_empty() {
602 let current = sqlx::query_scalar!(
603 "SELECT status FROM forge_workflow_runs WHERE id = $1",
604 run_id,
605 )
606 .fetch_optional(&self.pool)
607 .await
608 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
609
610 match current {
611 Some(ref s) if valid_from.contains(&s.as_str()) => {}
612 Some(_) => {
613 return Err(forge_core::ForgeError::InvalidState(format!(
614 "Cannot transition workflow {} to {:?}: invalid current state",
615 run_id, status
616 )));
617 }
618 None => {
619 return Err(forge_core::ForgeError::NotFound(format!(
620 "Workflow run {} not found",
621 run_id
622 )));
623 }
624 }
625 }
626
627 sqlx::query!(
628 "UPDATE forge_workflow_runs SET status = $1 WHERE id = $2",
629 status.as_str(),
630 run_id,
631 )
632 .execute(&self.pool)
633 .await
634 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
635
636 Ok(())
637 }
638
639 async fn complete_workflow(
641 &self,
642 run_id: Uuid,
643 output: serde_json::Value,
644 ) -> forge_core::Result<()> {
645 let result = sqlx::query!(
646 "UPDATE forge_workflow_runs SET status = 'completed', output = $1, completed_at = NOW() WHERE id = $2 AND status = 'running'",
647 output as _,
648 run_id,
649 )
650 .execute(&self.pool)
651 .await
652 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
653
654 if result.rows_affected() == 0 {
655 return Err(forge_core::ForgeError::InvalidState(format!(
656 "Cannot complete workflow {}: not in 'running' state",
657 run_id
658 )));
659 }
660
661 Ok(())
662 }
663
664 async fn fail_workflow(&self, run_id: Uuid, error: &str) -> forge_core::Result<()> {
666 let result = sqlx::query!(
667 "UPDATE forge_workflow_runs SET status = 'failed', error = $1, completed_at = NOW() WHERE id = $2 AND status IN ('running', 'waiting', 'compensating')",
668 error,
669 run_id,
670 )
671 .execute(&self.pool)
672 .await
673 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
674
675 if result.rows_affected() == 0 {
676 return Err(forge_core::ForgeError::InvalidState(format!(
677 "Cannot fail workflow {}: not in a valid state for failure",
678 run_id
679 )));
680 }
681
682 Ok(())
683 }
684
685 async fn block_workflow(
687 &self,
688 run_id: Uuid,
689 status: WorkflowStatus,
690 reason: &str,
691 ) -> forge_core::Result<()> {
692 sqlx::query!(
693 "UPDATE forge_workflow_runs SET status = $1, blocking_reason = $2 WHERE id = $3 AND status IN ('waiting', 'running', 'created')",
694 status.as_str(),
695 reason,
696 run_id,
697 )
698 .execute(&self.pool)
699 .await
700 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
701
702 Ok(())
703 }
704
705 pub async fn cancel_by_operator(&self, run_id: Uuid, reason: &str) -> forge_core::Result<()> {
707 let result = sqlx::query!(
708 "UPDATE forge_workflow_runs SET status = 'cancelled_by_operator', resolution_reason = $1, completed_at = NOW() WHERE id = $2 AND status NOT IN ('completed', 'compensated', 'cancelled_by_operator', 'retired_unresumable')",
709 reason,
710 run_id,
711 )
712 .execute(&self.pool)
713 .await
714 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
715
716 if result.rows_affected() == 0 {
717 return Err(forge_core::ForgeError::InvalidState(format!(
718 "Cannot cancel workflow {}: already in a terminal state",
719 run_id
720 )));
721 }
722
723 Ok(())
724 }
725
726 pub async fn retire_unresumable(&self, run_id: Uuid, reason: &str) -> forge_core::Result<()> {
728 let result = sqlx::query!(
729 "UPDATE forge_workflow_runs SET status = 'retired_unresumable', resolution_reason = $1, completed_at = NOW() WHERE id = $2 AND status NOT IN ('completed', 'compensated', 'cancelled_by_operator', 'retired_unresumable')",
730 reason,
731 run_id,
732 )
733 .execute(&self.pool)
734 .await
735 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
736
737 if result.rows_affected() == 0 {
738 return Err(forge_core::ForgeError::InvalidState(format!(
739 "Cannot retire workflow {}: already in a terminal state",
740 run_id
741 )));
742 }
743
744 Ok(())
745 }
746
747 pub async fn save_step(&self, step: &WorkflowStepRecord) -> forge_core::Result<()> {
749 sqlx::query!(
750 r#"
751 INSERT INTO forge_workflow_steps (
752 id, workflow_run_id, step_name, status, result, error, started_at, completed_at
753 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
754 ON CONFLICT (workflow_run_id, step_name) DO UPDATE SET
755 status = EXCLUDED.status,
756 result = EXCLUDED.result,
757 error = EXCLUDED.error,
758 started_at = COALESCE(forge_workflow_steps.started_at, EXCLUDED.started_at),
759 completed_at = EXCLUDED.completed_at
760 "#,
761 step.id,
762 step.workflow_run_id,
763 &step.step_name,
764 step.status.as_str(),
765 step.result as _,
766 step.error as _,
767 step.started_at,
768 step.completed_at,
769 )
770 .execute(&self.pool)
771 .await
772 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
773
774 Ok(())
775 }
776
777 pub async fn start_by_name(
779 &self,
780 workflow_name: &str,
781 input: serde_json::Value,
782 owner_subject: Option<String>,
783 ) -> forge_core::Result<Uuid> {
784 self.start(workflow_name, input, owner_subject).await
785 }
786}
787
788impl WorkflowDispatch for WorkflowExecutor {
789 fn get_info(&self, workflow_name: &str) -> Option<forge_core::workflow::WorkflowInfo> {
790 self.registry
791 .get_active(workflow_name)
792 .map(|e| e.info.clone())
793 }
794
795 fn start_by_name(
796 &self,
797 workflow_name: &str,
798 input: serde_json::Value,
799 owner_subject: Option<String>,
800 ) -> Pin<Box<dyn Future<Output = forge_core::Result<Uuid>> + Send + '_>> {
801 let workflow_name = workflow_name.to_string();
802 Box::pin(async move {
803 self.start_by_name(&workflow_name, input, owner_subject)
804 .await
805 })
806 }
807}
808
809#[cfg(test)]
810#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
811mod tests {
812 use super::*;
813
814 #[test]
815 fn test_workflow_result_types() {
816 let completed = WorkflowResult::Completed(serde_json::json!({}));
817 let _waiting = WorkflowResult::Waiting {
818 event_type: "approval".to_string(),
819 };
820 let _failed = WorkflowResult::Failed {
821 error: "test".to_string(),
822 };
823 let _compensated = WorkflowResult::Compensated;
824
825 match completed {
827 WorkflowResult::Completed(_) => {}
828 _ => panic!("Expected Completed"),
829 }
830 }
831}