Skip to main content

sayiir_runtime/runner/
distributed.rs

1//! Checkpointing workflow runner for single-process execution with persistence.
2//!
3//! This runner executes an entire workflow within a single process while saving
4//! snapshots after each task completion. This enables crash recovery and resumption
5//! without requiring multiple workers.
6//!
7//! **Use this when**: You want to run a workflow reliably on a single node with
8//! the ability to resume after crashes.
9//!
10//! **Use [`PooledWorker`](crate::worker::PooledWorker) instead when**: You need
11//! horizontal scaling with multiple workers collaborating on tasks.
12
13use std::ops::ControlFlow;
14use std::sync::Arc;
15
16use bytes::Bytes;
17use sayiir_core::codec::sealed;
18use sayiir_core::codec::{Codec, EnvelopeCodec};
19use sayiir_core::context::WorkflowContext;
20use sayiir_core::error::WorkflowError;
21use sayiir_core::snapshot::{ExecutionPosition, TaskHint, WorkflowSnapshot};
22use sayiir_core::workflow::{ConflictPolicy, Workflow, WorkflowContinuation, WorkflowStatus};
23use sayiir_persistence::PersistentBackend;
24
25use crate::error::RuntimeError;
26use crate::execution::control_flow::{
27    ParkReason, StepOutcome, StepResult, compute_signal_timeout, compute_wake_at,
28    save_branch_park_checkpoint, save_park_checkpoint,
29};
30use crate::execution::loop_runner::{
31    CheckpointingLoopHooks, LoopConfig, LoopExit, LoopNext, resolve_loop_iteration, run_loop_async,
32};
33use crate::execution::{
34    ForkBranchOutcome, JoinResolution, ResumeParkedPosition, branch_execute_or_skip_task,
35    check_guards, collect_cached_branches, execute_or_skip_task, finalize_execution,
36    get_resume_input, resolve_join, retry_with_checkpoint, set_deadline_if_needed,
37    settle_fork_outcome,
38};
39
40/// A single-process workflow runner with checkpointing for crash recovery.
41///
42/// `CheckpointingRunner` executes an entire workflow within one process,
43/// saving snapshots after each task. Fork branches run concurrently as tokio tasks.
44/// If the process crashes, the workflow can be resumed from the last checkpoint.
45///
46/// # When to Use
47///
48/// - **Single-node execution**: One process runs the entire workflow
49/// - **Crash recovery**: Resume from the last completed task after restart
50/// - **Simple deployment**: No coordination between workers needed
51///
52/// For horizontal scaling with multiple workers, use [`PooledWorker`](crate::worker::PooledWorker).
53///
54/// # Example
55///
56/// ```rust,no_run
57/// # use sayiir_runtime::prelude::*;
58/// # use std::sync::Arc;
59/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
60/// let backend = InMemoryBackend::new();
61/// let runner = CheckpointingRunner::new(backend);
62///
63/// let ctx = WorkflowContext::new("my-workflow", Arc::new(JsonCodec), Arc::new(()));
64/// let workflow = WorkflowBuilder::new(ctx)
65///     .then("step1", |i: u32| async move { Ok(i + 1) })
66///     .build()?;
67///
68/// // Run workflow - snapshots are saved automatically
69/// let status = runner.run(&workflow, "instance-123", 1u32).await?;
70///
71/// // Resume from checkpoint if needed (e.g., after crash)
72/// let status = runner.resume(&workflow, "instance-123").await?;
73/// # Ok(())
74/// # }
75/// ```
76pub struct CheckpointingRunner<B> {
77    backend: Arc<B>,
78    conflict_policy: ConflictPolicy,
79}
80
81impl<B> CheckpointingRunner<B>
82where
83    B: PersistentBackend,
84{
85    /// Create a new checkpointing runner with the given backend.
86    pub fn new(backend: B) -> Self {
87        Self {
88            backend: Arc::new(backend),
89            conflict_policy: ConflictPolicy::default(),
90        }
91    }
92
93    /// Create a runner from a shared backend reference.
94    ///
95    /// Useful when the same backend is shared with a [`WorkflowClient`](crate::WorkflowClient).
96    pub fn from_shared(backend: Arc<B>) -> Self {
97        Self {
98            backend,
99            conflict_policy: ConflictPolicy::default(),
100        }
101    }
102
103    /// Set the conflict policy for duplicate instance IDs.
104    #[must_use]
105    pub fn with_conflict_policy(mut self, policy: ConflictPolicy) -> Self {
106        self.conflict_policy = policy;
107        self
108    }
109
110    /// Get a reference to the backend.
111    #[must_use]
112    pub fn backend(&self) -> &Arc<B> {
113        &self.backend
114    }
115}
116
117impl<B> CheckpointingRunner<B>
118where
119    B: PersistentBackend + 'static,
120{
121    /// Run a workflow from the beginning, saving checkpoints after each task.
122    ///
123    /// The `instance_id` uniquely identifies this workflow execution instance.
124    /// The [`ConflictPolicy`] (set via [`with_conflict_policy`](Self::with_conflict_policy))
125    /// controls behaviour when a snapshot with this ID already exists.
126    ///
127    /// # Errors
128    ///
129    /// Returns an error if the workflow cannot be executed, if snapshot
130    /// operations fail, or if the conflict policy rejects a duplicate.
131    pub async fn run<C, Input, M>(
132        &self,
133        workflow: &Workflow<C, Input, M>,
134        instance_id: impl Into<String>,
135        input: Input,
136    ) -> Result<WorkflowStatus, RuntimeError>
137    where
138        Input: Send + 'static,
139        M: Send + Sync + 'static,
140        C: Codec
141            + EnvelopeCodec
142            + sealed::EncodeValue<Input>
143            + sealed::DecodeValue<Input>
144            + 'static,
145    {
146        use crate::{PrepareRunOutcome, check_existing_instance, prepare_run};
147
148        let instance_id = instance_id.into();
149        let definition_hash = workflow.definition_hash().to_string();
150        let conflict_policy = self.conflict_policy;
151
152        // Phase 1: check for existing instance before encoding input.
153        if let Some((status, _output)) = check_existing_instance(
154            &instance_id,
155            &definition_hash,
156            self.backend.as_ref(),
157            conflict_policy,
158        )
159        .await?
160        {
161            return Ok(status);
162        }
163
164        // Phase 2: encode input and prepare snapshot.
165        let input_bytes = workflow.context().codec.encode(&input)?;
166        let first_task = workflow.continuation().first_task_hint();
167
168        let mut snapshot = match prepare_run(
169            instance_id,
170            definition_hash,
171            input_bytes.clone(),
172            first_task,
173            self.backend.as_ref(),
174            conflict_policy,
175            true, // prechecked — check_existing_instance already ran
176        )
177        .await?
178        {
179            PrepareRunOutcome::Fresh(s) => *s,
180            PrepareRunOutcome::ExistingStatus(status, _output) => {
181                return Ok(status);
182            }
183        };
184
185        // Execute workflow with checkpointing
186        let context = workflow.context().clone();
187        let continuation = workflow.continuation();
188        let backend = Arc::clone(&self.backend);
189
190        let result = Self::execute_with_checkpointing(
191            continuation,
192            input_bytes,
193            &mut snapshot,
194            Arc::clone(&backend),
195            context,
196        )
197        .await;
198
199        let (status, _output) = finalize_execution(result, &mut snapshot, backend.as_ref()).await?;
200        Ok(status)
201    }
202
203    /// Resume a workflow from a saved snapshot.
204    ///
205    /// Loads the snapshot for the given instance ID and continues execution
206    /// from the last checkpoint.
207    ///
208    /// # Errors
209    ///
210    /// Returns an error if:
211    /// - The snapshot is not found
212    /// - The workflow definition hash doesn't match (workflow definition changed)
213    /// - The workflow cannot be resumed
214    #[allow(clippy::needless_lifetimes)]
215    pub async fn resume<'w, C, Input, M>(
216        &self,
217        workflow: &'w Workflow<C, Input, M>,
218        instance_id: &str,
219    ) -> Result<WorkflowStatus, RuntimeError>
220    where
221        Input: Send + 'static,
222        M: Send + Sync + 'static,
223        C: Codec
224            + EnvelopeCodec
225            + sealed::DecodeValue<Input>
226            + sealed::EncodeValue<Input>
227            + 'static,
228    {
229        // Load snapshot
230        let mut snapshot = self.backend.load_snapshot(instance_id).await?;
231
232        // Validate definition hash
233        if snapshot.definition_hash != workflow.definition_hash() {
234            return Err(WorkflowError::DefinitionMismatch {
235                expected: workflow.definition_hash().to_string(),
236                found: snapshot.definition_hash.clone(),
237            }
238            .into());
239        }
240
241        // Check if already in terminal state
242        if let Some(status) = snapshot.state.as_terminal_status() {
243            return Ok(status);
244        }
245
246        // Resolve any parked position (delay / fork) before resuming.
247        let parked = ResumeParkedPosition::extract(&snapshot);
248        if let Some(status) = parked
249            .resolve(&mut snapshot, instance_id, self.backend.as_ref())
250            .await?
251        {
252            return Ok(status);
253        }
254
255        // Resume execution
256        let context = workflow.context().clone();
257        let continuation = workflow.continuation();
258        let backend = Arc::clone(&self.backend);
259
260        // Get the last completed task's output or initial input
261        let input_bytes = get_resume_input(&snapshot)?;
262
263        let result = Self::execute_with_checkpointing(
264            continuation,
265            input_bytes,
266            &mut snapshot,
267            Arc::clone(&backend),
268            context,
269        )
270        .await;
271
272        let (status, _output) = finalize_execution(result, &mut snapshot, backend.as_ref()).await?;
273        Ok(status)
274    }
275
276    /// Execute continuation with checkpointing after each task (iterative, no boxing).
277    #[allow(clippy::manual_async_fn, clippy::too_many_lines)]
278    async fn execute_with_checkpointing<'a, C, M>(
279        continuation: &'a WorkflowContinuation,
280        input: Bytes,
281        snapshot: &'a mut WorkflowSnapshot,
282        backend: Arc<B>,
283        context: WorkflowContext<C, M>,
284    ) -> Result<Bytes, RuntimeError>
285    where
286        B: 'static,
287        C: Codec + EnvelopeCodec + 'static,
288        M: Send + Sync + 'static,
289    {
290        let mut current = continuation;
291        let mut current_input = input;
292
293        loop {
294            let step: StepResult = match current {
295                WorkflowContinuation::Task {
296                    id,
297                    func: Some(func),
298                    timeout,
299                    retry_policy,
300                    ..
301                } => {
302                    check_guards(backend.as_ref(), &snapshot.instance_id, Some(id)).await?;
303                    set_deadline_if_needed(id, timeout.as_ref(), snapshot, backend.as_ref())
304                        .await?;
305
306                    let output = retry_with_checkpoint(
307                        id,
308                        retry_policy.as_ref(),
309                        timeout.as_ref(),
310                        snapshot,
311                        Some(backend.as_ref()),
312                        async |snap| {
313                            execute_or_skip_task(id, current_input.clone(), |i| func.run(i), snap)
314                                .await
315                        },
316                    )
317                    .await?;
318
319                    if let Some(next_cont) = current.get_next() {
320                        let next_id = next_cont.first_task_id().to_string();
321                        snapshot.set_task_hint(&TaskHint {
322                            id: next_id.clone(),
323                            priority: continuation.get_task_priority(&next_id),
324                            tags: continuation.get_task_tags(&next_id),
325                        });
326                        snapshot.update_position(ExecutionPosition::AtTask { task_id: next_id });
327                    }
328                    backend.save_snapshot(snapshot).await?;
329                    check_guards(backend.as_ref(), &snapshot.instance_id, None).await?;
330
331                    Ok(ControlFlow::Continue(output))
332                }
333                WorkflowContinuation::Task { func: None, id, .. } => {
334                    return Err(WorkflowError::TaskNotImplemented(id.clone()).into());
335                }
336                WorkflowContinuation::Delay { id, duration, next } => {
337                    check_guards(backend.as_ref(), &snapshot.instance_id, Some(id)).await?;
338
339                    if snapshot.get_task_result(id).is_some() {
340                        Ok(ControlFlow::Continue(current_input.clone()))
341                    } else {
342                        let wake_at = compute_wake_at(duration)?;
343                        Ok(ControlFlow::Break(StepOutcome::Park(ParkReason::Delay {
344                            delay_id: id.clone(),
345                            wake_at,
346                            next_task: next.as_deref().map(WorkflowContinuation::first_task_hint),
347                            passthrough: current_input.clone(),
348                        })))
349                    }
350                }
351                WorkflowContinuation::AwaitSignal {
352                    id,
353                    signal_name,
354                    timeout,
355                    next,
356                } => {
357                    check_guards(backend.as_ref(), &snapshot.instance_id, Some(id)).await?;
358
359                    if snapshot.get_task_result(id).is_some() {
360                        let payload = snapshot
361                            .get_task_result_bytes(id)
362                            .unwrap_or(current_input.clone());
363                        Ok(ControlFlow::Continue(payload))
364                    } else {
365                        match backend
366                            .consume_event(&snapshot.instance_id, signal_name)
367                            .await
368                        {
369                            Ok(Some(payload)) => {
370                                snapshot.mark_task_completed(id.clone(), payload);
371                                if let Some(next_cont) = next.as_deref() {
372                                    let next_id = next_cont.first_task_id().to_string();
373                                    snapshot.set_task_hint(&TaskHint {
374                                        id: next_id.clone(),
375                                        priority: continuation.get_task_priority(&next_id),
376                                        tags: continuation.get_task_tags(&next_id),
377                                    });
378                                    snapshot.update_position(ExecutionPosition::AtTask {
379                                        task_id: next_id,
380                                    });
381                                }
382                                backend.save_snapshot(snapshot).await?;
383                                let output = snapshot
384                                    .get_task_result_bytes(id)
385                                    .unwrap_or(current_input.clone());
386                                Ok(ControlFlow::Continue(output))
387                            }
388                            Ok(None) => Ok(ControlFlow::Break(StepOutcome::Park(
389                                ParkReason::AwaitingSignal {
390                                    signal_id: id.clone(),
391                                    signal_name: signal_name.clone(),
392                                    timeout: compute_signal_timeout(timeout.as_ref()),
393                                    next_task: next
394                                        .as_deref()
395                                        .map(WorkflowContinuation::first_task_hint),
396                                },
397                            ))),
398                            Err(e) => Err(RuntimeError::from(e)),
399                        }
400                    }
401                }
402                WorkflowContinuation::Fork {
403                    id: fork_id,
404                    branches,
405                    join,
406                } => {
407                    check_guards(backend.as_ref(), &snapshot.instance_id, None).await?;
408
409                    let branch_results =
410                        if let Some(cached) = collect_cached_branches(branches, snapshot) {
411                            cached
412                        } else {
413                            let outcome = Self::execute_fork_branches_parallel(
414                                branches,
415                                &current_input,
416                                snapshot,
417                                &backend,
418                                &context,
419                            )
420                            .await?;
421                            settle_fork_outcome(
422                                fork_id,
423                                outcome,
424                                join.as_deref(),
425                                snapshot,
426                                backend.as_ref(),
427                            )
428                            .await?
429                        };
430
431                    match resolve_join(join.as_deref(), &branch_results, context.codec.as_ref())? {
432                        JoinResolution::Continue { input, .. } => Ok(ControlFlow::Continue(input)),
433                        JoinResolution::Done(output) => {
434                            Ok(ControlFlow::Break(StepOutcome::Done(output)))
435                        }
436                    }
437                }
438                WorkflowContinuation::Branch {
439                    id,
440                    key_fn: Some(key_fn),
441                    branches,
442                    default,
443                    ..
444                } => {
445                    check_guards(backend.as_ref(), &snapshot.instance_id, Some(id)).await?;
446
447                    if let Some(result) = snapshot.get_task_result(id) {
448                        Ok(ControlFlow::Continue(result.output.clone()))
449                    } else {
450                        let key_bytes = key_fn
451                            .run(current_input.clone())
452                            .await
453                            .map_err(RuntimeError::from)?;
454                        let key: String = context
455                            .codec
456                            .decode_string(&key_bytes)
457                            .map_err(RuntimeError::from)?;
458
459                        let chosen = branches.get(&key).or(default.as_ref()).ok_or_else(|| {
460                            WorkflowError::BranchKeyNotFound {
461                                branch_id: id.clone(),
462                                key: key.clone(),
463                            }
464                        })?;
465
466                        let branch_output = Self::execute_branch_with_checkpoint(
467                            chosen,
468                            current_input.clone(),
469                            Arc::clone(&backend),
470                            snapshot.instance_id.clone(),
471                            context.clone(),
472                        )
473                        .await?;
474
475                        let envelope_bytes = context
476                            .codec
477                            .encode_branch_envelope(&key, &branch_output)
478                            .map_err(RuntimeError::from)?;
479
480                        snapshot.mark_task_completed(id.clone(), envelope_bytes.clone());
481                        backend.save_snapshot(snapshot).await?;
482
483                        Ok(ControlFlow::Continue(envelope_bytes))
484                    }
485                }
486                WorkflowContinuation::Branch {
487                    key_fn: None, id, ..
488                } => {
489                    return Err(WorkflowError::TaskNotImplemented(
490                        sayiir_core::workflow::key_fn_id(id),
491                    )
492                    .into());
493                }
494                WorkflowContinuation::Loop {
495                    id,
496                    body,
497                    max_iterations,
498                    on_max,
499                    ..
500                } => {
501                    check_guards(backend.as_ref(), &snapshot.instance_id, Some(id)).await?;
502
503                    if let Some(result) = snapshot.get_task_result(id) {
504                        Ok(ControlFlow::Continue(result.output.clone()))
505                    } else {
506                        let cfg = LoopConfig {
507                            id,
508                            body,
509                            max_iterations: *max_iterations,
510                            on_max: *on_max,
511                            start_iteration: snapshot.loop_iteration(id),
512                        };
513                        let mut loop_input = current_input.clone();
514                        let mut final_output = None;
515
516                        for iteration in cfg.start_iteration..cfg.max_iterations {
517                            let output = Box::pin(Self::execute_with_checkpointing(
518                                body,
519                                loop_input.clone(),
520                                snapshot,
521                                Arc::clone(&backend),
522                                context.clone(),
523                            ))
524                            .await?;
525
526                            let body_ser = body.to_serializable();
527                            for tid in &body_ser.task_ids() {
528                                snapshot.remove_task_result(tid);
529                            }
530
531                            match resolve_loop_iteration(&output, iteration, &cfg)? {
532                                ControlFlow::Break(LoopExit(inner)) => {
533                                    snapshot.clear_loop_iteration(id);
534                                    snapshot.mark_task_completed(id.clone(), inner.clone());
535                                    backend.save_snapshot(snapshot).await?;
536                                    final_output = Some(inner);
537                                    break;
538                                }
539                                ControlFlow::Continue(LoopNext(inner)) => {
540                                    snapshot.set_loop_iteration(id, iteration + 1);
541                                    snapshot.update_position(ExecutionPosition::InLoop {
542                                        loop_id: id.clone(),
543                                        iteration: iteration + 1,
544                                        next_task_id: Some(body.first_task_id().to_string()),
545                                    });
546                                    backend.save_snapshot(snapshot).await?;
547                                    loop_input = inner;
548                                }
549                            }
550                        }
551
552                        match final_output {
553                            Some(output) => Ok(ControlFlow::Continue(output)),
554                            None => Err(RuntimeError::from(WorkflowError::MaxIterationsExceeded {
555                                loop_id: id.clone(),
556                                max_iterations: *max_iterations,
557                            })),
558                        }
559                    }
560                }
561                WorkflowContinuation::ChildWorkflow { id, child, .. } => {
562                    check_guards(backend.as_ref(), &snapshot.instance_id, Some(id)).await?;
563
564                    if let Some(result) = snapshot.get_task_result(id) {
565                        Ok(ControlFlow::Continue(result.output.clone()))
566                    } else {
567                        let output = Box::pin(Self::execute_with_checkpointing(
568                            child,
569                            current_input.clone(),
570                            snapshot,
571                            Arc::clone(&backend),
572                            context.clone(),
573                        ))
574                        .await?;
575
576                        snapshot.mark_task_completed(id.clone(), output.clone());
577                        backend.save_snapshot(snapshot).await?;
578
579                        Ok(ControlFlow::Continue(output))
580                    }
581                }
582            };
583
584            match step? {
585                ControlFlow::Continue(output) => match current.get_next() {
586                    Some(next) => {
587                        current = next;
588                        current_input = output;
589                    }
590                    None => return Ok(output),
591                },
592                ControlFlow::Break(StepOutcome::Done(output)) => return Ok(output),
593                ControlFlow::Break(StepOutcome::Park(reason)) => {
594                    return Err(save_park_checkpoint(reason, snapshot, backend.as_ref()).await);
595                }
596            }
597        }
598    }
599
600    /// Execute fork branches in parallel using tokio tasks.
601    async fn execute_fork_branches_parallel<C, M>(
602        branches: &[Arc<WorkflowContinuation>],
603        input: &Bytes,
604        snapshot: &WorkflowSnapshot,
605        backend: &Arc<B>,
606        context: &WorkflowContext<C, M>,
607    ) -> Result<ForkBranchOutcome, RuntimeError>
608    where
609        B: 'static,
610        C: Codec + EnvelopeCodec + 'static,
611        M: Send + Sync + 'static,
612    {
613        let mut branch_results = Vec::with_capacity(branches.len());
614        let mut set = tokio::task::JoinSet::new();
615        let instance_id = snapshot.instance_id.clone();
616
617        for branch in branches {
618            let branch_id = branch.id().to_string();
619
620            if let Some(result) = snapshot.get_task_result(&branch_id) {
621                branch_results.push((branch_id, result.output.clone()));
622            } else {
623                let branch = Arc::clone(branch);
624                let branch_input = input.clone();
625                let branch_backend = Arc::clone(backend);
626                let branch_instance_id = instance_id.clone();
627                let ctx_for_work = context.clone();
628
629                set.spawn(async move {
630                    let result = Self::execute_branch_with_checkpoint(
631                        &branch,
632                        branch_input,
633                        branch_backend,
634                        branch_instance_id,
635                        ctx_for_work,
636                    )
637                    .await?;
638                    Ok((branch_id, result))
639                });
640            }
641        }
642
643        let mut max_wake_at: Option<chrono::DateTime<chrono::Utc>> = None;
644
645        while let Some(result) = set.join_next().await {
646            match result {
647                Ok(Ok((branch_id, output))) => {
648                    branch_results.push((branch_id, output));
649                }
650                Ok(Err(RuntimeError::Workflow(WorkflowError::Waiting { wake_at }))) => {
651                    max_wake_at = Some(match max_wake_at {
652                        Some(existing) => existing.max(wake_at),
653                        None => wake_at,
654                    });
655                }
656                Ok(Err(e)) => return Err(e),
657                Err(join_err) => return Err(RuntimeError::from(join_err)),
658            }
659        }
660
661        Ok(ForkBranchOutcome {
662            results: branch_results,
663            max_wake_at,
664        })
665    }
666
667    /// Execute nested fork branches in parallel within a branch.
668    ///
669    /// Spawns each branch as a tokio task, collects all results, and propagates
670    /// errors (including `JoinError`).
671    async fn execute_nested_fork_branches<C, M>(
672        branches: &[Arc<WorkflowContinuation>],
673        input: &Bytes,
674        backend: &Arc<B>,
675        instance_id: &str,
676        context: &WorkflowContext<C, M>,
677    ) -> Result<Vec<(String, Bytes)>, RuntimeError>
678    where
679        B: 'static,
680        C: Codec + EnvelopeCodec + 'static,
681        M: Send + Sync + 'static,
682    {
683        let mut set: tokio::task::JoinSet<Result<(String, Bytes), RuntimeError>> =
684            tokio::task::JoinSet::new();
685        for branch in branches {
686            let id = branch.id().to_string();
687            let branch = Arc::clone(branch);
688            let branch_input = input.clone();
689            let branch_backend = Arc::clone(backend);
690            let branch_instance_id = instance_id.to_string();
691            let ctx_for_work = context.clone();
692
693            set.spawn(async move {
694                let result = Self::execute_branch_with_checkpoint(
695                    &branch,
696                    branch_input,
697                    branch_backend,
698                    branch_instance_id,
699                    ctx_for_work,
700                )
701                .await?;
702                Ok((id, result))
703            });
704        }
705
706        let mut branch_results: Vec<(String, Bytes)> = Vec::with_capacity(set.len());
707        while let Some(res) = set.join_next().await {
708            branch_results.push(res??);
709        }
710        Ok(branch_results)
711    }
712
713    /// Execute branch continuation with per-task checkpointing (iterative, no boxing).
714    ///
715    /// Unlike `execute_with_checkpointing`, this doesn't update position tracking
716    /// (branches run independently). It saves each task result directly to the backend.
717    ///
718    /// On resume after `AtFork`, the backend snapshot contains sub-task results from
719    /// the previous execution. This function loads the snapshot to skip cached tasks
720    /// and parks at delays instead of sleeping through them.
721    #[allow(clippy::manual_async_fn, clippy::too_many_lines)]
722    fn execute_branch_with_checkpoint<C, M>(
723        continuation: &WorkflowContinuation,
724        input: Bytes,
725        backend: Arc<B>,
726        instance_id: String,
727        context: WorkflowContext<C, M>,
728    ) -> impl std::future::Future<Output = Result<Bytes, RuntimeError>> + Send + '_
729    where
730        B: 'static,
731        C: Codec + EnvelopeCodec + 'static,
732        M: Send + Sync + 'static,
733    {
734        async move {
735            let mut snapshot = backend.load_snapshot(&instance_id).await?;
736
737            let mut current = continuation;
738            let mut current_input = input;
739
740            loop {
741                let step: StepResult = match current {
742                    WorkflowContinuation::Task {
743                        id,
744                        func,
745                        timeout,
746                        retry_policy,
747                        ..
748                    } => {
749                        let func = func
750                            .as_ref()
751                            .ok_or_else(|| WorkflowError::TaskNotImplemented(id.clone()))?;
752
753                        let output = loop {
754                            match branch_execute_or_skip_task(
755                                id,
756                                current_input.clone(),
757                                |i| func.run(i),
758                                timeout.as_ref(),
759                                &mut snapshot,
760                                &instance_id,
761                                backend.as_ref(),
762                            )
763                            .await
764                            {
765                                Ok(output) => {
766                                    snapshot.clear_retry_state(id);
767                                    break output;
768                                }
769                                Err(e) => {
770                                    if let Some(rp) = retry_policy
771                                        && !snapshot.retries_exhausted(id)
772                                    {
773                                        let next_retry_at =
774                                            snapshot.record_retry(id, rp, &e.to_string(), None);
775                                        snapshot.clear_task_deadline();
776                                        tracing::info!(
777                                            task_id = %id,
778                                            attempt = snapshot.get_retry_state(id).map_or(0, |rs| rs.attempts),
779                                            max_retries = rp.max_retries,
780                                            %next_retry_at,
781                                            error = %e,
782                                            "Retrying task (branch)"
783                                        );
784                                        let delay = (next_retry_at - chrono::Utc::now())
785                                            .to_std()
786                                            .unwrap_or_default();
787                                        tokio::time::sleep(delay).await;
788                                        continue;
789                                    }
790                                    return Err(e);
791                                }
792                            }
793                        };
794                        Ok(ControlFlow::Continue(output))
795                    }
796                    WorkflowContinuation::Delay { id, duration, .. } => {
797                        if let Some(result) = snapshot.get_task_result(id) {
798                            tracing::debug!(delay_id = %id, "delay already completed in branch, skipping");
799                            Ok(ControlFlow::Continue(result.output.clone()))
800                        } else {
801                            let wake_at = compute_wake_at(duration)?;
802                            Ok(ControlFlow::Break(StepOutcome::Park(ParkReason::Delay {
803                                delay_id: id.clone(),
804                                wake_at,
805                                next_task: None,
806                                passthrough: current_input.clone(),
807                            })))
808                        }
809                    }
810                    WorkflowContinuation::AwaitSignal {
811                        id,
812                        signal_name,
813                        timeout,
814                        ..
815                    } => {
816                        if let Some(result) = snapshot.get_task_result(id) {
817                            tracing::debug!(signal_id = %id, %signal_name, "signal already consumed in branch, skipping");
818                            Ok(ControlFlow::Continue(result.output.clone()))
819                        } else {
820                            let wake_at = compute_signal_timeout(timeout.as_ref());
821                            Ok(ControlFlow::Break(StepOutcome::Park(
822                                ParkReason::AwaitingSignal {
823                                    signal_id: id.clone(),
824                                    signal_name: signal_name.clone(),
825                                    timeout: wake_at,
826                                    next_task: None,
827                                },
828                            )))
829                        }
830                    }
831                    WorkflowContinuation::Fork { branches, join, .. } => {
832                        let branch_results = Self::execute_nested_fork_branches(
833                            branches,
834                            &current_input,
835                            &backend,
836                            &instance_id,
837                            &context,
838                        )
839                        .await?;
840
841                        match resolve_join(
842                            join.as_deref(),
843                            &branch_results,
844                            context.codec.as_ref(),
845                        )? {
846                            JoinResolution::Continue { input, .. } => {
847                                Ok(ControlFlow::Continue(input))
848                            }
849                            JoinResolution::Done(output) => {
850                                Ok(ControlFlow::Break(StepOutcome::Done(output)))
851                            }
852                        }
853                    }
854                    WorkflowContinuation::Branch {
855                        id,
856                        key_fn: Some(key_fn),
857                        branches,
858                        default,
859                        ..
860                    } => {
861                        if let Some(result) = snapshot.get_task_result(id) {
862                            Ok(ControlFlow::Continue(result.output.clone()))
863                        } else {
864                            let key_bytes = key_fn
865                                .run(current_input.clone())
866                                .await
867                                .map_err(RuntimeError::from)?;
868                            let key: String = context
869                                .codec
870                                .decode_string(&key_bytes)
871                                .map_err(RuntimeError::from)?;
872
873                            let chosen =
874                                branches.get(&key).or(default.as_ref()).ok_or_else(|| {
875                                    WorkflowError::BranchKeyNotFound {
876                                        branch_id: id.clone(),
877                                        key: key.clone(),
878                                    }
879                                })?;
880
881                            let branch_output = Box::pin(Self::execute_branch_with_checkpoint(
882                                chosen,
883                                current_input.clone(),
884                                Arc::clone(&backend),
885                                instance_id.clone(),
886                                context.clone(),
887                            ))
888                            .await?;
889
890                            let envelope_bytes = context
891                                .codec
892                                .encode_branch_envelope(&key, &branch_output)
893                                .map_err(RuntimeError::from)?;
894
895                            snapshot.mark_task_completed(id.clone(), envelope_bytes.clone());
896                            backend.save_snapshot(&snapshot).await?;
897
898                            Ok(ControlFlow::Continue(envelope_bytes))
899                        }
900                    }
901                    WorkflowContinuation::Branch {
902                        key_fn: None, id, ..
903                    } => {
904                        return Err(WorkflowError::TaskNotImplemented(
905                            sayiir_core::workflow::key_fn_id(id),
906                        )
907                        .into());
908                    }
909                    WorkflowContinuation::Loop {
910                        id,
911                        body,
912                        max_iterations,
913                        on_max,
914                        ..
915                    } => {
916                        if let Some(result) = snapshot.get_task_result(id) {
917                            Ok(ControlFlow::Continue(result.output.clone()))
918                        } else {
919                            let cfg = LoopConfig {
920                                id,
921                                body,
922                                max_iterations: *max_iterations,
923                                on_max: *on_max,
924                                start_iteration: snapshot.loop_iteration(id),
925                            };
926                            let mut hooks = CheckpointingLoopHooks {
927                                snapshot: &mut snapshot,
928                                backend: backend.as_ref(),
929                                track_position: false,
930                            };
931                            let output = run_loop_async(
932                                &cfg,
933                                current_input.clone(),
934                                |input| {
935                                    Box::pin(Self::execute_branch_with_checkpoint(
936                                        body,
937                                        input,
938                                        Arc::clone(&backend),
939                                        instance_id.clone(),
940                                        context.clone(),
941                                    ))
942                                },
943                                &mut hooks,
944                            )
945                            .await?;
946                            Ok(ControlFlow::Continue(output))
947                        }
948                    }
949                    WorkflowContinuation::ChildWorkflow { id, child, .. } => {
950                        if let Some(result) = snapshot.get_task_result(id) {
951                            Ok(ControlFlow::Continue(result.output.clone()))
952                        } else {
953                            let output = Box::pin(Self::execute_branch_with_checkpoint(
954                                child,
955                                current_input.clone(),
956                                Arc::clone(&backend),
957                                instance_id.clone(),
958                                context.clone(),
959                            ))
960                            .await?;
961
962                            snapshot.mark_task_completed(id.clone(), output.clone());
963                            backend.save_snapshot(&snapshot).await?;
964
965                            Ok(ControlFlow::Continue(output))
966                        }
967                    }
968                };
969
970                match step? {
971                    ControlFlow::Continue(output) => match current.get_next() {
972                        Some(next) => {
973                            current = next;
974                            current_input = output;
975                        }
976                        None => return Ok(output),
977                    },
978                    ControlFlow::Break(StepOutcome::Done(output)) => return Ok(output),
979                    ControlFlow::Break(StepOutcome::Park(reason)) => {
980                        return Err(save_branch_park_checkpoint(
981                            reason,
982                            &instance_id,
983                            backend.as_ref(),
984                        )
985                        .await);
986                    }
987                }
988            }
989        }
990    }
991}
992
993#[cfg(test)]
994#[allow(
995    clippy::unwrap_used,
996    clippy::expect_used,
997    clippy::panic,
998    clippy::indexing_slicing,
999    clippy::too_many_lines,
1000    clippy::manual_let_else
1001)]
1002mod tests {
1003    use super::*;
1004    use crate::serialization::JsonCodec;
1005    use sayiir_core::codec::Encoder;
1006    use sayiir_core::context::WorkflowContext;
1007    use sayiir_core::error::BoxError;
1008    use sayiir_core::snapshot::SignalKind;
1009    use sayiir_core::snapshot::WorkflowSnapshotState;
1010    use sayiir_core::task::BranchOutputs;
1011    use sayiir_core::workflow::WorkflowBuilder;
1012    use sayiir_macros::BranchKey;
1013    use sayiir_persistence::InMemoryBackend;
1014    use sayiir_persistence::{SignalStore, SnapshotStore};
1015
1016    #[derive(BranchKey)]
1017    enum RouteKey {
1018        Billing,
1019        Tech,
1020    }
1021
1022    #[derive(BranchKey)]
1023    enum AbKey {
1024        A,
1025        B,
1026    }
1027
1028    fn ctx() -> WorkflowContext<JsonCodec, ()> {
1029        WorkflowContext::new("test-workflow", Arc::new(JsonCodec), Arc::new(()))
1030    }
1031
1032    // ========================================================================
1033    // Run (fresh execution)
1034    // ========================================================================
1035
1036    #[tokio::test]
1037    async fn test_run_single_task() {
1038        let backend = InMemoryBackend::new();
1039        let runner = CheckpointingRunner::new(backend);
1040
1041        let workflow = WorkflowBuilder::new(ctx())
1042            .then("add_one", |i: u32| async move { Ok(i + 1) })
1043            .build()
1044            .unwrap();
1045
1046        let status = runner.run(&workflow, "inst-1", 5u32).await.unwrap();
1047        assert!(matches!(status, WorkflowStatus::Completed));
1048
1049        // Verify snapshot was saved as completed
1050        let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
1051        assert!(snapshot.state.is_completed());
1052    }
1053
1054    #[tokio::test]
1055    async fn test_run_chained_tasks() {
1056        let backend = InMemoryBackend::new();
1057        let runner = CheckpointingRunner::new(backend);
1058
1059        let workflow = WorkflowBuilder::new(ctx())
1060            .then("add_one", |i: u32| async move { Ok(i + 1) })
1061            .then("double", |i: u32| async move { Ok(i * 2) })
1062            .build()
1063            .unwrap();
1064
1065        let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1066        assert!(matches!(status, WorkflowStatus::Completed));
1067
1068        let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
1069        assert!(snapshot.state.is_completed());
1070    }
1071
1072    #[tokio::test]
1073    async fn test_run_three_task_chain() {
1074        let backend = InMemoryBackend::new();
1075        let runner = CheckpointingRunner::new(backend);
1076
1077        let workflow = WorkflowBuilder::new(ctx())
1078            .then("step1", |i: u32| async move { Ok(i + 1) })
1079            .then("step2", |i: u32| async move { Ok(i * 3) })
1080            .then("step3", |i: u32| async move { Ok(i - 2) })
1081            .build()
1082            .unwrap();
1083
1084        let status = runner.run(&workflow, "inst-1", 5u32).await.unwrap();
1085        // 5+1=6, 6*3=18, 18-2=16
1086        assert!(matches!(status, WorkflowStatus::Completed));
1087    }
1088
1089    #[tokio::test]
1090    async fn test_run_task_failure() {
1091        let backend = InMemoryBackend::new();
1092        let runner = CheckpointingRunner::new(backend);
1093
1094        let workflow = WorkflowBuilder::new(ctx())
1095            .then("fail", |_i: u32| async move {
1096                Err::<u32, BoxError>("intentional failure".into())
1097            })
1098            .build()
1099            .unwrap();
1100
1101        let status = runner.run(&workflow, "inst-1", 1u32).await.unwrap();
1102        match status {
1103            WorkflowStatus::Failed(e) => {
1104                assert!(e.contains("intentional failure"));
1105            }
1106            _ => panic!("Expected Failed status"),
1107        }
1108
1109        // Snapshot should be marked as failed
1110        let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
1111        assert!(snapshot.state.is_failed());
1112    }
1113
1114    #[tokio::test]
1115    async fn test_run_fork_join() {
1116        let backend = InMemoryBackend::new();
1117        let runner = CheckpointingRunner::new(backend);
1118
1119        let workflow = WorkflowBuilder::new(ctx())
1120            .then("prepare", |i: u32| async move { Ok(i) })
1121            .branches(|b| {
1122                b.add("double", |i: u32| async move { Ok(i * 2) });
1123                b.add("add_ten", |i: u32| async move { Ok(i + 10) });
1124            })
1125            .join("combine", |outputs: BranchOutputs<JsonCodec>| async move {
1126                let doubled: u32 = outputs.get_by_id("double")?;
1127                let added: u32 = outputs.get_by_id("add_ten")?;
1128                Ok(doubled + added)
1129            })
1130            .build()
1131            .unwrap();
1132
1133        let status = runner.run(&workflow, "inst-1", 5u32).await.unwrap();
1134        assert!(matches!(status, WorkflowStatus::Completed));
1135
1136        let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
1137        assert!(snapshot.state.is_completed());
1138    }
1139
1140    #[tokio::test]
1141    async fn test_run_checkpoints_intermediate_tasks() {
1142        let backend = InMemoryBackend::new();
1143        let runner = CheckpointingRunner::new(backend);
1144
1145        let workflow = WorkflowBuilder::new(ctx())
1146            .then("step1", |i: u32| async move { Ok(i + 1) })
1147            .then("step2", |i: u32| async move { Ok(i * 2) })
1148            .build()
1149            .unwrap();
1150
1151        let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1152        assert!(matches!(status, WorkflowStatus::Completed));
1153
1154        // The final snapshot should be completed, but we can verify the
1155        // instance was tracked throughout
1156        let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
1157        assert!(snapshot.state.is_completed());
1158    }
1159
1160    // ========================================================================
1161    // Resume
1162    // ========================================================================
1163
1164    #[tokio::test]
1165    async fn test_resume_completed_workflow() {
1166        let backend = InMemoryBackend::new();
1167        let runner = CheckpointingRunner::new(backend);
1168
1169        let workflow = WorkflowBuilder::new(ctx())
1170            .then("step1", |i: u32| async move { Ok(i + 1) })
1171            .build()
1172            .unwrap();
1173
1174        // Run to completion
1175        runner.run(&workflow, "inst-1", 5u32).await.unwrap();
1176
1177        // Resume should return Completed immediately
1178        let status = runner.resume(&workflow, "inst-1").await.unwrap();
1179        assert!(matches!(status, WorkflowStatus::Completed));
1180    }
1181
1182    #[tokio::test]
1183    async fn test_resume_failed_workflow() {
1184        let backend = InMemoryBackend::new();
1185        let runner = CheckpointingRunner::new(backend);
1186
1187        let workflow = WorkflowBuilder::new(ctx())
1188            .then("fail", |_i: u32| async move {
1189                Err::<u32, BoxError>("failure".into())
1190            })
1191            .build()
1192            .unwrap();
1193
1194        runner.run(&workflow, "inst-1", 1u32).await.unwrap();
1195
1196        let status = runner.resume(&workflow, "inst-1").await.unwrap();
1197        match status {
1198            WorkflowStatus::Failed(_) => {}
1199            _ => panic!("Expected Failed status"),
1200        }
1201    }
1202
1203    #[tokio::test]
1204    async fn test_resume_definition_hash_mismatch() {
1205        let backend = InMemoryBackend::new();
1206        let runner = CheckpointingRunner::new(backend);
1207
1208        let workflow1 = WorkflowBuilder::new(ctx())
1209            .then("step1", |i: u32| async move { Ok(i + 1) })
1210            .build()
1211            .unwrap();
1212
1213        // Run with workflow1
1214        runner.run(&workflow1, "inst-1", 5u32).await.unwrap();
1215
1216        // Manually create in-progress snapshot with workflow1's hash
1217        let mut snapshot = WorkflowSnapshot::with_initial_input(
1218            "inst-2".into(),
1219            workflow1.definition_hash().to_string(),
1220            Bytes::from(serde_json::to_vec(&5u32).unwrap()),
1221        );
1222        snapshot.update_position(ExecutionPosition::AtTask {
1223            task_id: "step1".into(),
1224        });
1225        runner.backend().save_snapshot(&snapshot).await.unwrap();
1226
1227        // Build a different workflow
1228        let workflow2 = WorkflowBuilder::new(ctx())
1229            .then("step1", |i: u32| async move { Ok(i + 1) })
1230            .then("step2", |i: u32| async move { Ok(i * 2) })
1231            .build()
1232            .unwrap();
1233
1234        // Resume with different workflow definition should fail
1235        let result = runner.resume(&workflow2, "inst-2").await;
1236        assert!(result.is_err());
1237        assert!(result.unwrap_err().to_string().contains("mismatch"));
1238    }
1239
1240    // ========================================================================
1241    // Cancellation
1242    // ========================================================================
1243
1244    #[tokio::test]
1245    async fn test_cancel_running_workflow() {
1246        let backend = InMemoryBackend::new();
1247        let runner = CheckpointingRunner::new(backend);
1248
1249        // Create a workflow with a slow task
1250        let workflow = WorkflowBuilder::new(ctx())
1251            .then("slow_task", |i: u32| async move {
1252                tokio::time::sleep(std::time::Duration::from_secs(10)).await;
1253                Ok(i)
1254            })
1255            .build()
1256            .unwrap();
1257
1258        // Set up a snapshot as if it's in progress
1259        let input_bytes = Arc::new(JsonCodec).encode(&1u32).unwrap();
1260        let mut snapshot = WorkflowSnapshot::with_initial_input(
1261            "inst-cancel".into(),
1262            workflow.definition_hash().to_string(),
1263            input_bytes,
1264        );
1265        snapshot.update_position(ExecutionPosition::AtTask {
1266            task_id: "slow_task".into(),
1267        });
1268        runner.backend().save_snapshot(&snapshot).await.unwrap();
1269
1270        // Request cancellation via WorkflowClient
1271        let client = crate::WorkflowClient::from_shared(Arc::clone(runner.backend()));
1272        client
1273            .cancel(
1274                "inst-cancel",
1275                Some("testing".into()),
1276                Some("test-suite".into()),
1277            )
1278            .await
1279            .unwrap();
1280
1281        // Verify cancellation request was stored
1282        let req = runner
1283            .backend()
1284            .get_signal("inst-cancel", SignalKind::Cancel)
1285            .await
1286            .unwrap();
1287        assert!(req.is_some());
1288        assert_eq!(req.unwrap().reason, Some("testing".into()));
1289    }
1290
1291    #[tokio::test]
1292    async fn test_run_with_pre_cancellation() {
1293        let backend = InMemoryBackend::new();
1294        let runner = CheckpointingRunner::new(backend);
1295
1296        let workflow = WorkflowBuilder::new(ctx())
1297            .then("task1", |i: u32| async move { Ok(i + 1) })
1298            .then("task2", |i: u32| async move { Ok(i * 2) })
1299            .build()
1300            .unwrap();
1301
1302        // Save initial snapshot and request cancellation before running
1303        let input_bytes = Arc::new(JsonCodec).encode(&1u32).unwrap();
1304        let mut snapshot = WorkflowSnapshot::with_initial_input(
1305            "inst-precancel".into(),
1306            workflow.definition_hash().to_string(),
1307            input_bytes,
1308        );
1309        snapshot.update_position(ExecutionPosition::AtTask {
1310            task_id: "task1".into(),
1311        });
1312        runner.backend().save_snapshot(&snapshot).await.unwrap();
1313
1314        let client = crate::WorkflowClient::from_shared(Arc::clone(runner.backend()));
1315        client
1316            .cancel("inst-precancel", Some("pre-cancel".into()), None)
1317            .await
1318            .unwrap();
1319
1320        // Resume should detect cancellation
1321        let status = runner.resume(&workflow, "inst-precancel").await.unwrap();
1322        match status {
1323            WorkflowStatus::Cancelled { reason, .. } => {
1324                assert_eq!(reason, Some("pre-cancel".into()));
1325            }
1326            _ => panic!("Expected Cancelled status, got: {status:?}"),
1327        }
1328    }
1329
1330    // ========================================================================
1331    // Edge cases
1332    // ========================================================================
1333
1334    #[tokio::test]
1335    async fn test_resume_nonexistent_instance() {
1336        let backend = InMemoryBackend::new();
1337        let runner = CheckpointingRunner::new(backend);
1338
1339        let workflow = WorkflowBuilder::new(ctx())
1340            .then("task", |i: u32| async move { Ok(i) })
1341            .build()
1342            .unwrap();
1343
1344        let result = runner.resume(&workflow, "nonexistent").await;
1345        assert!(result.is_err());
1346    }
1347
1348    #[tokio::test]
1349    async fn test_run_failure_in_chain_saves_snapshot() {
1350        let backend = InMemoryBackend::new();
1351        let runner = CheckpointingRunner::new(backend);
1352
1353        let workflow = WorkflowBuilder::new(ctx())
1354            .then("step1", |i: u32| async move { Ok(i + 1) })
1355            .then("fail_step", |_i: u32| async move {
1356                Err::<u32, BoxError>("mid-chain failure".into())
1357            })
1358            .then("step3", |i: u32| async move { Ok(i * 2) })
1359            .build()
1360            .unwrap();
1361
1362        let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1363        match status {
1364            WorkflowStatus::Failed(e) => {
1365                assert!(e.contains("mid-chain failure"));
1366            }
1367            _ => panic!("Expected Failed"),
1368        }
1369
1370        // Snapshot should be saved as failed
1371        let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
1372        assert!(snapshot.state.is_failed());
1373    }
1374
1375    // ========================================================================
1376    // Delay tests
1377    // ========================================================================
1378
1379    #[tokio::test]
1380    async fn test_run_workflow_with_delay_returns_waiting() {
1381        let backend = InMemoryBackend::new();
1382        let runner = CheckpointingRunner::new(backend);
1383
1384        let workflow = WorkflowBuilder::new(ctx())
1385            .then("step1", |i: u32| async move { Ok(i + 1) })
1386            .delay("wait_1h", std::time::Duration::from_secs(3600))
1387            .then("step2", |i: u32| async move { Ok(i * 2) })
1388            .build()
1389            .unwrap();
1390
1391        let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1392
1393        // Should return Waiting (delay is 1 hour in the future)
1394        match &status {
1395            WorkflowStatus::Waiting { delay_id, .. } => {
1396                assert_eq!(delay_id, "wait_1h");
1397            }
1398            _ => panic!("Expected Waiting status, got {status:?}"),
1399        }
1400
1401        // Snapshot should be in-progress at AtDelay position
1402        let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
1403        assert!(snapshot.state.is_in_progress());
1404        match &snapshot.state {
1405            WorkflowSnapshotState::InProgress { position, .. } => match position {
1406                ExecutionPosition::AtDelay {
1407                    delay_id,
1408                    next_task_id,
1409                    ..
1410                } => {
1411                    assert_eq!(delay_id, "wait_1h");
1412                    assert_eq!(next_task_id.as_deref(), Some("step2"));
1413                }
1414                other => panic!("Expected AtDelay, got {other:?}"),
1415            },
1416            _ => panic!("Expected InProgress"),
1417        }
1418
1419        // step1 should have been completed
1420        assert!(snapshot.get_task_result("step1").is_some());
1421        // delay pass-through should be stored
1422        assert!(snapshot.get_task_result("wait_1h").is_some());
1423    }
1424
1425    #[tokio::test]
1426    async fn test_resume_before_delay_expires_returns_waiting() {
1427        let backend = InMemoryBackend::new();
1428        let runner = CheckpointingRunner::new(backend);
1429
1430        let workflow = WorkflowBuilder::new(ctx())
1431            .then("step1", |i: u32| async move { Ok(i + 1) })
1432            .delay("wait_1h", std::time::Duration::from_secs(3600))
1433            .then("step2", |i: u32| async move { Ok(i * 2) })
1434            .build()
1435            .unwrap();
1436
1437        // Run to delay
1438        let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1439        assert!(matches!(status, WorkflowStatus::Waiting { .. }));
1440
1441        // Resume immediately (delay hasn't expired)
1442        let status = runner.resume(&workflow, "inst-1").await.unwrap();
1443        match &status {
1444            WorkflowStatus::Waiting { delay_id, .. } => {
1445                assert_eq!(delay_id, "wait_1h");
1446            }
1447            _ => panic!("Expected Waiting on resume, got {status:?}"),
1448        }
1449    }
1450
1451    #[tokio::test]
1452    async fn test_resume_after_delay_expires_completes() {
1453        let backend = InMemoryBackend::new();
1454        let runner = CheckpointingRunner::new(backend);
1455
1456        // Use a very short delay so it expires immediately
1457        let workflow = WorkflowBuilder::new(ctx())
1458            .then("step1", |i: u32| async move { Ok(i + 1) })
1459            .delay("wait_short", std::time::Duration::from_millis(1))
1460            .then("step2", |i: u32| async move { Ok(i * 2) })
1461            .build()
1462            .unwrap();
1463
1464        // Run — delay is so short it should still park (snapshot is saved before checking time)
1465        let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1466        assert!(matches!(status, WorkflowStatus::Waiting { .. }));
1467
1468        // Wait a bit for the delay to definitely expire
1469        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1470
1471        // Resume — delay should have expired, execution continues
1472        let status = runner.resume(&workflow, "inst-1").await.unwrap();
1473        assert!(
1474            matches!(status, WorkflowStatus::Completed),
1475            "Expected Completed after delay expired, got {status:?}"
1476        );
1477
1478        // Verify final state
1479        let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
1480        assert!(snapshot.state.is_completed());
1481    }
1482
1483    #[tokio::test]
1484    async fn test_cancel_during_delay() {
1485        let backend = InMemoryBackend::new();
1486        let runner = CheckpointingRunner::new(backend);
1487
1488        let workflow = WorkflowBuilder::new(ctx())
1489            .then("step1", |i: u32| async move { Ok(i + 1) })
1490            .delay("wait_1h", std::time::Duration::from_secs(3600))
1491            .then("step2", |i: u32| async move { Ok(i * 2) })
1492            .build()
1493            .unwrap();
1494
1495        // Run to delay
1496        let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1497        assert!(matches!(status, WorkflowStatus::Waiting { .. }));
1498
1499        // Cancel during delay via WorkflowClient
1500        let client = crate::WorkflowClient::from_shared(Arc::clone(runner.backend()));
1501        client
1502            .cancel(
1503                "inst-1",
1504                Some("no longer needed".into()),
1505                Some("admin".into()),
1506            )
1507            .await
1508            .unwrap();
1509
1510        // Resume should detect cancellation
1511        let status = runner.resume(&workflow, "inst-1").await.unwrap();
1512        match status {
1513            WorkflowStatus::Cancelled {
1514                reason,
1515                cancelled_by,
1516            } => {
1517                assert_eq!(reason, Some("no longer needed".into()));
1518                assert_eq!(cancelled_by, Some("admin".into()));
1519            }
1520            _ => panic!("Expected Cancelled status, got {status:?}"),
1521        }
1522    }
1523
1524    #[tokio::test]
1525    async fn test_delay_as_last_node() {
1526        let backend = InMemoryBackend::new();
1527        let runner = CheckpointingRunner::new(backend);
1528
1529        let workflow = WorkflowBuilder::new(ctx())
1530            .then("step1", |i: u32| async move { Ok(i + 1) })
1531            .delay("final_wait", std::time::Duration::from_millis(1))
1532            .build()
1533            .unwrap();
1534
1535        // Run to delay
1536        let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1537        assert!(matches!(status, WorkflowStatus::Waiting { .. }));
1538
1539        // Wait for delay to expire
1540        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1541
1542        // Resume — delay was the last node, should complete
1543        let status = runner.resume(&workflow, "inst-1").await.unwrap();
1544        assert!(
1545            matches!(status, WorkflowStatus::Completed),
1546            "Expected Completed when delay is last node, got {status:?}"
1547        );
1548    }
1549
1550    #[tokio::test]
1551    async fn test_delay_data_passthrough() {
1552        let backend = InMemoryBackend::new();
1553        let runner = CheckpointingRunner::new(backend);
1554
1555        // step1 produces 11, delay passes it through, step2 receives 11 and doubles
1556        let workflow = WorkflowBuilder::new(ctx())
1557            .then("step1", |i: u32| async move { Ok(i + 1) })
1558            .delay("wait", std::time::Duration::from_millis(1))
1559            .then("step2", |i: u32| async move {
1560                // Verify input is the passthrough value from step1
1561                assert_eq!(i, 11);
1562                Ok(i * 2)
1563            })
1564            .build()
1565            .unwrap();
1566
1567        // Run to delay
1568        runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1569
1570        // Wait and resume
1571        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1572        let status = runner.resume(&workflow, "inst-1").await.unwrap();
1573        assert!(matches!(status, WorkflowStatus::Completed));
1574    }
1575
1576    // ========================================================================
1577    // Timeout tests
1578    // ========================================================================
1579
1580    #[tokio::test]
1581    async fn test_run_task_timeout_fails_workflow() {
1582        use sayiir_core::task::TaskMetadata;
1583
1584        let backend = InMemoryBackend::new();
1585        let runner = CheckpointingRunner::new(backend);
1586
1587        let workflow = WorkflowBuilder::new(ctx())
1588            .with_registry()
1589            .then("slow_task", |i: u32| async move {
1590                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1591                Ok(i)
1592            })
1593            .with_metadata(TaskMetadata {
1594                timeout: Some(std::time::Duration::from_millis(5)),
1595                ..Default::default()
1596            })
1597            .build()
1598            .unwrap();
1599
1600        let status = runner
1601            .run(workflow.workflow(), "inst-timeout", 5u32)
1602            .await
1603            .unwrap();
1604        match status {
1605            WorkflowStatus::Failed(msg) => {
1606                assert!(
1607                    msg.contains("timed out"),
1608                    "Expected timeout error, got: {msg}"
1609                );
1610                assert!(
1611                    msg.contains("slow_task"),
1612                    "Expected task id in error, got: {msg}"
1613                );
1614            }
1615            other => panic!("Expected Failed status, got {other:?}"),
1616        }
1617    }
1618
1619    #[tokio::test]
1620    async fn test_run_task_within_timeout_succeeds() {
1621        use sayiir_core::task::TaskMetadata;
1622
1623        let backend = InMemoryBackend::new();
1624        let runner = CheckpointingRunner::new(backend);
1625
1626        let workflow = WorkflowBuilder::new(ctx())
1627            .with_registry()
1628            .then("fast_task", |i: u32| async move { Ok(i + 1) })
1629            .with_metadata(TaskMetadata {
1630                timeout: Some(std::time::Duration::from_secs(5)),
1631                ..Default::default()
1632            })
1633            .build()
1634            .unwrap();
1635
1636        let status = runner
1637            .run(workflow.workflow(), "inst-fast", 5u32)
1638            .await
1639            .unwrap();
1640        assert!(matches!(status, WorkflowStatus::Completed));
1641    }
1642
1643    #[tokio::test]
1644    async fn test_route_selects_correct_branch() {
1645        let backend = InMemoryBackend::new();
1646        let runner = CheckpointingRunner::new(backend.clone());
1647
1648        let workflow = WorkflowBuilder::new(ctx())
1649            .then("classify", |input: String| async move {
1650                Ok(serde_json::json!({ "intent": input }))
1651            })
1652            .route::<u32, RouteKey, _, _>(|data: serde_json::Value| async move {
1653                match data["intent"].as_str().unwrap_or("unknown") {
1654                    "billing" => Ok(RouteKey::Billing),
1655                    "tech" => Ok(RouteKey::Tech),
1656                    other => Err(format!("unknown intent: {other}").into()),
1657                }
1658            })
1659            .branch(RouteKey::Billing, |sub| {
1660                sub.then("handle_billing", |_data: serde_json::Value| async move {
1661                    Ok(100u32)
1662                })
1663            })
1664            .branch(RouteKey::Tech, |sub| {
1665                sub.then("handle_tech", |_data: serde_json::Value| async move {
1666                    Ok(200u32)
1667                })
1668            })
1669            .done()
1670            .build()
1671            .unwrap();
1672
1673        // Route to "billing"
1674        let status = runner
1675            .run(&workflow, "inst-branch-1", "billing".to_string())
1676            .await
1677            .unwrap();
1678        assert!(matches!(status, WorkflowStatus::Completed));
1679
1680        let snapshot = backend.load_snapshot("inst-branch-1").await.unwrap();
1681        // Workflow completed — check the final output (which is the branch envelope
1682        // since route is the last step)
1683        match &snapshot.state {
1684            WorkflowSnapshotState::Completed { final_output } => {
1685                let envelope: serde_json::Value = serde_json::from_slice(final_output).unwrap();
1686                assert_eq!(envelope["branch"], "billing");
1687                assert_eq!(envelope["result"], 100);
1688            }
1689            other => panic!("Expected Completed, got: {other:?}"),
1690        }
1691    }
1692
1693    #[tokio::test]
1694    async fn test_route_with_default() {
1695        let backend = InMemoryBackend::new();
1696        let runner = CheckpointingRunner::new(backend.clone());
1697
1698        // With typed keys the default branch catches enum variants that
1699        // don't have an explicit `.branch()` call.  Route "b" has no
1700        // branch, so the default fires.
1701        let workflow = WorkflowBuilder::new(ctx())
1702            .route::<String, AbKey, _, _>(|input: String| async move {
1703                match input.as_str() {
1704                    "a" => Ok(AbKey::A),
1705                    "b" => Ok(AbKey::B),
1706                    other => Err(format!("unknown: {other}").into()),
1707                }
1708            })
1709            .branch(AbKey::A, |sub| {
1710                sub.then("handle_a", |_data: String| async move {
1711                    Ok("matched".to_string())
1712                })
1713            })
1714            .default_branch(|sub| {
1715                sub.then("handle_fallback", |_data: String| async move {
1716                    Ok("fallback".to_string())
1717                })
1718            })
1719            .done()
1720            .build()
1721            .unwrap();
1722
1723        // Send "b" — not explicitly branched, so the default fires
1724        let status = runner
1725            .run(&workflow, "inst-branch-default", "b".to_string())
1726            .await
1727            .unwrap();
1728        assert!(matches!(status, WorkflowStatus::Completed));
1729
1730        let snapshot = backend.load_snapshot("inst-branch-default").await.unwrap();
1731        match &snapshot.state {
1732            WorkflowSnapshotState::Completed { final_output } => {
1733                let envelope: serde_json::Value = serde_json::from_slice(final_output).unwrap();
1734                assert_eq!(envelope["branch"], "b");
1735                assert_eq!(envelope["result"], "fallback");
1736            }
1737            other => panic!("Expected Completed, got: {other:?}"),
1738        }
1739    }
1740
1741    #[tokio::test]
1742    async fn test_route_missing_branches_detected() {
1743        // With typed keys, missing branches are caught at build time.
1744        // RouteKey has {billing, tech} but we only branch on billing → MissingBranches.
1745        let result = WorkflowBuilder::new(ctx())
1746            .route::<String, RouteKey, _, _>(|input: String| async move {
1747                match input.as_str() {
1748                    "billing" => Ok(RouteKey::Billing),
1749                    _ => Ok(RouteKey::Tech),
1750                }
1751            })
1752            .branch(RouteKey::Billing, |sub| {
1753                sub.then("handle_billing", |_data: String| async move {
1754                    Ok("ok".to_string())
1755                })
1756            })
1757            .done()
1758            .build();
1759
1760        let errors = match result {
1761            Err(e) => e,
1762            Ok(_) => panic!("expected build error"),
1763        };
1764        let has_missing = errors.iter().any(|e| {
1765            matches!(
1766                e,
1767                sayiir_core::error::BuildError::MissingBranches {
1768                    branch_id,
1769                    missing_keys,
1770                } if branch_id == "branch_1" && missing_keys.contains(&"tech".to_string())
1771            )
1772        });
1773        assert!(has_missing, "Expected MissingBranches error in: {errors:?}");
1774    }
1775
1776    #[tokio::test]
1777    async fn test_route_then_next_step() {
1778        use sayiir_core::task::BranchEnvelope;
1779
1780        let backend = InMemoryBackend::new();
1781        let runner = CheckpointingRunner::new(backend.clone());
1782
1783        let workflow = WorkflowBuilder::new(ctx())
1784            .route::<u32, AbKey, _, _>(|input: String| async move {
1785                match input.as_str() {
1786                    "a" => Ok(AbKey::A),
1787                    "b" => Ok(AbKey::B),
1788                    other => Err(format!("unknown: {other}").into()),
1789                }
1790            })
1791            .branch(AbKey::A, |sub| {
1792                sub.then("handle_a", |_data: String| async move { Ok(10u32) })
1793            })
1794            .branch(AbKey::B, |sub| {
1795                sub.then("handle_b", |_data: String| async move { Ok(20u32) })
1796            })
1797            .done()
1798            .then("finalize", |env: BranchEnvelope<u32>| async move {
1799                Ok(env.result + 1)
1800            })
1801            .build()
1802            .unwrap();
1803
1804        let status = runner
1805            .run(&workflow, "inst-branch-next", "a".to_string())
1806            .await
1807            .unwrap();
1808        assert!(matches!(status, WorkflowStatus::Completed));
1809
1810        let snapshot = backend.load_snapshot("inst-branch-next").await.unwrap();
1811        match &snapshot.state {
1812            WorkflowSnapshotState::Completed { final_output } => {
1813                let val: u32 = serde_json::from_slice(final_output).unwrap();
1814                assert_eq!(val, 11); // branch "a" returned 10, finalize adds 1
1815            }
1816            other => panic!("Expected Completed, got: {other:?}"),
1817        }
1818    }
1819}