Skip to main content

adk_graph/
graph.rs

1//! StateGraph builder for constructing graphs
2
3use crate::checkpoint::Checkpointer;
4use crate::edge::{END, Edge, EdgeTarget, RouterFn, START};
5use crate::error::{GraphError, Result};
6use crate::node::{FunctionNode, Node, NodeContext, NodeOutput};
7use crate::state::{State, StateSchema};
8use std::collections::{HashMap, HashSet};
9use std::future::Future;
10use std::sync::Arc;
11
12/// Builder for constructing graphs
13pub struct StateGraph {
14    /// State schema
15    pub schema: StateSchema,
16    /// Registered nodes
17    pub nodes: HashMap<String, Arc<dyn Node>>,
18    /// Registered edges
19    pub edges: Vec<Edge>,
20}
21
22impl StateGraph {
23    /// Create a new graph with the given state schema
24    pub fn new(schema: StateSchema) -> Self {
25        Self { schema, nodes: HashMap::new(), edges: vec![] }
26    }
27
28    /// Create with a simple schema (just channel names, all overwrite)
29    pub fn with_channels(channels: &[&str]) -> Self {
30        Self::new(StateSchema::simple(channels))
31    }
32
33    /// Add a node to the graph
34    pub fn add_node<N: Node + 'static>(mut self, node: N) -> Self {
35        self.nodes.insert(node.name().to_string(), Arc::new(node));
36        self
37    }
38
39    /// Add a function as a node
40    pub fn add_node_fn<F, Fut>(self, name: &str, func: F) -> Self
41    where
42        F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
43        Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
44    {
45        self.add_node(FunctionNode::new(name, func))
46    }
47
48    /// Add a direct edge from source to target
49    pub fn add_edge(mut self, source: &str, target: &str) -> Self {
50        let target = EdgeTarget::from(target);
51
52        if source == START {
53            // Find existing entry or create new one
54            let entry_idx = self.edges.iter().position(|e| matches!(e, Edge::Entry { .. }));
55
56            match entry_idx {
57                Some(idx) => {
58                    if let Edge::Entry { targets } = &mut self.edges[idx] {
59                        if let EdgeTarget::Node(node) = &target {
60                            if !targets.contains(node) {
61                                targets.push(node.clone());
62                            }
63                        }
64                    }
65                }
66                None => {
67                    if let EdgeTarget::Node(node) = target {
68                        self.edges.push(Edge::Entry { targets: vec![node] });
69                    }
70                }
71            }
72        } else {
73            self.edges.push(Edge::Direct { source: source.to_string(), target });
74        }
75
76        self
77    }
78
79    /// Add a conditional edge with a router function
80    pub fn add_conditional_edges<F, I>(mut self, source: &str, router: F, targets: I) -> Self
81    where
82        F: Fn(&State) -> String + Send + Sync + 'static,
83        I: IntoIterator<Item = (&'static str, &'static str)>,
84    {
85        let targets_map: HashMap<String, EdgeTarget> =
86            targets.into_iter().map(|(k, v)| (k.to_string(), EdgeTarget::from(v))).collect();
87
88        self.edges.push(Edge::Conditional {
89            source: source.to_string(),
90            router: Arc::new(router),
91            targets: targets_map,
92        });
93
94        self
95    }
96
97    /// Add a conditional edge with an Arc router (for pre-built routers)
98    pub fn add_conditional_edges_arc<I>(
99        mut self,
100        source: &str,
101        router: RouterFn,
102        targets: I,
103    ) -> Self
104    where
105        I: IntoIterator<Item = (&'static str, &'static str)>,
106    {
107        let targets_map: HashMap<String, EdgeTarget> =
108            targets.into_iter().map(|(k, v)| (k.to_string(), EdgeTarget::from(v))).collect();
109
110        self.edges.push(Edge::Conditional {
111            source: source.to_string(),
112            router,
113            targets: targets_map,
114        });
115
116        self
117    }
118
119    /// Compile the graph for execution
120    pub fn compile(self) -> Result<CompiledGraph> {
121        self.validate()?;
122
123        Ok(CompiledGraph {
124            schema: self.schema,
125            nodes: self.nodes,
126            edges: self.edges,
127            checkpointer: None,
128            interrupt_before: HashSet::new(),
129            interrupt_after: HashSet::new(),
130            recursion_limit: 50,
131            timeout_policies: HashMap::new(),
132            default_timeout: None,
133            deferred_configs: HashMap::new(),
134            #[cfg(feature = "node-cache")]
135            cache_policies: HashMap::new(),
136        })
137    }
138
139    /// Validate the graph structure
140    fn validate(&self) -> Result<()> {
141        // Check for entry point
142        let has_entry = self.edges.iter().any(|e| matches!(e, Edge::Entry { .. }));
143        if !has_entry {
144            return Err(GraphError::NoEntryPoint);
145        }
146
147        // Check all node references exist
148        for edge in &self.edges {
149            match edge {
150                Edge::Direct { source, target } => {
151                    if source != START && !self.nodes.contains_key(source) {
152                        return Err(GraphError::NodeNotFound(source.clone()));
153                    }
154                    if let EdgeTarget::Node(name) = target {
155                        if !self.nodes.contains_key(name) {
156                            return Err(GraphError::EdgeTargetNotFound(name.clone()));
157                        }
158                    }
159                }
160                Edge::Conditional { source, targets, .. } => {
161                    if !self.nodes.contains_key(source) {
162                        return Err(GraphError::NodeNotFound(source.clone()));
163                    }
164                    for target in targets.values() {
165                        if let EdgeTarget::Node(name) = target {
166                            if !self.nodes.contains_key(name) {
167                                return Err(GraphError::EdgeTargetNotFound(name.clone()));
168                            }
169                        }
170                    }
171                }
172                Edge::Entry { targets } => {
173                    for target in targets {
174                        if !self.nodes.contains_key(target) {
175                            return Err(GraphError::EdgeTargetNotFound(target.clone()));
176                        }
177                    }
178                }
179            }
180        }
181
182        Ok(())
183    }
184}
185
186/// A compiled graph ready for execution
187pub struct CompiledGraph {
188    pub(crate) schema: StateSchema,
189    pub(crate) nodes: HashMap<String, Arc<dyn Node>>,
190    pub(crate) edges: Vec<Edge>,
191    pub(crate) checkpointer: Option<Arc<dyn Checkpointer>>,
192    pub(crate) interrupt_before: HashSet<String>,
193    pub(crate) interrupt_after: HashSet<String>,
194    pub(crate) recursion_limit: usize,
195    /// Per-node timeout policies, keyed by node name.
196    pub(crate) timeout_policies: HashMap<String, crate::timeout::TimeoutPolicy>,
197    /// Default timeout policy applied to all nodes without an explicit override.
198    pub(crate) default_timeout: Option<crate::timeout::TimeoutPolicy>,
199    /// Deferred node configurations, keyed by node name.
200    pub(crate) deferred_configs: HashMap<String, crate::deferred::DeferredNodeConfig>,
201    /// Per-node cache policies, keyed by node name.
202    #[cfg(feature = "node-cache")]
203    pub(crate) cache_policies: HashMap<String, crate::cache::NodeCachePolicy>,
204}
205
206impl CompiledGraph {
207    /// Configure checkpointing
208    pub fn with_checkpointer<C: Checkpointer + 'static>(mut self, checkpointer: C) -> Self {
209        self.checkpointer = Some(Arc::new(checkpointer));
210        self
211    }
212
213    /// Configure checkpointing with Arc
214    pub fn with_checkpointer_arc(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
215        self.checkpointer = Some(checkpointer);
216        self
217    }
218
219    /// Configure interrupt before specific nodes
220    pub fn with_interrupt_before(mut self, nodes: &[&str]) -> Self {
221        self.interrupt_before = nodes.iter().map(|s| s.to_string()).collect();
222        self
223    }
224
225    /// Configure interrupt after specific nodes
226    pub fn with_interrupt_after(mut self, nodes: &[&str]) -> Self {
227        self.interrupt_after = nodes.iter().map(|s| s.to_string()).collect();
228        self
229    }
230
231    /// Set recursion limit for cycles
232    pub fn with_recursion_limit(mut self, limit: usize) -> Self {
233        self.recursion_limit = limit;
234        self
235    }
236
237    /// Get the effective timeout policy for a node.
238    ///
239    /// Returns the per-node policy if one was configured via
240    /// [`GraphAgentBuilder::node_timeout`], otherwise falls back to the
241    /// default timeout policy. Returns `None` if neither is set.
242    pub fn timeout_policy_for(&self, node_name: &str) -> Option<&crate::timeout::TimeoutPolicy> {
243        self.timeout_policies.get(node_name).or(self.default_timeout.as_ref())
244    }
245
246    /// Get entry nodes
247    pub fn get_entry_nodes(&self) -> Vec<String> {
248        for edge in &self.edges {
249            if let Edge::Entry { targets } = edge {
250                return targets.clone();
251            }
252        }
253        vec![]
254    }
255
256    /// Get next nodes after executing the given nodes
257    pub fn get_next_nodes(&self, executed: &[String], state: &State) -> Vec<String> {
258        let mut next = Vec::new();
259
260        for edge in &self.edges {
261            match edge {
262                Edge::Direct { source, target: EdgeTarget::Node(n) }
263                    if executed.contains(source) =>
264                {
265                    if !next.contains(n) {
266                        next.push(n.clone());
267                    }
268                }
269                Edge::Conditional { source, router, targets } if executed.contains(source) => {
270                    let route = router(state);
271                    if let Some(EdgeTarget::Node(n)) = targets.get(&route) {
272                        if !next.contains(n) {
273                            next.push(n.clone());
274                        }
275                    }
276                    // If route leads to END or not found in targets, next will be empty for this path
277                }
278                _ => {}
279            }
280        }
281
282        next
283    }
284
285    /// Check if any of the executed nodes lead to END
286    pub fn leads_to_end(&self, executed: &[String], state: &State) -> bool {
287        for edge in &self.edges {
288            match edge {
289                Edge::Direct { source, target } if executed.contains(source) => {
290                    if target.is_end() {
291                        return true;
292                    }
293                }
294                Edge::Conditional { source, router, targets } if executed.contains(source) => {
295                    let route = router(state);
296                    if route == END {
297                        return true;
298                    }
299                    if let Some(target) = targets.get(&route) {
300                        if target.is_end() {
301                            return true;
302                        }
303                    }
304                }
305                _ => {}
306            }
307        }
308        false
309    }
310
311    /// Get all upstream source nodes for a given target node.
312    ///
313    /// Returns the names of all nodes that have an edge pointing to the given
314    /// target node. This is used by the deferred node scheduler to determine
315    /// which upstream paths must complete before a fan-in node can execute.
316    ///
317    /// For conditional edges, all possible source nodes are included since any
318    /// of them could route to the target at runtime.
319    pub fn get_upstream_nodes(&self, target_node: &str) -> Vec<String> {
320        let mut sources = Vec::new();
321
322        for edge in &self.edges {
323            match edge {
324                Edge::Direct { source, target } => {
325                    if let EdgeTarget::Node(name) = target {
326                        if name == target_node && !sources.contains(source) {
327                            sources.push(source.clone());
328                        }
329                    }
330                }
331                Edge::Conditional { source, targets, .. } => {
332                    for target in targets.values() {
333                        if let EdgeTarget::Node(name) = target {
334                            if name == target_node && !sources.contains(source) {
335                                sources.push(source.clone());
336                            }
337                        }
338                    }
339                }
340                Edge::Entry { targets } => {
341                    if targets.contains(&target_node.to_string()) {
342                        // Entry nodes come from START, which is not a real node
343                        // so we don't add it as an upstream source
344                    }
345                }
346            }
347        }
348
349        sources
350    }
351
352    /// Get the state schema
353    pub fn schema(&self) -> &StateSchema {
354        &self.schema
355    }
356
357    /// Get the checkpointer if configured
358    pub fn checkpointer(&self) -> Option<&Arc<dyn Checkpointer>> {
359        self.checkpointer.as_ref()
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366    use serde_json::json;
367
368    #[test]
369    fn test_basic_graph_construction() {
370        let graph = StateGraph::with_channels(&["input", "output"])
371            .add_node_fn("process", |_ctx| async { Ok(NodeOutput::new()) })
372            .add_edge(START, "process")
373            .add_edge("process", END)
374            .compile();
375
376        assert!(graph.is_ok());
377    }
378
379    #[test]
380    fn test_graph_missing_entry() {
381        let graph = StateGraph::with_channels(&["input"])
382            .add_node_fn("process", |_ctx| async { Ok(NodeOutput::new()) })
383            .add_edge("process", END) // No START -> process edge
384            .compile();
385
386        assert!(matches!(graph, Err(GraphError::NoEntryPoint)));
387    }
388
389    #[test]
390    fn test_graph_missing_node() {
391        let graph = StateGraph::with_channels(&["input"]).add_edge(START, "nonexistent").compile();
392
393        assert!(matches!(graph, Err(GraphError::EdgeTargetNotFound(_))));
394    }
395
396    #[test]
397    fn test_conditional_edges() {
398        let graph = StateGraph::with_channels(&["next"])
399            .add_node_fn("router", |_ctx| async { Ok(NodeOutput::new()) })
400            .add_node_fn("path_a", |_ctx| async { Ok(NodeOutput::new()) })
401            .add_node_fn("path_b", |_ctx| async { Ok(NodeOutput::new()) })
402            .add_edge(START, "router")
403            .add_conditional_edges(
404                "router",
405                |state| state.get("next").and_then(|v| v.as_str()).unwrap_or(END).to_string(),
406                [("path_a", "path_a"), ("path_b", "path_b"), (END, END)],
407            )
408            .compile()
409            .unwrap();
410
411        // Test routing
412        let mut state = State::new();
413        state.insert("next".to_string(), json!("path_a"));
414        let next = graph.get_next_nodes(&["router".to_string()], &state);
415        assert_eq!(next, vec!["path_a".to_string()]);
416
417        state.insert("next".to_string(), json!("path_b"));
418        let next = graph.get_next_nodes(&["router".to_string()], &state);
419        assert_eq!(next, vec!["path_b".to_string()]);
420    }
421}