use crate::types::{
ChatMessage, ChatRequest, ChatResponse, LlmProvider, MessageRole, RunnerError, TokenUsage,
ToolCallRequest, ToolDefinition,
};
use serde_json::Value;
use std::fmt::Write;
use std::sync::Arc;
use tracing::{debug, info, warn};
pub type FunctionDeclaration = ToolDefinition;
#[derive(Debug, Clone)]
pub struct FunctionCall {
pub name: String,
pub args: Value,
}
impl From<ToolCallRequest> for FunctionCall {
fn from(tc: ToolCallRequest) -> Self {
Self {
name: tc.function_name,
args: tc.arguments,
}
}
}
impl From<FunctionCall> for ToolCallRequest {
fn from(fc: FunctionCall) -> Self {
Self {
id: format!("call_{}", fc.name),
function_name: fc.name,
arguments: fc.args,
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct FunctionResponse {
pub name: String,
pub response: Value,
}
#[derive(serde::Deserialize)]
struct ToolCallPayload {
name: String,
#[serde(default)]
arguments: Option<Value>,
}
pub type TextToolHandler = Arc<dyn Fn(&str, &Value) -> FunctionResponse + Send + Sync>;
#[derive(Debug, Clone)]
pub struct TextToolResponse {
pub content: String,
pub usage: Option<TokenUsage>,
pub finish_reason: Option<String>,
pub tool_calls_count: u32,
}
#[must_use]
pub fn generate_tool_catalog(declarations: &[FunctionDeclaration]) -> String {
let mut catalog = String::with_capacity(4096);
catalog.push_str("\n\n");
catalog.push_str(
"I am testing a function-calling protocol. For each user request below, \
generate the correct XML output that invokes the matching function. \
Output ONLY the raw XML block with no code fences and no explanation.\n\n",
);
catalog.push_str("The output format is:\n\n");
catalog.push_str(
"<tool_call>\n{\"name\": \"FUNCTION_NAME\", \"arguments\": {\"PARAM\": \"VALUE\"}}\n</tool_call>\n\n",
);
catalog.push_str(
"Rules:\n\
- Output ONLY <tool_call> blocks. No markdown, no code fences, no commentary.\n\
- You may output multiple <tool_call> blocks if multiple functions apply.\n\
- ONLY call functions listed under \"Registered functions\" below. \
Do NOT call any other tools (Glob, Grep, Read, Bash, Edit, Write, etc.) — they do not exist in this environment.\n\
- After you receive <tool_result> data, use it to answer the original question.\n\n",
);
catalog.push_str("Registered functions:\n\n");
for decl in declarations {
let _ = writeln!(catalog, "### {}", decl.name);
let _ = writeln!(catalog, "{}", decl.description);
append_parameter_docs(&mut catalog, decl);
catalog.push('\n');
}
if let Some(first) = declarations.first() {
append_few_shot_example(&mut catalog, first);
}
catalog
}
fn append_parameter_docs(catalog: &mut String, decl: &FunctionDeclaration) {
let Some(ref params) = decl.parameters else {
return;
};
let Some(props_obj) = params.get("properties").and_then(|p| p.as_object()) else {
return;
};
if props_obj.is_empty() {
return;
}
let required: Vec<&str> = params
.get("required")
.and_then(|r| r.as_array())
.map(|arr| arr.iter().filter_map(|v| v.as_str()).collect())
.unwrap_or_default();
catalog.push_str("Parameters:\n");
for (name, schema) in props_obj {
let type_str = schema.get("type").and_then(|t| t.as_str()).unwrap_or("any");
let is_required = required.contains(&name.as_str());
let req_label = if is_required { ", required" } else { "" };
let _ = writeln!(catalog, "- `{name}` ({type_str}{req_label})");
}
}
fn append_few_shot_example(catalog: &mut String, decl: &FunctionDeclaration) {
catalog.push_str("Example interaction:\n\n");
let example_args = build_example_args(decl);
let args_json = serde_json::to_string(&example_args).unwrap_or_else(|_| "{}".to_owned());
let _ = writeln!(catalog, "User: [asks a question related to {}]", decl.name);
catalog.push_str("Assistant:\n");
let _ = writeln!(
catalog,
"<tool_call>\n{{\"name\": \"{}\", \"arguments\": {args_json}}}\n</tool_call>",
decl.name
);
}
fn build_example_args(decl: &FunctionDeclaration) -> serde_json::Map<String, Value> {
let mut args = serde_json::Map::new();
let Some(ref params) = decl.parameters else {
return args;
};
let Some(props_obj) = params.get("properties").and_then(|p| p.as_object()) else {
return args;
};
for (name, schema) in props_obj {
let type_str = schema
.get("type")
.and_then(|t| t.as_str())
.unwrap_or("string");
let example_value = match type_str {
"integer" | "number" => Value::Number(serde_json::Number::from(1)),
"boolean" => Value::Bool(true),
"array" => Value::Array(vec![Value::String("example".to_owned())]),
_ => Value::String("example".to_owned()),
};
args.insert(name.clone(), example_value);
}
args
}
pub fn inject_tool_catalog(messages: &mut Vec<ChatMessage>, catalog: &str) {
if let Some(system_msg) = messages.first_mut() {
if system_msg.role == MessageRole::System {
let augmented = format!("{}{catalog}", system_msg.content);
*system_msg = ChatMessage::system(augmented);
return;
}
}
messages.insert(0, ChatMessage::system(catalog));
}
#[must_use]
pub fn parse_tool_call_blocks(content: &str) -> Vec<FunctionCall> {
let mut calls = Vec::new();
let mut search_from = 0;
while let Some(start) = content[search_from..].find("<tool_call>") {
let abs_start = search_from + start + "<tool_call>".len();
let Some(end) = content[abs_start..].find("</tool_call>") else {
warn!("Found <tool_call> without matching </tool_call>");
break;
};
let abs_end = abs_start + end;
let json_str = content[abs_start..abs_end].trim();
match serde_json::from_str::<ToolCallPayload>(json_str) {
Ok(payload) => {
info!("Parsed tool call: {}", payload.name);
calls.push(FunctionCall {
name: payload.name,
args: payload
.arguments
.unwrap_or_else(|| Value::Object(serde_json::Map::new())),
});
}
Err(e) => {
warn!(
"Failed to parse <tool_call> JSON ({} bytes): {e}",
json_str.len()
);
}
}
search_from = abs_end + "</tool_call>".len();
}
calls
}
#[must_use]
pub fn strip_tool_call_blocks(content: &str) -> String {
let mut result = String::with_capacity(content.len());
let mut search_from = 0;
while let Some(start) = content[search_from..].find("<tool_call>") {
let abs_start = search_from + start;
result.push_str(&content[search_from..abs_start]);
let close_tag = "</tool_call>";
if let Some(end) = content[abs_start..].find(close_tag) {
search_from = abs_start + end + close_tag.len();
} else {
search_from = content.len();
}
}
result.push_str(&content[search_from..]);
result.trim().to_owned()
}
#[must_use]
pub fn format_tool_results_as_text(responses: &[FunctionResponse]) -> String {
let mut text = String::with_capacity(4096);
text.push_str("Here are the results from the tools you requested:\n\n");
for resp in responses {
let _ = writeln!(text, "<tool_result name=\"{}\">", resp.name);
let json_str =
serde_json::to_string_pretty(&resp.response).unwrap_or_else(|_| "{}".to_owned());
let _ = writeln!(text, "{json_str}");
text.push_str("</tool_result>\n\n");
}
text.push_str("Please analyze the data above and respond to the user's question.");
text
}
const MAX_TOOL_ITERATIONS: usize = 10;
pub async fn execute_with_text_tools(
provider: &dyn LlmProvider,
messages: &mut Vec<ChatMessage>,
declarations: &[FunctionDeclaration],
tool_handler: TextToolHandler,
max_iterations: usize,
) -> Result<TextToolResponse, RunnerError> {
let tool_catalog = generate_tool_catalog(declarations);
inject_tool_catalog(messages, &tool_catalog);
debug!(
message_count = messages.len(),
catalog_len = tool_catalog.len(),
tool_count = declarations.len(),
max_iterations,
"Text tool loop: starting with injected tool catalog"
);
let mut tool_calls_count: u32 = 0;
let effective_max = max_iterations.min(MAX_TOOL_ITERATIONS);
for iteration in 0..effective_max {
let request = ChatRequest::new(messages.clone());
let response: ChatResponse = provider.complete(&request).await?;
let parsed_tool_calls = parse_tool_call_blocks(&response.content);
if parsed_tool_calls.is_empty() {
let content = strip_tool_call_blocks(&response.content);
debug!(
iteration,
content_len = content.len(),
total_tool_calls = tool_calls_count,
"Text tool loop: final response (no tool calls)"
);
return Ok(TextToolResponse {
content,
usage: response.usage,
finish_reason: response.finish_reason,
tool_calls_count,
});
}
info!(
"Text tool iteration {}: parsed {} tool call(s)",
iteration,
parsed_tool_calls.len()
);
let mut function_responses = Vec::with_capacity(parsed_tool_calls.len());
for call in &parsed_tool_calls {
info!(tool_name = %call.name, "Executing tool call");
let resp = tool_handler(&call.name, &call.args);
function_responses.push(resp);
}
#[allow(clippy::cast_possible_truncation)]
{
tool_calls_count += parsed_tool_calls.len() as u32;
}
let assistant_text = strip_tool_call_blocks(&response.content);
if !assistant_text.is_empty() {
messages.push(ChatMessage::assistant(assistant_text));
}
let tool_results_text = format_tool_results_as_text(&function_responses);
messages.push(ChatMessage::user(tool_results_text));
}
Ok(TextToolResponse {
content: String::new(),
usage: None,
finish_reason: Some("max_iterations".to_owned()),
tool_calls_count,
})
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn parse_single_tool_call() {
let content = r#"Let me fetch your data.
<tool_call>
{"name": "get_activities", "arguments": {"provider": "strava", "limit": 25}}
</tool_call>"#;
let calls = parse_tool_call_blocks(content);
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "get_activities");
assert_eq!(calls[0].args["provider"], "strava");
assert_eq!(calls[0].args["limit"], 25);
}
#[test]
fn parse_multiple_tool_calls() {
let content = r#"I'll fetch your data.
<tool_call>
{"name": "get_activities", "arguments": {"provider": "strava", "limit": 10}}
</tool_call>
And your profile:
<tool_call>
{"name": "get_athlete", "arguments": {"provider": "strava"}}
</tool_call>"#;
let calls = parse_tool_call_blocks(content);
assert_eq!(calls.len(), 2);
assert_eq!(calls[0].name, "get_activities");
assert_eq!(calls[1].name, "get_athlete");
}
#[test]
fn parse_no_tool_calls() {
let content = "Here is your analysis of the data. You had a great week!";
let calls = parse_tool_call_blocks(content);
assert!(calls.is_empty());
}
#[test]
fn parse_malformed_json_skipped() {
let content = r#"<tool_call>
{not valid json}
</tool_call>
<tool_call>
{"name": "get_stats", "arguments": {"provider": "strava"}}
</tool_call>"#;
let calls = parse_tool_call_blocks(content);
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "get_stats");
}
#[test]
fn parse_tool_call_without_arguments() {
let content = r#"<tool_call>
{"name": "get_connection_status"}
</tool_call>"#;
let calls = parse_tool_call_blocks(content);
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "get_connection_status");
assert!(calls[0].args.is_object());
}
#[test]
fn strip_tool_call_blocks_removes_blocks() {
let content = r#"Let me fetch your data.
<tool_call>
{"name": "get_activities", "arguments": {"provider": "strava"}}
</tool_call>
And some more text."#;
let stripped = strip_tool_call_blocks(content);
assert_eq!(
stripped,
"Let me fetch your data.\n\n\n\nAnd some more text."
);
assert!(!stripped.contains("<tool_call>"));
}
#[test]
fn strip_preserves_no_tool_calls() {
let content = "Just plain text with no tool calls.";
let stripped = strip_tool_call_blocks(content);
assert_eq!(stripped, content);
}
#[test]
fn generate_tool_catalog_has_tools() {
let declarations = vec![
FunctionDeclaration {
name: "get_activities".to_owned(),
description: "Get user's recent fitness activities".to_owned(),
parameters: Some(json!({
"type": "object",
"properties": {
"provider": {"type": "string"},
"limit": {"type": "integer"}
},
"required": ["provider"]
})),
},
FunctionDeclaration {
name: "get_athlete".to_owned(),
description: "Get user's athlete profile".to_owned(),
parameters: Some(json!({
"type": "object",
"properties": {
"provider": {"type": "string"}
},
"required": ["provider"]
})),
},
];
let catalog = generate_tool_catalog(&declarations);
assert!(catalog.contains("### get_activities"));
assert!(catalog.contains("### get_athlete"));
assert!(catalog.contains("<tool_call>"));
assert!(catalog.contains("`provider` (string, required)"));
assert!(catalog.contains("`limit` (integer)"));
}
#[test]
fn generate_tool_catalog_no_parameters() {
let declarations = vec![FunctionDeclaration {
name: "ping".to_owned(),
description: "Check connectivity".to_owned(),
parameters: None,
}];
let catalog = generate_tool_catalog(&declarations);
assert!(catalog.contains("### ping"));
assert!(catalog.contains("Check connectivity"));
}
#[test]
fn format_tool_results_single() {
let responses = vec![FunctionResponse {
name: "get_stats".to_owned(),
response: json!({"total_distance_km": 1234.5}),
}];
let text = format_tool_results_as_text(&responses);
assert!(text.contains("<tool_result name=\"get_stats\">"));
assert!(text.contains("1234.5"));
assert!(text.contains("</tool_result>"));
}
#[test]
fn format_tool_results_multiple() {
let responses = vec![
FunctionResponse {
name: "get_weather".to_owned(),
response: json!({"temp": 72}),
},
FunctionResponse {
name: "get_time".to_owned(),
response: json!({"time": "14:30"}),
},
];
let text = format_tool_results_as_text(&responses);
assert!(text.contains("<tool_result name=\"get_weather\">"));
assert!(text.contains("<tool_result name=\"get_time\">"));
}
#[test]
fn inject_appends_to_existing_system() {
let mut messages = vec![
ChatMessage::system("You are a helpful assistant."),
ChatMessage::user("Hello"),
];
let catalog = "\n\n## Tools\nSome tools here.";
inject_tool_catalog(&mut messages, catalog);
assert_eq!(messages.len(), 2);
assert!(messages[0].content.contains("You are a helpful assistant."));
assert!(messages[0].content.contains("## Tools"));
}
#[test]
fn inject_creates_system_when_missing() {
let mut messages = vec![ChatMessage::user("Hello")];
let catalog = "## Tools\nSome tools here.";
inject_tool_catalog(&mut messages, catalog);
assert_eq!(messages.len(), 2);
assert_eq!(messages[0].role, MessageRole::System);
assert!(messages[0].content.contains("## Tools"));
}
}