enact_core/graph/node/
llm.rs1use super::{Node, NodeState};
4use crate::callable::LlmCallable;
5use crate::providers::ModelProvider;
6use async_trait::async_trait;
7use std::sync::Arc;
8
9pub struct LlmNode {
15 name: String,
16 callable: LlmCallable,
17}
18
19impl LlmNode {
20 pub fn new(
22 name: impl Into<String>,
23 system_prompt: impl Into<String>,
24 provider: Arc<dyn ModelProvider>,
25 ) -> Self {
26 let name = name.into();
27 let callable = LlmCallable::with_provider(name.clone(), system_prompt, provider);
28 Self { name, callable }
29 }
30
31 pub fn with_model(
33 name: impl Into<String>,
34 system_prompt: impl Into<String>,
35 model: impl Into<String>,
36 provider: Arc<dyn ModelProvider>,
37 ) -> Self {
38 let name = name.into();
39 let callable =
40 LlmCallable::with_provider(name.clone(), system_prompt, provider).with_model(model);
41 Self { name, callable }
42 }
43
44 pub fn with_tools(mut self, tools: Vec<crate::tool::DynTool>) -> Self {
46 self.callable = self.callable.add_tools(tools);
47 self
48 }
49}
50
51#[async_trait]
52impl Node for LlmNode {
53 fn name(&self) -> &str {
54 &self.name
55 }
56
57 async fn execute(&self, state: NodeState) -> anyhow::Result<NodeState> {
58 let input = match &state.data {
60 serde_json::Value::String(s) => s.clone(),
61 serde_json::Value::Null => String::new(),
62 other => serde_json::to_string(other)?,
63 };
64
65 use crate::callable::Callable;
67 let response = self.callable.run(&input).await?;
68
69 Ok(NodeState::from_string(&response))
71 }
72}
73
74#[cfg(test)]
75mod tests {
76 use super::*;
77 use crate::providers::{ChatChoice, ChatMessage, ChatRequest, ChatResponse};
78 use async_trait::async_trait;
79
80 struct MockProvider {
82 response: String,
83 }
84
85 impl MockProvider {
86 fn new(response: impl Into<String>) -> Self {
87 Self {
88 response: response.into(),
89 }
90 }
91 }
92
93 #[async_trait]
94 impl ModelProvider for MockProvider {
95 fn name(&self) -> &str {
96 "mock"
97 }
98
99 async fn chat(&self, _request: ChatRequest) -> anyhow::Result<ChatResponse> {
100 Ok(ChatResponse {
101 id: "mock-id".to_string(),
102 choices: vec![ChatChoice {
103 index: 0,
104 message: ChatMessage::assistant(&self.response),
105 finish_reason: Some("stop".to_string()),
106 }],
107 usage: None,
108 })
109 }
110 }
111
112 #[tokio::test]
113 async fn test_llm_node_execute() {
114 let provider = Arc::new(MockProvider::new("Hello, world!"));
115 let node = LlmNode::new("test_node", "You are a helpful assistant", provider);
116
117 assert_eq!(node.name(), "test_node");
118
119 let input = NodeState::from_string("Say hello");
120 let result = node.execute(input).await.unwrap();
121
122 assert_eq!(result.as_str(), Some("Hello, world!"));
123 }
124
125 #[tokio::test]
126 async fn test_llm_node_with_json_input() {
127 let provider = Arc::new(MockProvider::new("Processed JSON"));
128 let node = LlmNode::new("json_node", "Process the input", provider);
129
130 let input = NodeState::from_value(serde_json::json!({"key": "value"}));
131 let result = node.execute(input).await.unwrap();
132
133 assert_eq!(result.as_str(), Some("Processed JSON"));
134 }
135
136 #[tokio::test]
137 async fn test_llm_node_with_empty_input() {
138 let provider = Arc::new(MockProvider::new("Default response"));
139 let node = LlmNode::new("empty_node", "Handle empty input", provider);
140
141 let input = NodeState::new();
142 let result = node.execute(input).await.unwrap();
143
144 assert_eq!(result.as_str(), Some("Default response"));
145 }
146}