Skip to main content

llmoxide_tools/
runner.rs

1use crate::registry::ToolRegistry;
2use async_trait::async_trait;
3use llmoxide::{Client, Event, Message, Prompt, Response, ResponseRequest, Role, ToolCall};
4
5/// When **`LLMOXIDE_DEBUG_TOOLS_STREAM`** is `1`, `true`, or `yes`, [`ToolRunnerStream`] logs
6/// diagnostics to stderr (provider, rounds, tool-call counts, response summaries).
7///
8/// Anthropic SSE tracing uses the same env values as **`LLMOXIDE_DEBUG_ANTHROPIC_STREAM`** (or
9/// **`LLMOXIDE_DEBUG_TOOLS_STREAM`**): stderr lines prefixed with **`[llmoxide anthropic stream]`**.
10pub fn tools_stream_debug_enabled() -> bool {
11    matches!(
12        std::env::var("LLMOXIDE_DEBUG_TOOLS_STREAM").as_deref(),
13        Ok("1") | Ok("true") | Ok("yes")
14    )
15}
16
17fn stream_dbg(msg: impl std::fmt::Display) {
18    if tools_stream_debug_enabled() {
19        eprintln!("[llmoxide-tools stream] {msg}");
20    }
21}
22
23#[derive(Debug, thiserror::Error)]
24pub enum ToolError {
25    #[error("unknown tool: {tool}")]
26    UnknownTool { tool: String },
27
28    #[error("invalid arguments for tool {tool}: {details}")]
29    InvalidArguments { tool: String, details: String },
30
31    #[error("tool handler error for {tool}: {details}")]
32    Handler { tool: String, details: String },
33
34    #[error("provider returned tool call without id for {tool}")]
35    MissingCallId { tool: String },
36}
37
38#[derive(Debug, Clone)]
39pub struct RunConfig {
40    /// Maximum number of tool rounds (model -> tools -> model ...).
41    pub max_rounds: usize,
42}
43
44impl Default for RunConfig {
45    fn default() -> Self {
46        Self { max_rounds: 8 }
47    }
48}
49
50#[async_trait(?Send)]
51pub trait ToolRunner {
52    async fn run_with_tools(
53        &self,
54        req: ResponseRequest,
55        tools: &ToolRegistry,
56        cfg: RunConfig,
57    ) -> Result<Response, llmoxide::Error>;
58}
59
60#[async_trait(?Send)]
61pub trait ToolRunnerText {
62    async fn run_with_tools_text(
63        &self,
64        prompt: impl Into<String> + Send,
65        tools: &ToolRegistry,
66        cfg: RunConfig,
67    ) -> Result<Response, llmoxide::Error>;
68}
69
70/// Streaming variant: forwards [`Event::TextDelta`] and [`Event::ToolCall`] while running the tool
71/// loop. Inner [`Event::Completed`] events from [`Client::stream`] are suppressed; a single
72/// [`Event::Completed`] is emitted for the **final** assistant response once the loop finishes.
73#[async_trait(?Send)]
74pub trait ToolRunnerStream {
75    async fn run_with_tools_stream(
76        &self,
77        req: ResponseRequest,
78        tools: &ToolRegistry,
79        cfg: RunConfig,
80        on_event: &mut dyn FnMut(Event),
81    ) -> Result<Response, llmoxide::Error>;
82}
83
84#[async_trait(?Send)]
85pub trait ToolRunnerStreamText {
86    async fn run_with_tools_stream_text(
87        &self,
88        prompt: impl Into<String> + Send,
89        tools: &ToolRegistry,
90        cfg: RunConfig,
91        on_event: &mut dyn FnMut(Event),
92    ) -> Result<Response, llmoxide::Error>;
93}
94
95#[async_trait(?Send)]
96impl ToolRunner for Client {
97    async fn run_with_tools(
98        &self,
99        mut req: ResponseRequest,
100        tools: &ToolRegistry,
101        cfg: RunConfig,
102    ) -> Result<Response, llmoxide::Error> {
103        // Attach schemas once; providers will ignore if unsupported.
104        req = req.tools(tools.specs());
105        let mut history = req.messages;
106
107        for _round in 0..cfg.max_rounds {
108            let req_round = ResponseRequest {
109                model: req.model.clone(),
110                messages: history.clone(),
111                max_output_tokens: req.max_output_tokens,
112                tools: req.tools.clone(),
113            };
114            let resp = self.send(req_round).await?;
115
116            if resp.tool_calls.is_empty() {
117                return Ok(resp);
118            }
119
120            let mut tool_messages: Vec<Message> = Vec::with_capacity(resp.tool_calls.len());
121
122            for call in &resp.tool_calls {
123                let call_id = call.id.clone().ok_or_else(|| {
124                    llmoxide::Error::InvalidInput(
125                        ToolError::MissingCallId {
126                            tool: call.name.clone(),
127                        }
128                        .to_string()
129                        .into(),
130                    )
131                })?;
132                let (_name, out) = tools.dispatch(call).await.map_err(|e| {
133                    // Map tool-layer errors into llmoxide::Error::InvalidInput for now.
134                    llmoxide::Error::InvalidInput(e.to_string().into())
135                })?;
136                history.push(Message::tool_call(
137                    call_id.clone(),
138                    call.name.clone(),
139                    call.arguments.clone(),
140                ));
141                tool_messages.push(Message::tool_result_named(call_id, call.name.clone(), out));
142            }
143
144            history.extend(tool_messages);
145        }
146
147        // If we hit the round limit, do one final call without executing tools.
148        let final_req = ResponseRequest {
149            model: req.model,
150            messages: history,
151            max_output_tokens: req.max_output_tokens,
152            tools: req.tools,
153        };
154        self.send(final_req).await
155    }
156}
157
158#[async_trait(?Send)]
159impl ToolRunnerText for Client {
160    async fn run_with_tools_text(
161        &self,
162        prompt: impl Into<String> + Send,
163        tools: &ToolRegistry,
164        cfg: RunConfig,
165    ) -> Result<Response, llmoxide::Error> {
166        let req = ResponseRequest::new_auto().push_message(Message::text(Role::User, prompt));
167        self.run_with_tools(req, tools, cfg).await
168    }
169}
170
171#[async_trait(?Send)]
172impl ToolRunnerText for Prompt {
173    async fn run_with_tools_text(
174        &self,
175        prompt: impl Into<String> + Send,
176        tools: &ToolRegistry,
177        cfg: RunConfig,
178    ) -> Result<Response, llmoxide::Error> {
179        self.client().run_with_tools_text(prompt, tools, cfg).await
180    }
181}
182
183#[async_trait(?Send)]
184impl ToolRunnerStream for Client {
185    async fn run_with_tools_stream(
186        &self,
187        mut req: ResponseRequest,
188        tools: &ToolRegistry,
189        cfg: RunConfig,
190        on_event: &mut dyn FnMut(Event),
191    ) -> Result<Response, llmoxide::Error> {
192        req = req.tools(tools.specs());
193        let mut history = req.messages;
194
195        stream_dbg(format!(
196            "start provider={:?} tool_specs={} max_rounds={} history_messages={}",
197            self.provider(),
198            req.tools.len(),
199            cfg.max_rounds,
200            history.len()
201        ));
202
203        for round in 0..cfg.max_rounds {
204            let req_round = ResponseRequest {
205                model: req.model.clone(),
206                messages: history.clone(),
207                max_output_tokens: req.max_output_tokens,
208                tools: req.tools.clone(),
209            };
210
211            let mut streamed_tool_calls: Vec<ToolCall> = Vec::new();
212
213            stream_dbg(format!(
214                "round {round}: streaming request (messages={}, model={:?})",
215                req_round.messages.len(),
216                req_round.model.as_ref().map(|m| m.0.as_str())
217            ));
218
219            let resp = match self
220                .stream(req_round, |ev| {
221                    if let Event::ToolCall(ref tc) = ev {
222                        streamed_tool_calls.push(tc.clone());
223                    }
224                    match ev {
225                        Event::Completed(_) => {}
226                        other => on_event(other),
227                    }
228                })
229                .await
230            {
231                Ok(r) => r,
232                Err(e) => {
233                    stream_dbg(format!("round {round}: stream ERROR: {e}"));
234                    return Err(e);
235                }
236            };
237
238            stream_dbg(format!(
239                "round {round}: stream OK — resp.tool_calls.len()={}, collected_stream_tool_calls={}, assistant_text_len={:?}",
240                resp.tool_calls.len(),
241                streamed_tool_calls.len(),
242                resp.text().map(|t| t.len())
243            ));
244
245            let tool_calls = if !resp.tool_calls.is_empty() {
246                resp.tool_calls.clone()
247            } else {
248                streamed_tool_calls.clone()
249            };
250
251            if !tool_calls.is_empty() {
252                for (i, c) in tool_calls.iter().enumerate() {
253                    stream_dbg(format!(
254                        "round {round}: tool_call[{i}] name={:?} id={:?} args={}",
255                        c.name, c.id, c.arguments
256                    ));
257                }
258            }
259
260            if tool_calls.is_empty() {
261                stream_dbg(format!(
262                    "round {round}: no tool calls — emitting Completed and returning (assistant empty={})",
263                    resp.text().map(|t| t.is_empty()).unwrap_or(true)
264                ));
265                on_event(Event::Completed(resp.clone()));
266                return Ok(resp);
267            }
268
269            let mut tool_messages: Vec<Message> = Vec::with_capacity(tool_calls.len());
270
271            for call in &tool_calls {
272                let call_id = call.id.clone().ok_or_else(|| {
273                    llmoxide::Error::InvalidInput(
274                        ToolError::MissingCallId {
275                            tool: call.name.clone(),
276                        }
277                        .to_string()
278                        .into(),
279                    )
280                })?;
281                let (_name, out) = tools
282                    .dispatch(call)
283                    .await
284                    .map_err(|e| llmoxide::Error::InvalidInput(e.to_string().into()))?;
285                history.push(Message::tool_call(
286                    call_id.clone(),
287                    call.name.clone(),
288                    call.arguments.clone(),
289                ));
290                tool_messages.push(Message::tool_result_named(call_id, call.name.clone(), out));
291            }
292
293            history.extend(tool_messages);
294            stream_dbg(format!(
295                "round {round}: dispatched {} tool result(s); history now {} message(s)",
296                tool_calls.len(),
297                history.len()
298            ));
299        }
300
301        stream_dbg(format!(
302            "max_rounds ({}) exhausted — final stream (no tool execution this turn)",
303            cfg.max_rounds
304        ));
305
306        let final_req = ResponseRequest {
307            model: req.model,
308            messages: history,
309            max_output_tokens: req.max_output_tokens,
310            tools: req.tools,
311        };
312
313        let mut streamed_tool_calls: Vec<ToolCall> = Vec::new();
314        let resp = match self
315            .stream(final_req, |ev| {
316                if let Event::ToolCall(ref tc) = ev {
317                    streamed_tool_calls.push(tc.clone());
318                }
319                match ev {
320                    Event::Completed(_) => {}
321                    other => on_event(other),
322                }
323            })
324            .await
325        {
326            Ok(r) => r,
327            Err(e) => {
328                stream_dbg(format!("final stream ERROR: {e}"));
329                return Err(e);
330            }
331        };
332
333        stream_dbg(format!(
334            "final stream OK — resp.tool_calls.len()={}, streamed_tool_calls={}, assistant_text_len={:?}",
335            resp.tool_calls.len(),
336            streamed_tool_calls.len(),
337            resp.text().map(|t| t.len())
338        ));
339
340        let resp = if resp.tool_calls.is_empty() && !streamed_tool_calls.is_empty() {
341            stream_dbg("merging streamed tool_calls into Response (response had empty tool_calls)");
342            resp.with_tool_calls(streamed_tool_calls)
343        } else {
344            resp
345        };
346
347        on_event(Event::Completed(resp.clone()));
348        Ok(resp)
349    }
350}
351
352#[async_trait(?Send)]
353impl ToolRunnerStreamText for Client {
354    async fn run_with_tools_stream_text(
355        &self,
356        prompt: impl Into<String> + Send,
357        tools: &ToolRegistry,
358        cfg: RunConfig,
359        on_event: &mut dyn FnMut(Event),
360    ) -> Result<Response, llmoxide::Error> {
361        let req = ResponseRequest::new_auto().push_message(Message::text(Role::User, prompt));
362        self.run_with_tools_stream(req, tools, cfg, on_event).await
363    }
364}
365
366#[async_trait(?Send)]
367impl ToolRunnerStream for Prompt {
368    async fn run_with_tools_stream(
369        &self,
370        req: ResponseRequest,
371        tools: &ToolRegistry,
372        cfg: RunConfig,
373        on_event: &mut dyn FnMut(Event),
374    ) -> Result<Response, llmoxide::Error> {
375        self.client()
376            .run_with_tools_stream(req, tools, cfg, on_event)
377            .await
378    }
379}
380
381#[async_trait(?Send)]
382impl ToolRunnerStreamText for Prompt {
383    async fn run_with_tools_stream_text(
384        &self,
385        prompt: impl Into<String> + Send,
386        tools: &ToolRegistry,
387        cfg: RunConfig,
388        on_event: &mut dyn FnMut(Event),
389    ) -> Result<Response, llmoxide::Error> {
390        self.client()
391            .run_with_tools_stream_text(prompt, tools, cfg, on_event)
392            .await
393    }
394}