Skip to main content

oxillama_runtime/
tool_dispatch.rs

1//! Tool-invocation runtime callbacks.
2//!
3//! Provides infrastructure for detecting, parsing, and dispatching tool calls
4//! produced by language models during generation.
5//!
6//! ## Overview
7//!
8//! Different model families emit tool calls using different delimiter syntax:
9//! - LLaMA 3: `<|tool_call|>{ ... }<|/tool_call|>`
10//! - Qwen: `<tool_call>{ ... }</tool_call>`
11//! - Mistral: `[TOOL_CALLS][ ... ]`
12//! - Custom: user-supplied open/close delimiters
13//!
14//! The [`ToolCallDetector`] accumulates token text, detects open/close
15//! delimiters, validates the JSON payload, and fires whenever a complete tool
16//! call is parsed.
17//!
18//! Tool results can be queued for injection back into the generation stream
19//! via the engine's injection queue mechanism.
20
21use serde_json::Value;
22use std::sync::Arc;
23
24// ─── Core trait ───────────────────────────────────────────────────────────────
25
26/// Dispatches tool calls to registered handler implementations.
27///
28/// Implement this trait to handle tool invocations produced by the model
29/// during generation.  The implementation must be `Send + Sync` so it can
30/// be shared across threads and stored in `Arc`.
31///
32/// # Example
33///
34/// ```
35/// use oxillama_runtime::tool_dispatch::{ToolDispatcher, ToolResult};
36/// use serde_json::Value;
37///
38/// struct WeatherTool;
39///
40/// impl ToolDispatcher for WeatherTool {
41///     fn invoke(&self, name: &str, args: &Value) -> ToolResult {
42///         if name == "get_weather" {
43///             ToolResult::Ok(Value::String("sunny, 22°C".to_string()))
44///         } else {
45///             ToolResult::Err(format!("unknown tool: {name}"))
46///         }
47///     }
48/// }
49/// ```
50pub trait ToolDispatcher: Send + Sync {
51    /// Invoke the named tool with the given JSON arguments.
52    ///
53    /// Returns [`ToolResult::Ok`] with the tool's output value on success, or
54    /// [`ToolResult::Err`] with an error message on failure.
55    fn invoke(&self, name: &str, args: &Value) -> ToolResult;
56}
57
58/// Result of a tool invocation.
59#[derive(Debug, Clone)]
60pub enum ToolResult {
61    /// Successful invocation with a JSON result value.
62    Ok(Value),
63    /// Failed invocation with a human-readable error message.
64    Err(String),
65}
66
67impl ToolResult {
68    /// Format the result as a string suitable for injection into the token stream.
69    pub fn as_injection_string(&self) -> String {
70        match self {
71            ToolResult::Ok(v) => {
72                format!("<tool_result>{}</tool_result>", v)
73            }
74            ToolResult::Err(e) => {
75                format!(
76                    "<tool_result>{{\"error\":{}}}</tool_result>",
77                    serde_json::json!(e)
78                )
79            }
80        }
81    }
82}
83
84// ─── Grammar ─────────────────────────────────────────────────────────────────
85
86/// Specifies the delimiter syntax used by a given model for tool calls.
87///
88/// Choose the variant that matches your deployed model:
89/// - [`Llama3`](ToolCallGrammar::Llama3) for LLaMA 3 models.
90/// - [`Qwen`](ToolCallGrammar::Qwen) for Qwen / Qwen2 models.
91/// - [`Mistral`](ToolCallGrammar::Mistral) for Mistral / Mixtral function-calling models.
92/// - [`Custom`](ToolCallGrammar::Custom) for any other format.
93#[derive(Debug, Clone)]
94pub enum ToolCallGrammar {
95    /// LLaMA 3 tool-call format: `<|tool_call|>...</|tool_call|>`.
96    Llama3,
97    /// Qwen / Qwen2 format: `<tool_call>...</tool_call>`.
98    Qwen,
99    /// Mistral function-calling format: `[TOOL_CALLS][...]`.
100    Mistral,
101    /// User-supplied open/close delimiter pair.
102    Custom {
103        /// Opening delimiter (e.g. `"<tool_call>"`).
104        open: String,
105        /// Closing delimiter (e.g. `"</tool_call>"`).
106        close: String,
107    },
108}
109
110impl ToolCallGrammar {
111    /// Return the opening delimiter for this grammar variant.
112    pub fn open_delimiter(&self) -> &str {
113        match self {
114            ToolCallGrammar::Llama3 => "<|tool_call|>",
115            ToolCallGrammar::Qwen => "<tool_call>",
116            ToolCallGrammar::Mistral => "[TOOL_CALLS][",
117            ToolCallGrammar::Custom { open, .. } => open.as_str(),
118        }
119    }
120
121    /// Return the closing delimiter for this grammar variant.
122    pub fn close_delimiter(&self) -> &str {
123        match self {
124            ToolCallGrammar::Llama3 => "<|/tool_call|>",
125            ToolCallGrammar::Qwen => "</tool_call>",
126            ToolCallGrammar::Mistral => "]",
127            ToolCallGrammar::Custom { close, .. } => close.as_str(),
128        }
129    }
130}
131
132// ─── Parsed tool call ─────────────────────────────────────────────────────────
133
134/// A fully-parsed tool call extracted from the model's output stream.
135#[derive(Debug, Clone)]
136pub struct ToolCall {
137    /// Name of the tool to invoke.
138    pub name: String,
139    /// Arguments to pass to the tool as a JSON value.
140    pub args: Value,
141}
142
143// ─── Detection state machine ─────────────────────────────────────────────────
144
145/// Internal state of the tool-call detector.
146#[derive(Debug, Clone, Copy, PartialEq, Eq)]
147enum ToolDetectionState {
148    /// No tool call in progress; scanning for the open delimiter.
149    Idle,
150    /// Open delimiter has been seen; capturing token text until the close.
151    Capturing,
152}
153
154/// Incremental tool-call detector.
155///
156/// Feed token text via [`feed`](ToolCallDetector::feed) as tokens are generated.
157/// When a complete tool call (open delimiter + valid JSON + close delimiter)
158/// is recognised, `feed` returns `Some(ToolCall)`.
159///
160/// The detector maintains a rolling scan buffer to handle delimiters that
161/// span multiple tokens.
162///
163/// # Example
164///
165/// ```
166/// use oxillama_runtime::tool_dispatch::{ToolCallDetector, ToolCallGrammar};
167///
168/// let mut detector = ToolCallDetector::new(ToolCallGrammar::Llama3);
169/// let call = detector.feed("<|tool_call|>{\"name\":\"ping\",\"args\":{}}<|/tool_call|>");
170/// assert!(call.is_some());
171/// ```
172pub struct ToolCallDetector {
173    grammar: ToolCallGrammar,
174    state: ToolDetectionState,
175    /// Text accumulated since the start of the current token or since the last
176    /// full-delimiter match candidate.
177    buffer: String,
178}
179
180impl ToolCallDetector {
181    /// Construct a new detector for the given grammar.
182    pub fn new(grammar: ToolCallGrammar) -> Self {
183        Self {
184            grammar,
185            state: ToolDetectionState::Idle,
186            buffer: String::new(),
187        }
188    }
189
190    /// Feed one token's decoded text into the detector.
191    ///
192    /// Returns `Some(ToolCall)` when a complete, valid tool call is detected.
193    /// Returns `None` while the tool call is still accumulating or if the text
194    /// is not a tool call.
195    ///
196    /// After returning `Some`, the detector automatically resets to `Idle`.
197    /// This means it can detect multiple sequential tool calls: feed text
198    /// from the second call and it will be detected in a subsequent call.
199    pub fn feed(&mut self, token_text: &str) -> Option<ToolCall> {
200        self.buffer.push_str(token_text);
201        self.try_parse()
202    }
203
204    /// Reset the detector to the idle state, discarding any buffered content.
205    pub fn reset(&mut self) {
206        self.state = ToolDetectionState::Idle;
207        self.buffer.clear();
208    }
209
210    // ─── Internal parsing ─────────────────────────────────────────────────────
211
212    /// Try to parse a complete tool call from the current buffer.
213    ///
214    /// This is the core state machine: it searches for open and close
215    /// delimiters in the buffer and attempts JSON parsing on the content
216    /// between them.
217    ///
218    /// Multiple calls are detected by scanning for repeated open/close pairs.
219    fn try_parse(&mut self) -> Option<ToolCall> {
220        let open = self.grammar.open_delimiter().to_string();
221        let close = self.grammar.close_delimiter().to_string();
222
223        loop {
224            match self.state {
225                ToolDetectionState::Idle => {
226                    // Look for the opening delimiter.
227                    if let Some(start) = self.buffer.find(open.as_str()) {
228                        // Discard everything before the open delimiter.
229                        let after_open = start + open.len();
230                        self.buffer = self.buffer[after_open..].to_string();
231                        self.state = ToolDetectionState::Capturing;
232                        // Fall through and look for the close delimiter.
233                    } else {
234                        // No open delimiter yet; keep only the trailing portion
235                        // that could be a partial delimiter prefix.
236                        self.trim_idle_buffer(&open);
237                        return None;
238                    }
239                }
240
241                ToolDetectionState::Capturing => {
242                    if let Some(end) = self.buffer.find(close.as_str()) {
243                        // Extract the JSON payload.
244                        let payload = self.buffer[..end].trim().to_string();
245                        // Consume past the close delimiter.
246                        let after_close = end + close.len();
247                        let remainder = self.buffer[after_close..].to_string();
248                        self.buffer = remainder;
249                        self.state = ToolDetectionState::Idle;
250
251                        // Parse and validate the JSON.
252                        if let Some(call) = parse_tool_call_json(&payload) {
253                            return Some(call);
254                        }
255                        // Bad JSON — continue scanning the remainder.
256                        // (fall back to Idle and try again)
257                    } else {
258                        // Close delimiter not yet seen; keep capturing.
259                        return None;
260                    }
261                }
262            }
263        }
264    }
265
266    /// Trim the idle buffer to at most `max_suffix` chars that could be a
267    /// prefix of the open delimiter.  Prevents unbounded growth of the buffer
268    /// when no tool call is ever emitted.
269    fn trim_idle_buffer(&mut self, open: &str) {
270        let max_keep = open.len().saturating_sub(1);
271        if self.buffer.len() > max_keep {
272            let trim_to = self.buffer.len() - max_keep;
273            self.buffer = self.buffer[trim_to..].to_string();
274        }
275    }
276}
277
278// ─── JSON parsing ─────────────────────────────────────────────────────────────
279
280/// Parse a JSON string as a tool call object with `name` and `args` fields.
281///
282/// Accepts two JSON shapes:
283/// 1. `{"name": "...", "args": { ... }}` — preferred
284/// 2. `{"name": "...", "arguments": { ... }}` — OpenAI-compat alias
285///
286/// Returns `None` if the string is not valid JSON or the expected fields
287/// are missing.
288fn parse_tool_call_json(payload: &str) -> Option<ToolCall> {
289    let v: Value = serde_json::from_str(payload).ok()?;
290    let obj = v.as_object()?;
291
292    let name = obj.get("name")?.as_str()?.to_string();
293
294    // Accept either "args" or "arguments" for OpenAI compatibility.
295    let args = obj
296        .get("args")
297        .or_else(|| obj.get("arguments"))
298        .cloned()
299        .unwrap_or(Value::Object(serde_json::Map::new()));
300
301    Some(ToolCall { name, args })
302}
303
304// ─── Tool-dispatcher no-op helper ────────────────────────────────────────────
305
306/// A no-op dispatcher that returns a stub `Ok(null)` for every tool call.
307///
308/// Useful for testing or when you want to detect tool calls but not execute
309/// them yet.
310pub struct NoOpDispatcher;
311
312impl ToolDispatcher for NoOpDispatcher {
313    fn invoke(&self, _name: &str, _args: &Value) -> ToolResult {
314        ToolResult::Ok(Value::Null)
315    }
316}
317
318/// Create a no-op dispatcher wrapped in an `Arc`.
319pub fn no_op_dispatcher() -> Arc<dyn ToolDispatcher> {
320    Arc::new(NoOpDispatcher)
321}
322
323// ─── Tests ────────────────────────────────────────────────────────────────────
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    // ── A: Basic detection ────────────────────────────────────────────────────
330
331    /// (a) Complete LLaMA-3 tool call in a single feed() call.
332    #[test]
333    fn tool_call_detection_llama3() {
334        let mut det = ToolCallDetector::new(ToolCallGrammar::Llama3);
335        let result = det
336            .feed(r#"<|tool_call|>{"name":"get_weather","args":{"city":"Tokyo"}}<|/tool_call|>"#);
337        assert!(result.is_some(), "must detect a complete Llama3 tool call");
338        let call = result.expect("detection should succeed");
339        assert_eq!(call.name, "get_weather");
340        assert_eq!(call.args["city"], Value::String("Tokyo".to_string()));
341    }
342
343    /// (b) Tool call where open delimiter, JSON body, and close delimiter arrive
344    ///     in separate feed() calls (simulates streaming tokenizer output).
345    #[test]
346    fn tool_call_streamed_across_chunks() {
347        let mut det = ToolCallDetector::new(ToolCallGrammar::Llama3);
348
349        // Chunk 1: opening delimiter
350        let r1 = det.feed("<|tool_call|>");
351        assert!(r1.is_none(), "open delimiter alone must not fire");
352
353        // Chunk 2: JSON body (no close yet)
354        let r2 = det.feed(r#"{"name":"add","args":{"a":1,"b":2}}"#);
355        assert!(r2.is_none(), "body without close must not fire");
356
357        // Chunk 3: closing delimiter
358        let r3 = det.feed("<|/tool_call|>");
359        assert!(
360            r3.is_some(),
361            "close delimiter should complete the detection"
362        );
363        let call = r3.expect("detection should succeed");
364        assert_eq!(call.name, "add");
365        assert_eq!(call.args["a"], 1);
366        assert_eq!(call.args["b"], 2);
367    }
368
369    /// (c) Malformed / unclosed JSON must not produce a ToolCall.
370    #[test]
371    fn malformed_json_does_not_return_call() {
372        let mut det = ToolCallDetector::new(ToolCallGrammar::Llama3);
373
374        // Send an open delimiter followed by unclosed JSON.
375        let r1 = det.feed("<|tool_call|>{\"name\":\"broken\"");
376        assert!(r1.is_none(), "partial JSON must not fire");
377
378        // Never close — detector should stay in Capturing without panic.
379        for _ in 0..5 {
380            let r = det.feed("more garbage");
381            assert!(r.is_none(), "unfinished tool call must not fire");
382        }
383    }
384
385    /// (d) Two complete tool calls back-to-back in the same buffer.
386    ///     The detector must fire twice (once per call).
387    #[test]
388    fn multiple_calls_sequentially() {
389        let mut det = ToolCallDetector::new(ToolCallGrammar::Qwen);
390
391        let r1 = det.feed(
392            r#"<tool_call>{"name":"tool1","args":{"x":1}}</tool_call><tool_call>{"name":"tool2","args":{"y":2}}</tool_call>"#,
393        );
394        assert!(r1.is_some(), "first call must be detected");
395        let c1 = r1.expect("first call");
396        assert_eq!(c1.name, "tool1");
397
398        // Second call should be detected on an empty feed (it's still in buffer).
399        let r2 = det.feed("");
400        assert!(r2.is_some(), "second call must be detected from remainder");
401        let c2 = r2.expect("second call");
402        assert_eq!(c2.name, "tool2");
403    }
404
405    // ── B: Grammar variant tests ──────────────────────────────────────────────
406
407    /// Qwen format detection.
408    #[test]
409    fn tool_call_detection_qwen() {
410        let mut det = ToolCallDetector::new(ToolCallGrammar::Qwen);
411        let result = det.feed(r#"<tool_call>{"name":"calc","args":{"expr":"1+1"}}</tool_call>"#);
412        assert!(result.is_some());
413        let call = result.expect("qwen call");
414        assert_eq!(call.name, "calc");
415    }
416
417    /// Mistral format detection.
418    #[test]
419    fn tool_call_detection_mistral() {
420        let mut det = ToolCallDetector::new(ToolCallGrammar::Mistral);
421        let result = det.feed(r#"[TOOL_CALLS][{"name":"search","args":{"q":"rust"}}]"#);
422        assert!(result.is_some());
423        let call = result.expect("mistral call");
424        assert_eq!(call.name, "search");
425        assert_eq!(call.args["q"], "rust");
426    }
427
428    /// Custom grammar detection.
429    #[test]
430    fn tool_call_detection_custom() {
431        let mut det = ToolCallDetector::new(ToolCallGrammar::Custom {
432            open: "<<TOOL>>".to_string(),
433            close: "<</TOOL>>".to_string(),
434        });
435        let result = det.feed(r#"<<TOOL>>{"name":"echo","args":{"msg":"hi"}}<</TOOL>>"#);
436        assert!(result.is_some());
437        let call = result.expect("custom call");
438        assert_eq!(call.name, "echo");
439    }
440
441    // ── C: Grammar delimiter accessors ───────────────────────────────────────
442
443    #[test]
444    fn grammar_delimiters_llama3() {
445        let g = ToolCallGrammar::Llama3;
446        assert_eq!(g.open_delimiter(), "<|tool_call|>");
447        assert_eq!(g.close_delimiter(), "<|/tool_call|>");
448    }
449
450    #[test]
451    fn grammar_delimiters_qwen() {
452        let g = ToolCallGrammar::Qwen;
453        assert_eq!(g.open_delimiter(), "<tool_call>");
454        assert_eq!(g.close_delimiter(), "</tool_call>");
455    }
456
457    #[test]
458    fn grammar_delimiters_mistral() {
459        let g = ToolCallGrammar::Mistral;
460        assert_eq!(g.open_delimiter(), "[TOOL_CALLS][");
461        assert_eq!(g.close_delimiter(), "]");
462    }
463
464    #[test]
465    fn grammar_delimiters_custom() {
466        let g = ToolCallGrammar::Custom {
467            open: "START".to_string(),
468            close: "END".to_string(),
469        };
470        assert_eq!(g.open_delimiter(), "START");
471        assert_eq!(g.close_delimiter(), "END");
472    }
473
474    // ── D: Reset test ─────────────────────────────────────────────────────────
475
476    /// After reset(), the detector treats new input as if it had never
477    /// seen the previous stream (can detect a new call from scratch).
478    #[test]
479    fn reset_clears_state() {
480        let mut det = ToolCallDetector::new(ToolCallGrammar::Llama3);
481
482        // Start a call but don't finish it.
483        det.feed("<|tool_call|>{\"name\":\"half");
484        assert_eq!(det.state, ToolDetectionState::Capturing);
485
486        // Reset.
487        det.reset();
488        assert_eq!(det.state, ToolDetectionState::Idle);
489        assert!(det.buffer.is_empty());
490
491        // A fresh call after reset should still work.
492        let r = det.feed(r#"<|tool_call|>{"name":"fresh","args":{}}<|/tool_call|>"#);
493        assert!(r.is_some(), "should detect call after reset");
494    }
495
496    // ── E: ToolResult injection string ────────────────────────────────────────
497
498    #[test]
499    fn tool_result_ok_injection_string() {
500        let result = ToolResult::Ok(Value::String("42°C".to_string()));
501        let s = result.as_injection_string();
502        assert!(s.contains("<tool_result>"), "must contain opening tag");
503        assert!(s.contains("</tool_result>"), "must contain closing tag");
504        assert!(s.contains("42°C"), "must contain result value");
505    }
506
507    #[test]
508    fn tool_result_err_injection_string() {
509        let result = ToolResult::Err("not found".to_string());
510        let s = result.as_injection_string();
511        assert!(s.contains("<tool_result>"), "must contain opening tag");
512        assert!(s.contains("error"), "must contain error key");
513    }
514
515    // ── F: OpenAI-compat "arguments" field ────────────────────────────────────
516
517    /// parse_tool_call_json must accept "arguments" as alias for "args".
518    #[test]
519    fn tool_call_arguments_alias() {
520        let mut det = ToolCallDetector::new(ToolCallGrammar::Llama3);
521        let r = det.feed(r#"<|tool_call|>{"name":"fn","arguments":{"k":"v"}}<|/tool_call|>"#);
522        assert!(r.is_some(), "arguments alias should be accepted");
523        let call = r.expect("call with arguments");
524        assert_eq!(call.args["k"], "v");
525    }
526
527    // ── G: NoOpDispatcher ─────────────────────────────────────────────────────
528
529    #[test]
530    fn no_op_dispatcher_returns_ok_null() {
531        let d = no_op_dispatcher();
532        let result = d.invoke("anything", &Value::Null);
533        assert!(matches!(result, ToolResult::Ok(Value::Null)));
534    }
535}