adk_graph/
edge.rs

1//! Edge types for graph control flow
2//!
3//! Edges define how execution flows between nodes.
4
5use crate::state::State;
6use std::collections::HashMap;
7use std::sync::Arc;
8
9/// Special node identifiers
10pub const START: &str = "__start__";
11pub const END: &str = "__end__";
12
13/// Target of an edge
14#[derive(Clone, Debug, PartialEq, Eq)]
15pub enum EdgeTarget {
16    /// Specific node
17    Node(String),
18    /// End of graph
19    End,
20}
21
22impl EdgeTarget {
23    /// Check if this is the END target
24    pub fn is_end(&self) -> bool {
25        matches!(self, Self::End)
26    }
27
28    /// Get the node name if this is a Node target
29    pub fn node_name(&self) -> Option<&str> {
30        match self {
31            Self::Node(name) => Some(name),
32            Self::End => None,
33        }
34    }
35}
36
37impl From<&str> for EdgeTarget {
38    fn from(s: &str) -> Self {
39        if s == END {
40            Self::End
41        } else {
42            Self::Node(s.to_string())
43        }
44    }
45}
46
47/// Router function type
48pub type RouterFn = Arc<dyn Fn(&State) -> String + Send + Sync>;
49
50/// Edge type
51#[derive(Clone)]
52pub enum Edge {
53    /// Direct edge: always go from source to target
54    Direct { source: String, target: EdgeTarget },
55
56    /// Conditional edge: route based on state
57    Conditional {
58        source: String,
59        /// Router function returns target node name or END
60        router: RouterFn,
61        /// Map of route names to targets (for validation and documentation)
62        targets: HashMap<String, EdgeTarget>,
63    },
64
65    /// Entry edge: from START to first node(s)
66    Entry { targets: Vec<String> },
67}
68
69impl std::fmt::Debug for Edge {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        match self {
72            Self::Direct { source, target } => {
73                f.debug_struct("Direct").field("source", source).field("target", target).finish()
74            }
75            Self::Conditional { source, targets, .. } => f
76                .debug_struct("Conditional")
77                .field("source", source)
78                .field("targets", targets)
79                .finish(),
80            Self::Entry { targets } => f.debug_struct("Entry").field("targets", targets).finish(),
81        }
82    }
83}
84
85/// Router helper functions for common patterns
86pub struct Router;
87
88impl Router {
89    /// Route based on a state field value
90    ///
91    /// # Example
92    /// ```ignore
93    /// .conditional_edge("supervisor", Router::by_field("next_agent"), targets)
94    /// ```
95    pub fn by_field(field: &str) -> impl Fn(&State) -> String + Send + Sync + Clone {
96        let field = field.to_string();
97        move |state: &State| state.get(&field).and_then(|v| v.as_str()).unwrap_or(END).to_string()
98    }
99
100    /// Route based on whether the last message has tool calls
101    ///
102    /// # Example
103    /// ```ignore
104    /// .conditional_edge("agent", Router::has_tool_calls("messages", "tools", END), targets)
105    /// ```
106    pub fn has_tool_calls(
107        messages_field: &str,
108        if_true: &str,
109        if_false: &str,
110    ) -> impl Fn(&State) -> String + Send + Sync + Clone {
111        let messages_field = messages_field.to_string();
112        let if_true = if_true.to_string();
113        let if_false = if_false.to_string();
114
115        move |state: &State| {
116            let has_calls = state
117                .get(&messages_field)
118                .and_then(|v| v.as_array())
119                .and_then(|arr| arr.last())
120                .and_then(|msg| msg.get("tool_calls"))
121                .map(|tc| !tc.as_array().map(|a| a.is_empty()).unwrap_or(true))
122                .unwrap_or(false);
123
124            if has_calls {
125                if_true.clone()
126            } else {
127                if_false.clone()
128            }
129        }
130    }
131
132    /// Route based on a boolean state field
133    ///
134    /// # Example
135    /// ```ignore
136    /// .conditional_edge("check", Router::by_bool("should_continue", "process", END), targets)
137    /// ```
138    pub fn by_bool(
139        field: &str,
140        if_true: &str,
141        if_false: &str,
142    ) -> impl Fn(&State) -> String + Send + Sync + Clone {
143        let field = field.to_string();
144        let if_true = if_true.to_string();
145        let if_false = if_false.to_string();
146
147        move |state: &State| {
148            let is_true = state.get(&field).and_then(|v| v.as_bool()).unwrap_or(false);
149
150            if is_true {
151                if_true.clone()
152            } else {
153                if_false.clone()
154            }
155        }
156    }
157
158    /// Route based on iteration count
159    ///
160    /// # Example
161    /// ```ignore
162    /// .conditional_edge("loop", Router::max_iterations("iteration", 5, "continue", "done"), targets)
163    /// ```
164    pub fn max_iterations(
165        counter_field: &str,
166        max: usize,
167        continue_target: &str,
168        done_target: &str,
169    ) -> impl Fn(&State) -> String + Send + Sync + Clone {
170        let counter_field = counter_field.to_string();
171        let continue_target = continue_target.to_string();
172        let done_target = done_target.to_string();
173
174        move |state: &State| {
175            let count = state.get(&counter_field).and_then(|v| v.as_u64()).unwrap_or(0) as usize;
176
177            if count < max {
178                continue_target.clone()
179            } else {
180                done_target.clone()
181            }
182        }
183    }
184
185    /// Route based on the presence of an error
186    ///
187    /// # Example
188    /// ```ignore
189    /// .conditional_edge("process", Router::on_error("error", "error_handler", "success"), targets)
190    /// ```
191    pub fn on_error(
192        error_field: &str,
193        error_target: &str,
194        success_target: &str,
195    ) -> impl Fn(&State) -> String + Send + Sync + Clone {
196        let error_field = error_field.to_string();
197        let error_target = error_target.to_string();
198        let success_target = success_target.to_string();
199
200        move |state: &State| {
201            let has_error = state.get(&error_field).map(|v| !v.is_null()).unwrap_or(false);
202
203            if has_error {
204                error_target.clone()
205            } else {
206                success_target.clone()
207            }
208        }
209    }
210
211    /// Create a custom router from a closure
212    ///
213    /// # Example
214    /// ```ignore
215    /// .conditional_edge("agent", Router::custom(|state| {
216    ///     if state.get("done").and_then(|v| v.as_bool()).unwrap_or(false) {
217    ///         END.to_string()
218    ///     } else {
219    ///         "continue".to_string()
220    ///     }
221    /// }), targets)
222    /// ```
223    pub fn custom<F>(f: F) -> impl Fn(&State) -> String + Send + Sync + Clone
224    where
225        F: Fn(&State) -> String + Send + Sync + Clone + 'static,
226    {
227        f
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use serde_json::json;
235
236    #[test]
237    fn test_by_field_router() {
238        let router = Router::by_field("next");
239
240        let mut state = State::new();
241        state.insert("next".to_string(), json!("agent_a"));
242        assert_eq!(router(&state), "agent_a");
243
244        state.insert("next".to_string(), json!("agent_b"));
245        assert_eq!(router(&state), "agent_b");
246
247        // Missing field returns END
248        let empty_state = State::new();
249        assert_eq!(router(&empty_state), END);
250    }
251
252    #[test]
253    fn test_has_tool_calls_router() {
254        let router = Router::has_tool_calls("messages", "tools", END);
255
256        // No messages
257        let state = State::new();
258        assert_eq!(router(&state), END);
259
260        // Messages without tool calls
261        let mut state = State::new();
262        state.insert("messages".to_string(), json!([{"role": "assistant", "content": "Hello"}]));
263        assert_eq!(router(&state), END);
264
265        // Messages with tool calls
266        let mut state = State::new();
267        state.insert(
268            "messages".to_string(),
269            json!([{"role": "assistant", "tool_calls": [{"name": "search"}]}]),
270        );
271        assert_eq!(router(&state), "tools");
272    }
273
274    #[test]
275    fn test_by_bool_router() {
276        let router = Router::by_bool("should_continue", "continue", "stop");
277
278        let mut state = State::new();
279        state.insert("should_continue".to_string(), json!(true));
280        assert_eq!(router(&state), "continue");
281
282        state.insert("should_continue".to_string(), json!(false));
283        assert_eq!(router(&state), "stop");
284    }
285
286    #[test]
287    fn test_max_iterations_router() {
288        let router = Router::max_iterations("count", 3, "loop", "done");
289
290        let mut state = State::new();
291        state.insert("count".to_string(), json!(0));
292        assert_eq!(router(&state), "loop");
293
294        state.insert("count".to_string(), json!(2));
295        assert_eq!(router(&state), "loop");
296
297        state.insert("count".to_string(), json!(3));
298        assert_eq!(router(&state), "done");
299    }
300
301    #[test]
302    fn test_edge_target_from_str() {
303        assert_eq!(EdgeTarget::from("node_a"), EdgeTarget::Node("node_a".to_string()));
304        assert_eq!(EdgeTarget::from(END), EdgeTarget::End);
305    }
306}