Skip to main content

cognis_graph/
durability.rs

1//! Compile-time setting controlling **when** checkpoints are persisted
2//! relative to a superstep's execution.
3//!
4//! Mirrors V1's `Durability` modes:
5//!
6//! - [`Durability::Sync`] (default) — `cp.save(...)` is awaited inline at
7//!   the end of every superstep. Strongest guarantee; the next step never
8//!   starts until the previous step's state is durably persisted. Used in
9//!   tests and production deployments where loss of a single step's work
10//!   is unacceptable.
11//!
12//! - [`Durability::Async`] — `cp.save(...)` is spawned and the engine
13//!   advances without awaiting. The save races with the next superstep;
14//!   on a crash you may lose up to one step. Use when checkpoint backends
15//!   are slow (network/postgres) and step boundaries are frequent.
16//!
17//! - [`Durability::Exit`] — only the final state is persisted (one
18//!   `cp.save` call when the graph reaches `End` or runs out of work).
19//!   Cheapest; suitable for short, non-resumable workflows where
20//!   intermediate snapshots aren't useful.
21//!
22//! - [`Durability::Every`] (extension) — save every N steps. Takes a
23//!   stride and falls back to one of the above modes for every Nth step.
24//!
25//! Custom strategies plug in via [`DurabilityHook`] — see the trait for
26//! how to pass a fully custom decision function.
27
28use std::sync::Arc;
29
30/// Decision for one superstep's checkpoint timing. Returned by a
31/// [`DurabilityHook`].
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum DurabilityDecision {
34    /// Await `cp.save` inline.
35    Sync,
36    /// Spawn `cp.save`.
37    Async,
38    /// Skip persisting this step.
39    Skip,
40}
41
42/// Object-safe hook called per step to decide when to checkpoint. Plugs
43/// into [`Durability::Custom`] for fully bespoke policies (e.g. "save
44/// every 10 steps to S3", "sync only on tool-result steps").
45pub trait DurabilityHook: Send + Sync {
46    /// Decide whether/how to persist `step` for `run_id`.
47    /// `is_terminal` is true on the final emitted decision (graph End).
48    fn decide(&self, step: u64, is_terminal: bool) -> DurabilityDecision;
49}
50
51/// Convenience: any `Fn(u64, bool) -> DurabilityDecision + Send + Sync`
52/// is a `DurabilityHook`.
53impl<F> DurabilityHook for F
54where
55    F: Fn(u64, bool) -> DurabilityDecision + Send + Sync,
56{
57    fn decide(&self, step: u64, is_terminal: bool) -> DurabilityDecision {
58        (self)(step, is_terminal)
59    }
60}
61
62/// Checkpoint timing relative to step execution.
63#[derive(Clone, Default)]
64pub enum Durability {
65    /// Await `cp.save` inline after each step (default).
66    #[default]
67    Sync,
68    /// Spawn `cp.save` without awaiting. Crash before the spawn lands
69    /// loses the most recent step.
70    Async,
71    /// Save only at graph completion. No intermediate checkpoints.
72    Exit,
73    /// Save every Nth step (n>=1) using the wrapped sub-mode for that
74    /// step. Other steps are skipped. The graph-completion save is
75    /// always emitted regardless of stride.
76    Every {
77        /// Stride. 1 means "every step" (equivalent to the wrapped mode).
78        n: u64,
79        /// Mode used on the Nth step.
80        mode: Box<Durability>,
81    },
82    /// Fully user-defined policy.
83    Custom(Arc<dyn DurabilityHook>),
84}
85
86impl std::fmt::Debug for Durability {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        match self {
89            Self::Sync => f.write_str("Sync"),
90            Self::Async => f.write_str("Async"),
91            Self::Exit => f.write_str("Exit"),
92            Self::Every { n, mode } => f
93                .debug_struct("Every")
94                .field("n", n)
95                .field("mode", mode)
96                .finish(),
97            Self::Custom(_) => f.write_str("Custom(<hook>)"),
98        }
99    }
100}
101
102impl PartialEq for Durability {
103    fn eq(&self, other: &Self) -> bool {
104        match (self, other) {
105            (Self::Sync, Self::Sync) | (Self::Async, Self::Async) | (Self::Exit, Self::Exit) => {
106                true
107            }
108            (Self::Every { n: a, mode: ma }, Self::Every { n: b, mode: mb }) => a == b && ma == mb,
109            // Two `Custom` impls are never equal (no way to compare hooks).
110            _ => false,
111        }
112    }
113}
114
115impl Durability {
116    /// Decide what action to take at the end of `step`. `is_terminal`
117    /// is true on the graph-completion save (always honored except for
118    /// `Skip`-returning custom hooks that opt out).
119    pub fn decide(&self, step: u64, is_terminal: bool) -> DurabilityDecision {
120        match self {
121            Self::Sync => DurabilityDecision::Sync,
122            Self::Async => DurabilityDecision::Async,
123            Self::Exit => {
124                if is_terminal {
125                    DurabilityDecision::Sync
126                } else {
127                    DurabilityDecision::Skip
128                }
129            }
130            Self::Every { n, mode } => {
131                if is_terminal {
132                    return DurabilityDecision::Sync;
133                }
134                let stride = (*n).max(1);
135                if step.is_multiple_of(stride) {
136                    mode.decide(step, false)
137                } else {
138                    DurabilityDecision::Skip
139                }
140            }
141            Self::Custom(h) => h.decide(step, is_terminal),
142        }
143    }
144
145    /// True if the engine should save inline after each step.
146    pub fn save_per_step_sync(&self) -> bool {
147        matches!(self, Self::Sync)
148    }
149
150    /// True if the engine should spawn an async save after each step.
151    pub fn save_per_step_async(&self) -> bool {
152        matches!(self, Self::Async)
153    }
154
155    /// True if a final-only save should be emitted on graph completion.
156    pub fn save_on_exit(&self) -> bool {
157        matches!(self, Self::Exit)
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    #[test]
166    fn default_is_sync() {
167        assert_eq!(Durability::default(), Durability::Sync);
168    }
169
170    #[test]
171    fn predicates_match_variants() {
172        assert!(Durability::Sync.save_per_step_sync());
173        assert!(!Durability::Sync.save_per_step_async());
174        assert!(!Durability::Sync.save_on_exit());
175
176        assert!(!Durability::Async.save_per_step_sync());
177        assert!(Durability::Async.save_per_step_async());
178        assert!(!Durability::Async.save_on_exit());
179
180        assert!(!Durability::Exit.save_per_step_sync());
181        assert!(!Durability::Exit.save_per_step_async());
182        assert!(Durability::Exit.save_on_exit());
183    }
184
185    #[test]
186    fn every_stride_skips_intermediate() {
187        let d = Durability::Every {
188            n: 3,
189            mode: Box::new(Durability::Sync),
190        };
191        assert_eq!(d.decide(0, false), DurabilityDecision::Sync);
192        assert_eq!(d.decide(1, false), DurabilityDecision::Skip);
193        assert_eq!(d.decide(2, false), DurabilityDecision::Skip);
194        assert_eq!(d.decide(3, false), DurabilityDecision::Sync);
195        // Terminal always saves.
196        assert_eq!(d.decide(7, true), DurabilityDecision::Sync);
197    }
198
199    #[test]
200    fn custom_hook_is_invoked() {
201        let d = Durability::Custom(Arc::new(|step: u64, terminal: bool| {
202            if terminal || step.is_multiple_of(2) {
203                DurabilityDecision::Sync
204            } else {
205                DurabilityDecision::Skip
206            }
207        }));
208        assert_eq!(d.decide(0, false), DurabilityDecision::Sync);
209        assert_eq!(d.decide(1, false), DurabilityDecision::Skip);
210        assert_eq!(d.decide(99, true), DurabilityDecision::Sync);
211    }
212}