use std::sync::Arc;
use async_trait::async_trait;
use rune_chain_core::{Chain, ChainError, GenerateResult, Llm, Message, PromptArgs, Tool};
const DEFAULT_MAX_ITERATIONS: usize = 10;
const ACTION_PREFIX: &str = "Action:";
const ACTION_INPUT_PREFIX: &str = "Action Input:";
const FINAL_ANSWER_PREFIX: &str = "Final Answer:";
#[derive(Debug, Clone)]
pub struct AgentResult {
pub output: String,
pub iterations: usize,
pub scratchpad: Vec<String>,
}
pub struct AgentExecutor {
llm: Arc<dyn Llm>,
tools: Vec<Box<dyn Tool>>,
max_iterations: usize,
system_prompt: Option<String>,
}
impl AgentExecutor {
pub fn new(llm: Arc<dyn Llm>) -> Self {
Self {
llm,
tools: Vec::new(),
max_iterations: DEFAULT_MAX_ITERATIONS,
system_prompt: None,
}
}
pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
self.tools.push(Box::new(tool));
self
}
pub fn max_iterations(mut self, max_iterations: usize) -> Self {
self.max_iterations = max_iterations;
self
}
pub fn system_prompt(mut self, system_prompt: impl Into<String>) -> Self {
self.system_prompt = Some(system_prompt.into());
self
}
pub async fn run(&self, input: &str) -> Result<AgentResult, ChainError> {
let system_content = self
.system_prompt
.clone()
.unwrap_or_else(|| build_system_prompt(&self.tools));
let mut messages = vec![
Message::system(system_content),
Message::human(input),
];
let mut scratchpad: Vec<String> = Vec::new();
let mut iterations = 0;
let mut last_response;
loop {
let result = self.llm.generate(&messages).await?;
let response = result.generation.trim().to_string();
last_response = response.clone();
scratchpad.push(response.clone());
if let Some(answer) = extract_final_answer(&response) {
return Ok(AgentResult {
output: answer,
iterations,
scratchpad,
});
}
iterations += 1;
if iterations >= self.max_iterations {
break;
}
if let Some((tool_name, tool_input)) = extract_action(&response) {
let observation = self.invoke_tool(&tool_name, &tool_input);
let observation_entry = format!("Observation: {observation}");
scratchpad.push(observation_entry.clone());
let assistant_turn = format!("{response}\n{observation_entry}");
messages.push(Message::ai(assistant_turn));
} else {
messages.push(Message::ai(response));
}
}
Ok(AgentResult {
output: last_response,
iterations,
scratchpad,
})
}
fn invoke_tool(&self, name: &str, input: &str) -> String {
self.tools
.iter()
.find(|tool| tool.name() == name)
.map(|tool| tool.run(input))
.unwrap_or_else(|| format!("Error: unknown tool '{name}'"))
}
}
#[async_trait]
impl Chain for AgentExecutor {
async fn call(&self, input_variables: PromptArgs) -> Result<GenerateResult, ChainError> {
let input = input_variables
.get("input")
.and_then(|v| v.as_str())
.ok_or_else(|| ChainError::MissingVariable("input".to_string()))?;
let agent_result = self.run(input).await?;
Ok(GenerateResult::from_text(agent_result.output))
}
fn input_keys(&self) -> Vec<String> {
vec!["input".to_string()]
}
}
fn build_system_prompt(tools: &[Box<dyn Tool>]) -> String {
let tool_list = if tools.is_empty() {
"(no tools available)".to_string()
} else {
tools
.iter()
.map(|tool| format!("- {}: {}", tool.name(), tool.description()))
.collect::<Vec<_>>()
.join("\n")
};
format!(
"You are an agent that solves problems step by step using tools.\n\
\n\
Available tools:\n\
{tool_list}\n\
\n\
Use this format:\n\
Thought: <reason about what to do>\n\
Action: <tool_name>\n\
Action Input: <input to the tool>\n\
Observation: <result of the action>\n\
... (repeat Thought/Action/Action Input/Observation as needed)\n\
Thought: I now have enough information.\n\
Final Answer: <your final answer>\n\
\n\
Begin!"
)
}
fn extract_final_answer(response: &str) -> Option<String> {
response
.lines()
.find(|line| line.trim_start().starts_with(FINAL_ANSWER_PREFIX))
.map(|line| {
line.trim_start()
.trim_start_matches(FINAL_ANSWER_PREFIX)
.trim()
.to_string()
})
}
fn extract_action(response: &str) -> Option<(String, String)> {
let mut action: Option<String> = None;
let mut action_input: Option<String> = None;
for line in response.lines() {
let trimmed = line.trim_start();
if trimmed.starts_with(ACTION_PREFIX) && !trimmed.starts_with(ACTION_INPUT_PREFIX) {
action = Some(
trimmed
.trim_start_matches(ACTION_PREFIX)
.trim()
.to_string(),
);
} else if trimmed.starts_with(ACTION_INPUT_PREFIX) {
action_input = Some(
trimmed
.trim_start_matches(ACTION_INPUT_PREFIX)
.trim()
.to_string(),
);
}
}
match (action, action_input) {
(Some(name), Some(input)) => Some((name, input)),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_final_answer_finds_answer() {
let response = "Thought: I know the answer.\nFinal Answer: 42";
assert_eq!(
extract_final_answer(response),
Some("42".to_string())
);
}
#[test]
fn extract_final_answer_returns_none_when_absent() {
let response = "Thought: I need to think more.\nAction: calculator\nAction Input: 2+2";
assert!(extract_final_answer(response).is_none());
}
#[test]
fn extract_action_parses_tool_call() {
let response =
"Thought: Let me calculate.\nAction: calculator\nAction Input: 2 + 2";
let result = extract_action(response);
assert_eq!(
result,
Some(("calculator".to_string(), "2 + 2".to_string()))
);
}
#[test]
fn extract_action_returns_none_when_no_action() {
let response = "Thought: I am thinking.\nFinal Answer: done";
assert!(extract_action(response).is_none());
}
#[test]
fn extract_action_requires_both_fields() {
let response = "Thought: Let me try.\nAction: calculator";
assert!(extract_action(response).is_none());
}
#[test]
fn build_system_prompt_no_tools() {
let prompt = build_system_prompt(&[]);
assert!(prompt.contains("(no tools available)"));
assert!(prompt.contains("Final Answer:"));
}
struct Dummy;
impl Tool for Dummy {
fn name(&self) -> &str {
"dummy"
}
fn description(&self) -> &str {
"A dummy tool."
}
fn run(&self, input: &str) -> String {
format!("got: {input}")
}
}
#[test]
fn build_system_prompt_with_tools() {
let tools: Vec<Box<dyn Tool>> = vec![Box::new(Dummy)];
let prompt = build_system_prompt(&tools);
assert!(prompt.contains("dummy"));
assert!(prompt.contains("A dummy tool."));
}
#[test]
fn invoke_tool_unknown_name_returns_error_string() {
use std::sync::Arc;
struct FakeLlm;
#[async_trait::async_trait]
impl Llm for FakeLlm {
async fn generate(
&self,
_messages: &[Message],
) -> Result<GenerateResult, rune_chain_core::LlmError> {
Ok(GenerateResult::from_text("Final Answer: done"))
}
}
let executor = AgentExecutor::new(Arc::new(FakeLlm)).tool(Dummy);
let result = executor.invoke_tool("nonexistent", "anything");
assert!(result.contains("unknown tool"));
}
#[test]
fn invoke_tool_known_name_calls_tool() {
use std::sync::Arc;
struct FakeLlm;
#[async_trait::async_trait]
impl Llm for FakeLlm {
async fn generate(
&self,
_messages: &[Message],
) -> Result<GenerateResult, rune_chain_core::LlmError> {
Ok(GenerateResult::from_text("Final Answer: done"))
}
}
let executor = AgentExecutor::new(Arc::new(FakeLlm)).tool(Dummy);
let result = executor.invoke_tool("dummy", "hello");
assert_eq!(result, "got: hello");
}
#[tokio::test]
async fn agent_run_returns_final_answer() {
use std::sync::Arc;
struct DirectAnswerLlm;
#[async_trait::async_trait]
impl Llm for DirectAnswerLlm {
async fn generate(
&self,
_messages: &[Message],
) -> Result<GenerateResult, rune_chain_core::LlmError> {
Ok(GenerateResult::from_text(
"Thought: I know this.\nFinal Answer: Paris",
))
}
}
let agent = AgentExecutor::new(Arc::new(DirectAnswerLlm));
let result = agent.run("What is the capital of France?").await.unwrap();
assert_eq!(result.output, "Paris");
assert_eq!(result.iterations, 0);
assert_eq!(result.scratchpad.len(), 1);
}
#[tokio::test]
async fn agent_run_calls_tool_then_answers() {
use std::sync::Arc;
use std::sync::Mutex;
struct SequencedLlm {
call_count: Mutex<usize>,
}
#[async_trait::async_trait]
impl Llm for SequencedLlm {
async fn generate(
&self,
_messages: &[Message],
) -> Result<GenerateResult, rune_chain_core::LlmError> {
let mut count = self.call_count.lock().unwrap();
*count += 1;
let response = if *count == 1 {
"Thought: I need the tool.\nAction: dummy\nAction Input: test"
} else {
"Thought: Got the observation.\nFinal Answer: got: test"
};
Ok(GenerateResult::from_text(response))
}
}
let llm = Arc::new(SequencedLlm {
call_count: Mutex::new(0),
});
let agent = AgentExecutor::new(llm).tool(Dummy);
let result = agent.run("Use the dummy tool with 'test'.").await.unwrap();
assert_eq!(result.output, "got: test");
assert_eq!(result.iterations, 1);
assert!(result.scratchpad.iter().any(|s| s.contains("Observation: got: test")));
}
#[tokio::test]
async fn agent_stops_at_max_iterations() {
use std::sync::Arc;
struct LoopingLlm;
#[async_trait::async_trait]
impl Llm for LoopingLlm {
async fn generate(
&self,
_messages: &[Message],
) -> Result<GenerateResult, rune_chain_core::LlmError> {
Ok(GenerateResult::from_text(
"Thought: Still thinking.\nAction: dummy\nAction Input: x",
))
}
}
let agent = AgentExecutor::new(Arc::new(LoopingLlm))
.tool(Dummy)
.max_iterations(3);
let result = agent.run("Will this loop?").await.unwrap();
assert_eq!(result.iterations, 3);
assert!(!result.output.is_empty());
}
#[tokio::test]
async fn chain_call_reads_input_key() {
use std::sync::Arc;
struct DirectAnswerLlm;
#[async_trait::async_trait]
impl Llm for DirectAnswerLlm {
async fn generate(
&self,
_messages: &[Message],
) -> Result<GenerateResult, rune_chain_core::LlmError> {
Ok(GenerateResult::from_text("Final Answer: 42"))
}
}
use rune_chain_core::prompt_args;
let agent = AgentExecutor::new(Arc::new(DirectAnswerLlm));
let result = agent
.call(prompt_args! { "input" => "What is 6 * 7?" })
.await
.unwrap();
assert_eq!(result.generation, "42");
}
#[tokio::test]
async fn chain_call_missing_input_returns_error() {
use std::sync::Arc;
struct DirectAnswerLlm;
#[async_trait::async_trait]
impl Llm for DirectAnswerLlm {
async fn generate(
&self,
_messages: &[Message],
) -> Result<GenerateResult, rune_chain_core::LlmError> {
Ok(GenerateResult::from_text("Final Answer: ok"))
}
}
use rune_chain_core::prompt_args;
let agent = AgentExecutor::new(Arc::new(DirectAnswerLlm));
let err = agent.call(prompt_args! {}).await.unwrap_err();
assert!(matches!(err, ChainError::MissingVariable(_)));
}
}