1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, RwLock};
5use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use serde::de::DeserializeOwned;
9use tokio::sync::mpsc;
10use uuid::Uuid;
11
12use super::parallel::ParallelBuilder;
13use super::step::StepStatus;
14use super::suspend::{SuspendReason, WorkflowEvent};
15use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
16use crate::function::AuthContext;
17use crate::{ForgeError, Result};
18
19pub type CompensationHandler = Arc<
21 dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync,
22>;
23
24#[derive(Debug, Clone)]
26pub struct StepState {
27 pub name: String,
29 pub status: StepStatus,
31 pub result: Option<serde_json::Value>,
33 pub error: Option<String>,
35 pub started_at: Option<DateTime<Utc>>,
37 pub completed_at: Option<DateTime<Utc>>,
39}
40
41impl StepState {
42 pub fn new(name: impl Into<String>) -> Self {
44 Self {
45 name: name.into(),
46 status: StepStatus::Pending,
47 result: None,
48 error: None,
49 started_at: None,
50 completed_at: None,
51 }
52 }
53
54 pub fn start(&mut self) {
56 self.status = StepStatus::Running;
57 self.started_at = Some(Utc::now());
58 }
59
60 pub fn complete(&mut self, result: serde_json::Value) {
62 self.status = StepStatus::Completed;
63 self.result = Some(result);
64 self.completed_at = Some(Utc::now());
65 }
66
67 pub fn fail(&mut self, error: impl Into<String>) {
69 self.status = StepStatus::Failed;
70 self.error = Some(error.into());
71 self.completed_at = Some(Utc::now());
72 }
73
74 pub fn compensate(&mut self) {
76 self.status = StepStatus::Compensated;
77 }
78}
79
80pub struct WorkflowContext {
82 pub run_id: Uuid,
84 pub workflow_name: String,
86 pub version: u32,
88 pub started_at: DateTime<Utc>,
90 workflow_time: DateTime<Utc>,
92 pub auth: AuthContext,
94 db_pool: sqlx::PgPool,
96 http_client: reqwest::Client,
98 step_states: Arc<RwLock<HashMap<String, StepState>>>,
100 completed_steps: Arc<RwLock<Vec<String>>>,
102 compensation_handlers: Arc<RwLock<HashMap<String, CompensationHandler>>>,
104 suspend_tx: Option<mpsc::Sender<SuspendReason>>,
106 is_resumed: bool,
108 resumed_from_sleep: bool,
110 tenant_id: Option<Uuid>,
112 env_provider: Arc<dyn EnvProvider>,
114}
115
116impl WorkflowContext {
117 pub fn new(
119 run_id: Uuid,
120 workflow_name: String,
121 version: u32,
122 db_pool: sqlx::PgPool,
123 http_client: reqwest::Client,
124 ) -> Self {
125 let now = Utc::now();
126 Self {
127 run_id,
128 workflow_name,
129 version,
130 started_at: now,
131 workflow_time: now,
132 auth: AuthContext::unauthenticated(),
133 db_pool,
134 http_client,
135 step_states: Arc::new(RwLock::new(HashMap::new())),
136 completed_steps: Arc::new(RwLock::new(Vec::new())),
137 compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
138 suspend_tx: None,
139 is_resumed: false,
140 resumed_from_sleep: false,
141 tenant_id: None,
142 env_provider: Arc::new(RealEnvProvider::new()),
143 }
144 }
145
146 pub fn resumed(
148 run_id: Uuid,
149 workflow_name: String,
150 version: u32,
151 started_at: DateTime<Utc>,
152 db_pool: sqlx::PgPool,
153 http_client: reqwest::Client,
154 ) -> Self {
155 Self {
156 run_id,
157 workflow_name,
158 version,
159 started_at,
160 workflow_time: started_at,
161 auth: AuthContext::unauthenticated(),
162 db_pool,
163 http_client,
164 step_states: Arc::new(RwLock::new(HashMap::new())),
165 completed_steps: Arc::new(RwLock::new(Vec::new())),
166 compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
167 suspend_tx: None,
168 is_resumed: true,
169 resumed_from_sleep: false,
170 tenant_id: None,
171 env_provider: Arc::new(RealEnvProvider::new()),
172 }
173 }
174
175 pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
177 self.env_provider = provider;
178 self
179 }
180
181 pub fn with_resumed_from_sleep(mut self) -> Self {
183 self.resumed_from_sleep = true;
184 self
185 }
186
187 pub fn with_suspend_channel(mut self, tx: mpsc::Sender<SuspendReason>) -> Self {
189 self.suspend_tx = Some(tx);
190 self
191 }
192
193 pub fn with_tenant(mut self, tenant_id: Uuid) -> Self {
195 self.tenant_id = Some(tenant_id);
196 self
197 }
198
199 pub fn tenant_id(&self) -> Option<Uuid> {
201 self.tenant_id
202 }
203
204 pub fn is_resumed(&self) -> bool {
206 self.is_resumed
207 }
208
209 pub fn workflow_time(&self) -> DateTime<Utc> {
211 self.workflow_time
212 }
213
214 pub fn db(&self) -> &sqlx::PgPool {
216 &self.db_pool
217 }
218
219 pub fn http(&self) -> &reqwest::Client {
221 &self.http_client
222 }
223
224 pub fn with_auth(mut self, auth: AuthContext) -> Self {
226 self.auth = auth;
227 self
228 }
229
230 pub fn with_step_states(self, states: HashMap<String, StepState>) -> Self {
232 let completed: Vec<String> = states
233 .iter()
234 .filter(|(_, s)| s.status == StepStatus::Completed)
235 .map(|(name, _)| name.clone())
236 .collect();
237
238 *self.step_states.write().unwrap() = states;
239 *self.completed_steps.write().unwrap() = completed;
240 self
241 }
242
243 pub fn get_step_state(&self, name: &str) -> Option<StepState> {
245 self.step_states.read().unwrap().get(name).cloned()
246 }
247
248 pub fn is_step_completed(&self, name: &str) -> bool {
250 self.step_states
251 .read()
252 .unwrap()
253 .get(name)
254 .map(|s| s.status == StepStatus::Completed)
255 .unwrap_or(false)
256 }
257
258 pub fn is_step_started(&self, name: &str) -> bool {
263 self.step_states
264 .read()
265 .unwrap()
266 .get(name)
267 .map(|s| s.status != StepStatus::Pending)
268 .unwrap_or(false)
269 }
270
271 pub fn get_step_result<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
273 self.step_states
274 .read()
275 .unwrap()
276 .get(name)
277 .and_then(|s| s.result.as_ref())
278 .and_then(|v| serde_json::from_value(v.clone()).ok())
279 }
280
281 pub fn record_step_start(&self, name: &str) {
286 let mut states = self.step_states.write().unwrap();
287 let state = states
288 .entry(name.to_string())
289 .or_insert_with(|| StepState::new(name));
290
291 if state.status != StepStatus::Pending {
294 return;
295 }
296
297 state.start();
298 let state_clone = state.clone();
299 drop(states);
300
301 let pool = self.db_pool.clone();
303 let run_id = self.run_id;
304 let step_name = name.to_string();
305 tokio::spawn(async move {
306 let step_id = Uuid::new_v4();
307 if let Err(e) = sqlx::query(
308 r#"
309 INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, started_at)
310 VALUES ($1, $2, $3, $4, $5)
311 ON CONFLICT (workflow_run_id, step_name) DO NOTHING
312 "#,
313 )
314 .bind(step_id)
315 .bind(run_id)
316 .bind(&step_name)
317 .bind(state_clone.status.as_str())
318 .bind(state_clone.started_at)
319 .execute(&pool)
320 .await
321 {
322 tracing::warn!(
323 workflow_run_id = %run_id,
324 step = %step_name,
325 "Failed to persist step start: {}",
326 e
327 );
328 }
329 });
330 }
331
332 pub fn record_step_complete(&self, name: &str, result: serde_json::Value) {
335 let state_clone = self.update_step_state_complete(name, result);
336
337 if let Some(state) = state_clone {
339 let pool = self.db_pool.clone();
340 let run_id = self.run_id;
341 let step_name = name.to_string();
342 tokio::spawn(async move {
343 Self::persist_step_complete(&pool, run_id, &step_name, &state).await;
344 });
345 }
346 }
347
348 pub async fn record_step_complete_async(&self, name: &str, result: serde_json::Value) {
350 let state_clone = self.update_step_state_complete(name, result);
351
352 if let Some(state) = state_clone {
354 Self::persist_step_complete(&self.db_pool, self.run_id, name, &state).await;
355 }
356 }
357
358 fn update_step_state_complete(
360 &self,
361 name: &str,
362 result: serde_json::Value,
363 ) -> Option<StepState> {
364 let mut states = self.step_states.write().unwrap();
365 if let Some(state) = states.get_mut(name) {
366 state.complete(result.clone());
367 }
368 let state_clone = states.get(name).cloned();
369 drop(states);
370
371 let mut completed = self.completed_steps.write().unwrap();
372 if !completed.contains(&name.to_string()) {
373 completed.push(name.to_string());
374 }
375 drop(completed);
376
377 state_clone
378 }
379
380 async fn persist_step_complete(
382 pool: &sqlx::PgPool,
383 run_id: Uuid,
384 step_name: &str,
385 state: &StepState,
386 ) {
387 if let Err(e) = sqlx::query(
389 r#"
390 INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, result, started_at, completed_at)
391 VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6)
392 ON CONFLICT (workflow_run_id, step_name) DO UPDATE
393 SET status = $3, result = $4, completed_at = $6
394 "#,
395 )
396 .bind(run_id)
397 .bind(step_name)
398 .bind(state.status.as_str())
399 .bind(&state.result)
400 .bind(state.started_at)
401 .bind(state.completed_at)
402 .execute(pool)
403 .await
404 {
405 tracing::warn!(
406 workflow_run_id = %run_id,
407 step = %step_name,
408 "Failed to persist step completion: {}",
409 e
410 );
411 }
412 }
413
414 pub fn record_step_failure(&self, name: &str, error: impl Into<String>) {
416 let error_str = error.into();
417 let mut states = self.step_states.write().unwrap();
418 if let Some(state) = states.get_mut(name) {
419 state.fail(error_str.clone());
420 }
421 let state_clone = states.get(name).cloned();
422 drop(states);
423
424 if let Some(state) = state_clone {
426 let pool = self.db_pool.clone();
427 let run_id = self.run_id;
428 let step_name = name.to_string();
429 tokio::spawn(async move {
430 if let Err(e) = sqlx::query(
431 r#"
432 UPDATE forge_workflow_steps
433 SET status = $3, error = $4, completed_at = $5
434 WHERE workflow_run_id = $1 AND step_name = $2
435 "#,
436 )
437 .bind(run_id)
438 .bind(&step_name)
439 .bind(state.status.as_str())
440 .bind(&state.error)
441 .bind(state.completed_at)
442 .execute(&pool)
443 .await
444 {
445 tracing::warn!(
446 workflow_run_id = %run_id,
447 step = %step_name,
448 "Failed to persist step failure: {}",
449 e
450 );
451 }
452 });
453 }
454 }
455
456 pub fn record_step_compensated(&self, name: &str) {
458 let mut states = self.step_states.write().unwrap();
459 if let Some(state) = states.get_mut(name) {
460 state.compensate();
461 }
462 let state_clone = states.get(name).cloned();
463 drop(states);
464
465 if let Some(state) = state_clone {
467 let pool = self.db_pool.clone();
468 let run_id = self.run_id;
469 let step_name = name.to_string();
470 tokio::spawn(async move {
471 if let Err(e) = sqlx::query(
472 r#"
473 UPDATE forge_workflow_steps
474 SET status = $3
475 WHERE workflow_run_id = $1 AND step_name = $2
476 "#,
477 )
478 .bind(run_id)
479 .bind(&step_name)
480 .bind(state.status.as_str())
481 .execute(&pool)
482 .await
483 {
484 tracing::warn!(
485 workflow_run_id = %run_id,
486 step = %step_name,
487 "Failed to persist step compensation: {}",
488 e
489 );
490 }
491 });
492 }
493 }
494
495 pub fn completed_steps_reversed(&self) -> Vec<String> {
497 let completed = self.completed_steps.read().unwrap();
498 completed.iter().rev().cloned().collect()
499 }
500
501 pub fn all_step_states(&self) -> HashMap<String, StepState> {
503 self.step_states.read().unwrap().clone()
504 }
505
506 pub fn elapsed(&self) -> chrono::Duration {
508 Utc::now() - self.started_at
509 }
510
511 pub fn register_compensation(&self, step_name: &str, handler: CompensationHandler) {
513 let mut handlers = self.compensation_handlers.write().unwrap();
514 handlers.insert(step_name.to_string(), handler);
515 }
516
517 pub fn get_compensation_handler(&self, step_name: &str) -> Option<CompensationHandler> {
519 self.compensation_handlers
520 .read()
521 .unwrap()
522 .get(step_name)
523 .cloned()
524 }
525
526 pub fn has_compensation(&self, step_name: &str) -> bool {
528 self.compensation_handlers
529 .read()
530 .unwrap()
531 .contains_key(step_name)
532 }
533
534 pub async fn run_compensation(&self) -> Vec<(String, bool)> {
537 let steps = self.completed_steps_reversed();
538 let mut results = Vec::new();
539
540 for step_name in steps {
541 let handler = self.get_compensation_handler(&step_name);
542 let result = self
543 .get_step_state(&step_name)
544 .and_then(|s| s.result.clone());
545
546 if let Some(handler) = handler {
547 let step_result = result.unwrap_or(serde_json::Value::Null);
548 match handler(step_result).await {
549 Ok(()) => {
550 self.record_step_compensated(&step_name);
551 results.push((step_name, true));
552 }
553 Err(e) => {
554 tracing::error!(step = %step_name, error = %e, "Compensation failed");
555 results.push((step_name, false));
556 }
557 }
558 } else {
559 self.record_step_compensated(&step_name);
561 results.push((step_name, true));
562 }
563 }
564
565 results
566 }
567
568 pub fn compensation_handlers(&self) -> HashMap<String, CompensationHandler> {
570 self.compensation_handlers.read().unwrap().clone()
571 }
572
573 pub async fn sleep(&self, duration: Duration) -> Result<()> {
584 if self.resumed_from_sleep {
586 return Ok(());
587 }
588
589 let wake_at = Utc::now() + chrono::Duration::from_std(duration).unwrap_or_default();
590 self.sleep_until(wake_at).await
591 }
592
593 pub async fn sleep_until(&self, wake_at: DateTime<Utc>) -> Result<()> {
604 if self.resumed_from_sleep {
606 return Ok(());
607 }
608
609 if wake_at <= Utc::now() {
611 return Ok(());
612 }
613
614 self.set_wake_at(wake_at).await?;
616
617 self.signal_suspend(SuspendReason::Sleep { wake_at })
619 .await?;
620
621 Ok(())
622 }
623
624 pub async fn wait_for_event<T: DeserializeOwned>(
637 &self,
638 event_name: &str,
639 timeout: Option<Duration>,
640 ) -> Result<T> {
641 let correlation_id = self.run_id.to_string();
642
643 if let Some(event) = self.try_consume_event(event_name, &correlation_id).await? {
645 return serde_json::from_value(event.payload.unwrap_or_default())
646 .map_err(|e| ForgeError::Deserialization(e.to_string()));
647 }
648
649 let timeout_at =
651 timeout.map(|d| Utc::now() + chrono::Duration::from_std(d).unwrap_or_default());
652
653 self.set_waiting_for_event(event_name, timeout_at).await?;
655
656 self.signal_suspend(SuspendReason::WaitingEvent {
658 event_name: event_name.to_string(),
659 timeout: timeout_at,
660 })
661 .await?;
662
663 self.try_consume_event(event_name, &correlation_id)
665 .await?
666 .and_then(|e| e.payload)
667 .and_then(|p| serde_json::from_value(p).ok())
668 .ok_or_else(|| ForgeError::Timeout(format!("Event '{}' timed out", event_name)))
669 }
670
671 #[allow(clippy::type_complexity)]
673 async fn try_consume_event(
674 &self,
675 event_name: &str,
676 correlation_id: &str,
677 ) -> Result<Option<WorkflowEvent>> {
678 let result: Option<(
679 Uuid,
680 String,
681 String,
682 Option<serde_json::Value>,
683 DateTime<Utc>,
684 )> = sqlx::query_as(
685 r#"
686 UPDATE forge_workflow_events
687 SET consumed_at = NOW(), consumed_by = $3
688 WHERE id = (
689 SELECT id FROM forge_workflow_events
690 WHERE event_name = $1 AND correlation_id = $2 AND consumed_at IS NULL
691 ORDER BY created_at ASC LIMIT 1
692 FOR UPDATE SKIP LOCKED
693 )
694 RETURNING id, event_name, correlation_id, payload, created_at
695 "#,
696 )
697 .bind(event_name)
698 .bind(correlation_id)
699 .bind(self.run_id)
700 .fetch_optional(&self.db_pool)
701 .await
702 .map_err(|e| ForgeError::Database(e.to_string()))?;
703
704 Ok(result.map(
705 |(id, event_name, correlation_id, payload, created_at)| WorkflowEvent {
706 id,
707 event_name,
708 correlation_id,
709 payload,
710 created_at,
711 },
712 ))
713 }
714
715 async fn set_wake_at(&self, wake_at: DateTime<Utc>) -> Result<()> {
717 sqlx::query(
718 r#"
719 UPDATE forge_workflow_runs
720 SET status = 'waiting', suspended_at = NOW(), wake_at = $2
721 WHERE id = $1
722 "#,
723 )
724 .bind(self.run_id)
725 .bind(wake_at)
726 .execute(&self.db_pool)
727 .await
728 .map_err(|e| ForgeError::Database(e.to_string()))?;
729 Ok(())
730 }
731
732 async fn set_waiting_for_event(
734 &self,
735 event_name: &str,
736 timeout_at: Option<DateTime<Utc>>,
737 ) -> Result<()> {
738 sqlx::query(
739 r#"
740 UPDATE forge_workflow_runs
741 SET status = 'waiting', suspended_at = NOW(), waiting_for_event = $2, event_timeout_at = $3
742 WHERE id = $1
743 "#,
744 )
745 .bind(self.run_id)
746 .bind(event_name)
747 .bind(timeout_at)
748 .execute(&self.db_pool)
749 .await
750 .map_err(|e| ForgeError::Database(e.to_string()))?;
751 Ok(())
752 }
753
754 async fn signal_suspend(&self, reason: SuspendReason) -> Result<()> {
756 if let Some(ref tx) = self.suspend_tx {
757 tx.send(reason)
758 .await
759 .map_err(|_| ForgeError::Internal("Failed to signal suspension".into()))?;
760 }
761 Err(ForgeError::WorkflowSuspended)
763 }
764
765 pub fn parallel(&self) -> ParallelBuilder<'_> {
781 ParallelBuilder::new(self)
782 }
783
784 pub fn step<T, F, Fut>(&self, name: impl Into<String>, f: F) -> super::StepRunner<'_, T>
834 where
835 T: serde::Serialize + serde::de::DeserializeOwned + Clone + Send + Sync + 'static,
836 F: Fn() -> Fut + Send + Sync + 'static,
837 Fut: std::future::Future<Output = crate::Result<T>> + Send + 'static,
838 {
839 super::StepRunner::new(self, name, f)
840 }
841}
842
843impl EnvAccess for WorkflowContext {
844 fn env_provider(&self) -> &dyn EnvProvider {
845 self.env_provider.as_ref()
846 }
847}
848
849#[cfg(test)]
850mod tests {
851 use super::*;
852
853 #[tokio::test]
854 async fn test_workflow_context_creation() {
855 let pool = sqlx::postgres::PgPoolOptions::new()
856 .max_connections(1)
857 .connect_lazy("postgres://localhost/nonexistent")
858 .expect("Failed to create mock pool");
859
860 let run_id = Uuid::new_v4();
861 let ctx = WorkflowContext::new(
862 run_id,
863 "test_workflow".to_string(),
864 1,
865 pool,
866 reqwest::Client::new(),
867 );
868
869 assert_eq!(ctx.run_id, run_id);
870 assert_eq!(ctx.workflow_name, "test_workflow");
871 assert_eq!(ctx.version, 1);
872 }
873
874 #[tokio::test]
875 async fn test_step_state_tracking() {
876 let pool = sqlx::postgres::PgPoolOptions::new()
877 .max_connections(1)
878 .connect_lazy("postgres://localhost/nonexistent")
879 .expect("Failed to create mock pool");
880
881 let ctx = WorkflowContext::new(
882 Uuid::new_v4(),
883 "test".to_string(),
884 1,
885 pool,
886 reqwest::Client::new(),
887 );
888
889 ctx.record_step_start("step1");
890 assert!(!ctx.is_step_completed("step1"));
891
892 ctx.record_step_complete("step1", serde_json::json!({"result": "ok"}));
893 assert!(ctx.is_step_completed("step1"));
894
895 let result: Option<serde_json::Value> = ctx.get_step_result("step1");
896 assert!(result.is_some());
897 }
898
899 #[test]
900 fn test_step_state_transitions() {
901 let mut state = StepState::new("test");
902 assert_eq!(state.status, StepStatus::Pending);
903
904 state.start();
905 assert_eq!(state.status, StepStatus::Running);
906 assert!(state.started_at.is_some());
907
908 state.complete(serde_json::json!({}));
909 assert_eq!(state.status, StepStatus::Completed);
910 assert!(state.completed_at.is_some());
911 }
912}