Skip to main content

agent_line/
runner.rs

1use crate::{Ctx, Outcome, StepError, Workflow};
2use std::time::{Duration, Instant};
3
4/// Passed to the `on_step` hook after each successful agent step.
5pub struct StepEvent<'a> {
6    /// Name of the agent that ran.
7    pub agent: &'a str,
8    /// The outcome the agent returned.
9    pub outcome: &'a Outcome,
10    /// Wall-clock time for the step.
11    pub duration: Duration,
12    /// Sequential step counter (starts at 1).
13    pub step_number: usize,
14    /// Consecutive retry count for the current agent.
15    pub retries: usize,
16}
17
18/// Passed to the `on_error` hook when an agent errors or a limit is exceeded.
19pub struct ErrorEvent<'a> {
20    /// Name of the agent that errored.
21    pub agent: &'a str,
22    /// The error that occurred.
23    pub error: &'a StepError,
24    /// Step number where the error happened.
25    pub step_number: usize,
26}
27
28type StepHook = Box<dyn FnMut(&StepEvent)>;
29type ErrorHook = Box<dyn FnMut(&ErrorEvent)>;
30
31/// Executes a [`Workflow`] step by step, handling retries, waits, and routing.
32pub struct Runner<S: Clone + 'static> {
33    wf: Workflow<S>,
34    max_steps: usize,
35    max_retries: usize,
36    on_step: Option<StepHook>,
37    on_error: Option<ErrorHook>,
38}
39
40impl<S: Clone + 'static> Runner<S> {
41    /// Create a runner for the given workflow with default limits
42    /// (max_steps: 10,000, max_retries: 3).
43    pub fn new(wf: Workflow<S>) -> Self {
44        Self {
45            wf,
46            max_steps: 10_000,
47            max_retries: 3,
48            on_step: None,
49            on_error: None,
50        }
51    }
52
53    /// Prevent accidental infinite loops.
54    pub fn with_max_steps(mut self, max_steps: usize) -> Self {
55        self.max_steps = max_steps;
56        self
57    }
58
59    /// Set the maximum consecutive retries per agent before failing.
60    pub fn with_max_retries(mut self, max_retries: usize) -> Self {
61        self.max_retries = max_retries;
62        self
63    }
64
65    /// Register a callback that fires after each successful agent step.
66    pub fn on_step(mut self, cb: impl FnMut(&StepEvent) + 'static) -> Self {
67        self.on_step = Some(Box::new(cb));
68        self
69    }
70
71    /// Register a callback that fires when an agent errors or a limit is exceeded.
72    pub fn on_error(mut self, cb: impl FnMut(&ErrorEvent) + 'static) -> Self {
73        self.on_error = Some(Box::new(cb));
74        self
75    }
76
77    /// Set both hooks to print step transitions and errors to stderr.
78    pub fn with_tracing(self) -> Self {
79        self.on_step(|e| {
80            eprintln!(
81                "[step {}] {} -> {:?} ({:.3}s)",
82                e.step_number,
83                e.agent,
84                e.outcome,
85                e.duration.as_secs_f64()
86            );
87        })
88        .on_error(|e| {
89            eprintln!("[error] {} at step {}: {}", e.agent, e.step_number, e.error);
90        })
91    }
92
93    /// Run the workflow to completion, returning the final state or an error.
94    /// Can be called multiple times on the same runner.
95    pub fn run(&mut self, mut state: S, ctx: &mut Ctx) -> Result<S, StepError> {
96        let mut current = self.wf.start();
97        let mut retries: usize = 0;
98        let mut step_number: usize = 0;
99
100        for _ in 0..self.max_steps {
101            step_number += 1;
102
103            let agent = self
104                .wf
105                .agent_mut(current)
106                .ok_or_else(|| StepError::other(format!("unknown step: {current}")))?;
107
108            let start = Instant::now();
109            let result = agent.run(state.clone(), ctx);
110            let duration = start.elapsed();
111
112            match result {
113                Err(err) => {
114                    if let Some(cb) = &mut self.on_error {
115                        cb(&ErrorEvent {
116                            agent: current,
117                            error: &err,
118                            step_number,
119                        });
120                    }
121                    return Err(err);
122                }
123                Ok((next_state, outcome)) => {
124                    if let Some(cb) = &mut self.on_step {
125                        cb(&StepEvent {
126                            agent: current,
127                            outcome: &outcome,
128                            duration,
129                            step_number,
130                            retries,
131                        });
132                    }
133
134                    state = next_state;
135
136                    match outcome {
137                        Outcome::Done => return Ok(state),
138                        Outcome::Fail(msg) => return Err(StepError::other(msg)),
139                        Outcome::Next(step) => {
140                            current = step;
141                            retries = 0;
142                            continue;
143                        }
144                        Outcome::Continue => {
145                            if let Some(next) = self.wf.default_next(current) {
146                                current = next;
147                                retries = 0;
148                                continue;
149                            }
150                            return Err(StepError::other(format!(
151                                "step '{current}' returned Continue but no default next step is configured"
152                            )));
153                        }
154                        Outcome::Retry(hint) => {
155                            retries += 1;
156                            if retries > self.max_retries {
157                                let err = StepError::other(format!(
158                                    "step '{}' exceeded max retries ({}): {}",
159                                    current, self.max_retries, hint.reason
160                                ));
161                                if let Some(cb) = &mut self.on_error {
162                                    cb(&ErrorEvent {
163                                        agent: current,
164                                        error: &err,
165                                        step_number,
166                                    });
167                                }
168                                return Err(err);
169                            }
170                            continue;
171                        }
172                        Outcome::Wait(dur) => {
173                            retries += 1;
174                            if retries > self.max_retries {
175                                let err = StepError::other(format!(
176                                    "step '{}' exceeded max retries ({}) while waiting",
177                                    current, self.max_retries
178                                ));
179                                if let Some(cb) = &mut self.on_error {
180                                    cb(&ErrorEvent {
181                                        agent: current,
182                                        error: &err,
183                                        step_number,
184                                    });
185                                }
186                                return Err(err);
187                            }
188                            std::thread::sleep(dur);
189                            continue;
190                        }
191                    }
192                }
193            }
194        }
195
196        let err = StepError::other(format!(
197            "max_steps exceeded (possible infinite loop) in workflow {}",
198            self.wf.name()
199        ));
200        if let Some(cb) = &mut self.on_error {
201            cb(&ErrorEvent {
202                agent: current,
203                error: &err,
204                step_number,
205            });
206        }
207        Err(err)
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::{Agent, Outcome, RetryHint, StepResult, Workflow};
215    use std::time::Duration;
216
217    #[derive(Clone)]
218    struct S(u32);
219
220    struct RetryAgent {
221        attempts: u32,
222        succeed_on: u32,
223    }
224
225    impl Agent<S> for RetryAgent {
226        fn name(&self) -> &'static str {
227            "retry_agent"
228        }
229        fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
230            self.attempts += 1;
231            if self.attempts >= self.succeed_on {
232                Ok((state, Outcome::Done))
233            } else {
234                Ok((state, Outcome::Retry(RetryHint::new("not ready"))))
235            }
236        }
237    }
238
239    struct AlwaysRetry;
240    impl Agent<S> for AlwaysRetry {
241        fn name(&self) -> &'static str {
242            "always_retry"
243        }
244        fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
245            Ok((state, Outcome::Retry(RetryHint::new("never ready"))))
246        }
247    }
248
249    struct WaitOnce {
250        waited: bool,
251    }
252    impl Agent<S> for WaitOnce {
253        fn name(&self) -> &'static str {
254            "wait_once"
255        }
256        fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
257            if !self.waited {
258                self.waited = true;
259                Ok((state, Outcome::Wait(Duration::from_millis(1))))
260            } else {
261                Ok((state, Outcome::Done))
262            }
263        }
264    }
265
266    #[test]
267    fn retry_succeeds_within_limit() {
268        let wf = Workflow::builder("test")
269            .register(RetryAgent {
270                attempts: 0,
271                succeed_on: 3,
272            })
273            .build()
274            .unwrap();
275
276        let mut runner = Runner::new(wf);
277        let mut ctx = Ctx::new();
278        let result = runner.run(S(0), &mut ctx);
279        assert!(result.is_ok());
280    }
281
282    #[test]
283    fn retry_exceeds_limit() {
284        let wf = Workflow::builder("test")
285            .register(AlwaysRetry)
286            .build()
287            .unwrap();
288
289        let mut runner = Runner::new(wf).with_max_retries(2);
290        let mut ctx = Ctx::new();
291        let err = runner.run(S(0), &mut ctx).err().unwrap();
292        assert!(err.to_string().contains("exceeded max retries"));
293    }
294
295    #[test]
296    fn wait_sleeps_and_reruns() {
297        let wf = Workflow::builder("test")
298            .register(WaitOnce { waited: false })
299            .build()
300            .unwrap();
301
302        let mut runner = Runner::new(wf);
303        let mut ctx = Ctx::new();
304        let result = runner.run(S(0), &mut ctx);
305        assert!(result.is_ok());
306    }
307
308    // --- hook tests ---
309
310    struct DoneAgent;
311    impl Agent<S> for DoneAgent {
312        fn name(&self) -> &'static str {
313            "done_agent"
314        }
315        fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
316            Ok((state, Outcome::Done))
317        }
318    }
319
320    struct FailingAgent;
321    impl Agent<S> for FailingAgent {
322        fn name(&self) -> &'static str {
323            "failing_agent"
324        }
325        fn run(&mut self, _state: S, _ctx: &mut Ctx) -> StepResult<S> {
326            Err(StepError::transient("boom"))
327        }
328    }
329
330    struct AlwaysContinue;
331    impl Agent<S> for AlwaysContinue {
332        fn name(&self) -> &'static str {
333            "always_continue"
334        }
335        fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
336            Ok((state, Outcome::Continue))
337        }
338    }
339
340    #[test]
341    fn on_step_fires_on_success() {
342        use std::sync::{Arc, Mutex};
343
344        let count = Arc::new(Mutex::new(0usize));
345        let count_clone = Arc::clone(&count);
346
347        let wf = Workflow::builder("test")
348            .register(DoneAgent)
349            .build()
350            .unwrap();
351
352        let mut runner = Runner::new(wf).on_step(move |_e| {
353            *count_clone.lock().unwrap() += 1;
354        });
355
356        let mut ctx = Ctx::new();
357        runner.run(S(0), &mut ctx).unwrap();
358        assert_eq!(*count.lock().unwrap(), 1);
359    }
360
361    #[test]
362    fn on_error_fires_on_agent_error() {
363        use std::sync::{Arc, Mutex};
364
365        let count = Arc::new(Mutex::new(0usize));
366        let count_clone = Arc::clone(&count);
367
368        let wf = Workflow::builder("test")
369            .register(FailingAgent)
370            .build()
371            .unwrap();
372
373        let mut runner = Runner::new(wf).on_error(move |_e| {
374            *count_clone.lock().unwrap() += 1;
375        });
376
377        let mut ctx = Ctx::new();
378        let _ = runner.run(S(0), &mut ctx);
379        assert_eq!(*count.lock().unwrap(), 1);
380    }
381
382    #[test]
383    fn on_error_fires_on_max_retries() {
384        use std::sync::{Arc, Mutex};
385
386        let count = Arc::new(Mutex::new(0usize));
387        let count_clone = Arc::clone(&count);
388
389        let wf = Workflow::builder("test")
390            .register(AlwaysRetry)
391            .build()
392            .unwrap();
393
394        let mut runner = Runner::new(wf)
395            .with_max_retries(1)
396            .on_error(move |_e| {
397                *count_clone.lock().unwrap() += 1;
398            });
399
400        let mut ctx = Ctx::new();
401        let _ = runner.run(S(0), &mut ctx);
402        assert_eq!(*count.lock().unwrap(), 1);
403    }
404
405    #[test]
406    fn on_error_fires_on_max_steps() {
407        use std::sync::{Arc, Mutex};
408
409        let count = Arc::new(Mutex::new(0usize));
410        let count_clone = Arc::clone(&count);
411
412        let wf = Workflow::builder("test")
413            .register(AlwaysContinue)
414            .register(DoneAgent)
415            .start_at("always_continue")
416            .then("done_agent")
417            .build()
418            .unwrap();
419
420        // Two agents ping-pong via Continue, but max_steps=1 cuts it short
421        let mut runner = Runner::new(wf)
422            .with_max_steps(1)
423            .on_error(move |e| {
424                assert!(e.error.to_string().contains("max_steps exceeded"));
425                *count_clone.lock().unwrap() += 1;
426            });
427
428        let mut ctx = Ctx::new();
429        let _ = runner.run(S(0), &mut ctx);
430        assert_eq!(*count.lock().unwrap(), 1);
431    }
432
433    #[test]
434    fn on_step_receives_correct_step_number() {
435        use std::sync::{Arc, Mutex};
436
437        let steps = Arc::new(Mutex::new(Vec::new()));
438        let steps_clone = Arc::clone(&steps);
439
440        let wf = Workflow::builder("test")
441            .register(RetryAgent {
442                attempts: 0,
443                succeed_on: 3,
444            })
445            .build()
446            .unwrap();
447
448        let mut runner = Runner::new(wf).on_step(move |e| {
449            steps_clone
450                .lock()
451                .unwrap()
452                .push((e.step_number, e.retries));
453        });
454
455        let mut ctx = Ctx::new();
456        runner.run(S(0), &mut ctx).unwrap();
457
458        let steps = steps.lock().unwrap();
459        // 3 steps total: retry at step 1, retry at step 2, done at step 3
460        assert_eq!(steps.len(), 3);
461        assert_eq!(steps[0], (1, 0)); // first retry, 0 retries accumulated yet
462        assert_eq!(steps[1], (2, 1)); // second retry, 1 retry accumulated
463        assert_eq!(steps[2], (3, 2)); // success, 2 retries accumulated
464    }
465
466    // --- Outcome::Next ---
467
468    struct NextAgent;
469    impl Agent<S> for NextAgent {
470        fn name(&self) -> &'static str {
471            "next_agent"
472        }
473        fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
474            Ok((S(state.0 + 1), Outcome::Next("done_agent")))
475        }
476    }
477
478    #[test]
479    fn next_jumps_to_named_agent() {
480        let wf = Workflow::builder("test")
481            .register(NextAgent)
482            .register(DoneAgent)
483            .build()
484            .unwrap();
485
486        let mut runner = Runner::new(wf);
487        let mut ctx = Ctx::new();
488        let result = runner.run(S(0), &mut ctx).unwrap();
489        assert_eq!(result.0, 1);
490    }
491
492    // --- Outcome::Fail ---
493
494    struct FailOutcomeAgent;
495    impl Agent<S> for FailOutcomeAgent {
496        fn name(&self) -> &'static str {
497            "fail_outcome"
498        }
499        fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
500            Ok((state, Outcome::Fail("reason".into())))
501        }
502    }
503
504    #[test]
505    fn fail_outcome_returns_step_error() {
506        let wf = Workflow::builder("test")
507            .register(FailOutcomeAgent)
508            .build()
509            .unwrap();
510
511        let mut runner = Runner::new(wf);
512        let mut ctx = Ctx::new();
513        let err = runner.run(S(0), &mut ctx).err().unwrap();
514        assert_eq!(err.to_string(), "reason");
515    }
516
517    // --- Continue without default_next ---
518
519    #[test]
520    fn continue_without_default_next_errors() {
521        let wf = Workflow::builder("test")
522            .register(AlwaysContinue)
523            .build()
524            .unwrap();
525
526        let mut runner = Runner::new(wf);
527        let mut ctx = Ctx::new();
528        let err = runner.run(S(0), &mut ctx).err().unwrap();
529        assert!(err.to_string().contains("no default next step"));
530    }
531
532    // --- Wait exceeds max_retries ---
533
534    struct AlwaysWait;
535    impl Agent<S> for AlwaysWait {
536        fn name(&self) -> &'static str {
537            "always_wait"
538        }
539        fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
540            Ok((state, Outcome::Wait(Duration::from_millis(1))))
541        }
542    }
543
544    #[test]
545    fn wait_exceeds_max_retries() {
546        let wf = Workflow::builder("test")
547            .register(AlwaysWait)
548            .build()
549            .unwrap();
550
551        let mut runner = Runner::new(wf).with_max_retries(1);
552        let mut ctx = Ctx::new();
553        let err = runner.run(S(0), &mut ctx).err().unwrap();
554        assert!(err.to_string().contains("exceeded max retries"));
555    }
556
557    // --- Retry counter resets on step transition ---
558
559    struct RetryOnceThenContinue {
560        attempts: u32,
561    }
562    impl Agent<S> for RetryOnceThenContinue {
563        fn name(&self) -> &'static str {
564            "retry_once_then_continue"
565        }
566        fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
567            self.attempts += 1;
568            if self.attempts < 2 {
569                Ok((state, Outcome::Retry(RetryHint::new("not yet"))))
570            } else {
571                Ok((state, Outcome::Continue))
572            }
573        }
574    }
575
576    #[test]
577    fn retry_counter_resets_on_step_transition() {
578        use std::sync::{Arc, Mutex};
579
580        let events = Arc::new(Mutex::new(Vec::new()));
581        let events_clone = Arc::clone(&events);
582
583        let wf = Workflow::builder("test")
584            .register(RetryOnceThenContinue { attempts: 0 })
585            .register(DoneAgent)
586            .start_at("retry_once_then_continue")
587            .then("done_agent")
588            .build()
589            .unwrap();
590
591        let mut runner = Runner::new(wf).on_step(move |e| {
592            events_clone
593                .lock()
594                .unwrap()
595                .push((e.agent.to_string(), e.retries));
596        });
597
598        let mut ctx = Ctx::new();
599        runner.run(S(0), &mut ctx).unwrap();
600
601        let events = events.lock().unwrap();
602        // retry_once_then_continue fires twice (retry then continue), done_agent fires once
603        assert_eq!(events.len(), 3);
604        // done_agent should have retries=0 (reset after transition)
605        let done_event = events.iter().find(|(name, _)| name == "done_agent").unwrap();
606        assert_eq!(done_event.1, 0);
607    }
608}