1use crate::state::State;
6use std::collections::HashMap;
7use std::sync::Arc;
8
9pub const START: &str = "__start__";
11pub const END: &str = "__end__";
12
13#[derive(Clone, Debug, PartialEq, Eq)]
15pub enum EdgeTarget {
16 Node(String),
18 End,
20}
21
22impl EdgeTarget {
23 pub fn is_end(&self) -> bool {
25 matches!(self, Self::End)
26 }
27
28 pub fn node_name(&self) -> Option<&str> {
30 match self {
31 Self::Node(name) => Some(name),
32 Self::End => None,
33 }
34 }
35}
36
37impl From<&str> for EdgeTarget {
38 fn from(s: &str) -> Self {
39 if s == END { Self::End } else { Self::Node(s.to_string()) }
40 }
41}
42
43pub type RouterFn = Arc<dyn Fn(&State) -> String + Send + Sync>;
45
46#[derive(Clone)]
48pub enum Edge {
49 Direct { source: String, target: EdgeTarget },
51
52 Conditional {
54 source: String,
55 router: RouterFn,
57 targets: HashMap<String, EdgeTarget>,
59 },
60
61 Entry { targets: Vec<String> },
63}
64
65impl std::fmt::Debug for Edge {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 match self {
68 Self::Direct { source, target } => {
69 f.debug_struct("Direct").field("source", source).field("target", target).finish()
70 }
71 Self::Conditional { source, targets, .. } => f
72 .debug_struct("Conditional")
73 .field("source", source)
74 .field("targets", targets)
75 .finish(),
76 Self::Entry { targets } => f.debug_struct("Entry").field("targets", targets).finish(),
77 }
78 }
79}
80
81pub struct Router;
83
84impl Router {
85 pub fn by_field(field: &str) -> impl Fn(&State) -> String + Send + Sync + Clone {
92 let field = field.to_string();
93 move |state: &State| state.get(&field).and_then(|v| v.as_str()).unwrap_or(END).to_string()
94 }
95
96 pub fn has_tool_calls(
103 messages_field: &str,
104 if_true: &str,
105 if_false: &str,
106 ) -> impl Fn(&State) -> String + Send + Sync + Clone {
107 let messages_field = messages_field.to_string();
108 let if_true = if_true.to_string();
109 let if_false = if_false.to_string();
110
111 move |state: &State| {
112 let has_calls = state
113 .get(&messages_field)
114 .and_then(|v| v.as_array())
115 .and_then(|arr| arr.last())
116 .and_then(|msg| msg.get("tool_calls"))
117 .map(|tc| !tc.as_array().map(|a| a.is_empty()).unwrap_or(true))
118 .unwrap_or(false);
119
120 if has_calls { if_true.clone() } else { if_false.clone() }
121 }
122 }
123
124 pub fn by_bool(
131 field: &str,
132 if_true: &str,
133 if_false: &str,
134 ) -> impl Fn(&State) -> String + Send + Sync + Clone {
135 let field = field.to_string();
136 let if_true = if_true.to_string();
137 let if_false = if_false.to_string();
138
139 move |state: &State| {
140 let is_true = state.get(&field).and_then(|v| v.as_bool()).unwrap_or(false);
141
142 if is_true { if_true.clone() } else { if_false.clone() }
143 }
144 }
145
146 pub fn max_iterations(
153 counter_field: &str,
154 max: usize,
155 continue_target: &str,
156 done_target: &str,
157 ) -> impl Fn(&State) -> String + Send + Sync + Clone {
158 let counter_field = counter_field.to_string();
159 let continue_target = continue_target.to_string();
160 let done_target = done_target.to_string();
161
162 move |state: &State| {
163 let count = state.get(&counter_field).and_then(|v| v.as_u64()).unwrap_or(0) as usize;
164
165 if count < max { continue_target.clone() } else { done_target.clone() }
166 }
167 }
168
169 pub fn on_error(
176 error_field: &str,
177 error_target: &str,
178 success_target: &str,
179 ) -> impl Fn(&State) -> String + Send + Sync + Clone {
180 let error_field = error_field.to_string();
181 let error_target = error_target.to_string();
182 let success_target = success_target.to_string();
183
184 move |state: &State| {
185 let has_error = state.get(&error_field).map(|v| !v.is_null()).unwrap_or(false);
186
187 if has_error { error_target.clone() } else { success_target.clone() }
188 }
189 }
190
191 pub fn custom<F>(f: F) -> impl Fn(&State) -> String + Send + Sync + Clone
204 where
205 F: Fn(&State) -> String + Send + Sync + Clone + 'static,
206 {
207 f
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use serde_json::json;
215
216 #[test]
217 fn test_by_field_router() {
218 let router = Router::by_field("next");
219
220 let mut state = State::new();
221 state.insert("next".to_string(), json!("agent_a"));
222 assert_eq!(router(&state), "agent_a");
223
224 state.insert("next".to_string(), json!("agent_b"));
225 assert_eq!(router(&state), "agent_b");
226
227 let empty_state = State::new();
229 assert_eq!(router(&empty_state), END);
230 }
231
232 #[test]
233 fn test_has_tool_calls_router() {
234 let router = Router::has_tool_calls("messages", "tools", END);
235
236 let state = State::new();
238 assert_eq!(router(&state), END);
239
240 let mut state = State::new();
242 state.insert("messages".to_string(), json!([{"role": "assistant", "content": "Hello"}]));
243 assert_eq!(router(&state), END);
244
245 let mut state = State::new();
247 state.insert(
248 "messages".to_string(),
249 json!([{"role": "assistant", "tool_calls": [{"name": "search"}]}]),
250 );
251 assert_eq!(router(&state), "tools");
252 }
253
254 #[test]
255 fn test_by_bool_router() {
256 let router = Router::by_bool("should_continue", "continue", "stop");
257
258 let mut state = State::new();
259 state.insert("should_continue".to_string(), json!(true));
260 assert_eq!(router(&state), "continue");
261
262 state.insert("should_continue".to_string(), json!(false));
263 assert_eq!(router(&state), "stop");
264 }
265
266 #[test]
267 fn test_max_iterations_router() {
268 let router = Router::max_iterations("count", 3, "loop", "done");
269
270 let mut state = State::new();
271 state.insert("count".to_string(), json!(0));
272 assert_eq!(router(&state), "loop");
273
274 state.insert("count".to_string(), json!(2));
275 assert_eq!(router(&state), "loop");
276
277 state.insert("count".to_string(), json!(3));
278 assert_eq!(router(&state), "done");
279 }
280
281 #[test]
282 fn test_edge_target_from_str() {
283 assert_eq!(EdgeTarget::from("node_a"), EdgeTarget::Node("node_a".to_string()));
284 assert_eq!(EdgeTarget::from(END), EdgeTarget::End);
285 }
286}