use radkit::agent::{LlmFunction, LlmWorker};
use radkit::macros::LLMOutput;
use radkit::models::{Content, ContentPart, Event, LlmResponse, Thread, TokenUsage};
use radkit::test_support::{structured_response, FakeLlm, RecordingTool};
use radkit::tools::{FunctionTool, ToolCall, ToolContext, ToolResult};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::{HashMap, VecDeque};
#[derive(Debug, PartialEq, Deserialize, LLMOutput, Serialize, JsonSchema)]
struct SentimentAnalysis {
sentiment: String,
confidence: f64,
}
#[tokio::test]
async fn test_llm_function_with_system_instructions() {
let response = SentimentAnalysis {
sentiment: "positive".to_string(),
confidence: 0.95,
};
let llm = FakeLlm::with_responses("fake-llm", [Ok(structured_response(&response))]);
let sentiment_fn = LlmFunction::<SentimentAnalysis>::new_with_system_instructions(
llm.clone(),
"You are a sentiment analyzer.",
);
let thread = Thread::from_user("Analyze: I love this!");
let result = sentiment_fn.run(thread).await.unwrap();
assert_eq!(result.sentiment, "positive");
assert_eq!(result.confidence, 0.95);
let calls = llm.calls();
assert_eq!(calls.len(), 1);
let system_instructions = calls[0].system().expect("system instructions present");
assert!(
system_instructions.contains("You are a sentiment analyzer."),
"Expected system instructions to contain custom instructions"
);
assert!(
system_instructions.contains("JSON object matching the following schema"),
"Expected system instructions to contain schema instructions"
);
}
#[derive(Debug, PartialEq, Deserialize, LLMOutput, Serialize, JsonSchema)]
struct Translation {
translated_text: String,
source_language: String,
}
#[tokio::test]
async fn test_llm_function_multi_turn_conversation() {
let response1 = Translation {
translated_text: "Hola mundo".to_string(),
source_language: "English".to_string(),
};
let response2 = Translation {
translated_text: "Buenos días".to_string(),
source_language: "English".to_string(),
};
let llm = FakeLlm::with_responses(
"fake-llm",
[
Ok(structured_response(&response1)),
Ok(structured_response(&response2)),
],
);
let translate_fn = LlmFunction::<Translation>::new(llm.clone());
let thread = Thread::from_user("Translate to Spanish: Hello world");
let (result1, continued_thread) = translate_fn.run_and_continue(thread).await.unwrap();
assert_eq!(result1.translated_text, "Hola mundo");
let thread2 = continued_thread.add_event(Event::user("Now translate: Good morning"));
let result2 = translate_fn.run(thread2).await.unwrap();
assert_eq!(result2.translated_text, "Buenos días");
assert_eq!(llm.calls().len(), 2);
}
#[derive(Debug, PartialEq, Deserialize, LLMOutput, Serialize, JsonSchema)]
struct WeatherReport {
location: String,
temperature: f64,
condition: String,
}
#[tokio::test]
async fn test_llm_worker_with_tool() {
let tool_call_response = LlmResponse::new(
Content::from_parts(vec![ContentPart::ToolCall(ToolCall::new(
"call-1",
"get_weather",
json!({"location": "Seattle"}),
))]),
TokenUsage::empty(),
);
let final_response = WeatherReport {
location: "Seattle".to_string(),
temperature: 72.5,
condition: "sunny".to_string(),
};
let weather_tool = FunctionTool::new(
"get_weather",
"Get current weather",
|args: HashMap<String, serde_json::Value>, _ctx: &ToolContext| {
Box::pin(async move {
let location = args
.get("location")
.and_then(|v| v.as_str())
.unwrap_or("Unknown");
ToolResult::success(json!({
"temperature": 72.5,
"condition": "sunny",
"location": location
}))
})
},
);
let worker_llm = FakeLlm::with_responses(
"fake-llm",
[
Ok(tool_call_response),
Ok(structured_response(&final_response)),
],
);
let worker = LlmWorker::<WeatherReport>::builder(worker_llm)
.with_tool(weather_tool)
.build();
let thread = Thread::from_user("What's the weather in Seattle?");
let report = worker.run(thread).await.unwrap();
assert_eq!(report.location, "Seattle");
assert_eq!(report.temperature, 72.5);
assert_eq!(report.condition, "sunny");
}
#[derive(Debug, PartialEq, Deserialize, LLMOutput, Serialize, JsonSchema)]
struct CalculationResult {
result: f64,
steps: Vec<String>,
}
#[tokio::test]
async fn test_llm_worker_multiple_tool_calls() {
let response1 = LlmResponse::new(
Content::from_parts(vec![ContentPart::ToolCall(ToolCall::new(
"call-1",
"add",
json!({"a": 2, "b": 3}),
))]),
TokenUsage::empty(),
);
let response2 = LlmResponse::new(
Content::from_parts(vec![ContentPart::ToolCall(ToolCall::new(
"call-2",
"multiply",
json!({"a": 5, "b": 4}),
))]),
TokenUsage::empty(),
);
let final_result = CalculationResult {
result: 20.0,
steps: vec![
"add(2, 3) = 5".to_string(),
"multiply(5, 4) = 20".to_string(),
],
};
let llm = FakeLlm::with_responses(
"fake-llm",
[
Ok(response1),
Ok(response2),
Ok(structured_response(&final_result)),
],
);
let add_tool = FunctionTool::new(
"add",
"Add two numbers",
|args: HashMap<String, serde_json::Value>, _ctx: &ToolContext| {
Box::pin(async move {
let a = args.get("a").and_then(|v| v.as_f64()).unwrap_or(0.0);
let b = args.get("b").and_then(|v| v.as_f64()).unwrap_or(0.0);
ToolResult::success(json!(a + b))
})
},
);
let multiply_tool = FunctionTool::new(
"multiply",
"Multiply two numbers",
|args: HashMap<String, serde_json::Value>, _ctx: &ToolContext| {
Box::pin(async move {
let a = args.get("a").and_then(|v| v.as_f64()).unwrap_or(0.0);
let b = args.get("b").and_then(|v| v.as_f64()).unwrap_or(0.0);
ToolResult::success(json!(a * b))
})
},
);
let worker = LlmWorker::<CalculationResult>::builder(llm)
.with_tool(add_tool)
.with_tool(multiply_tool)
.build();
let thread = Thread::from_user("Calculate: (2 + 3) * 4");
let result = worker.run(thread).await.unwrap();
assert_eq!(result.result, 20.0);
assert_eq!(result.steps.len(), 2);
}
#[derive(Debug, Deserialize, LLMOutput, Serialize, JsonSchema)]
struct SearchResult {
query: String,
found: bool,
}
#[tokio::test]
async fn test_llm_worker_tool_execution_verification() {
let recording_tool = RecordingTool::new("search", "Search for information", {
let mut results = VecDeque::new();
results.push_back(ToolResult::success(json!({"results": ["item1", "item2"]})));
results
});
let tool_call_response = LlmResponse::new(
Content::from_parts(vec![ContentPart::ToolCall(ToolCall::new(
"call-1",
"search",
json!({"query": "rust async"}),
))]),
TokenUsage::empty(),
);
let final_response = SearchResult {
query: "rust async".to_string(),
found: true,
};
let llm = FakeLlm::with_responses(
"fake-llm",
[
Ok(tool_call_response),
Ok(structured_response(&final_response)),
],
);
let worker = LlmWorker::<SearchResult>::builder(llm)
.with_tool(recording_tool.clone())
.build();
let thread = Thread::from_user("Search for rust async");
let result = worker.run(thread).await.unwrap();
assert_eq!(result.query, "rust async");
assert!(result.found);
assert_eq!(recording_tool.call_count(), 1);
let calls = recording_tool.calls();
assert_eq!(calls[0]["query"], "rust async");
}
#[tokio::test]
async fn test_llm_worker_respects_max_iterations() {
let endless_tool_call = LlmResponse::new(
Content::from_parts(vec![ContentPart::ToolCall(ToolCall::new(
"call-1",
"infinite_tool",
json!({}),
))]),
TokenUsage::empty(),
);
let mut responses = Vec::new();
for _ in 0..25 {
responses.push(Ok(endless_tool_call.clone()));
}
let llm = FakeLlm::with_responses("fake-llm", responses);
let infinite_tool = FunctionTool::new(
"infinite_tool",
"Tool that never ends",
|_args: HashMap<String, serde_json::Value>, _ctx: &ToolContext| {
Box::pin(async move { ToolResult::success(json!({"status": "continue"})) })
},
);
#[derive(Debug, Deserialize, LLMOutput, Serialize, JsonSchema)]
struct SimpleResponse {
done: bool,
}
let worker = LlmWorker::<SimpleResponse>::builder(llm)
.with_tool(infinite_tool)
.with_max_iterations(5) .build();
let thread = Thread::from_user("Start infinite loop");
let result = worker.run(thread).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("iteration") || err.to_string().contains("limit"));
}