Skip to main content

sayiir_runtime/runner/
distributed.rs

1//! Checkpointing workflow runner for single-process execution with persistence.
2//!
3//! This runner executes an entire workflow within a single process while saving
4//! snapshots after each task completion. This enables crash recovery and resumption
5//! without requiring multiple workers.
6//!
7//! **Use this when**: You want to run a workflow reliably on a single node with
8//! the ability to resume after crashes.
9//!
10//! **Use [`PooledWorker`](crate::worker::PooledWorker) instead when**: You need
11//! horizontal scaling with multiple workers collaborating on tasks.
12
13use bytes::Bytes;
14use sayiir_core::codec::Codec;
15use sayiir_core::codec::sealed;
16use sayiir_core::context::{WorkflowContext, with_context};
17use sayiir_core::error::WorkflowError;
18use sayiir_core::snapshot::{ExecutionPosition, SignalKind, SignalRequest, WorkflowSnapshot};
19use sayiir_core::workflow::{Workflow, WorkflowContinuation, WorkflowStatus};
20use sayiir_persistence::PersistentBackend;
21use std::sync::Arc;
22
23use crate::error::RuntimeError;
24use crate::execution::{
25    ForkBranchOutcome, JoinResolution, ResumeParkedPosition, branch_execute_or_skip_task,
26    check_guards, collect_cached_branches, execute_or_skip_task, finalize_execution,
27    get_resume_input, park_at_delay, park_at_signal, park_branch_at_delay, park_branch_at_signal,
28    resolve_join, retry_with_checkpoint, set_deadline_if_needed, settle_fork_outcome,
29};
30
31/// A single-process workflow runner with checkpointing for crash recovery.
32///
33/// `CheckpointingRunner` executes an entire workflow within one process,
34/// saving snapshots after each task. Fork branches run concurrently as tokio tasks.
35/// If the process crashes, the workflow can be resumed from the last checkpoint.
36///
37/// # When to Use
38///
39/// - **Single-node execution**: One process runs the entire workflow
40/// - **Crash recovery**: Resume from the last completed task after restart
41/// - **Simple deployment**: No coordination between workers needed
42///
43/// For horizontal scaling with multiple workers, use [`PooledWorker`](crate::worker::PooledWorker).
44///
45/// # Example
46///
47/// ```rust,no_run
48/// # use sayiir_runtime::prelude::*;
49/// # use std::sync::Arc;
50/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
51/// let backend = InMemoryBackend::new();
52/// let runner = CheckpointingRunner::new(backend);
53///
54/// let ctx = WorkflowContext::new("my-workflow", Arc::new(JsonCodec), Arc::new(()));
55/// let workflow = WorkflowBuilder::new(ctx)
56///     .then("step1", |i: u32| async move { Ok(i + 1) })
57///     .build()?;
58///
59/// // Run workflow - snapshots are saved automatically
60/// let status = runner.run(&workflow, "instance-123", 1u32).await?;
61///
62/// // Resume from checkpoint if needed (e.g., after crash)
63/// let status = runner.resume(&workflow, "instance-123").await?;
64/// # Ok(())
65/// # }
66/// ```
67pub struct CheckpointingRunner<B> {
68    backend: Arc<B>,
69}
70
71impl<B> CheckpointingRunner<B>
72where
73    B: PersistentBackend,
74{
75    /// Create a new checkpointing runner with the given backend.
76    pub fn new(backend: B) -> Self {
77        Self {
78            backend: Arc::new(backend),
79        }
80    }
81
82    /// Request cancellation of a workflow.
83    ///
84    /// This requests cancellation of the specified workflow instance.
85    /// The workflow will be cancelled at the next task boundary.
86    ///
87    /// # Parameters
88    ///
89    /// - `instance_id`: The workflow instance ID to cancel
90    /// - `reason`: Optional reason for the cancellation
91    /// - `cancelled_by`: Optional identifier of who requested the cancellation
92    ///
93    /// # Errors
94    ///
95    /// Returns an error if the workflow cannot be cancelled (not found or in terminal state).
96    pub async fn cancel(
97        &self,
98        instance_id: &str,
99        reason: Option<String>,
100        cancelled_by: Option<String>,
101    ) -> Result<(), RuntimeError> {
102        self.backend
103            .store_signal(
104                instance_id,
105                SignalKind::Cancel,
106                SignalRequest::new(reason, cancelled_by),
107            )
108            .await?;
109
110        Ok(())
111    }
112
113    /// Request pausing of a workflow.
114    ///
115    /// The workflow will be paused at the next task boundary.
116    ///
117    /// # Errors
118    ///
119    /// Returns an error if the backend fails to store the pause request.
120    pub async fn pause(
121        &self,
122        instance_id: &str,
123        reason: Option<String>,
124        paused_by: Option<String>,
125    ) -> Result<(), RuntimeError> {
126        self.backend
127            .store_signal(
128                instance_id,
129                SignalKind::Pause,
130                SignalRequest::new(reason, paused_by),
131            )
132            .await?;
133        Ok(())
134    }
135
136    /// Unpause a paused workflow and return the updated snapshot.
137    ///
138    /// Transitions the workflow from Paused back to `InProgress`.
139    ///
140    /// # Errors
141    ///
142    /// Returns an error if the backend fails to unpause the workflow.
143    pub async fn unpause(&self, instance_id: &str) -> Result<WorkflowSnapshot, RuntimeError> {
144        let snapshot = self.backend.unpause(instance_id).await?;
145        Ok(snapshot)
146    }
147
148    /// Get a reference to the backend.
149    #[must_use]
150    pub fn backend(&self) -> &Arc<B> {
151        &self.backend
152    }
153}
154
155impl<B> CheckpointingRunner<B>
156where
157    B: PersistentBackend + 'static,
158{
159    /// Run a workflow from the beginning, saving checkpoints after each task.
160    ///
161    /// The `instance_id` uniquely identifies this workflow execution instance.
162    /// If a snapshot with this ID already exists, it will be overwritten.
163    ///
164    /// # Errors
165    ///
166    /// Returns an error if the workflow cannot be executed or if snapshot
167    /// operations fail.
168    pub async fn run<C, Input, M>(
169        &self,
170        workflow: &Workflow<C, Input, M>,
171        instance_id: impl Into<String>,
172        input: Input,
173    ) -> Result<WorkflowStatus, RuntimeError>
174    where
175        Input: Send + 'static,
176        M: Send + Sync + 'static,
177        C: Codec + sealed::EncodeValue<Input> + sealed::DecodeValue<Input> + 'static,
178    {
179        let instance_id = instance_id.into();
180        let definition_hash = workflow.definition_hash().to_string();
181
182        // Encode initial input
183        let input_bytes = workflow.context().codec.encode(&input)?;
184
185        // Create initial snapshot with input
186        let mut snapshot = WorkflowSnapshot::with_initial_input(
187            instance_id.clone(),
188            definition_hash.clone(),
189            input_bytes.clone(),
190        );
191        snapshot.update_position(ExecutionPosition::AtTask {
192            task_id: workflow.continuation().first_task_id().to_string(),
193        });
194
195        // Save initial snapshot
196        self.backend.save_snapshot(&snapshot).await?;
197
198        // Execute workflow with checkpointing
199        let context = workflow.context().clone();
200        let continuation = workflow.continuation();
201        let backend = Arc::clone(&self.backend);
202
203        with_context(context.clone(), || async move {
204            let result = Self::execute_with_checkpointing(
205                continuation,
206                input_bytes,
207                &mut snapshot,
208                Arc::clone(&backend),
209                context,
210            )
211            .await;
212
213            let (status, _output) =
214                finalize_execution(result, &mut snapshot, backend.as_ref()).await?;
215            Ok(status)
216        })
217        .await
218    }
219
220    /// Resume a workflow from a saved snapshot.
221    ///
222    /// Loads the snapshot for the given instance ID and continues execution
223    /// from the last checkpoint.
224    ///
225    /// # Errors
226    ///
227    /// Returns an error if:
228    /// - The snapshot is not found
229    /// - The workflow definition hash doesn't match (workflow definition changed)
230    /// - The workflow cannot be resumed
231    #[allow(clippy::needless_lifetimes)]
232    pub async fn resume<'w, C, Input, M>(
233        &self,
234        workflow: &'w Workflow<C, Input, M>,
235        instance_id: &str,
236    ) -> Result<WorkflowStatus, RuntimeError>
237    where
238        Input: Send + 'static,
239        M: Send + Sync + 'static,
240        C: Codec + sealed::DecodeValue<Input> + sealed::EncodeValue<Input> + 'static,
241    {
242        // Load snapshot
243        let mut snapshot = self.backend.load_snapshot(instance_id).await?;
244
245        // Validate definition hash
246        if snapshot.definition_hash != workflow.definition_hash() {
247            return Err(WorkflowError::DefinitionMismatch {
248                expected: workflow.definition_hash().to_string(),
249                found: snapshot.definition_hash.clone(),
250            }
251            .into());
252        }
253
254        // Check if already in terminal state
255        if let Some(status) = snapshot.state.as_terminal_status() {
256            return Ok(status);
257        }
258
259        // Resolve any parked position (delay / fork) before resuming.
260        let parked = ResumeParkedPosition::extract(&snapshot);
261        if let Some(status) = parked
262            .resolve(&mut snapshot, instance_id, self.backend.as_ref())
263            .await?
264        {
265            return Ok(status);
266        }
267
268        // Resume execution
269        let context = workflow.context().clone();
270        let continuation = workflow.continuation();
271        let backend = Arc::clone(&self.backend);
272
273        with_context(context.clone(), || async move {
274            // Get the last completed task's output or initial input
275            let input_bytes = get_resume_input(&snapshot)?;
276
277            let result = Self::execute_with_checkpointing(
278                continuation,
279                input_bytes,
280                &mut snapshot,
281                Arc::clone(&backend),
282                context,
283            )
284            .await;
285
286            let (status, _output) =
287                finalize_execution(result, &mut snapshot, backend.as_ref()).await?;
288            Ok(status)
289        })
290        .await
291    }
292
293    /// Execute continuation with checkpointing after each task (iterative, no boxing).
294    #[allow(clippy::manual_async_fn, clippy::too_many_lines)]
295    async fn execute_with_checkpointing<'a, C, M>(
296        continuation: &'a WorkflowContinuation,
297        input: Bytes,
298        snapshot: &'a mut WorkflowSnapshot,
299        backend: Arc<B>,
300        context: WorkflowContext<C, M>,
301    ) -> Result<Bytes, RuntimeError>
302    where
303        B: 'static,
304        C: Codec + 'static,
305        M: Send + Sync + 'static,
306    {
307        let mut current = continuation;
308        let mut current_input = input;
309
310        loop {
311            match current {
312                WorkflowContinuation::Task {
313                    id,
314                    func: Some(func),
315                    timeout,
316                    retry_policy,
317                    next,
318                } => {
319                    check_guards(backend.as_ref(), &snapshot.instance_id, Some(id)).await?;
320                    set_deadline_if_needed(id, timeout.as_ref(), snapshot, backend.as_ref())
321                        .await?;
322
323                    let output = retry_with_checkpoint(
324                        id,
325                        retry_policy.as_ref(),
326                        timeout.as_ref(),
327                        snapshot,
328                        Some(backend.as_ref()),
329                        async |snap| {
330                            execute_or_skip_task(id, current_input.clone(), |i| func.run(i), snap)
331                                .await
332                        },
333                    )
334                    .await?;
335
336                    if let Some(next_cont) = next {
337                        snapshot.update_position(ExecutionPosition::AtTask {
338                            task_id: next_cont.first_task_id().to_string(),
339                        });
340                    }
341                    backend.save_snapshot(snapshot).await?;
342                    check_guards(backend.as_ref(), &snapshot.instance_id, None).await?;
343
344                    match next {
345                        Some(next_continuation) => {
346                            current = next_continuation;
347                            current_input = output;
348                        }
349                        None => return Ok(output),
350                    }
351                }
352                WorkflowContinuation::Task { func: None, id, .. } => {
353                    return Err(WorkflowError::TaskNotImplemented(id.clone()).into());
354                }
355                WorkflowContinuation::Delay { id, duration, next } => {
356                    check_guards(backend.as_ref(), &snapshot.instance_id, Some(id)).await?;
357
358                    if snapshot.get_task_result(id).is_some() {
359                        match next {
360                            Some(next_continuation) => {
361                                current = next_continuation;
362                                continue;
363                            }
364                            None => return Ok(current_input),
365                        }
366                    }
367
368                    return Err(park_at_delay(
369                        id,
370                        duration,
371                        next.as_deref(),
372                        current_input,
373                        snapshot,
374                        backend.as_ref(),
375                    )
376                    .await);
377                }
378                WorkflowContinuation::AwaitSignal {
379                    id,
380                    signal_name,
381                    timeout,
382                    next,
383                } => {
384                    check_guards(backend.as_ref(), &snapshot.instance_id, Some(id)).await?;
385
386                    if snapshot.get_task_result(id).is_some() {
387                        match next {
388                            Some(n) => {
389                                current = n;
390                                current_input =
391                                    snapshot.get_task_result_bytes(id).unwrap_or(current_input);
392                                continue;
393                            }
394                            None => return Ok(current_input),
395                        }
396                    }
397
398                    let err = park_at_signal(
399                        id,
400                        signal_name,
401                        timeout.as_ref(),
402                        next.as_deref(),
403                        snapshot,
404                        backend.as_ref(),
405                    )
406                    .await;
407
408                    if matches!(err, RuntimeError::Workflow(WorkflowError::SignalConsumed)) {
409                        if let Some(n) = next {
410                            current = n;
411                            current_input =
412                                snapshot.get_task_result_bytes(id).unwrap_or(current_input);
413                            continue;
414                        }
415                        let output = snapshot.get_task_result_bytes(id).unwrap_or(current_input);
416                        return Ok(output);
417                    }
418
419                    return Err(err);
420                }
421                WorkflowContinuation::Fork {
422                    id: fork_id,
423                    branches,
424                    join,
425                } => {
426                    check_guards(backend.as_ref(), &snapshot.instance_id, None).await?;
427
428                    let branch_results =
429                        if let Some(cached) = collect_cached_branches(branches, snapshot) {
430                            cached
431                        } else {
432                            let outcome = Self::execute_fork_branches_parallel(
433                                branches,
434                                &current_input,
435                                snapshot,
436                                &backend,
437                                &context,
438                            )
439                            .await?;
440                            settle_fork_outcome(
441                                fork_id,
442                                outcome,
443                                join.as_deref(),
444                                snapshot,
445                                backend.as_ref(),
446                            )
447                            .await?
448                        };
449
450                    match resolve_join(join.as_deref(), &branch_results)? {
451                        JoinResolution::Continue { next, input } => {
452                            current = next;
453                            current_input = input;
454                        }
455                        JoinResolution::Done(output) => return Ok(output),
456                    }
457                }
458            }
459        }
460    }
461
462    /// Execute fork branches in parallel using tokio tasks.
463    async fn execute_fork_branches_parallel<C, M>(
464        branches: &[Arc<WorkflowContinuation>],
465        input: &Bytes,
466        snapshot: &WorkflowSnapshot,
467        backend: &Arc<B>,
468        context: &WorkflowContext<C, M>,
469    ) -> Result<ForkBranchOutcome, RuntimeError>
470    where
471        B: 'static,
472        C: Codec + 'static,
473        M: Send + Sync + 'static,
474    {
475        let mut branch_results = Vec::with_capacity(branches.len());
476        let mut set = tokio::task::JoinSet::new();
477        let instance_id = snapshot.instance_id.clone();
478
479        for branch in branches {
480            let branch_id = branch.id().to_string();
481
482            if let Some(result) = snapshot.get_task_result(&branch_id) {
483                branch_results.push((branch_id, result.output.clone()));
484            } else {
485                let branch = Arc::clone(branch);
486                let branch_input = input.clone();
487                let branch_backend = Arc::clone(backend);
488                let branch_instance_id = instance_id.clone();
489                let ctx_for_work = context.clone();
490
491                set.spawn(with_context(context.clone(), || async move {
492                    let result = Self::execute_branch_with_checkpoint(
493                        &branch,
494                        branch_input,
495                        branch_backend,
496                        branch_instance_id,
497                        ctx_for_work,
498                    )
499                    .await?;
500                    Ok((branch_id, result))
501                }));
502            }
503        }
504
505        let mut max_wake_at: Option<chrono::DateTime<chrono::Utc>> = None;
506
507        while let Some(result) = set.join_next().await {
508            match result {
509                Ok(Ok((branch_id, output))) => {
510                    branch_results.push((branch_id, output));
511                }
512                Ok(Err(RuntimeError::Workflow(WorkflowError::Waiting { wake_at }))) => {
513                    max_wake_at = Some(match max_wake_at {
514                        Some(existing) => existing.max(wake_at),
515                        None => wake_at,
516                    });
517                }
518                Ok(Err(e)) => return Err(e),
519                Err(join_err) => return Err(RuntimeError::from(join_err)),
520            }
521        }
522
523        Ok(ForkBranchOutcome {
524            results: branch_results,
525            max_wake_at,
526        })
527    }
528
529    /// Execute nested fork branches in parallel within a branch.
530    ///
531    /// Spawns each branch as a tokio task, collects all results, and propagates
532    /// errors (including `JoinError`).
533    async fn execute_nested_fork_branches<C, M>(
534        branches: &[Arc<WorkflowContinuation>],
535        input: &Bytes,
536        backend: &Arc<B>,
537        instance_id: &str,
538        context: &WorkflowContext<C, M>,
539    ) -> Result<Vec<(String, Bytes)>, RuntimeError>
540    where
541        B: 'static,
542        C: Codec + 'static,
543        M: Send + Sync + 'static,
544    {
545        let mut set: tokio::task::JoinSet<Result<(String, Bytes), RuntimeError>> =
546            tokio::task::JoinSet::new();
547        for branch in branches {
548            let id = branch.id().to_string();
549            let branch = Arc::clone(branch);
550            let branch_input = input.clone();
551            let branch_backend = Arc::clone(backend);
552            let branch_instance_id = instance_id.to_string();
553            let ctx_for_work = context.clone();
554
555            set.spawn(with_context(context.clone(), || async move {
556                let result = Self::execute_branch_with_checkpoint(
557                    &branch,
558                    branch_input,
559                    branch_backend,
560                    branch_instance_id,
561                    ctx_for_work,
562                )
563                .await?;
564                Ok((id, result))
565            }));
566        }
567
568        let mut branch_results: Vec<(String, Bytes)> = Vec::with_capacity(set.len());
569        while let Some(res) = set.join_next().await {
570            branch_results.push(res??);
571        }
572        Ok(branch_results)
573    }
574
575    /// Execute branch continuation with per-task checkpointing (iterative, no boxing).
576    ///
577    /// Unlike `execute_with_checkpointing`, this doesn't update position tracking
578    /// (branches run independently). It saves each task result directly to the backend.
579    ///
580    /// On resume after `AtFork`, the backend snapshot contains sub-task results from
581    /// the previous execution. This function loads the snapshot to skip cached tasks
582    /// and parks at delays instead of sleeping through them.
583    #[allow(clippy::manual_async_fn, clippy::too_many_lines)]
584    fn execute_branch_with_checkpoint<C, M>(
585        continuation: &WorkflowContinuation,
586        input: Bytes,
587        backend: Arc<B>,
588        instance_id: String,
589        context: WorkflowContext<C, M>,
590    ) -> impl std::future::Future<Output = Result<Bytes, RuntimeError>> + Send + '_
591    where
592        B: 'static,
593        C: Codec + 'static,
594        M: Send + Sync + 'static,
595    {
596        async move {
597            // Load snapshot for checking cached results (populated on resume after AtFork)
598            let mut snapshot = backend.load_snapshot(&instance_id).await?;
599
600            let mut current = continuation;
601            let mut current_input = input;
602
603            loop {
604                match current {
605                    WorkflowContinuation::Task {
606                        id,
607                        func,
608                        timeout,
609                        retry_policy,
610                        next,
611                    } => {
612                        let func = func
613                            .as_ref()
614                            .ok_or_else(|| WorkflowError::TaskNotImplemented(id.clone()))?;
615
616                        let output = loop {
617                            match branch_execute_or_skip_task(
618                                id,
619                                current_input.clone(),
620                                |i| func.run(i),
621                                timeout.as_ref(),
622                                &mut snapshot,
623                                &instance_id,
624                                backend.as_ref(),
625                            )
626                            .await
627                            {
628                                Ok(output) => {
629                                    snapshot.clear_retry_state(id);
630                                    break output;
631                                }
632                                Err(e) => {
633                                    if let Some(rp) = retry_policy
634                                        && !snapshot.retries_exhausted(id)
635                                    {
636                                        let next_retry_at =
637                                            snapshot.record_retry(id, rp, &e.to_string(), None);
638                                        snapshot.clear_task_deadline();
639                                        tracing::info!(
640                                            task_id = %id,
641                                            attempt = snapshot.get_retry_state(id).map_or(0, |rs| rs.attempts),
642                                            max_retries = rp.max_retries,
643                                            %next_retry_at,
644                                            error = %e,
645                                            "Retrying task (branch)"
646                                        );
647                                        let delay = (next_retry_at - chrono::Utc::now())
648                                            .to_std()
649                                            .unwrap_or_default();
650                                        tokio::time::sleep(delay).await;
651                                        continue;
652                                    }
653                                    return Err(e);
654                                }
655                            }
656                        };
657
658                        match next {
659                            Some(next_continuation) => {
660                                current = next_continuation;
661                                current_input = output;
662                            }
663                            None => return Ok(output),
664                        }
665                    }
666                    WorkflowContinuation::Delay { id, duration, next } => {
667                        // Skip if pass-through was already saved (resume case)
668                        if let Some(result) = snapshot.get_task_result(id) {
669                            tracing::debug!(delay_id = %id, "delay already completed in branch, skipping");
670                            match next {
671                                Some(next_cont) => {
672                                    current = next_cont;
673                                    current_input = result.output.clone();
674                                    continue;
675                                }
676                                None => return Ok(result.output.clone()),
677                            }
678                        }
679
680                        return Err(park_branch_at_delay(
681                            id,
682                            duration,
683                            current_input,
684                            &instance_id,
685                            backend.as_ref(),
686                        )
687                        .await);
688                    }
689                    WorkflowContinuation::AwaitSignal {
690                        id,
691                        signal_name,
692                        timeout,
693                        next,
694                    } => {
695                        // Skip if signal was already consumed (resume case)
696                        if let Some(result) = snapshot.get_task_result(id) {
697                            tracing::debug!(signal_id = %id, %signal_name, "signal already consumed in branch, skipping");
698                            match next {
699                                Some(next_cont) => {
700                                    current = next_cont;
701                                    current_input = result.output.clone();
702                                    continue;
703                                }
704                                None => return Ok(result.output.clone()),
705                            }
706                        }
707
708                        return Err(park_branch_at_signal(
709                            id,
710                            signal_name,
711                            timeout.as_ref(),
712                            current_input,
713                            &instance_id,
714                            backend.as_ref(),
715                        )
716                        .await);
717                    }
718                    WorkflowContinuation::Fork { branches, join, .. } => {
719                        let branch_results = Self::execute_nested_fork_branches(
720                            branches,
721                            &current_input,
722                            &backend,
723                            &instance_id,
724                            &context,
725                        )
726                        .await?;
727
728                        match resolve_join(join.as_deref(), &branch_results)? {
729                            JoinResolution::Continue { next, input } => {
730                                current = next;
731                                current_input = input;
732                            }
733                            JoinResolution::Done(output) => return Ok(output),
734                        }
735                    }
736                }
737            }
738        }
739    }
740}
741
742#[cfg(test)]
743#[allow(
744    clippy::unwrap_used,
745    clippy::expect_used,
746    clippy::panic,
747    clippy::indexing_slicing,
748    clippy::too_many_lines
749)]
750mod tests {
751    use super::*;
752    use crate::serialization::JsonCodec;
753    use sayiir_core::codec::Encoder;
754    use sayiir_core::context::WorkflowContext;
755    use sayiir_core::error::BoxError;
756    use sayiir_core::snapshot::WorkflowSnapshotState;
757    use sayiir_core::task::BranchOutputs;
758    use sayiir_core::workflow::WorkflowBuilder;
759    use sayiir_persistence::InMemoryBackend;
760    use sayiir_persistence::{SignalStore, SnapshotStore};
761
762    fn ctx() -> WorkflowContext<JsonCodec, ()> {
763        WorkflowContext::new("test-workflow", Arc::new(JsonCodec), Arc::new(()))
764    }
765
766    // ========================================================================
767    // Run (fresh execution)
768    // ========================================================================
769
770    #[tokio::test]
771    async fn test_run_single_task() {
772        let backend = InMemoryBackend::new();
773        let runner = CheckpointingRunner::new(backend);
774
775        let workflow = WorkflowBuilder::new(ctx())
776            .then("add_one", |i: u32| async move { Ok(i + 1) })
777            .build()
778            .unwrap();
779
780        let status = runner.run(&workflow, "inst-1", 5u32).await.unwrap();
781        assert!(matches!(status, WorkflowStatus::Completed));
782
783        // Verify snapshot was saved as completed
784        let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
785        assert!(snapshot.state.is_completed());
786    }
787
788    #[tokio::test]
789    async fn test_run_chained_tasks() {
790        let backend = InMemoryBackend::new();
791        let runner = CheckpointingRunner::new(backend);
792
793        let workflow = WorkflowBuilder::new(ctx())
794            .then("add_one", |i: u32| async move { Ok(i + 1) })
795            .then("double", |i: u32| async move { Ok(i * 2) })
796            .build()
797            .unwrap();
798
799        let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
800        assert!(matches!(status, WorkflowStatus::Completed));
801
802        let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
803        assert!(snapshot.state.is_completed());
804    }
805
806    #[tokio::test]
807    async fn test_run_three_task_chain() {
808        let backend = InMemoryBackend::new();
809        let runner = CheckpointingRunner::new(backend);
810
811        let workflow = WorkflowBuilder::new(ctx())
812            .then("step1", |i: u32| async move { Ok(i + 1) })
813            .then("step2", |i: u32| async move { Ok(i * 3) })
814            .then("step3", |i: u32| async move { Ok(i - 2) })
815            .build()
816            .unwrap();
817
818        let status = runner.run(&workflow, "inst-1", 5u32).await.unwrap();
819        // 5+1=6, 6*3=18, 18-2=16
820        assert!(matches!(status, WorkflowStatus::Completed));
821    }
822
823    #[tokio::test]
824    async fn test_run_task_failure() {
825        let backend = InMemoryBackend::new();
826        let runner = CheckpointingRunner::new(backend);
827
828        let workflow = WorkflowBuilder::new(ctx())
829            .then("fail", |_i: u32| async move {
830                Err::<u32, BoxError>("intentional failure".into())
831            })
832            .build()
833            .unwrap();
834
835        let status = runner.run(&workflow, "inst-1", 1u32).await.unwrap();
836        match status {
837            WorkflowStatus::Failed(e) => {
838                assert!(e.contains("intentional failure"));
839            }
840            _ => panic!("Expected Failed status"),
841        }
842
843        // Snapshot should be marked as failed
844        let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
845        assert!(snapshot.state.is_failed());
846    }
847
848    #[tokio::test]
849    async fn test_run_fork_join() {
850        let backend = InMemoryBackend::new();
851        let runner = CheckpointingRunner::new(backend);
852
853        let workflow = WorkflowBuilder::new(ctx())
854            .then("prepare", |i: u32| async move { Ok(i) })
855            .branches(|b| {
856                b.add("double", |i: u32| async move { Ok(i * 2) });
857                b.add("add_ten", |i: u32| async move { Ok(i + 10) });
858            })
859            .join("combine", |outputs: BranchOutputs<JsonCodec>| async move {
860                let doubled: u32 = outputs.get("double")?;
861                let added: u32 = outputs.get("add_ten")?;
862                Ok(doubled + added)
863            })
864            .build()
865            .unwrap();
866
867        let status = runner.run(&workflow, "inst-1", 5u32).await.unwrap();
868        assert!(matches!(status, WorkflowStatus::Completed));
869
870        let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
871        assert!(snapshot.state.is_completed());
872    }
873
874    #[tokio::test]
875    async fn test_run_checkpoints_intermediate_tasks() {
876        let backend = InMemoryBackend::new();
877        let runner = CheckpointingRunner::new(backend);
878
879        let workflow = WorkflowBuilder::new(ctx())
880            .then("step1", |i: u32| async move { Ok(i + 1) })
881            .then("step2", |i: u32| async move { Ok(i * 2) })
882            .build()
883            .unwrap();
884
885        let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
886        assert!(matches!(status, WorkflowStatus::Completed));
887
888        // The final snapshot should be completed, but we can verify the
889        // instance was tracked throughout
890        let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
891        assert!(snapshot.state.is_completed());
892    }
893
894    // ========================================================================
895    // Resume
896    // ========================================================================
897
898    #[tokio::test]
899    async fn test_resume_completed_workflow() {
900        let backend = InMemoryBackend::new();
901        let runner = CheckpointingRunner::new(backend);
902
903        let workflow = WorkflowBuilder::new(ctx())
904            .then("step1", |i: u32| async move { Ok(i + 1) })
905            .build()
906            .unwrap();
907
908        // Run to completion
909        runner.run(&workflow, "inst-1", 5u32).await.unwrap();
910
911        // Resume should return Completed immediately
912        let status = runner.resume(&workflow, "inst-1").await.unwrap();
913        assert!(matches!(status, WorkflowStatus::Completed));
914    }
915
916    #[tokio::test]
917    async fn test_resume_failed_workflow() {
918        let backend = InMemoryBackend::new();
919        let runner = CheckpointingRunner::new(backend);
920
921        let workflow = WorkflowBuilder::new(ctx())
922            .then("fail", |_i: u32| async move {
923                Err::<u32, BoxError>("failure".into())
924            })
925            .build()
926            .unwrap();
927
928        runner.run(&workflow, "inst-1", 1u32).await.unwrap();
929
930        let status = runner.resume(&workflow, "inst-1").await.unwrap();
931        match status {
932            WorkflowStatus::Failed(_) => {}
933            _ => panic!("Expected Failed status"),
934        }
935    }
936
937    #[tokio::test]
938    async fn test_resume_definition_hash_mismatch() {
939        let backend = InMemoryBackend::new();
940        let runner = CheckpointingRunner::new(backend);
941
942        let workflow1 = WorkflowBuilder::new(ctx())
943            .then("step1", |i: u32| async move { Ok(i + 1) })
944            .build()
945            .unwrap();
946
947        // Run with workflow1
948        runner.run(&workflow1, "inst-1", 5u32).await.unwrap();
949
950        // Manually create in-progress snapshot with workflow1's hash
951        let mut snapshot = WorkflowSnapshot::with_initial_input(
952            "inst-2".into(),
953            workflow1.definition_hash().to_string(),
954            Bytes::from(serde_json::to_vec(&5u32).unwrap()),
955        );
956        snapshot.update_position(ExecutionPosition::AtTask {
957            task_id: "step1".into(),
958        });
959        runner.backend().save_snapshot(&snapshot).await.unwrap();
960
961        // Build a different workflow
962        let workflow2 = WorkflowBuilder::new(ctx())
963            .then("step1", |i: u32| async move { Ok(i + 1) })
964            .then("step2", |i: u32| async move { Ok(i * 2) })
965            .build()
966            .unwrap();
967
968        // Resume with different workflow definition should fail
969        let result = runner.resume(&workflow2, "inst-2").await;
970        assert!(result.is_err());
971        assert!(result.unwrap_err().to_string().contains("mismatch"));
972    }
973
974    // ========================================================================
975    // Cancellation
976    // ========================================================================
977
978    #[tokio::test]
979    async fn test_cancel_running_workflow() {
980        let backend = InMemoryBackend::new();
981        let runner = CheckpointingRunner::new(backend);
982
983        // Create a workflow with a slow task
984        let workflow = WorkflowBuilder::new(ctx())
985            .then("slow_task", |i: u32| async move {
986                tokio::time::sleep(std::time::Duration::from_secs(10)).await;
987                Ok(i)
988            })
989            .build()
990            .unwrap();
991
992        // Set up a snapshot as if it's in progress
993        let input_bytes = Arc::new(JsonCodec).encode(&1u32).unwrap();
994        let mut snapshot = WorkflowSnapshot::with_initial_input(
995            "inst-cancel".into(),
996            workflow.definition_hash().to_string(),
997            input_bytes,
998        );
999        snapshot.update_position(ExecutionPosition::AtTask {
1000            task_id: "slow_task".into(),
1001        });
1002        runner.backend().save_snapshot(&snapshot).await.unwrap();
1003
1004        // Request cancellation
1005        runner
1006            .cancel(
1007                "inst-cancel",
1008                Some("testing".into()),
1009                Some("test-suite".into()),
1010            )
1011            .await
1012            .unwrap();
1013
1014        // Verify cancellation request was stored
1015        let req = runner
1016            .backend()
1017            .get_signal("inst-cancel", SignalKind::Cancel)
1018            .await
1019            .unwrap();
1020        assert!(req.is_some());
1021        assert_eq!(req.unwrap().reason, Some("testing".into()));
1022    }
1023
1024    #[tokio::test]
1025    async fn test_run_with_pre_cancellation() {
1026        let backend = InMemoryBackend::new();
1027        let runner = CheckpointingRunner::new(backend);
1028
1029        let workflow = WorkflowBuilder::new(ctx())
1030            .then("task1", |i: u32| async move { Ok(i + 1) })
1031            .then("task2", |i: u32| async move { Ok(i * 2) })
1032            .build()
1033            .unwrap();
1034
1035        // Save initial snapshot and request cancellation before running
1036        let input_bytes = Arc::new(JsonCodec).encode(&1u32).unwrap();
1037        let mut snapshot = WorkflowSnapshot::with_initial_input(
1038            "inst-precancel".into(),
1039            workflow.definition_hash().to_string(),
1040            input_bytes,
1041        );
1042        snapshot.update_position(ExecutionPosition::AtTask {
1043            task_id: "task1".into(),
1044        });
1045        runner.backend().save_snapshot(&snapshot).await.unwrap();
1046
1047        runner
1048            .cancel("inst-precancel", Some("pre-cancel".into()), None)
1049            .await
1050            .unwrap();
1051
1052        // Resume should detect cancellation
1053        let status = runner.resume(&workflow, "inst-precancel").await.unwrap();
1054        match status {
1055            WorkflowStatus::Cancelled { reason, .. } => {
1056                assert_eq!(reason, Some("pre-cancel".into()));
1057            }
1058            _ => panic!("Expected Cancelled status, got: {status:?}"),
1059        }
1060    }
1061
1062    // ========================================================================
1063    // Edge cases
1064    // ========================================================================
1065
1066    #[tokio::test]
1067    async fn test_resume_nonexistent_instance() {
1068        let backend = InMemoryBackend::new();
1069        let runner = CheckpointingRunner::new(backend);
1070
1071        let workflow = WorkflowBuilder::new(ctx())
1072            .then("task", |i: u32| async move { Ok(i) })
1073            .build()
1074            .unwrap();
1075
1076        let result = runner.resume(&workflow, "nonexistent").await;
1077        assert!(result.is_err());
1078    }
1079
1080    #[tokio::test]
1081    async fn test_run_failure_in_chain_saves_snapshot() {
1082        let backend = InMemoryBackend::new();
1083        let runner = CheckpointingRunner::new(backend);
1084
1085        let workflow = WorkflowBuilder::new(ctx())
1086            .then("step1", |i: u32| async move { Ok(i + 1) })
1087            .then("fail_step", |_i: u32| async move {
1088                Err::<u32, BoxError>("mid-chain failure".into())
1089            })
1090            .then("step3", |i: u32| async move { Ok(i * 2) })
1091            .build()
1092            .unwrap();
1093
1094        let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1095        match status {
1096            WorkflowStatus::Failed(e) => {
1097                assert!(e.contains("mid-chain failure"));
1098            }
1099            _ => panic!("Expected Failed"),
1100        }
1101
1102        // Snapshot should be saved as failed
1103        let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
1104        assert!(snapshot.state.is_failed());
1105    }
1106
1107    // ========================================================================
1108    // Delay tests
1109    // ========================================================================
1110
1111    #[tokio::test]
1112    async fn test_run_workflow_with_delay_returns_waiting() {
1113        let backend = InMemoryBackend::new();
1114        let runner = CheckpointingRunner::new(backend);
1115
1116        let workflow = WorkflowBuilder::new(ctx())
1117            .then("step1", |i: u32| async move { Ok(i + 1) })
1118            .delay("wait_1h", std::time::Duration::from_secs(3600))
1119            .then("step2", |i: u32| async move { Ok(i * 2) })
1120            .build()
1121            .unwrap();
1122
1123        let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1124
1125        // Should return Waiting (delay is 1 hour in the future)
1126        match &status {
1127            WorkflowStatus::Waiting { delay_id, .. } => {
1128                assert_eq!(delay_id, "wait_1h");
1129            }
1130            _ => panic!("Expected Waiting status, got {status:?}"),
1131        }
1132
1133        // Snapshot should be in-progress at AtDelay position
1134        let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
1135        assert!(snapshot.state.is_in_progress());
1136        match &snapshot.state {
1137            WorkflowSnapshotState::InProgress { position, .. } => match position {
1138                ExecutionPosition::AtDelay {
1139                    delay_id,
1140                    next_task_id,
1141                    ..
1142                } => {
1143                    assert_eq!(delay_id, "wait_1h");
1144                    assert_eq!(next_task_id.as_deref(), Some("step2"));
1145                }
1146                other => panic!("Expected AtDelay, got {other:?}"),
1147            },
1148            _ => panic!("Expected InProgress"),
1149        }
1150
1151        // step1 should have been completed
1152        assert!(snapshot.get_task_result("step1").is_some());
1153        // delay pass-through should be stored
1154        assert!(snapshot.get_task_result("wait_1h").is_some());
1155    }
1156
1157    #[tokio::test]
1158    async fn test_resume_before_delay_expires_returns_waiting() {
1159        let backend = InMemoryBackend::new();
1160        let runner = CheckpointingRunner::new(backend);
1161
1162        let workflow = WorkflowBuilder::new(ctx())
1163            .then("step1", |i: u32| async move { Ok(i + 1) })
1164            .delay("wait_1h", std::time::Duration::from_secs(3600))
1165            .then("step2", |i: u32| async move { Ok(i * 2) })
1166            .build()
1167            .unwrap();
1168
1169        // Run to delay
1170        let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1171        assert!(matches!(status, WorkflowStatus::Waiting { .. }));
1172
1173        // Resume immediately (delay hasn't expired)
1174        let status = runner.resume(&workflow, "inst-1").await.unwrap();
1175        match &status {
1176            WorkflowStatus::Waiting { delay_id, .. } => {
1177                assert_eq!(delay_id, "wait_1h");
1178            }
1179            _ => panic!("Expected Waiting on resume, got {status:?}"),
1180        }
1181    }
1182
1183    #[tokio::test]
1184    async fn test_resume_after_delay_expires_completes() {
1185        let backend = InMemoryBackend::new();
1186        let runner = CheckpointingRunner::new(backend);
1187
1188        // Use a very short delay so it expires immediately
1189        let workflow = WorkflowBuilder::new(ctx())
1190            .then("step1", |i: u32| async move { Ok(i + 1) })
1191            .delay("wait_short", std::time::Duration::from_millis(1))
1192            .then("step2", |i: u32| async move { Ok(i * 2) })
1193            .build()
1194            .unwrap();
1195
1196        // Run — delay is so short it should still park (snapshot is saved before checking time)
1197        let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1198        assert!(matches!(status, WorkflowStatus::Waiting { .. }));
1199
1200        // Wait a bit for the delay to definitely expire
1201        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1202
1203        // Resume — delay should have expired, execution continues
1204        let status = runner.resume(&workflow, "inst-1").await.unwrap();
1205        assert!(
1206            matches!(status, WorkflowStatus::Completed),
1207            "Expected Completed after delay expired, got {status:?}"
1208        );
1209
1210        // Verify final state
1211        let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
1212        assert!(snapshot.state.is_completed());
1213    }
1214
1215    #[tokio::test]
1216    async fn test_cancel_during_delay() {
1217        let backend = InMemoryBackend::new();
1218        let runner = CheckpointingRunner::new(backend);
1219
1220        let workflow = WorkflowBuilder::new(ctx())
1221            .then("step1", |i: u32| async move { Ok(i + 1) })
1222            .delay("wait_1h", std::time::Duration::from_secs(3600))
1223            .then("step2", |i: u32| async move { Ok(i * 2) })
1224            .build()
1225            .unwrap();
1226
1227        // Run to delay
1228        let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1229        assert!(matches!(status, WorkflowStatus::Waiting { .. }));
1230
1231        // Cancel during delay
1232        runner
1233            .cancel(
1234                "inst-1",
1235                Some("no longer needed".into()),
1236                Some("admin".into()),
1237            )
1238            .await
1239            .unwrap();
1240
1241        // Resume should detect cancellation
1242        let status = runner.resume(&workflow, "inst-1").await.unwrap();
1243        match status {
1244            WorkflowStatus::Cancelled {
1245                reason,
1246                cancelled_by,
1247            } => {
1248                assert_eq!(reason, Some("no longer needed".into()));
1249                assert_eq!(cancelled_by, Some("admin".into()));
1250            }
1251            _ => panic!("Expected Cancelled status, got {status:?}"),
1252        }
1253    }
1254
1255    #[tokio::test]
1256    async fn test_delay_as_last_node() {
1257        let backend = InMemoryBackend::new();
1258        let runner = CheckpointingRunner::new(backend);
1259
1260        let workflow = WorkflowBuilder::new(ctx())
1261            .then("step1", |i: u32| async move { Ok(i + 1) })
1262            .delay("final_wait", std::time::Duration::from_millis(1))
1263            .build()
1264            .unwrap();
1265
1266        // Run to delay
1267        let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1268        assert!(matches!(status, WorkflowStatus::Waiting { .. }));
1269
1270        // Wait for delay to expire
1271        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1272
1273        // Resume — delay was the last node, should complete
1274        let status = runner.resume(&workflow, "inst-1").await.unwrap();
1275        assert!(
1276            matches!(status, WorkflowStatus::Completed),
1277            "Expected Completed when delay is last node, got {status:?}"
1278        );
1279    }
1280
1281    #[tokio::test]
1282    async fn test_delay_data_passthrough() {
1283        let backend = InMemoryBackend::new();
1284        let runner = CheckpointingRunner::new(backend);
1285
1286        // step1 produces 11, delay passes it through, step2 receives 11 and doubles
1287        let workflow = WorkflowBuilder::new(ctx())
1288            .then("step1", |i: u32| async move { Ok(i + 1) })
1289            .delay("wait", std::time::Duration::from_millis(1))
1290            .then("step2", |i: u32| async move {
1291                // Verify input is the passthrough value from step1
1292                assert_eq!(i, 11);
1293                Ok(i * 2)
1294            })
1295            .build()
1296            .unwrap();
1297
1298        // Run to delay
1299        runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1300
1301        // Wait and resume
1302        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1303        let status = runner.resume(&workflow, "inst-1").await.unwrap();
1304        assert!(matches!(status, WorkflowStatus::Completed));
1305    }
1306
1307    // ========================================================================
1308    // Timeout tests
1309    // ========================================================================
1310
1311    #[tokio::test]
1312    async fn test_run_task_timeout_fails_workflow() {
1313        use sayiir_core::task::TaskMetadata;
1314
1315        let backend = InMemoryBackend::new();
1316        let runner = CheckpointingRunner::new(backend);
1317
1318        let workflow = WorkflowBuilder::new(ctx())
1319            .with_registry()
1320            .then("slow_task", |i: u32| async move {
1321                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1322                Ok(i)
1323            })
1324            .with_metadata(TaskMetadata {
1325                timeout: Some(std::time::Duration::from_millis(5)),
1326                ..Default::default()
1327            })
1328            .build()
1329            .unwrap();
1330
1331        let status = runner
1332            .run(workflow.workflow(), "inst-timeout", 5u32)
1333            .await
1334            .unwrap();
1335        match status {
1336            WorkflowStatus::Failed(msg) => {
1337                assert!(
1338                    msg.contains("timed out"),
1339                    "Expected timeout error, got: {msg}"
1340                );
1341                assert!(
1342                    msg.contains("slow_task"),
1343                    "Expected task id in error, got: {msg}"
1344                );
1345            }
1346            other => panic!("Expected Failed status, got {other:?}"),
1347        }
1348    }
1349
1350    #[tokio::test]
1351    async fn test_run_task_within_timeout_succeeds() {
1352        use sayiir_core::task::TaskMetadata;
1353
1354        let backend = InMemoryBackend::new();
1355        let runner = CheckpointingRunner::new(backend);
1356
1357        let workflow = WorkflowBuilder::new(ctx())
1358            .with_registry()
1359            .then("fast_task", |i: u32| async move { Ok(i + 1) })
1360            .with_metadata(TaskMetadata {
1361                timeout: Some(std::time::Duration::from_secs(5)),
1362                ..Default::default()
1363            })
1364            .build()
1365            .unwrap();
1366
1367        let status = runner
1368            .run(workflow.workflow(), "inst-fast", 5u32)
1369            .await
1370            .unwrap();
1371        assert!(matches!(status, WorkflowStatus::Completed));
1372    }
1373}