1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, RwLock};
5use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use serde::de::DeserializeOwned;
9use tokio::sync::mpsc;
10use uuid::Uuid;
11
12use super::parallel::ParallelBuilder;
13use super::step::StepStatus;
14use super::suspend::{SuspendReason, WorkflowEvent};
15use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
16use crate::function::AuthContext;
17use crate::{ForgeError, Result};
18
19pub type CompensationHandler = Arc<
21 dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync,
22>;
23
24#[derive(Debug, Clone)]
26pub struct StepState {
27 pub name: String,
29 pub status: StepStatus,
31 pub result: Option<serde_json::Value>,
33 pub error: Option<String>,
35 pub started_at: Option<DateTime<Utc>>,
37 pub completed_at: Option<DateTime<Utc>>,
39}
40
41impl StepState {
42 pub fn new(name: impl Into<String>) -> Self {
44 Self {
45 name: name.into(),
46 status: StepStatus::Pending,
47 result: None,
48 error: None,
49 started_at: None,
50 completed_at: None,
51 }
52 }
53
54 pub fn start(&mut self) {
56 self.status = StepStatus::Running;
57 self.started_at = Some(Utc::now());
58 }
59
60 pub fn complete(&mut self, result: serde_json::Value) {
62 self.status = StepStatus::Completed;
63 self.result = Some(result);
64 self.completed_at = Some(Utc::now());
65 }
66
67 pub fn fail(&mut self, error: impl Into<String>) {
69 self.status = StepStatus::Failed;
70 self.error = Some(error.into());
71 self.completed_at = Some(Utc::now());
72 }
73
74 pub fn compensate(&mut self) {
76 self.status = StepStatus::Compensated;
77 }
78}
79
80pub struct WorkflowContext {
82 pub run_id: Uuid,
84 pub workflow_name: String,
86 pub version: u32,
88 pub started_at: DateTime<Utc>,
90 workflow_time: DateTime<Utc>,
92 pub auth: AuthContext,
94 db_pool: sqlx::PgPool,
96 http_client: reqwest::Client,
98 step_states: Arc<RwLock<HashMap<String, StepState>>>,
100 completed_steps: Arc<RwLock<Vec<String>>>,
102 compensation_handlers: Arc<RwLock<HashMap<String, CompensationHandler>>>,
104 suspend_tx: Option<mpsc::Sender<SuspendReason>>,
106 is_resumed: bool,
108 resumed_from_sleep: bool,
110 tenant_id: Option<Uuid>,
112 env_provider: Arc<dyn EnvProvider>,
114}
115
116impl WorkflowContext {
117 pub fn new(
119 run_id: Uuid,
120 workflow_name: String,
121 version: u32,
122 db_pool: sqlx::PgPool,
123 http_client: reqwest::Client,
124 ) -> Self {
125 let now = Utc::now();
126 Self {
127 run_id,
128 workflow_name,
129 version,
130 started_at: now,
131 workflow_time: now,
132 auth: AuthContext::unauthenticated(),
133 db_pool,
134 http_client,
135 step_states: Arc::new(RwLock::new(HashMap::new())),
136 completed_steps: Arc::new(RwLock::new(Vec::new())),
137 compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
138 suspend_tx: None,
139 is_resumed: false,
140 resumed_from_sleep: false,
141 tenant_id: None,
142 env_provider: Arc::new(RealEnvProvider::new()),
143 }
144 }
145
146 pub fn resumed(
148 run_id: Uuid,
149 workflow_name: String,
150 version: u32,
151 started_at: DateTime<Utc>,
152 db_pool: sqlx::PgPool,
153 http_client: reqwest::Client,
154 ) -> Self {
155 Self {
156 run_id,
157 workflow_name,
158 version,
159 started_at,
160 workflow_time: started_at,
161 auth: AuthContext::unauthenticated(),
162 db_pool,
163 http_client,
164 step_states: Arc::new(RwLock::new(HashMap::new())),
165 completed_steps: Arc::new(RwLock::new(Vec::new())),
166 compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
167 suspend_tx: None,
168 is_resumed: true,
169 resumed_from_sleep: false,
170 tenant_id: None,
171 env_provider: Arc::new(RealEnvProvider::new()),
172 }
173 }
174
175 pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
177 self.env_provider = provider;
178 self
179 }
180
181 pub fn with_resumed_from_sleep(mut self) -> Self {
183 self.resumed_from_sleep = true;
184 self
185 }
186
187 pub fn with_suspend_channel(mut self, tx: mpsc::Sender<SuspendReason>) -> Self {
189 self.suspend_tx = Some(tx);
190 self
191 }
192
193 pub fn with_tenant(mut self, tenant_id: Uuid) -> Self {
195 self.tenant_id = Some(tenant_id);
196 self
197 }
198
199 pub fn tenant_id(&self) -> Option<Uuid> {
200 self.tenant_id
201 }
202
203 pub fn is_resumed(&self) -> bool {
204 self.is_resumed
205 }
206
207 pub fn workflow_time(&self) -> DateTime<Utc> {
208 self.workflow_time
209 }
210
211 pub fn db(&self) -> &sqlx::PgPool {
212 &self.db_pool
213 }
214
215 pub fn db_conn(&self) -> crate::function::DbConn<'_> {
217 crate::function::DbConn::Pool(&self.db_pool)
218 }
219
220 pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
222 Ok(crate::function::ForgeConn::Pool(
223 self.db_pool.acquire().await?,
224 ))
225 }
226
227 pub fn http(&self) -> &reqwest::Client {
228 &self.http_client
229 }
230
231 pub fn with_auth(mut self, auth: AuthContext) -> Self {
233 self.auth = auth;
234 self
235 }
236
237 pub fn with_step_states(self, states: HashMap<String, StepState>) -> Self {
239 let completed: Vec<String> = states
240 .iter()
241 .filter(|(_, s)| s.status == StepStatus::Completed)
242 .map(|(name, _)| name.clone())
243 .collect();
244
245 *self.step_states.write().expect("workflow lock poisoned") = states;
246 *self
247 .completed_steps
248 .write()
249 .expect("workflow lock poisoned") = completed;
250 self
251 }
252
253 pub fn get_step_state(&self, name: &str) -> Option<StepState> {
254 self.step_states
255 .read()
256 .expect("workflow lock poisoned")
257 .get(name)
258 .cloned()
259 }
260
261 pub fn is_step_completed(&self, name: &str) -> bool {
262 self.step_states
263 .read()
264 .expect("workflow lock poisoned")
265 .get(name)
266 .map(|s| s.status == StepStatus::Completed)
267 .unwrap_or(false)
268 }
269
270 pub fn is_step_started(&self, name: &str) -> bool {
275 self.step_states
276 .read()
277 .expect("workflow lock poisoned")
278 .get(name)
279 .map(|s| s.status != StepStatus::Pending)
280 .unwrap_or(false)
281 }
282
283 pub fn get_step_result<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
284 self.step_states
285 .read()
286 .expect("workflow lock poisoned")
287 .get(name)
288 .and_then(|s| s.result.as_ref())
289 .and_then(|v| serde_json::from_value(v.clone()).ok())
290 }
291
292 pub fn record_step_start(&self, name: &str) {
297 let mut states = self.step_states.write().expect("workflow lock poisoned");
298 let state = states
299 .entry(name.to_string())
300 .or_insert_with(|| StepState::new(name));
301
302 if state.status != StepStatus::Pending {
305 return;
306 }
307
308 state.start();
309 let state_clone = state.clone();
310 drop(states);
311
312 let pool = self.db_pool.clone();
314 let run_id = self.run_id;
315 let step_name = name.to_string();
316 tokio::spawn(async move {
317 let step_id = Uuid::new_v4();
318 if let Err(e) = sqlx::query(
319 r#"
320 INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, started_at)
321 VALUES ($1, $2, $3, $4, $5)
322 ON CONFLICT (workflow_run_id, step_name) DO NOTHING
323 "#,
324 )
325 .bind(step_id)
326 .bind(run_id)
327 .bind(&step_name)
328 .bind(state_clone.status.as_str())
329 .bind(state_clone.started_at)
330 .execute(&pool)
331 .await
332 {
333 tracing::warn!(
334 workflow_run_id = %run_id,
335 step = %step_name,
336 "Failed to persist step start: {}",
337 e
338 );
339 }
340 });
341 }
342
343 pub fn record_step_complete(&self, name: &str, result: serde_json::Value) {
346 let state_clone = self.update_step_state_complete(name, result);
347
348 if let Some(state) = state_clone {
350 let pool = self.db_pool.clone();
351 let run_id = self.run_id;
352 let step_name = name.to_string();
353 tokio::spawn(async move {
354 Self::persist_step_complete(&pool, run_id, &step_name, &state).await;
355 });
356 }
357 }
358
359 pub async fn record_step_complete_async(&self, name: &str, result: serde_json::Value) {
361 let state_clone = self.update_step_state_complete(name, result);
362
363 if let Some(state) = state_clone {
365 Self::persist_step_complete(&self.db_pool, self.run_id, name, &state).await;
366 }
367 }
368
369 fn update_step_state_complete(
371 &self,
372 name: &str,
373 result: serde_json::Value,
374 ) -> Option<StepState> {
375 let mut states = self.step_states.write().expect("workflow lock poisoned");
376 if let Some(state) = states.get_mut(name) {
377 state.complete(result.clone());
378 }
379 let state_clone = states.get(name).cloned();
380 drop(states);
381
382 let mut completed = self
383 .completed_steps
384 .write()
385 .expect("workflow lock poisoned");
386 if !completed.contains(&name.to_string()) {
387 completed.push(name.to_string());
388 }
389 drop(completed);
390
391 state_clone
392 }
393
394 async fn persist_step_complete(
396 pool: &sqlx::PgPool,
397 run_id: Uuid,
398 step_name: &str,
399 state: &StepState,
400 ) {
401 if let Err(e) = sqlx::query(
403 r#"
404 INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, result, started_at, completed_at)
405 VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6)
406 ON CONFLICT (workflow_run_id, step_name) DO UPDATE
407 SET status = $3, result = $4, completed_at = $6
408 "#,
409 )
410 .bind(run_id)
411 .bind(step_name)
412 .bind(state.status.as_str())
413 .bind(&state.result)
414 .bind(state.started_at)
415 .bind(state.completed_at)
416 .execute(pool)
417 .await
418 {
419 tracing::warn!(
420 workflow_run_id = %run_id,
421 step = %step_name,
422 "Failed to persist step completion: {}",
423 e
424 );
425 }
426 }
427
428 pub fn record_step_failure(&self, name: &str, error: impl Into<String>) {
430 let error_str = error.into();
431 let mut states = self.step_states.write().expect("workflow lock poisoned");
432 if let Some(state) = states.get_mut(name) {
433 state.fail(error_str.clone());
434 }
435 let state_clone = states.get(name).cloned();
436 drop(states);
437
438 if let Some(state) = state_clone {
440 let pool = self.db_pool.clone();
441 let run_id = self.run_id;
442 let step_name = name.to_string();
443 tokio::spawn(async move {
444 if let Err(e) = sqlx::query(
445 r#"
446 UPDATE forge_workflow_steps
447 SET status = $3, error = $4, completed_at = $5
448 WHERE workflow_run_id = $1 AND step_name = $2
449 "#,
450 )
451 .bind(run_id)
452 .bind(&step_name)
453 .bind(state.status.as_str())
454 .bind(&state.error)
455 .bind(state.completed_at)
456 .execute(&pool)
457 .await
458 {
459 tracing::warn!(
460 workflow_run_id = %run_id,
461 step = %step_name,
462 "Failed to persist step failure: {}",
463 e
464 );
465 }
466 });
467 }
468 }
469
470 pub fn record_step_compensated(&self, name: &str) {
472 let mut states = self.step_states.write().expect("workflow lock poisoned");
473 if let Some(state) = states.get_mut(name) {
474 state.compensate();
475 }
476 let state_clone = states.get(name).cloned();
477 drop(states);
478
479 if let Some(state) = state_clone {
481 let pool = self.db_pool.clone();
482 let run_id = self.run_id;
483 let step_name = name.to_string();
484 tokio::spawn(async move {
485 if let Err(e) = sqlx::query(
486 r#"
487 UPDATE forge_workflow_steps
488 SET status = $3
489 WHERE workflow_run_id = $1 AND step_name = $2
490 "#,
491 )
492 .bind(run_id)
493 .bind(&step_name)
494 .bind(state.status.as_str())
495 .execute(&pool)
496 .await
497 {
498 tracing::warn!(
499 workflow_run_id = %run_id,
500 step = %step_name,
501 "Failed to persist step compensation: {}",
502 e
503 );
504 }
505 });
506 }
507 }
508
509 pub fn completed_steps_reversed(&self) -> Vec<String> {
510 let completed = self.completed_steps.read().expect("workflow lock poisoned");
511 completed.iter().rev().cloned().collect()
512 }
513
514 pub fn all_step_states(&self) -> HashMap<String, StepState> {
515 self.step_states
516 .read()
517 .expect("workflow lock poisoned")
518 .clone()
519 }
520
521 pub fn elapsed(&self) -> chrono::Duration {
522 Utc::now() - self.started_at
523 }
524
525 pub fn register_compensation(&self, step_name: &str, handler: CompensationHandler) {
527 let mut handlers = self
528 .compensation_handlers
529 .write()
530 .expect("workflow lock poisoned");
531 handlers.insert(step_name.to_string(), handler);
532 }
533
534 pub fn get_compensation_handler(&self, step_name: &str) -> Option<CompensationHandler> {
535 self.compensation_handlers
536 .read()
537 .expect("workflow lock poisoned")
538 .get(step_name)
539 .cloned()
540 }
541
542 pub fn has_compensation(&self, step_name: &str) -> bool {
543 self.compensation_handlers
544 .read()
545 .expect("workflow lock poisoned")
546 .contains_key(step_name)
547 }
548
549 pub async fn run_compensation(&self) -> Vec<(String, bool)> {
552 let steps = self.completed_steps_reversed();
553 let mut results = Vec::new();
554
555 for step_name in steps {
556 let handler = self.get_compensation_handler(&step_name);
557 let result = self
558 .get_step_state(&step_name)
559 .and_then(|s| s.result.clone());
560
561 if let Some(handler) = handler {
562 let step_result = result.unwrap_or(serde_json::Value::Null);
563 match handler(step_result).await {
564 Ok(()) => {
565 self.record_step_compensated(&step_name);
566 results.push((step_name, true));
567 }
568 Err(e) => {
569 tracing::error!(step = %step_name, error = %e, "Compensation failed");
570 results.push((step_name, false));
571 }
572 }
573 } else {
574 self.record_step_compensated(&step_name);
576 results.push((step_name, true));
577 }
578 }
579
580 results
581 }
582
583 pub fn compensation_handlers(&self) -> HashMap<String, CompensationHandler> {
584 self.compensation_handlers
585 .read()
586 .expect("workflow lock poisoned")
587 .clone()
588 }
589
590 pub async fn sleep(&self, duration: Duration) -> Result<()> {
601 if self.resumed_from_sleep {
603 return Ok(());
604 }
605
606 let wake_at = Utc::now() + chrono::Duration::from_std(duration).unwrap_or_default();
607 self.sleep_until(wake_at).await
608 }
609
610 pub async fn sleep_until(&self, wake_at: DateTime<Utc>) -> Result<()> {
621 if self.resumed_from_sleep {
623 return Ok(());
624 }
625
626 if wake_at <= Utc::now() {
628 return Ok(());
629 }
630
631 self.set_wake_at(wake_at).await?;
633
634 self.signal_suspend(SuspendReason::Sleep { wake_at })
636 .await?;
637
638 Ok(())
639 }
640
641 pub async fn wait_for_event<T: DeserializeOwned>(
654 &self,
655 event_name: &str,
656 timeout: Option<Duration>,
657 ) -> Result<T> {
658 let correlation_id = self.run_id.to_string();
659
660 if let Some(event) = self.try_consume_event(event_name, &correlation_id).await? {
662 return serde_json::from_value(event.payload.unwrap_or_default())
663 .map_err(|e| ForgeError::Deserialization(e.to_string()));
664 }
665
666 let timeout_at =
668 timeout.map(|d| Utc::now() + chrono::Duration::from_std(d).unwrap_or_default());
669
670 self.set_waiting_for_event(event_name, timeout_at).await?;
672
673 self.signal_suspend(SuspendReason::WaitingEvent {
675 event_name: event_name.to_string(),
676 timeout: timeout_at,
677 })
678 .await?;
679
680 self.try_consume_event(event_name, &correlation_id)
682 .await?
683 .and_then(|e| e.payload)
684 .and_then(|p| serde_json::from_value(p).ok())
685 .ok_or_else(|| ForgeError::Timeout(format!("Event '{}' timed out", event_name)))
686 }
687
688 #[allow(clippy::type_complexity)]
690 async fn try_consume_event(
691 &self,
692 event_name: &str,
693 correlation_id: &str,
694 ) -> Result<Option<WorkflowEvent>> {
695 let result: Option<(
696 Uuid,
697 String,
698 String,
699 Option<serde_json::Value>,
700 DateTime<Utc>,
701 )> = sqlx::query_as(
702 r#"
703 UPDATE forge_workflow_events
704 SET consumed_at = NOW(), consumed_by = $3
705 WHERE id = (
706 SELECT id FROM forge_workflow_events
707 WHERE event_name = $1 AND correlation_id = $2 AND consumed_at IS NULL
708 ORDER BY created_at ASC LIMIT 1
709 FOR UPDATE SKIP LOCKED
710 )
711 RETURNING id, event_name, correlation_id, payload, created_at
712 "#,
713 )
714 .bind(event_name)
715 .bind(correlation_id)
716 .bind(self.run_id)
717 .fetch_optional(&self.db_pool)
718 .await
719 .map_err(|e| ForgeError::Database(e.to_string()))?;
720
721 Ok(result.map(
722 |(id, event_name, correlation_id, payload, created_at)| WorkflowEvent {
723 id,
724 event_name,
725 correlation_id,
726 payload,
727 created_at,
728 },
729 ))
730 }
731
732 async fn set_wake_at(&self, wake_at: DateTime<Utc>) -> Result<()> {
734 sqlx::query(
735 r#"
736 UPDATE forge_workflow_runs
737 SET status = 'waiting', suspended_at = NOW(), wake_at = $2
738 WHERE id = $1
739 "#,
740 )
741 .bind(self.run_id)
742 .bind(wake_at)
743 .execute(&self.db_pool)
744 .await
745 .map_err(|e| ForgeError::Database(e.to_string()))?;
746 Ok(())
747 }
748
749 async fn set_waiting_for_event(
751 &self,
752 event_name: &str,
753 timeout_at: Option<DateTime<Utc>>,
754 ) -> Result<()> {
755 sqlx::query(
756 r#"
757 UPDATE forge_workflow_runs
758 SET status = 'waiting', suspended_at = NOW(), waiting_for_event = $2, event_timeout_at = $3
759 WHERE id = $1
760 "#,
761 )
762 .bind(self.run_id)
763 .bind(event_name)
764 .bind(timeout_at)
765 .execute(&self.db_pool)
766 .await
767 .map_err(|e| ForgeError::Database(e.to_string()))?;
768 Ok(())
769 }
770
771 async fn signal_suspend(&self, reason: SuspendReason) -> Result<()> {
773 if let Some(ref tx) = self.suspend_tx {
774 tx.send(reason)
775 .await
776 .map_err(|_| ForgeError::Internal("Failed to signal suspension".into()))?;
777 }
778 Err(ForgeError::WorkflowSuspended)
780 }
781
782 pub fn parallel(&self) -> ParallelBuilder<'_> {
798 ParallelBuilder::new(self)
799 }
800
801 pub fn step<T, F, Fut>(&self, name: impl Into<String>, f: F) -> super::StepRunner<'_, T>
851 where
852 T: serde::Serialize + serde::de::DeserializeOwned + Clone + Send + Sync + 'static,
853 F: Fn() -> Fut + Send + Sync + 'static,
854 Fut: std::future::Future<Output = crate::Result<T>> + Send + 'static,
855 {
856 super::StepRunner::new(self, name, f)
857 }
858}
859
860impl EnvAccess for WorkflowContext {
861 fn env_provider(&self) -> &dyn EnvProvider {
862 self.env_provider.as_ref()
863 }
864}
865
866#[cfg(test)]
867#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
868mod tests {
869 use super::*;
870
871 #[tokio::test]
872 async fn test_workflow_context_creation() {
873 let pool = sqlx::postgres::PgPoolOptions::new()
874 .max_connections(1)
875 .connect_lazy("postgres://localhost/nonexistent")
876 .expect("Failed to create mock pool");
877
878 let run_id = Uuid::new_v4();
879 let ctx = WorkflowContext::new(
880 run_id,
881 "test_workflow".to_string(),
882 1,
883 pool,
884 reqwest::Client::new(),
885 );
886
887 assert_eq!(ctx.run_id, run_id);
888 assert_eq!(ctx.workflow_name, "test_workflow");
889 assert_eq!(ctx.version, 1);
890 }
891
892 #[tokio::test]
893 async fn test_step_state_tracking() {
894 let pool = sqlx::postgres::PgPoolOptions::new()
895 .max_connections(1)
896 .connect_lazy("postgres://localhost/nonexistent")
897 .expect("Failed to create mock pool");
898
899 let ctx = WorkflowContext::new(
900 Uuid::new_v4(),
901 "test".to_string(),
902 1,
903 pool,
904 reqwest::Client::new(),
905 );
906
907 ctx.record_step_start("step1");
908 assert!(!ctx.is_step_completed("step1"));
909
910 ctx.record_step_complete("step1", serde_json::json!({"result": "ok"}));
911 assert!(ctx.is_step_completed("step1"));
912
913 let result: Option<serde_json::Value> = ctx.get_step_result("step1");
914 assert!(result.is_some());
915 }
916
917 #[test]
918 fn test_step_state_transitions() {
919 let mut state = StepState::new("test");
920 assert_eq!(state.status, StepStatus::Pending);
921
922 state.start();
923 assert_eq!(state.status, StepStatus::Running);
924 assert!(state.started_at.is_some());
925
926 state.complete(serde_json::json!({}));
927 assert_eq!(state.status, StepStatus::Completed);
928 assert!(state.completed_at.is_some());
929 }
930}