1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, RwLock};
5use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use serde::de::DeserializeOwned;
9use tokio::sync::mpsc;
10use uuid::Uuid;
11
12use super::parallel::ParallelBuilder;
13use super::step::StepStatus;
14use super::suspend::{SuspendReason, WorkflowEvent};
15use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
16use crate::function::AuthContext;
17use crate::{ForgeError, Result};
18
19pub type CompensationHandler = Arc<
21 dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync,
22>;
23
24#[derive(Debug, Clone)]
26pub struct StepState {
27 pub name: String,
29 pub status: StepStatus,
31 pub result: Option<serde_json::Value>,
33 pub error: Option<String>,
35 pub started_at: Option<DateTime<Utc>>,
37 pub completed_at: Option<DateTime<Utc>>,
39}
40
41impl StepState {
42 pub fn new(name: impl Into<String>) -> Self {
44 Self {
45 name: name.into(),
46 status: StepStatus::Pending,
47 result: None,
48 error: None,
49 started_at: None,
50 completed_at: None,
51 }
52 }
53
54 pub fn start(&mut self) {
56 self.status = StepStatus::Running;
57 self.started_at = Some(Utc::now());
58 }
59
60 pub fn complete(&mut self, result: serde_json::Value) {
62 self.status = StepStatus::Completed;
63 self.result = Some(result);
64 self.completed_at = Some(Utc::now());
65 }
66
67 pub fn fail(&mut self, error: impl Into<String>) {
69 self.status = StepStatus::Failed;
70 self.error = Some(error.into());
71 self.completed_at = Some(Utc::now());
72 }
73
74 pub fn compensate(&mut self) {
76 self.status = StepStatus::Compensated;
77 }
78}
79
80pub struct WorkflowContext {
82 pub run_id: Uuid,
84 pub workflow_name: String,
86 pub version: u32,
88 pub started_at: DateTime<Utc>,
90 workflow_time: DateTime<Utc>,
92 pub auth: AuthContext,
94 db_pool: sqlx::PgPool,
96 http_client: reqwest::Client,
98 step_states: Arc<RwLock<HashMap<String, StepState>>>,
100 completed_steps: Arc<RwLock<Vec<String>>>,
102 compensation_handlers: Arc<RwLock<HashMap<String, CompensationHandler>>>,
104 suspend_tx: Option<mpsc::Sender<SuspendReason>>,
106 is_resumed: bool,
108 resumed_from_sleep: bool,
110 tenant_id: Option<Uuid>,
112 env_provider: Arc<dyn EnvProvider>,
114}
115
116impl WorkflowContext {
117 pub fn new(
119 run_id: Uuid,
120 workflow_name: String,
121 version: u32,
122 db_pool: sqlx::PgPool,
123 http_client: reqwest::Client,
124 ) -> Self {
125 let now = Utc::now();
126 Self {
127 run_id,
128 workflow_name,
129 version,
130 started_at: now,
131 workflow_time: now,
132 auth: AuthContext::unauthenticated(),
133 db_pool,
134 http_client,
135 step_states: Arc::new(RwLock::new(HashMap::new())),
136 completed_steps: Arc::new(RwLock::new(Vec::new())),
137 compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
138 suspend_tx: None,
139 is_resumed: false,
140 resumed_from_sleep: false,
141 tenant_id: None,
142 env_provider: Arc::new(RealEnvProvider::new()),
143 }
144 }
145
146 pub fn resumed(
148 run_id: Uuid,
149 workflow_name: String,
150 version: u32,
151 started_at: DateTime<Utc>,
152 db_pool: sqlx::PgPool,
153 http_client: reqwest::Client,
154 ) -> Self {
155 Self {
156 run_id,
157 workflow_name,
158 version,
159 started_at,
160 workflow_time: started_at,
161 auth: AuthContext::unauthenticated(),
162 db_pool,
163 http_client,
164 step_states: Arc::new(RwLock::new(HashMap::new())),
165 completed_steps: Arc::new(RwLock::new(Vec::new())),
166 compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
167 suspend_tx: None,
168 is_resumed: true,
169 resumed_from_sleep: false,
170 tenant_id: None,
171 env_provider: Arc::new(RealEnvProvider::new()),
172 }
173 }
174
175 pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
177 self.env_provider = provider;
178 self
179 }
180
181 pub fn with_resumed_from_sleep(mut self) -> Self {
183 self.resumed_from_sleep = true;
184 self
185 }
186
187 pub fn with_suspend_channel(mut self, tx: mpsc::Sender<SuspendReason>) -> Self {
189 self.suspend_tx = Some(tx);
190 self
191 }
192
193 pub fn with_tenant(mut self, tenant_id: Uuid) -> Self {
195 self.tenant_id = Some(tenant_id);
196 self
197 }
198
199 pub fn tenant_id(&self) -> Option<Uuid> {
200 self.tenant_id
201 }
202
203 pub fn is_resumed(&self) -> bool {
204 self.is_resumed
205 }
206
207 pub fn workflow_time(&self) -> DateTime<Utc> {
208 self.workflow_time
209 }
210
211 pub fn db(&self) -> crate::function::ForgeDb {
212 crate::function::ForgeDb::from_pool(&self.db_pool)
213 }
214
215 pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
217 Ok(crate::function::ForgeConn::Pool(
218 self.db_pool.acquire().await?,
219 ))
220 }
221
222 pub fn http(&self) -> &reqwest::Client {
223 &self.http_client
224 }
225
226 pub fn with_auth(mut self, auth: AuthContext) -> Self {
228 self.auth = auth;
229 self
230 }
231
232 pub fn with_step_states(self, states: HashMap<String, StepState>) -> Self {
234 let completed: Vec<String> = states
235 .iter()
236 .filter(|(_, s)| s.status == StepStatus::Completed)
237 .map(|(name, _)| name.clone())
238 .collect();
239
240 *self.step_states.write().expect("workflow lock poisoned") = states;
241 *self
242 .completed_steps
243 .write()
244 .expect("workflow lock poisoned") = completed;
245 self
246 }
247
248 pub fn get_step_state(&self, name: &str) -> Option<StepState> {
249 self.step_states
250 .read()
251 .expect("workflow lock poisoned")
252 .get(name)
253 .cloned()
254 }
255
256 pub fn is_step_completed(&self, name: &str) -> bool {
257 self.step_states
258 .read()
259 .expect("workflow lock poisoned")
260 .get(name)
261 .map(|s| s.status == StepStatus::Completed)
262 .unwrap_or(false)
263 }
264
265 pub fn is_step_started(&self, name: &str) -> bool {
270 self.step_states
271 .read()
272 .expect("workflow lock poisoned")
273 .get(name)
274 .map(|s| s.status != StepStatus::Pending)
275 .unwrap_or(false)
276 }
277
278 pub fn get_step_result<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
279 self.step_states
280 .read()
281 .expect("workflow lock poisoned")
282 .get(name)
283 .and_then(|s| s.result.as_ref())
284 .and_then(|v| serde_json::from_value(v.clone()).ok())
285 }
286
287 pub fn record_step_start(&self, name: &str) {
292 let mut states = self.step_states.write().expect("workflow lock poisoned");
293 let state = states
294 .entry(name.to_string())
295 .or_insert_with(|| StepState::new(name));
296
297 if state.status != StepStatus::Pending {
300 return;
301 }
302
303 state.start();
304 let state_clone = state.clone();
305 drop(states);
306
307 let pool = self.db_pool.clone();
309 let run_id = self.run_id;
310 let step_name = name.to_string();
311 tokio::spawn(async move {
312 let step_id = Uuid::new_v4();
313 if let Err(e) = sqlx::query(
314 r#"
315 INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, started_at)
316 VALUES ($1, $2, $3, $4, $5)
317 ON CONFLICT (workflow_run_id, step_name) DO NOTHING
318 "#,
319 )
320 .bind(step_id)
321 .bind(run_id)
322 .bind(&step_name)
323 .bind(state_clone.status.as_str())
324 .bind(state_clone.started_at)
325 .execute(&pool)
326 .await
327 {
328 tracing::warn!(
329 workflow_run_id = %run_id,
330 step = %step_name,
331 "Failed to persist step start: {}",
332 e
333 );
334 }
335 });
336 }
337
338 pub fn record_step_complete(&self, name: &str, result: serde_json::Value) {
341 let state_clone = self.update_step_state_complete(name, result);
342
343 if let Some(state) = state_clone {
345 let pool = self.db_pool.clone();
346 let run_id = self.run_id;
347 let step_name = name.to_string();
348 tokio::spawn(async move {
349 Self::persist_step_complete(&pool, run_id, &step_name, &state).await;
350 });
351 }
352 }
353
354 pub async fn record_step_complete_async(&self, name: &str, result: serde_json::Value) {
356 let state_clone = self.update_step_state_complete(name, result);
357
358 if let Some(state) = state_clone {
360 Self::persist_step_complete(&self.db_pool, self.run_id, name, &state).await;
361 }
362 }
363
364 fn update_step_state_complete(
366 &self,
367 name: &str,
368 result: serde_json::Value,
369 ) -> Option<StepState> {
370 let mut states = self.step_states.write().expect("workflow lock poisoned");
371 if let Some(state) = states.get_mut(name) {
372 state.complete(result.clone());
373 }
374 let state_clone = states.get(name).cloned();
375 drop(states);
376
377 let mut completed = self
378 .completed_steps
379 .write()
380 .expect("workflow lock poisoned");
381 if !completed.contains(&name.to_string()) {
382 completed.push(name.to_string());
383 }
384 drop(completed);
385
386 state_clone
387 }
388
389 async fn persist_step_complete(
391 pool: &sqlx::PgPool,
392 run_id: Uuid,
393 step_name: &str,
394 state: &StepState,
395 ) {
396 if let Err(e) = sqlx::query(
398 r#"
399 INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, result, started_at, completed_at)
400 VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6)
401 ON CONFLICT (workflow_run_id, step_name) DO UPDATE
402 SET status = $3, result = $4, completed_at = $6
403 "#,
404 )
405 .bind(run_id)
406 .bind(step_name)
407 .bind(state.status.as_str())
408 .bind(&state.result)
409 .bind(state.started_at)
410 .bind(state.completed_at)
411 .execute(pool)
412 .await
413 {
414 tracing::warn!(
415 workflow_run_id = %run_id,
416 step = %step_name,
417 "Failed to persist step completion: {}",
418 e
419 );
420 }
421 }
422
423 pub fn record_step_failure(&self, name: &str, error: impl Into<String>) {
425 let error_str = error.into();
426 let mut states = self.step_states.write().expect("workflow lock poisoned");
427 if let Some(state) = states.get_mut(name) {
428 state.fail(error_str.clone());
429 }
430 let state_clone = states.get(name).cloned();
431 drop(states);
432
433 if let Some(state) = state_clone {
435 let pool = self.db_pool.clone();
436 let run_id = self.run_id;
437 let step_name = name.to_string();
438 tokio::spawn(async move {
439 if let Err(e) = sqlx::query(
440 r#"
441 UPDATE forge_workflow_steps
442 SET status = $3, error = $4, completed_at = $5
443 WHERE workflow_run_id = $1 AND step_name = $2
444 "#,
445 )
446 .bind(run_id)
447 .bind(&step_name)
448 .bind(state.status.as_str())
449 .bind(&state.error)
450 .bind(state.completed_at)
451 .execute(&pool)
452 .await
453 {
454 tracing::warn!(
455 workflow_run_id = %run_id,
456 step = %step_name,
457 "Failed to persist step failure: {}",
458 e
459 );
460 }
461 });
462 }
463 }
464
465 pub fn record_step_compensated(&self, name: &str) {
467 let mut states = self.step_states.write().expect("workflow lock poisoned");
468 if let Some(state) = states.get_mut(name) {
469 state.compensate();
470 }
471 let state_clone = states.get(name).cloned();
472 drop(states);
473
474 if let Some(state) = state_clone {
476 let pool = self.db_pool.clone();
477 let run_id = self.run_id;
478 let step_name = name.to_string();
479 tokio::spawn(async move {
480 if let Err(e) = sqlx::query(
481 r#"
482 UPDATE forge_workflow_steps
483 SET status = $3
484 WHERE workflow_run_id = $1 AND step_name = $2
485 "#,
486 )
487 .bind(run_id)
488 .bind(&step_name)
489 .bind(state.status.as_str())
490 .execute(&pool)
491 .await
492 {
493 tracing::warn!(
494 workflow_run_id = %run_id,
495 step = %step_name,
496 "Failed to persist step compensation: {}",
497 e
498 );
499 }
500 });
501 }
502 }
503
504 pub fn completed_steps_reversed(&self) -> Vec<String> {
505 let completed = self.completed_steps.read().expect("workflow lock poisoned");
506 completed.iter().rev().cloned().collect()
507 }
508
509 pub fn all_step_states(&self) -> HashMap<String, StepState> {
510 self.step_states
511 .read()
512 .expect("workflow lock poisoned")
513 .clone()
514 }
515
516 pub fn elapsed(&self) -> chrono::Duration {
517 Utc::now() - self.started_at
518 }
519
520 pub fn register_compensation(&self, step_name: &str, handler: CompensationHandler) {
522 let mut handlers = self
523 .compensation_handlers
524 .write()
525 .expect("workflow lock poisoned");
526 handlers.insert(step_name.to_string(), handler);
527 }
528
529 pub fn get_compensation_handler(&self, step_name: &str) -> Option<CompensationHandler> {
530 self.compensation_handlers
531 .read()
532 .expect("workflow lock poisoned")
533 .get(step_name)
534 .cloned()
535 }
536
537 pub fn has_compensation(&self, step_name: &str) -> bool {
538 self.compensation_handlers
539 .read()
540 .expect("workflow lock poisoned")
541 .contains_key(step_name)
542 }
543
544 pub async fn run_compensation(&self) -> Vec<(String, bool)> {
547 let steps = self.completed_steps_reversed();
548 let mut results = Vec::new();
549
550 for step_name in steps {
551 let handler = self.get_compensation_handler(&step_name);
552 let result = self
553 .get_step_state(&step_name)
554 .and_then(|s| s.result.clone());
555
556 if let Some(handler) = handler {
557 let step_result = result.unwrap_or(serde_json::Value::Null);
558 match handler(step_result).await {
559 Ok(()) => {
560 self.record_step_compensated(&step_name);
561 results.push((step_name, true));
562 }
563 Err(e) => {
564 tracing::error!(step = %step_name, error = %e, "Compensation failed");
565 results.push((step_name, false));
566 }
567 }
568 } else {
569 self.record_step_compensated(&step_name);
571 results.push((step_name, true));
572 }
573 }
574
575 results
576 }
577
578 pub fn compensation_handlers(&self) -> HashMap<String, CompensationHandler> {
579 self.compensation_handlers
580 .read()
581 .expect("workflow lock poisoned")
582 .clone()
583 }
584
585 pub async fn sleep(&self, duration: Duration) -> Result<()> {
596 if self.resumed_from_sleep {
598 return Ok(());
599 }
600
601 let wake_at = Utc::now() + chrono::Duration::from_std(duration).unwrap_or_default();
602 self.sleep_until(wake_at).await
603 }
604
605 pub async fn sleep_until(&self, wake_at: DateTime<Utc>) -> Result<()> {
616 if self.resumed_from_sleep {
618 return Ok(());
619 }
620
621 if wake_at <= Utc::now() {
623 return Ok(());
624 }
625
626 self.set_wake_at(wake_at).await?;
628
629 self.signal_suspend(SuspendReason::Sleep { wake_at })
631 .await?;
632
633 Ok(())
634 }
635
636 pub async fn wait_for_event<T: DeserializeOwned>(
649 &self,
650 event_name: &str,
651 timeout: Option<Duration>,
652 ) -> Result<T> {
653 let correlation_id = self.run_id.to_string();
654
655 if let Some(event) = self.try_consume_event(event_name, &correlation_id).await? {
657 return serde_json::from_value(event.payload.unwrap_or_default())
658 .map_err(|e| ForgeError::Deserialization(e.to_string()));
659 }
660
661 let timeout_at =
663 timeout.map(|d| Utc::now() + chrono::Duration::from_std(d).unwrap_or_default());
664
665 self.set_waiting_for_event(event_name, timeout_at).await?;
667
668 self.signal_suspend(SuspendReason::WaitingEvent {
670 event_name: event_name.to_string(),
671 timeout: timeout_at,
672 })
673 .await?;
674
675 self.try_consume_event(event_name, &correlation_id)
677 .await?
678 .and_then(|e| e.payload)
679 .and_then(|p| serde_json::from_value(p).ok())
680 .ok_or_else(|| ForgeError::Timeout(format!("Event '{}' timed out", event_name)))
681 }
682
683 #[allow(clippy::type_complexity)]
685 async fn try_consume_event(
686 &self,
687 event_name: &str,
688 correlation_id: &str,
689 ) -> Result<Option<WorkflowEvent>> {
690 let result: Option<(
691 Uuid,
692 String,
693 String,
694 Option<serde_json::Value>,
695 DateTime<Utc>,
696 )> = sqlx::query_as(
697 r#"
698 UPDATE forge_workflow_events
699 SET consumed_at = NOW(), consumed_by = $3
700 WHERE id = (
701 SELECT id FROM forge_workflow_events
702 WHERE event_name = $1 AND correlation_id = $2 AND consumed_at IS NULL
703 ORDER BY created_at ASC LIMIT 1
704 FOR UPDATE SKIP LOCKED
705 )
706 RETURNING id, event_name, correlation_id, payload, created_at
707 "#,
708 )
709 .bind(event_name)
710 .bind(correlation_id)
711 .bind(self.run_id)
712 .fetch_optional(&self.db_pool)
713 .await
714 .map_err(|e| ForgeError::Database(e.to_string()))?;
715
716 Ok(result.map(
717 |(id, event_name, correlation_id, payload, created_at)| WorkflowEvent {
718 id,
719 event_name,
720 correlation_id,
721 payload,
722 created_at,
723 },
724 ))
725 }
726
727 async fn set_wake_at(&self, wake_at: DateTime<Utc>) -> Result<()> {
729 sqlx::query(
730 r#"
731 UPDATE forge_workflow_runs
732 SET status = 'waiting', suspended_at = NOW(), wake_at = $2
733 WHERE id = $1
734 "#,
735 )
736 .bind(self.run_id)
737 .bind(wake_at)
738 .execute(&self.db_pool)
739 .await
740 .map_err(|e| ForgeError::Database(e.to_string()))?;
741 Ok(())
742 }
743
744 async fn set_waiting_for_event(
746 &self,
747 event_name: &str,
748 timeout_at: Option<DateTime<Utc>>,
749 ) -> Result<()> {
750 sqlx::query(
751 r#"
752 UPDATE forge_workflow_runs
753 SET status = 'waiting', suspended_at = NOW(), waiting_for_event = $2, event_timeout_at = $3
754 WHERE id = $1
755 "#,
756 )
757 .bind(self.run_id)
758 .bind(event_name)
759 .bind(timeout_at)
760 .execute(&self.db_pool)
761 .await
762 .map_err(|e| ForgeError::Database(e.to_string()))?;
763 Ok(())
764 }
765
766 async fn signal_suspend(&self, reason: SuspendReason) -> Result<()> {
768 if let Some(ref tx) = self.suspend_tx {
769 tx.send(reason)
770 .await
771 .map_err(|_| ForgeError::Internal("Failed to signal suspension".into()))?;
772 }
773 Err(ForgeError::WorkflowSuspended)
775 }
776
777 pub fn parallel(&self) -> ParallelBuilder<'_> {
793 ParallelBuilder::new(self)
794 }
795
796 pub fn step<T, F, Fut>(&self, name: impl Into<String>, f: F) -> super::StepRunner<'_, T>
846 where
847 T: serde::Serialize + serde::de::DeserializeOwned + Clone + Send + Sync + 'static,
848 F: Fn() -> Fut + Send + Sync + 'static,
849 Fut: std::future::Future<Output = crate::Result<T>> + Send + 'static,
850 {
851 super::StepRunner::new(self, name, f)
852 }
853}
854
855impl EnvAccess for WorkflowContext {
856 fn env_provider(&self) -> &dyn EnvProvider {
857 self.env_provider.as_ref()
858 }
859}
860
861#[cfg(test)]
862#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
863mod tests {
864 use super::*;
865
866 #[tokio::test]
867 async fn test_workflow_context_creation() {
868 let pool = sqlx::postgres::PgPoolOptions::new()
869 .max_connections(1)
870 .connect_lazy("postgres://localhost/nonexistent")
871 .expect("Failed to create mock pool");
872
873 let run_id = Uuid::new_v4();
874 let ctx = WorkflowContext::new(
875 run_id,
876 "test_workflow".to_string(),
877 1,
878 pool,
879 reqwest::Client::new(),
880 );
881
882 assert_eq!(ctx.run_id, run_id);
883 assert_eq!(ctx.workflow_name, "test_workflow");
884 assert_eq!(ctx.version, 1);
885 }
886
887 #[tokio::test]
888 async fn test_step_state_tracking() {
889 let pool = sqlx::postgres::PgPoolOptions::new()
890 .max_connections(1)
891 .connect_lazy("postgres://localhost/nonexistent")
892 .expect("Failed to create mock pool");
893
894 let ctx = WorkflowContext::new(
895 Uuid::new_v4(),
896 "test".to_string(),
897 1,
898 pool,
899 reqwest::Client::new(),
900 );
901
902 ctx.record_step_start("step1");
903 assert!(!ctx.is_step_completed("step1"));
904
905 ctx.record_step_complete("step1", serde_json::json!({"result": "ok"}));
906 assert!(ctx.is_step_completed("step1"));
907
908 let result: Option<serde_json::Value> = ctx.get_step_result("step1");
909 assert!(result.is_some());
910 }
911
912 #[test]
913 fn test_step_state_transitions() {
914 let mut state = StepState::new("test");
915 assert_eq!(state.status, StepStatus::Pending);
916
917 state.start();
918 assert_eq!(state.status, StepStatus::Running);
919 assert!(state.started_at.is_some());
920
921 state.complete(serde_json::json!({}));
922 assert_eq!(state.status, StepStatus::Completed);
923 assert!(state.completed_at.is_some());
924 }
925}