use crate::registry::ToolRegistry;
use async_trait::async_trait;
use llmoxide::{Client, Event, Message, Prompt, Response, ResponseRequest, Role, ToolCall};
pub fn tools_stream_debug_enabled() -> bool {
matches!(
std::env::var("LLMOXIDE_DEBUG_TOOLS_STREAM").as_deref(),
Ok("1") | Ok("true") | Ok("yes")
)
}
fn stream_dbg(msg: impl std::fmt::Display) {
if tools_stream_debug_enabled() {
eprintln!("[llmoxide-tools stream] {msg}");
}
}
#[derive(Debug, thiserror::Error)]
pub enum ToolError {
#[error("unknown tool: {tool}")]
UnknownTool { tool: String },
#[error("invalid arguments for tool {tool}: {details}")]
InvalidArguments { tool: String, details: String },
#[error("tool handler error for {tool}: {details}")]
Handler { tool: String, details: String },
#[error("provider returned tool call without id for {tool}")]
MissingCallId { tool: String },
}
#[derive(Debug, Clone)]
pub struct RunConfig {
pub max_rounds: usize,
}
impl Default for RunConfig {
fn default() -> Self {
Self { max_rounds: 8 }
}
}
#[async_trait(?Send)]
pub trait ToolRunner {
async fn run_with_tools(
&self,
req: ResponseRequest,
tools: &ToolRegistry,
cfg: RunConfig,
) -> Result<Response, llmoxide::Error>;
}
#[async_trait(?Send)]
pub trait ToolRunnerText {
async fn run_with_tools_text(
&self,
prompt: impl Into<String> + Send,
tools: &ToolRegistry,
cfg: RunConfig,
) -> Result<Response, llmoxide::Error>;
}
#[async_trait(?Send)]
pub trait ToolRunnerStream {
async fn run_with_tools_stream(
&self,
req: ResponseRequest,
tools: &ToolRegistry,
cfg: RunConfig,
on_event: &mut dyn FnMut(Event),
) -> Result<Response, llmoxide::Error>;
}
#[async_trait(?Send)]
pub trait ToolRunnerStreamText {
async fn run_with_tools_stream_text(
&self,
prompt: impl Into<String> + Send,
tools: &ToolRegistry,
cfg: RunConfig,
on_event: &mut dyn FnMut(Event),
) -> Result<Response, llmoxide::Error>;
}
#[async_trait(?Send)]
impl ToolRunner for Client {
async fn run_with_tools(
&self,
mut req: ResponseRequest,
tools: &ToolRegistry,
cfg: RunConfig,
) -> Result<Response, llmoxide::Error> {
req = req.tools(tools.specs());
let mut history = req.messages;
for _round in 0..cfg.max_rounds {
let req_round = ResponseRequest {
model: req.model.clone(),
messages: history.clone(),
max_output_tokens: req.max_output_tokens,
tools: req.tools.clone(),
};
let resp = self.send(req_round).await?;
if resp.tool_calls.is_empty() {
return Ok(resp);
}
let mut tool_messages: Vec<Message> = Vec::with_capacity(resp.tool_calls.len());
for call in &resp.tool_calls {
let call_id = call.id.clone().ok_or_else(|| {
llmoxide::Error::InvalidInput(
ToolError::MissingCallId {
tool: call.name.clone(),
}
.to_string()
.into(),
)
})?;
let (_name, out) = tools.dispatch(call).await.map_err(|e| {
llmoxide::Error::InvalidInput(e.to_string().into())
})?;
history.push(Message::tool_call(
call_id.clone(),
call.name.clone(),
call.arguments.clone(),
));
tool_messages.push(Message::tool_result_named(call_id, call.name.clone(), out));
}
history.extend(tool_messages);
}
let final_req = ResponseRequest {
model: req.model,
messages: history,
max_output_tokens: req.max_output_tokens,
tools: req.tools,
};
self.send(final_req).await
}
}
#[async_trait(?Send)]
impl ToolRunnerText for Client {
async fn run_with_tools_text(
&self,
prompt: impl Into<String> + Send,
tools: &ToolRegistry,
cfg: RunConfig,
) -> Result<Response, llmoxide::Error> {
let req = ResponseRequest::new_auto().push_message(Message::text(Role::User, prompt));
self.run_with_tools(req, tools, cfg).await
}
}
#[async_trait(?Send)]
impl ToolRunnerText for Prompt {
async fn run_with_tools_text(
&self,
prompt: impl Into<String> + Send,
tools: &ToolRegistry,
cfg: RunConfig,
) -> Result<Response, llmoxide::Error> {
self.client().run_with_tools_text(prompt, tools, cfg).await
}
}
#[async_trait(?Send)]
impl ToolRunnerStream for Client {
async fn run_with_tools_stream(
&self,
mut req: ResponseRequest,
tools: &ToolRegistry,
cfg: RunConfig,
on_event: &mut dyn FnMut(Event),
) -> Result<Response, llmoxide::Error> {
req = req.tools(tools.specs());
let mut history = req.messages;
stream_dbg(format!(
"start provider={:?} tool_specs={} max_rounds={} history_messages={}",
self.provider(),
req.tools.len(),
cfg.max_rounds,
history.len()
));
for round in 0..cfg.max_rounds {
let req_round = ResponseRequest {
model: req.model.clone(),
messages: history.clone(),
max_output_tokens: req.max_output_tokens,
tools: req.tools.clone(),
};
let mut streamed_tool_calls: Vec<ToolCall> = Vec::new();
stream_dbg(format!(
"round {round}: streaming request (messages={}, model={:?})",
req_round.messages.len(),
req_round.model.as_ref().map(|m| m.0.as_str())
));
let resp = match self
.stream(req_round, |ev| {
if let Event::ToolCall(ref tc) = ev {
streamed_tool_calls.push(tc.clone());
}
match ev {
Event::Completed(_) => {}
other => on_event(other),
}
})
.await
{
Ok(r) => r,
Err(e) => {
stream_dbg(format!("round {round}: stream ERROR: {e}"));
return Err(e);
}
};
stream_dbg(format!(
"round {round}: stream OK — resp.tool_calls.len()={}, collected_stream_tool_calls={}, assistant_text_len={:?}",
resp.tool_calls.len(),
streamed_tool_calls.len(),
resp.text().map(|t| t.len())
));
let tool_calls = if !resp.tool_calls.is_empty() {
resp.tool_calls.clone()
} else {
streamed_tool_calls.clone()
};
if !tool_calls.is_empty() {
for (i, c) in tool_calls.iter().enumerate() {
stream_dbg(format!(
"round {round}: tool_call[{i}] name={:?} id={:?} args={}",
c.name, c.id, c.arguments
));
}
}
if tool_calls.is_empty() {
stream_dbg(format!(
"round {round}: no tool calls — emitting Completed and returning (assistant empty={})",
resp.text().map(|t| t.is_empty()).unwrap_or(true)
));
on_event(Event::Completed(resp.clone()));
return Ok(resp);
}
let mut tool_messages: Vec<Message> = Vec::with_capacity(tool_calls.len());
for call in &tool_calls {
let call_id = call.id.clone().ok_or_else(|| {
llmoxide::Error::InvalidInput(
ToolError::MissingCallId {
tool: call.name.clone(),
}
.to_string()
.into(),
)
})?;
let (_name, out) = tools
.dispatch(call)
.await
.map_err(|e| llmoxide::Error::InvalidInput(e.to_string().into()))?;
history.push(Message::tool_call(
call_id.clone(),
call.name.clone(),
call.arguments.clone(),
));
tool_messages.push(Message::tool_result_named(call_id, call.name.clone(), out));
}
history.extend(tool_messages);
stream_dbg(format!(
"round {round}: dispatched {} tool result(s); history now {} message(s)",
tool_calls.len(),
history.len()
));
}
stream_dbg(format!(
"max_rounds ({}) exhausted — final stream (no tool execution this turn)",
cfg.max_rounds
));
let final_req = ResponseRequest {
model: req.model,
messages: history,
max_output_tokens: req.max_output_tokens,
tools: req.tools,
};
let mut streamed_tool_calls: Vec<ToolCall> = Vec::new();
let resp = match self
.stream(final_req, |ev| {
if let Event::ToolCall(ref tc) = ev {
streamed_tool_calls.push(tc.clone());
}
match ev {
Event::Completed(_) => {}
other => on_event(other),
}
})
.await
{
Ok(r) => r,
Err(e) => {
stream_dbg(format!("final stream ERROR: {e}"));
return Err(e);
}
};
stream_dbg(format!(
"final stream OK — resp.tool_calls.len()={}, streamed_tool_calls={}, assistant_text_len={:?}",
resp.tool_calls.len(),
streamed_tool_calls.len(),
resp.text().map(|t| t.len())
));
let resp = if resp.tool_calls.is_empty() && !streamed_tool_calls.is_empty() {
stream_dbg("merging streamed tool_calls into Response (response had empty tool_calls)");
resp.with_tool_calls(streamed_tool_calls)
} else {
resp
};
on_event(Event::Completed(resp.clone()));
Ok(resp)
}
}
#[async_trait(?Send)]
impl ToolRunnerStreamText for Client {
async fn run_with_tools_stream_text(
&self,
prompt: impl Into<String> + Send,
tools: &ToolRegistry,
cfg: RunConfig,
on_event: &mut dyn FnMut(Event),
) -> Result<Response, llmoxide::Error> {
let req = ResponseRequest::new_auto().push_message(Message::text(Role::User, prompt));
self.run_with_tools_stream(req, tools, cfg, on_event).await
}
}
#[async_trait(?Send)]
impl ToolRunnerStream for Prompt {
async fn run_with_tools_stream(
&self,
req: ResponseRequest,
tools: &ToolRegistry,
cfg: RunConfig,
on_event: &mut dyn FnMut(Event),
) -> Result<Response, llmoxide::Error> {
self.client()
.run_with_tools_stream(req, tools, cfg, on_event)
.await
}
}
#[async_trait(?Send)]
impl ToolRunnerStreamText for Prompt {
async fn run_with_tools_stream_text(
&self,
prompt: impl Into<String> + Send,
tools: &ToolRegistry,
cfg: RunConfig,
on_event: &mut dyn FnMut(Event),
) -> Result<Response, llmoxide::Error> {
self.client()
.run_with_tools_stream_text(prompt, tools, cfg, on_event)
.await
}
}