use std::fmt::Write as _;
use serde_json::{json, Value};
use super::corpus;
use crate::agent::{AgentCommandResult, AgentError, AgentWorkspace, AgentWorkspaceConfig};
use crate::protocol::{
create_chat_completion_with_solver, ChatCompletionRequest, ChatMessage, ToolCall,
};
use crate::solver::{SolverConfig, UniversalSolver};
pub const DRIVER_TOOLS: [&str; 4] = ["web_search", "web_fetch", "write_file", "run_command"];
const MAX_TURNS: usize = 12;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DriverToolStep {
pub tool: String,
pub arguments: String,
pub result: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DriverOutcome {
pub task: String,
pub steps: Vec<DriverToolStep>,
pub final_answer: String,
pub turns: usize,
pub hit_turn_cap: bool,
}
impl DriverOutcome {
#[must_use]
pub fn transcript(&self) -> String {
let mut out = String::new();
let _ = writeln!(out, "task: {}", self.task);
let _ = writeln!(
out,
"turns: {} (tool calls: {})",
self.turns,
self.steps.len()
);
for (index, step) in self.steps.iter().enumerate() {
let _ = writeln!(out, " [{}] {} {}", index + 1, step.tool, step.arguments);
let _ = writeln!(out, " -> {}", preview(&step.result, 200));
}
if self.hit_turn_cap {
let _ = writeln!(out, "(stopped at the {MAX_TURNS}-turn safety cap)");
}
out
}
}
pub fn run_agentic_task(task: &str) -> Result<DriverOutcome, AgentError> {
run_agentic_task_in(task, &AgentWorkspaceConfig::default())
}
pub fn run_agentic_task_in(
task: &str,
config: &AgentWorkspaceConfig,
) -> Result<DriverOutcome, AgentError> {
let solver = UniversalSolver::new(SolverConfig {
agent_mode: true,
..SolverConfig::default()
});
let tools = tool_definitions(&DRIVER_TOOLS);
let mut workspace = AgentWorkspace::for_prompt(task, config)?;
let mut messages = vec![ChatMessage::user(task)];
let mut steps = Vec::new();
let mut turns = 0usize;
loop {
if turns >= MAX_TURNS {
return Ok(DriverOutcome {
task: task.to_owned(),
steps,
final_answer: String::new(),
turns,
hit_turn_cap: true,
});
}
turns += 1;
let request = ChatCompletionRequest {
model: None,
messages: messages.clone(),
temperature: None,
stream: false,
tools: tools.clone(),
tool_choice: None,
functions: Vec::new(),
function_call: None,
};
let completion = create_chat_completion_with_solver(&request, &solver);
let Some(choice) = completion.choices.into_iter().next() else {
return Ok(DriverOutcome {
task: task.to_owned(),
steps,
final_answer: String::new(),
turns,
hit_turn_cap: false,
});
};
let requested_tools =
choice.finish_reason == "tool_calls" && !choice.message.tool_calls.is_empty();
if !requested_tools {
return Ok(DriverOutcome {
task: task.to_owned(),
steps,
final_answer: choice.message.content.plain_text(),
turns,
hit_turn_cap: false,
});
}
let assistant = choice.message;
let mut results = Vec::with_capacity(assistant.tool_calls.len());
for call in &assistant.tool_calls {
let result = execute_tool_call(call, &mut workspace);
steps.push(DriverToolStep {
tool: call.function.name.clone(),
arguments: call.function.arguments.clone(),
result: result.clone(),
});
results.push(ChatMessage::tool_result(
call.id.clone(),
call.function.name.clone(),
result,
));
}
messages.push(assistant);
messages.extend(results);
}
}
fn execute_tool_call(call: &ToolCall, workspace: &mut AgentWorkspace) -> String {
let arguments: Value = serde_json::from_str(&call.function.arguments).unwrap_or(Value::Null);
match call.function.name.as_str() {
"web_search" => corpus::web_search(arg_str(&arguments, "query")),
"web_fetch" => corpus::web_fetch(arg_str(&arguments, "url")),
"write_file" => {
let path = arg_str(&arguments, "path");
let content = arg_str(&arguments, "content");
workspace.create_file(path, content);
format!("wrote {} byte(s) to {path}", content.len())
}
"run_command" => {
let command = arg_str(&arguments, "command");
workspace.run_command(command);
workspace.last_command_result().map_or_else(
|| format!("run_command produced no result for {command:?}"),
format_command_result,
)
}
other => format!("error: unsupported tool {other}"),
}
}
fn tool_definitions(names: &[&str]) -> Vec<Value> {
names
.iter()
.map(|name| {
json!({
"type": "function",
"function": { "name": name, "description": tool_description(name) },
})
})
.collect()
}
fn tool_description(name: &str) -> &'static str {
match name {
"web_search" => "Search the web for sources. Arguments: {\"query\": string}.",
"web_fetch" => "Fetch the text at a URL. Arguments: {\"url\": string}.",
"write_file" => {
"Write a workspace file. Arguments: {\"path\": string, \"content\": string}."
}
"run_command" => {
"Run an allowlisted command in the workspace. Arguments: {\"command\": string}."
}
_ => "Tool.",
}
}
fn format_command_result(result: &AgentCommandResult) -> String {
if result.timed_out {
return format!("command timed out: {}", result.command);
}
match result.status_code {
Some(0) => result.stdout.clone(),
Some(code) => format!(
"command exited with status {code}\nstdout:\n{}\nstderr:\n{}",
result.stdout, result.stderr
),
None => format!(
"command terminated without an exit status\nstderr:\n{}",
result.stderr
),
}
}
fn arg_str<'a>(arguments: &'a Value, key: &str) -> &'a str {
arguments
.get(key)
.and_then(Value::as_str)
.unwrap_or_default()
}
fn preview(text: &str, max: usize) -> String {
let collapsed = text.split_whitespace().collect::<Vec<_>>().join(" ");
if collapsed.chars().count() <= max {
return collapsed;
}
let truncated: String = collapsed.chars().take(max).collect();
format!("{truncated}…")
}