use anyhow::Result;
use chrono::Utc;
use crate::llm::{LlmMessage, LlmProvider, LlmRequest};
use crate::observe::{Observer, ToolCallRecord};
use crate::tool::ToolRegistry;
pub struct ToolLoopConfig {
pub max_rounds: usize,
pub max_concurrency: usize,
}
impl Default for ToolLoopConfig {
fn default() -> Self {
Self {
max_rounds: 10,
max_concurrency: 5,
}
}
}
#[derive(Debug)]
pub struct ToolLoopResult {
pub final_text: String,
pub messages: Vec<LlmMessage>,
pub rounds: usize,
pub total_tool_calls: usize,
}
pub async fn run_tool_loop(
llm: &dyn LlmProvider,
request: LlmRequest,
tools: &ToolRegistry,
config: &ToolLoopConfig,
observer: Option<&dyn Observer>,
) -> Result<ToolLoopResult> {
let mut messages = request.messages.clone();
let mut rounds = 0;
let mut total_tool_calls = 0;
loop {
if rounds >= config.max_rounds {
return Err(anyhow::anyhow!(
"tool loop exceeded max_rounds={}",
config.max_rounds
));
}
let req = LlmRequest {
messages: messages.clone(),
tools: Some(tools.to_llm_tools()),
model: request.model.clone(),
system_prompt: request.system_prompt.clone(),
temperature: request.temperature,
max_tokens: request.max_tokens,
};
let response = llm.complete(req).await?;
rounds += 1;
messages.push(LlmMessage {
role: "assistant".into(),
content: response.content.clone(),
tool_calls: response.tool_calls.clone(),
tool_call_id: None,
});
let calls = response.tool_calls.unwrap_or_default();
if calls.is_empty() {
return Ok(ToolLoopResult {
final_text: response.content,
messages,
rounds,
total_tool_calls,
});
}
total_tool_calls += calls.len();
let results = tools
.execute_with_concurrency(&calls, config.max_concurrency)
.await;
if let Some(obs) = observer {
for (idx, (result, duration_ms)) in results.iter().enumerate() {
let call = &calls[idx];
let input: serde_json::Value =
serde_json::from_str(&call.arguments).unwrap_or_default();
let record = ToolCallRecord {
case_key: String::new(),
step_name: None,
tool_name: call.name.clone(),
call_id: call.id.clone(),
input,
output: Some(result.content.clone()),
is_error: result.is_error,
duration_ms: *duration_ms,
timestamp: Utc::now(),
};
obs.on_tool_call(&record).await;
}
}
for (result, _duration_ms) in results {
messages.push(LlmMessage {
role: "tool".into(),
content: result.content.clone(),
tool_calls: None,
tool_call_id: Some(result.call_id.clone()),
});
}
}
}