1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, RwLock};
5use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use serde::de::DeserializeOwned;
9use tokio::sync::mpsc;
10use uuid::Uuid;
11
12use super::parallel::ParallelBuilder;
13use super::step::StepStatus;
14use super::suspend::{SuspendReason, WorkflowEvent};
15use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
16use crate::function::AuthContext;
17use crate::http::CircuitBreakerClient;
18use crate::{ForgeError, Result};
19
20pub type CompensationHandler = Arc<
22 dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync,
23>;
24
25#[derive(Debug, Clone)]
27pub struct StepState {
28 pub name: String,
30 pub status: StepStatus,
32 pub result: Option<serde_json::Value>,
34 pub error: Option<String>,
36 pub started_at: Option<DateTime<Utc>>,
38 pub completed_at: Option<DateTime<Utc>>,
40}
41
42impl StepState {
43 pub fn new(name: impl Into<String>) -> Self {
45 Self {
46 name: name.into(),
47 status: StepStatus::Pending,
48 result: None,
49 error: None,
50 started_at: None,
51 completed_at: None,
52 }
53 }
54
55 pub fn start(&mut self) {
57 self.status = StepStatus::Running;
58 self.started_at = Some(Utc::now());
59 }
60
61 pub fn complete(&mut self, result: serde_json::Value) {
63 self.status = StepStatus::Completed;
64 self.result = Some(result);
65 self.completed_at = Some(Utc::now());
66 }
67
68 pub fn fail(&mut self, error: impl Into<String>) {
70 self.status = StepStatus::Failed;
71 self.error = Some(error.into());
72 self.completed_at = Some(Utc::now());
73 }
74
75 pub fn compensate(&mut self) {
77 self.status = StepStatus::Compensated;
78 }
79}
80
81pub struct WorkflowContext {
83 pub run_id: Uuid,
85 pub workflow_name: String,
87 pub version: u32,
89 pub started_at: DateTime<Utc>,
91 workflow_time: DateTime<Utc>,
93 pub auth: AuthContext,
95 db_pool: sqlx::PgPool,
97 http_client: CircuitBreakerClient,
99 http_timeout: Option<Duration>,
102 step_states: Arc<RwLock<HashMap<String, StepState>>>,
104 completed_steps: Arc<RwLock<Vec<String>>>,
106 compensation_handlers: Arc<RwLock<HashMap<String, CompensationHandler>>>,
108 suspend_tx: Option<mpsc::Sender<SuspendReason>>,
110 is_resumed: bool,
112 resumed_from_sleep: bool,
114 tenant_id: Option<Uuid>,
116 env_provider: Arc<dyn EnvProvider>,
118}
119
120impl WorkflowContext {
121 pub fn new(
123 run_id: Uuid,
124 workflow_name: String,
125 version: u32,
126 db_pool: sqlx::PgPool,
127 http_client: CircuitBreakerClient,
128 ) -> Self {
129 let now = Utc::now();
130 Self {
131 run_id,
132 workflow_name,
133 version,
134 started_at: now,
135 workflow_time: now,
136 auth: AuthContext::unauthenticated(),
137 db_pool,
138 http_client,
139 http_timeout: None,
140 step_states: Arc::new(RwLock::new(HashMap::new())),
141 completed_steps: Arc::new(RwLock::new(Vec::new())),
142 compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
143 suspend_tx: None,
144 is_resumed: false,
145 resumed_from_sleep: false,
146 tenant_id: None,
147 env_provider: Arc::new(RealEnvProvider::new()),
148 }
149 }
150
151 pub fn resumed(
153 run_id: Uuid,
154 workflow_name: String,
155 version: u32,
156 started_at: DateTime<Utc>,
157 db_pool: sqlx::PgPool,
158 http_client: CircuitBreakerClient,
159 ) -> Self {
160 Self {
161 run_id,
162 workflow_name,
163 version,
164 started_at,
165 workflow_time: started_at,
166 auth: AuthContext::unauthenticated(),
167 db_pool,
168 http_client,
169 http_timeout: None,
170 step_states: Arc::new(RwLock::new(HashMap::new())),
171 completed_steps: Arc::new(RwLock::new(Vec::new())),
172 compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
173 suspend_tx: None,
174 is_resumed: true,
175 resumed_from_sleep: false,
176 tenant_id: None,
177 env_provider: Arc::new(RealEnvProvider::new()),
178 }
179 }
180
181 pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
183 self.env_provider = provider;
184 self
185 }
186
187 pub fn with_resumed_from_sleep(mut self) -> Self {
189 self.resumed_from_sleep = true;
190 self
191 }
192
193 pub fn with_suspend_channel(mut self, tx: mpsc::Sender<SuspendReason>) -> Self {
195 self.suspend_tx = Some(tx);
196 self
197 }
198
199 pub fn with_tenant(mut self, tenant_id: Uuid) -> Self {
201 self.tenant_id = Some(tenant_id);
202 self
203 }
204
205 pub fn tenant_id(&self) -> Option<Uuid> {
206 self.tenant_id
207 }
208
209 pub fn is_resumed(&self) -> bool {
210 self.is_resumed
211 }
212
213 pub fn workflow_time(&self) -> DateTime<Utc> {
214 self.workflow_time
215 }
216
217 pub fn db(&self) -> crate::function::ForgeDb {
218 crate::function::ForgeDb::from_pool(&self.db_pool)
219 }
220
221 pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
223 Ok(crate::function::ForgeConn::Pool(
224 self.db_pool.acquire().await?,
225 ))
226 }
227
228 pub fn http(&self) -> crate::http::HttpClient {
229 self.http_client.with_timeout(self.http_timeout)
230 }
231
232 pub fn raw_http(&self) -> &reqwest::Client {
233 self.http_client.inner()
234 }
235
236 pub fn http_with_circuit_breaker(&self) -> crate::http::HttpClient {
237 self.http()
238 }
239
240 pub fn set_http_timeout(&mut self, timeout: Option<Duration>) {
241 self.http_timeout = timeout;
242 }
243
244 pub fn with_auth(mut self, auth: AuthContext) -> Self {
246 self.auth = auth;
247 self
248 }
249
250 pub fn with_step_states(self, states: HashMap<String, StepState>) -> Self {
252 let completed: Vec<String> = states
253 .iter()
254 .filter(|(_, s)| s.status == StepStatus::Completed)
255 .map(|(name, _)| name.clone())
256 .collect();
257
258 *self.step_states.write().expect("workflow lock poisoned") = states;
259 *self
260 .completed_steps
261 .write()
262 .expect("workflow lock poisoned") = completed;
263 self
264 }
265
266 pub fn get_step_state(&self, name: &str) -> Option<StepState> {
267 self.step_states
268 .read()
269 .expect("workflow lock poisoned")
270 .get(name)
271 .cloned()
272 }
273
274 pub fn is_step_completed(&self, name: &str) -> bool {
275 self.step_states
276 .read()
277 .expect("workflow lock poisoned")
278 .get(name)
279 .map(|s| s.status == StepStatus::Completed)
280 .unwrap_or(false)
281 }
282
283 pub fn is_step_started(&self, name: &str) -> bool {
288 self.step_states
289 .read()
290 .expect("workflow lock poisoned")
291 .get(name)
292 .map(|s| s.status != StepStatus::Pending)
293 .unwrap_or(false)
294 }
295
296 pub fn get_step_result<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
297 self.step_states
298 .read()
299 .expect("workflow lock poisoned")
300 .get(name)
301 .and_then(|s| s.result.as_ref())
302 .and_then(|v| serde_json::from_value(v.clone()).ok())
303 }
304
305 pub fn record_step_start(&self, name: &str) {
310 let mut states = self.step_states.write().expect("workflow lock poisoned");
311 let state = states
312 .entry(name.to_string())
313 .or_insert_with(|| StepState::new(name));
314
315 if state.status != StepStatus::Pending {
318 return;
319 }
320
321 state.start();
322 let state_clone = state.clone();
323 drop(states);
324
325 let pool = self.db_pool.clone();
327 let run_id = self.run_id;
328 let step_name = name.to_string();
329 tokio::spawn(async move {
330 let step_id = Uuid::new_v4();
331 if let Err(e) = sqlx::query(
332 r#"
333 INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, started_at)
334 VALUES ($1, $2, $3, $4, $5)
335 ON CONFLICT (workflow_run_id, step_name) DO NOTHING
336 "#,
337 )
338 .bind(step_id)
339 .bind(run_id)
340 .bind(&step_name)
341 .bind(state_clone.status.as_str())
342 .bind(state_clone.started_at)
343 .execute(&pool)
344 .await
345 {
346 tracing::warn!(
347 workflow_run_id = %run_id,
348 step = %step_name,
349 "Failed to persist step start: {}",
350 e
351 );
352 }
353 });
354 }
355
356 pub fn record_step_complete(&self, name: &str, result: serde_json::Value) {
359 let state_clone = self.update_step_state_complete(name, result);
360
361 if let Some(state) = state_clone {
363 let pool = self.db_pool.clone();
364 let run_id = self.run_id;
365 let step_name = name.to_string();
366 tokio::spawn(async move {
367 Self::persist_step_complete(&pool, run_id, &step_name, &state).await;
368 });
369 }
370 }
371
372 pub async fn record_step_complete_async(&self, name: &str, result: serde_json::Value) {
374 let state_clone = self.update_step_state_complete(name, result);
375
376 if let Some(state) = state_clone {
378 Self::persist_step_complete(&self.db_pool, self.run_id, name, &state).await;
379 }
380 }
381
382 fn update_step_state_complete(
384 &self,
385 name: &str,
386 result: serde_json::Value,
387 ) -> Option<StepState> {
388 let mut states = self.step_states.write().expect("workflow lock poisoned");
389 if let Some(state) = states.get_mut(name) {
390 state.complete(result.clone());
391 }
392 let state_clone = states.get(name).cloned();
393 drop(states);
394
395 let mut completed = self
396 .completed_steps
397 .write()
398 .expect("workflow lock poisoned");
399 if !completed.contains(&name.to_string()) {
400 completed.push(name.to_string());
401 }
402 drop(completed);
403
404 state_clone
405 }
406
407 async fn persist_step_complete(
409 pool: &sqlx::PgPool,
410 run_id: Uuid,
411 step_name: &str,
412 state: &StepState,
413 ) {
414 if let Err(e) = sqlx::query(
416 r#"
417 INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, result, started_at, completed_at)
418 VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6)
419 ON CONFLICT (workflow_run_id, step_name) DO UPDATE
420 SET status = $3, result = $4, completed_at = $6
421 "#,
422 )
423 .bind(run_id)
424 .bind(step_name)
425 .bind(state.status.as_str())
426 .bind(&state.result)
427 .bind(state.started_at)
428 .bind(state.completed_at)
429 .execute(pool)
430 .await
431 {
432 tracing::warn!(
433 workflow_run_id = %run_id,
434 step = %step_name,
435 "Failed to persist step completion: {}",
436 e
437 );
438 }
439 }
440
441 pub fn record_step_failure(&self, name: &str, error: impl Into<String>) {
443 let error_str = error.into();
444 let mut states = self.step_states.write().expect("workflow lock poisoned");
445 if let Some(state) = states.get_mut(name) {
446 state.fail(error_str.clone());
447 }
448 let state_clone = states.get(name).cloned();
449 drop(states);
450
451 if let Some(state) = state_clone {
453 let pool = self.db_pool.clone();
454 let run_id = self.run_id;
455 let step_name = name.to_string();
456 tokio::spawn(async move {
457 if let Err(e) = sqlx::query(
458 r#"
459 UPDATE forge_workflow_steps
460 SET status = $3, error = $4, completed_at = $5
461 WHERE workflow_run_id = $1 AND step_name = $2
462 "#,
463 )
464 .bind(run_id)
465 .bind(&step_name)
466 .bind(state.status.as_str())
467 .bind(&state.error)
468 .bind(state.completed_at)
469 .execute(&pool)
470 .await
471 {
472 tracing::warn!(
473 workflow_run_id = %run_id,
474 step = %step_name,
475 "Failed to persist step failure: {}",
476 e
477 );
478 }
479 });
480 }
481 }
482
483 pub fn record_step_compensated(&self, name: &str) {
485 let mut states = self.step_states.write().expect("workflow lock poisoned");
486 if let Some(state) = states.get_mut(name) {
487 state.compensate();
488 }
489 let state_clone = states.get(name).cloned();
490 drop(states);
491
492 if let Some(state) = state_clone {
494 let pool = self.db_pool.clone();
495 let run_id = self.run_id;
496 let step_name = name.to_string();
497 tokio::spawn(async move {
498 if let Err(e) = sqlx::query(
499 r#"
500 UPDATE forge_workflow_steps
501 SET status = $3
502 WHERE workflow_run_id = $1 AND step_name = $2
503 "#,
504 )
505 .bind(run_id)
506 .bind(&step_name)
507 .bind(state.status.as_str())
508 .execute(&pool)
509 .await
510 {
511 tracing::warn!(
512 workflow_run_id = %run_id,
513 step = %step_name,
514 "Failed to persist step compensation: {}",
515 e
516 );
517 }
518 });
519 }
520 }
521
522 pub fn completed_steps_reversed(&self) -> Vec<String> {
523 let completed = self.completed_steps.read().expect("workflow lock poisoned");
524 completed.iter().rev().cloned().collect()
525 }
526
527 pub fn all_step_states(&self) -> HashMap<String, StepState> {
528 self.step_states
529 .read()
530 .expect("workflow lock poisoned")
531 .clone()
532 }
533
534 pub fn elapsed(&self) -> chrono::Duration {
535 Utc::now() - self.started_at
536 }
537
538 pub fn register_compensation(&self, step_name: &str, handler: CompensationHandler) {
540 let mut handlers = self
541 .compensation_handlers
542 .write()
543 .expect("workflow lock poisoned");
544 handlers.insert(step_name.to_string(), handler);
545 }
546
547 pub fn get_compensation_handler(&self, step_name: &str) -> Option<CompensationHandler> {
548 self.compensation_handlers
549 .read()
550 .expect("workflow lock poisoned")
551 .get(step_name)
552 .cloned()
553 }
554
555 pub fn has_compensation(&self, step_name: &str) -> bool {
556 self.compensation_handlers
557 .read()
558 .expect("workflow lock poisoned")
559 .contains_key(step_name)
560 }
561
562 pub async fn run_compensation(&self) -> Vec<(String, bool)> {
565 let steps = self.completed_steps_reversed();
566 let mut results = Vec::new();
567
568 for step_name in steps {
569 let handler = self.get_compensation_handler(&step_name);
570 let result = self
571 .get_step_state(&step_name)
572 .and_then(|s| s.result.clone());
573
574 if let Some(handler) = handler {
575 let step_result = result.unwrap_or(serde_json::Value::Null);
576 match handler(step_result).await {
577 Ok(()) => {
578 self.record_step_compensated(&step_name);
579 results.push((step_name, true));
580 }
581 Err(e) => {
582 tracing::error!(step = %step_name, error = %e, "Compensation failed");
583 results.push((step_name, false));
584 }
585 }
586 } else {
587 self.record_step_compensated(&step_name);
589 results.push((step_name, true));
590 }
591 }
592
593 results
594 }
595
596 pub fn compensation_handlers(&self) -> HashMap<String, CompensationHandler> {
597 self.compensation_handlers
598 .read()
599 .expect("workflow lock poisoned")
600 .clone()
601 }
602
603 pub async fn sleep(&self, duration: Duration) -> Result<()> {
614 if self.resumed_from_sleep {
616 return Ok(());
617 }
618
619 let wake_at = Utc::now() + chrono::Duration::from_std(duration).unwrap_or_default();
620 self.sleep_until(wake_at).await
621 }
622
623 pub async fn sleep_until(&self, wake_at: DateTime<Utc>) -> Result<()> {
634 if self.resumed_from_sleep {
636 return Ok(());
637 }
638
639 if wake_at <= Utc::now() {
641 return Ok(());
642 }
643
644 self.set_wake_at(wake_at).await?;
646
647 self.signal_suspend(SuspendReason::Sleep { wake_at })
649 .await?;
650
651 Ok(())
652 }
653
654 pub async fn wait_for_event<T: DeserializeOwned>(
667 &self,
668 event_name: &str,
669 timeout: Option<Duration>,
670 ) -> Result<T> {
671 let correlation_id = self.run_id.to_string();
672
673 if let Some(event) = self.try_consume_event(event_name, &correlation_id).await? {
675 return serde_json::from_value(event.payload.unwrap_or_default())
676 .map_err(|e| ForgeError::Deserialization(e.to_string()));
677 }
678
679 let timeout_at =
681 timeout.map(|d| Utc::now() + chrono::Duration::from_std(d).unwrap_or_default());
682
683 self.set_waiting_for_event(event_name, timeout_at).await?;
685
686 self.signal_suspend(SuspendReason::WaitingEvent {
688 event_name: event_name.to_string(),
689 timeout: timeout_at,
690 })
691 .await?;
692
693 self.try_consume_event(event_name, &correlation_id)
695 .await?
696 .and_then(|e| e.payload)
697 .and_then(|p| serde_json::from_value(p).ok())
698 .ok_or_else(|| ForgeError::Timeout(format!("Event '{}' timed out", event_name)))
699 }
700
701 #[allow(clippy::type_complexity)]
703 async fn try_consume_event(
704 &self,
705 event_name: &str,
706 correlation_id: &str,
707 ) -> Result<Option<WorkflowEvent>> {
708 let result = sqlx::query!(
709 r#"
710 UPDATE forge_workflow_events
711 SET consumed_at = NOW(), consumed_by = $3
712 WHERE id = (
713 SELECT id FROM forge_workflow_events
714 WHERE event_name = $1 AND correlation_id = $2 AND consumed_at IS NULL
715 ORDER BY created_at ASC LIMIT 1
716 FOR UPDATE SKIP LOCKED
717 )
718 RETURNING id, event_name, correlation_id, payload, created_at
719 "#,
720 event_name,
721 correlation_id,
722 self.run_id
723 )
724 .fetch_optional(&self.db_pool)
725 .await
726 .map_err(|e| ForgeError::Database(e.to_string()))?;
727
728 Ok(result.map(|row| WorkflowEvent {
729 id: row.id,
730 event_name: row.event_name,
731 correlation_id: row.correlation_id,
732 payload: row.payload,
733 created_at: row.created_at,
734 }))
735 }
736
737 async fn set_wake_at(&self, wake_at: DateTime<Utc>) -> Result<()> {
739 sqlx::query(
740 r#"
741 UPDATE forge_workflow_runs
742 SET status = 'waiting', suspended_at = NOW(), wake_at = $2
743 WHERE id = $1
744 "#,
745 )
746 .bind(self.run_id)
747 .bind(wake_at)
748 .execute(&self.db_pool)
749 .await
750 .map_err(|e| ForgeError::Database(e.to_string()))?;
751 Ok(())
752 }
753
754 async fn set_waiting_for_event(
756 &self,
757 event_name: &str,
758 timeout_at: Option<DateTime<Utc>>,
759 ) -> Result<()> {
760 sqlx::query(
761 r#"
762 UPDATE forge_workflow_runs
763 SET status = 'waiting', suspended_at = NOW(), waiting_for_event = $2, event_timeout_at = $3
764 WHERE id = $1
765 "#,
766 )
767 .bind(self.run_id)
768 .bind(event_name)
769 .bind(timeout_at)
770 .execute(&self.db_pool)
771 .await
772 .map_err(|e| ForgeError::Database(e.to_string()))?;
773 Ok(())
774 }
775
776 async fn signal_suspend(&self, reason: SuspendReason) -> Result<()> {
778 if let Some(ref tx) = self.suspend_tx {
779 tx.send(reason)
780 .await
781 .map_err(|_| ForgeError::Internal("Failed to signal suspension".into()))?;
782 }
783 Err(ForgeError::WorkflowSuspended)
785 }
786
787 pub fn parallel(&self) -> ParallelBuilder<'_> {
803 ParallelBuilder::new(self)
804 }
805
806 pub fn step<T, F, Fut>(&self, name: impl Into<String>, f: F) -> super::StepRunner<'_, T>
856 where
857 T: serde::Serialize + serde::de::DeserializeOwned + Clone + Send + Sync + 'static,
858 F: Fn() -> Fut + Send + Sync + 'static,
859 Fut: std::future::Future<Output = crate::Result<T>> + Send + 'static,
860 {
861 super::StepRunner::new(self, name, f)
862 }
863}
864
865impl EnvAccess for WorkflowContext {
866 fn env_provider(&self) -> &dyn EnvProvider {
867 self.env_provider.as_ref()
868 }
869}
870
871#[cfg(test)]
872#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
873mod tests {
874 use super::*;
875
876 #[tokio::test]
877 async fn test_workflow_context_creation() {
878 let pool = sqlx::postgres::PgPoolOptions::new()
879 .max_connections(1)
880 .connect_lazy("postgres://localhost/nonexistent")
881 .expect("Failed to create mock pool");
882
883 let run_id = Uuid::new_v4();
884 let ctx = WorkflowContext::new(
885 run_id,
886 "test_workflow".to_string(),
887 1,
888 pool,
889 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
890 );
891
892 assert_eq!(ctx.run_id, run_id);
893 assert_eq!(ctx.workflow_name, "test_workflow");
894 assert_eq!(ctx.version, 1);
895 }
896
897 #[tokio::test]
898 async fn test_step_state_tracking() {
899 let pool = sqlx::postgres::PgPoolOptions::new()
900 .max_connections(1)
901 .connect_lazy("postgres://localhost/nonexistent")
902 .expect("Failed to create mock pool");
903
904 let ctx = WorkflowContext::new(
905 Uuid::new_v4(),
906 "test".to_string(),
907 1,
908 pool,
909 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
910 );
911
912 ctx.record_step_start("step1");
913 assert!(!ctx.is_step_completed("step1"));
914
915 ctx.record_step_complete("step1", serde_json::json!({"result": "ok"}));
916 assert!(ctx.is_step_completed("step1"));
917
918 let result: Option<serde_json::Value> = ctx.get_step_result("step1");
919 assert!(result.is_some());
920 }
921
922 #[test]
923 fn test_step_state_transitions() {
924 let mut state = StepState::new("test");
925 assert_eq!(state.status, StepStatus::Pending);
926
927 state.start();
928 assert_eq!(state.status, StepStatus::Running);
929 assert!(state.started_at.is_some());
930
931 state.complete(serde_json::json!({}));
932 assert_eq!(state.status, StepStatus::Completed);
933 assert!(state.completed_at.is_some());
934 }
935}