1use crate::registry::ToolRegistry;
2use async_trait::async_trait;
3use llmoxide::{Client, Event, Message, Prompt, Response, ResponseRequest, Role, ToolCall};
4
5pub fn tools_stream_debug_enabled() -> bool {
11 matches!(
12 std::env::var("LLMOXIDE_DEBUG_TOOLS_STREAM").as_deref(),
13 Ok("1") | Ok("true") | Ok("yes")
14 )
15}
16
17fn stream_dbg(msg: impl std::fmt::Display) {
18 if tools_stream_debug_enabled() {
19 eprintln!("[llmoxide-tools stream] {msg}");
20 }
21}
22
23#[derive(Debug, thiserror::Error)]
24pub enum ToolError {
25 #[error("unknown tool: {tool}")]
26 UnknownTool { tool: String },
27
28 #[error("invalid arguments for tool {tool}: {details}")]
29 InvalidArguments { tool: String, details: String },
30
31 #[error("tool handler error for {tool}: {details}")]
32 Handler { tool: String, details: String },
33
34 #[error("provider returned tool call without id for {tool}")]
35 MissingCallId { tool: String },
36}
37
38#[derive(Debug, Clone)]
39pub struct RunConfig {
40 pub max_rounds: usize,
42}
43
44impl Default for RunConfig {
45 fn default() -> Self {
46 Self { max_rounds: 8 }
47 }
48}
49
50#[async_trait(?Send)]
51pub trait ToolRunner {
52 async fn run_with_tools(
53 &self,
54 req: ResponseRequest,
55 tools: &ToolRegistry,
56 cfg: RunConfig,
57 ) -> Result<Response, llmoxide::Error>;
58}
59
60#[async_trait(?Send)]
61pub trait ToolRunnerText {
62 async fn run_with_tools_text(
63 &self,
64 prompt: impl Into<String> + Send,
65 tools: &ToolRegistry,
66 cfg: RunConfig,
67 ) -> Result<Response, llmoxide::Error>;
68}
69
70#[async_trait(?Send)]
74pub trait ToolRunnerStream {
75 async fn run_with_tools_stream(
76 &self,
77 req: ResponseRequest,
78 tools: &ToolRegistry,
79 cfg: RunConfig,
80 on_event: &mut dyn FnMut(Event),
81 ) -> Result<Response, llmoxide::Error>;
82}
83
84#[async_trait(?Send)]
85pub trait ToolRunnerStreamText {
86 async fn run_with_tools_stream_text(
87 &self,
88 prompt: impl Into<String> + Send,
89 tools: &ToolRegistry,
90 cfg: RunConfig,
91 on_event: &mut dyn FnMut(Event),
92 ) -> Result<Response, llmoxide::Error>;
93}
94
95#[async_trait(?Send)]
96impl ToolRunner for Client {
97 async fn run_with_tools(
98 &self,
99 mut req: ResponseRequest,
100 tools: &ToolRegistry,
101 cfg: RunConfig,
102 ) -> Result<Response, llmoxide::Error> {
103 req = req.tools(tools.specs());
105 let mut history = req.messages;
106
107 for _round in 0..cfg.max_rounds {
108 let req_round = ResponseRequest {
109 model: req.model.clone(),
110 messages: history.clone(),
111 max_output_tokens: req.max_output_tokens,
112 tools: req.tools.clone(),
113 };
114 let resp = self.send(req_round).await?;
115
116 if resp.tool_calls.is_empty() {
117 return Ok(resp);
118 }
119
120 let mut tool_messages: Vec<Message> = Vec::with_capacity(resp.tool_calls.len());
121
122 for call in &resp.tool_calls {
123 let call_id = call.id.clone().ok_or_else(|| {
124 llmoxide::Error::InvalidInput(
125 ToolError::MissingCallId {
126 tool: call.name.clone(),
127 }
128 .to_string()
129 .into(),
130 )
131 })?;
132 let (_name, out) = tools.dispatch(call).await.map_err(|e| {
133 llmoxide::Error::InvalidInput(e.to_string().into())
135 })?;
136 history.push(Message::tool_call(
137 call_id.clone(),
138 call.name.clone(),
139 call.arguments.clone(),
140 ));
141 tool_messages.push(Message::tool_result_named(call_id, call.name.clone(), out));
142 }
143
144 history.extend(tool_messages);
145 }
146
147 let final_req = ResponseRequest {
149 model: req.model,
150 messages: history,
151 max_output_tokens: req.max_output_tokens,
152 tools: req.tools,
153 };
154 self.send(final_req).await
155 }
156}
157
158#[async_trait(?Send)]
159impl ToolRunnerText for Client {
160 async fn run_with_tools_text(
161 &self,
162 prompt: impl Into<String> + Send,
163 tools: &ToolRegistry,
164 cfg: RunConfig,
165 ) -> Result<Response, llmoxide::Error> {
166 let req = ResponseRequest::new_auto().push_message(Message::text(Role::User, prompt));
167 self.run_with_tools(req, tools, cfg).await
168 }
169}
170
171#[async_trait(?Send)]
172impl ToolRunnerText for Prompt {
173 async fn run_with_tools_text(
174 &self,
175 prompt: impl Into<String> + Send,
176 tools: &ToolRegistry,
177 cfg: RunConfig,
178 ) -> Result<Response, llmoxide::Error> {
179 self.client().run_with_tools_text(prompt, tools, cfg).await
180 }
181}
182
183#[async_trait(?Send)]
184impl ToolRunnerStream for Client {
185 async fn run_with_tools_stream(
186 &self,
187 mut req: ResponseRequest,
188 tools: &ToolRegistry,
189 cfg: RunConfig,
190 on_event: &mut dyn FnMut(Event),
191 ) -> Result<Response, llmoxide::Error> {
192 req = req.tools(tools.specs());
193 let mut history = req.messages;
194
195 stream_dbg(format!(
196 "start provider={:?} tool_specs={} max_rounds={} history_messages={}",
197 self.provider(),
198 req.tools.len(),
199 cfg.max_rounds,
200 history.len()
201 ));
202
203 for round in 0..cfg.max_rounds {
204 let req_round = ResponseRequest {
205 model: req.model.clone(),
206 messages: history.clone(),
207 max_output_tokens: req.max_output_tokens,
208 tools: req.tools.clone(),
209 };
210
211 let mut streamed_tool_calls: Vec<ToolCall> = Vec::new();
212
213 stream_dbg(format!(
214 "round {round}: streaming request (messages={}, model={:?})",
215 req_round.messages.len(),
216 req_round.model.as_ref().map(|m| m.0.as_str())
217 ));
218
219 let resp = match self
220 .stream(req_round, |ev| {
221 if let Event::ToolCall(ref tc) = ev {
222 streamed_tool_calls.push(tc.clone());
223 }
224 match ev {
225 Event::Completed(_) => {}
226 other => on_event(other),
227 }
228 })
229 .await
230 {
231 Ok(r) => r,
232 Err(e) => {
233 stream_dbg(format!("round {round}: stream ERROR: {e}"));
234 return Err(e);
235 }
236 };
237
238 stream_dbg(format!(
239 "round {round}: stream OK — resp.tool_calls.len()={}, collected_stream_tool_calls={}, assistant_text_len={:?}",
240 resp.tool_calls.len(),
241 streamed_tool_calls.len(),
242 resp.text().map(|t| t.len())
243 ));
244
245 let tool_calls = if !resp.tool_calls.is_empty() {
246 resp.tool_calls.clone()
247 } else {
248 streamed_tool_calls.clone()
249 };
250
251 if !tool_calls.is_empty() {
252 for (i, c) in tool_calls.iter().enumerate() {
253 stream_dbg(format!(
254 "round {round}: tool_call[{i}] name={:?} id={:?} args={}",
255 c.name, c.id, c.arguments
256 ));
257 }
258 }
259
260 if tool_calls.is_empty() {
261 stream_dbg(format!(
262 "round {round}: no tool calls — emitting Completed and returning (assistant empty={})",
263 resp.text().map(|t| t.is_empty()).unwrap_or(true)
264 ));
265 on_event(Event::Completed(resp.clone()));
266 return Ok(resp);
267 }
268
269 let mut tool_messages: Vec<Message> = Vec::with_capacity(tool_calls.len());
270
271 for call in &tool_calls {
272 let call_id = call.id.clone().ok_or_else(|| {
273 llmoxide::Error::InvalidInput(
274 ToolError::MissingCallId {
275 tool: call.name.clone(),
276 }
277 .to_string()
278 .into(),
279 )
280 })?;
281 let (_name, out) = tools
282 .dispatch(call)
283 .await
284 .map_err(|e| llmoxide::Error::InvalidInput(e.to_string().into()))?;
285 history.push(Message::tool_call(
286 call_id.clone(),
287 call.name.clone(),
288 call.arguments.clone(),
289 ));
290 tool_messages.push(Message::tool_result_named(call_id, call.name.clone(), out));
291 }
292
293 history.extend(tool_messages);
294 stream_dbg(format!(
295 "round {round}: dispatched {} tool result(s); history now {} message(s)",
296 tool_calls.len(),
297 history.len()
298 ));
299 }
300
301 stream_dbg(format!(
302 "max_rounds ({}) exhausted — final stream (no tool execution this turn)",
303 cfg.max_rounds
304 ));
305
306 let final_req = ResponseRequest {
307 model: req.model,
308 messages: history,
309 max_output_tokens: req.max_output_tokens,
310 tools: req.tools,
311 };
312
313 let mut streamed_tool_calls: Vec<ToolCall> = Vec::new();
314 let resp = match self
315 .stream(final_req, |ev| {
316 if let Event::ToolCall(ref tc) = ev {
317 streamed_tool_calls.push(tc.clone());
318 }
319 match ev {
320 Event::Completed(_) => {}
321 other => on_event(other),
322 }
323 })
324 .await
325 {
326 Ok(r) => r,
327 Err(e) => {
328 stream_dbg(format!("final stream ERROR: {e}"));
329 return Err(e);
330 }
331 };
332
333 stream_dbg(format!(
334 "final stream OK — resp.tool_calls.len()={}, streamed_tool_calls={}, assistant_text_len={:?}",
335 resp.tool_calls.len(),
336 streamed_tool_calls.len(),
337 resp.text().map(|t| t.len())
338 ));
339
340 let resp = if resp.tool_calls.is_empty() && !streamed_tool_calls.is_empty() {
341 stream_dbg("merging streamed tool_calls into Response (response had empty tool_calls)");
342 resp.with_tool_calls(streamed_tool_calls)
343 } else {
344 resp
345 };
346
347 on_event(Event::Completed(resp.clone()));
348 Ok(resp)
349 }
350}
351
352#[async_trait(?Send)]
353impl ToolRunnerStreamText for Client {
354 async fn run_with_tools_stream_text(
355 &self,
356 prompt: impl Into<String> + Send,
357 tools: &ToolRegistry,
358 cfg: RunConfig,
359 on_event: &mut dyn FnMut(Event),
360 ) -> Result<Response, llmoxide::Error> {
361 let req = ResponseRequest::new_auto().push_message(Message::text(Role::User, prompt));
362 self.run_with_tools_stream(req, tools, cfg, on_event).await
363 }
364}
365
366#[async_trait(?Send)]
367impl ToolRunnerStream for Prompt {
368 async fn run_with_tools_stream(
369 &self,
370 req: ResponseRequest,
371 tools: &ToolRegistry,
372 cfg: RunConfig,
373 on_event: &mut dyn FnMut(Event),
374 ) -> Result<Response, llmoxide::Error> {
375 self.client()
376 .run_with_tools_stream(req, tools, cfg, on_event)
377 .await
378 }
379}
380
381#[async_trait(?Send)]
382impl ToolRunnerStreamText for Prompt {
383 async fn run_with_tools_stream_text(
384 &self,
385 prompt: impl Into<String> + Send,
386 tools: &ToolRegistry,
387 cfg: RunConfig,
388 on_event: &mut dyn FnMut(Event),
389 ) -> Result<Response, llmoxide::Error> {
390 self.client()
391 .run_with_tools_stream_text(prompt, tools, cfg, on_event)
392 .await
393 }
394}