1use std::collections::{HashMap, HashSet};
17use std::sync::Arc;
18
19use atomr_agents_core::{AgentError, Result, RunId, Value, WorkflowId};
20use atomr_agents_state::{CheckpointKey, Checkpointer, RunState, Snapshot, StateSchema};
21use parking_lot::Mutex;
22use serde::{Deserialize, Serialize};
23
24use crate::dag::{Dag, StepId};
25use crate::state_runner::StatefulStep;
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub enum Command {
29 Continue,
31 Resume(Value),
34 Update(Vec<(String, Value)>),
36 Goto(StepId),
38}
39
40#[derive(Debug, Clone)]
41pub enum RunOutcome {
42 Done(RunState),
44 Paused {
46 super_step: u64,
47 reason: PauseReason,
48 payload: Option<Value>,
49 },
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub enum PauseReason {
54 DynamicInterrupt { step_id: StepId },
55 Before(StepId),
56 After(StepId),
57}
58
59#[derive(Clone)]
63pub struct InterruptCtrl {
64 inner: Arc<Mutex<Option<InterruptRequest>>>,
65 resume_value: Arc<Mutex<Option<Value>>>,
68}
69
70#[derive(Clone)]
71struct InterruptRequest {
72 step_id: StepId,
73 payload: Option<Value>,
74}
75
76impl InterruptCtrl {
77 pub fn new() -> Self {
78 Self {
79 inner: Arc::new(Mutex::new(None)),
80 resume_value: Arc::new(Mutex::new(None)),
81 }
82 }
83
84 pub fn interrupt(&self, step_id: StepId, payload: Option<Value>) {
86 *self.inner.lock() = Some(InterruptRequest { step_id, payload });
87 }
88
89 pub fn take_resume_value(&self) -> Option<Value> {
91 self.resume_value.lock().take()
92 }
93
94 fn pending(&self) -> Option<InterruptRequest> {
95 self.inner.lock().take()
96 }
97
98 fn set_resume_value(&self, v: Option<Value>) {
99 *self.resume_value.lock() = v;
100 }
101}
102
103impl Default for InterruptCtrl {
104 fn default() -> Self {
105 Self::new()
106 }
107}
108
109#[async_trait::async_trait]
111pub trait InterruptibleStep: Send + Sync + 'static {
112 async fn run(&self, state: &RunState, ctrl: &InterruptCtrl) -> Result<Vec<(String, Value)>>;
113}
114
115pub struct PlainStep(pub Arc<dyn StatefulStep>);
117
118#[async_trait::async_trait]
119impl InterruptibleStep for PlainStep {
120 async fn run(&self, state: &RunState, _ctrl: &InterruptCtrl) -> Result<Vec<(String, Value)>> {
121 self.0.run(state).await
122 }
123}
124
125pub struct FnInterruptStep<F>(pub F);
127
128#[async_trait::async_trait]
129impl<F, Fut> InterruptibleStep for FnInterruptStep<F>
130where
131 F: Fn(&RunState, &InterruptCtrl) -> Fut + Send + Sync + 'static,
132 Fut: std::future::Future<Output = Result<Vec<(String, Value)>>> + Send + 'static,
133{
134 async fn run(&self, state: &RunState, ctrl: &InterruptCtrl) -> Result<Vec<(String, Value)>> {
135 (self.0)(state, ctrl).await
136 }
137}
138
139pub struct Interruptible {
140 pub workflow_id: WorkflowId,
141 pub run_id: RunId,
142 pub dag: Dag<Arc<dyn InterruptibleStep>>,
143 pub schema: Arc<StateSchema>,
144 pub checkpointer: Arc<dyn Checkpointer>,
145 pub interrupt_before: HashSet<StepId>,
146 pub interrupt_after: HashSet<StepId>,
147}
148
149impl Interruptible {
150 pub async fn run(&self) -> Result<RunOutcome> {
151 let snap = self.checkpointer.latest(&self.workflow_id, &self.run_id).await?;
152 let mut state = match &snap {
153 Some(s) => RunState::from_snapshot(self.schema.clone(), s.values.clone(), s.key.super_step),
154 None => RunState::new(self.schema.clone()),
155 };
156 self.run_inner(&mut state, None, None, false).await
157 }
158
159 pub async fn resume(&self, command: Command) -> Result<RunOutcome> {
161 let snap = self
162 .checkpointer
163 .latest(&self.workflow_id, &self.run_id)
164 .await?
165 .ok_or_else(|| AgentError::Workflow("resume: no checkpoint".into()))?;
166 let (resume_value, edits, goto): (Option<Value>, Vec<(String, Value)>, Option<StepId>) = match command
167 {
168 Command::Continue => (None, Vec::new(), None),
169 Command::Resume(v) => (Some(v), Vec::new(), None),
170 Command::Update(es) => (None, es, None),
171 Command::Goto(s) => (None, Vec::new(), Some(s)),
172 };
173 let mut values = snap.values.clone();
174 for (k, v) in &edits {
175 values.insert(k.clone(), v.clone());
176 }
177 let mut state = RunState::from_snapshot(self.schema.clone(), values, snap.key.super_step);
178 self.run_inner(&mut state, resume_value, goto, true).await
183 }
184
185 async fn run_inner(
186 &self,
187 state: &mut RunState,
188 resume_value: Option<Value>,
189 goto: Option<StepId>,
190 mut skip_breakpoints_once: bool,
191 ) -> Result<RunOutcome> {
192 let order = self.dag.topo_sort()?;
193 let layers = layered(&self.dag, &order);
194 let resume_at = state.super_step();
195 let ctrl = InterruptCtrl::new();
196 let mut resume_value = resume_value;
197
198 let goto_layer = goto.as_ref().and_then(|sid| {
200 layers
201 .iter()
202 .position(|layer| layer.contains(sid))
203 .map(|p| p as u64)
204 });
205 let start_layer = goto_layer.unwrap_or(resume_at);
206
207 for (layer_idx, layer) in layers.iter().enumerate() {
208 let super_step = layer_idx as u64 + 1;
209 if super_step <= start_layer {
210 continue;
211 }
212 for sid in layer {
214 if self.interrupt_before.contains(sid) {
215 if skip_breakpoints_once {
216 skip_breakpoints_once = false;
217 continue;
218 }
219 self.persist_pause(
220 state,
221 super_step.saturating_sub(1),
222 PauseReason::Before(sid.clone()),
223 None,
224 )
225 .await?;
226 return Ok(RunOutcome::Paused {
227 super_step,
228 reason: PauseReason::Before(sid.clone()),
229 payload: None,
230 });
231 }
232 }
233 let mut all_writes: Vec<(String, Value)> = Vec::new();
237 for sid in layer {
238 if let Some(rv) = resume_value.take() {
239 ctrl.set_resume_value(Some(rv));
240 }
241 let step = self
242 .dag
243 .steps
244 .get(sid)
245 .ok_or_else(|| AgentError::Workflow(format!("missing step {}", sid.as_str())))?;
246 let writes = step.run(state, &ctrl).await?;
247 if let Some(req) = ctrl.pending() {
248 self.persist_pause(
249 state,
250 super_step.saturating_sub(1),
251 PauseReason::DynamicInterrupt {
252 step_id: req.step_id.clone(),
253 },
254 req.payload.clone(),
255 )
256 .await?;
257 return Ok(RunOutcome::Paused {
258 super_step,
259 reason: PauseReason::DynamicInterrupt {
260 step_id: req.step_id.clone(),
261 },
262 payload: req.payload,
263 });
264 }
265 all_writes.extend(writes);
266 }
267 state.merge_writes(all_writes)?;
268 state.advance();
269 self.checkpointer
270 .save(Snapshot {
271 key: CheckpointKey {
272 workflow_id: self.workflow_id.clone(),
273 run_id: self.run_id.clone(),
274 super_step,
275 },
276 values: state.snapshot(),
277 label: format!("layer:{super_step}"),
278 timestamp_ms: now_ms(),
279 })
280 .await?;
281 for sid in layer {
283 if self.interrupt_after.contains(sid) {
284 if skip_breakpoints_once {
285 skip_breakpoints_once = false;
286 continue;
287 }
288 return Ok(RunOutcome::Paused {
289 super_step,
290 reason: PauseReason::After(sid.clone()),
291 payload: None,
292 });
293 }
294 }
295 }
296 Ok(RunOutcome::Done(state.clone()))
297 }
298
299 async fn persist_pause(
300 &self,
301 state: &RunState,
302 super_step: u64,
303 reason: PauseReason,
304 payload: Option<Value>,
305 ) -> Result<()> {
306 let label = match &reason {
307 PauseReason::DynamicInterrupt { step_id } => {
308 format!("interrupt:{}", step_id.as_str())
309 }
310 PauseReason::Before(s) => format!("before:{}", s.as_str()),
311 PauseReason::After(s) => format!("after:{}", s.as_str()),
312 };
313 let mut values = state.snapshot();
314 if let Some(p) = payload {
315 values.insert("__interrupt_payload__".into(), p);
316 }
317 self.checkpointer
318 .save(Snapshot {
319 key: CheckpointKey {
320 workflow_id: self.workflow_id.clone(),
321 run_id: self.run_id.clone(),
322 super_step,
323 },
324 values,
325 label,
326 timestamp_ms: now_ms(),
327 })
328 .await
329 }
330}
331
332fn layered(dag: &Dag<Arc<dyn InterruptibleStep>>, order: &[StepId]) -> Vec<Vec<StepId>> {
333 let mut depth: HashMap<StepId, usize> = HashMap::new();
334 for s in order {
335 depth.insert(s.clone(), 0);
336 }
337 for s in order {
338 if let Some(succs) = dag.edges.get(s) {
339 let cur = depth[s];
340 for n in succs {
341 let next = (cur + 1).max(*depth.get(n).unwrap_or(&0));
342 depth.insert(n.clone(), next);
343 }
344 }
345 }
346 let max_d = depth.values().copied().max().unwrap_or(0);
347 let mut layers: Vec<Vec<StepId>> = vec![Vec::new(); max_d + 1];
348 for s in order {
349 layers[depth[s]].push(s.clone());
350 }
351 layers
352}
353
354fn now_ms() -> i64 {
355 use std::time::{SystemTime, UNIX_EPOCH};
356 SystemTime::now()
357 .duration_since(UNIX_EPOCH)
358 .map(|d| d.as_millis() as i64)
359 .unwrap_or(0)
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use atomr_agents_state::{InMemoryCheckpointer, LastWriteWins, MergeMap, StateSchema};
366 use serde_json::json;
367
368 fn schema() -> Arc<StateSchema> {
369 Arc::new(
370 StateSchema::builder()
371 .add("approved", LastWriteWins)
372 .add("amount", LastWriteWins)
373 .add("config", MergeMap)
374 .build(),
375 )
376 }
377
378 #[tokio::test]
379 async fn dynamic_interrupt_pauses_then_resume_with_value() {
380 let dag: Dag<Arc<dyn InterruptibleStep>> = Dag::builder("a")
381 .step(
382 "a",
383 Arc::new(FnInterruptStep(|_state: &RunState, ctrl: &InterruptCtrl| {
384 let ctrl = ctrl.clone();
385 async move {
386 if let Some(approval) = ctrl.take_resume_value() {
387 return Ok(vec![("approved".into(), approval)]);
389 }
390 ctrl.interrupt(StepId::new("a"), Some(json!({"q": "approve?"})));
391 Ok(vec![])
392 }
393 })) as Arc<dyn InterruptibleStep>,
394 )
395 .step(
396 "b",
397 Arc::new(FnInterruptStep(|state: &RunState, _ctrl: &InterruptCtrl| {
398 let approved = state.read("approved").as_bool().unwrap_or(false);
399 async move { Ok(vec![("amount".into(), json!(if approved { 100 } else { 0 }))]) }
400 })) as Arc<dyn InterruptibleStep>,
401 )
402 .edge("a", "b")
403 .build();
404 let cpt: Arc<dyn Checkpointer> = Arc::new(InMemoryCheckpointer::new());
405 let r = Interruptible {
406 workflow_id: WorkflowId::from("wf"),
407 run_id: RunId::from("r"),
408 dag,
409 schema: schema(),
410 checkpointer: cpt.clone(),
411 interrupt_before: HashSet::new(),
412 interrupt_after: HashSet::new(),
413 };
414 let out = r.run().await.unwrap();
415 match out {
416 RunOutcome::Paused { reason, payload, .. } => {
417 assert!(matches!(reason, PauseReason::DynamicInterrupt { .. }));
418 assert_eq!(payload.unwrap()["q"], "approve?");
419 }
420 _ => panic!("expected pause"),
421 }
422 let resumed = r.resume(Command::Resume(json!(true))).await.unwrap();
423 match resumed {
424 RunOutcome::Done(state) => {
425 assert_eq!(state.read("approved"), &json!(true));
426 assert_eq!(state.read("amount"), &json!(100));
427 }
428 _ => panic!("expected done"),
429 }
430 }
431
432 #[tokio::test]
433 async fn static_interrupt_before_pauses() {
434 let mk_step = || -> Arc<dyn InterruptibleStep> {
435 Arc::new(FnInterruptStep(|_s: &RunState, _c: &InterruptCtrl| async {
436 Ok(vec![("config".into(), json!({"x": 1}))])
437 }))
438 };
439 let dag: Dag<Arc<dyn InterruptibleStep>> = Dag::builder("a")
440 .step("a", mk_step())
441 .step("b", mk_step())
442 .edge("a", "b")
443 .build();
444 let cpt: Arc<dyn Checkpointer> = Arc::new(InMemoryCheckpointer::new());
445 let mut before = HashSet::new();
446 before.insert(StepId::new("b"));
447 let r = Interruptible {
448 workflow_id: WorkflowId::from("wf"),
449 run_id: RunId::from("r"),
450 dag,
451 schema: schema(),
452 checkpointer: cpt.clone(),
453 interrupt_before: before,
454 interrupt_after: HashSet::new(),
455 };
456 let out = r.run().await.unwrap();
457 match out {
458 RunOutcome::Paused { reason, .. } => {
459 assert!(matches!(reason, PauseReason::Before(s) if s.as_str() == "b"));
460 }
461 _ => panic!("expected pause before b"),
462 }
463 let done = r.resume(Command::Continue).await.unwrap();
464 assert!(matches!(done, RunOutcome::Done(_)));
465 }
466
467 #[tokio::test]
468 async fn update_command_edits_state_at_resume() {
469 let dag: Dag<Arc<dyn InterruptibleStep>> = Dag::builder("only")
470 .step(
471 "only",
472 Arc::new(FnInterruptStep(|state: &RunState, _c: &InterruptCtrl| {
473 let v = state.read("config").clone();
474 async move { Ok(vec![("config".into(), v)]) }
475 })) as Arc<dyn InterruptibleStep>,
476 )
477 .build();
478 let cpt: Arc<dyn Checkpointer> = Arc::new(InMemoryCheckpointer::new());
479 let mut before = HashSet::new();
480 before.insert(StepId::new("only"));
481 let r = Interruptible {
482 workflow_id: WorkflowId::from("wf"),
483 run_id: RunId::from("r"),
484 dag,
485 schema: schema(),
486 checkpointer: cpt,
487 interrupt_before: before,
488 interrupt_after: HashSet::new(),
489 };
490 let _ = r.run().await.unwrap();
491 let done = r
492 .resume(Command::Update(vec![(
493 "config".into(),
494 json!({"injected": true}),
495 )]))
496 .await
497 .unwrap();
498 match done {
499 RunOutcome::Done(state) => {
500 assert_eq!(state.read("config")["injected"], true);
501 }
502 _ => panic!("expected done"),
503 }
504 }
505}