1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use tokio::sync::RwLock;
7use uuid::Uuid;
8
9use super::registry::WorkflowRegistry;
10use super::state::{WorkflowRecord, WorkflowStepRecord};
11use forge_core::CircuitBreakerClient;
12use forge_core::function::WorkflowDispatch;
13use forge_core::workflow::{CompensationHandler, StepStatus, WorkflowContext, WorkflowStatus};
14
15#[derive(Debug)]
17pub enum WorkflowResult {
18 Completed(serde_json::Value),
20 Waiting { event_type: String },
22 Failed { error: String },
24 Compensated,
26}
27
28struct CompensationState {
30 handlers: HashMap<String, CompensationHandler>,
31 completed_steps: Vec<String>,
32}
33
34pub struct WorkflowExecutor {
36 registry: Arc<WorkflowRegistry>,
37 pool: sqlx::PgPool,
38 http_client: CircuitBreakerClient,
39 compensation_state: Arc<RwLock<HashMap<Uuid, CompensationState>>>,
41}
42
43impl WorkflowExecutor {
44 pub fn new(
46 registry: Arc<WorkflowRegistry>,
47 pool: sqlx::PgPool,
48 http_client: CircuitBreakerClient,
49 ) -> Self {
50 Self {
51 registry,
52 pool,
53 http_client,
54 compensation_state: Arc::new(RwLock::new(HashMap::new())),
55 }
56 }
57
58 pub async fn start<I: serde::Serialize>(
61 &self,
62 workflow_name: &str,
63 input: I,
64 owner_subject: Option<String>,
65 ) -> forge_core::Result<Uuid> {
66 let entry = self.registry.get(workflow_name).ok_or_else(|| {
67 forge_core::ForgeError::NotFound(format!("Workflow '{}' not found", workflow_name))
68 })?;
69
70 let input_value = serde_json::to_value(input)?;
71
72 let record = WorkflowRecord::new(
73 workflow_name,
74 entry.info.version,
75 input_value.clone(),
76 owner_subject,
77 );
78 let run_id = record.id;
79
80 let entry_info = entry.info.clone();
82 let entry_handler = entry.handler.clone();
83
84 self.save_workflow(&record).await?;
86
87 let registry = self.registry.clone();
89 let pool = self.pool.clone();
90 let http_client = self.http_client.clone();
91 let compensation_state = self.compensation_state.clone();
92
93 tokio::spawn(async move {
94 let executor = WorkflowExecutor {
95 registry,
96 pool,
97 http_client,
98 compensation_state,
99 };
100 let entry = super::registry::WorkflowEntry {
101 info: entry_info,
102 handler: entry_handler,
103 };
104 if let Err(e) = executor.execute_workflow(run_id, &entry, input_value).await {
105 tracing::error!(
106 workflow_run_id = %run_id,
107 error = %e,
108 "Workflow execution failed"
109 );
110 }
111 });
112
113 Ok(run_id)
114 }
115
116 async fn execute_workflow(
118 &self,
119 run_id: Uuid,
120 entry: &super::registry::WorkflowEntry,
121 input: serde_json::Value,
122 ) -> forge_core::Result<WorkflowResult> {
123 self.update_workflow_status(run_id, WorkflowStatus::Running)
125 .await?;
126
127 let mut ctx = WorkflowContext::new(
129 run_id,
130 entry.info.name.to_string(),
131 entry.info.version,
132 self.pool.clone(),
133 self.http_client.clone(),
134 );
135 ctx.set_http_timeout(entry.info.http_timeout);
136
137 let handler = entry.handler.clone();
139 let result = tokio::time::timeout(entry.info.timeout, handler(&ctx, input)).await;
140
141 let compensation_state = CompensationState {
143 handlers: ctx.compensation_handlers(),
144 completed_steps: ctx.completed_steps_reversed().into_iter().rev().collect(),
145 };
146 self.compensation_state
147 .write()
148 .await
149 .insert(run_id, compensation_state);
150
151 match result {
152 Ok(Ok(output)) => {
153 self.complete_workflow(run_id, output.clone()).await?;
155 self.compensation_state.write().await.remove(&run_id);
156 Ok(WorkflowResult::Completed(output))
157 }
158 Ok(Err(e)) => {
159 if matches!(e, forge_core::ForgeError::WorkflowSuspended) {
161 return Ok(WorkflowResult::Waiting {
164 event_type: "timer".to_string(),
165 });
166 }
167 self.fail_workflow(run_id, &e.to_string()).await?;
169 Ok(WorkflowResult::Failed {
170 error: e.to_string(),
171 })
172 }
173 Err(_) => {
174 self.fail_workflow(run_id, "Workflow timed out").await?;
176 Ok(WorkflowResult::Failed {
177 error: "Workflow timed out".to_string(),
178 })
179 }
180 }
181 }
182
183 async fn execute_workflow_resumed(
185 &self,
186 run_id: Uuid,
187 entry: &super::registry::WorkflowEntry,
188 input: serde_json::Value,
189 started_at: chrono::DateTime<chrono::Utc>,
190 from_sleep: bool,
191 ) -> forge_core::Result<WorkflowResult> {
192 self.update_workflow_status(run_id, WorkflowStatus::Running)
194 .await?;
195
196 let step_records = self.get_workflow_steps(run_id).await?;
198 let mut step_states = std::collections::HashMap::new();
199 for step in step_records {
200 let status = step.status;
201 step_states.insert(
202 step.step_name.clone(),
203 forge_core::workflow::StepState {
204 name: step.step_name,
205 status,
206 result: step.result,
207 error: step.error,
208 started_at: step.started_at,
209 completed_at: step.completed_at,
210 },
211 );
212 }
213
214 let mut ctx = WorkflowContext::resumed(
216 run_id,
217 entry.info.name.to_string(),
218 entry.info.version,
219 started_at,
220 self.pool.clone(),
221 self.http_client.clone(),
222 )
223 .with_step_states(step_states);
224 ctx.set_http_timeout(entry.info.http_timeout);
225
226 if from_sleep {
228 ctx = ctx.with_resumed_from_sleep();
229 }
230
231 let handler = entry.handler.clone();
233 let result = tokio::time::timeout(entry.info.timeout, handler(&ctx, input)).await;
234
235 let compensation_state = CompensationState {
237 handlers: ctx.compensation_handlers(),
238 completed_steps: ctx.completed_steps_reversed().into_iter().rev().collect(),
239 };
240 self.compensation_state
241 .write()
242 .await
243 .insert(run_id, compensation_state);
244
245 match result {
246 Ok(Ok(output)) => {
247 self.complete_workflow(run_id, output.clone()).await?;
249 self.compensation_state.write().await.remove(&run_id);
250 Ok(WorkflowResult::Completed(output))
251 }
252 Ok(Err(e)) => {
253 if matches!(e, forge_core::ForgeError::WorkflowSuspended) {
255 return Ok(WorkflowResult::Waiting {
258 event_type: "timer".to_string(),
259 });
260 }
261 self.fail_workflow(run_id, &e.to_string()).await?;
263 Ok(WorkflowResult::Failed {
264 error: e.to_string(),
265 })
266 }
267 Err(_) => {
268 self.fail_workflow(run_id, "Workflow timed out").await?;
270 Ok(WorkflowResult::Failed {
271 error: "Workflow timed out".to_string(),
272 })
273 }
274 }
275 }
276
277 pub async fn resume(&self, run_id: Uuid) -> forge_core::Result<WorkflowResult> {
279 self.resume_internal(run_id, false).await
280 }
281
282 pub async fn resume_from_sleep(&self, run_id: Uuid) -> forge_core::Result<WorkflowResult> {
284 self.resume_internal(run_id, true).await
285 }
286
287 async fn resume_internal(
289 &self,
290 run_id: Uuid,
291 from_sleep: bool,
292 ) -> forge_core::Result<WorkflowResult> {
293 let record = self.get_workflow(run_id).await?;
294
295 let entry = self
296 .registry
297 .get_version(&record.workflow_name, record.version)
298 .ok_or_else(|| {
299 forge_core::ForgeError::NotFound(format!(
300 "Workflow '{}' version {} not found",
301 record.workflow_name, record.version
302 ))
303 })?;
304
305 match record.status {
307 WorkflowStatus::Running | WorkflowStatus::Waiting => {
308 }
310 status if status.is_terminal() => {
311 return Err(forge_core::ForgeError::Validation(format!(
312 "Cannot resume workflow in {} state",
313 status.as_str()
314 )));
315 }
316 _ => {}
317 }
318
319 self.execute_workflow_resumed(run_id, entry, record.input, record.started_at, from_sleep)
320 .await
321 }
322
323 pub async fn status(&self, run_id: Uuid) -> forge_core::Result<WorkflowRecord> {
325 self.get_workflow(run_id).await
326 }
327
328 pub async fn cancel(&self, run_id: Uuid) -> forge_core::Result<()> {
338 self.update_workflow_status(run_id, WorkflowStatus::Compensating)
339 .await?;
340
341 let state = self.compensation_state.write().await.remove(&run_id);
343
344 if let Some(state) = state {
345 let steps = self.get_workflow_steps(run_id).await?;
347
348 for step_name in state.completed_steps.iter().rev() {
353 if let Some(handler) = state.handlers.get(step_name) {
354 let step_result = steps
356 .iter()
357 .find(|s| &s.step_name == step_name)
358 .and_then(|s| s.result.clone())
359 .unwrap_or(serde_json::Value::Null);
360
361 match handler(step_result).await {
363 Ok(()) => {
364 tracing::info!(
365 workflow_run_id = %run_id,
366 step = %step_name,
367 "Compensation completed"
368 );
369 self.update_step_status(run_id, step_name, StepStatus::Compensated)
370 .await?;
371 }
372 Err(e) => {
373 tracing::error!(
374 workflow_run_id = %run_id,
375 step = %step_name,
376 error = %e,
377 "Compensation failed"
378 );
379 }
381 }
382 } else {
383 self.update_step_status(run_id, step_name, StepStatus::Compensated)
385 .await?;
386 }
387 }
388 } else {
389 let msg =
391 "Compensation handlers unavailable (likely restart); refusing to mark compensated";
392 tracing::error!(workflow_run_id = %run_id, "{msg}");
393 self.fail_workflow(run_id, msg).await?;
394 return Err(forge_core::ForgeError::InvalidState(msg.to_string()));
395 }
396
397 self.update_workflow_status(run_id, WorkflowStatus::Compensated)
398 .await?;
399
400 Ok(())
401 }
402
403 async fn get_workflow_steps(
405 &self,
406 workflow_run_id: Uuid,
407 ) -> forge_core::Result<Vec<WorkflowStepRecord>> {
408 let rows = sqlx::query(
409 r#"
410 SELECT id, workflow_run_id, step_name, status, result, error, started_at, completed_at
411 FROM forge_workflow_steps
412 WHERE workflow_run_id = $1
413 ORDER BY started_at ASC
414 "#,
415 )
416 .bind(workflow_run_id)
417 .fetch_all(&self.pool)
418 .await
419 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
420
421 use sqlx::Row;
422 rows.into_iter()
423 .map(|row| {
424 let status_str = row.get::<String, _>("status");
425 let status = status_str.parse().map_err(|e| {
426 forge_core::ForgeError::Database(format!(
427 "Invalid step status '{}': {}",
428 status_str, e
429 ))
430 })?;
431 Ok(WorkflowStepRecord {
432 id: row.get("id"),
433 workflow_run_id: row.get("workflow_run_id"),
434 step_name: row.get("step_name"),
435 status,
436 result: row.get("result"),
437 error: row.get("error"),
438 started_at: row.get("started_at"),
439 completed_at: row.get("completed_at"),
440 })
441 })
442 .collect()
443 }
444
445 async fn update_step_status(
447 &self,
448 workflow_run_id: Uuid,
449 step_name: &str,
450 status: StepStatus,
451 ) -> forge_core::Result<()> {
452 sqlx::query(
453 r#"
454 UPDATE forge_workflow_steps
455 SET status = $3
456 WHERE workflow_run_id = $1 AND step_name = $2
457 "#,
458 )
459 .bind(workflow_run_id)
460 .bind(step_name)
461 .bind(status.as_str())
462 .execute(&self.pool)
463 .await
464 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
465
466 Ok(())
467 }
468
469 async fn save_workflow(&self, record: &WorkflowRecord) -> forge_core::Result<()> {
471 sqlx::query(
472 r#"
473 INSERT INTO forge_workflow_runs (
474 id, workflow_name, version, owner_subject, input, status, current_step,
475 step_results, started_at, trace_id
476 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
477 "#,
478 )
479 .bind(record.id)
480 .bind(&record.workflow_name)
481 .bind(record.version as i32)
482 .bind(&record.owner_subject)
483 .bind(&record.input)
484 .bind(record.status.as_str())
485 .bind(&record.current_step)
486 .bind(&record.step_results)
487 .bind(record.started_at)
488 .bind(&record.trace_id)
489 .execute(&self.pool)
490 .await
491 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
492
493 Ok(())
494 }
495
496 async fn get_workflow(&self, run_id: Uuid) -> forge_core::Result<WorkflowRecord> {
498 let row = sqlx::query(
499 r#"
500 SELECT id, workflow_name, version, owner_subject, input, output, status, current_step,
501 step_results, started_at, completed_at, error, trace_id
502 FROM forge_workflow_runs
503 WHERE id = $1
504 "#,
505 )
506 .bind(run_id)
507 .fetch_optional(&self.pool)
508 .await
509 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
510
511 let row = row.ok_or_else(|| {
512 forge_core::ForgeError::NotFound(format!("Workflow run {} not found", run_id))
513 })?;
514
515 use sqlx::Row;
516 let status_str = row.get::<String, _>("status");
517 let status = status_str.parse().map_err(|e| {
518 forge_core::ForgeError::Database(format!(
519 "Invalid workflow status '{}': {}",
520 status_str, e
521 ))
522 })?;
523 Ok(WorkflowRecord {
524 id: row.get("id"),
525 workflow_name: row.get("workflow_name"),
526 version: row.get::<i32, _>("version") as u32,
527 owner_subject: row.get("owner_subject"),
528 input: row.get("input"),
529 output: row.get("output"),
530 status,
531 current_step: row.get("current_step"),
532 step_results: row.get("step_results"),
533 started_at: row.get("started_at"),
534 completed_at: row.get("completed_at"),
535 error: row.get("error"),
536 trace_id: row.get("trace_id"),
537 })
538 }
539
540 async fn update_workflow_status(
542 &self,
543 run_id: Uuid,
544 status: WorkflowStatus,
545 ) -> forge_core::Result<()> {
546 sqlx::query(
547 r#"
548 UPDATE forge_workflow_runs
549 SET status = $2
550 WHERE id = $1
551 "#,
552 )
553 .bind(run_id)
554 .bind(status.as_str())
555 .execute(&self.pool)
556 .await
557 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
558
559 Ok(())
560 }
561
562 async fn complete_workflow(
564 &self,
565 run_id: Uuid,
566 output: serde_json::Value,
567 ) -> forge_core::Result<()> {
568 sqlx::query(
569 r#"
570 UPDATE forge_workflow_runs
571 SET status = 'completed', output = $2, completed_at = NOW()
572 WHERE id = $1
573 "#,
574 )
575 .bind(run_id)
576 .bind(output)
577 .execute(&self.pool)
578 .await
579 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
580
581 Ok(())
582 }
583
584 async fn fail_workflow(&self, run_id: Uuid, error: &str) -> forge_core::Result<()> {
586 sqlx::query(
587 r#"
588 UPDATE forge_workflow_runs
589 SET status = 'failed', error = $2, completed_at = NOW()
590 WHERE id = $1
591 "#,
592 )
593 .bind(run_id)
594 .bind(error)
595 .execute(&self.pool)
596 .await
597 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
598
599 Ok(())
600 }
601
602 pub async fn save_step(&self, step: &WorkflowStepRecord) -> forge_core::Result<()> {
604 sqlx::query(
605 r#"
606 INSERT INTO forge_workflow_steps (
607 id, workflow_run_id, step_name, status, result, error, started_at, completed_at
608 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
609 ON CONFLICT (workflow_run_id, step_name) DO UPDATE SET
610 status = EXCLUDED.status,
611 result = EXCLUDED.result,
612 error = EXCLUDED.error,
613 started_at = COALESCE(forge_workflow_steps.started_at, EXCLUDED.started_at),
614 completed_at = EXCLUDED.completed_at
615 "#,
616 )
617 .bind(step.id)
618 .bind(step.workflow_run_id)
619 .bind(&step.step_name)
620 .bind(step.status.as_str())
621 .bind(&step.result)
622 .bind(&step.error)
623 .bind(step.started_at)
624 .bind(step.completed_at)
625 .execute(&self.pool)
626 .await
627 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
628
629 Ok(())
630 }
631
632 pub async fn start_by_name(
634 &self,
635 workflow_name: &str,
636 input: serde_json::Value,
637 owner_subject: Option<String>,
638 ) -> forge_core::Result<Uuid> {
639 self.start(workflow_name, input, owner_subject).await
640 }
641}
642
643impl WorkflowDispatch for WorkflowExecutor {
644 fn get_info(&self, workflow_name: &str) -> Option<forge_core::workflow::WorkflowInfo> {
645 self.registry.get(workflow_name).map(|e| e.info.clone())
646 }
647
648 fn start_by_name(
649 &self,
650 workflow_name: &str,
651 input: serde_json::Value,
652 owner_subject: Option<String>,
653 ) -> Pin<Box<dyn Future<Output = forge_core::Result<Uuid>> + Send + '_>> {
654 let workflow_name = workflow_name.to_string();
655 Box::pin(async move {
656 self.start_by_name(&workflow_name, input, owner_subject)
657 .await
658 })
659 }
660}
661
662#[cfg(test)]
663#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
664mod tests {
665 use super::*;
666
667 #[test]
668 fn test_workflow_result_types() {
669 let completed = WorkflowResult::Completed(serde_json::json!({}));
670 let _waiting = WorkflowResult::Waiting {
671 event_type: "approval".to_string(),
672 };
673 let _failed = WorkflowResult::Failed {
674 error: "test".to_string(),
675 };
676 let _compensated = WorkflowResult::Compensated;
677
678 match completed {
680 WorkflowResult::Completed(_) => {}
681 _ => panic!("Expected Completed"),
682 }
683 }
684}