1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, RwLock};
5use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use serde::de::DeserializeOwned;
9use uuid::Uuid;
10
11use super::step::StepStatus;
12use super::suspend::{SuspendReason, WorkflowEvent};
13use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
14use crate::function::{AuthContext, KvHandle};
15use crate::http::CircuitBreakerClient;
16use crate::{ForgeError, Result};
17
18const LOCK_POISONED: &str = "workflow lock poisoned";
19
20pub type CompensationHandler = Arc<
22 dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync,
23>;
24
25#[derive(Debug, Clone)]
27pub struct StepState {
28 pub name: String,
30 pub status: StepStatus,
32 pub result: Option<serde_json::Value>,
34 pub error: Option<String>,
36 pub started_at: Option<DateTime<Utc>>,
38 pub completed_at: Option<DateTime<Utc>>,
40}
41
42impl StepState {
43 pub fn new(name: impl Into<String>) -> Self {
45 Self {
46 name: name.into(),
47 status: StepStatus::Pending,
48 result: None,
49 error: None,
50 started_at: None,
51 completed_at: None,
52 }
53 }
54
55 pub fn start(&mut self) {
57 self.status = StepStatus::Running;
58 self.started_at = Some(Utc::now());
59 }
60
61 pub fn complete(&mut self, result: serde_json::Value) {
63 self.status = StepStatus::Completed;
64 self.result = Some(result);
65 self.completed_at = Some(Utc::now());
66 }
67
68 pub fn fail(&mut self, error: impl Into<String>) {
70 self.status = StepStatus::Failed;
71 self.error = Some(error.into());
72 self.completed_at = Some(Utc::now());
73 }
74
75 pub fn compensate(&mut self) {
77 self.status = StepStatus::Compensated;
78 }
79}
80
81#[non_exhaustive]
83pub struct WorkflowContext {
84 pub run_id: Uuid,
86 pub workflow_name: String,
88 pub started_at: DateTime<Utc>,
90 workflow_time: DateTime<Utc>,
92 pub auth: AuthContext,
94 db_pool: sqlx::PgPool,
96 http_client: CircuitBreakerClient,
98 http_timeout: Option<Duration>,
101 step_states: Arc<RwLock<HashMap<String, StepState>>>,
102 completed_steps: Arc<RwLock<Vec<String>>>,
104 compensation_handlers: Arc<RwLock<HashMap<String, CompensationHandler>>>,
105 is_resumed: bool,
106 resumed_from_sleep: bool,
107 tenant_id: Option<Uuid>,
108 env_provider: Arc<dyn EnvProvider>,
109 saved_state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
111 kv: Option<Arc<dyn KvHandle>>,
112 persist_step_start: bool,
117 suspend_reason: Arc<std::sync::Mutex<Option<SuspendReason>>>,
120}
121
122impl WorkflowContext {
123 pub fn new(
125 run_id: Uuid,
126 workflow_name: String,
127 db_pool: sqlx::PgPool,
128 http_client: CircuitBreakerClient,
129 ) -> Self {
130 let now = Utc::now();
131 Self {
132 run_id,
133 workflow_name,
134 started_at: now,
135 workflow_time: now,
136 auth: AuthContext::unauthenticated(),
137 db_pool,
138 http_client,
139 http_timeout: None,
140 step_states: Arc::new(RwLock::new(HashMap::new())),
141 completed_steps: Arc::new(RwLock::new(Vec::new())),
142 compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
143
144 is_resumed: false,
145 resumed_from_sleep: false,
146 tenant_id: None,
147 env_provider: Arc::new(RealEnvProvider::new()),
148 saved_state: Arc::new(RwLock::new(HashMap::new())),
149 kv: None,
150 persist_step_start: false,
151 suspend_reason: Arc::new(std::sync::Mutex::new(None)),
152 }
153 }
154
155 pub fn with_persist_step_start(mut self, persist: bool) -> Self {
160 self.persist_step_start = persist;
161 self
162 }
163
164 pub fn resumed(
166 run_id: Uuid,
167 workflow_name: String,
168 started_at: DateTime<Utc>,
169 db_pool: sqlx::PgPool,
170 http_client: CircuitBreakerClient,
171 ) -> Self {
172 Self {
173 run_id,
174 workflow_name,
175 started_at,
176 workflow_time: started_at,
177 auth: AuthContext::unauthenticated(),
178 db_pool,
179 http_client,
180 http_timeout: None,
181 step_states: Arc::new(RwLock::new(HashMap::new())),
182 completed_steps: Arc::new(RwLock::new(Vec::new())),
183 compensation_handlers: Arc::new(RwLock::new(HashMap::new())),
184
185 is_resumed: true,
186 resumed_from_sleep: false,
187 tenant_id: None,
188 env_provider: Arc::new(RealEnvProvider::new()),
189 saved_state: Arc::new(RwLock::new(HashMap::new())),
190 kv: None,
191 persist_step_start: false,
192 suspend_reason: Arc::new(std::sync::Mutex::new(None)),
193 }
194 }
195
196 pub fn with_kv(mut self, kv: Arc<dyn KvHandle>) -> Self {
199 self.kv = Some(kv);
200 self
201 }
202
203 pub fn kv(&self) -> crate::error::Result<&dyn KvHandle> {
205 self.kv
206 .as_deref()
207 .ok_or_else(|| crate::error::ForgeError::internal("KV store not available"))
208 }
209
210 pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
212 self.env_provider = provider;
213 self
214 }
215
216 pub fn with_resumed_from_sleep(mut self) -> Self {
218 self.resumed_from_sleep = true;
219 self
220 }
221
222 pub fn with_tenant(mut self, tenant_id: Uuid) -> Self {
224 self.tenant_id = Some(tenant_id);
225 self
226 }
227
228 pub fn tenant_id(&self) -> Option<Uuid> {
229 self.tenant_id
230 }
231
232 pub fn is_resumed(&self) -> bool {
233 self.is_resumed
234 }
235
236 pub fn workflow_time(&self) -> DateTime<Utc> {
237 self.workflow_time
238 }
239
240 pub fn db(&self) -> crate::function::ForgeDb {
241 crate::function::ForgeDb::from_pool(&self.db_pool)
242 }
243
244 pub fn db_conn(&self) -> crate::function::DbConn<'_> {
246 crate::function::DbConn::Pool(self.db_pool.clone())
247 }
248
249 pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
251 Ok(crate::function::ForgeConn::Pool(
252 self.db_pool.acquire().await?,
253 ))
254 }
255
256 pub fn http(&self) -> crate::http::HttpClient {
257 self.http_client.with_timeout(self.http_timeout)
258 }
259
260 pub fn raw_http(&self) -> &reqwest::Client {
261 self.http_client.inner()
262 }
263
264 pub fn set_http_timeout(&mut self, timeout: Option<Duration>) {
265 self.http_timeout = timeout;
266 }
267
268 pub fn with_auth(mut self, auth: AuthContext) -> Self {
270 self.auth = auth;
271 self
272 }
273
274 pub fn with_saved_state(self, state: HashMap<String, serde_json::Value>) -> Self {
276 *self.saved_state.write().expect(LOCK_POISONED) = state;
277 self
278 }
279
280 pub fn save_state(&self, key: &str, value: impl serde::Serialize) -> crate::Result<()> {
282 let json = serde_json::to_value(value)
283 .map_err(|e| crate::ForgeError::Serialization(e.to_string()))?;
284 self.saved_state
285 .write()
286 .expect(LOCK_POISONED)
287 .insert(key.to_string(), json);
288 Ok(())
289 }
290
291 pub fn load_state<T: serde::de::DeserializeOwned>(
293 &self,
294 key: &str,
295 ) -> crate::Result<Option<T>> {
296 let guard = self.saved_state.read().expect(LOCK_POISONED);
297 match guard.get(key) {
298 Some(value) => {
299 let result = serde_json::from_value(value.clone())
300 .map_err(|e| crate::ForgeError::Deserialization(e.to_string()))?;
301 Ok(Some(result))
302 }
303 None => Ok(None),
304 }
305 }
306
307 pub fn take_saved_state(&self) -> HashMap<String, serde_json::Value> {
309 self.saved_state.read().expect(LOCK_POISONED).clone()
310 }
311
312 pub fn with_step_states(self, states: HashMap<String, StepState>) -> Self {
314 let completed: Vec<String> = states
315 .iter()
316 .filter(|(_, s)| s.status == StepStatus::Completed)
317 .map(|(name, _)| name.clone())
318 .collect();
319
320 *self.step_states.write().expect(LOCK_POISONED) = states;
321 *self.completed_steps.write().expect(LOCK_POISONED) = completed;
322 self
323 }
324
325 pub fn get_step_state(&self, name: &str) -> Option<StepState> {
326 self.step_states
327 .read()
328 .expect(LOCK_POISONED)
329 .get(name)
330 .cloned()
331 }
332
333 pub fn is_step_completed(&self, name: &str) -> bool {
334 self.step_states
335 .read()
336 .expect(LOCK_POISONED)
337 .get(name)
338 .map(|s| s.status == StepStatus::Completed)
339 .unwrap_or(false)
340 }
341
342 pub fn is_step_started(&self, name: &str) -> bool {
347 self.step_states
348 .read()
349 .expect(LOCK_POISONED)
350 .get(name)
351 .map(|s| s.status != StepStatus::Pending)
352 .unwrap_or(false)
353 }
354
355 pub fn get_step_result<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
356 self.step_states
357 .read()
358 .expect(LOCK_POISONED)
359 .get(name)
360 .and_then(|s| s.result.as_ref())
361 .and_then(|v| serde_json::from_value(v.clone()).ok())
362 }
363
364 pub async fn record_step_start(&self, name: &str) -> crate::Result<()> {
379 let state_clone = {
380 let mut states = self.step_states.write().expect(LOCK_POISONED);
381 let state = states
382 .entry(name.to_string())
383 .or_insert_with(|| StepState::new(name));
384
385 if state.status != StepStatus::Pending {
386 return Ok(());
387 }
388
389 state.start();
390 state.clone()
391 };
392
393 if !self.persist_step_start {
394 return Ok(());
395 }
396
397 let step_id = Uuid::new_v4();
398 let step_name = name.to_string();
399 sqlx::query!(
400 r#"
401 INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, started_at)
402 VALUES ($1, $2, $3, $4, $5)
403 ON CONFLICT (workflow_run_id, step_name) DO NOTHING
404 "#,
405 step_id,
406 self.run_id,
407 step_name,
408 state_clone.status.as_str(),
409 state_clone.started_at,
410 )
411 .execute(&self.db_pool)
412 .await
413 .map_err(crate::ForgeError::Database)?;
414 Ok(())
415 }
416
417 pub async fn record_step_complete(
422 &self,
423 name: &str,
424 result: serde_json::Value,
425 ) -> crate::Result<()> {
426 let state_clone = self.update_step_state_complete(name, result);
427
428 if let Some(state) = state_clone {
429 Self::persist_step_complete(&self.db_pool, self.run_id, name, &state).await?;
430 }
431 Ok(())
432 }
433
434 fn update_step_state_complete(
436 &self,
437 name: &str,
438 result: serde_json::Value,
439 ) -> Option<StepState> {
440 let mut states = self.step_states.write().expect(LOCK_POISONED);
441 if let Some(state) = states.get_mut(name) {
442 state.complete(result.clone());
443 }
444 let state_clone = states.get(name).cloned();
445 drop(states);
446
447 let mut completed = self.completed_steps.write().expect(LOCK_POISONED);
448 if !completed.contains(&name.to_string()) {
449 completed.push(name.to_string());
450 }
451 drop(completed);
452
453 state_clone
454 }
455
456 async fn persist_step_complete(
458 pool: &sqlx::PgPool,
459 run_id: Uuid,
460 step_name: &str,
461 state: &StepState,
462 ) -> crate::Result<()> {
463 sqlx::query!(
465 r#"
466 INSERT INTO forge_workflow_steps (id, workflow_run_id, step_name, status, result, started_at, completed_at)
467 VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6)
468 ON CONFLICT (workflow_run_id, step_name) DO UPDATE
469 SET status = $3, result = $4, completed_at = $6
470 "#,
471 run_id,
472 step_name,
473 state.status.as_str(),
474 state.result as _,
475 state.started_at,
476 state.completed_at,
477 )
478 .execute(pool)
479 .await
480 .map_err(crate::ForgeError::Database)?;
481 Ok(())
482 }
483
484 pub async fn record_step_failure(
490 &self,
491 name: &str,
492 error: impl Into<String>,
493 ) -> crate::Result<()> {
494 let error_str = error.into();
495 let state_clone = {
496 let mut states = self.step_states.write().expect(LOCK_POISONED);
497 if let Some(state) = states.get_mut(name) {
498 state.fail(error_str.clone());
499 }
500 states.get(name).cloned()
501 };
502
503 if let Some(state) = state_clone {
504 let step_name = name.to_string();
505 sqlx::query!(
506 r#"
507 UPDATE forge_workflow_steps
508 SET status = $3, error = $4, completed_at = $5
509 WHERE workflow_run_id = $1 AND step_name = $2
510 "#,
511 self.run_id,
512 step_name,
513 state.status.as_str(),
514 state.error as _,
515 state.completed_at,
516 )
517 .execute(&self.db_pool)
518 .await
519 .map_err(crate::ForgeError::Database)?;
520 }
521 Ok(())
522 }
523
524 pub async fn record_step_compensated(&self, name: &str) -> crate::Result<()> {
532 let state_clone = {
533 let mut states = self.step_states.write().expect(LOCK_POISONED);
534 if let Some(state) = states.get_mut(name) {
535 state.compensate();
536 }
537 states.get(name).cloned()
538 };
539
540 if let Some(state) = state_clone {
541 let step_name = name.to_string();
542 sqlx::query!(
543 r#"
544 UPDATE forge_workflow_steps
545 SET status = $3
546 WHERE workflow_run_id = $1 AND step_name = $2
547 "#,
548 self.run_id,
549 step_name,
550 state.status.as_str(),
551 )
552 .execute(&self.db_pool)
553 .await
554 .map_err(crate::ForgeError::Database)?;
555 }
556 Ok(())
557 }
558
559 pub fn completed_steps_reversed(&self) -> Vec<String> {
560 let completed = self.completed_steps.read().expect(LOCK_POISONED);
561 completed.iter().rev().cloned().collect()
562 }
563
564 pub fn all_step_states(&self) -> HashMap<String, StepState> {
565 self.step_states.read().expect(LOCK_POISONED).clone()
566 }
567
568 pub fn elapsed(&self) -> chrono::Duration {
569 Utc::now() - self.started_at
570 }
571
572 pub fn register_compensation(&self, step_name: &str, handler: CompensationHandler) {
583 let mut handlers = self.compensation_handlers.write().expect(LOCK_POISONED);
584 handlers.insert(step_name.to_string(), handler);
585 }
586
587 pub fn get_compensation_handler(&self, step_name: &str) -> Option<CompensationHandler> {
588 self.compensation_handlers
589 .read()
590 .expect(LOCK_POISONED)
591 .get(step_name)
592 .cloned()
593 }
594
595 pub fn has_compensation(&self, step_name: &str) -> bool {
596 self.compensation_handlers
597 .read()
598 .expect(LOCK_POISONED)
599 .contains_key(step_name)
600 }
601
602 pub async fn run_compensation(&self) -> Vec<(String, bool)> {
605 let steps = self.completed_steps_reversed();
606 let mut results = Vec::new();
607
608 for step_name in steps {
609 let handler = self.get_compensation_handler(&step_name);
610 let result = self
611 .get_step_state(&step_name)
612 .and_then(|s| s.result.clone());
613
614 if let Some(handler) = handler {
615 let step_result = result.unwrap_or(serde_json::Value::Null);
616 match handler(step_result).await {
617 Ok(()) => match self.record_step_compensated(&step_name).await {
618 Ok(()) => results.push((step_name, true)),
619 Err(e) => {
620 tracing::error!(
621 step = %step_name,
622 error = %e,
623 "Failed to persist step compensation; marking compensation as failed",
624 );
625 results.push((step_name, false));
626 }
627 },
628 Err(e) => {
629 tracing::error!(step = %step_name, error = %e, "Compensation failed");
630 results.push((step_name, false));
631 }
632 }
633 } else {
634 match self.record_step_compensated(&step_name).await {
636 Ok(()) => results.push((step_name, true)),
637 Err(e) => {
638 tracing::error!(
639 step = %step_name,
640 error = %e,
641 "Failed to persist step compensation",
642 );
643 results.push((step_name, false));
644 }
645 }
646 }
647 }
648
649 results
650 }
651
652 pub fn compensation_handlers(&self) -> HashMap<String, CompensationHandler> {
653 self.compensation_handlers
654 .read()
655 .expect(LOCK_POISONED)
656 .clone()
657 }
658
659 pub async fn sleep(&self, duration: Duration) -> Result<()> {
670 if self.resumed_from_sleep {
672 return Ok(());
673 }
674
675 let wake_at = Utc::now() + chrono::Duration::from_std(duration).unwrap_or_default();
676 self.sleep_until(wake_at).await
677 }
678
679 pub async fn sleep_until(&self, wake_at: DateTime<Utc>) -> Result<()> {
690 if self.resumed_from_sleep {
692 return Ok(());
693 }
694
695 if wake_at <= Utc::now() {
697 return Ok(());
698 }
699
700 self.set_wake_at(wake_at).await?;
702
703 self.signal_suspend(SuspendReason::Sleep { wake_at })
705 .await?;
706
707 Ok(())
708 }
709
710 pub async fn wait_for_event<T: DeserializeOwned>(
723 &self,
724 event_name: &str,
725 timeout: Option<Duration>,
726 ) -> Result<T> {
727 let correlation_id = self.run_id.to_string();
728
729 if self.is_resumed
732 && let Some(event) = self
733 .find_consumed_event(event_name, &correlation_id)
734 .await?
735 {
736 return serde_json::from_value(event.payload.unwrap_or_default())
737 .map_err(|e| ForgeError::Deserialization(e.to_string()));
738 }
739
740 if let Some(event) = self.try_consume_event(event_name, &correlation_id).await? {
742 return serde_json::from_value(event.payload.unwrap_or_default())
743 .map_err(|e| ForgeError::Deserialization(e.to_string()));
744 }
745
746 let timeout_at =
748 timeout.map(|d| Utc::now() + chrono::Duration::from_std(d).unwrap_or_default());
749
750 self.set_waiting_for_event(event_name, timeout_at).await?;
752
753 self.signal_suspend(SuspendReason::WaitingEvent {
755 event_name: event_name.to_string(),
756 timeout: timeout_at,
757 })
758 .await?;
759
760 self.try_consume_event(event_name, &correlation_id)
762 .await?
763 .and_then(|e| e.payload)
764 .and_then(|p| serde_json::from_value(p).ok())
765 .ok_or_else(|| ForgeError::Timeout(format!("Event '{}' timed out", event_name)))
766 }
767
768 #[allow(clippy::type_complexity)]
770 async fn try_consume_event(
771 &self,
772 event_name: &str,
773 correlation_id: &str,
774 ) -> Result<Option<WorkflowEvent>> {
775 let result = sqlx::query!(
776 r#"
777 UPDATE forge_workflow_events
778 SET consumed_at = NOW(), consumed_by = $3
779 WHERE id = (
780 SELECT id FROM forge_workflow_events
781 WHERE event_name = $1 AND correlation_id = $2 AND consumed_at IS NULL
782 ORDER BY created_at ASC LIMIT 1
783 FOR UPDATE SKIP LOCKED
784 )
785 RETURNING id, event_name, correlation_id, payload, created_at
786 "#,
787 event_name,
788 correlation_id,
789 self.run_id
790 )
791 .fetch_optional(&self.db_pool)
792 .await
793 .map_err(ForgeError::Database)?;
794
795 Ok(result.map(|row| WorkflowEvent {
796 id: row.id,
797 event_name: row.event_name,
798 correlation_id: row.correlation_id,
799 payload: row.payload,
800 created_at: row.created_at,
801 }))
802 }
803
804 async fn find_consumed_event(
807 &self,
808 event_name: &str,
809 correlation_id: &str,
810 ) -> Result<Option<WorkflowEvent>> {
811 let result = sqlx::query!(
812 r#"
813 SELECT id, event_name, correlation_id, payload, created_at
814 FROM forge_workflow_events
815 WHERE event_name = $1 AND correlation_id = $2 AND consumed_by = $3
816 ORDER BY created_at DESC LIMIT 1
817 "#,
818 event_name,
819 correlation_id,
820 self.run_id
821 )
822 .fetch_optional(&self.db_pool)
823 .await
824 .map_err(ForgeError::Database)?;
825
826 Ok(result.map(|row| WorkflowEvent {
827 id: row.id,
828 event_name: row.event_name,
829 correlation_id: row.correlation_id,
830 payload: row.payload,
831 created_at: row.created_at,
832 }))
833 }
834
835 async fn set_wake_at(&self, wake_at: DateTime<Utc>) -> Result<()> {
837 sqlx::query!(
838 r#"
839 UPDATE forge_workflow_runs
840 SET status = 'sleeping', suspended_at = NOW(), wake_at = $2
841 WHERE id = $1
842 "#,
843 self.run_id,
844 wake_at,
845 )
846 .execute(&self.db_pool)
847 .await
848 .map_err(ForgeError::Database)?;
849
850 #[allow(clippy::disallowed_methods)]
853 if let Err(e) = sqlx::query("SELECT pg_notify('forge_workflow_wakeup', $1::text)")
854 .bind(self.run_id.to_string())
855 .execute(&self.db_pool)
856 .await
857 {
858 tracing::debug!(
859 workflow_run_id = %self.run_id,
860 error = %e,
861 "Failed to send workflow wakeup notify (scheduler will poll)",
862 );
863 }
864
865 Ok(())
866 }
867
868 async fn set_waiting_for_event(
870 &self,
871 event_name: &str,
872 timeout_at: Option<DateTime<Utc>>,
873 ) -> Result<()> {
874 sqlx::query!(
875 r#"
876 UPDATE forge_workflow_runs
877 SET status = 'waiting', suspended_at = NOW(), waiting_for_event = $2, event_timeout_at = $3
878 WHERE id = $1
879 "#,
880 self.run_id,
881 event_name,
882 timeout_at,
883 )
884 .execute(&self.db_pool)
885 .await
886 .map_err(ForgeError::Database)?;
887 Ok(())
888 }
889
890 async fn signal_suspend(&self, reason: SuspendReason) -> Result<()> {
898 *self.suspend_reason.lock().expect(LOCK_POISONED) = Some(reason.clone());
899 Err(ForgeError::WorkflowSuspended(reason))
900 }
901
902 pub fn take_suspend_reason(&self) -> Option<SuspendReason> {
907 self.suspend_reason.lock().expect(LOCK_POISONED).take()
908 }
909}
910
911impl EnvAccess for WorkflowContext {
912 fn env_provider(&self) -> &dyn EnvProvider {
913 self.env_provider.as_ref()
914 }
915}
916
917#[cfg(test)]
918#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
919mod tests {
920 use super::*;
921
922 #[tokio::test]
923 async fn test_workflow_context_creation() {
924 let pool = sqlx::postgres::PgPoolOptions::new()
925 .max_connections(1)
926 .acquire_timeout(std::time::Duration::from_millis(1))
927 .connect_lazy("postgres://localhost/nonexistent")
928 .expect("Failed to create mock pool");
929
930 let run_id = Uuid::new_v4();
931 let ctx = WorkflowContext::new(
932 run_id,
933 "test_workflow".to_string(),
934 pool,
935 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
936 );
937
938 assert_eq!(ctx.run_id, run_id);
939 assert_eq!(ctx.workflow_name, "test_workflow");
940 }
941
942 #[tokio::test]
943 async fn test_step_state_tracking() {
944 let pool = sqlx::postgres::PgPoolOptions::new()
945 .max_connections(1)
946 .acquire_timeout(std::time::Duration::from_millis(1))
947 .connect_lazy("postgres://localhost/nonexistent")
948 .expect("Failed to create mock pool");
949
950 let ctx = WorkflowContext::new(
951 Uuid::new_v4(),
952 "test".to_string(),
953 pool,
954 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
955 );
956
957 ctx.record_step_start("step1")
960 .await
961 .expect("record_step_start should not touch db when persist disabled");
962 assert!(!ctx.is_step_completed("step1"));
963
964 let complete_err = ctx
968 .record_step_complete("step1", serde_json::json!({"result": "ok"}))
969 .await
970 .expect_err("record_step_complete should propagate db errors");
971 assert!(
972 matches!(complete_err, crate::ForgeError::Database(_)),
973 "expected Database error, got {complete_err:?}",
974 );
975 assert!(ctx.is_step_completed("step1"));
977
978 let result: Option<serde_json::Value> = ctx.get_step_result("step1");
979 assert!(result.is_some());
980 }
981
982 #[test]
983 fn test_step_state_transitions() {
984 let mut state = StepState::new("test");
985 assert_eq!(state.status, StepStatus::Pending);
986
987 state.start();
988 assert_eq!(state.status, StepStatus::Running);
989 assert!(state.started_at.is_some());
990
991 state.complete(serde_json::json!({}));
992 assert_eq!(state.status, StepStatus::Completed);
993 assert!(state.completed_at.is_some());
994 }
995
996 fn lazy_ctx() -> WorkflowContext {
997 let pool = sqlx::postgres::PgPoolOptions::new()
998 .max_connections(1)
999 .acquire_timeout(std::time::Duration::from_millis(1))
1000 .connect_lazy("postgres://localhost/nonexistent")
1001 .expect("Failed to create mock pool");
1002 WorkflowContext::new(
1003 Uuid::new_v4(),
1004 "test".to_string(),
1005 pool,
1006 CircuitBreakerClient::with_defaults(reqwest::Client::new()),
1007 )
1008 }
1009
1010 #[test]
1011 fn step_state_fail_records_error_and_completion() {
1012 let mut state = StepState::new("step");
1013 state.start();
1014 state.fail("boom");
1015 assert_eq!(state.status, StepStatus::Failed);
1016 assert_eq!(state.error.as_deref(), Some("boom"));
1017 assert!(state.completed_at.is_some());
1018 }
1019
1020 #[test]
1021 fn step_state_compensate_only_flips_status() {
1022 let mut state = StepState::new("step");
1023 state.complete(serde_json::json!({"ok": true}));
1024 let completed_at = state.completed_at;
1025 state.compensate();
1026 assert_eq!(state.status, StepStatus::Compensated);
1027 assert_eq!(state.completed_at, completed_at);
1029 }
1030
1031 #[tokio::test]
1032 async fn save_state_and_load_state_round_trip() {
1033 let ctx = lazy_ctx();
1034 ctx.save_state("count", 42_u32).unwrap();
1035 let v: Option<u32> = ctx.load_state("count").unwrap();
1036 assert_eq!(v, Some(42));
1037 }
1038
1039 #[tokio::test]
1040 async fn load_state_returns_none_for_unknown_key() {
1041 let ctx = lazy_ctx();
1042 let v: Option<String> = ctx.load_state("missing").unwrap();
1043 assert!(v.is_none());
1044 }
1045
1046 #[tokio::test]
1047 async fn load_state_returns_deserialization_error_on_type_mismatch() {
1048 let ctx = lazy_ctx();
1049 ctx.save_state("k", "a string").unwrap();
1050 let err = ctx.load_state::<u32>("k").unwrap_err();
1051 assert!(matches!(err, ForgeError::Deserialization(_)));
1052 }
1053
1054 #[tokio::test]
1055 async fn take_saved_state_returns_snapshot_of_all_entries() {
1056 let ctx = lazy_ctx();
1057 ctx.save_state("a", 1_u32).unwrap();
1058 ctx.save_state("b", "two").unwrap();
1059 let snap = ctx.take_saved_state();
1060 assert_eq!(snap.len(), 2);
1061 assert_eq!(snap.get("a"), Some(&serde_json::json!(1)));
1062 assert_eq!(snap.get("b"), Some(&serde_json::json!("two")));
1063 }
1064
1065 #[tokio::test]
1066 async fn tenant_id_defaults_to_none_and_with_tenant_sets_it() {
1067 let ctx = lazy_ctx();
1068 assert!(ctx.tenant_id().is_none());
1069 let tenant = Uuid::new_v4();
1070 let ctx = ctx.with_tenant(tenant);
1071 assert_eq!(ctx.tenant_id(), Some(tenant));
1072 }
1073
1074 #[tokio::test]
1075 async fn is_resumed_defaults_to_false() {
1076 let ctx = lazy_ctx();
1077 assert!(!ctx.is_resumed());
1078 }
1079
1080 #[tokio::test]
1081 async fn is_step_completed_and_started_return_false_for_unknown_steps() {
1082 let ctx = lazy_ctx();
1083 assert!(!ctx.is_step_completed("nope"));
1084 assert!(!ctx.is_step_started("nope"));
1085 }
1086
1087 #[tokio::test]
1088 async fn get_step_result_returns_none_for_unknown_step() {
1089 let ctx = lazy_ctx();
1090 let v: Option<serde_json::Value> = ctx.get_step_result("nope");
1091 assert!(v.is_none());
1092 }
1093
1094 #[tokio::test]
1095 async fn with_step_states_rebuilds_completed_steps_from_status() {
1096 let ctx = lazy_ctx();
1097 let mut s = HashMap::new();
1098 let mut completed = StepState::new("done");
1099 completed.complete(serde_json::json!({"v": 1}));
1100 s.insert("done".to_string(), completed);
1101 let pending = StepState::new("pending");
1102 s.insert("pending".to_string(), pending);
1103
1104 let ctx = ctx.with_step_states(s);
1105 assert!(ctx.is_step_completed("done"));
1106 assert!(!ctx.is_step_completed("pending"));
1107
1108 let reversed = ctx.completed_steps_reversed();
1109 assert_eq!(reversed, vec!["done".to_string()]);
1110 }
1111
1112 #[tokio::test]
1113 async fn completed_steps_reversed_is_empty_initially() {
1114 let ctx = lazy_ctx();
1115 assert!(ctx.completed_steps_reversed().is_empty());
1116 }
1117
1118 #[tokio::test]
1119 async fn elapsed_is_non_negative() {
1120 let ctx = lazy_ctx();
1121 let e = ctx.elapsed();
1122 assert!(e.num_milliseconds() >= 0);
1124 }
1125
1126 #[tokio::test]
1127 async fn register_and_has_compensation_round_trip() {
1128 let ctx = lazy_ctx();
1129 assert!(!ctx.has_compensation("step1"));
1130 let handler: CompensationHandler =
1131 Arc::new(|_v| Box::pin(async { Ok::<(), ForgeError>(()) }));
1132 ctx.register_compensation("step1", handler);
1133 assert!(ctx.has_compensation("step1"));
1134 assert!(ctx.get_compensation_handler("step1").is_some());
1135 assert!(ctx.get_compensation_handler("step2").is_none());
1136 }
1137
1138 #[tokio::test]
1139 async fn all_step_states_returns_independent_clone() {
1140 let ctx = lazy_ctx();
1141 let mut s = HashMap::new();
1142 s.insert("a".to_string(), StepState::new("a"));
1143 let ctx = ctx.with_step_states(s);
1144
1145 let snap = ctx.all_step_states();
1146 assert_eq!(snap.len(), 1);
1147 assert!(snap.contains_key("a"));
1148 }
1149}