1use crate::{Ctx, Outcome, StepError, Workflow};
2use std::time::{Duration, Instant};
3
4pub struct StepEvent<'a> {
6 pub agent: &'a str,
8 pub outcome: &'a Outcome,
10 pub duration: Duration,
12 pub step_number: usize,
14 pub retries: usize,
16}
17
18pub struct ErrorEvent<'a> {
20 pub agent: &'a str,
22 pub error: &'a StepError,
24 pub step_number: usize,
26}
27
28type StepHook = Box<dyn FnMut(&StepEvent)>;
29type ErrorHook = Box<dyn FnMut(&ErrorEvent)>;
30
31pub struct Runner<S: Clone + 'static> {
33 wf: Workflow<S>,
34 max_steps: usize,
35 max_retries: usize,
36 on_step: Option<StepHook>,
37 on_error: Option<ErrorHook>,
38}
39
40impl<S: Clone + 'static> Runner<S> {
41 pub fn new(wf: Workflow<S>) -> Self {
44 Self {
45 wf,
46 max_steps: 10_000,
47 max_retries: 3,
48 on_step: None,
49 on_error: None,
50 }
51 }
52
53 pub fn with_max_steps(mut self, max_steps: usize) -> Self {
55 self.max_steps = max_steps;
56 self
57 }
58
59 pub fn with_max_retries(mut self, max_retries: usize) -> Self {
61 self.max_retries = max_retries;
62 self
63 }
64
65 pub fn on_step(mut self, cb: impl FnMut(&StepEvent) + 'static) -> Self {
67 self.on_step = Some(Box::new(cb));
68 self
69 }
70
71 pub fn on_error(mut self, cb: impl FnMut(&ErrorEvent) + 'static) -> Self {
73 self.on_error = Some(Box::new(cb));
74 self
75 }
76
77 pub fn with_tracing(self) -> Self {
79 self.on_step(|e| {
80 eprintln!(
81 "[step {}] {} -> {:?} ({:.3}s)",
82 e.step_number,
83 e.agent,
84 e.outcome,
85 e.duration.as_secs_f64()
86 );
87 })
88 .on_error(|e| {
89 eprintln!("[error] {} at step {}: {}", e.agent, e.step_number, e.error);
90 })
91 }
92
93 pub fn run(&mut self, mut state: S, ctx: &mut Ctx) -> Result<S, StepError> {
96 let mut current = self.wf.start();
97 let mut retries: usize = 0;
98 let mut step_number: usize = 0;
99
100 for _ in 0..self.max_steps {
101 step_number += 1;
102
103 let agent = self
104 .wf
105 .agent_mut(current)
106 .ok_or_else(|| StepError::other(format!("unknown step: {current}")))?;
107
108 let start = Instant::now();
109 let result = agent.run(state.clone(), ctx);
110 let duration = start.elapsed();
111
112 match result {
113 Err(err) => {
114 if let Some(cb) = &mut self.on_error {
115 cb(&ErrorEvent {
116 agent: current,
117 error: &err,
118 step_number,
119 });
120 }
121 return Err(err);
122 }
123 Ok((next_state, outcome)) => {
124 if let Some(cb) = &mut self.on_step {
125 cb(&StepEvent {
126 agent: current,
127 outcome: &outcome,
128 duration,
129 step_number,
130 retries,
131 });
132 }
133
134 state = next_state;
135
136 match outcome {
137 Outcome::Done => return Ok(state),
138 Outcome::Fail(msg) => return Err(StepError::other(msg)),
139 Outcome::Next(step) => {
140 current = step;
141 retries = 0;
142 continue;
143 }
144 Outcome::Continue => {
145 if let Some(next) = self.wf.default_next(current) {
146 current = next;
147 retries = 0;
148 continue;
149 }
150 return Err(StepError::other(format!(
151 "step '{current}' returned Continue but no default next step is configured"
152 )));
153 }
154 Outcome::Retry(hint) => {
155 retries += 1;
156 if retries > self.max_retries {
157 let err = StepError::other(format!(
158 "step '{}' exceeded max retries ({}): {}",
159 current, self.max_retries, hint.reason
160 ));
161 if let Some(cb) = &mut self.on_error {
162 cb(&ErrorEvent {
163 agent: current,
164 error: &err,
165 step_number,
166 });
167 }
168 return Err(err);
169 }
170 continue;
171 }
172 Outcome::Wait(dur) => {
173 retries += 1;
174 if retries > self.max_retries {
175 let err = StepError::other(format!(
176 "step '{}' exceeded max retries ({}) while waiting",
177 current, self.max_retries
178 ));
179 if let Some(cb) = &mut self.on_error {
180 cb(&ErrorEvent {
181 agent: current,
182 error: &err,
183 step_number,
184 });
185 }
186 return Err(err);
187 }
188 std::thread::sleep(dur);
189 continue;
190 }
191 }
192 }
193 }
194 }
195
196 let err = StepError::other(format!(
197 "max_steps exceeded (possible infinite loop) in workflow {}",
198 self.wf.name()
199 ));
200 if let Some(cb) = &mut self.on_error {
201 cb(&ErrorEvent {
202 agent: current,
203 error: &err,
204 step_number,
205 });
206 }
207 Err(err)
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use crate::{Agent, Outcome, RetryHint, StepResult, Workflow};
215 use std::time::Duration;
216
217 #[derive(Clone)]
218 struct S(u32);
219
220 struct RetryAgent {
221 attempts: u32,
222 succeed_on: u32,
223 }
224
225 impl Agent<S> for RetryAgent {
226 fn name(&self) -> &'static str {
227 "retry_agent"
228 }
229 fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
230 self.attempts += 1;
231 if self.attempts >= self.succeed_on {
232 Ok((state, Outcome::Done))
233 } else {
234 Ok((state, Outcome::Retry(RetryHint::new("not ready"))))
235 }
236 }
237 }
238
239 struct AlwaysRetry;
240 impl Agent<S> for AlwaysRetry {
241 fn name(&self) -> &'static str {
242 "always_retry"
243 }
244 fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
245 Ok((state, Outcome::Retry(RetryHint::new("never ready"))))
246 }
247 }
248
249 struct WaitOnce {
250 waited: bool,
251 }
252 impl Agent<S> for WaitOnce {
253 fn name(&self) -> &'static str {
254 "wait_once"
255 }
256 fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
257 if !self.waited {
258 self.waited = true;
259 Ok((state, Outcome::Wait(Duration::from_millis(1))))
260 } else {
261 Ok((state, Outcome::Done))
262 }
263 }
264 }
265
266 #[test]
267 fn retry_succeeds_within_limit() {
268 let wf = Workflow::builder("test")
269 .register(RetryAgent {
270 attempts: 0,
271 succeed_on: 3,
272 })
273 .build()
274 .unwrap();
275
276 let mut runner = Runner::new(wf);
277 let mut ctx = Ctx::new();
278 let result = runner.run(S(0), &mut ctx);
279 assert!(result.is_ok());
280 }
281
282 #[test]
283 fn retry_exceeds_limit() {
284 let wf = Workflow::builder("test")
285 .register(AlwaysRetry)
286 .build()
287 .unwrap();
288
289 let mut runner = Runner::new(wf).with_max_retries(2);
290 let mut ctx = Ctx::new();
291 let err = runner.run(S(0), &mut ctx).err().unwrap();
292 assert!(err.to_string().contains("exceeded max retries"));
293 }
294
295 #[test]
296 fn wait_sleeps_and_reruns() {
297 let wf = Workflow::builder("test")
298 .register(WaitOnce { waited: false })
299 .build()
300 .unwrap();
301
302 let mut runner = Runner::new(wf);
303 let mut ctx = Ctx::new();
304 let result = runner.run(S(0), &mut ctx);
305 assert!(result.is_ok());
306 }
307
308 struct DoneAgent;
311 impl Agent<S> for DoneAgent {
312 fn name(&self) -> &'static str {
313 "done_agent"
314 }
315 fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
316 Ok((state, Outcome::Done))
317 }
318 }
319
320 struct FailingAgent;
321 impl Agent<S> for FailingAgent {
322 fn name(&self) -> &'static str {
323 "failing_agent"
324 }
325 fn run(&mut self, _state: S, _ctx: &mut Ctx) -> StepResult<S> {
326 Err(StepError::transient("boom"))
327 }
328 }
329
330 struct AlwaysContinue;
331 impl Agent<S> for AlwaysContinue {
332 fn name(&self) -> &'static str {
333 "always_continue"
334 }
335 fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
336 Ok((state, Outcome::Continue))
337 }
338 }
339
340 #[test]
341 fn on_step_fires_on_success() {
342 use std::sync::{Arc, Mutex};
343
344 let count = Arc::new(Mutex::new(0usize));
345 let count_clone = Arc::clone(&count);
346
347 let wf = Workflow::builder("test")
348 .register(DoneAgent)
349 .build()
350 .unwrap();
351
352 let mut runner = Runner::new(wf).on_step(move |_e| {
353 *count_clone.lock().unwrap() += 1;
354 });
355
356 let mut ctx = Ctx::new();
357 runner.run(S(0), &mut ctx).unwrap();
358 assert_eq!(*count.lock().unwrap(), 1);
359 }
360
361 #[test]
362 fn on_error_fires_on_agent_error() {
363 use std::sync::{Arc, Mutex};
364
365 let count = Arc::new(Mutex::new(0usize));
366 let count_clone = Arc::clone(&count);
367
368 let wf = Workflow::builder("test")
369 .register(FailingAgent)
370 .build()
371 .unwrap();
372
373 let mut runner = Runner::new(wf).on_error(move |_e| {
374 *count_clone.lock().unwrap() += 1;
375 });
376
377 let mut ctx = Ctx::new();
378 let _ = runner.run(S(0), &mut ctx);
379 assert_eq!(*count.lock().unwrap(), 1);
380 }
381
382 #[test]
383 fn on_error_fires_on_max_retries() {
384 use std::sync::{Arc, Mutex};
385
386 let count = Arc::new(Mutex::new(0usize));
387 let count_clone = Arc::clone(&count);
388
389 let wf = Workflow::builder("test")
390 .register(AlwaysRetry)
391 .build()
392 .unwrap();
393
394 let mut runner = Runner::new(wf)
395 .with_max_retries(1)
396 .on_error(move |_e| {
397 *count_clone.lock().unwrap() += 1;
398 });
399
400 let mut ctx = Ctx::new();
401 let _ = runner.run(S(0), &mut ctx);
402 assert_eq!(*count.lock().unwrap(), 1);
403 }
404
405 #[test]
406 fn on_error_fires_on_max_steps() {
407 use std::sync::{Arc, Mutex};
408
409 let count = Arc::new(Mutex::new(0usize));
410 let count_clone = Arc::clone(&count);
411
412 let wf = Workflow::builder("test")
413 .register(AlwaysContinue)
414 .register(DoneAgent)
415 .start_at("always_continue")
416 .then("done_agent")
417 .build()
418 .unwrap();
419
420 let mut runner = Runner::new(wf)
422 .with_max_steps(1)
423 .on_error(move |e| {
424 assert!(e.error.to_string().contains("max_steps exceeded"));
425 *count_clone.lock().unwrap() += 1;
426 });
427
428 let mut ctx = Ctx::new();
429 let _ = runner.run(S(0), &mut ctx);
430 assert_eq!(*count.lock().unwrap(), 1);
431 }
432
433 #[test]
434 fn on_step_receives_correct_step_number() {
435 use std::sync::{Arc, Mutex};
436
437 let steps = Arc::new(Mutex::new(Vec::new()));
438 let steps_clone = Arc::clone(&steps);
439
440 let wf = Workflow::builder("test")
441 .register(RetryAgent {
442 attempts: 0,
443 succeed_on: 3,
444 })
445 .build()
446 .unwrap();
447
448 let mut runner = Runner::new(wf).on_step(move |e| {
449 steps_clone
450 .lock()
451 .unwrap()
452 .push((e.step_number, e.retries));
453 });
454
455 let mut ctx = Ctx::new();
456 runner.run(S(0), &mut ctx).unwrap();
457
458 let steps = steps.lock().unwrap();
459 assert_eq!(steps.len(), 3);
461 assert_eq!(steps[0], (1, 0)); assert_eq!(steps[1], (2, 1)); assert_eq!(steps[2], (3, 2)); }
465
466 struct NextAgent;
469 impl Agent<S> for NextAgent {
470 fn name(&self) -> &'static str {
471 "next_agent"
472 }
473 fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
474 Ok((S(state.0 + 1), Outcome::Next("done_agent")))
475 }
476 }
477
478 #[test]
479 fn next_jumps_to_named_agent() {
480 let wf = Workflow::builder("test")
481 .register(NextAgent)
482 .register(DoneAgent)
483 .build()
484 .unwrap();
485
486 let mut runner = Runner::new(wf);
487 let mut ctx = Ctx::new();
488 let result = runner.run(S(0), &mut ctx).unwrap();
489 assert_eq!(result.0, 1);
490 }
491
492 struct FailOutcomeAgent;
495 impl Agent<S> for FailOutcomeAgent {
496 fn name(&self) -> &'static str {
497 "fail_outcome"
498 }
499 fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
500 Ok((state, Outcome::Fail("reason".into())))
501 }
502 }
503
504 #[test]
505 fn fail_outcome_returns_step_error() {
506 let wf = Workflow::builder("test")
507 .register(FailOutcomeAgent)
508 .build()
509 .unwrap();
510
511 let mut runner = Runner::new(wf);
512 let mut ctx = Ctx::new();
513 let err = runner.run(S(0), &mut ctx).err().unwrap();
514 assert_eq!(err.to_string(), "reason");
515 }
516
517 #[test]
520 fn continue_without_default_next_errors() {
521 let wf = Workflow::builder("test")
522 .register(AlwaysContinue)
523 .build()
524 .unwrap();
525
526 let mut runner = Runner::new(wf);
527 let mut ctx = Ctx::new();
528 let err = runner.run(S(0), &mut ctx).err().unwrap();
529 assert!(err.to_string().contains("no default next step"));
530 }
531
532 struct AlwaysWait;
535 impl Agent<S> for AlwaysWait {
536 fn name(&self) -> &'static str {
537 "always_wait"
538 }
539 fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
540 Ok((state, Outcome::Wait(Duration::from_millis(1))))
541 }
542 }
543
544 #[test]
545 fn wait_exceeds_max_retries() {
546 let wf = Workflow::builder("test")
547 .register(AlwaysWait)
548 .build()
549 .unwrap();
550
551 let mut runner = Runner::new(wf).with_max_retries(1);
552 let mut ctx = Ctx::new();
553 let err = runner.run(S(0), &mut ctx).err().unwrap();
554 assert!(err.to_string().contains("exceeded max retries"));
555 }
556
557 struct RetryOnceThenContinue {
560 attempts: u32,
561 }
562 impl Agent<S> for RetryOnceThenContinue {
563 fn name(&self) -> &'static str {
564 "retry_once_then_continue"
565 }
566 fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
567 self.attempts += 1;
568 if self.attempts < 2 {
569 Ok((state, Outcome::Retry(RetryHint::new("not yet"))))
570 } else {
571 Ok((state, Outcome::Continue))
572 }
573 }
574 }
575
576 #[test]
577 fn retry_counter_resets_on_step_transition() {
578 use std::sync::{Arc, Mutex};
579
580 let events = Arc::new(Mutex::new(Vec::new()));
581 let events_clone = Arc::clone(&events);
582
583 let wf = Workflow::builder("test")
584 .register(RetryOnceThenContinue { attempts: 0 })
585 .register(DoneAgent)
586 .start_at("retry_once_then_continue")
587 .then("done_agent")
588 .build()
589 .unwrap();
590
591 let mut runner = Runner::new(wf).on_step(move |e| {
592 events_clone
593 .lock()
594 .unwrap()
595 .push((e.agent.to_string(), e.retries));
596 });
597
598 let mut ctx = Ctx::new();
599 runner.run(S(0), &mut ctx).unwrap();
600
601 let events = events.lock().unwrap();
602 assert_eq!(events.len(), 3);
604 let done_event = events.iter().find(|(name, _)| name == "done_agent").unwrap();
606 assert_eq!(done_event.1, 0);
607 }
608}