adk_graph/
edge.rs

1//! Edge types for graph control flow
2//!
3//! Edges define how execution flows between nodes.
4
5use crate::state::State;
6use std::collections::HashMap;
7use std::sync::Arc;
8
9/// Special node identifiers
10pub const START: &str = "__start__";
11pub const END: &str = "__end__";
12
13/// Target of an edge
14#[derive(Clone, Debug, PartialEq, Eq)]
15pub enum EdgeTarget {
16    /// Specific node
17    Node(String),
18    /// End of graph
19    End,
20}
21
22impl EdgeTarget {
23    /// Check if this is the END target
24    pub fn is_end(&self) -> bool {
25        matches!(self, Self::End)
26    }
27
28    /// Get the node name if this is a Node target
29    pub fn node_name(&self) -> Option<&str> {
30        match self {
31            Self::Node(name) => Some(name),
32            Self::End => None,
33        }
34    }
35}
36
37impl From<&str> for EdgeTarget {
38    fn from(s: &str) -> Self {
39        if s == END { Self::End } else { Self::Node(s.to_string()) }
40    }
41}
42
43/// Router function type
44pub type RouterFn = Arc<dyn Fn(&State) -> String + Send + Sync>;
45
46/// Edge type
47#[derive(Clone)]
48pub enum Edge {
49    /// Direct edge: always go from source to target
50    Direct { source: String, target: EdgeTarget },
51
52    /// Conditional edge: route based on state
53    Conditional {
54        source: String,
55        /// Router function returns target node name or END
56        router: RouterFn,
57        /// Map of route names to targets (for validation and documentation)
58        targets: HashMap<String, EdgeTarget>,
59    },
60
61    /// Entry edge: from START to first node(s)
62    Entry { targets: Vec<String> },
63}
64
65impl std::fmt::Debug for Edge {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        match self {
68            Self::Direct { source, target } => {
69                f.debug_struct("Direct").field("source", source).field("target", target).finish()
70            }
71            Self::Conditional { source, targets, .. } => f
72                .debug_struct("Conditional")
73                .field("source", source)
74                .field("targets", targets)
75                .finish(),
76            Self::Entry { targets } => f.debug_struct("Entry").field("targets", targets).finish(),
77        }
78    }
79}
80
81/// Router helper functions for common patterns
82pub struct Router;
83
84impl Router {
85    /// Route based on a state field value
86    ///
87    /// # Example
88    /// ```ignore
89    /// .conditional_edge("supervisor", Router::by_field("next_agent"), targets)
90    /// ```
91    pub fn by_field(field: &str) -> impl Fn(&State) -> String + Send + Sync + Clone {
92        let field = field.to_string();
93        move |state: &State| state.get(&field).and_then(|v| v.as_str()).unwrap_or(END).to_string()
94    }
95
96    /// Route based on whether the last message has tool calls
97    ///
98    /// # Example
99    /// ```ignore
100    /// .conditional_edge("agent", Router::has_tool_calls("messages", "tools", END), targets)
101    /// ```
102    pub fn has_tool_calls(
103        messages_field: &str,
104        if_true: &str,
105        if_false: &str,
106    ) -> impl Fn(&State) -> String + Send + Sync + Clone {
107        let messages_field = messages_field.to_string();
108        let if_true = if_true.to_string();
109        let if_false = if_false.to_string();
110
111        move |state: &State| {
112            let has_calls = state
113                .get(&messages_field)
114                .and_then(|v| v.as_array())
115                .and_then(|arr| arr.last())
116                .and_then(|msg| msg.get("tool_calls"))
117                .map(|tc| !tc.as_array().map(|a| a.is_empty()).unwrap_or(true))
118                .unwrap_or(false);
119
120            if has_calls { if_true.clone() } else { if_false.clone() }
121        }
122    }
123
124    /// Route based on a boolean state field
125    ///
126    /// # Example
127    /// ```ignore
128    /// .conditional_edge("check", Router::by_bool("should_continue", "process", END), targets)
129    /// ```
130    pub fn by_bool(
131        field: &str,
132        if_true: &str,
133        if_false: &str,
134    ) -> impl Fn(&State) -> String + Send + Sync + Clone {
135        let field = field.to_string();
136        let if_true = if_true.to_string();
137        let if_false = if_false.to_string();
138
139        move |state: &State| {
140            let is_true = state.get(&field).and_then(|v| v.as_bool()).unwrap_or(false);
141
142            if is_true { if_true.clone() } else { if_false.clone() }
143        }
144    }
145
146    /// Route based on iteration count
147    ///
148    /// # Example
149    /// ```ignore
150    /// .conditional_edge("loop", Router::max_iterations("iteration", 5, "continue", "done"), targets)
151    /// ```
152    pub fn max_iterations(
153        counter_field: &str,
154        max: usize,
155        continue_target: &str,
156        done_target: &str,
157    ) -> impl Fn(&State) -> String + Send + Sync + Clone {
158        let counter_field = counter_field.to_string();
159        let continue_target = continue_target.to_string();
160        let done_target = done_target.to_string();
161
162        move |state: &State| {
163            let count = state.get(&counter_field).and_then(|v| v.as_u64()).unwrap_or(0) as usize;
164
165            if count < max { continue_target.clone() } else { done_target.clone() }
166        }
167    }
168
169    /// Route based on the presence of an error
170    ///
171    /// # Example
172    /// ```ignore
173    /// .conditional_edge("process", Router::on_error("error", "error_handler", "success"), targets)
174    /// ```
175    pub fn on_error(
176        error_field: &str,
177        error_target: &str,
178        success_target: &str,
179    ) -> impl Fn(&State) -> String + Send + Sync + Clone {
180        let error_field = error_field.to_string();
181        let error_target = error_target.to_string();
182        let success_target = success_target.to_string();
183
184        move |state: &State| {
185            let has_error = state.get(&error_field).map(|v| !v.is_null()).unwrap_or(false);
186
187            if has_error { error_target.clone() } else { success_target.clone() }
188        }
189    }
190
191    /// Create a custom router from a closure
192    ///
193    /// # Example
194    /// ```ignore
195    /// .conditional_edge("agent", Router::custom(|state| {
196    ///     if state.get("done").and_then(|v| v.as_bool()).unwrap_or(false) {
197    ///         END.to_string()
198    ///     } else {
199    ///         "continue".to_string()
200    ///     }
201    /// }), targets)
202    /// ```
203    pub fn custom<F>(f: F) -> impl Fn(&State) -> String + Send + Sync + Clone
204    where
205        F: Fn(&State) -> String + Send + Sync + Clone + 'static,
206    {
207        f
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use serde_json::json;
215
216    #[test]
217    fn test_by_field_router() {
218        let router = Router::by_field("next");
219
220        let mut state = State::new();
221        state.insert("next".to_string(), json!("agent_a"));
222        assert_eq!(router(&state), "agent_a");
223
224        state.insert("next".to_string(), json!("agent_b"));
225        assert_eq!(router(&state), "agent_b");
226
227        // Missing field returns END
228        let empty_state = State::new();
229        assert_eq!(router(&empty_state), END);
230    }
231
232    #[test]
233    fn test_has_tool_calls_router() {
234        let router = Router::has_tool_calls("messages", "tools", END);
235
236        // No messages
237        let state = State::new();
238        assert_eq!(router(&state), END);
239
240        // Messages without tool calls
241        let mut state = State::new();
242        state.insert("messages".to_string(), json!([{"role": "assistant", "content": "Hello"}]));
243        assert_eq!(router(&state), END);
244
245        // Messages with tool calls
246        let mut state = State::new();
247        state.insert(
248            "messages".to_string(),
249            json!([{"role": "assistant", "tool_calls": [{"name": "search"}]}]),
250        );
251        assert_eq!(router(&state), "tools");
252    }
253
254    #[test]
255    fn test_by_bool_router() {
256        let router = Router::by_bool("should_continue", "continue", "stop");
257
258        let mut state = State::new();
259        state.insert("should_continue".to_string(), json!(true));
260        assert_eq!(router(&state), "continue");
261
262        state.insert("should_continue".to_string(), json!(false));
263        assert_eq!(router(&state), "stop");
264    }
265
266    #[test]
267    fn test_max_iterations_router() {
268        let router = Router::max_iterations("count", 3, "loop", "done");
269
270        let mut state = State::new();
271        state.insert("count".to_string(), json!(0));
272        assert_eq!(router(&state), "loop");
273
274        state.insert("count".to_string(), json!(2));
275        assert_eq!(router(&state), "loop");
276
277        state.insert("count".to_string(), json!(3));
278        assert_eq!(router(&state), "done");
279    }
280
281    #[test]
282    fn test_edge_target_from_str() {
283        assert_eq!(EdgeTarget::from("node_a"), EdgeTarget::Node("node_a".to_string()));
284        assert_eq!(EdgeTarget::from(END), EdgeTarget::End);
285    }
286}