llmoxide-tools 0.1.0

Tool-calling runner for llmoxide (schemas, dispatch, streaming callbacks)
Documentation
use crate::registry::ToolRegistry;
use async_trait::async_trait;
use llmoxide::{Client, Event, Message, Prompt, Response, ResponseRequest, Role, ToolCall};

/// When **`LLMOXIDE_DEBUG_TOOLS_STREAM`** is `1`, `true`, or `yes`, [`ToolRunnerStream`] logs
/// diagnostics to stderr (provider, rounds, tool-call counts, response summaries).
///
/// Anthropic SSE tracing uses the same env values as **`LLMOXIDE_DEBUG_ANTHROPIC_STREAM`** (or
/// **`LLMOXIDE_DEBUG_TOOLS_STREAM`**): stderr lines prefixed with **`[llmoxide anthropic stream]`**.
pub fn tools_stream_debug_enabled() -> bool {
    matches!(
        std::env::var("LLMOXIDE_DEBUG_TOOLS_STREAM").as_deref(),
        Ok("1") | Ok("true") | Ok("yes")
    )
}

fn stream_dbg(msg: impl std::fmt::Display) {
    if tools_stream_debug_enabled() {
        eprintln!("[llmoxide-tools stream] {msg}");
    }
}

#[derive(Debug, thiserror::Error)]
pub enum ToolError {
    #[error("unknown tool: {tool}")]
    UnknownTool { tool: String },

    #[error("invalid arguments for tool {tool}: {details}")]
    InvalidArguments { tool: String, details: String },

    #[error("tool handler error for {tool}: {details}")]
    Handler { tool: String, details: String },

    #[error("provider returned tool call without id for {tool}")]
    MissingCallId { tool: String },
}

#[derive(Debug, Clone)]
pub struct RunConfig {
    /// Maximum number of tool rounds (model -> tools -> model ...).
    pub max_rounds: usize,
}

impl Default for RunConfig {
    fn default() -> Self {
        Self { max_rounds: 8 }
    }
}

#[async_trait(?Send)]
pub trait ToolRunner {
    async fn run_with_tools(
        &self,
        req: ResponseRequest,
        tools: &ToolRegistry,
        cfg: RunConfig,
    ) -> Result<Response, llmoxide::Error>;
}

#[async_trait(?Send)]
pub trait ToolRunnerText {
    async fn run_with_tools_text(
        &self,
        prompt: impl Into<String> + Send,
        tools: &ToolRegistry,
        cfg: RunConfig,
    ) -> Result<Response, llmoxide::Error>;
}

/// Streaming variant: forwards [`Event::TextDelta`] and [`Event::ToolCall`] while running the tool
/// loop. Inner [`Event::Completed`] events from [`Client::stream`] are suppressed; a single
/// [`Event::Completed`] is emitted for the **final** assistant response once the loop finishes.
#[async_trait(?Send)]
pub trait ToolRunnerStream {
    async fn run_with_tools_stream(
        &self,
        req: ResponseRequest,
        tools: &ToolRegistry,
        cfg: RunConfig,
        on_event: &mut dyn FnMut(Event),
    ) -> Result<Response, llmoxide::Error>;
}

#[async_trait(?Send)]
pub trait ToolRunnerStreamText {
    async fn run_with_tools_stream_text(
        &self,
        prompt: impl Into<String> + Send,
        tools: &ToolRegistry,
        cfg: RunConfig,
        on_event: &mut dyn FnMut(Event),
    ) -> Result<Response, llmoxide::Error>;
}

#[async_trait(?Send)]
impl ToolRunner for Client {
    async fn run_with_tools(
        &self,
        mut req: ResponseRequest,
        tools: &ToolRegistry,
        cfg: RunConfig,
    ) -> Result<Response, llmoxide::Error> {
        // Attach schemas once; providers will ignore if unsupported.
        req = req.tools(tools.specs());
        let mut history = req.messages;

        for _round in 0..cfg.max_rounds {
            let req_round = ResponseRequest {
                model: req.model.clone(),
                messages: history.clone(),
                max_output_tokens: req.max_output_tokens,
                tools: req.tools.clone(),
            };
            let resp = self.send(req_round).await?;

            if resp.tool_calls.is_empty() {
                return Ok(resp);
            }

            let mut tool_messages: Vec<Message> = Vec::with_capacity(resp.tool_calls.len());

            for call in &resp.tool_calls {
                let call_id = call.id.clone().ok_or_else(|| {
                    llmoxide::Error::InvalidInput(
                        ToolError::MissingCallId {
                            tool: call.name.clone(),
                        }
                        .to_string()
                        .into(),
                    )
                })?;
                let (_name, out) = tools.dispatch(call).await.map_err(|e| {
                    // Map tool-layer errors into llmoxide::Error::InvalidInput for now.
                    llmoxide::Error::InvalidInput(e.to_string().into())
                })?;
                history.push(Message::tool_call(
                    call_id.clone(),
                    call.name.clone(),
                    call.arguments.clone(),
                ));
                tool_messages.push(Message::tool_result_named(call_id, call.name.clone(), out));
            }

            history.extend(tool_messages);
        }

        // If we hit the round limit, do one final call without executing tools.
        let final_req = ResponseRequest {
            model: req.model,
            messages: history,
            max_output_tokens: req.max_output_tokens,
            tools: req.tools,
        };
        self.send(final_req).await
    }
}

#[async_trait(?Send)]
impl ToolRunnerText for Client {
    async fn run_with_tools_text(
        &self,
        prompt: impl Into<String> + Send,
        tools: &ToolRegistry,
        cfg: RunConfig,
    ) -> Result<Response, llmoxide::Error> {
        let req = ResponseRequest::new_auto().push_message(Message::text(Role::User, prompt));
        self.run_with_tools(req, tools, cfg).await
    }
}

#[async_trait(?Send)]
impl ToolRunnerText for Prompt {
    async fn run_with_tools_text(
        &self,
        prompt: impl Into<String> + Send,
        tools: &ToolRegistry,
        cfg: RunConfig,
    ) -> Result<Response, llmoxide::Error> {
        self.client().run_with_tools_text(prompt, tools, cfg).await
    }
}

#[async_trait(?Send)]
impl ToolRunnerStream for Client {
    async fn run_with_tools_stream(
        &self,
        mut req: ResponseRequest,
        tools: &ToolRegistry,
        cfg: RunConfig,
        on_event: &mut dyn FnMut(Event),
    ) -> Result<Response, llmoxide::Error> {
        req = req.tools(tools.specs());
        let mut history = req.messages;

        stream_dbg(format!(
            "start provider={:?} tool_specs={} max_rounds={} history_messages={}",
            self.provider(),
            req.tools.len(),
            cfg.max_rounds,
            history.len()
        ));

        for round in 0..cfg.max_rounds {
            let req_round = ResponseRequest {
                model: req.model.clone(),
                messages: history.clone(),
                max_output_tokens: req.max_output_tokens,
                tools: req.tools.clone(),
            };

            let mut streamed_tool_calls: Vec<ToolCall> = Vec::new();

            stream_dbg(format!(
                "round {round}: streaming request (messages={}, model={:?})",
                req_round.messages.len(),
                req_round.model.as_ref().map(|m| m.0.as_str())
            ));

            let resp = match self
                .stream(req_round, |ev| {
                    if let Event::ToolCall(ref tc) = ev {
                        streamed_tool_calls.push(tc.clone());
                    }
                    match ev {
                        Event::Completed(_) => {}
                        other => on_event(other),
                    }
                })
                .await
            {
                Ok(r) => r,
                Err(e) => {
                    stream_dbg(format!("round {round}: stream ERROR: {e}"));
                    return Err(e);
                }
            };

            stream_dbg(format!(
                "round {round}: stream OK — resp.tool_calls.len()={}, collected_stream_tool_calls={}, assistant_text_len={:?}",
                resp.tool_calls.len(),
                streamed_tool_calls.len(),
                resp.text().map(|t| t.len())
            ));

            let tool_calls = if !resp.tool_calls.is_empty() {
                resp.tool_calls.clone()
            } else {
                streamed_tool_calls.clone()
            };

            if !tool_calls.is_empty() {
                for (i, c) in tool_calls.iter().enumerate() {
                    stream_dbg(format!(
                        "round {round}: tool_call[{i}] name={:?} id={:?} args={}",
                        c.name, c.id, c.arguments
                    ));
                }
            }

            if tool_calls.is_empty() {
                stream_dbg(format!(
                    "round {round}: no tool calls — emitting Completed and returning (assistant empty={})",
                    resp.text().map(|t| t.is_empty()).unwrap_or(true)
                ));
                on_event(Event::Completed(resp.clone()));
                return Ok(resp);
            }

            let mut tool_messages: Vec<Message> = Vec::with_capacity(tool_calls.len());

            for call in &tool_calls {
                let call_id = call.id.clone().ok_or_else(|| {
                    llmoxide::Error::InvalidInput(
                        ToolError::MissingCallId {
                            tool: call.name.clone(),
                        }
                        .to_string()
                        .into(),
                    )
                })?;
                let (_name, out) = tools
                    .dispatch(call)
                    .await
                    .map_err(|e| llmoxide::Error::InvalidInput(e.to_string().into()))?;
                history.push(Message::tool_call(
                    call_id.clone(),
                    call.name.clone(),
                    call.arguments.clone(),
                ));
                tool_messages.push(Message::tool_result_named(call_id, call.name.clone(), out));
            }

            history.extend(tool_messages);
            stream_dbg(format!(
                "round {round}: dispatched {} tool result(s); history now {} message(s)",
                tool_calls.len(),
                history.len()
            ));
        }

        stream_dbg(format!(
            "max_rounds ({}) exhausted — final stream (no tool execution this turn)",
            cfg.max_rounds
        ));

        let final_req = ResponseRequest {
            model: req.model,
            messages: history,
            max_output_tokens: req.max_output_tokens,
            tools: req.tools,
        };

        let mut streamed_tool_calls: Vec<ToolCall> = Vec::new();
        let resp = match self
            .stream(final_req, |ev| {
                if let Event::ToolCall(ref tc) = ev {
                    streamed_tool_calls.push(tc.clone());
                }
                match ev {
                    Event::Completed(_) => {}
                    other => on_event(other),
                }
            })
            .await
        {
            Ok(r) => r,
            Err(e) => {
                stream_dbg(format!("final stream ERROR: {e}"));
                return Err(e);
            }
        };

        stream_dbg(format!(
            "final stream OK — resp.tool_calls.len()={}, streamed_tool_calls={}, assistant_text_len={:?}",
            resp.tool_calls.len(),
            streamed_tool_calls.len(),
            resp.text().map(|t| t.len())
        ));

        let resp = if resp.tool_calls.is_empty() && !streamed_tool_calls.is_empty() {
            stream_dbg("merging streamed tool_calls into Response (response had empty tool_calls)");
            resp.with_tool_calls(streamed_tool_calls)
        } else {
            resp
        };

        on_event(Event::Completed(resp.clone()));
        Ok(resp)
    }
}

#[async_trait(?Send)]
impl ToolRunnerStreamText for Client {
    async fn run_with_tools_stream_text(
        &self,
        prompt: impl Into<String> + Send,
        tools: &ToolRegistry,
        cfg: RunConfig,
        on_event: &mut dyn FnMut(Event),
    ) -> Result<Response, llmoxide::Error> {
        let req = ResponseRequest::new_auto().push_message(Message::text(Role::User, prompt));
        self.run_with_tools_stream(req, tools, cfg, on_event).await
    }
}

#[async_trait(?Send)]
impl ToolRunnerStream for Prompt {
    async fn run_with_tools_stream(
        &self,
        req: ResponseRequest,
        tools: &ToolRegistry,
        cfg: RunConfig,
        on_event: &mut dyn FnMut(Event),
    ) -> Result<Response, llmoxide::Error> {
        self.client()
            .run_with_tools_stream(req, tools, cfg, on_event)
            .await
    }
}

#[async_trait(?Send)]
impl ToolRunnerStreamText for Prompt {
    async fn run_with_tools_stream_text(
        &self,
        prompt: impl Into<String> + Send,
        tools: &ToolRegistry,
        cfg: RunConfig,
        on_event: &mut dyn FnMut(Event),
    ) -> Result<Response, llmoxide::Error> {
        self.client()
            .run_with_tools_stream_text(prompt, tools, cfg, on_event)
            .await
    }
}