1use crate::{Ctx, Outcome, StepError, Workflow};
2use std::time::{Duration, Instant};
3
4pub struct StepEvent<'a> {
6 pub agent: &'a str,
8 pub outcome: &'a Outcome,
10 pub duration: Duration,
12 pub step_number: usize,
14 pub retries: usize,
16}
17
18pub struct ErrorEvent<'a> {
20 pub agent: &'a str,
22 pub error: &'a StepError,
24 pub step_number: usize,
26}
27
28type StepHook = Box<dyn FnMut(&StepEvent)>;
29type ErrorHook = Box<dyn FnMut(&ErrorEvent)>;
30
31pub struct Runner<S: Clone + 'static> {
33 wf: Workflow<S>,
34 max_steps: usize,
35 max_retries: usize,
36 on_step: Option<StepHook>,
37 on_error: Option<ErrorHook>,
38}
39
40impl<S: Clone + 'static> Runner<S> {
41 pub fn new(wf: Workflow<S>) -> Self {
44 Self {
45 wf,
46 max_steps: 10_000,
47 max_retries: 3,
48 on_step: None,
49 on_error: None,
50 }
51 }
52
53 pub fn with_max_steps(mut self, max_steps: usize) -> Self {
55 self.max_steps = max_steps;
56 self
57 }
58
59 pub fn with_max_retries(mut self, max_retries: usize) -> Self {
61 self.max_retries = max_retries;
62 self
63 }
64
65 pub fn on_step(mut self, cb: impl FnMut(&StepEvent) + 'static) -> Self {
67 self.on_step = Some(Box::new(cb));
68 self
69 }
70
71 pub fn on_error(mut self, cb: impl FnMut(&ErrorEvent) + 'static) -> Self {
73 self.on_error = Some(Box::new(cb));
74 self
75 }
76
77 pub fn with_tracing(self) -> Self {
79 self.on_step(|e| {
80 eprintln!(
81 "[step {}] {} -> {:?} ({:.3}s)",
82 e.step_number,
83 e.agent,
84 e.outcome,
85 e.duration.as_secs_f64()
86 );
87 })
88 .on_error(|e| {
89 eprintln!("[error] {} at step {}: {}", e.agent, e.step_number, e.error);
90 })
91 }
92
93 pub fn run(&mut self, mut state: S, ctx: &mut Ctx) -> Result<S, StepError> {
96 let mut current = self.wf.start();
97 let mut retries: usize = 0;
98 let mut step_number: usize = 0;
99
100 for _ in 0..self.max_steps {
101 step_number += 1;
102
103 let agent = self
104 .wf
105 .agent_mut(current)
106 .ok_or_else(|| StepError::other(format!("unknown step: {current}")))?;
107
108 let start = Instant::now();
109 let result = agent.run(state.clone(), ctx);
110 let duration = start.elapsed();
111
112 match result {
113 Err(err) => {
114 if let Some(cb) = &mut self.on_error {
115 cb(&ErrorEvent {
116 agent: current,
117 error: &err,
118 step_number,
119 });
120 }
121 return Err(err);
122 }
123 Ok((next_state, outcome)) => {
124 if let Some(cb) = &mut self.on_step {
125 cb(&StepEvent {
126 agent: current,
127 outcome: &outcome,
128 duration,
129 step_number,
130 retries,
131 });
132 }
133
134 state = next_state;
135
136 match outcome {
137 Outcome::Done => return Ok(state),
138 Outcome::Fail(msg) => return Err(StepError::other(msg)),
139 Outcome::Next(step) => {
140 current = step;
141 retries = 0;
142 continue;
143 }
144 Outcome::Continue => {
145 if let Some(next) = self.wf.default_next(current) {
146 current = next;
147 retries = 0;
148 continue;
149 }
150 return Err(StepError::other(format!(
151 "step '{current}' returned Continue but no default next step is configured"
152 )));
153 }
154 Outcome::Retry(hint) => {
155 retries += 1;
156 if retries > self.max_retries {
157 let err = StepError::other(format!(
158 "step '{}' exceeded max retries ({}): {}",
159 current, self.max_retries, hint.reason
160 ));
161 if let Some(cb) = &mut self.on_error {
162 cb(&ErrorEvent {
163 agent: current,
164 error: &err,
165 step_number,
166 });
167 }
168 return Err(err);
169 }
170 continue;
171 }
172 Outcome::Wait(dur) => {
173 retries += 1;
174 if retries > self.max_retries {
175 let err = StepError::other(format!(
176 "step '{}' exceeded max retries ({}) while waiting",
177 current, self.max_retries
178 ));
179 if let Some(cb) = &mut self.on_error {
180 cb(&ErrorEvent {
181 agent: current,
182 error: &err,
183 step_number,
184 });
185 }
186 return Err(err);
187 }
188 std::thread::sleep(dur);
189 continue;
190 }
191 }
192 }
193 }
194 }
195
196 let err = StepError::other(format!(
197 "max_steps exceeded (possible infinite loop) in workflow {}",
198 self.wf.name()
199 ));
200 if let Some(cb) = &mut self.on_error {
201 cb(&ErrorEvent {
202 agent: current,
203 error: &err,
204 step_number,
205 });
206 }
207 Err(err)
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use crate::{Agent, Outcome, RetryHint, StepResult, Workflow};
215 use std::time::Duration;
216
217 #[derive(Clone)]
218 struct S(u32);
219
220 struct RetryAgent {
221 attempts: u32,
222 succeed_on: u32,
223 }
224
225 impl Agent<S> for RetryAgent {
226 fn name(&self) -> &'static str {
227 "retry_agent"
228 }
229 fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
230 self.attempts += 1;
231 if self.attempts >= self.succeed_on {
232 Ok((state, Outcome::Done))
233 } else {
234 Ok((state, Outcome::Retry(RetryHint::new("not ready"))))
235 }
236 }
237 }
238
239 struct AlwaysRetry;
240 impl Agent<S> for AlwaysRetry {
241 fn name(&self) -> &'static str {
242 "always_retry"
243 }
244 fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
245 Ok((state, Outcome::Retry(RetryHint::new("never ready"))))
246 }
247 }
248
249 struct WaitOnce {
250 waited: bool,
251 }
252 impl Agent<S> for WaitOnce {
253 fn name(&self) -> &'static str {
254 "wait_once"
255 }
256 fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
257 if !self.waited {
258 self.waited = true;
259 Ok((state, Outcome::Wait(Duration::from_millis(1))))
260 } else {
261 Ok((state, Outcome::Done))
262 }
263 }
264 }
265
266 #[test]
267 fn retry_succeeds_within_limit() {
268 let wf = Workflow::builder("test")
269 .register(RetryAgent {
270 attempts: 0,
271 succeed_on: 3,
272 })
273 .build()
274 .unwrap();
275
276 let mut runner = Runner::new(wf);
277 let mut ctx = Ctx::new();
278 let result = runner.run(S(0), &mut ctx);
279 assert!(result.is_ok());
280 }
281
282 #[test]
283 fn retry_exceeds_limit() {
284 let wf = Workflow::builder("test")
285 .register(AlwaysRetry)
286 .build()
287 .unwrap();
288
289 let mut runner = Runner::new(wf).with_max_retries(2);
290 let mut ctx = Ctx::new();
291 let err = runner.run(S(0), &mut ctx).err().unwrap();
292 assert!(err.to_string().contains("exceeded max retries"));
293 }
294
295 #[test]
296 fn wait_sleeps_and_reruns() {
297 let wf = Workflow::builder("test")
298 .register(WaitOnce { waited: false })
299 .build()
300 .unwrap();
301
302 let mut runner = Runner::new(wf);
303 let mut ctx = Ctx::new();
304 let result = runner.run(S(0), &mut ctx);
305 assert!(result.is_ok());
306 }
307
308 struct DoneAgent;
311 impl Agent<S> for DoneAgent {
312 fn name(&self) -> &'static str {
313 "done_agent"
314 }
315 fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
316 Ok((state, Outcome::Done))
317 }
318 }
319
320 struct FailingAgent;
321 impl Agent<S> for FailingAgent {
322 fn name(&self) -> &'static str {
323 "failing_agent"
324 }
325 fn run(&mut self, _state: S, _ctx: &mut Ctx) -> StepResult<S> {
326 Err(StepError::transient("boom"))
327 }
328 }
329
330 struct AlwaysContinue;
331 impl Agent<S> for AlwaysContinue {
332 fn name(&self) -> &'static str {
333 "always_continue"
334 }
335 fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
336 Ok((state, Outcome::Continue))
337 }
338 }
339
340 #[test]
341 fn on_step_fires_on_success() {
342 use std::sync::{Arc, Mutex};
343
344 let count = Arc::new(Mutex::new(0usize));
345 let count_clone = Arc::clone(&count);
346
347 let wf = Workflow::builder("test")
348 .register(DoneAgent)
349 .build()
350 .unwrap();
351
352 let mut runner = Runner::new(wf).on_step(move |_e| {
353 *count_clone.lock().unwrap() += 1;
354 });
355
356 let mut ctx = Ctx::new();
357 runner.run(S(0), &mut ctx).unwrap();
358 assert_eq!(*count.lock().unwrap(), 1);
359 }
360
361 #[test]
362 fn on_error_fires_on_agent_error() {
363 use std::sync::{Arc, Mutex};
364
365 let count = Arc::new(Mutex::new(0usize));
366 let count_clone = Arc::clone(&count);
367
368 let wf = Workflow::builder("test")
369 .register(FailingAgent)
370 .build()
371 .unwrap();
372
373 let mut runner = Runner::new(wf).on_error(move |_e| {
374 *count_clone.lock().unwrap() += 1;
375 });
376
377 let mut ctx = Ctx::new();
378 let _ = runner.run(S(0), &mut ctx);
379 assert_eq!(*count.lock().unwrap(), 1);
380 }
381
382 #[test]
383 fn on_error_fires_on_max_retries() {
384 use std::sync::{Arc, Mutex};
385
386 let count = Arc::new(Mutex::new(0usize));
387 let count_clone = Arc::clone(&count);
388
389 let wf = Workflow::builder("test")
390 .register(AlwaysRetry)
391 .build()
392 .unwrap();
393
394 let mut runner = Runner::new(wf).with_max_retries(1).on_error(move |_e| {
395 *count_clone.lock().unwrap() += 1;
396 });
397
398 let mut ctx = Ctx::new();
399 let _ = runner.run(S(0), &mut ctx);
400 assert_eq!(*count.lock().unwrap(), 1);
401 }
402
403 #[test]
404 fn on_error_fires_on_max_steps() {
405 use std::sync::{Arc, Mutex};
406
407 let count = Arc::new(Mutex::new(0usize));
408 let count_clone = Arc::clone(&count);
409
410 let wf = Workflow::builder("test")
411 .register(AlwaysContinue)
412 .register(DoneAgent)
413 .start_at("always_continue")
414 .then("done_agent")
415 .build()
416 .unwrap();
417
418 let mut runner = Runner::new(wf).with_max_steps(1).on_error(move |e| {
420 assert!(e.error.to_string().contains("max_steps exceeded"));
421 *count_clone.lock().unwrap() += 1;
422 });
423
424 let mut ctx = Ctx::new();
425 let _ = runner.run(S(0), &mut ctx);
426 assert_eq!(*count.lock().unwrap(), 1);
427 }
428
429 #[test]
430 fn on_step_receives_correct_step_number() {
431 use std::sync::{Arc, Mutex};
432
433 let steps = Arc::new(Mutex::new(Vec::new()));
434 let steps_clone = Arc::clone(&steps);
435
436 let wf = Workflow::builder("test")
437 .register(RetryAgent {
438 attempts: 0,
439 succeed_on: 3,
440 })
441 .build()
442 .unwrap();
443
444 let mut runner = Runner::new(wf).on_step(move |e| {
445 steps_clone.lock().unwrap().push((e.step_number, e.retries));
446 });
447
448 let mut ctx = Ctx::new();
449 runner.run(S(0), &mut ctx).unwrap();
450
451 let steps = steps.lock().unwrap();
452 assert_eq!(steps.len(), 3);
454 assert_eq!(steps[0], (1, 0)); assert_eq!(steps[1], (2, 1)); assert_eq!(steps[2], (3, 2)); }
458
459 struct NextAgent;
462 impl Agent<S> for NextAgent {
463 fn name(&self) -> &'static str {
464 "next_agent"
465 }
466 fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
467 Ok((S(state.0 + 1), Outcome::Next("done_agent")))
468 }
469 }
470
471 #[test]
472 fn next_jumps_to_named_agent() {
473 let wf = Workflow::builder("test")
474 .register(NextAgent)
475 .register(DoneAgent)
476 .build()
477 .unwrap();
478
479 let mut runner = Runner::new(wf);
480 let mut ctx = Ctx::new();
481 let result = runner.run(S(0), &mut ctx).unwrap();
482 assert_eq!(result.0, 1);
483 }
484
485 struct FailOutcomeAgent;
488 impl Agent<S> for FailOutcomeAgent {
489 fn name(&self) -> &'static str {
490 "fail_outcome"
491 }
492 fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
493 Ok((state, Outcome::Fail("reason".into())))
494 }
495 }
496
497 #[test]
498 fn fail_outcome_returns_step_error() {
499 let wf = Workflow::builder("test")
500 .register(FailOutcomeAgent)
501 .build()
502 .unwrap();
503
504 let mut runner = Runner::new(wf);
505 let mut ctx = Ctx::new();
506 let err = runner.run(S(0), &mut ctx).err().unwrap();
507 assert_eq!(err.to_string(), "reason");
508 }
509
510 #[test]
513 fn continue_without_default_next_errors() {
514 let wf = Workflow::builder("test")
515 .register(AlwaysContinue)
516 .build()
517 .unwrap();
518
519 let mut runner = Runner::new(wf);
520 let mut ctx = Ctx::new();
521 let err = runner.run(S(0), &mut ctx).err().unwrap();
522 assert!(err.to_string().contains("no default next step"));
523 }
524
525 struct AlwaysWait;
528 impl Agent<S> for AlwaysWait {
529 fn name(&self) -> &'static str {
530 "always_wait"
531 }
532 fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
533 Ok((state, Outcome::Wait(Duration::from_millis(1))))
534 }
535 }
536
537 #[test]
538 fn wait_exceeds_max_retries() {
539 let wf = Workflow::builder("test")
540 .register(AlwaysWait)
541 .build()
542 .unwrap();
543
544 let mut runner = Runner::new(wf).with_max_retries(1);
545 let mut ctx = Ctx::new();
546 let err = runner.run(S(0), &mut ctx).err().unwrap();
547 assert!(err.to_string().contains("exceeded max retries"));
548 }
549
550 struct RetryOnceThenContinue {
553 attempts: u32,
554 }
555 impl Agent<S> for RetryOnceThenContinue {
556 fn name(&self) -> &'static str {
557 "retry_once_then_continue"
558 }
559 fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
560 self.attempts += 1;
561 if self.attempts < 2 {
562 Ok((state, Outcome::Retry(RetryHint::new("not yet"))))
563 } else {
564 Ok((state, Outcome::Continue))
565 }
566 }
567 }
568
569 #[test]
570 fn retry_counter_resets_on_step_transition() {
571 use std::sync::{Arc, Mutex};
572
573 let events = Arc::new(Mutex::new(Vec::new()));
574 let events_clone = Arc::clone(&events);
575
576 let wf = Workflow::builder("test")
577 .register(RetryOnceThenContinue { attempts: 0 })
578 .register(DoneAgent)
579 .start_at("retry_once_then_continue")
580 .then("done_agent")
581 .build()
582 .unwrap();
583
584 let mut runner = Runner::new(wf).on_step(move |e| {
585 events_clone
586 .lock()
587 .unwrap()
588 .push((e.agent.to_string(), e.retries));
589 });
590
591 let mut ctx = Ctx::new();
592 runner.run(S(0), &mut ctx).unwrap();
593
594 let events = events.lock().unwrap();
595 assert_eq!(events.len(), 3);
597 let done_event = events
599 .iter()
600 .find(|(name, _)| name == "done_agent")
601 .unwrap();
602 assert_eq!(done_event.1, 0);
603 }
604}