1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use tokio::sync::RwLock;
7use uuid::Uuid;
8
9use super::registry::WorkflowRegistry;
10use super::state::{WorkflowRecord, WorkflowStepRecord};
11use forge_core::CircuitBreakerClient;
12use forge_core::function::WorkflowDispatch;
13use forge_core::workflow::{CompensationHandler, StepStatus, WorkflowContext, WorkflowStatus};
14
15#[derive(Debug)]
17pub enum WorkflowResult {
18 Completed(serde_json::Value),
20 Waiting { event_type: String },
22 Failed { error: String },
24 Compensated,
26}
27
28struct CompensationState {
30 handlers: HashMap<String, CompensationHandler>,
31 completed_steps: Vec<String>,
32}
33
34pub struct WorkflowExecutor {
36 registry: Arc<WorkflowRegistry>,
37 pool: sqlx::PgPool,
38 http_client: CircuitBreakerClient,
39 compensation_state: Arc<RwLock<HashMap<Uuid, CompensationState>>>,
41}
42
43impl WorkflowExecutor {
44 pub fn new(
46 registry: Arc<WorkflowRegistry>,
47 pool: sqlx::PgPool,
48 http_client: CircuitBreakerClient,
49 ) -> Self {
50 Self {
51 registry,
52 pool,
53 http_client,
54 compensation_state: Arc::new(RwLock::new(HashMap::new())),
55 }
56 }
57
58 pub async fn start<I: serde::Serialize>(
61 &self,
62 workflow_name: &str,
63 input: I,
64 ) -> forge_core::Result<Uuid> {
65 let entry = self.registry.get(workflow_name).ok_or_else(|| {
66 forge_core::ForgeError::NotFound(format!("Workflow '{}' not found", workflow_name))
67 })?;
68
69 let input_value = serde_json::to_value(input)?;
70
71 let record = WorkflowRecord::new(workflow_name, entry.info.version, input_value.clone());
72 let run_id = record.id;
73
74 let entry_info = entry.info.clone();
76 let entry_handler = entry.handler.clone();
77
78 self.save_workflow(&record).await?;
80
81 let registry = self.registry.clone();
83 let pool = self.pool.clone();
84 let http_client = self.http_client.clone();
85 let compensation_state = self.compensation_state.clone();
86
87 tokio::spawn(async move {
88 let executor = WorkflowExecutor {
89 registry,
90 pool,
91 http_client,
92 compensation_state,
93 };
94 let entry = super::registry::WorkflowEntry {
95 info: entry_info,
96 handler: entry_handler,
97 };
98 if let Err(e) = executor.execute_workflow(run_id, &entry, input_value).await {
99 tracing::error!(
100 workflow_run_id = %run_id,
101 error = %e,
102 "Workflow execution failed"
103 );
104 }
105 });
106
107 Ok(run_id)
108 }
109
110 async fn execute_workflow(
112 &self,
113 run_id: Uuid,
114 entry: &super::registry::WorkflowEntry,
115 input: serde_json::Value,
116 ) -> forge_core::Result<WorkflowResult> {
117 self.update_workflow_status(run_id, WorkflowStatus::Running)
119 .await?;
120
121 let ctx = WorkflowContext::new(
123 run_id,
124 entry.info.name.to_string(),
125 entry.info.version,
126 self.pool.clone(),
127 self.http_client.inner().clone(),
128 );
129
130 let handler = entry.handler.clone();
132 let result = tokio::time::timeout(entry.info.timeout, handler(&ctx, input)).await;
133
134 let compensation_state = CompensationState {
136 handlers: ctx.compensation_handlers(),
137 completed_steps: ctx.completed_steps_reversed().into_iter().rev().collect(),
138 };
139 self.compensation_state
140 .write()
141 .await
142 .insert(run_id, compensation_state);
143
144 match result {
145 Ok(Ok(output)) => {
146 self.complete_workflow(run_id, output.clone()).await?;
148 self.compensation_state.write().await.remove(&run_id);
149 Ok(WorkflowResult::Completed(output))
150 }
151 Ok(Err(e)) => {
152 if matches!(e, forge_core::ForgeError::WorkflowSuspended) {
154 return Ok(WorkflowResult::Waiting {
157 event_type: "timer".to_string(),
158 });
159 }
160 self.fail_workflow(run_id, &e.to_string()).await?;
162 Ok(WorkflowResult::Failed {
163 error: e.to_string(),
164 })
165 }
166 Err(_) => {
167 self.fail_workflow(run_id, "Workflow timed out").await?;
169 Ok(WorkflowResult::Failed {
170 error: "Workflow timed out".to_string(),
171 })
172 }
173 }
174 }
175
176 async fn execute_workflow_resumed(
178 &self,
179 run_id: Uuid,
180 entry: &super::registry::WorkflowEntry,
181 input: serde_json::Value,
182 started_at: chrono::DateTime<chrono::Utc>,
183 from_sleep: bool,
184 ) -> forge_core::Result<WorkflowResult> {
185 self.update_workflow_status(run_id, WorkflowStatus::Running)
187 .await?;
188
189 let step_records = self.get_workflow_steps(run_id).await?;
191 let mut step_states = std::collections::HashMap::new();
192 for step in step_records {
193 let status = step.status;
194 step_states.insert(
195 step.step_name.clone(),
196 forge_core::workflow::StepState {
197 name: step.step_name,
198 status,
199 result: step.result,
200 error: step.error,
201 started_at: step.started_at,
202 completed_at: step.completed_at,
203 },
204 );
205 }
206
207 let mut ctx = WorkflowContext::resumed(
209 run_id,
210 entry.info.name.to_string(),
211 entry.info.version,
212 started_at,
213 self.pool.clone(),
214 self.http_client.inner().clone(),
215 )
216 .with_step_states(step_states);
217
218 if from_sleep {
220 ctx = ctx.with_resumed_from_sleep();
221 }
222
223 let handler = entry.handler.clone();
225 let result = tokio::time::timeout(entry.info.timeout, handler(&ctx, input)).await;
226
227 let compensation_state = CompensationState {
229 handlers: ctx.compensation_handlers(),
230 completed_steps: ctx.completed_steps_reversed().into_iter().rev().collect(),
231 };
232 self.compensation_state
233 .write()
234 .await
235 .insert(run_id, compensation_state);
236
237 match result {
238 Ok(Ok(output)) => {
239 self.complete_workflow(run_id, output.clone()).await?;
241 self.compensation_state.write().await.remove(&run_id);
242 Ok(WorkflowResult::Completed(output))
243 }
244 Ok(Err(e)) => {
245 if matches!(e, forge_core::ForgeError::WorkflowSuspended) {
247 return Ok(WorkflowResult::Waiting {
250 event_type: "timer".to_string(),
251 });
252 }
253 self.fail_workflow(run_id, &e.to_string()).await?;
255 Ok(WorkflowResult::Failed {
256 error: e.to_string(),
257 })
258 }
259 Err(_) => {
260 self.fail_workflow(run_id, "Workflow timed out").await?;
262 Ok(WorkflowResult::Failed {
263 error: "Workflow timed out".to_string(),
264 })
265 }
266 }
267 }
268
269 pub async fn resume(&self, run_id: Uuid) -> forge_core::Result<WorkflowResult> {
271 self.resume_internal(run_id, false).await
272 }
273
274 pub async fn resume_from_sleep(&self, run_id: Uuid) -> forge_core::Result<WorkflowResult> {
276 self.resume_internal(run_id, true).await
277 }
278
279 async fn resume_internal(
281 &self,
282 run_id: Uuid,
283 from_sleep: bool,
284 ) -> forge_core::Result<WorkflowResult> {
285 let record = self.get_workflow(run_id).await?;
286
287 let entry = self.registry.get(&record.workflow_name).ok_or_else(|| {
288 forge_core::ForgeError::NotFound(format!(
289 "Workflow '{}' not found",
290 record.workflow_name
291 ))
292 })?;
293
294 match record.status {
296 WorkflowStatus::Running | WorkflowStatus::Waiting => {
297 }
299 status if status.is_terminal() => {
300 return Err(forge_core::ForgeError::Validation(format!(
301 "Cannot resume workflow in {} state",
302 status.as_str()
303 )));
304 }
305 _ => {}
306 }
307
308 self.execute_workflow_resumed(run_id, entry, record.input, record.started_at, from_sleep)
309 .await
310 }
311
312 pub async fn status(&self, run_id: Uuid) -> forge_core::Result<WorkflowRecord> {
314 self.get_workflow(run_id).await
315 }
316
317 pub async fn cancel(&self, run_id: Uuid) -> forge_core::Result<()> {
319 self.update_workflow_status(run_id, WorkflowStatus::Compensating)
320 .await?;
321
322 let state = self.compensation_state.write().await.remove(&run_id);
324
325 if let Some(state) = state {
326 let steps = self.get_workflow_steps(run_id).await?;
328
329 for step_name in state.completed_steps.iter().rev() {
331 if let Some(handler) = state.handlers.get(step_name) {
332 let step_result = steps
334 .iter()
335 .find(|s| &s.step_name == step_name)
336 .and_then(|s| s.result.clone())
337 .unwrap_or(serde_json::Value::Null);
338
339 match handler(step_result).await {
341 Ok(()) => {
342 tracing::info!(
343 workflow_run_id = %run_id,
344 step = %step_name,
345 "Compensation completed"
346 );
347 self.update_step_status(run_id, step_name, StepStatus::Compensated)
348 .await?;
349 }
350 Err(e) => {
351 tracing::error!(
352 workflow_run_id = %run_id,
353 step = %step_name,
354 error = %e,
355 "Compensation failed"
356 );
357 }
359 }
360 } else {
361 self.update_step_status(run_id, step_name, StepStatus::Compensated)
363 .await?;
364 }
365 }
366 } else {
367 tracing::warn!(
370 workflow_run_id = %run_id,
371 "No compensation state found, marking as compensated without handlers"
372 );
373 }
374
375 self.update_workflow_status(run_id, WorkflowStatus::Compensated)
376 .await?;
377
378 Ok(())
379 }
380
381 async fn get_workflow_steps(
383 &self,
384 workflow_run_id: Uuid,
385 ) -> forge_core::Result<Vec<WorkflowStepRecord>> {
386 let rows = sqlx::query(
387 r#"
388 SELECT id, workflow_run_id, step_name, status, result, error, started_at, completed_at
389 FROM forge_workflow_steps
390 WHERE workflow_run_id = $1
391 ORDER BY started_at ASC
392 "#,
393 )
394 .bind(workflow_run_id)
395 .fetch_all(&self.pool)
396 .await
397 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
398
399 use sqlx::Row;
400 Ok(rows
401 .into_iter()
402 .map(|row| WorkflowStepRecord {
403 id: row.get("id"),
404 workflow_run_id: row.get("workflow_run_id"),
405 step_name: row.get("step_name"),
406 status: row.get::<String, _>("status").parse().unwrap(),
407 result: row.get("result"),
408 error: row.get("error"),
409 started_at: row.get("started_at"),
410 completed_at: row.get("completed_at"),
411 })
412 .collect())
413 }
414
415 async fn update_step_status(
417 &self,
418 workflow_run_id: Uuid,
419 step_name: &str,
420 status: StepStatus,
421 ) -> forge_core::Result<()> {
422 sqlx::query(
423 r#"
424 UPDATE forge_workflow_steps
425 SET status = $3
426 WHERE workflow_run_id = $1 AND step_name = $2
427 "#,
428 )
429 .bind(workflow_run_id)
430 .bind(step_name)
431 .bind(status.as_str())
432 .execute(&self.pool)
433 .await
434 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
435
436 Ok(())
437 }
438
439 async fn save_workflow(&self, record: &WorkflowRecord) -> forge_core::Result<()> {
441 sqlx::query(
442 r#"
443 INSERT INTO forge_workflow_runs (
444 id, workflow_name, input, status, current_step,
445 step_results, started_at, trace_id
446 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
447 "#,
448 )
449 .bind(record.id)
450 .bind(&record.workflow_name)
451 .bind(&record.input)
452 .bind(record.status.as_str())
453 .bind(&record.current_step)
454 .bind(&record.step_results)
455 .bind(record.started_at)
456 .bind(&record.trace_id)
457 .execute(&self.pool)
458 .await
459 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
460
461 Ok(())
462 }
463
464 async fn get_workflow(&self, run_id: Uuid) -> forge_core::Result<WorkflowRecord> {
466 let row = sqlx::query(
467 r#"
468 SELECT id, workflow_name, input, output, status, current_step,
469 step_results, started_at, completed_at, error, trace_id
470 FROM forge_workflow_runs
471 WHERE id = $1
472 "#,
473 )
474 .bind(run_id)
475 .fetch_optional(&self.pool)
476 .await
477 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
478
479 let row = row.ok_or_else(|| {
480 forge_core::ForgeError::NotFound(format!("Workflow run {} not found", run_id))
481 })?;
482
483 use sqlx::Row;
484 Ok(WorkflowRecord {
485 id: row.get("id"),
486 workflow_name: row.get("workflow_name"),
487 version: 1, input: row.get("input"),
489 output: row.get("output"),
490 status: row.get::<String, _>("status").parse().unwrap(),
491 current_step: row.get("current_step"),
492 step_results: row.get("step_results"),
493 started_at: row.get("started_at"),
494 completed_at: row.get("completed_at"),
495 error: row.get("error"),
496 trace_id: row.get("trace_id"),
497 })
498 }
499
500 async fn update_workflow_status(
502 &self,
503 run_id: Uuid,
504 status: WorkflowStatus,
505 ) -> forge_core::Result<()> {
506 sqlx::query(
507 r#"
508 UPDATE forge_workflow_runs
509 SET status = $2
510 WHERE id = $1
511 "#,
512 )
513 .bind(run_id)
514 .bind(status.as_str())
515 .execute(&self.pool)
516 .await
517 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
518
519 Ok(())
520 }
521
522 async fn complete_workflow(
524 &self,
525 run_id: Uuid,
526 output: serde_json::Value,
527 ) -> forge_core::Result<()> {
528 sqlx::query(
529 r#"
530 UPDATE forge_workflow_runs
531 SET status = 'completed', output = $2, completed_at = NOW()
532 WHERE id = $1
533 "#,
534 )
535 .bind(run_id)
536 .bind(output)
537 .execute(&self.pool)
538 .await
539 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
540
541 Ok(())
542 }
543
544 async fn fail_workflow(&self, run_id: Uuid, error: &str) -> forge_core::Result<()> {
546 sqlx::query(
547 r#"
548 UPDATE forge_workflow_runs
549 SET status = 'failed', error = $2, completed_at = NOW()
550 WHERE id = $1
551 "#,
552 )
553 .bind(run_id)
554 .bind(error)
555 .execute(&self.pool)
556 .await
557 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
558
559 Ok(())
560 }
561
562 pub async fn save_step(&self, step: &WorkflowStepRecord) -> forge_core::Result<()> {
564 sqlx::query(
565 r#"
566 INSERT INTO forge_workflow_steps (
567 id, workflow_run_id, step_name, status, result, error, started_at, completed_at
568 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
569 ON CONFLICT (workflow_run_id, step_name) DO UPDATE SET
570 status = EXCLUDED.status,
571 result = EXCLUDED.result,
572 error = EXCLUDED.error,
573 started_at = COALESCE(forge_workflow_steps.started_at, EXCLUDED.started_at),
574 completed_at = EXCLUDED.completed_at
575 "#,
576 )
577 .bind(step.id)
578 .bind(step.workflow_run_id)
579 .bind(&step.step_name)
580 .bind(step.status.as_str())
581 .bind(&step.result)
582 .bind(&step.error)
583 .bind(step.started_at)
584 .bind(step.completed_at)
585 .execute(&self.pool)
586 .await
587 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
588
589 Ok(())
590 }
591
592 pub async fn start_by_name(
594 &self,
595 workflow_name: &str,
596 input: serde_json::Value,
597 ) -> forge_core::Result<Uuid> {
598 self.start(workflow_name, input).await
599 }
600}
601
602impl WorkflowDispatch for WorkflowExecutor {
603 fn get_info(&self, workflow_name: &str) -> Option<forge_core::workflow::WorkflowInfo> {
604 self.registry.get(workflow_name).map(|e| e.info.clone())
605 }
606
607 fn start_by_name(
608 &self,
609 workflow_name: &str,
610 input: serde_json::Value,
611 ) -> Pin<Box<dyn Future<Output = forge_core::Result<Uuid>> + Send + '_>> {
612 let workflow_name = workflow_name.to_string();
613 Box::pin(async move { self.start_by_name(&workflow_name, input).await })
614 }
615}
616
617#[cfg(test)]
618mod tests {
619 use super::*;
620
621 #[test]
622 fn test_workflow_result_types() {
623 let completed = WorkflowResult::Completed(serde_json::json!({}));
624 let _waiting = WorkflowResult::Waiting {
625 event_type: "approval".to_string(),
626 };
627 let _failed = WorkflowResult::Failed {
628 error: "test".to_string(),
629 };
630 let _compensated = WorkflowResult::Compensated;
631
632 match completed {
634 WorkflowResult::Completed(_) => {}
635 _ => panic!("Expected Completed"),
636 }
637 }
638}