1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, RwLock};
5use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use serde::de::DeserializeOwned;
9use tokio::sync::mpsc;
10use uuid::Uuid;
11
12use super::parallel::ParallelBuilder;
13use super::step::StepStatus;
14use super::suspend::{SuspendReason, WorkflowEvent};
15use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
16use crate::function::AuthContext;
17use crate::http::CircuitBreakerClient;
18use crate::{ForgeError, Result};
19
20pub type CompensationHandler = Arc<
22 dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync,
23>;
24
25#[derive(Debug, Clone)]
27pub struct StepState {
28 pub name: String,
30 pub status: StepStatus,
32 pub result: Option<serde_json::Value>,
34 pub error: Option<String>,
36 pub started_at: Option<DateTime<Utc>>,
38 pub completed_at: Option<DateTime<Utc>>,
40}
41
42impl StepState {
43 pub fn new(name: impl Into<String>) -> Self {
45 Self {
46 name: name.into(),
47 status: StepStatus::Pending,
48 result: None,
49 error: None,
50 started_at: None,
51 completed_at: None,
52 }
53 }
54
55 pub fn start(&mut self) {
57 self.status = StepStatus::Running;
58 self.started_at = Some(Utc::now());
59 }
60
61 pub fn complete(&mut self, result: serde_json::Value) {
63 self.status = StepStatus::Completed;
64 self.result = Some(result);
65 self.completed_at = Some(Utc::now());
66 }
67
68 pub fn fail(&mut self, error: impl Into<String>) {
70 self.status = StepStatus::Failed;
71 self.error = Some(error.into());
72 self.completed_at = Some(Utc::now());
73 }
74
75 pub fn compensate(&mut self) {
77 self.status = StepStatus::Compensated;
78 }
79}
80
81pub struct WorkflowContext {
83 pub run_id: Uuid,
85 pub workflow_name: String,
87 pub started_at: DateTime<Utc>,
89 workflow_time: DateTime<Utc>,
91 pub auth: AuthContext,
93 db_pool: sqlx::PgPool,
95 http_client: CircuitBreakerClient,
97 http_timeout: Option<Duration>,
100 step_states: Arc<RwLock<HashMap<String, StepState>>>,
102 completed_steps: Arc<RwLock<Vec<String>>>,
104 compensation_handlers: Arc<RwLock<HashMap<String, CompensationHandler>>>,
106 suspend_tx: Option<mpsc::Sender<SuspendReason>>,
108 is_resumed: bool,
110 resumed_from_sleep: bool,
112 tenant_id: Option<Uuid>,
114 env_provider: Arc<dyn EnvProvider>,
116}
117
118impl WorkflowContext {
119 pub fn new(
121 run_id: Uuid,
122 workflow_name: String,
123 db_pool: sqlx::PgPool,
124 http_client: CircuitBreakerClient,
125 ) -> Self {
126 let now = Utc::now();
127 Self {
128 run_id,
129 workflow_name,
130 started_at: now,
131 workflow_time: now,
132 auth: AuthContext::unauthenticated(),
133 db_pool,
134 http_client,
135 http_timeout: None,
136 step_states: Arc::new(RwLock::new(HashMap::new())),
137 completed_steps: Arc::new(RwLock::new(Vec::new())),
138 compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
139 suspend_tx: None,
140 is_resumed: false,
141 resumed_from_sleep: false,
142 tenant_id: None,
143 env_provider: Arc::new(RealEnvProvider::new()),
144 }
145 }
146
147 pub fn resumed(
149 run_id: Uuid,
150 workflow_name: String,
151 started_at: DateTime<Utc>,
152 db_pool: sqlx::PgPool,
153 http_client: CircuitBreakerClient,
154 ) -> Self {
155 Self {
156 run_id,
157 workflow_name,
158 started_at,
159 workflow_time: started_at,
160 auth: AuthContext::unauthenticated(),
161 db_pool,
162 http_client,
163 http_timeout: None,
164 step_states: Arc::new(RwLock::new(HashMap::new())),
165 completed_steps: Arc::new(RwLock::new(Vec::new())),
166 compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
167 suspend_tx: None,
168 is_resumed: true,
169 resumed_from_sleep: false,
170 tenant_id: None,
171 env_provider: Arc::new(RealEnvProvider::new()),
172 }
173 }
174
175 pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
177 self.env_provider = provider;
178 self
179 }
180
181 pub fn with_resumed_from_sleep(mut self) -> Self {
183 self.resumed_from_sleep = true;
184 self
185 }
186
187 pub fn with_suspend_channel(mut self, tx: mpsc::Sender<SuspendReason>) -> Self {
189 self.suspend_tx = Some(tx);
190 self
191 }
192
193 pub fn with_tenant(mut self, tenant_id: Uuid) -> Self {
195 self.tenant_id = Some(tenant_id);
196 self
197 }
198
199 pub fn tenant_id(&self) -> Option<Uuid> {
200 self.tenant_id
201 }
202
203 pub fn is_resumed(&self) -> bool {
204 self.is_resumed
205 }
206
207 pub fn workflow_time(&self) -> DateTime<Utc> {
208 self.workflow_time
209 }
210
211 pub fn db(&self) -> crate::function::ForgeDb {
212 crate::function::ForgeDb::from_pool(&self.db_pool)
213 }
214
215 pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
217 Ok(crate::function::ForgeConn::Pool(
218 self.db_pool.acquire().await?,
219 ))
220 }
221
222 pub fn http(&self) -> crate::http::HttpClient {
223 self.http_client.with_timeout(self.http_timeout)
224 }
225
226 pub fn raw_http(&self) -> &reqwest::Client {
227 self.http_client.inner()
228 }
229
230 pub fn set_http_timeout(&mut self, timeout: Option<Duration>) {
231 self.http_timeout = timeout;
232 }
233
234 pub fn with_auth(mut self, auth: AuthContext) -> Self {
236 self.auth = auth;
237 self
238 }
239
240 pub fn with_step_states(self, states: HashMap<String, StepState>) -> Self {
242 let completed: Vec<String> = states
243 .iter()
244 .filter(|(_, s)| s.status == StepStatus::Completed)
245 .map(|(name, _)| name.clone())
246 .collect();
247
248 *self.step_states.write().expect("workflow lock poisoned") = states;
249 *self
250 .completed_steps
251 .write()
252 .expect("workflow lock poisoned") = completed;
253 self
254 }
255
256 pub fn get_step_state(&self, name: &str) -> Option<StepState> {
257 self.step_states
258 .read()
259 .expect("workflow lock poisoned")
260 .get(name)
261 .cloned()
262 }
263
264 pub fn is_step_completed(&self, name: &str) -> bool {
265 self.step_states
266 .read()
267 .expect("workflow lock poisoned")
268 .get(name)
269 .map(|s| s.status == StepStatus::Completed)
270 .unwrap_or(false)
271 }
272
273 pub fn is_step_started(&self, name: &str) -> bool {
278 self.step_states
279 .read()
280 .expect("workflow lock poisoned")
281 .get(name)
282 .map(|s| s.status != StepStatus::Pending)
283 .unwrap_or(false)
284 }
285
286 pub fn get_step_result<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
287 self.step_states
288 .read()
289 .expect("workflow lock poisoned")
290 .get(name)
291 .and_then(|s| s.result.as_ref())
292 .and_then(|v| serde_json::from_value(v.clone()).ok())
293 }
294
295 pub fn record_step_start(&self, name: &str) {
300 let mut states = self.step_states.write().expect("workflow lock poisoned");
301 let state = states
302 .entry(name.to_string())
303 .or_insert_with(|| StepState::new(name));
304
305 if state.status != StepStatus::Pending {
308 return;
309 }
310
311 state.start();
312 let state_clone = state.clone();
313 drop(states);
314
315 let pool = self.db_pool.clone();
317 let run_id = self.run_id;
318 let step_name = name.to_string();
319 tokio::spawn(async move {
320 let step_id = Uuid::new_v4();
321 if let Err(e) = sqlx::query!(
322 r#"
323 INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, started_at)
324 VALUES ($1, $2, $3, $4, $5)
325 ON CONFLICT (workflow_run_id, step_name) DO NOTHING
326 "#,
327 step_id,
328 run_id,
329 step_name,
330 state_clone.status.as_str(),
331 state_clone.started_at,
332 )
333 .execute(&pool)
334 .await
335 {
336 tracing::warn!(
337 workflow_run_id = %run_id,
338 step = %step_name,
339 "Failed to persist step start: {}",
340 e
341 );
342 }
343 });
344 }
345
346 pub fn record_step_complete(&self, name: &str, result: serde_json::Value) {
349 let state_clone = self.update_step_state_complete(name, result);
350
351 if let Some(state) = state_clone {
353 let pool = self.db_pool.clone();
354 let run_id = self.run_id;
355 let step_name = name.to_string();
356 tokio::spawn(async move {
357 Self::persist_step_complete(&pool, run_id, &step_name, &state).await;
358 });
359 }
360 }
361
362 pub async fn record_step_complete_async(&self, name: &str, result: serde_json::Value) {
364 let state_clone = self.update_step_state_complete(name, result);
365
366 if let Some(state) = state_clone {
368 Self::persist_step_complete(&self.db_pool, self.run_id, name, &state).await;
369 }
370 }
371
372 fn update_step_state_complete(
374 &self,
375 name: &str,
376 result: serde_json::Value,
377 ) -> Option<StepState> {
378 let mut states = self.step_states.write().expect("workflow lock poisoned");
379 if let Some(state) = states.get_mut(name) {
380 state.complete(result.clone());
381 }
382 let state_clone = states.get(name).cloned();
383 drop(states);
384
385 let mut completed = self
386 .completed_steps
387 .write()
388 .expect("workflow lock poisoned");
389 if !completed.contains(&name.to_string()) {
390 completed.push(name.to_string());
391 }
392 drop(completed);
393
394 state_clone
395 }
396
397 async fn persist_step_complete(
399 pool: &sqlx::PgPool,
400 run_id: Uuid,
401 step_name: &str,
402 state: &StepState,
403 ) {
404 if let Err(e) = sqlx::query!(
406 r#"
407 INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, result, started_at, completed_at)
408 VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6)
409 ON CONFLICT (workflow_run_id, step_name) DO UPDATE
410 SET status = $3, result = $4, completed_at = $6
411 "#,
412 run_id,
413 step_name,
414 state.status.as_str(),
415 state.result as _,
416 state.started_at,
417 state.completed_at,
418 )
419 .execute(pool)
420 .await
421 {
422 tracing::warn!(
423 workflow_run_id = %run_id,
424 step = %step_name,
425 "Failed to persist step completion: {}",
426 e
427 );
428 }
429 }
430
431 pub fn record_step_failure(&self, name: &str, error: impl Into<String>) {
433 let error_str = error.into();
434 let mut states = self.step_states.write().expect("workflow lock poisoned");
435 if let Some(state) = states.get_mut(name) {
436 state.fail(error_str.clone());
437 }
438 let state_clone = states.get(name).cloned();
439 drop(states);
440
441 if let Some(state) = state_clone {
443 let pool = self.db_pool.clone();
444 let run_id = self.run_id;
445 let step_name = name.to_string();
446 tokio::spawn(async move {
447 if let Err(e) = sqlx::query!(
448 r#"
449 UPDATE forge_workflow_steps
450 SET status = $3, error = $4, completed_at = $5
451 WHERE workflow_run_id = $1 AND step_name = $2
452 "#,
453 run_id,
454 step_name,
455 state.status.as_str(),
456 state.error as _,
457 state.completed_at,
458 )
459 .execute(&pool)
460 .await
461 {
462 tracing::warn!(
463 workflow_run_id = %run_id,
464 step = %step_name,
465 "Failed to persist step failure: {}",
466 e
467 );
468 }
469 });
470 }
471 }
472
473 pub fn record_step_compensated(&self, name: &str) {
475 let mut states = self.step_states.write().expect("workflow lock poisoned");
476 if let Some(state) = states.get_mut(name) {
477 state.compensate();
478 }
479 let state_clone = states.get(name).cloned();
480 drop(states);
481
482 if let Some(state) = state_clone {
484 let pool = self.db_pool.clone();
485 let run_id = self.run_id;
486 let step_name = name.to_string();
487 tokio::spawn(async move {
488 if let Err(e) = sqlx::query!(
489 r#"
490 UPDATE forge_workflow_steps
491 SET status = $3
492 WHERE workflow_run_id = $1 AND step_name = $2
493 "#,
494 run_id,
495 step_name,
496 state.status.as_str(),
497 )
498 .execute(&pool)
499 .await
500 {
501 tracing::warn!(
502 workflow_run_id = %run_id,
503 step = %step_name,
504 "Failed to persist step compensation: {}",
505 e
506 );
507 }
508 });
509 }
510 }
511
512 pub fn completed_steps_reversed(&self) -> Vec<String> {
513 let completed = self.completed_steps.read().expect("workflow lock poisoned");
514 completed.iter().rev().cloned().collect()
515 }
516
517 pub fn all_step_states(&self) -> HashMap<String, StepState> {
518 self.step_states
519 .read()
520 .expect("workflow lock poisoned")
521 .clone()
522 }
523
524 pub fn elapsed(&self) -> chrono::Duration {
525 Utc::now() - self.started_at
526 }
527
528 pub fn register_compensation(&self, step_name: &str, handler: CompensationHandler) {
530 let mut handlers = self
531 .compensation_handlers
532 .write()
533 .expect("workflow lock poisoned");
534 handlers.insert(step_name.to_string(), handler);
535 }
536
537 pub fn get_compensation_handler(&self, step_name: &str) -> Option<CompensationHandler> {
538 self.compensation_handlers
539 .read()
540 .expect("workflow lock poisoned")
541 .get(step_name)
542 .cloned()
543 }
544
545 pub fn has_compensation(&self, step_name: &str) -> bool {
546 self.compensation_handlers
547 .read()
548 .expect("workflow lock poisoned")
549 .contains_key(step_name)
550 }
551
552 pub async fn run_compensation(&self) -> Vec<(String, bool)> {
555 let steps = self.completed_steps_reversed();
556 let mut results = Vec::new();
557
558 for step_name in steps {
559 let handler = self.get_compensation_handler(&step_name);
560 let result = self
561 .get_step_state(&step_name)
562 .and_then(|s| s.result.clone());
563
564 if let Some(handler) = handler {
565 let step_result = result.unwrap_or(serde_json::Value::Null);
566 match handler(step_result).await {
567 Ok(()) => {
568 self.record_step_compensated(&step_name);
569 results.push((step_name, true));
570 }
571 Err(e) => {
572 tracing::error!(step = %step_name, error = %e, "Compensation failed");
573 results.push((step_name, false));
574 }
575 }
576 } else {
577 self.record_step_compensated(&step_name);
579 results.push((step_name, true));
580 }
581 }
582
583 results
584 }
585
586 pub fn compensation_handlers(&self) -> HashMap<String, CompensationHandler> {
587 self.compensation_handlers
588 .read()
589 .expect("workflow lock poisoned")
590 .clone()
591 }
592
593 pub async fn sleep(&self, duration: Duration) -> Result<()> {
604 if self.resumed_from_sleep {
606 return Ok(());
607 }
608
609 let wake_at = Utc::now() + chrono::Duration::from_std(duration).unwrap_or_default();
610 self.sleep_until(wake_at).await
611 }
612
613 pub async fn sleep_until(&self, wake_at: DateTime<Utc>) -> Result<()> {
624 if self.resumed_from_sleep {
626 return Ok(());
627 }
628
629 if wake_at <= Utc::now() {
631 return Ok(());
632 }
633
634 self.set_wake_at(wake_at).await?;
636
637 self.signal_suspend(SuspendReason::Sleep { wake_at })
639 .await?;
640
641 Ok(())
642 }
643
644 pub async fn wait_for_event<T: DeserializeOwned>(
657 &self,
658 event_name: &str,
659 timeout: Option<Duration>,
660 ) -> Result<T> {
661 let correlation_id = self.run_id.to_string();
662
663 if self.is_resumed
666 && let Some(event) = self
667 .find_consumed_event(event_name, &correlation_id)
668 .await?
669 {
670 return serde_json::from_value(event.payload.unwrap_or_default())
671 .map_err(|e| ForgeError::Deserialization(e.to_string()));
672 }
673
674 if let Some(event) = self.try_consume_event(event_name, &correlation_id).await? {
676 return serde_json::from_value(event.payload.unwrap_or_default())
677 .map_err(|e| ForgeError::Deserialization(e.to_string()));
678 }
679
680 let timeout_at =
682 timeout.map(|d| Utc::now() + chrono::Duration::from_std(d).unwrap_or_default());
683
684 self.set_waiting_for_event(event_name, timeout_at).await?;
686
687 self.signal_suspend(SuspendReason::WaitingEvent {
689 event_name: event_name.to_string(),
690 timeout: timeout_at,
691 })
692 .await?;
693
694 self.try_consume_event(event_name, &correlation_id)
696 .await?
697 .and_then(|e| e.payload)
698 .and_then(|p| serde_json::from_value(p).ok())
699 .ok_or_else(|| ForgeError::Timeout(format!("Event '{}' timed out", event_name)))
700 }
701
702 #[allow(clippy::type_complexity)]
704 async fn try_consume_event(
705 &self,
706 event_name: &str,
707 correlation_id: &str,
708 ) -> Result<Option<WorkflowEvent>> {
709 let result = sqlx::query!(
710 r#"
711 UPDATE forge_workflow_events
712 SET consumed_at = NOW(), consumed_by = $3
713 WHERE id = (
714 SELECT id FROM forge_workflow_events
715 WHERE event_name = $1 AND correlation_id = $2 AND consumed_at IS NULL
716 ORDER BY created_at ASC LIMIT 1
717 FOR UPDATE SKIP LOCKED
718 )
719 RETURNING id, event_name, correlation_id, payload, created_at
720 "#,
721 event_name,
722 correlation_id,
723 self.run_id
724 )
725 .fetch_optional(&self.db_pool)
726 .await
727 .map_err(|e| ForgeError::Database(e.to_string()))?;
728
729 Ok(result.map(|row| WorkflowEvent {
730 id: row.id,
731 event_name: row.event_name,
732 correlation_id: row.correlation_id,
733 payload: row.payload,
734 created_at: row.created_at,
735 }))
736 }
737
738 async fn find_consumed_event(
741 &self,
742 event_name: &str,
743 correlation_id: &str,
744 ) -> Result<Option<WorkflowEvent>> {
745 let result = sqlx::query!(
746 r#"
747 SELECT id, event_name, correlation_id, payload, created_at
748 FROM forge_workflow_events
749 WHERE event_name = $1 AND correlation_id = $2 AND consumed_by = $3
750 ORDER BY created_at DESC LIMIT 1
751 "#,
752 event_name,
753 correlation_id,
754 self.run_id
755 )
756 .fetch_optional(&self.db_pool)
757 .await
758 .map_err(|e| ForgeError::Database(e.to_string()))?;
759
760 Ok(result.map(|row| WorkflowEvent {
761 id: row.id,
762 event_name: row.event_name,
763 correlation_id: row.correlation_id,
764 payload: row.payload,
765 created_at: row.created_at,
766 }))
767 }
768
769 async fn set_wake_at(&self, wake_at: DateTime<Utc>) -> Result<()> {
771 sqlx::query!(
772 r#"
773 UPDATE forge_workflow_runs
774 SET status = 'waiting', suspended_at = NOW(), wake_at = $2
775 WHERE id = $1
776 "#,
777 self.run_id,
778 wake_at,
779 )
780 .execute(&self.db_pool)
781 .await
782 .map_err(|e| ForgeError::Database(e.to_string()))?;
783 Ok(())
784 }
785
786 async fn set_waiting_for_event(
788 &self,
789 event_name: &str,
790 timeout_at: Option<DateTime<Utc>>,
791 ) -> Result<()> {
792 sqlx::query!(
793 r#"
794 UPDATE forge_workflow_runs
795 SET status = 'waiting', suspended_at = NOW(), waiting_for_event = $2, event_timeout_at = $3
796 WHERE id = $1
797 "#,
798 self.run_id,
799 event_name,
800 timeout_at,
801 )
802 .execute(&self.db_pool)
803 .await
804 .map_err(|e| ForgeError::Database(e.to_string()))?;
805 Ok(())
806 }
807
808 async fn signal_suspend(&self, reason: SuspendReason) -> Result<()> {
810 if let Some(ref tx) = self.suspend_tx {
811 tx.send(reason)
812 .await
813 .map_err(|_| ForgeError::Internal("Failed to signal suspension".into()))?;
814 }
815 Err(ForgeError::WorkflowSuspended)
817 }
818
819 pub fn parallel(&self) -> ParallelBuilder<'_> {
835 ParallelBuilder::new(self)
836 }
837
838 pub fn step<T, F, Fut>(&self, name: impl Into<String>, f: F) -> super::StepRunner<'_, T>
888 where
889 T: serde::Serialize + serde::de::DeserializeOwned + Clone + Send + Sync + 'static,
890 F: Fn() -> Fut + Send + Sync + 'static,
891 Fut: std::future::Future<Output = crate::Result<T>> + Send + 'static,
892 {
893 super::StepRunner::new(self, name, f)
894 }
895}
896
897impl EnvAccess for WorkflowContext {
898 fn env_provider(&self) -> &dyn EnvProvider {
899 self.env_provider.as_ref()
900 }
901}
902
903#[cfg(test)]
904#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
905mod tests {
906 use super::*;
907
908 #[tokio::test]
909 async fn test_workflow_context_creation() {
910 let pool = sqlx::postgres::PgPoolOptions::new()
911 .max_connections(1)
912 .connect_lazy("postgres://localhost/nonexistent")
913 .expect("Failed to create mock pool");
914
915 let run_id = Uuid::new_v4();
916 let ctx = WorkflowContext::new(
917 run_id,
918 "test_workflow".to_string(),
919 pool,
920 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
921 );
922
923 assert_eq!(ctx.run_id, run_id);
924 assert_eq!(ctx.workflow_name, "test_workflow");
925 }
926
927 #[tokio::test]
928 async fn test_step_state_tracking() {
929 let pool = sqlx::postgres::PgPoolOptions::new()
930 .max_connections(1)
931 .connect_lazy("postgres://localhost/nonexistent")
932 .expect("Failed to create mock pool");
933
934 let ctx = WorkflowContext::new(
935 Uuid::new_v4(),
936 "test".to_string(),
937 pool,
938 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
939 );
940
941 ctx.record_step_start("step1");
942 assert!(!ctx.is_step_completed("step1"));
943
944 ctx.record_step_complete("step1", serde_json::json!({"result": "ok"}));
945 assert!(ctx.is_step_completed("step1"));
946
947 let result: Option<serde_json::Value> = ctx.get_step_result("step1");
948 assert!(result.is_some());
949 }
950
951 #[test]
952 fn test_step_state_transitions() {
953 let mut state = StepState::new("test");
954 assert_eq!(state.status, StepStatus::Pending);
955
956 state.start();
957 assert_eq!(state.status, StepStatus::Running);
958 assert!(state.started_at.is_some());
959
960 state.complete(serde_json::json!({}));
961 assert_eq!(state.status, StepStatus::Completed);
962 assert!(state.completed_at.is_some());
963 }
964}