oxillama_runtime/tool_dispatch.rs
1//! Tool-invocation runtime callbacks.
2//!
3//! Provides infrastructure for detecting, parsing, and dispatching tool calls
4//! produced by language models during generation.
5//!
6//! ## Overview
7//!
8//! Different model families emit tool calls using different delimiter syntax:
9//! - LLaMA 3: `<|tool_call|>{ ... }<|/tool_call|>`
10//! - Qwen: `<tool_call>{ ... }</tool_call>`
11//! - Mistral: `[TOOL_CALLS][ ... ]`
12//! - Custom: user-supplied open/close delimiters
13//!
14//! The [`ToolCallDetector`] accumulates token text, detects open/close
15//! delimiters, validates the JSON payload, and fires whenever a complete tool
16//! call is parsed.
17//!
18//! Tool results can be queued for injection back into the generation stream
19//! via the engine's injection queue mechanism.
20
21use serde_json::Value;
22use std::sync::Arc;
23
24// ─── Core trait ───────────────────────────────────────────────────────────────
25
26/// Dispatches tool calls to registered handler implementations.
27///
28/// Implement this trait to handle tool invocations produced by the model
29/// during generation. The implementation must be `Send + Sync` so it can
30/// be shared across threads and stored in `Arc`.
31///
32/// # Example
33///
34/// ```
35/// use oxillama_runtime::tool_dispatch::{ToolDispatcher, ToolResult};
36/// use serde_json::Value;
37///
38/// struct WeatherTool;
39///
40/// impl ToolDispatcher for WeatherTool {
41/// fn invoke(&self, name: &str, args: &Value) -> ToolResult {
42/// if name == "get_weather" {
43/// ToolResult::Ok(Value::String("sunny, 22°C".to_string()))
44/// } else {
45/// ToolResult::Err(format!("unknown tool: {name}"))
46/// }
47/// }
48/// }
49/// ```
50pub trait ToolDispatcher: Send + Sync {
51 /// Invoke the named tool with the given JSON arguments.
52 ///
53 /// Returns [`ToolResult::Ok`] with the tool's output value on success, or
54 /// [`ToolResult::Err`] with an error message on failure.
55 fn invoke(&self, name: &str, args: &Value) -> ToolResult;
56}
57
58/// Result of a tool invocation.
59#[derive(Debug, Clone)]
60pub enum ToolResult {
61 /// Successful invocation with a JSON result value.
62 Ok(Value),
63 /// Failed invocation with a human-readable error message.
64 Err(String),
65}
66
67impl ToolResult {
68 /// Format the result as a string suitable for injection into the token stream.
69 pub fn as_injection_string(&self) -> String {
70 match self {
71 ToolResult::Ok(v) => {
72 format!("<tool_result>{}</tool_result>", v)
73 }
74 ToolResult::Err(e) => {
75 format!(
76 "<tool_result>{{\"error\":{}}}</tool_result>",
77 serde_json::json!(e)
78 )
79 }
80 }
81 }
82}
83
84// ─── Grammar ─────────────────────────────────────────────────────────────────
85
86/// Specifies the delimiter syntax used by a given model for tool calls.
87///
88/// Choose the variant that matches your deployed model:
89/// - [`Llama3`](ToolCallGrammar::Llama3) for LLaMA 3 models.
90/// - [`Qwen`](ToolCallGrammar::Qwen) for Qwen / Qwen2 models.
91/// - [`Mistral`](ToolCallGrammar::Mistral) for Mistral / Mixtral function-calling models.
92/// - [`Custom`](ToolCallGrammar::Custom) for any other format.
93#[derive(Debug, Clone)]
94pub enum ToolCallGrammar {
95 /// LLaMA 3 tool-call format: `<|tool_call|>...</|tool_call|>`.
96 Llama3,
97 /// Qwen / Qwen2 format: `<tool_call>...</tool_call>`.
98 Qwen,
99 /// Mistral function-calling format: `[TOOL_CALLS][...]`.
100 Mistral,
101 /// User-supplied open/close delimiter pair.
102 Custom {
103 /// Opening delimiter (e.g. `"<tool_call>"`).
104 open: String,
105 /// Closing delimiter (e.g. `"</tool_call>"`).
106 close: String,
107 },
108}
109
110impl ToolCallGrammar {
111 /// Return the opening delimiter for this grammar variant.
112 pub fn open_delimiter(&self) -> &str {
113 match self {
114 ToolCallGrammar::Llama3 => "<|tool_call|>",
115 ToolCallGrammar::Qwen => "<tool_call>",
116 ToolCallGrammar::Mistral => "[TOOL_CALLS][",
117 ToolCallGrammar::Custom { open, .. } => open.as_str(),
118 }
119 }
120
121 /// Return the closing delimiter for this grammar variant.
122 pub fn close_delimiter(&self) -> &str {
123 match self {
124 ToolCallGrammar::Llama3 => "<|/tool_call|>",
125 ToolCallGrammar::Qwen => "</tool_call>",
126 ToolCallGrammar::Mistral => "]",
127 ToolCallGrammar::Custom { close, .. } => close.as_str(),
128 }
129 }
130}
131
132// ─── Parsed tool call ─────────────────────────────────────────────────────────
133
134/// A fully-parsed tool call extracted from the model's output stream.
135#[derive(Debug, Clone)]
136pub struct ToolCall {
137 /// Name of the tool to invoke.
138 pub name: String,
139 /// Arguments to pass to the tool as a JSON value.
140 pub args: Value,
141}
142
143// ─── Detection state machine ─────────────────────────────────────────────────
144
145/// Internal state of the tool-call detector.
146#[derive(Debug, Clone, Copy, PartialEq, Eq)]
147enum ToolDetectionState {
148 /// No tool call in progress; scanning for the open delimiter.
149 Idle,
150 /// Open delimiter has been seen; capturing token text until the close.
151 Capturing,
152}
153
154/// Incremental tool-call detector.
155///
156/// Feed token text via [`feed`](ToolCallDetector::feed) as tokens are generated.
157/// When a complete tool call (open delimiter + valid JSON + close delimiter)
158/// is recognised, `feed` returns `Some(ToolCall)`.
159///
160/// The detector maintains a rolling scan buffer to handle delimiters that
161/// span multiple tokens.
162///
163/// # Example
164///
165/// ```
166/// use oxillama_runtime::tool_dispatch::{ToolCallDetector, ToolCallGrammar};
167///
168/// let mut detector = ToolCallDetector::new(ToolCallGrammar::Llama3);
169/// let call = detector.feed("<|tool_call|>{\"name\":\"ping\",\"args\":{}}<|/tool_call|>");
170/// assert!(call.is_some());
171/// ```
172pub struct ToolCallDetector {
173 grammar: ToolCallGrammar,
174 state: ToolDetectionState,
175 /// Text accumulated since the start of the current token or since the last
176 /// full-delimiter match candidate.
177 buffer: String,
178}
179
180impl ToolCallDetector {
181 /// Construct a new detector for the given grammar.
182 pub fn new(grammar: ToolCallGrammar) -> Self {
183 Self {
184 grammar,
185 state: ToolDetectionState::Idle,
186 buffer: String::new(),
187 }
188 }
189
190 /// Feed one token's decoded text into the detector.
191 ///
192 /// Returns `Some(ToolCall)` when a complete, valid tool call is detected.
193 /// Returns `None` while the tool call is still accumulating or if the text
194 /// is not a tool call.
195 ///
196 /// After returning `Some`, the detector automatically resets to `Idle`.
197 /// This means it can detect multiple sequential tool calls: feed text
198 /// from the second call and it will be detected in a subsequent call.
199 pub fn feed(&mut self, token_text: &str) -> Option<ToolCall> {
200 self.buffer.push_str(token_text);
201 self.try_parse()
202 }
203
204 /// Reset the detector to the idle state, discarding any buffered content.
205 pub fn reset(&mut self) {
206 self.state = ToolDetectionState::Idle;
207 self.buffer.clear();
208 }
209
210 // ─── Internal parsing ─────────────────────────────────────────────────────
211
212 /// Try to parse a complete tool call from the current buffer.
213 ///
214 /// This is the core state machine: it searches for open and close
215 /// delimiters in the buffer and attempts JSON parsing on the content
216 /// between them.
217 ///
218 /// Multiple calls are detected by scanning for repeated open/close pairs.
219 fn try_parse(&mut self) -> Option<ToolCall> {
220 let open = self.grammar.open_delimiter().to_string();
221 let close = self.grammar.close_delimiter().to_string();
222
223 loop {
224 match self.state {
225 ToolDetectionState::Idle => {
226 // Look for the opening delimiter.
227 if let Some(start) = self.buffer.find(open.as_str()) {
228 // Discard everything before the open delimiter.
229 let after_open = start + open.len();
230 self.buffer = self.buffer[after_open..].to_string();
231 self.state = ToolDetectionState::Capturing;
232 // Fall through and look for the close delimiter.
233 } else {
234 // No open delimiter yet; keep only the trailing portion
235 // that could be a partial delimiter prefix.
236 self.trim_idle_buffer(&open);
237 return None;
238 }
239 }
240
241 ToolDetectionState::Capturing => {
242 if let Some(end) = self.buffer.find(close.as_str()) {
243 // Extract the JSON payload.
244 let payload = self.buffer[..end].trim().to_string();
245 // Consume past the close delimiter.
246 let after_close = end + close.len();
247 let remainder = self.buffer[after_close..].to_string();
248 self.buffer = remainder;
249 self.state = ToolDetectionState::Idle;
250
251 // Parse and validate the JSON.
252 if let Some(call) = parse_tool_call_json(&payload) {
253 return Some(call);
254 }
255 // Bad JSON — continue scanning the remainder.
256 // (fall back to Idle and try again)
257 } else {
258 // Close delimiter not yet seen; keep capturing.
259 return None;
260 }
261 }
262 }
263 }
264 }
265
266 /// Trim the idle buffer to at most `max_suffix` chars that could be a
267 /// prefix of the open delimiter. Prevents unbounded growth of the buffer
268 /// when no tool call is ever emitted.
269 fn trim_idle_buffer(&mut self, open: &str) {
270 let max_keep = open.len().saturating_sub(1);
271 if self.buffer.len() > max_keep {
272 let trim_to = self.buffer.len() - max_keep;
273 self.buffer = self.buffer[trim_to..].to_string();
274 }
275 }
276}
277
278// ─── JSON parsing ─────────────────────────────────────────────────────────────
279
280/// Parse a JSON string as a tool call object with `name` and `args` fields.
281///
282/// Accepts two JSON shapes:
283/// 1. `{"name": "...", "args": { ... }}` — preferred
284/// 2. `{"name": "...", "arguments": { ... }}` — OpenAI-compat alias
285///
286/// Returns `None` if the string is not valid JSON or the expected fields
287/// are missing.
288fn parse_tool_call_json(payload: &str) -> Option<ToolCall> {
289 let v: Value = serde_json::from_str(payload).ok()?;
290 let obj = v.as_object()?;
291
292 let name = obj.get("name")?.as_str()?.to_string();
293
294 // Accept either "args" or "arguments" for OpenAI compatibility.
295 let args = obj
296 .get("args")
297 .or_else(|| obj.get("arguments"))
298 .cloned()
299 .unwrap_or(Value::Object(serde_json::Map::new()));
300
301 Some(ToolCall { name, args })
302}
303
304// ─── Tool-dispatcher no-op helper ────────────────────────────────────────────
305
306/// A no-op dispatcher that returns a stub `Ok(null)` for every tool call.
307///
308/// Useful for testing or when you want to detect tool calls but not execute
309/// them yet.
310pub struct NoOpDispatcher;
311
312impl ToolDispatcher for NoOpDispatcher {
313 fn invoke(&self, _name: &str, _args: &Value) -> ToolResult {
314 ToolResult::Ok(Value::Null)
315 }
316}
317
318/// Create a no-op dispatcher wrapped in an `Arc`.
319pub fn no_op_dispatcher() -> Arc<dyn ToolDispatcher> {
320 Arc::new(NoOpDispatcher)
321}
322
323// ─── Tests ────────────────────────────────────────────────────────────────────
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 // ── A: Basic detection ────────────────────────────────────────────────────
330
331 /// (a) Complete LLaMA-3 tool call in a single feed() call.
332 #[test]
333 fn tool_call_detection_llama3() {
334 let mut det = ToolCallDetector::new(ToolCallGrammar::Llama3);
335 let result = det
336 .feed(r#"<|tool_call|>{"name":"get_weather","args":{"city":"Tokyo"}}<|/tool_call|>"#);
337 assert!(result.is_some(), "must detect a complete Llama3 tool call");
338 let call = result.expect("detection should succeed");
339 assert_eq!(call.name, "get_weather");
340 assert_eq!(call.args["city"], Value::String("Tokyo".to_string()));
341 }
342
343 /// (b) Tool call where open delimiter, JSON body, and close delimiter arrive
344 /// in separate feed() calls (simulates streaming tokenizer output).
345 #[test]
346 fn tool_call_streamed_across_chunks() {
347 let mut det = ToolCallDetector::new(ToolCallGrammar::Llama3);
348
349 // Chunk 1: opening delimiter
350 let r1 = det.feed("<|tool_call|>");
351 assert!(r1.is_none(), "open delimiter alone must not fire");
352
353 // Chunk 2: JSON body (no close yet)
354 let r2 = det.feed(r#"{"name":"add","args":{"a":1,"b":2}}"#);
355 assert!(r2.is_none(), "body without close must not fire");
356
357 // Chunk 3: closing delimiter
358 let r3 = det.feed("<|/tool_call|>");
359 assert!(
360 r3.is_some(),
361 "close delimiter should complete the detection"
362 );
363 let call = r3.expect("detection should succeed");
364 assert_eq!(call.name, "add");
365 assert_eq!(call.args["a"], 1);
366 assert_eq!(call.args["b"], 2);
367 }
368
369 /// (c) Malformed / unclosed JSON must not produce a ToolCall.
370 #[test]
371 fn malformed_json_does_not_return_call() {
372 let mut det = ToolCallDetector::new(ToolCallGrammar::Llama3);
373
374 // Send an open delimiter followed by unclosed JSON.
375 let r1 = det.feed("<|tool_call|>{\"name\":\"broken\"");
376 assert!(r1.is_none(), "partial JSON must not fire");
377
378 // Never close — detector should stay in Capturing without panic.
379 for _ in 0..5 {
380 let r = det.feed("more garbage");
381 assert!(r.is_none(), "unfinished tool call must not fire");
382 }
383 }
384
385 /// (d) Two complete tool calls back-to-back in the same buffer.
386 /// The detector must fire twice (once per call).
387 #[test]
388 fn multiple_calls_sequentially() {
389 let mut det = ToolCallDetector::new(ToolCallGrammar::Qwen);
390
391 let r1 = det.feed(
392 r#"<tool_call>{"name":"tool1","args":{"x":1}}</tool_call><tool_call>{"name":"tool2","args":{"y":2}}</tool_call>"#,
393 );
394 assert!(r1.is_some(), "first call must be detected");
395 let c1 = r1.expect("first call");
396 assert_eq!(c1.name, "tool1");
397
398 // Second call should be detected on an empty feed (it's still in buffer).
399 let r2 = det.feed("");
400 assert!(r2.is_some(), "second call must be detected from remainder");
401 let c2 = r2.expect("second call");
402 assert_eq!(c2.name, "tool2");
403 }
404
405 // ── B: Grammar variant tests ──────────────────────────────────────────────
406
407 /// Qwen format detection.
408 #[test]
409 fn tool_call_detection_qwen() {
410 let mut det = ToolCallDetector::new(ToolCallGrammar::Qwen);
411 let result = det.feed(r#"<tool_call>{"name":"calc","args":{"expr":"1+1"}}</tool_call>"#);
412 assert!(result.is_some());
413 let call = result.expect("qwen call");
414 assert_eq!(call.name, "calc");
415 }
416
417 /// Mistral format detection.
418 #[test]
419 fn tool_call_detection_mistral() {
420 let mut det = ToolCallDetector::new(ToolCallGrammar::Mistral);
421 let result = det.feed(r#"[TOOL_CALLS][{"name":"search","args":{"q":"rust"}}]"#);
422 assert!(result.is_some());
423 let call = result.expect("mistral call");
424 assert_eq!(call.name, "search");
425 assert_eq!(call.args["q"], "rust");
426 }
427
428 /// Custom grammar detection.
429 #[test]
430 fn tool_call_detection_custom() {
431 let mut det = ToolCallDetector::new(ToolCallGrammar::Custom {
432 open: "<<TOOL>>".to_string(),
433 close: "<</TOOL>>".to_string(),
434 });
435 let result = det.feed(r#"<<TOOL>>{"name":"echo","args":{"msg":"hi"}}<</TOOL>>"#);
436 assert!(result.is_some());
437 let call = result.expect("custom call");
438 assert_eq!(call.name, "echo");
439 }
440
441 // ── C: Grammar delimiter accessors ───────────────────────────────────────
442
443 #[test]
444 fn grammar_delimiters_llama3() {
445 let g = ToolCallGrammar::Llama3;
446 assert_eq!(g.open_delimiter(), "<|tool_call|>");
447 assert_eq!(g.close_delimiter(), "<|/tool_call|>");
448 }
449
450 #[test]
451 fn grammar_delimiters_qwen() {
452 let g = ToolCallGrammar::Qwen;
453 assert_eq!(g.open_delimiter(), "<tool_call>");
454 assert_eq!(g.close_delimiter(), "</tool_call>");
455 }
456
457 #[test]
458 fn grammar_delimiters_mistral() {
459 let g = ToolCallGrammar::Mistral;
460 assert_eq!(g.open_delimiter(), "[TOOL_CALLS][");
461 assert_eq!(g.close_delimiter(), "]");
462 }
463
464 #[test]
465 fn grammar_delimiters_custom() {
466 let g = ToolCallGrammar::Custom {
467 open: "START".to_string(),
468 close: "END".to_string(),
469 };
470 assert_eq!(g.open_delimiter(), "START");
471 assert_eq!(g.close_delimiter(), "END");
472 }
473
474 // ── D: Reset test ─────────────────────────────────────────────────────────
475
476 /// After reset(), the detector treats new input as if it had never
477 /// seen the previous stream (can detect a new call from scratch).
478 #[test]
479 fn reset_clears_state() {
480 let mut det = ToolCallDetector::new(ToolCallGrammar::Llama3);
481
482 // Start a call but don't finish it.
483 det.feed("<|tool_call|>{\"name\":\"half");
484 assert_eq!(det.state, ToolDetectionState::Capturing);
485
486 // Reset.
487 det.reset();
488 assert_eq!(det.state, ToolDetectionState::Idle);
489 assert!(det.buffer.is_empty());
490
491 // A fresh call after reset should still work.
492 let r = det.feed(r#"<|tool_call|>{"name":"fresh","args":{}}<|/tool_call|>"#);
493 assert!(r.is_some(), "should detect call after reset");
494 }
495
496 // ── E: ToolResult injection string ────────────────────────────────────────
497
498 #[test]
499 fn tool_result_ok_injection_string() {
500 let result = ToolResult::Ok(Value::String("42°C".to_string()));
501 let s = result.as_injection_string();
502 assert!(s.contains("<tool_result>"), "must contain opening tag");
503 assert!(s.contains("</tool_result>"), "must contain closing tag");
504 assert!(s.contains("42°C"), "must contain result value");
505 }
506
507 #[test]
508 fn tool_result_err_injection_string() {
509 let result = ToolResult::Err("not found".to_string());
510 let s = result.as_injection_string();
511 assert!(s.contains("<tool_result>"), "must contain opening tag");
512 assert!(s.contains("error"), "must contain error key");
513 }
514
515 // ── F: OpenAI-compat "arguments" field ────────────────────────────────────
516
517 /// parse_tool_call_json must accept "arguments" as alias for "args".
518 #[test]
519 fn tool_call_arguments_alias() {
520 let mut det = ToolCallDetector::new(ToolCallGrammar::Llama3);
521 let r = det.feed(r#"<|tool_call|>{"name":"fn","arguments":{"k":"v"}}<|/tool_call|>"#);
522 assert!(r.is_some(), "arguments alias should be accepted");
523 let call = r.expect("call with arguments");
524 assert_eq!(call.args["k"], "v");
525 }
526
527 // ── G: NoOpDispatcher ─────────────────────────────────────────────────────
528
529 #[test]
530 fn no_op_dispatcher_returns_ok_null() {
531 let d = no_op_dispatcher();
532 let result = d.invoke("anything", &Value::Null);
533 assert!(matches!(result, ToolResult::Ok(Value::Null)));
534 }
535}