1use crate::state::State;
6use std::collections::HashMap;
7use std::sync::Arc;
8
9pub const START: &str = "__start__";
11pub const END: &str = "__end__";
12
13#[derive(Clone, Debug, PartialEq, Eq)]
15pub enum EdgeTarget {
16 Node(String),
18 End,
20}
21
22impl EdgeTarget {
23 pub fn is_end(&self) -> bool {
25 matches!(self, Self::End)
26 }
27
28 pub fn node_name(&self) -> Option<&str> {
30 match self {
31 Self::Node(name) => Some(name),
32 Self::End => None,
33 }
34 }
35}
36
37impl From<&str> for EdgeTarget {
38 fn from(s: &str) -> Self {
39 if s == END {
40 Self::End
41 } else {
42 Self::Node(s.to_string())
43 }
44 }
45}
46
47pub type RouterFn = Arc<dyn Fn(&State) -> String + Send + Sync>;
49
50#[derive(Clone)]
52pub enum Edge {
53 Direct { source: String, target: EdgeTarget },
55
56 Conditional {
58 source: String,
59 router: RouterFn,
61 targets: HashMap<String, EdgeTarget>,
63 },
64
65 Entry { targets: Vec<String> },
67}
68
69impl std::fmt::Debug for Edge {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 match self {
72 Self::Direct { source, target } => {
73 f.debug_struct("Direct").field("source", source).field("target", target).finish()
74 }
75 Self::Conditional { source, targets, .. } => f
76 .debug_struct("Conditional")
77 .field("source", source)
78 .field("targets", targets)
79 .finish(),
80 Self::Entry { targets } => f.debug_struct("Entry").field("targets", targets).finish(),
81 }
82 }
83}
84
85pub struct Router;
87
88impl Router {
89 pub fn by_field(field: &str) -> impl Fn(&State) -> String + Send + Sync + Clone {
96 let field = field.to_string();
97 move |state: &State| state.get(&field).and_then(|v| v.as_str()).unwrap_or(END).to_string()
98 }
99
100 pub fn has_tool_calls(
107 messages_field: &str,
108 if_true: &str,
109 if_false: &str,
110 ) -> impl Fn(&State) -> String + Send + Sync + Clone {
111 let messages_field = messages_field.to_string();
112 let if_true = if_true.to_string();
113 let if_false = if_false.to_string();
114
115 move |state: &State| {
116 let has_calls = state
117 .get(&messages_field)
118 .and_then(|v| v.as_array())
119 .and_then(|arr| arr.last())
120 .and_then(|msg| msg.get("tool_calls"))
121 .map(|tc| !tc.as_array().map(|a| a.is_empty()).unwrap_or(true))
122 .unwrap_or(false);
123
124 if has_calls {
125 if_true.clone()
126 } else {
127 if_false.clone()
128 }
129 }
130 }
131
132 pub fn by_bool(
139 field: &str,
140 if_true: &str,
141 if_false: &str,
142 ) -> impl Fn(&State) -> String + Send + Sync + Clone {
143 let field = field.to_string();
144 let if_true = if_true.to_string();
145 let if_false = if_false.to_string();
146
147 move |state: &State| {
148 let is_true = state.get(&field).and_then(|v| v.as_bool()).unwrap_or(false);
149
150 if is_true {
151 if_true.clone()
152 } else {
153 if_false.clone()
154 }
155 }
156 }
157
158 pub fn max_iterations(
165 counter_field: &str,
166 max: usize,
167 continue_target: &str,
168 done_target: &str,
169 ) -> impl Fn(&State) -> String + Send + Sync + Clone {
170 let counter_field = counter_field.to_string();
171 let continue_target = continue_target.to_string();
172 let done_target = done_target.to_string();
173
174 move |state: &State| {
175 let count = state.get(&counter_field).and_then(|v| v.as_u64()).unwrap_or(0) as usize;
176
177 if count < max {
178 continue_target.clone()
179 } else {
180 done_target.clone()
181 }
182 }
183 }
184
185 pub fn on_error(
192 error_field: &str,
193 error_target: &str,
194 success_target: &str,
195 ) -> impl Fn(&State) -> String + Send + Sync + Clone {
196 let error_field = error_field.to_string();
197 let error_target = error_target.to_string();
198 let success_target = success_target.to_string();
199
200 move |state: &State| {
201 let has_error = state.get(&error_field).map(|v| !v.is_null()).unwrap_or(false);
202
203 if has_error {
204 error_target.clone()
205 } else {
206 success_target.clone()
207 }
208 }
209 }
210
211 pub fn custom<F>(f: F) -> impl Fn(&State) -> String + Send + Sync + Clone
224 where
225 F: Fn(&State) -> String + Send + Sync + Clone + 'static,
226 {
227 f
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use serde_json::json;
235
236 #[test]
237 fn test_by_field_router() {
238 let router = Router::by_field("next");
239
240 let mut state = State::new();
241 state.insert("next".to_string(), json!("agent_a"));
242 assert_eq!(router(&state), "agent_a");
243
244 state.insert("next".to_string(), json!("agent_b"));
245 assert_eq!(router(&state), "agent_b");
246
247 let empty_state = State::new();
249 assert_eq!(router(&empty_state), END);
250 }
251
252 #[test]
253 fn test_has_tool_calls_router() {
254 let router = Router::has_tool_calls("messages", "tools", END);
255
256 let state = State::new();
258 assert_eq!(router(&state), END);
259
260 let mut state = State::new();
262 state.insert("messages".to_string(), json!([{"role": "assistant", "content": "Hello"}]));
263 assert_eq!(router(&state), END);
264
265 let mut state = State::new();
267 state.insert(
268 "messages".to_string(),
269 json!([{"role": "assistant", "tool_calls": [{"name": "search"}]}]),
270 );
271 assert_eq!(router(&state), "tools");
272 }
273
274 #[test]
275 fn test_by_bool_router() {
276 let router = Router::by_bool("should_continue", "continue", "stop");
277
278 let mut state = State::new();
279 state.insert("should_continue".to_string(), json!(true));
280 assert_eq!(router(&state), "continue");
281
282 state.insert("should_continue".to_string(), json!(false));
283 assert_eq!(router(&state), "stop");
284 }
285
286 #[test]
287 fn test_max_iterations_router() {
288 let router = Router::max_iterations("count", 3, "loop", "done");
289
290 let mut state = State::new();
291 state.insert("count".to_string(), json!(0));
292 assert_eq!(router(&state), "loop");
293
294 state.insert("count".to_string(), json!(2));
295 assert_eq!(router(&state), "loop");
296
297 state.insert("count".to_string(), json!(3));
298 assert_eq!(router(&state), "done");
299 }
300
301 #[test]
302 fn test_edge_target_from_str() {
303 assert_eq!(EdgeTarget::from("node_a"), EdgeTarget::Node("node_a".to_string()));
304 assert_eq!(EdgeTarget::from(END), EdgeTarget::End);
305 }
306}