1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, RwLock};
5use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use serde::de::DeserializeOwned;
9use tokio::sync::mpsc;
10use uuid::Uuid;
11
12use super::parallel::ParallelBuilder;
13use super::step::StepStatus;
14use super::suspend::{SuspendReason, WorkflowEvent};
15use crate::function::AuthContext;
16use crate::{ForgeError, Result};
17
18pub type CompensationHandler = Arc<
20 dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync,
21>;
22
23#[derive(Debug, Clone)]
25pub struct StepState {
26 pub name: String,
28 pub status: StepStatus,
30 pub result: Option<serde_json::Value>,
32 pub error: Option<String>,
34 pub started_at: Option<DateTime<Utc>>,
36 pub completed_at: Option<DateTime<Utc>>,
38}
39
40impl StepState {
41 pub fn new(name: impl Into<String>) -> Self {
43 Self {
44 name: name.into(),
45 status: StepStatus::Pending,
46 result: None,
47 error: None,
48 started_at: None,
49 completed_at: None,
50 }
51 }
52
53 pub fn start(&mut self) {
55 self.status = StepStatus::Running;
56 self.started_at = Some(Utc::now());
57 }
58
59 pub fn complete(&mut self, result: serde_json::Value) {
61 self.status = StepStatus::Completed;
62 self.result = Some(result);
63 self.completed_at = Some(Utc::now());
64 }
65
66 pub fn fail(&mut self, error: impl Into<String>) {
68 self.status = StepStatus::Failed;
69 self.error = Some(error.into());
70 self.completed_at = Some(Utc::now());
71 }
72
73 pub fn compensate(&mut self) {
75 self.status = StepStatus::Compensated;
76 }
77}
78
79pub struct WorkflowContext {
81 pub run_id: Uuid,
83 pub workflow_name: String,
85 pub version: u32,
87 pub started_at: DateTime<Utc>,
89 workflow_time: DateTime<Utc>,
91 pub auth: AuthContext,
93 db_pool: sqlx::PgPool,
95 http_client: reqwest::Client,
97 step_states: Arc<RwLock<HashMap<String, StepState>>>,
99 completed_steps: Arc<RwLock<Vec<String>>>,
101 compensation_handlers: Arc<RwLock<HashMap<String, CompensationHandler>>>,
103 suspend_tx: Option<mpsc::Sender<SuspendReason>>,
105 is_resumed: bool,
107 resumed_from_sleep: bool,
109 tenant_id: Option<Uuid>,
111}
112
113impl WorkflowContext {
114 pub fn new(
116 run_id: Uuid,
117 workflow_name: String,
118 version: u32,
119 db_pool: sqlx::PgPool,
120 http_client: reqwest::Client,
121 ) -> Self {
122 let now = Utc::now();
123 Self {
124 run_id,
125 workflow_name,
126 version,
127 started_at: now,
128 workflow_time: now,
129 auth: AuthContext::unauthenticated(),
130 db_pool,
131 http_client,
132 step_states: Arc::new(RwLock::new(HashMap::new())),
133 completed_steps: Arc::new(RwLock::new(Vec::new())),
134 compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
135 suspend_tx: None,
136 is_resumed: false,
137 resumed_from_sleep: false,
138 tenant_id: None,
139 }
140 }
141
142 pub fn resumed(
144 run_id: Uuid,
145 workflow_name: String,
146 version: u32,
147 started_at: DateTime<Utc>,
148 db_pool: sqlx::PgPool,
149 http_client: reqwest::Client,
150 ) -> Self {
151 Self {
152 run_id,
153 workflow_name,
154 version,
155 started_at,
156 workflow_time: started_at,
157 auth: AuthContext::unauthenticated(),
158 db_pool,
159 http_client,
160 step_states: Arc::new(RwLock::new(HashMap::new())),
161 completed_steps: Arc::new(RwLock::new(Vec::new())),
162 compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
163 suspend_tx: None,
164 is_resumed: true,
165 resumed_from_sleep: false,
166 tenant_id: None,
167 }
168 }
169
170 pub fn with_resumed_from_sleep(mut self) -> Self {
172 self.resumed_from_sleep = true;
173 self
174 }
175
176 pub fn with_suspend_channel(mut self, tx: mpsc::Sender<SuspendReason>) -> Self {
178 self.suspend_tx = Some(tx);
179 self
180 }
181
182 pub fn with_tenant(mut self, tenant_id: Uuid) -> Self {
184 self.tenant_id = Some(tenant_id);
185 self
186 }
187
188 pub fn tenant_id(&self) -> Option<Uuid> {
190 self.tenant_id
191 }
192
193 pub fn is_resumed(&self) -> bool {
195 self.is_resumed
196 }
197
198 pub fn workflow_time(&self) -> DateTime<Utc> {
200 self.workflow_time
201 }
202
203 pub fn db(&self) -> &sqlx::PgPool {
205 &self.db_pool
206 }
207
208 pub fn http(&self) -> &reqwest::Client {
210 &self.http_client
211 }
212
213 pub fn with_auth(mut self, auth: AuthContext) -> Self {
215 self.auth = auth;
216 self
217 }
218
219 pub fn with_step_states(self, states: HashMap<String, StepState>) -> Self {
221 let completed: Vec<String> = states
222 .iter()
223 .filter(|(_, s)| s.status == StepStatus::Completed)
224 .map(|(name, _)| name.clone())
225 .collect();
226
227 *self.step_states.write().unwrap() = states;
228 *self.completed_steps.write().unwrap() = completed;
229 self
230 }
231
232 pub fn get_step_state(&self, name: &str) -> Option<StepState> {
234 self.step_states.read().unwrap().get(name).cloned()
235 }
236
237 pub fn is_step_completed(&self, name: &str) -> bool {
239 self.step_states
240 .read()
241 .unwrap()
242 .get(name)
243 .map(|s| s.status == StepStatus::Completed)
244 .unwrap_or(false)
245 }
246
247 pub fn is_step_started(&self, name: &str) -> bool {
252 self.step_states
253 .read()
254 .unwrap()
255 .get(name)
256 .map(|s| s.status != StepStatus::Pending)
257 .unwrap_or(false)
258 }
259
260 pub fn get_step_result<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
262 self.step_states
263 .read()
264 .unwrap()
265 .get(name)
266 .and_then(|s| s.result.as_ref())
267 .and_then(|v| serde_json::from_value(v.clone()).ok())
268 }
269
270 pub fn record_step_start(&self, name: &str) {
275 let mut states = self.step_states.write().unwrap();
276 let state = states
277 .entry(name.to_string())
278 .or_insert_with(|| StepState::new(name));
279
280 if state.status != StepStatus::Pending {
283 return;
284 }
285
286 state.start();
287 let state_clone = state.clone();
288 drop(states);
289
290 let pool = self.db_pool.clone();
292 let run_id = self.run_id;
293 let step_name = name.to_string();
294 tokio::spawn(async move {
295 let step_id = Uuid::new_v4();
296 if let Err(e) = sqlx::query(
297 r#"
298 INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, started_at)
299 VALUES ($1, $2, $3, $4, $5)
300 ON CONFLICT (workflow_run_id, step_name) DO NOTHING
301 "#,
302 )
303 .bind(step_id)
304 .bind(run_id)
305 .bind(&step_name)
306 .bind(state_clone.status.as_str())
307 .bind(state_clone.started_at)
308 .execute(&pool)
309 .await
310 {
311 tracing::warn!(
312 workflow_run_id = %run_id,
313 step = %step_name,
314 "Failed to persist step start: {}",
315 e
316 );
317 }
318 });
319 }
320
321 pub fn record_step_complete(&self, name: &str, result: serde_json::Value) {
324 let state_clone = self.update_step_state_complete(name, result);
325
326 if let Some(state) = state_clone {
328 let pool = self.db_pool.clone();
329 let run_id = self.run_id;
330 let step_name = name.to_string();
331 tokio::spawn(async move {
332 Self::persist_step_complete(&pool, run_id, &step_name, &state).await;
333 });
334 }
335 }
336
337 pub async fn record_step_complete_async(&self, name: &str, result: serde_json::Value) {
339 let state_clone = self.update_step_state_complete(name, result);
340
341 if let Some(state) = state_clone {
343 Self::persist_step_complete(&self.db_pool, self.run_id, name, &state).await;
344 }
345 }
346
347 fn update_step_state_complete(
349 &self,
350 name: &str,
351 result: serde_json::Value,
352 ) -> Option<StepState> {
353 let mut states = self.step_states.write().unwrap();
354 if let Some(state) = states.get_mut(name) {
355 state.complete(result.clone());
356 }
357 let state_clone = states.get(name).cloned();
358 drop(states);
359
360 let mut completed = self.completed_steps.write().unwrap();
361 if !completed.contains(&name.to_string()) {
362 completed.push(name.to_string());
363 }
364 drop(completed);
365
366 state_clone
367 }
368
369 async fn persist_step_complete(
371 pool: &sqlx::PgPool,
372 run_id: Uuid,
373 step_name: &str,
374 state: &StepState,
375 ) {
376 if let Err(e) = sqlx::query(
378 r#"
379 INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, result, started_at, completed_at)
380 VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6)
381 ON CONFLICT (workflow_run_id, step_name) DO UPDATE
382 SET status = $3, result = $4, completed_at = $6
383 "#,
384 )
385 .bind(run_id)
386 .bind(step_name)
387 .bind(state.status.as_str())
388 .bind(&state.result)
389 .bind(state.started_at)
390 .bind(state.completed_at)
391 .execute(pool)
392 .await
393 {
394 tracing::warn!(
395 workflow_run_id = %run_id,
396 step = %step_name,
397 "Failed to persist step completion: {}",
398 e
399 );
400 }
401 }
402
403 pub fn record_step_failure(&self, name: &str, error: impl Into<String>) {
405 let error_str = error.into();
406 let mut states = self.step_states.write().unwrap();
407 if let Some(state) = states.get_mut(name) {
408 state.fail(error_str.clone());
409 }
410 let state_clone = states.get(name).cloned();
411 drop(states);
412
413 if let Some(state) = state_clone {
415 let pool = self.db_pool.clone();
416 let run_id = self.run_id;
417 let step_name = name.to_string();
418 tokio::spawn(async move {
419 if let Err(e) = sqlx::query(
420 r#"
421 UPDATE forge_workflow_steps
422 SET status = $3, error = $4, completed_at = $5
423 WHERE workflow_run_id = $1 AND step_name = $2
424 "#,
425 )
426 .bind(run_id)
427 .bind(&step_name)
428 .bind(state.status.as_str())
429 .bind(&state.error)
430 .bind(state.completed_at)
431 .execute(&pool)
432 .await
433 {
434 tracing::warn!(
435 workflow_run_id = %run_id,
436 step = %step_name,
437 "Failed to persist step failure: {}",
438 e
439 );
440 }
441 });
442 }
443 }
444
445 pub fn record_step_compensated(&self, name: &str) {
447 let mut states = self.step_states.write().unwrap();
448 if let Some(state) = states.get_mut(name) {
449 state.compensate();
450 }
451 let state_clone = states.get(name).cloned();
452 drop(states);
453
454 if let Some(state) = state_clone {
456 let pool = self.db_pool.clone();
457 let run_id = self.run_id;
458 let step_name = name.to_string();
459 tokio::spawn(async move {
460 if let Err(e) = sqlx::query(
461 r#"
462 UPDATE forge_workflow_steps
463 SET status = $3
464 WHERE workflow_run_id = $1 AND step_name = $2
465 "#,
466 )
467 .bind(run_id)
468 .bind(&step_name)
469 .bind(state.status.as_str())
470 .execute(&pool)
471 .await
472 {
473 tracing::warn!(
474 workflow_run_id = %run_id,
475 step = %step_name,
476 "Failed to persist step compensation: {}",
477 e
478 );
479 }
480 });
481 }
482 }
483
484 pub fn completed_steps_reversed(&self) -> Vec<String> {
486 let completed = self.completed_steps.read().unwrap();
487 completed.iter().rev().cloned().collect()
488 }
489
490 pub fn all_step_states(&self) -> HashMap<String, StepState> {
492 self.step_states.read().unwrap().clone()
493 }
494
495 pub fn elapsed(&self) -> chrono::Duration {
497 Utc::now() - self.started_at
498 }
499
500 pub fn register_compensation(&self, step_name: &str, handler: CompensationHandler) {
502 let mut handlers = self.compensation_handlers.write().unwrap();
503 handlers.insert(step_name.to_string(), handler);
504 }
505
506 pub fn get_compensation_handler(&self, step_name: &str) -> Option<CompensationHandler> {
508 self.compensation_handlers
509 .read()
510 .unwrap()
511 .get(step_name)
512 .cloned()
513 }
514
515 pub fn has_compensation(&self, step_name: &str) -> bool {
517 self.compensation_handlers
518 .read()
519 .unwrap()
520 .contains_key(step_name)
521 }
522
523 pub async fn run_compensation(&self) -> Vec<(String, bool)> {
526 let steps = self.completed_steps_reversed();
527 let mut results = Vec::new();
528
529 for step_name in steps {
530 let handler = self.get_compensation_handler(&step_name);
531 let result = self
532 .get_step_state(&step_name)
533 .and_then(|s| s.result.clone());
534
535 if let Some(handler) = handler {
536 let step_result = result.unwrap_or(serde_json::Value::Null);
537 match handler(step_result).await {
538 Ok(()) => {
539 self.record_step_compensated(&step_name);
540 results.push((step_name, true));
541 }
542 Err(e) => {
543 tracing::error!(step = %step_name, error = %e, "Compensation failed");
544 results.push((step_name, false));
545 }
546 }
547 } else {
548 self.record_step_compensated(&step_name);
550 results.push((step_name, true));
551 }
552 }
553
554 results
555 }
556
557 pub fn compensation_handlers(&self) -> HashMap<String, CompensationHandler> {
559 self.compensation_handlers.read().unwrap().clone()
560 }
561
562 pub async fn sleep(&self, duration: Duration) -> Result<()> {
577 if self.resumed_from_sleep {
579 return Ok(());
580 }
581
582 let wake_at = Utc::now() + chrono::Duration::from_std(duration).unwrap_or_default();
583 self.sleep_until(wake_at).await
584 }
585
586 pub async fn sleep_until(&self, wake_at: DateTime<Utc>) -> Result<()> {
597 if self.resumed_from_sleep {
599 return Ok(());
600 }
601
602 if wake_at <= Utc::now() {
604 return Ok(());
605 }
606
607 self.set_wake_at(wake_at).await?;
609
610 self.signal_suspend(SuspendReason::Sleep { wake_at })
612 .await?;
613
614 Ok(())
615 }
616
617 pub async fn wait_for_event<T: DeserializeOwned>(
630 &self,
631 event_name: &str,
632 timeout: Option<Duration>,
633 ) -> Result<T> {
634 let correlation_id = self.run_id.to_string();
635
636 if let Some(event) = self.try_consume_event(event_name, &correlation_id).await? {
638 return serde_json::from_value(event.payload.unwrap_or_default())
639 .map_err(|e| ForgeError::Deserialization(e.to_string()));
640 }
641
642 let timeout_at =
644 timeout.map(|d| Utc::now() + chrono::Duration::from_std(d).unwrap_or_default());
645
646 self.set_waiting_for_event(event_name, timeout_at).await?;
648
649 self.signal_suspend(SuspendReason::WaitingEvent {
651 event_name: event_name.to_string(),
652 timeout: timeout_at,
653 })
654 .await?;
655
656 self.try_consume_event(event_name, &correlation_id)
658 .await?
659 .and_then(|e| e.payload)
660 .and_then(|p| serde_json::from_value(p).ok())
661 .ok_or_else(|| ForgeError::Timeout(format!("Event '{}' timed out", event_name)))
662 }
663
664 #[allow(clippy::type_complexity)]
666 async fn try_consume_event(
667 &self,
668 event_name: &str,
669 correlation_id: &str,
670 ) -> Result<Option<WorkflowEvent>> {
671 let result: Option<(
672 Uuid,
673 String,
674 String,
675 Option<serde_json::Value>,
676 DateTime<Utc>,
677 )> = sqlx::query_as(
678 r#"
679 UPDATE forge_workflow_events
680 SET consumed_at = NOW(), consumed_by = $3
681 WHERE id = (
682 SELECT id FROM forge_workflow_events
683 WHERE event_name = $1 AND correlation_id = $2 AND consumed_at IS NULL
684 ORDER BY created_at ASC LIMIT 1
685 FOR UPDATE SKIP LOCKED
686 )
687 RETURNING id, event_name, correlation_id, payload, created_at
688 "#,
689 )
690 .bind(event_name)
691 .bind(correlation_id)
692 .bind(self.run_id)
693 .fetch_optional(&self.db_pool)
694 .await
695 .map_err(|e| ForgeError::Database(e.to_string()))?;
696
697 Ok(result.map(
698 |(id, event_name, correlation_id, payload, created_at)| WorkflowEvent {
699 id,
700 event_name,
701 correlation_id,
702 payload,
703 created_at,
704 },
705 ))
706 }
707
708 async fn set_wake_at(&self, wake_at: DateTime<Utc>) -> Result<()> {
710 sqlx::query(
711 r#"
712 UPDATE forge_workflow_runs
713 SET status = 'waiting', suspended_at = NOW(), wake_at = $2
714 WHERE id = $1
715 "#,
716 )
717 .bind(self.run_id)
718 .bind(wake_at)
719 .execute(&self.db_pool)
720 .await
721 .map_err(|e| ForgeError::Database(e.to_string()))?;
722 Ok(())
723 }
724
725 async fn set_waiting_for_event(
727 &self,
728 event_name: &str,
729 timeout_at: Option<DateTime<Utc>>,
730 ) -> Result<()> {
731 sqlx::query(
732 r#"
733 UPDATE forge_workflow_runs
734 SET status = 'waiting', suspended_at = NOW(), waiting_for_event = $2, event_timeout_at = $3
735 WHERE id = $1
736 "#,
737 )
738 .bind(self.run_id)
739 .bind(event_name)
740 .bind(timeout_at)
741 .execute(&self.db_pool)
742 .await
743 .map_err(|e| ForgeError::Database(e.to_string()))?;
744 Ok(())
745 }
746
747 async fn signal_suspend(&self, reason: SuspendReason) -> Result<()> {
749 if let Some(ref tx) = self.suspend_tx {
750 tx.send(reason)
751 .await
752 .map_err(|_| ForgeError::Internal("Failed to signal suspension".into()))?;
753 }
754 Err(ForgeError::WorkflowSuspended)
756 }
757
758 pub fn parallel(&self) -> ParallelBuilder<'_> {
778 ParallelBuilder::new(self)
779 }
780
781 pub fn step<T, F, Fut>(&self, name: impl Into<String>, f: F) -> super::StepRunner<'_, T>
835 where
836 T: serde::Serialize + serde::de::DeserializeOwned + Clone + Send + Sync + 'static,
837 F: FnOnce() -> Fut + Send + 'static,
838 Fut: std::future::Future<Output = crate::Result<T>> + Send + 'static,
839 {
840 super::StepRunner::new(self, name, f)
841 }
842}
843
844#[cfg(test)]
845mod tests {
846 use super::*;
847
848 #[tokio::test]
849 async fn test_workflow_context_creation() {
850 let pool = sqlx::postgres::PgPoolOptions::new()
851 .max_connections(1)
852 .connect_lazy("postgres://localhost/nonexistent")
853 .expect("Failed to create mock pool");
854
855 let run_id = Uuid::new_v4();
856 let ctx = WorkflowContext::new(
857 run_id,
858 "test_workflow".to_string(),
859 1,
860 pool,
861 reqwest::Client::new(),
862 );
863
864 assert_eq!(ctx.run_id, run_id);
865 assert_eq!(ctx.workflow_name, "test_workflow");
866 assert_eq!(ctx.version, 1);
867 }
868
869 #[tokio::test]
870 async fn test_step_state_tracking() {
871 let pool = sqlx::postgres::PgPoolOptions::new()
872 .max_connections(1)
873 .connect_lazy("postgres://localhost/nonexistent")
874 .expect("Failed to create mock pool");
875
876 let ctx = WorkflowContext::new(
877 Uuid::new_v4(),
878 "test".to_string(),
879 1,
880 pool,
881 reqwest::Client::new(),
882 );
883
884 ctx.record_step_start("step1");
885 assert!(!ctx.is_step_completed("step1"));
886
887 ctx.record_step_complete("step1", serde_json::json!({"result": "ok"}));
888 assert!(ctx.is_step_completed("step1"));
889
890 let result: Option<serde_json::Value> = ctx.get_step_result("step1");
891 assert!(result.is_some());
892 }
893
894 #[test]
895 fn test_step_state_transitions() {
896 let mut state = StepState::new("test");
897 assert_eq!(state.status, StepStatus::Pending);
898
899 state.start();
900 assert_eq!(state.status, StepStatus::Running);
901 assert!(state.started_at.is_some());
902
903 state.complete(serde_json::json!({}));
904 assert_eq!(state.status, StepStatus::Completed);
905 assert!(state.completed_at.is_some());
906 }
907}