enact_core/graph/
loader.rs1use super::node::LlmNode;
6use super::schema::{GraphDefinition, NodeDefinition};
7use super::{EdgeTarget, NodeState, StateGraph};
8use crate::providers::ModelProvider;
9use crate::routing::{resolve_model_precedence, DEFAULT_MODEL_ROUTER_ID};
10use anyhow::{anyhow, Context, Result};
11use std::sync::Arc;
12
13#[derive(Default, Clone)]
19pub struct GraphLoaderContext {
20 pub provider: Option<Arc<dyn ModelProvider>>,
22 pub default_model: Option<String>,
24}
25
26impl GraphLoaderContext {
27 pub fn new() -> Self {
29 Self::default()
30 }
31
32 pub fn with_provider(provider: Arc<dyn ModelProvider>) -> Self {
34 Self {
35 provider: Some(provider),
36 default_model: None,
37 }
38 }
39
40 pub fn with_provider_and_model(
42 provider: Arc<dyn ModelProvider>,
43 default_model: impl Into<String>,
44 ) -> Self {
45 Self {
46 provider: Some(provider),
47 default_model: Some(default_model.into()),
48 }
49 }
50}
51
52pub struct GraphLoader;
53
54impl GraphLoader {
55 pub fn load_from_str(yaml: &str) -> Result<StateGraph> {
61 Self::load_from_str_with_context(yaml, &GraphLoaderContext::default())
62 }
63
64 pub fn load_from_str_with_context(yaml: &str, ctx: &GraphLoaderContext) -> Result<StateGraph> {
69 let def: GraphDefinition =
70 serde_yaml::from_str(yaml).context("Failed to parse graph definition YAML")?;
71
72 let mut graph = StateGraph::new();
73
74 for (name, node_def) in &def.nodes {
76 match node_def {
77 NodeDefinition::Llm {
78 model,
79 system_prompt,
80 ..
81 } => {
82 let (resolved_model, selection_source) = resolve_model_precedence(
83 model.as_deref(),
84 def.model.as_deref(),
85 ctx.default_model.as_deref(),
86 );
87
88 if let Some(provider) = &ctx.provider {
89 tracing::debug!(
91 node = %name,
92 selected_model = %resolved_model,
93 source = ?selection_source,
94 "Resolved LLM node model using precedence"
95 );
96
97 let llm_node = LlmNode::with_model(
98 name.clone(),
99 system_prompt.clone(),
100 resolved_model,
101 provider.clone(),
102 );
103 graph = graph.add_node_impl(llm_node);
104 } else {
105 let name_clone = name.clone();
106 let model = if resolved_model.is_empty() {
107 DEFAULT_MODEL_ROUTER_ID.to_string()
108 } else {
109 resolved_model
110 };
111 let prompt = system_prompt.clone();
112
113 graph = graph.add_node(name, move |_state: NodeState| {
114 let n = name_clone.clone();
115 let m = model.clone();
116 let p = prompt.clone();
117 async move {
118 tracing::error!(
119 node = %n,
120 model = %m,
121 prompt = %p,
122 "LLM node requires a model provider but none was configured"
123 );
124 Err(anyhow::anyhow!(
125 "LLM node '{}' requires a model provider. \
126 Use GraphLoaderContext::with_provider() when loading the graph.",
127 n
128 ))
129 }
130 });
131 }
132 }
133 NodeDefinition::Function { action, .. } => {
134 let name_clone = name.clone();
135 let action = action.clone();
136
137 graph = graph.add_node(name, move |state: NodeState| {
138 let n = name_clone.clone();
139 let a = action.clone();
140 async move {
141 println!("⚙️ [Function Node: {}] Action: {}", n, a);
142 if a.starts_with("echo ") {
145 let output = a.trim_start_matches("echo ").to_string();
146 return Ok(NodeState::from_string(&output));
147 }
148 Ok(state)
149 }
150 });
151 }
152 NodeDefinition::Condition { expr, .. } => {
153 let name_clone = name.clone();
154 let expr = expr.clone();
155
156 graph = graph.add_node(name, move |state: NodeState| {
159 let n = name_clone.clone();
160 let e = expr.clone();
161 async move {
162 println!("❓ [Condition Node: {}] Expr: {}", n, e);
163 let input = state.as_str().unwrap_or("");
166 if e.contains("contains('error')") {
167 if input.contains("error") {
168 return Ok(NodeState::from_string("error"));
169 } else {
170 return Ok(NodeState::from_string("ok"));
171 }
172 }
173 Ok(NodeState::from_string("default"))
174 }
175 });
176 }
177 _ => {
178 return Err(anyhow!("Unsupported node type in yaml"));
179 }
180 }
181 }
182
183 for (name, node_def) in &def.nodes {
185 let edges = node_def.edges();
186
187 let is_conditional = matches!(node_def, NodeDefinition::Condition { .. });
196
197 if is_conditional {
198 let edges_clone = edges.clone();
200 let router = move |output: &str| -> EdgeTarget {
201 if let Some(target) = edges_clone.get(output) {
202 if target == "END" {
203 EdgeTarget::End
204 } else {
205 EdgeTarget::Node(target.clone())
206 }
207 } else if let Some(default) = edges_clone.get("_default") {
208 if default == "END" {
209 EdgeTarget::End
210 } else {
211 EdgeTarget::Node(default.clone())
212 }
213 } else {
214 EdgeTarget::End
215 }
216 };
217
218 graph = graph.add_conditional_edge(name, router);
219 } else {
220 if let Some(target) = edges.get("_default") {
224 if target == "END" {
225 graph = graph.add_edge_to_end(name);
226 } else {
227 graph = graph.add_edge(name, target);
228 }
229 }
230 }
231 }
232
233 if def.nodes.contains_key("start") {
243 graph = graph.set_entry_point("start");
244 } else if def.nodes.contains_key("input") {
245 graph = graph.set_entry_point("input");
246 }
247
248 Ok(graph)
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255 use crate::providers::{ChatChoice, ChatMessage, ChatRequest, ChatResponse};
256 use async_trait::async_trait;
257
258 struct MockProvider {
260 response: String,
261 }
262
263 impl MockProvider {
264 fn new(response: impl Into<String>) -> Self {
265 Self {
266 response: response.into(),
267 }
268 }
269 }
270
271 #[async_trait]
272 impl ModelProvider for MockProvider {
273 fn name(&self) -> &str {
274 "mock"
275 }
276
277 async fn chat(&self, _request: ChatRequest) -> anyhow::Result<ChatResponse> {
278 Ok(ChatResponse {
279 id: "mock-id".to_string(),
280 choices: vec![ChatChoice {
281 index: 0,
282 message: ChatMessage::assistant(&self.response),
283 finish_reason: Some("stop".to_string()),
284 }],
285 usage: None,
286 })
287 }
288 }
289
290 const SIMPLE_GRAPH_YAML: &str = r#"
291name: test-graph
292version: "1.0"
293nodes:
294 start:
295 type: llm
296 system_prompt: "You are a helpful assistant"
297 edges:
298 _default: END
299"#;
300
301 #[test]
302 fn test_load_without_context_creates_placeholder() {
303 let graph = GraphLoader::load_from_str(SIMPLE_GRAPH_YAML).unwrap();
304 assert!(graph.nodes.contains_key("start"));
305 }
306
307 #[test]
308 fn test_load_with_context_creates_functional_node() {
309 let provider = Arc::new(MockProvider::new("Hello!"));
310 let ctx = GraphLoaderContext::with_provider(provider);
311
312 let graph = GraphLoader::load_from_str_with_context(SIMPLE_GRAPH_YAML, &ctx).unwrap();
313 assert!(graph.nodes.contains_key("start"));
314 }
315
316 #[tokio::test]
317 async fn test_functional_llm_node_executes() {
318 let provider = Arc::new(MockProvider::new("LLM Response"));
319 let ctx = GraphLoaderContext::with_provider(provider);
320
321 let graph = GraphLoader::load_from_str_with_context(SIMPLE_GRAPH_YAML, &ctx).unwrap();
322 let compiled = graph.compile().unwrap();
323
324 let result = compiled.run("User input").await.unwrap();
326
327 assert_eq!(result.as_str(), Some("LLM Response"));
328 }
329
330 #[test]
331 fn test_context_builder() {
332 let ctx = GraphLoaderContext::new();
333 assert!(ctx.provider.is_none());
334 assert!(ctx.default_model.is_none());
335
336 let provider = Arc::new(MockProvider::new("test"));
337 let ctx = GraphLoaderContext::with_provider(provider);
338 assert!(ctx.provider.is_some());
339 assert!(ctx.default_model.is_none());
340 }
341
342 #[test]
343 fn test_context_builder_with_default_model() {
344 let provider = Arc::new(MockProvider::new("test"));
345 let ctx = GraphLoaderContext::with_provider_and_model(provider, "agent/default-model");
346 assert!(ctx.provider.is_some());
347 assert_eq!(ctx.default_model.as_deref(), Some("agent/default-model"));
348 }
349
350 const MULTI_NODE_YAML: &str = r#"
351name: multi-node-graph
352version: "1.0"
353nodes:
354 start:
355 type: llm
356 model: gpt-4
357 system_prompt: "Process the input"
358 edges:
359 _default: check
360 check:
361 type: condition
362 expr: "input.contains('error')"
363 edges:
364 error: handle_error
365 ok: END
366 handle_error:
367 type: function
368 action: "echo error handled"
369 edges:
370 _default: END
371"#;
372
373 #[test]
374 fn test_multi_node_graph_loading() {
375 let provider = Arc::new(MockProvider::new("processed"));
376 let ctx = GraphLoaderContext::with_provider(provider);
377
378 let graph = GraphLoader::load_from_str_with_context(MULTI_NODE_YAML, &ctx).unwrap();
379
380 assert!(graph.nodes.contains_key("start"));
381 assert!(graph.nodes.contains_key("check"));
382 assert!(graph.nodes.contains_key("handle_error"));
383 }
384}