use super::{Node, NodeState};
use crate::callable::LlmCallable;
use crate::providers::ModelProvider;
use async_trait::async_trait;
use std::sync::Arc;
pub struct LlmNode {
name: String,
callable: LlmCallable,
}
impl LlmNode {
pub fn new(
name: impl Into<String>,
system_prompt: impl Into<String>,
provider: Arc<dyn ModelProvider>,
) -> Self {
let name = name.into();
let callable = LlmCallable::with_provider(name.clone(), system_prompt, provider);
Self { name, callable }
}
pub fn with_model(
name: impl Into<String>,
system_prompt: impl Into<String>,
model: impl Into<String>,
provider: Arc<dyn ModelProvider>,
) -> Self {
let name = name.into();
let callable =
LlmCallable::with_provider(name.clone(), system_prompt, provider).with_model(model);
Self { name, callable }
}
pub fn with_tools(mut self, tools: Vec<crate::tool::DynTool>) -> Self {
self.callable = self.callable.add_tools(tools);
self
}
}
#[async_trait]
impl Node for LlmNode {
fn name(&self) -> &str {
&self.name
}
async fn execute(&self, state: NodeState) -> anyhow::Result<NodeState> {
let input = match &state.data {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Null => String::new(),
other => serde_json::to_string(other)?,
};
use crate::callable::Callable;
let response = self.callable.run(&input).await?;
Ok(NodeState::from_string(&response))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::{ChatChoice, ChatMessage, ChatRequest, ChatResponse};
use async_trait::async_trait;
struct MockProvider {
response: String,
}
impl MockProvider {
fn new(response: impl Into<String>) -> Self {
Self {
response: response.into(),
}
}
}
#[async_trait]
impl ModelProvider for MockProvider {
fn name(&self) -> &str {
"mock"
}
async fn chat(&self, _request: ChatRequest) -> anyhow::Result<ChatResponse> {
Ok(ChatResponse {
id: "mock-id".to_string(),
choices: vec![ChatChoice {
index: 0,
message: ChatMessage::assistant(&self.response),
finish_reason: Some("stop".to_string()),
}],
usage: None,
})
}
}
#[tokio::test]
async fn test_llm_node_execute() {
let provider = Arc::new(MockProvider::new("Hello, world!"));
let node = LlmNode::new("test_node", "You are a helpful assistant", provider);
assert_eq!(node.name(), "test_node");
let input = NodeState::from_string("Say hello");
let result = node.execute(input).await.unwrap();
assert_eq!(result.as_str(), Some("Hello, world!"));
}
#[tokio::test]
async fn test_llm_node_with_json_input() {
let provider = Arc::new(MockProvider::new("Processed JSON"));
let node = LlmNode::new("json_node", "Process the input", provider);
let input = NodeState::from_value(serde_json::json!({"key": "value"}));
let result = node.execute(input).await.unwrap();
assert_eq!(result.as_str(), Some("Processed JSON"));
}
#[tokio::test]
async fn test_llm_node_with_empty_input() {
let provider = Arc::new(MockProvider::new("Default response"));
let node = LlmNode::new("empty_node", "Handle empty input", provider);
let input = NodeState::new();
let result = node.execute(input).await.unwrap();
assert_eq!(result.as_str(), Some("Default response"));
}
}