Skip to main content

enact_core/graph/
loader.rs

1//! GraphLoader - loads graph definitions from YAML
2//!
3//! Supports provider injection for LLM nodes through `GraphLoaderContext`.
4
5use super::node::LlmNode;
6use super::schema::{GraphDefinition, NodeDefinition};
7use super::{EdgeTarget, NodeState, StateGraph};
8use crate::providers::ModelProvider;
9use crate::routing::{resolve_model_precedence, DEFAULT_MODEL_ROUTER_ID};
10use anyhow::{anyhow, Context, Result};
11use std::sync::Arc;
12
13/// Context for loading graphs with runtime dependencies
14///
15/// Provides optional access to model providers for creating functional LLM nodes.
16/// When no provider is set, LLM nodes will be created as placeholders that
17/// log their invocation but don't make actual LLM calls.
18#[derive(Default, Clone)]
19pub struct GraphLoaderContext {
20    /// Model provider for LLM nodes
21    pub provider: Option<Arc<dyn ModelProvider>>,
22    /// Agent-level default model pin used in precedence resolution.
23    pub default_model: Option<String>,
24}
25
26impl GraphLoaderContext {
27    /// Create a new empty context
28    pub fn new() -> Self {
29        Self::default()
30    }
31
32    /// Create a context with a model provider
33    pub fn with_provider(provider: Arc<dyn ModelProvider>) -> Self {
34        Self {
35            provider: Some(provider),
36            default_model: None,
37        }
38    }
39
40    /// Create a context with provider and agent-level default model.
41    pub fn with_provider_and_model(
42        provider: Arc<dyn ModelProvider>,
43        default_model: impl Into<String>,
44    ) -> Self {
45        Self {
46            provider: Some(provider),
47            default_model: Some(default_model.into()),
48        }
49    }
50}
51
52pub struct GraphLoader;
53
54impl GraphLoader {
55    /// Load a graph from YAML string
56    ///
57    /// LLM nodes will be created as placeholders that log but don't execute.
58    /// Use `load_from_str_with_context` to provide a model provider for
59    /// functional LLM nodes.
60    pub fn load_from_str(yaml: &str) -> Result<StateGraph> {
61        Self::load_from_str_with_context(yaml, &GraphLoaderContext::default())
62    }
63
64    /// Load a graph from YAML string with runtime context
65    ///
66    /// If the context contains a model provider, LLM nodes will be created
67    /// as functional nodes that execute actual LLM calls.
68    pub fn load_from_str_with_context(yaml: &str, ctx: &GraphLoaderContext) -> Result<StateGraph> {
69        let def: GraphDefinition =
70            serde_yaml::from_str(yaml).context("Failed to parse graph definition YAML")?;
71
72        let mut graph = StateGraph::new();
73
74        // 1. Add all nodes
75        for (name, node_def) in &def.nodes {
76            match node_def {
77                NodeDefinition::Llm {
78                    model,
79                    system_prompt,
80                    ..
81                } => {
82                    let (resolved_model, selection_source) = resolve_model_precedence(
83                        model.as_deref(),
84                        def.model.as_deref(),
85                        ctx.default_model.as_deref(),
86                    );
87
88                    if let Some(provider) = &ctx.provider {
89                        // Create functional LLM node with real provider
90                        tracing::debug!(
91                            node = %name,
92                            selected_model = %resolved_model,
93                            source = ?selection_source,
94                            "Resolved LLM node model using precedence"
95                        );
96
97                        let llm_node = LlmNode::with_model(
98                            name.clone(),
99                            system_prompt.clone(),
100                            resolved_model,
101                            provider.clone(),
102                        );
103                        graph = graph.add_node_impl(llm_node);
104                    } else {
105                        let name_clone = name.clone();
106                        let model = if resolved_model.is_empty() {
107                            DEFAULT_MODEL_ROUTER_ID.to_string()
108                        } else {
109                            resolved_model
110                        };
111                        let prompt = system_prompt.clone();
112
113                        graph = graph.add_node(name, move |_state: NodeState| {
114                            let n = name_clone.clone();
115                            let m = model.clone();
116                            let p = prompt.clone();
117                            async move {
118                                tracing::error!(
119                                    node = %n,
120                                    model = %m,
121                                    prompt = %p,
122                                    "LLM node requires a model provider but none was configured"
123                                );
124                                Err(anyhow::anyhow!(
125                                    "LLM node '{}' requires a model provider. \
126                                     Use GraphLoaderContext::with_provider() when loading the graph.",
127                                    n
128                                ))
129                            }
130                        });
131                    }
132                }
133                NodeDefinition::Function { action, .. } => {
134                    let name_clone = name.clone();
135                    let action = action.clone();
136
137                    graph = graph.add_node(name, move |state: NodeState| {
138                        let n = name_clone.clone();
139                        let a = action.clone();
140                        async move {
141                            println!("⚙️ [Function Node: {}] Action: {}", n, a);
142                            // In a real implementation, this would execute the action command
143                            // For now, allow simple "echo" for testing
144                            if a.starts_with("echo ") {
145                                let output = a.trim_start_matches("echo ").to_string();
146                                return Ok(NodeState::from_string(&output));
147                            }
148                            Ok(state)
149                        }
150                    });
151                }
152                NodeDefinition::Condition { expr, .. } => {
153                    let name_clone = name.clone();
154                    let expr = expr.clone();
155
156                    // Condition node evaluates expression and returns the result key
157                    // (which matches an edge key)
158                    graph = graph.add_node(name, move |state: NodeState| {
159                        let n = name_clone.clone();
160                        let e = expr.clone();
161                        async move {
162                            println!("❓ [Condition Node: {}] Expr: {}", n, e);
163                            // Simple mock evaluation
164                            // If input contains "error", return "error", else "ok"
165                            let input = state.as_str().unwrap_or("");
166                            if e.contains("contains('error')") {
167                                if input.contains("error") {
168                                    return Ok(NodeState::from_string("error"));
169                                } else {
170                                    return Ok(NodeState::from_string("ok"));
171                                }
172                            }
173                            Ok(NodeState::from_string("default"))
174                        }
175                    });
176                }
177                _ => {
178                    return Err(anyhow!("Unsupported node type in yaml"));
179                }
180            }
181        }
182
183        // 2. Add edges
184        for (name, node_def) in &def.nodes {
185            let edges = node_def.edges();
186
187            // Check if this is a conditional node (router)
188            // If it has multiple edges with keys other than "_default",
189            // valid keys are the outputs of the previous node.
190
191            // For Llm/Function nodes, usually they have a single "_default" edge
192            // or specific keys if they return structured data?
193            // The schema implies simple string matching on output.
194
195            let is_conditional = matches!(node_def, NodeDefinition::Condition { .. });
196
197            if is_conditional {
198                // Conditional edges based on node output
199                let edges_clone = edges.clone();
200                let router = move |output: &str| -> EdgeTarget {
201                    if let Some(target) = edges_clone.get(output) {
202                        if target == "END" {
203                            EdgeTarget::End
204                        } else {
205                            EdgeTarget::Node(target.clone())
206                        }
207                    } else if let Some(default) = edges_clone.get("_default") {
208                        if default == "END" {
209                            EdgeTarget::End
210                        } else {
211                            EdgeTarget::Node(default.clone())
212                        }
213                    } else {
214                        EdgeTarget::End
215                    }
216                };
217
218                graph = graph.add_conditional_edge(name, router);
219            } else {
220                // Standard edges
221                // TODO: Support branching from non-condition nodes?
222                // For now, assume "_default" is the main edge
223                if let Some(target) = edges.get("_default") {
224                    if target == "END" {
225                        graph = graph.add_edge_to_end(name);
226                    } else {
227                        graph = graph.add_edge(name, target);
228                    }
229                }
230            }
231        }
232
233        // 3. Set entry point
234        // Ideally schema allows defining it, or we use first node?
235        // Current StateGraph defaults to first node if not set.
236        // We could look for "start" or "input" node?
237        // The implementation_plan example didn't specify entry point explicitly.
238        // Let's assume the first defined node in YAML (but HashMap is unordered).
239        // Use "start" or "input" if present, else random?
240        // Better: require `triggers` or look for a node named "start".
241
242        if def.nodes.contains_key("start") {
243            graph = graph.set_entry_point("start");
244        } else if def.nodes.contains_key("input") {
245            graph = graph.set_entry_point("input");
246        }
247
248        Ok(graph)
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255    use crate::providers::{ChatChoice, ChatMessage, ChatRequest, ChatResponse};
256    use async_trait::async_trait;
257
258    /// Mock provider for testing LLM node creation
259    struct MockProvider {
260        response: String,
261    }
262
263    impl MockProvider {
264        fn new(response: impl Into<String>) -> Self {
265            Self {
266                response: response.into(),
267            }
268        }
269    }
270
271    #[async_trait]
272    impl ModelProvider for MockProvider {
273        fn name(&self) -> &str {
274            "mock"
275        }
276
277        async fn chat(&self, _request: ChatRequest) -> anyhow::Result<ChatResponse> {
278            Ok(ChatResponse {
279                id: "mock-id".to_string(),
280                choices: vec![ChatChoice {
281                    index: 0,
282                    message: ChatMessage::assistant(&self.response),
283                    finish_reason: Some("stop".to_string()),
284                }],
285                usage: None,
286            })
287        }
288    }
289
290    const SIMPLE_GRAPH_YAML: &str = r#"
291name: test-graph
292version: "1.0"
293nodes:
294  start:
295    type: llm
296    system_prompt: "You are a helpful assistant"
297    edges:
298      _default: END
299"#;
300
301    #[test]
302    fn test_load_without_context_creates_placeholder() {
303        let graph = GraphLoader::load_from_str(SIMPLE_GRAPH_YAML).unwrap();
304        assert!(graph.nodes.contains_key("start"));
305    }
306
307    #[test]
308    fn test_load_with_context_creates_functional_node() {
309        let provider = Arc::new(MockProvider::new("Hello!"));
310        let ctx = GraphLoaderContext::with_provider(provider);
311
312        let graph = GraphLoader::load_from_str_with_context(SIMPLE_GRAPH_YAML, &ctx).unwrap();
313        assert!(graph.nodes.contains_key("start"));
314    }
315
316    #[tokio::test]
317    async fn test_functional_llm_node_executes() {
318        let provider = Arc::new(MockProvider::new("LLM Response"));
319        let ctx = GraphLoaderContext::with_provider(provider);
320
321        let graph = GraphLoader::load_from_str_with_context(SIMPLE_GRAPH_YAML, &ctx).unwrap();
322        let compiled = graph.compile().unwrap();
323
324        // Execute the graph
325        let result = compiled.run("User input").await.unwrap();
326
327        assert_eq!(result.as_str(), Some("LLM Response"));
328    }
329
330    #[test]
331    fn test_context_builder() {
332        let ctx = GraphLoaderContext::new();
333        assert!(ctx.provider.is_none());
334        assert!(ctx.default_model.is_none());
335
336        let provider = Arc::new(MockProvider::new("test"));
337        let ctx = GraphLoaderContext::with_provider(provider);
338        assert!(ctx.provider.is_some());
339        assert!(ctx.default_model.is_none());
340    }
341
342    #[test]
343    fn test_context_builder_with_default_model() {
344        let provider = Arc::new(MockProvider::new("test"));
345        let ctx = GraphLoaderContext::with_provider_and_model(provider, "agent/default-model");
346        assert!(ctx.provider.is_some());
347        assert_eq!(ctx.default_model.as_deref(), Some("agent/default-model"));
348    }
349
350    const MULTI_NODE_YAML: &str = r#"
351name: multi-node-graph
352version: "1.0"
353nodes:
354  start:
355    type: llm
356    model: gpt-4
357    system_prompt: "Process the input"
358    edges:
359      _default: check
360  check:
361    type: condition
362    expr: "input.contains('error')"
363    edges:
364      error: handle_error
365      ok: END
366  handle_error:
367    type: function
368    action: "echo error handled"
369    edges:
370      _default: END
371"#;
372
373    #[test]
374    fn test_multi_node_graph_loading() {
375        let provider = Arc::new(MockProvider::new("processed"));
376        let ctx = GraphLoaderContext::with_provider(provider);
377
378        let graph = GraphLoader::load_from_str_with_context(MULTI_NODE_YAML, &ctx).unwrap();
379
380        assert!(graph.nodes.contains_key("start"));
381        assert!(graph.nodes.contains_key("check"));
382        assert!(graph.nodes.contains_key("handle_error"));
383    }
384}