Skip to main content

adk_graph/
graph.rs

1//! StateGraph builder for constructing graphs
2
3use crate::checkpoint::Checkpointer;
4use crate::edge::{END, Edge, EdgeTarget, RouterFn, START};
5use crate::error::{GraphError, Result};
6use crate::node::{FunctionNode, Node, NodeContext, NodeOutput};
7use crate::state::{State, StateSchema};
8use std::collections::{HashMap, HashSet};
9use std::future::Future;
10use std::sync::Arc;
11
12/// Builder for constructing graphs
13pub struct StateGraph {
14    /// State schema
15    pub schema: StateSchema,
16    /// Registered nodes
17    pub nodes: HashMap<String, Arc<dyn Node>>,
18    /// Registered edges
19    pub edges: Vec<Edge>,
20}
21
22impl StateGraph {
23    /// Create a new graph with the given state schema
24    pub fn new(schema: StateSchema) -> Self {
25        Self { schema, nodes: HashMap::new(), edges: vec![] }
26    }
27
28    /// Create with a simple schema (just channel names, all overwrite)
29    pub fn with_channels(channels: &[&str]) -> Self {
30        Self::new(StateSchema::simple(channels))
31    }
32
33    /// Add a node to the graph
34    pub fn add_node<N: Node + 'static>(mut self, node: N) -> Self {
35        self.nodes.insert(node.name().to_string(), Arc::new(node));
36        self
37    }
38
39    /// Add a function as a node
40    pub fn add_node_fn<F, Fut>(self, name: &str, func: F) -> Self
41    where
42        F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
43        Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
44    {
45        self.add_node(FunctionNode::new(name, func))
46    }
47
48    /// Add a direct edge from source to target
49    pub fn add_edge(mut self, source: &str, target: &str) -> Self {
50        let target = EdgeTarget::from(target);
51
52        if source == START {
53            // Find existing entry or create new one
54            let entry_idx = self.edges.iter().position(|e| matches!(e, Edge::Entry { .. }));
55
56            match entry_idx {
57                Some(idx) => {
58                    if let Edge::Entry { targets } = &mut self.edges[idx]
59                        && let EdgeTarget::Node(node) = &target
60                        && !targets.contains(node)
61                    {
62                        targets.push(node.clone());
63                    }
64                }
65                None => {
66                    if let EdgeTarget::Node(node) = target {
67                        self.edges.push(Edge::Entry { targets: vec![node] });
68                    }
69                }
70            }
71        } else {
72            self.edges.push(Edge::Direct { source: source.to_string(), target });
73        }
74
75        self
76    }
77
78    /// Add a conditional edge with a router function
79    pub fn add_conditional_edges<F, I>(mut self, source: &str, router: F, targets: I) -> Self
80    where
81        F: Fn(&State) -> String + Send + Sync + 'static,
82        I: IntoIterator<Item = (&'static str, &'static str)>,
83    {
84        let targets_map: HashMap<String, EdgeTarget> =
85            targets.into_iter().map(|(k, v)| (k.to_string(), EdgeTarget::from(v))).collect();
86
87        self.edges.push(Edge::Conditional {
88            source: source.to_string(),
89            router: Arc::new(router),
90            targets: targets_map,
91        });
92
93        self
94    }
95
96    /// Add a conditional edge with an Arc router (for pre-built routers)
97    pub fn add_conditional_edges_arc<I>(
98        mut self,
99        source: &str,
100        router: RouterFn,
101        targets: I,
102    ) -> Self
103    where
104        I: IntoIterator<Item = (&'static str, &'static str)>,
105    {
106        let targets_map: HashMap<String, EdgeTarget> =
107            targets.into_iter().map(|(k, v)| (k.to_string(), EdgeTarget::from(v))).collect();
108
109        self.edges.push(Edge::Conditional {
110            source: source.to_string(),
111            router,
112            targets: targets_map,
113        });
114
115        self
116    }
117
118    /// Compile the graph for execution
119    pub fn compile(self) -> Result<CompiledGraph> {
120        self.validate()?;
121
122        Ok(CompiledGraph {
123            schema: self.schema,
124            nodes: self.nodes,
125            edges: self.edges,
126            checkpointer: None,
127            interrupt_before: HashSet::new(),
128            interrupt_after: HashSet::new(),
129            recursion_limit: 50,
130            timeout_policies: HashMap::new(),
131            default_timeout: None,
132            deferred_configs: HashMap::new(),
133            #[cfg(feature = "node-cache")]
134            cache_policies: HashMap::new(),
135        })
136    }
137
138    /// Validate the graph structure
139    fn validate(&self) -> Result<()> {
140        // Check for entry point
141        let has_entry = self.edges.iter().any(|e| matches!(e, Edge::Entry { .. }));
142        if !has_entry {
143            return Err(GraphError::NoEntryPoint);
144        }
145
146        // Check all node references exist
147        for edge in &self.edges {
148            match edge {
149                Edge::Direct { source, target } => {
150                    if source != START && !self.nodes.contains_key(source) {
151                        return Err(GraphError::NodeNotFound(source.clone()));
152                    }
153                    if let EdgeTarget::Node(name) = target
154                        && !self.nodes.contains_key(name)
155                    {
156                        return Err(GraphError::EdgeTargetNotFound(name.clone()));
157                    }
158                }
159                Edge::Conditional { source, targets, .. } => {
160                    if !self.nodes.contains_key(source) {
161                        return Err(GraphError::NodeNotFound(source.clone()));
162                    }
163                    for target in targets.values() {
164                        if let EdgeTarget::Node(name) = target
165                            && !self.nodes.contains_key(name)
166                        {
167                            return Err(GraphError::EdgeTargetNotFound(name.clone()));
168                        }
169                    }
170                }
171                Edge::Entry { targets } => {
172                    for target in targets {
173                        if !self.nodes.contains_key(target) {
174                            return Err(GraphError::EdgeTargetNotFound(target.clone()));
175                        }
176                    }
177                }
178            }
179        }
180
181        Ok(())
182    }
183}
184
185/// A compiled graph ready for execution
186pub struct CompiledGraph {
187    pub(crate) schema: StateSchema,
188    pub(crate) nodes: HashMap<String, Arc<dyn Node>>,
189    pub(crate) edges: Vec<Edge>,
190    pub(crate) checkpointer: Option<Arc<dyn Checkpointer>>,
191    pub(crate) interrupt_before: HashSet<String>,
192    pub(crate) interrupt_after: HashSet<String>,
193    pub(crate) recursion_limit: usize,
194    /// Per-node timeout policies, keyed by node name.
195    pub(crate) timeout_policies: HashMap<String, crate::timeout::TimeoutPolicy>,
196    /// Default timeout policy applied to all nodes without an explicit override.
197    pub(crate) default_timeout: Option<crate::timeout::TimeoutPolicy>,
198    /// Deferred node configurations, keyed by node name.
199    pub(crate) deferred_configs: HashMap<String, crate::deferred::DeferredNodeConfig>,
200    /// Per-node cache policies, keyed by node name.
201    #[cfg(feature = "node-cache")]
202    pub(crate) cache_policies: HashMap<String, crate::cache::NodeCachePolicy>,
203}
204
205impl CompiledGraph {
206    /// Configure checkpointing
207    pub fn with_checkpointer<C: Checkpointer + 'static>(mut self, checkpointer: C) -> Self {
208        self.checkpointer = Some(Arc::new(checkpointer));
209        self
210    }
211
212    /// Configure checkpointing with Arc
213    pub fn with_checkpointer_arc(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
214        self.checkpointer = Some(checkpointer);
215        self
216    }
217
218    /// Configure interrupt before specific nodes
219    pub fn with_interrupt_before(mut self, nodes: &[&str]) -> Self {
220        self.interrupt_before = nodes.iter().map(|s| s.to_string()).collect();
221        self
222    }
223
224    /// Configure interrupt after specific nodes
225    pub fn with_interrupt_after(mut self, nodes: &[&str]) -> Self {
226        self.interrupt_after = nodes.iter().map(|s| s.to_string()).collect();
227        self
228    }
229
230    /// Set recursion limit for cycles
231    pub fn with_recursion_limit(mut self, limit: usize) -> Self {
232        self.recursion_limit = limit;
233        self
234    }
235
236    /// Get the effective timeout policy for a node.
237    ///
238    /// Returns the per-node policy if one was configured via
239    /// `GraphAgentBuilder::node_timeout`, otherwise falls back to the
240    /// default timeout policy. Returns `None` if neither is set.
241    pub fn timeout_policy_for(&self, node_name: &str) -> Option<&crate::timeout::TimeoutPolicy> {
242        self.timeout_policies.get(node_name).or(self.default_timeout.as_ref())
243    }
244
245    /// Get entry nodes
246    pub fn get_entry_nodes(&self) -> Vec<String> {
247        for edge in &self.edges {
248            if let Edge::Entry { targets } = edge {
249                return targets.clone();
250            }
251        }
252        vec![]
253    }
254
255    /// Get next nodes after executing the given nodes
256    pub fn get_next_nodes(&self, executed: &[String], state: &State) -> Vec<String> {
257        let mut next = Vec::new();
258
259        for edge in &self.edges {
260            match edge {
261                Edge::Direct { source, target: EdgeTarget::Node(n) }
262                    if executed.contains(source) =>
263                {
264                    if !next.contains(n) {
265                        next.push(n.clone());
266                    }
267                }
268                Edge::Conditional { source, router, targets } if executed.contains(source) => {
269                    let route = router(state);
270                    if let Some(EdgeTarget::Node(n)) = targets.get(&route)
271                        && !next.contains(n)
272                    {
273                        next.push(n.clone());
274                    }
275                    // If route leads to END or not found in targets, next will be empty for this path
276                }
277                _ => {}
278            }
279        }
280
281        next
282    }
283
284    /// Check if any of the executed nodes lead to END
285    pub fn leads_to_end(&self, executed: &[String], state: &State) -> bool {
286        for edge in &self.edges {
287            match edge {
288                Edge::Direct { source, target } if executed.contains(source) => {
289                    if target.is_end() {
290                        return true;
291                    }
292                }
293                Edge::Conditional { source, router, targets } if executed.contains(source) => {
294                    let route = router(state);
295                    if route == END {
296                        return true;
297                    }
298                    if let Some(target) = targets.get(&route)
299                        && target.is_end()
300                    {
301                        return true;
302                    }
303                }
304                _ => {}
305            }
306        }
307        false
308    }
309
310    /// Get all upstream source nodes for a given target node.
311    ///
312    /// Returns the names of all nodes that have an edge pointing to the given
313    /// target node. This is used by the deferred node scheduler to determine
314    /// which upstream paths must complete before a fan-in node can execute.
315    ///
316    /// For conditional edges, all possible source nodes are included since any
317    /// of them could route to the target at runtime.
318    pub fn get_upstream_nodes(&self, target_node: &str) -> Vec<String> {
319        let mut sources = Vec::new();
320
321        for edge in &self.edges {
322            match edge {
323                Edge::Direct { source, target } => {
324                    if let EdgeTarget::Node(name) = target
325                        && name == target_node
326                        && !sources.contains(source)
327                    {
328                        sources.push(source.clone());
329                    }
330                }
331                Edge::Conditional { source, targets, .. } => {
332                    for target in targets.values() {
333                        if let EdgeTarget::Node(name) = target
334                            && name == target_node
335                            && !sources.contains(source)
336                        {
337                            sources.push(source.clone());
338                        }
339                    }
340                }
341                Edge::Entry { targets } => {
342                    if targets.contains(&target_node.to_string()) {
343                        // Entry nodes come from START, which is not a real node
344                        // so we don't add it as an upstream source
345                    }
346                }
347            }
348        }
349
350        sources
351    }
352
353    /// Get the state schema
354    pub fn schema(&self) -> &StateSchema {
355        &self.schema
356    }
357
358    /// Get the checkpointer if configured
359    pub fn checkpointer(&self) -> Option<&Arc<dyn Checkpointer>> {
360        self.checkpointer.as_ref()
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367    use serde_json::json;
368
369    #[test]
370    fn test_basic_graph_construction() {
371        let graph = StateGraph::with_channels(&["input", "output"])
372            .add_node_fn("process", |_ctx| async { Ok(NodeOutput::new()) })
373            .add_edge(START, "process")
374            .add_edge("process", END)
375            .compile();
376
377        assert!(graph.is_ok());
378    }
379
380    #[test]
381    fn test_graph_missing_entry() {
382        let graph = StateGraph::with_channels(&["input"])
383            .add_node_fn("process", |_ctx| async { Ok(NodeOutput::new()) })
384            .add_edge("process", END) // No START -> process edge
385            .compile();
386
387        assert!(matches!(graph, Err(GraphError::NoEntryPoint)));
388    }
389
390    #[test]
391    fn test_graph_missing_node() {
392        let graph = StateGraph::with_channels(&["input"]).add_edge(START, "nonexistent").compile();
393
394        assert!(matches!(graph, Err(GraphError::EdgeTargetNotFound(_))));
395    }
396
397    #[test]
398    fn test_conditional_edges() {
399        let graph = StateGraph::with_channels(&["next"])
400            .add_node_fn("router", |_ctx| async { Ok(NodeOutput::new()) })
401            .add_node_fn("path_a", |_ctx| async { Ok(NodeOutput::new()) })
402            .add_node_fn("path_b", |_ctx| async { Ok(NodeOutput::new()) })
403            .add_edge(START, "router")
404            .add_conditional_edges(
405                "router",
406                |state| state.get("next").and_then(|v| v.as_str()).unwrap_or(END).to_string(),
407                [("path_a", "path_a"), ("path_b", "path_b"), (END, END)],
408            )
409            .compile()
410            .unwrap();
411
412        // Test routing
413        let mut state = State::new();
414        state.insert("next".to_string(), json!("path_a"));
415        let next = graph.get_next_nodes(&["router".to_string()], &state);
416        assert_eq!(next, vec!["path_a".to_string()]);
417
418        state.insert("next".to_string(), json!("path_b"));
419        let next = graph.get_next_nodes(&["router".to_string()], &state);
420        assert_eq!(next, vec!["path_b".to_string()]);
421    }
422}