1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, RwLock};
5use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use serde::de::DeserializeOwned;
9use tokio::sync::mpsc;
10use uuid::Uuid;
11
12use super::parallel::ParallelBuilder;
13use super::step::StepStatus;
14use super::suspend::{SuspendReason, WorkflowEvent};
15use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
16use crate::function::AuthContext;
17use crate::{ForgeError, Result};
18
19pub type CompensationHandler = Arc<
21 dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync,
22>;
23
24#[derive(Debug, Clone)]
26pub struct StepState {
27 pub name: String,
29 pub status: StepStatus,
31 pub result: Option<serde_json::Value>,
33 pub error: Option<String>,
35 pub started_at: Option<DateTime<Utc>>,
37 pub completed_at: Option<DateTime<Utc>>,
39}
40
41impl StepState {
42 pub fn new(name: impl Into<String>) -> Self {
44 Self {
45 name: name.into(),
46 status: StepStatus::Pending,
47 result: None,
48 error: None,
49 started_at: None,
50 completed_at: None,
51 }
52 }
53
54 pub fn start(&mut self) {
56 self.status = StepStatus::Running;
57 self.started_at = Some(Utc::now());
58 }
59
60 pub fn complete(&mut self, result: serde_json::Value) {
62 self.status = StepStatus::Completed;
63 self.result = Some(result);
64 self.completed_at = Some(Utc::now());
65 }
66
67 pub fn fail(&mut self, error: impl Into<String>) {
69 self.status = StepStatus::Failed;
70 self.error = Some(error.into());
71 self.completed_at = Some(Utc::now());
72 }
73
74 pub fn compensate(&mut self) {
76 self.status = StepStatus::Compensated;
77 }
78}
79
80pub struct WorkflowContext {
82 pub run_id: Uuid,
84 pub workflow_name: String,
86 pub version: u32,
88 pub started_at: DateTime<Utc>,
90 workflow_time: DateTime<Utc>,
92 pub auth: AuthContext,
94 db_pool: sqlx::PgPool,
96 http_client: reqwest::Client,
98 step_states: Arc<RwLock<HashMap<String, StepState>>>,
100 completed_steps: Arc<RwLock<Vec<String>>>,
102 compensation_handlers: Arc<RwLock<HashMap<String, CompensationHandler>>>,
104 suspend_tx: Option<mpsc::Sender<SuspendReason>>,
106 is_resumed: bool,
108 resumed_from_sleep: bool,
110 tenant_id: Option<Uuid>,
112 env_provider: Arc<dyn EnvProvider>,
114}
115
116impl WorkflowContext {
117 pub fn new(
119 run_id: Uuid,
120 workflow_name: String,
121 version: u32,
122 db_pool: sqlx::PgPool,
123 http_client: reqwest::Client,
124 ) -> Self {
125 let now = Utc::now();
126 Self {
127 run_id,
128 workflow_name,
129 version,
130 started_at: now,
131 workflow_time: now,
132 auth: AuthContext::unauthenticated(),
133 db_pool,
134 http_client,
135 step_states: Arc::new(RwLock::new(HashMap::new())),
136 completed_steps: Arc::new(RwLock::new(Vec::new())),
137 compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
138 suspend_tx: None,
139 is_resumed: false,
140 resumed_from_sleep: false,
141 tenant_id: None,
142 env_provider: Arc::new(RealEnvProvider::new()),
143 }
144 }
145
146 pub fn resumed(
148 run_id: Uuid,
149 workflow_name: String,
150 version: u32,
151 started_at: DateTime<Utc>,
152 db_pool: sqlx::PgPool,
153 http_client: reqwest::Client,
154 ) -> Self {
155 Self {
156 run_id,
157 workflow_name,
158 version,
159 started_at,
160 workflow_time: started_at,
161 auth: AuthContext::unauthenticated(),
162 db_pool,
163 http_client,
164 step_states: Arc::new(RwLock::new(HashMap::new())),
165 completed_steps: Arc::new(RwLock::new(Vec::new())),
166 compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
167 suspend_tx: None,
168 is_resumed: true,
169 resumed_from_sleep: false,
170 tenant_id: None,
171 env_provider: Arc::new(RealEnvProvider::new()),
172 }
173 }
174
175 pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
177 self.env_provider = provider;
178 self
179 }
180
181 pub fn with_resumed_from_sleep(mut self) -> Self {
183 self.resumed_from_sleep = true;
184 self
185 }
186
187 pub fn with_suspend_channel(mut self, tx: mpsc::Sender<SuspendReason>) -> Self {
189 self.suspend_tx = Some(tx);
190 self
191 }
192
193 pub fn with_tenant(mut self, tenant_id: Uuid) -> Self {
195 self.tenant_id = Some(tenant_id);
196 self
197 }
198
199 pub fn tenant_id(&self) -> Option<Uuid> {
200 self.tenant_id
201 }
202
203 pub fn is_resumed(&self) -> bool {
204 self.is_resumed
205 }
206
207 pub fn workflow_time(&self) -> DateTime<Utc> {
208 self.workflow_time
209 }
210
211 pub fn db(&self) -> &sqlx::PgPool {
212 &self.db_pool
213 }
214
215 pub fn db_conn(&self) -> crate::function::DbConn<'_> {
217 crate::function::DbConn::Pool(&self.db_pool)
218 }
219
220 pub fn http(&self) -> &reqwest::Client {
221 &self.http_client
222 }
223
224 pub fn with_auth(mut self, auth: AuthContext) -> Self {
226 self.auth = auth;
227 self
228 }
229
230 pub fn with_step_states(self, states: HashMap<String, StepState>) -> Self {
232 let completed: Vec<String> = states
233 .iter()
234 .filter(|(_, s)| s.status == StepStatus::Completed)
235 .map(|(name, _)| name.clone())
236 .collect();
237
238 *self.step_states.write().expect("workflow lock poisoned") = states;
239 *self
240 .completed_steps
241 .write()
242 .expect("workflow lock poisoned") = completed;
243 self
244 }
245
246 pub fn get_step_state(&self, name: &str) -> Option<StepState> {
247 self.step_states
248 .read()
249 .expect("workflow lock poisoned")
250 .get(name)
251 .cloned()
252 }
253
254 pub fn is_step_completed(&self, name: &str) -> bool {
255 self.step_states
256 .read()
257 .expect("workflow lock poisoned")
258 .get(name)
259 .map(|s| s.status == StepStatus::Completed)
260 .unwrap_or(false)
261 }
262
263 pub fn is_step_started(&self, name: &str) -> bool {
268 self.step_states
269 .read()
270 .expect("workflow lock poisoned")
271 .get(name)
272 .map(|s| s.status != StepStatus::Pending)
273 .unwrap_or(false)
274 }
275
276 pub fn get_step_result<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
277 self.step_states
278 .read()
279 .expect("workflow lock poisoned")
280 .get(name)
281 .and_then(|s| s.result.as_ref())
282 .and_then(|v| serde_json::from_value(v.clone()).ok())
283 }
284
285 pub fn record_step_start(&self, name: &str) {
290 let mut states = self.step_states.write().expect("workflow lock poisoned");
291 let state = states
292 .entry(name.to_string())
293 .or_insert_with(|| StepState::new(name));
294
295 if state.status != StepStatus::Pending {
298 return;
299 }
300
301 state.start();
302 let state_clone = state.clone();
303 drop(states);
304
305 let pool = self.db_pool.clone();
307 let run_id = self.run_id;
308 let step_name = name.to_string();
309 tokio::spawn(async move {
310 let step_id = Uuid::new_v4();
311 if let Err(e) = sqlx::query(
312 r#"
313 INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, started_at)
314 VALUES ($1, $2, $3, $4, $5)
315 ON CONFLICT (workflow_run_id, step_name) DO NOTHING
316 "#,
317 )
318 .bind(step_id)
319 .bind(run_id)
320 .bind(&step_name)
321 .bind(state_clone.status.as_str())
322 .bind(state_clone.started_at)
323 .execute(&pool)
324 .await
325 {
326 tracing::warn!(
327 workflow_run_id = %run_id,
328 step = %step_name,
329 "Failed to persist step start: {}",
330 e
331 );
332 }
333 });
334 }
335
336 pub fn record_step_complete(&self, name: &str, result: serde_json::Value) {
339 let state_clone = self.update_step_state_complete(name, result);
340
341 if let Some(state) = state_clone {
343 let pool = self.db_pool.clone();
344 let run_id = self.run_id;
345 let step_name = name.to_string();
346 tokio::spawn(async move {
347 Self::persist_step_complete(&pool, run_id, &step_name, &state).await;
348 });
349 }
350 }
351
352 pub async fn record_step_complete_async(&self, name: &str, result: serde_json::Value) {
354 let state_clone = self.update_step_state_complete(name, result);
355
356 if let Some(state) = state_clone {
358 Self::persist_step_complete(&self.db_pool, self.run_id, name, &state).await;
359 }
360 }
361
362 fn update_step_state_complete(
364 &self,
365 name: &str,
366 result: serde_json::Value,
367 ) -> Option<StepState> {
368 let mut states = self.step_states.write().expect("workflow lock poisoned");
369 if let Some(state) = states.get_mut(name) {
370 state.complete(result.clone());
371 }
372 let state_clone = states.get(name).cloned();
373 drop(states);
374
375 let mut completed = self
376 .completed_steps
377 .write()
378 .expect("workflow lock poisoned");
379 if !completed.contains(&name.to_string()) {
380 completed.push(name.to_string());
381 }
382 drop(completed);
383
384 state_clone
385 }
386
387 async fn persist_step_complete(
389 pool: &sqlx::PgPool,
390 run_id: Uuid,
391 step_name: &str,
392 state: &StepState,
393 ) {
394 if let Err(e) = sqlx::query(
396 r#"
397 INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, result, started_at, completed_at)
398 VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6)
399 ON CONFLICT (workflow_run_id, step_name) DO UPDATE
400 SET status = $3, result = $4, completed_at = $6
401 "#,
402 )
403 .bind(run_id)
404 .bind(step_name)
405 .bind(state.status.as_str())
406 .bind(&state.result)
407 .bind(state.started_at)
408 .bind(state.completed_at)
409 .execute(pool)
410 .await
411 {
412 tracing::warn!(
413 workflow_run_id = %run_id,
414 step = %step_name,
415 "Failed to persist step completion: {}",
416 e
417 );
418 }
419 }
420
421 pub fn record_step_failure(&self, name: &str, error: impl Into<String>) {
423 let error_str = error.into();
424 let mut states = self.step_states.write().expect("workflow lock poisoned");
425 if let Some(state) = states.get_mut(name) {
426 state.fail(error_str.clone());
427 }
428 let state_clone = states.get(name).cloned();
429 drop(states);
430
431 if let Some(state) = state_clone {
433 let pool = self.db_pool.clone();
434 let run_id = self.run_id;
435 let step_name = name.to_string();
436 tokio::spawn(async move {
437 if let Err(e) = sqlx::query(
438 r#"
439 UPDATE forge_workflow_steps
440 SET status = $3, error = $4, completed_at = $5
441 WHERE workflow_run_id = $1 AND step_name = $2
442 "#,
443 )
444 .bind(run_id)
445 .bind(&step_name)
446 .bind(state.status.as_str())
447 .bind(&state.error)
448 .bind(state.completed_at)
449 .execute(&pool)
450 .await
451 {
452 tracing::warn!(
453 workflow_run_id = %run_id,
454 step = %step_name,
455 "Failed to persist step failure: {}",
456 e
457 );
458 }
459 });
460 }
461 }
462
463 pub fn record_step_compensated(&self, name: &str) {
465 let mut states = self.step_states.write().expect("workflow lock poisoned");
466 if let Some(state) = states.get_mut(name) {
467 state.compensate();
468 }
469 let state_clone = states.get(name).cloned();
470 drop(states);
471
472 if let Some(state) = state_clone {
474 let pool = self.db_pool.clone();
475 let run_id = self.run_id;
476 let step_name = name.to_string();
477 tokio::spawn(async move {
478 if let Err(e) = sqlx::query(
479 r#"
480 UPDATE forge_workflow_steps
481 SET status = $3
482 WHERE workflow_run_id = $1 AND step_name = $2
483 "#,
484 )
485 .bind(run_id)
486 .bind(&step_name)
487 .bind(state.status.as_str())
488 .execute(&pool)
489 .await
490 {
491 tracing::warn!(
492 workflow_run_id = %run_id,
493 step = %step_name,
494 "Failed to persist step compensation: {}",
495 e
496 );
497 }
498 });
499 }
500 }
501
502 pub fn completed_steps_reversed(&self) -> Vec<String> {
503 let completed = self.completed_steps.read().expect("workflow lock poisoned");
504 completed.iter().rev().cloned().collect()
505 }
506
507 pub fn all_step_states(&self) -> HashMap<String, StepState> {
508 self.step_states
509 .read()
510 .expect("workflow lock poisoned")
511 .clone()
512 }
513
514 pub fn elapsed(&self) -> chrono::Duration {
515 Utc::now() - self.started_at
516 }
517
518 pub fn register_compensation(&self, step_name: &str, handler: CompensationHandler) {
520 let mut handlers = self
521 .compensation_handlers
522 .write()
523 .expect("workflow lock poisoned");
524 handlers.insert(step_name.to_string(), handler);
525 }
526
527 pub fn get_compensation_handler(&self, step_name: &str) -> Option<CompensationHandler> {
528 self.compensation_handlers
529 .read()
530 .expect("workflow lock poisoned")
531 .get(step_name)
532 .cloned()
533 }
534
535 pub fn has_compensation(&self, step_name: &str) -> bool {
536 self.compensation_handlers
537 .read()
538 .expect("workflow lock poisoned")
539 .contains_key(step_name)
540 }
541
542 pub async fn run_compensation(&self) -> Vec<(String, bool)> {
545 let steps = self.completed_steps_reversed();
546 let mut results = Vec::new();
547
548 for step_name in steps {
549 let handler = self.get_compensation_handler(&step_name);
550 let result = self
551 .get_step_state(&step_name)
552 .and_then(|s| s.result.clone());
553
554 if let Some(handler) = handler {
555 let step_result = result.unwrap_or(serde_json::Value::Null);
556 match handler(step_result).await {
557 Ok(()) => {
558 self.record_step_compensated(&step_name);
559 results.push((step_name, true));
560 }
561 Err(e) => {
562 tracing::error!(step = %step_name, error = %e, "Compensation failed");
563 results.push((step_name, false));
564 }
565 }
566 } else {
567 self.record_step_compensated(&step_name);
569 results.push((step_name, true));
570 }
571 }
572
573 results
574 }
575
576 pub fn compensation_handlers(&self) -> HashMap<String, CompensationHandler> {
577 self.compensation_handlers
578 .read()
579 .expect("workflow lock poisoned")
580 .clone()
581 }
582
583 pub async fn sleep(&self, duration: Duration) -> Result<()> {
594 if self.resumed_from_sleep {
596 return Ok(());
597 }
598
599 let wake_at = Utc::now() + chrono::Duration::from_std(duration).unwrap_or_default();
600 self.sleep_until(wake_at).await
601 }
602
603 pub async fn sleep_until(&self, wake_at: DateTime<Utc>) -> Result<()> {
614 if self.resumed_from_sleep {
616 return Ok(());
617 }
618
619 if wake_at <= Utc::now() {
621 return Ok(());
622 }
623
624 self.set_wake_at(wake_at).await?;
626
627 self.signal_suspend(SuspendReason::Sleep { wake_at })
629 .await?;
630
631 Ok(())
632 }
633
634 pub async fn wait_for_event<T: DeserializeOwned>(
647 &self,
648 event_name: &str,
649 timeout: Option<Duration>,
650 ) -> Result<T> {
651 let correlation_id = self.run_id.to_string();
652
653 if let Some(event) = self.try_consume_event(event_name, &correlation_id).await? {
655 return serde_json::from_value(event.payload.unwrap_or_default())
656 .map_err(|e| ForgeError::Deserialization(e.to_string()));
657 }
658
659 let timeout_at =
661 timeout.map(|d| Utc::now() + chrono::Duration::from_std(d).unwrap_or_default());
662
663 self.set_waiting_for_event(event_name, timeout_at).await?;
665
666 self.signal_suspend(SuspendReason::WaitingEvent {
668 event_name: event_name.to_string(),
669 timeout: timeout_at,
670 })
671 .await?;
672
673 self.try_consume_event(event_name, &correlation_id)
675 .await?
676 .and_then(|e| e.payload)
677 .and_then(|p| serde_json::from_value(p).ok())
678 .ok_or_else(|| ForgeError::Timeout(format!("Event '{}' timed out", event_name)))
679 }
680
681 #[allow(clippy::type_complexity)]
683 async fn try_consume_event(
684 &self,
685 event_name: &str,
686 correlation_id: &str,
687 ) -> Result<Option<WorkflowEvent>> {
688 let result: Option<(
689 Uuid,
690 String,
691 String,
692 Option<serde_json::Value>,
693 DateTime<Utc>,
694 )> = sqlx::query_as(
695 r#"
696 UPDATE forge_workflow_events
697 SET consumed_at = NOW(), consumed_by = $3
698 WHERE id = (
699 SELECT id FROM forge_workflow_events
700 WHERE event_name = $1 AND correlation_id = $2 AND consumed_at IS NULL
701 ORDER BY created_at ASC LIMIT 1
702 FOR UPDATE SKIP LOCKED
703 )
704 RETURNING id, event_name, correlation_id, payload, created_at
705 "#,
706 )
707 .bind(event_name)
708 .bind(correlation_id)
709 .bind(self.run_id)
710 .fetch_optional(&self.db_pool)
711 .await
712 .map_err(|e| ForgeError::Database(e.to_string()))?;
713
714 Ok(result.map(
715 |(id, event_name, correlation_id, payload, created_at)| WorkflowEvent {
716 id,
717 event_name,
718 correlation_id,
719 payload,
720 created_at,
721 },
722 ))
723 }
724
725 async fn set_wake_at(&self, wake_at: DateTime<Utc>) -> Result<()> {
727 sqlx::query(
728 r#"
729 UPDATE forge_workflow_runs
730 SET status = 'waiting', suspended_at = NOW(), wake_at = $2
731 WHERE id = $1
732 "#,
733 )
734 .bind(self.run_id)
735 .bind(wake_at)
736 .execute(&self.db_pool)
737 .await
738 .map_err(|e| ForgeError::Database(e.to_string()))?;
739 Ok(())
740 }
741
742 async fn set_waiting_for_event(
744 &self,
745 event_name: &str,
746 timeout_at: Option<DateTime<Utc>>,
747 ) -> Result<()> {
748 sqlx::query(
749 r#"
750 UPDATE forge_workflow_runs
751 SET status = 'waiting', suspended_at = NOW(), waiting_for_event = $2, event_timeout_at = $3
752 WHERE id = $1
753 "#,
754 )
755 .bind(self.run_id)
756 .bind(event_name)
757 .bind(timeout_at)
758 .execute(&self.db_pool)
759 .await
760 .map_err(|e| ForgeError::Database(e.to_string()))?;
761 Ok(())
762 }
763
764 async fn signal_suspend(&self, reason: SuspendReason) -> Result<()> {
766 if let Some(ref tx) = self.suspend_tx {
767 tx.send(reason)
768 .await
769 .map_err(|_| ForgeError::Internal("Failed to signal suspension".into()))?;
770 }
771 Err(ForgeError::WorkflowSuspended)
773 }
774
775 pub fn parallel(&self) -> ParallelBuilder<'_> {
791 ParallelBuilder::new(self)
792 }
793
794 pub fn step<T, F, Fut>(&self, name: impl Into<String>, f: F) -> super::StepRunner<'_, T>
844 where
845 T: serde::Serialize + serde::de::DeserializeOwned + Clone + Send + Sync + 'static,
846 F: Fn() -> Fut + Send + Sync + 'static,
847 Fut: std::future::Future<Output = crate::Result<T>> + Send + 'static,
848 {
849 super::StepRunner::new(self, name, f)
850 }
851}
852
853impl EnvAccess for WorkflowContext {
854 fn env_provider(&self) -> &dyn EnvProvider {
855 self.env_provider.as_ref()
856 }
857}
858
859#[cfg(test)]
860#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
861mod tests {
862 use super::*;
863
864 #[tokio::test]
865 async fn test_workflow_context_creation() {
866 let pool = sqlx::postgres::PgPoolOptions::new()
867 .max_connections(1)
868 .connect_lazy("postgres://localhost/nonexistent")
869 .expect("Failed to create mock pool");
870
871 let run_id = Uuid::new_v4();
872 let ctx = WorkflowContext::new(
873 run_id,
874 "test_workflow".to_string(),
875 1,
876 pool,
877 reqwest::Client::new(),
878 );
879
880 assert_eq!(ctx.run_id, run_id);
881 assert_eq!(ctx.workflow_name, "test_workflow");
882 assert_eq!(ctx.version, 1);
883 }
884
885 #[tokio::test]
886 async fn test_step_state_tracking() {
887 let pool = sqlx::postgres::PgPoolOptions::new()
888 .max_connections(1)
889 .connect_lazy("postgres://localhost/nonexistent")
890 .expect("Failed to create mock pool");
891
892 let ctx = WorkflowContext::new(
893 Uuid::new_v4(),
894 "test".to_string(),
895 1,
896 pool,
897 reqwest::Client::new(),
898 );
899
900 ctx.record_step_start("step1");
901 assert!(!ctx.is_step_completed("step1"));
902
903 ctx.record_step_complete("step1", serde_json::json!({"result": "ok"}));
904 assert!(ctx.is_step_completed("step1"));
905
906 let result: Option<serde_json::Value> = ctx.get_step_result("step1");
907 assert!(result.is_some());
908 }
909
910 #[test]
911 fn test_step_state_transitions() {
912 let mut state = StepState::new("test");
913 assert_eq!(state.status, StepStatus::Pending);
914
915 state.start();
916 assert_eq!(state.status, StepStatus::Running);
917 assert!(state.started_at.is_some());
918
919 state.complete(serde_json::json!({}));
920 assert_eq!(state.status, StepStatus::Completed);
921 assert!(state.completed_at.is_some());
922 }
923}