Skip to main content

agent_line/
workflow.rs

1use crate::Agent;
2use std::collections::HashMap;
3use std::fmt;
4
5// ---------------------------------------------------------------------------
6// WorkflowError
7// ---------------------------------------------------------------------------
8
9/// Errors returned by [`WorkflowBuilder::build`].
10#[derive(Debug)]
11pub enum WorkflowError {
12    /// Two agents were registered with the same name.
13    DuplicateAgent(&'static str),
14    /// A `start_at` or `then` target does not match any registered agent.
15    UnknownStep(&'static str),
16    /// No agents were registered or no start step could be determined.
17    MissingStart,
18}
19
20impl fmt::Display for WorkflowError {
21    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22        match self {
23            Self::DuplicateAgent(name) => write!(f, "duplicate agent name: {name}"),
24            Self::UnknownStep(name) => write!(f, "unknown step: {name}"),
25            Self::MissingStart => write!(f, "workflow missing start step"),
26        }
27    }
28}
29
30impl std::error::Error for WorkflowError {}
31
32// ---------------------------------------------------------------------------
33// WorkflowBuilder
34// ---------------------------------------------------------------------------
35
36/// Step-by-step builder for a [`Workflow`]. Obtained via [`Workflow::builder`].
37pub struct WorkflowBuilder<S: Clone + 'static> {
38    name: &'static str,
39    start: Option<&'static str>,
40    chain_last: Option<&'static str>,
41    agents: HashMap<&'static str, Box<dyn Agent<S>>>,
42    default_next: HashMap<&'static str, &'static str>,
43    duplicate: Option<&'static str>,
44}
45
46impl<S: Clone + 'static> WorkflowBuilder<S> {
47    /// Register an agent. The first agent registered becomes the default start step.
48    pub fn register<A: Agent<S>>(mut self, agent: A) -> Self {
49        let name = agent.name();
50        if self.agents.contains_key(name) {
51            self.duplicate = Some(name);
52        }
53        self.agents.insert(name, Box::new(agent));
54
55        // If this is the first agent added and start isn't set, default start to it.
56        if self.start.is_none() {
57            self.start = Some(name);
58        }
59
60        // Also initialize chain_last if it's not set.
61        if self.chain_last.is_none() {
62            self.chain_last = Some(name);
63        }
64
65        self
66    }
67
68    /// Set which agent runs first (overrides the default).
69    pub fn start_at(mut self, step: &'static str) -> Self {
70        self.start = Some(step);
71        self.chain_last = Some(step);
72        self
73    }
74
75    /// Chain the next step: current(chain_last) -> next
76    pub fn then(mut self, next: &'static str) -> Self {
77        let Some(current) = self.chain_last else {
78            // No prior step; treat `next` as the start
79            self.start = Some(next);
80            self.chain_last = Some(next);
81            return self;
82        };
83
84        self.default_next.insert(current, next);
85        self.chain_last = Some(next);
86        self
87    }
88
89    /// Validate and build the workflow. Returns an error if agents are
90    /// missing, duplicated, or if routing targets don't exist.
91    pub fn build(self) -> Result<Workflow<S>, WorkflowError> {
92        // Check for duplicate agents.
93        if let Some(name) = self.duplicate {
94            return Err(WorkflowError::DuplicateAgent(name));
95        }
96
97        // Check for a start step.
98        let start = self.start.ok_or(WorkflowError::MissingStart)?;
99
100        // Validate start_at target exists as a registered agent.
101        if !self.agents.contains_key(start) {
102            return Err(WorkflowError::UnknownStep(start));
103        }
104
105        // Validate every `then` target exists as a registered agent.
106        for &target in self.default_next.values() {
107            if !self.agents.contains_key(target) {
108                return Err(WorkflowError::UnknownStep(target));
109            }
110        }
111
112        Ok(Workflow {
113            name: self.name,
114            start,
115            agents: self.agents,
116            default_next: self.default_next,
117        })
118    }
119}
120
121// ---------------------------------------------------------------------------
122// Workflow (validated, only constructed via build())
123// ---------------------------------------------------------------------------
124
125/// A validated workflow of agents. Built via [`Workflow::builder`].
126pub struct Workflow<S: Clone + 'static> {
127    name: &'static str,
128    start: &'static str,
129    agents: HashMap<&'static str, Box<dyn Agent<S>>>,
130    default_next: HashMap<&'static str, &'static str>,
131}
132
133impl<S: Clone + 'static> Workflow<S> {
134    /// Create a new builder with the given workflow name.
135    pub fn builder(name: &'static str) -> WorkflowBuilder<S> {
136        WorkflowBuilder {
137            name,
138            start: None,
139            chain_last: None,
140            agents: HashMap::new(),
141            default_next: HashMap::new(),
142            duplicate: None,
143        }
144    }
145
146    /// The workflow's name (set at builder creation).
147    pub fn name(&self) -> &'static str {
148        self.name
149    }
150
151    // --- stuff the runner uses (keep pub(crate)) ---
152    pub(crate) fn start(&self) -> &'static str {
153        self.start
154    }
155
156    pub(crate) fn agent_mut(&mut self, name: &'static str) -> Option<&mut Box<dyn Agent<S>>> {
157        self.agents.get_mut(name)
158    }
159
160    pub(crate) fn default_next(&self, from: &'static str) -> Option<&'static str> {
161        self.default_next.get(from).copied()
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use crate::{Ctx, Outcome, StepResult};
169
170    #[derive(Clone)]
171    struct S;
172
173    struct FakeAgent(&'static str);
174
175    impl Agent<S> for FakeAgent {
176        fn name(&self) -> &'static str {
177            self.0
178        }
179        fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
180            Ok((state, Outcome::Done))
181        }
182    }
183
184    #[test]
185    fn build_valid_workflow() {
186        let wf = Workflow::builder("test")
187            .register(FakeAgent("a"))
188            .register(FakeAgent("b"))
189            .start_at("a")
190            .then("b")
191            .build();
192
193        assert!(wf.is_ok());
194        let wf = wf.unwrap();
195        assert_eq!(wf.name(), "test");
196        assert_eq!(wf.start(), "a");
197        assert_eq!(wf.default_next("a"), Some("b"));
198    }
199
200    #[test]
201    fn missing_start_on_empty_builder() {
202        let err = Workflow::<S>::builder("test").build().err().unwrap();
203        assert!(matches!(err, WorkflowError::MissingStart));
204    }
205
206    #[test]
207    fn unknown_start_at_step() {
208        let err = Workflow::builder("test")
209            .register(FakeAgent("a"))
210            .start_at("missing")
211            .build()
212            .err()
213            .unwrap();
214
215        assert!(matches!(err, WorkflowError::UnknownStep("missing")));
216    }
217
218    #[test]
219    fn unknown_then_target() {
220        let err = Workflow::builder("test")
221            .register(FakeAgent("a"))
222            .start_at("a")
223            .then("missing")
224            .build()
225            .err()
226            .unwrap();
227
228        assert!(matches!(err, WorkflowError::UnknownStep("missing")));
229    }
230
231    #[test]
232    fn first_agent_becomes_default_start() {
233        let wf = Workflow::builder("test")
234            .register(FakeAgent("first"))
235            .build();
236
237        assert!(wf.is_ok());
238    }
239
240    #[test]
241    fn duplicate_agent_rejected() {
242        let err = Workflow::builder("test")
243            .register(FakeAgent("a"))
244            .register(FakeAgent("a"))
245            .build()
246            .err()
247            .unwrap();
248
249        assert!(matches!(err, WorkflowError::DuplicateAgent("a")));
250    }
251}