Skip to main content

enact_core/graph/node/
llm.rs

1//! LLM Node - wraps LlmCallable for use in graphs
2
3use super::{Node, NodeState};
4use crate::callable::LlmCallable;
5use crate::providers::ModelProvider;
6use async_trait::async_trait;
7use std::sync::Arc;
8
9/// LLM Node - executes an LLM callable within a graph
10///
11/// This node wraps an `LlmCallable` and adapts it to the `Node` interface.
12/// The input `NodeState` is converted to a prompt, and the LLM response
13/// is converted back to a `NodeState` for downstream nodes.
14pub struct LlmNode {
15    name: String,
16    callable: LlmCallable,
17}
18
19impl LlmNode {
20    /// Create a new LLM node with a provider
21    pub fn new(
22        name: impl Into<String>,
23        system_prompt: impl Into<String>,
24        provider: Arc<dyn ModelProvider>,
25    ) -> Self {
26        let name = name.into();
27        let callable = LlmCallable::with_provider(name.clone(), system_prompt, provider);
28        Self { name, callable }
29    }
30
31    /// Create with explicit model pin.
32    pub fn with_model(
33        name: impl Into<String>,
34        system_prompt: impl Into<String>,
35        model: impl Into<String>,
36        provider: Arc<dyn ModelProvider>,
37    ) -> Self {
38        let name = name.into();
39        let callable =
40            LlmCallable::with_provider(name.clone(), system_prompt, provider).with_model(model);
41        Self { name, callable }
42    }
43
44    /// Add tools to the underlying callable
45    pub fn with_tools(mut self, tools: Vec<crate::tool::DynTool>) -> Self {
46        self.callable = self.callable.add_tools(tools);
47        self
48    }
49}
50
51#[async_trait]
52impl Node for LlmNode {
53    fn name(&self) -> &str {
54        &self.name
55    }
56
57    async fn execute(&self, state: NodeState) -> anyhow::Result<NodeState> {
58        // Convert NodeState to input string for the LLM
59        let input = match &state.data {
60            serde_json::Value::String(s) => s.clone(),
61            serde_json::Value::Null => String::new(),
62            other => serde_json::to_string(other)?,
63        };
64
65        // Execute the callable
66        use crate::callable::Callable;
67        let response = self.callable.run(&input).await?;
68
69        // Convert response back to NodeState
70        Ok(NodeState::from_string(&response))
71    }
72}
73
74#[cfg(test)]
75mod tests {
76    use super::*;
77    use crate::providers::{ChatChoice, ChatMessage, ChatRequest, ChatResponse};
78    use async_trait::async_trait;
79
80    /// Mock provider for testing
81    struct MockProvider {
82        response: String,
83    }
84
85    impl MockProvider {
86        fn new(response: impl Into<String>) -> Self {
87            Self {
88                response: response.into(),
89            }
90        }
91    }
92
93    #[async_trait]
94    impl ModelProvider for MockProvider {
95        fn name(&self) -> &str {
96            "mock"
97        }
98
99        async fn chat(&self, _request: ChatRequest) -> anyhow::Result<ChatResponse> {
100            Ok(ChatResponse {
101                id: "mock-id".to_string(),
102                choices: vec![ChatChoice {
103                    index: 0,
104                    message: ChatMessage::assistant(&self.response),
105                    finish_reason: Some("stop".to_string()),
106                }],
107                usage: None,
108            })
109        }
110    }
111
112    #[tokio::test]
113    async fn test_llm_node_execute() {
114        let provider = Arc::new(MockProvider::new("Hello, world!"));
115        let node = LlmNode::new("test_node", "You are a helpful assistant", provider);
116
117        assert_eq!(node.name(), "test_node");
118
119        let input = NodeState::from_string("Say hello");
120        let result = node.execute(input).await.unwrap();
121
122        assert_eq!(result.as_str(), Some("Hello, world!"));
123    }
124
125    #[tokio::test]
126    async fn test_llm_node_with_json_input() {
127        let provider = Arc::new(MockProvider::new("Processed JSON"));
128        let node = LlmNode::new("json_node", "Process the input", provider);
129
130        let input = NodeState::from_value(serde_json::json!({"key": "value"}));
131        let result = node.execute(input).await.unwrap();
132
133        assert_eq!(result.as_str(), Some("Processed JSON"));
134    }
135
136    #[tokio::test]
137    async fn test_llm_node_with_empty_input() {
138        let provider = Arc::new(MockProvider::new("Default response"));
139        let node = LlmNode::new("empty_node", "Handle empty input", provider);
140
141        let input = NodeState::new();
142        let result = node.execute(input).await.unwrap();
143
144        assert_eq!(result.as_str(), Some("Default response"));
145    }
146}