use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::error::NodeError;
use crate::graph::{NodeExecutor, NodeOutput};
use crate::state::{Message, SharedState, ToolCall};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct LLMConfig {
pub model: String,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub system_prompt: Option<String>,
pub stop_sequences: Option<Vec<String>>,
pub api_base: Option<String>,
pub api_key: Option<String>,
}
impl Default for LLMConfig {
fn default() -> Self {
Self {
model: "claude-3-sonnet-20240229".to_string(),
max_tokens: Some(4096),
temperature: Some(0.7),
system_prompt: None,
stop_sequences: None,
api_base: None,
api_key: None,
}
}
}
impl LLMConfig {
pub fn new(model: impl Into<String>) -> Self {
Self {
model: model.into(),
..Default::default()
}
}
pub fn max_tokens(mut self, tokens: u32) -> Self {
self.max_tokens = Some(tokens);
self
}
pub fn temperature(mut self, temp: f32) -> Self {
self.temperature = Some(temp);
self
}
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn api_base(mut self, url: impl Into<String>) -> Self {
self.api_base = Some(url.into());
self
}
pub fn api_key(mut self, key: impl Into<String>) -> Self {
self.api_key = Some(key.into());
self
}
}
#[derive(Clone, Debug)]
pub struct LLMResponse {
pub content: String,
pub tool_calls: Vec<ToolCall>,
pub stop_reason: StopReason,
pub usage: Option<TokenUsage>,
}
impl LLMResponse {
pub fn text(content: impl Into<String>) -> Self {
Self {
content: content.into(),
tool_calls: vec![],
stop_reason: StopReason::EndTurn,
usage: None,
}
}
pub fn with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
Self {
content: content.into(),
tool_calls,
stop_reason: StopReason::ToolUse,
usage: None,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum StopReason {
EndTurn,
MaxTokens,
StopSequence,
ToolUse,
Other(String),
}
#[derive(Clone, Debug, Default)]
pub struct TokenUsage {
pub input_tokens: u32,
pub output_tokens: u32,
}
#[async_trait]
pub trait LLMProvider: Send + Sync {
async fn generate(
&self,
messages: &[Message],
config: &LLMConfig,
) -> Result<LLMResponse, NodeError>;
fn name(&self) -> &str;
}
pub struct LLMNode<P: LLMProvider> {
id: String,
provider: P,
config: LLMConfig,
}
impl<P: LLMProvider> LLMNode<P> {
pub fn new(id: impl Into<String>, provider: P, config: LLMConfig) -> Self {
Self {
id: id.into(),
provider,
config,
}
}
pub fn with_provider(id: impl Into<String>, provider: P) -> Self {
Self::new(id, provider, LLMConfig::default())
}
}
#[async_trait]
impl<P: LLMProvider + 'static> NodeExecutor for LLMNode<P> {
fn id(&self) -> &str {
&self.id
}
async fn execute(&self, state: SharedState) -> Result<NodeOutput, NodeError> {
let messages = {
let guard = state
.read()
.map_err(|e| NodeError::execution_failed(e.to_string()))?;
guard.messages.clone()
};
let response = self.provider.generate(&messages, &self.config).await?;
{
let mut guard = state
.write()
.map_err(|e| NodeError::execution_failed(e.to_string()))?;
guard.add_assistant_message(&response.content);
if !response.tool_calls.is_empty() {
guard.tool_calls = response.tool_calls;
}
if guard.tool_calls.is_empty() {
guard.mark_complete();
}
}
Ok(NodeOutput::cont())
}
fn description(&self) -> Option<&str> {
Some("Calls an LLM provider to generate a response")
}
}
pub struct MockLLMProvider {
responses: std::collections::VecDeque<LLMResponse>,
}
impl MockLLMProvider {
pub fn new() -> Self {
Self {
responses: std::collections::VecDeque::new(),
}
}
pub fn with_response(mut self, response: LLMResponse) -> Self {
self.responses.push_back(response);
self
}
pub fn with_text_response(self, text: impl Into<String>) -> Self {
self.with_response(LLMResponse::text(text))
}
}
impl Default for MockLLMProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl LLMProvider for MockLLMProvider {
async fn generate(
&self,
_messages: &[Message],
_config: &LLMConfig,
) -> Result<LLMResponse, NodeError> {
Ok(LLMResponse::text("Mock response"))
}
fn name(&self) -> &str {
"mock"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::state::AgentState;
use std::sync::{Arc, RwLock};
#[test]
fn test_llm_config() {
let config = LLMConfig::new("gpt-4")
.max_tokens(1000)
.temperature(0.5)
.system_prompt("You are helpful");
assert_eq!(config.model, "gpt-4");
assert_eq!(config.max_tokens, Some(1000));
assert_eq!(config.temperature, Some(0.5));
assert_eq!(config.system_prompt, Some("You are helpful".to_string()));
}
#[test]
fn test_llm_response() {
let response = LLMResponse::text("Hello");
assert_eq!(response.content, "Hello");
assert!(response.tool_calls.is_empty());
assert_eq!(response.stop_reason, StopReason::EndTurn);
let tool_call = ToolCall::new("1", "get_weather", serde_json::json!({}));
let response = LLMResponse::with_tool_calls("Checking weather", vec![tool_call]);
assert_eq!(response.tool_calls.len(), 1);
assert_eq!(response.stop_reason, StopReason::ToolUse);
}
#[tokio::test]
async fn test_llm_node_with_mock() {
let provider = MockLLMProvider::new();
let node = LLMNode::with_provider("llm", provider);
let mut state = AgentState::new();
state.add_user_message("Hello");
let shared = Arc::new(RwLock::new(state));
let result = node.execute(shared.clone()).await.unwrap();
assert!(!result.is_terminal());
let guard = shared.read().unwrap();
assert_eq!(guard.messages.len(), 2); assert!(guard.is_complete); }
}