1use crate::Agent;
2use std::collections::HashMap;
3use std::fmt;
4
5#[derive(Debug)]
11pub enum WorkflowError {
12 DuplicateAgent(&'static str),
14 UnknownStep(&'static str),
16 MissingStart,
18}
19
20impl fmt::Display for WorkflowError {
21 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22 match self {
23 Self::DuplicateAgent(name) => write!(f, "duplicate agent name: {name}"),
24 Self::UnknownStep(name) => write!(f, "unknown step: {name}"),
25 Self::MissingStart => write!(f, "workflow missing start step"),
26 }
27 }
28}
29
30impl std::error::Error for WorkflowError {}
31
32pub struct WorkflowBuilder<S: Clone + 'static> {
38 name: &'static str,
39 start: Option<&'static str>,
40 chain_last: Option<&'static str>,
41 agents: HashMap<&'static str, Box<dyn Agent<S>>>,
42 default_next: HashMap<&'static str, &'static str>,
43 duplicate: Option<&'static str>,
44}
45
46impl<S: Clone + 'static> WorkflowBuilder<S> {
47 pub fn register<A: Agent<S>>(mut self, agent: A) -> Self {
49 let name = agent.name();
50 if self.agents.contains_key(name) {
51 self.duplicate = Some(name);
52 }
53 self.agents.insert(name, Box::new(agent));
54
55 if self.start.is_none() {
57 self.start = Some(name);
58 }
59
60 if self.chain_last.is_none() {
62 self.chain_last = Some(name);
63 }
64
65 self
66 }
67
68 pub fn start_at(mut self, step: &'static str) -> Self {
70 self.start = Some(step);
71 self.chain_last = Some(step);
72 self
73 }
74
75 pub fn then(mut self, next: &'static str) -> Self {
77 let Some(current) = self.chain_last else {
78 self.start = Some(next);
80 self.chain_last = Some(next);
81 return self;
82 };
83
84 self.default_next.insert(current, next);
85 self.chain_last = Some(next);
86 self
87 }
88
89 pub fn build(self) -> Result<Workflow<S>, WorkflowError> {
92 if let Some(name) = self.duplicate {
94 return Err(WorkflowError::DuplicateAgent(name));
95 }
96
97 let start = self.start.ok_or(WorkflowError::MissingStart)?;
99
100 if !self.agents.contains_key(start) {
102 return Err(WorkflowError::UnknownStep(start));
103 }
104
105 for &target in self.default_next.values() {
107 if !self.agents.contains_key(target) {
108 return Err(WorkflowError::UnknownStep(target));
109 }
110 }
111
112 Ok(Workflow {
113 name: self.name,
114 start,
115 agents: self.agents,
116 default_next: self.default_next,
117 })
118 }
119}
120
121pub struct Workflow<S: Clone + 'static> {
127 name: &'static str,
128 start: &'static str,
129 agents: HashMap<&'static str, Box<dyn Agent<S>>>,
130 default_next: HashMap<&'static str, &'static str>,
131}
132
133impl<S: Clone + 'static> Workflow<S> {
134 pub fn builder(name: &'static str) -> WorkflowBuilder<S> {
136 WorkflowBuilder {
137 name,
138 start: None,
139 chain_last: None,
140 agents: HashMap::new(),
141 default_next: HashMap::new(),
142 duplicate: None,
143 }
144 }
145
146 pub fn name(&self) -> &'static str {
148 self.name
149 }
150
151 pub(crate) fn start(&self) -> &'static str {
153 self.start
154 }
155
156 pub(crate) fn agent_mut(&mut self, name: &'static str) -> Option<&mut Box<dyn Agent<S>>> {
157 self.agents.get_mut(name)
158 }
159
160 pub(crate) fn default_next(&self, from: &'static str) -> Option<&'static str> {
161 self.default_next.get(from).copied()
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use crate::{Ctx, Outcome, StepResult};
169
170 #[derive(Clone)]
171 struct S;
172
173 struct FakeAgent(&'static str);
174
175 impl Agent<S> for FakeAgent {
176 fn name(&self) -> &'static str {
177 self.0
178 }
179 fn run(&mut self, state: S, _ctx: &mut Ctx) -> StepResult<S> {
180 Ok((state, Outcome::Done))
181 }
182 }
183
184 #[test]
185 fn build_valid_workflow() {
186 let wf = Workflow::builder("test")
187 .register(FakeAgent("a"))
188 .register(FakeAgent("b"))
189 .start_at("a")
190 .then("b")
191 .build();
192
193 assert!(wf.is_ok());
194 let wf = wf.unwrap();
195 assert_eq!(wf.name(), "test");
196 assert_eq!(wf.start(), "a");
197 assert_eq!(wf.default_next("a"), Some("b"));
198 }
199
200 #[test]
201 fn missing_start_on_empty_builder() {
202 let err = Workflow::<S>::builder("test").build().err().unwrap();
203 assert!(matches!(err, WorkflowError::MissingStart));
204 }
205
206 #[test]
207 fn unknown_start_at_step() {
208 let err = Workflow::builder("test")
209 .register(FakeAgent("a"))
210 .start_at("missing")
211 .build()
212 .err()
213 .unwrap();
214
215 assert!(matches!(err, WorkflowError::UnknownStep("missing")));
216 }
217
218 #[test]
219 fn unknown_then_target() {
220 let err = Workflow::builder("test")
221 .register(FakeAgent("a"))
222 .start_at("a")
223 .then("missing")
224 .build()
225 .err()
226 .unwrap();
227
228 assert!(matches!(err, WorkflowError::UnknownStep("missing")));
229 }
230
231 #[test]
232 fn first_agent_becomes_default_start() {
233 let wf = Workflow::builder("test")
234 .register(FakeAgent("first"))
235 .build();
236
237 assert!(wf.is_ok());
238 }
239
240 #[test]
241 fn duplicate_agent_rejected() {
242 let err = Workflow::builder("test")
243 .register(FakeAgent("a"))
244 .register(FakeAgent("a"))
245 .build()
246 .err()
247 .unwrap();
248
249 assert!(matches!(err, WorkflowError::DuplicateAgent("a")));
250 }
251}