1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, RwLock};
5use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use serde::de::DeserializeOwned;
9use tokio::sync::mpsc;
10use uuid::Uuid;
11
12use super::parallel::ParallelBuilder;
13use super::step::StepStatus;
14use super::suspend::{SuspendReason, WorkflowEvent};
15use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
16use crate::function::AuthContext;
17use crate::{ForgeError, Result};
18
19pub type CompensationHandler = Arc<
21 dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync,
22>;
23
24#[derive(Debug, Clone)]
26pub struct StepState {
27 pub name: String,
29 pub status: StepStatus,
31 pub result: Option<serde_json::Value>,
33 pub error: Option<String>,
35 pub started_at: Option<DateTime<Utc>>,
37 pub completed_at: Option<DateTime<Utc>>,
39}
40
41impl StepState {
42 pub fn new(name: impl Into<String>) -> Self {
44 Self {
45 name: name.into(),
46 status: StepStatus::Pending,
47 result: None,
48 error: None,
49 started_at: None,
50 completed_at: None,
51 }
52 }
53
54 pub fn start(&mut self) {
56 self.status = StepStatus::Running;
57 self.started_at = Some(Utc::now());
58 }
59
60 pub fn complete(&mut self, result: serde_json::Value) {
62 self.status = StepStatus::Completed;
63 self.result = Some(result);
64 self.completed_at = Some(Utc::now());
65 }
66
67 pub fn fail(&mut self, error: impl Into<String>) {
69 self.status = StepStatus::Failed;
70 self.error = Some(error.into());
71 self.completed_at = Some(Utc::now());
72 }
73
74 pub fn compensate(&mut self) {
76 self.status = StepStatus::Compensated;
77 }
78}
79
80pub struct WorkflowContext {
82 pub run_id: Uuid,
84 pub workflow_name: String,
86 pub version: u32,
88 pub started_at: DateTime<Utc>,
90 workflow_time: DateTime<Utc>,
92 pub auth: AuthContext,
94 db_pool: sqlx::PgPool,
96 http_client: reqwest::Client,
98 step_states: Arc<RwLock<HashMap<String, StepState>>>,
100 completed_steps: Arc<RwLock<Vec<String>>>,
102 compensation_handlers: Arc<RwLock<HashMap<String, CompensationHandler>>>,
104 suspend_tx: Option<mpsc::Sender<SuspendReason>>,
106 is_resumed: bool,
108 resumed_from_sleep: bool,
110 tenant_id: Option<Uuid>,
112 env_provider: Arc<dyn EnvProvider>,
114}
115
116impl WorkflowContext {
117 pub fn new(
119 run_id: Uuid,
120 workflow_name: String,
121 version: u32,
122 db_pool: sqlx::PgPool,
123 http_client: reqwest::Client,
124 ) -> Self {
125 let now = Utc::now();
126 Self {
127 run_id,
128 workflow_name,
129 version,
130 started_at: now,
131 workflow_time: now,
132 auth: AuthContext::unauthenticated(),
133 db_pool,
134 http_client,
135 step_states: Arc::new(RwLock::new(HashMap::new())),
136 completed_steps: Arc::new(RwLock::new(Vec::new())),
137 compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
138 suspend_tx: None,
139 is_resumed: false,
140 resumed_from_sleep: false,
141 tenant_id: None,
142 env_provider: Arc::new(RealEnvProvider::new()),
143 }
144 }
145
146 pub fn resumed(
148 run_id: Uuid,
149 workflow_name: String,
150 version: u32,
151 started_at: DateTime<Utc>,
152 db_pool: sqlx::PgPool,
153 http_client: reqwest::Client,
154 ) -> Self {
155 Self {
156 run_id,
157 workflow_name,
158 version,
159 started_at,
160 workflow_time: started_at,
161 auth: AuthContext::unauthenticated(),
162 db_pool,
163 http_client,
164 step_states: Arc::new(RwLock::new(HashMap::new())),
165 completed_steps: Arc::new(RwLock::new(Vec::new())),
166 compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
167 suspend_tx: None,
168 is_resumed: true,
169 resumed_from_sleep: false,
170 tenant_id: None,
171 env_provider: Arc::new(RealEnvProvider::new()),
172 }
173 }
174
175 pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
177 self.env_provider = provider;
178 self
179 }
180
181 pub fn with_resumed_from_sleep(mut self) -> Self {
183 self.resumed_from_sleep = true;
184 self
185 }
186
187 pub fn with_suspend_channel(mut self, tx: mpsc::Sender<SuspendReason>) -> Self {
189 self.suspend_tx = Some(tx);
190 self
191 }
192
193 pub fn with_tenant(mut self, tenant_id: Uuid) -> Self {
195 self.tenant_id = Some(tenant_id);
196 self
197 }
198
199 pub fn tenant_id(&self) -> Option<Uuid> {
200 self.tenant_id
201 }
202
203 pub fn is_resumed(&self) -> bool {
204 self.is_resumed
205 }
206
207 pub fn workflow_time(&self) -> DateTime<Utc> {
208 self.workflow_time
209 }
210
211 pub fn db(&self) -> &sqlx::PgPool {
212 &self.db_pool
213 }
214
215 pub fn http(&self) -> &reqwest::Client {
216 &self.http_client
217 }
218
219 pub fn with_auth(mut self, auth: AuthContext) -> Self {
221 self.auth = auth;
222 self
223 }
224
225 pub fn with_step_states(self, states: HashMap<String, StepState>) -> Self {
227 let completed: Vec<String> = states
228 .iter()
229 .filter(|(_, s)| s.status == StepStatus::Completed)
230 .map(|(name, _)| name.clone())
231 .collect();
232
233 *self.step_states.write().expect("workflow lock poisoned") = states;
234 *self
235 .completed_steps
236 .write()
237 .expect("workflow lock poisoned") = completed;
238 self
239 }
240
241 pub fn get_step_state(&self, name: &str) -> Option<StepState> {
242 self.step_states
243 .read()
244 .expect("workflow lock poisoned")
245 .get(name)
246 .cloned()
247 }
248
249 pub fn is_step_completed(&self, name: &str) -> bool {
250 self.step_states
251 .read()
252 .expect("workflow lock poisoned")
253 .get(name)
254 .map(|s| s.status == StepStatus::Completed)
255 .unwrap_or(false)
256 }
257
258 pub fn is_step_started(&self, name: &str) -> bool {
263 self.step_states
264 .read()
265 .expect("workflow lock poisoned")
266 .get(name)
267 .map(|s| s.status != StepStatus::Pending)
268 .unwrap_or(false)
269 }
270
271 pub fn get_step_result<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
272 self.step_states
273 .read()
274 .expect("workflow lock poisoned")
275 .get(name)
276 .and_then(|s| s.result.as_ref())
277 .and_then(|v| serde_json::from_value(v.clone()).ok())
278 }
279
280 pub fn record_step_start(&self, name: &str) {
285 let mut states = self.step_states.write().expect("workflow lock poisoned");
286 let state = states
287 .entry(name.to_string())
288 .or_insert_with(|| StepState::new(name));
289
290 if state.status != StepStatus::Pending {
293 return;
294 }
295
296 state.start();
297 let state_clone = state.clone();
298 drop(states);
299
300 let pool = self.db_pool.clone();
302 let run_id = self.run_id;
303 let step_name = name.to_string();
304 tokio::spawn(async move {
305 let step_id = Uuid::new_v4();
306 if let Err(e) = sqlx::query(
307 r#"
308 INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, started_at)
309 VALUES ($1, $2, $3, $4, $5)
310 ON CONFLICT (workflow_run_id, step_name) DO NOTHING
311 "#,
312 )
313 .bind(step_id)
314 .bind(run_id)
315 .bind(&step_name)
316 .bind(state_clone.status.as_str())
317 .bind(state_clone.started_at)
318 .execute(&pool)
319 .await
320 {
321 tracing::warn!(
322 workflow_run_id = %run_id,
323 step = %step_name,
324 "Failed to persist step start: {}",
325 e
326 );
327 }
328 });
329 }
330
331 pub fn record_step_complete(&self, name: &str, result: serde_json::Value) {
334 let state_clone = self.update_step_state_complete(name, result);
335
336 if let Some(state) = state_clone {
338 let pool = self.db_pool.clone();
339 let run_id = self.run_id;
340 let step_name = name.to_string();
341 tokio::spawn(async move {
342 Self::persist_step_complete(&pool, run_id, &step_name, &state).await;
343 });
344 }
345 }
346
347 pub async fn record_step_complete_async(&self, name: &str, result: serde_json::Value) {
349 let state_clone = self.update_step_state_complete(name, result);
350
351 if let Some(state) = state_clone {
353 Self::persist_step_complete(&self.db_pool, self.run_id, name, &state).await;
354 }
355 }
356
357 fn update_step_state_complete(
359 &self,
360 name: &str,
361 result: serde_json::Value,
362 ) -> Option<StepState> {
363 let mut states = self.step_states.write().expect("workflow lock poisoned");
364 if let Some(state) = states.get_mut(name) {
365 state.complete(result.clone());
366 }
367 let state_clone = states.get(name).cloned();
368 drop(states);
369
370 let mut completed = self
371 .completed_steps
372 .write()
373 .expect("workflow lock poisoned");
374 if !completed.contains(&name.to_string()) {
375 completed.push(name.to_string());
376 }
377 drop(completed);
378
379 state_clone
380 }
381
382 async fn persist_step_complete(
384 pool: &sqlx::PgPool,
385 run_id: Uuid,
386 step_name: &str,
387 state: &StepState,
388 ) {
389 if let Err(e) = sqlx::query(
391 r#"
392 INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, result, started_at, completed_at)
393 VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6)
394 ON CONFLICT (workflow_run_id, step_name) DO UPDATE
395 SET status = $3, result = $4, completed_at = $6
396 "#,
397 )
398 .bind(run_id)
399 .bind(step_name)
400 .bind(state.status.as_str())
401 .bind(&state.result)
402 .bind(state.started_at)
403 .bind(state.completed_at)
404 .execute(pool)
405 .await
406 {
407 tracing::warn!(
408 workflow_run_id = %run_id,
409 step = %step_name,
410 "Failed to persist step completion: {}",
411 e
412 );
413 }
414 }
415
416 pub fn record_step_failure(&self, name: &str, error: impl Into<String>) {
418 let error_str = error.into();
419 let mut states = self.step_states.write().expect("workflow lock poisoned");
420 if let Some(state) = states.get_mut(name) {
421 state.fail(error_str.clone());
422 }
423 let state_clone = states.get(name).cloned();
424 drop(states);
425
426 if let Some(state) = state_clone {
428 let pool = self.db_pool.clone();
429 let run_id = self.run_id;
430 let step_name = name.to_string();
431 tokio::spawn(async move {
432 if let Err(e) = sqlx::query(
433 r#"
434 UPDATE forge_workflow_steps
435 SET status = $3, error = $4, completed_at = $5
436 WHERE workflow_run_id = $1 AND step_name = $2
437 "#,
438 )
439 .bind(run_id)
440 .bind(&step_name)
441 .bind(state.status.as_str())
442 .bind(&state.error)
443 .bind(state.completed_at)
444 .execute(&pool)
445 .await
446 {
447 tracing::warn!(
448 workflow_run_id = %run_id,
449 step = %step_name,
450 "Failed to persist step failure: {}",
451 e
452 );
453 }
454 });
455 }
456 }
457
458 pub fn record_step_compensated(&self, name: &str) {
460 let mut states = self.step_states.write().expect("workflow lock poisoned");
461 if let Some(state) = states.get_mut(name) {
462 state.compensate();
463 }
464 let state_clone = states.get(name).cloned();
465 drop(states);
466
467 if let Some(state) = state_clone {
469 let pool = self.db_pool.clone();
470 let run_id = self.run_id;
471 let step_name = name.to_string();
472 tokio::spawn(async move {
473 if let Err(e) = sqlx::query(
474 r#"
475 UPDATE forge_workflow_steps
476 SET status = $3
477 WHERE workflow_run_id = $1 AND step_name = $2
478 "#,
479 )
480 .bind(run_id)
481 .bind(&step_name)
482 .bind(state.status.as_str())
483 .execute(&pool)
484 .await
485 {
486 tracing::warn!(
487 workflow_run_id = %run_id,
488 step = %step_name,
489 "Failed to persist step compensation: {}",
490 e
491 );
492 }
493 });
494 }
495 }
496
497 pub fn completed_steps_reversed(&self) -> Vec<String> {
498 let completed = self.completed_steps.read().expect("workflow lock poisoned");
499 completed.iter().rev().cloned().collect()
500 }
501
502 pub fn all_step_states(&self) -> HashMap<String, StepState> {
503 self.step_states
504 .read()
505 .expect("workflow lock poisoned")
506 .clone()
507 }
508
509 pub fn elapsed(&self) -> chrono::Duration {
510 Utc::now() - self.started_at
511 }
512
513 pub fn register_compensation(&self, step_name: &str, handler: CompensationHandler) {
515 let mut handlers = self
516 .compensation_handlers
517 .write()
518 .expect("workflow lock poisoned");
519 handlers.insert(step_name.to_string(), handler);
520 }
521
522 pub fn get_compensation_handler(&self, step_name: &str) -> Option<CompensationHandler> {
523 self.compensation_handlers
524 .read()
525 .expect("workflow lock poisoned")
526 .get(step_name)
527 .cloned()
528 }
529
530 pub fn has_compensation(&self, step_name: &str) -> bool {
531 self.compensation_handlers
532 .read()
533 .expect("workflow lock poisoned")
534 .contains_key(step_name)
535 }
536
537 pub async fn run_compensation(&self) -> Vec<(String, bool)> {
540 let steps = self.completed_steps_reversed();
541 let mut results = Vec::new();
542
543 for step_name in steps {
544 let handler = self.get_compensation_handler(&step_name);
545 let result = self
546 .get_step_state(&step_name)
547 .and_then(|s| s.result.clone());
548
549 if let Some(handler) = handler {
550 let step_result = result.unwrap_or(serde_json::Value::Null);
551 match handler(step_result).await {
552 Ok(()) => {
553 self.record_step_compensated(&step_name);
554 results.push((step_name, true));
555 }
556 Err(e) => {
557 tracing::error!(step = %step_name, error = %e, "Compensation failed");
558 results.push((step_name, false));
559 }
560 }
561 } else {
562 self.record_step_compensated(&step_name);
564 results.push((step_name, true));
565 }
566 }
567
568 results
569 }
570
571 pub fn compensation_handlers(&self) -> HashMap<String, CompensationHandler> {
572 self.compensation_handlers
573 .read()
574 .expect("workflow lock poisoned")
575 .clone()
576 }
577
578 pub async fn sleep(&self, duration: Duration) -> Result<()> {
589 if self.resumed_from_sleep {
591 return Ok(());
592 }
593
594 let wake_at = Utc::now() + chrono::Duration::from_std(duration).unwrap_or_default();
595 self.sleep_until(wake_at).await
596 }
597
598 pub async fn sleep_until(&self, wake_at: DateTime<Utc>) -> Result<()> {
609 if self.resumed_from_sleep {
611 return Ok(());
612 }
613
614 if wake_at <= Utc::now() {
616 return Ok(());
617 }
618
619 self.set_wake_at(wake_at).await?;
621
622 self.signal_suspend(SuspendReason::Sleep { wake_at })
624 .await?;
625
626 Ok(())
627 }
628
629 pub async fn wait_for_event<T: DeserializeOwned>(
642 &self,
643 event_name: &str,
644 timeout: Option<Duration>,
645 ) -> Result<T> {
646 let correlation_id = self.run_id.to_string();
647
648 if let Some(event) = self.try_consume_event(event_name, &correlation_id).await? {
650 return serde_json::from_value(event.payload.unwrap_or_default())
651 .map_err(|e| ForgeError::Deserialization(e.to_string()));
652 }
653
654 let timeout_at =
656 timeout.map(|d| Utc::now() + chrono::Duration::from_std(d).unwrap_or_default());
657
658 self.set_waiting_for_event(event_name, timeout_at).await?;
660
661 self.signal_suspend(SuspendReason::WaitingEvent {
663 event_name: event_name.to_string(),
664 timeout: timeout_at,
665 })
666 .await?;
667
668 self.try_consume_event(event_name, &correlation_id)
670 .await?
671 .and_then(|e| e.payload)
672 .and_then(|p| serde_json::from_value(p).ok())
673 .ok_or_else(|| ForgeError::Timeout(format!("Event '{}' timed out", event_name)))
674 }
675
676 #[allow(clippy::type_complexity)]
678 async fn try_consume_event(
679 &self,
680 event_name: &str,
681 correlation_id: &str,
682 ) -> Result<Option<WorkflowEvent>> {
683 let result: Option<(
684 Uuid,
685 String,
686 String,
687 Option<serde_json::Value>,
688 DateTime<Utc>,
689 )> = sqlx::query_as(
690 r#"
691 UPDATE forge_workflow_events
692 SET consumed_at = NOW(), consumed_by = $3
693 WHERE id = (
694 SELECT id FROM forge_workflow_events
695 WHERE event_name = $1 AND correlation_id = $2 AND consumed_at IS NULL
696 ORDER BY created_at ASC LIMIT 1
697 FOR UPDATE SKIP LOCKED
698 )
699 RETURNING id, event_name, correlation_id, payload, created_at
700 "#,
701 )
702 .bind(event_name)
703 .bind(correlation_id)
704 .bind(self.run_id)
705 .fetch_optional(&self.db_pool)
706 .await
707 .map_err(|e| ForgeError::Database(e.to_string()))?;
708
709 Ok(result.map(
710 |(id, event_name, correlation_id, payload, created_at)| WorkflowEvent {
711 id,
712 event_name,
713 correlation_id,
714 payload,
715 created_at,
716 },
717 ))
718 }
719
720 async fn set_wake_at(&self, wake_at: DateTime<Utc>) -> Result<()> {
722 sqlx::query(
723 r#"
724 UPDATE forge_workflow_runs
725 SET status = 'waiting', suspended_at = NOW(), wake_at = $2
726 WHERE id = $1
727 "#,
728 )
729 .bind(self.run_id)
730 .bind(wake_at)
731 .execute(&self.db_pool)
732 .await
733 .map_err(|e| ForgeError::Database(e.to_string()))?;
734 Ok(())
735 }
736
737 async fn set_waiting_for_event(
739 &self,
740 event_name: &str,
741 timeout_at: Option<DateTime<Utc>>,
742 ) -> Result<()> {
743 sqlx::query(
744 r#"
745 UPDATE forge_workflow_runs
746 SET status = 'waiting', suspended_at = NOW(), waiting_for_event = $2, event_timeout_at = $3
747 WHERE id = $1
748 "#,
749 )
750 .bind(self.run_id)
751 .bind(event_name)
752 .bind(timeout_at)
753 .execute(&self.db_pool)
754 .await
755 .map_err(|e| ForgeError::Database(e.to_string()))?;
756 Ok(())
757 }
758
759 async fn signal_suspend(&self, reason: SuspendReason) -> Result<()> {
761 if let Some(ref tx) = self.suspend_tx {
762 tx.send(reason)
763 .await
764 .map_err(|_| ForgeError::Internal("Failed to signal suspension".into()))?;
765 }
766 Err(ForgeError::WorkflowSuspended)
768 }
769
770 pub fn parallel(&self) -> ParallelBuilder<'_> {
786 ParallelBuilder::new(self)
787 }
788
789 pub fn step<T, F, Fut>(&self, name: impl Into<String>, f: F) -> super::StepRunner<'_, T>
839 where
840 T: serde::Serialize + serde::de::DeserializeOwned + Clone + Send + Sync + 'static,
841 F: Fn() -> Fut + Send + Sync + 'static,
842 Fut: std::future::Future<Output = crate::Result<T>> + Send + 'static,
843 {
844 super::StepRunner::new(self, name, f)
845 }
846}
847
848impl EnvAccess for WorkflowContext {
849 fn env_provider(&self) -> &dyn EnvProvider {
850 self.env_provider.as_ref()
851 }
852}
853
854#[cfg(test)]
855#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
856mod tests {
857 use super::*;
858
859 #[tokio::test]
860 async fn test_workflow_context_creation() {
861 let pool = sqlx::postgres::PgPoolOptions::new()
862 .max_connections(1)
863 .connect_lazy("postgres://localhost/nonexistent")
864 .expect("Failed to create mock pool");
865
866 let run_id = Uuid::new_v4();
867 let ctx = WorkflowContext::new(
868 run_id,
869 "test_workflow".to_string(),
870 1,
871 pool,
872 reqwest::Client::new(),
873 );
874
875 assert_eq!(ctx.run_id, run_id);
876 assert_eq!(ctx.workflow_name, "test_workflow");
877 assert_eq!(ctx.version, 1);
878 }
879
880 #[tokio::test]
881 async fn test_step_state_tracking() {
882 let pool = sqlx::postgres::PgPoolOptions::new()
883 .max_connections(1)
884 .connect_lazy("postgres://localhost/nonexistent")
885 .expect("Failed to create mock pool");
886
887 let ctx = WorkflowContext::new(
888 Uuid::new_v4(),
889 "test".to_string(),
890 1,
891 pool,
892 reqwest::Client::new(),
893 );
894
895 ctx.record_step_start("step1");
896 assert!(!ctx.is_step_completed("step1"));
897
898 ctx.record_step_complete("step1", serde_json::json!({"result": "ok"}));
899 assert!(ctx.is_step_completed("step1"));
900
901 let result: Option<serde_json::Value> = ctx.get_step_result("step1");
902 assert!(result.is_some());
903 }
904
905 #[test]
906 fn test_step_state_transitions() {
907 let mut state = StepState::new("test");
908 assert_eq!(state.status, StepStatus::Pending);
909
910 state.start();
911 assert_eq!(state.status, StepStatus::Running);
912 assert!(state.started_at.is_some());
913
914 state.complete(serde_json::json!({}));
915 assert_eq!(state.status, StepStatus::Completed);
916 assert!(state.completed_at.is_some());
917 }
918}