Skip to main content

atomr_agents_workflow/
interrupt.rs

1//! Dynamic interrupts + static breakpoints + resume.
2//!
3//! `Interruptible` wraps a `StatefulRunner`-style execution with the
4//! ability to:
5//!
6//! 1. Pause when a step calls `Interrupt::raise(payload)`.
7//! 2. Pause before/after named steps via `interrupt_before` /
8//!    `interrupt_after`.
9//! 3. Persist the pause state as a checkpoint with a special label so
10//!    the caller can `resume(run_id, command)` and continue.
11//!
12//! The pause mechanism is cooperative: a step runs `interrupt(...)`
13//! and the runner translates that into a `RunOutcome::Paused`. Resume
14//! re-enters the loop with the supplied `Command`.
15
16use std::collections::{HashMap, HashSet};
17use std::sync::Arc;
18
19use atomr_agents_core::{AgentError, Result, RunId, Value, WorkflowId};
20use atomr_agents_state::{CheckpointKey, Checkpointer, RunState, Snapshot, StateSchema};
21use parking_lot::Mutex;
22use serde::{Deserialize, Serialize};
23
24use crate::dag::{Dag, StepId};
25use crate::state_runner::StatefulStep;
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub enum Command {
29    /// Resume with no edits / no injected value.
30    Continue,
31    /// Resume; the supplied value is the return value of the
32    /// `interrupt(...)` call inside the paused step.
33    Resume(Value),
34    /// Edit channels then continue.
35    Update(Vec<(String, Value)>),
36    /// Jump to a specific step on the next super-step.
37    Goto(StepId),
38}
39
40#[derive(Debug, Clone)]
41pub enum RunOutcome {
42    /// Run completed normally.
43    Done(RunState),
44    /// Run paused; supply a `Command` to `Interruptible::resume`.
45    Paused {
46        super_step: u64,
47        reason: PauseReason,
48        payload: Option<Value>,
49    },
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub enum PauseReason {
54    DynamicInterrupt { step_id: StepId },
55    Before(StepId),
56    After(StepId),
57}
58
59/// Per-run interrupt control passed to a step. A step calls
60/// `ctrl.interrupt(payload)` to pause; on resume the runner returns
61/// the value from `Command::Resume(...)`.
62#[derive(Clone)]
63pub struct InterruptCtrl {
64    inner: Arc<Mutex<Option<InterruptRequest>>>,
65    /// On resume, the runner pre-populates this with the value from
66    /// `Command::Resume(...)` so a step can read it.
67    resume_value: Arc<Mutex<Option<Value>>>,
68}
69
70#[derive(Clone)]
71struct InterruptRequest {
72    step_id: StepId,
73    payload: Option<Value>,
74}
75
76impl InterruptCtrl {
77    pub fn new() -> Self {
78        Self {
79            inner: Arc::new(Mutex::new(None)),
80            resume_value: Arc::new(Mutex::new(None)),
81        }
82    }
83
84    /// Called by step code to request a pause.
85    pub fn interrupt(&self, step_id: StepId, payload: Option<Value>) {
86        *self.inner.lock() = Some(InterruptRequest { step_id, payload });
87    }
88
89    /// Called by step code on the resume path to read the resume value.
90    pub fn take_resume_value(&self) -> Option<Value> {
91        self.resume_value.lock().take()
92    }
93
94    fn pending(&self) -> Option<InterruptRequest> {
95        self.inner.lock().take()
96    }
97
98    fn set_resume_value(&self, v: Option<Value>) {
99        *self.resume_value.lock() = v;
100    }
101}
102
103impl Default for InterruptCtrl {
104    fn default() -> Self {
105        Self::new()
106    }
107}
108
109/// `StatefulStep` extension that gets the interrupt ctrl as well.
110#[async_trait::async_trait]
111pub trait InterruptibleStep: Send + Sync + 'static {
112    async fn run(&self, state: &RunState, ctrl: &InterruptCtrl) -> Result<Vec<(String, Value)>>;
113}
114
115/// Adapter that turns any `StatefulStep` into an `InterruptibleStep`.
116pub struct PlainStep(pub Arc<dyn StatefulStep>);
117
118#[async_trait::async_trait]
119impl InterruptibleStep for PlainStep {
120    async fn run(&self, state: &RunState, _ctrl: &InterruptCtrl) -> Result<Vec<(String, Value)>> {
121        self.0.run(state).await
122    }
123}
124
125/// Closure-friendly InterruptibleStep.
126pub struct FnInterruptStep<F>(pub F);
127
128#[async_trait::async_trait]
129impl<F, Fut> InterruptibleStep for FnInterruptStep<F>
130where
131    F: Fn(&RunState, &InterruptCtrl) -> Fut + Send + Sync + 'static,
132    Fut: std::future::Future<Output = Result<Vec<(String, Value)>>> + Send + 'static,
133{
134    async fn run(&self, state: &RunState, ctrl: &InterruptCtrl) -> Result<Vec<(String, Value)>> {
135        (self.0)(state, ctrl).await
136    }
137}
138
139pub struct Interruptible {
140    pub workflow_id: WorkflowId,
141    pub run_id: RunId,
142    pub dag: Dag<Arc<dyn InterruptibleStep>>,
143    pub schema: Arc<StateSchema>,
144    pub checkpointer: Arc<dyn Checkpointer>,
145    pub interrupt_before: HashSet<StepId>,
146    pub interrupt_after: HashSet<StepId>,
147}
148
149impl Interruptible {
150    pub async fn run(&self) -> Result<RunOutcome> {
151        let snap = self.checkpointer.latest(&self.workflow_id, &self.run_id).await?;
152        let mut state = match &snap {
153            Some(s) => RunState::from_snapshot(self.schema.clone(), s.values.clone(), s.key.super_step),
154            None => RunState::new(self.schema.clone()),
155        };
156        self.run_inner(&mut state, None, None, false).await
157    }
158
159    /// Resume from the most recent paused checkpoint.
160    pub async fn resume(&self, command: Command) -> Result<RunOutcome> {
161        let snap = self
162            .checkpointer
163            .latest(&self.workflow_id, &self.run_id)
164            .await?
165            .ok_or_else(|| AgentError::Workflow("resume: no checkpoint".into()))?;
166        let (resume_value, edits, goto): (Option<Value>, Vec<(String, Value)>, Option<StepId>) = match command
167        {
168            Command::Continue => (None, Vec::new(), None),
169            Command::Resume(v) => (Some(v), Vec::new(), None),
170            Command::Update(es) => (None, es, None),
171            Command::Goto(s) => (None, Vec::new(), Some(s)),
172        };
173        let mut values = snap.values.clone();
174        for (k, v) in &edits {
175            values.insert(k.clone(), v.clone());
176        }
177        let mut state = RunState::from_snapshot(self.schema.clone(), values, snap.key.super_step);
178        // Resume always disables the next breakpoint hit so paused
179        // breakpoints don't re-fire immediately. Dynamic interrupts
180        // are similarly cleared by the snapshot label distinguishing
181        // them.
182        self.run_inner(&mut state, resume_value, goto, true).await
183    }
184
185    async fn run_inner(
186        &self,
187        state: &mut RunState,
188        resume_value: Option<Value>,
189        goto: Option<StepId>,
190        mut skip_breakpoints_once: bool,
191    ) -> Result<RunOutcome> {
192        let order = self.dag.topo_sort()?;
193        let layers = layered(&self.dag, &order);
194        let resume_at = state.super_step();
195        let ctrl = InterruptCtrl::new();
196        let mut resume_value = resume_value;
197
198        // Optionally jump to a layer containing `goto`.
199        let goto_layer = goto.as_ref().and_then(|sid| {
200            layers
201                .iter()
202                .position(|layer| layer.contains(sid))
203                .map(|p| p as u64)
204        });
205        let start_layer = goto_layer.unwrap_or(resume_at);
206
207        for (layer_idx, layer) in layers.iter().enumerate() {
208            let super_step = layer_idx as u64 + 1;
209            if super_step <= start_layer {
210                continue;
211            }
212            // interrupt_before
213            for sid in layer {
214                if self.interrupt_before.contains(sid) {
215                    if skip_breakpoints_once {
216                        skip_breakpoints_once = false;
217                        continue;
218                    }
219                    self.persist_pause(
220                        state,
221                        super_step.saturating_sub(1),
222                        PauseReason::Before(sid.clone()),
223                        None,
224                    )
225                    .await?;
226                    return Ok(RunOutcome::Paused {
227                        super_step,
228                        reason: PauseReason::Before(sid.clone()),
229                        payload: None,
230                    });
231                }
232            }
233            // Run all steps in the layer (sequential here so the
234            // `ctrl.interrupt` semantics are unambiguous; full parallel
235            // dispatch lands in R7).
236            let mut all_writes: Vec<(String, Value)> = Vec::new();
237            for sid in layer {
238                if let Some(rv) = resume_value.take() {
239                    ctrl.set_resume_value(Some(rv));
240                }
241                let step = self
242                    .dag
243                    .steps
244                    .get(sid)
245                    .ok_or_else(|| AgentError::Workflow(format!("missing step {}", sid.as_str())))?;
246                let writes = step.run(state, &ctrl).await?;
247                if let Some(req) = ctrl.pending() {
248                    self.persist_pause(
249                        state,
250                        super_step.saturating_sub(1),
251                        PauseReason::DynamicInterrupt {
252                            step_id: req.step_id.clone(),
253                        },
254                        req.payload.clone(),
255                    )
256                    .await?;
257                    return Ok(RunOutcome::Paused {
258                        super_step,
259                        reason: PauseReason::DynamicInterrupt {
260                            step_id: req.step_id.clone(),
261                        },
262                        payload: req.payload,
263                    });
264                }
265                all_writes.extend(writes);
266            }
267            state.merge_writes(all_writes)?;
268            state.advance();
269            self.checkpointer
270                .save(Snapshot {
271                    key: CheckpointKey {
272                        workflow_id: self.workflow_id.clone(),
273                        run_id: self.run_id.clone(),
274                        super_step,
275                    },
276                    values: state.snapshot(),
277                    label: format!("layer:{super_step}"),
278                    timestamp_ms: now_ms(),
279                })
280                .await?;
281            // interrupt_after
282            for sid in layer {
283                if self.interrupt_after.contains(sid) {
284                    if skip_breakpoints_once {
285                        skip_breakpoints_once = false;
286                        continue;
287                    }
288                    return Ok(RunOutcome::Paused {
289                        super_step,
290                        reason: PauseReason::After(sid.clone()),
291                        payload: None,
292                    });
293                }
294            }
295        }
296        Ok(RunOutcome::Done(state.clone()))
297    }
298
299    async fn persist_pause(
300        &self,
301        state: &RunState,
302        super_step: u64,
303        reason: PauseReason,
304        payload: Option<Value>,
305    ) -> Result<()> {
306        let label = match &reason {
307            PauseReason::DynamicInterrupt { step_id } => {
308                format!("interrupt:{}", step_id.as_str())
309            }
310            PauseReason::Before(s) => format!("before:{}", s.as_str()),
311            PauseReason::After(s) => format!("after:{}", s.as_str()),
312        };
313        let mut values = state.snapshot();
314        if let Some(p) = payload {
315            values.insert("__interrupt_payload__".into(), p);
316        }
317        self.checkpointer
318            .save(Snapshot {
319                key: CheckpointKey {
320                    workflow_id: self.workflow_id.clone(),
321                    run_id: self.run_id.clone(),
322                    super_step,
323                },
324                values,
325                label,
326                timestamp_ms: now_ms(),
327            })
328            .await
329    }
330}
331
332fn layered(dag: &Dag<Arc<dyn InterruptibleStep>>, order: &[StepId]) -> Vec<Vec<StepId>> {
333    let mut depth: HashMap<StepId, usize> = HashMap::new();
334    for s in order {
335        depth.insert(s.clone(), 0);
336    }
337    for s in order {
338        if let Some(succs) = dag.edges.get(s) {
339            let cur = depth[s];
340            for n in succs {
341                let next = (cur + 1).max(*depth.get(n).unwrap_or(&0));
342                depth.insert(n.clone(), next);
343            }
344        }
345    }
346    let max_d = depth.values().copied().max().unwrap_or(0);
347    let mut layers: Vec<Vec<StepId>> = vec![Vec::new(); max_d + 1];
348    for s in order {
349        layers[depth[s]].push(s.clone());
350    }
351    layers
352}
353
354fn now_ms() -> i64 {
355    use std::time::{SystemTime, UNIX_EPOCH};
356    SystemTime::now()
357        .duration_since(UNIX_EPOCH)
358        .map(|d| d.as_millis() as i64)
359        .unwrap_or(0)
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use atomr_agents_state::{InMemoryCheckpointer, LastWriteWins, MergeMap, StateSchema};
366    use serde_json::json;
367
368    fn schema() -> Arc<StateSchema> {
369        Arc::new(
370            StateSchema::builder()
371                .add("approved", LastWriteWins)
372                .add("amount", LastWriteWins)
373                .add("config", MergeMap)
374                .build(),
375        )
376    }
377
378    #[tokio::test]
379    async fn dynamic_interrupt_pauses_then_resume_with_value() {
380        let dag: Dag<Arc<dyn InterruptibleStep>> = Dag::builder("a")
381            .step(
382                "a",
383                Arc::new(FnInterruptStep(|_state: &RunState, ctrl: &InterruptCtrl| {
384                    let ctrl = ctrl.clone();
385                    async move {
386                        if let Some(approval) = ctrl.take_resume_value() {
387                            // resumed
388                            return Ok(vec![("approved".into(), approval)]);
389                        }
390                        ctrl.interrupt(StepId::new("a"), Some(json!({"q": "approve?"})));
391                        Ok(vec![])
392                    }
393                })) as Arc<dyn InterruptibleStep>,
394            )
395            .step(
396                "b",
397                Arc::new(FnInterruptStep(|state: &RunState, _ctrl: &InterruptCtrl| {
398                    let approved = state.read("approved").as_bool().unwrap_or(false);
399                    async move { Ok(vec![("amount".into(), json!(if approved { 100 } else { 0 }))]) }
400                })) as Arc<dyn InterruptibleStep>,
401            )
402            .edge("a", "b")
403            .build();
404        let cpt: Arc<dyn Checkpointer> = Arc::new(InMemoryCheckpointer::new());
405        let r = Interruptible {
406            workflow_id: WorkflowId::from("wf"),
407            run_id: RunId::from("r"),
408            dag,
409            schema: schema(),
410            checkpointer: cpt.clone(),
411            interrupt_before: HashSet::new(),
412            interrupt_after: HashSet::new(),
413        };
414        let out = r.run().await.unwrap();
415        match out {
416            RunOutcome::Paused { reason, payload, .. } => {
417                assert!(matches!(reason, PauseReason::DynamicInterrupt { .. }));
418                assert_eq!(payload.unwrap()["q"], "approve?");
419            }
420            _ => panic!("expected pause"),
421        }
422        let resumed = r.resume(Command::Resume(json!(true))).await.unwrap();
423        match resumed {
424            RunOutcome::Done(state) => {
425                assert_eq!(state.read("approved"), &json!(true));
426                assert_eq!(state.read("amount"), &json!(100));
427            }
428            _ => panic!("expected done"),
429        }
430    }
431
432    #[tokio::test]
433    async fn static_interrupt_before_pauses() {
434        let mk_step = || -> Arc<dyn InterruptibleStep> {
435            Arc::new(FnInterruptStep(|_s: &RunState, _c: &InterruptCtrl| async {
436                Ok(vec![("config".into(), json!({"x": 1}))])
437            }))
438        };
439        let dag: Dag<Arc<dyn InterruptibleStep>> = Dag::builder("a")
440            .step("a", mk_step())
441            .step("b", mk_step())
442            .edge("a", "b")
443            .build();
444        let cpt: Arc<dyn Checkpointer> = Arc::new(InMemoryCheckpointer::new());
445        let mut before = HashSet::new();
446        before.insert(StepId::new("b"));
447        let r = Interruptible {
448            workflow_id: WorkflowId::from("wf"),
449            run_id: RunId::from("r"),
450            dag,
451            schema: schema(),
452            checkpointer: cpt.clone(),
453            interrupt_before: before,
454            interrupt_after: HashSet::new(),
455        };
456        let out = r.run().await.unwrap();
457        match out {
458            RunOutcome::Paused { reason, .. } => {
459                assert!(matches!(reason, PauseReason::Before(s) if s.as_str() == "b"));
460            }
461            _ => panic!("expected pause before b"),
462        }
463        let done = r.resume(Command::Continue).await.unwrap();
464        assert!(matches!(done, RunOutcome::Done(_)));
465    }
466
467    #[tokio::test]
468    async fn update_command_edits_state_at_resume() {
469        let dag: Dag<Arc<dyn InterruptibleStep>> = Dag::builder("only")
470            .step(
471                "only",
472                Arc::new(FnInterruptStep(|state: &RunState, _c: &InterruptCtrl| {
473                    let v = state.read("config").clone();
474                    async move { Ok(vec![("config".into(), v)]) }
475                })) as Arc<dyn InterruptibleStep>,
476            )
477            .build();
478        let cpt: Arc<dyn Checkpointer> = Arc::new(InMemoryCheckpointer::new());
479        let mut before = HashSet::new();
480        before.insert(StepId::new("only"));
481        let r = Interruptible {
482            workflow_id: WorkflowId::from("wf"),
483            run_id: RunId::from("r"),
484            dag,
485            schema: schema(),
486            checkpointer: cpt,
487            interrupt_before: before,
488            interrupt_after: HashSet::new(),
489        };
490        let _ = r.run().await.unwrap();
491        let done = r
492            .resume(Command::Update(vec![(
493                "config".into(),
494                json!({"injected": true}),
495            )]))
496            .await
497            .unwrap();
498        match done {
499            RunOutcome::Done(state) => {
500                assert_eq!(state.read("config")["injected"], true);
501            }
502            _ => panic!("expected done"),
503        }
504    }
505}