use serde_json::Value;
use std::sync::Arc;
pub trait ToolDispatcher: Send + Sync {
fn invoke(&self, name: &str, args: &Value) -> ToolResult;
}
#[derive(Debug, Clone)]
pub enum ToolResult {
Ok(Value),
Err(String),
}
impl ToolResult {
pub fn as_injection_string(&self) -> String {
match self {
ToolResult::Ok(v) => {
format!("<tool_result>{}</tool_result>", v)
}
ToolResult::Err(e) => {
format!(
"<tool_result>{{\"error\":{}}}</tool_result>",
serde_json::json!(e)
)
}
}
}
}
#[derive(Debug, Clone)]
pub enum ToolCallGrammar {
Llama3,
Qwen,
Mistral,
Custom {
open: String,
close: String,
},
}
impl ToolCallGrammar {
pub fn open_delimiter(&self) -> &str {
match self {
ToolCallGrammar::Llama3 => "<|tool_call|>",
ToolCallGrammar::Qwen => "<tool_call>",
ToolCallGrammar::Mistral => "[TOOL_CALLS][",
ToolCallGrammar::Custom { open, .. } => open.as_str(),
}
}
pub fn close_delimiter(&self) -> &str {
match self {
ToolCallGrammar::Llama3 => "<|/tool_call|>",
ToolCallGrammar::Qwen => "</tool_call>",
ToolCallGrammar::Mistral => "]",
ToolCallGrammar::Custom { close, .. } => close.as_str(),
}
}
}
#[derive(Debug, Clone)]
pub struct ToolCall {
pub name: String,
pub args: Value,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ToolDetectionState {
Idle,
Capturing,
}
pub struct ToolCallDetector {
grammar: ToolCallGrammar,
state: ToolDetectionState,
buffer: String,
}
impl ToolCallDetector {
pub fn new(grammar: ToolCallGrammar) -> Self {
Self {
grammar,
state: ToolDetectionState::Idle,
buffer: String::new(),
}
}
pub fn feed(&mut self, token_text: &str) -> Option<ToolCall> {
self.buffer.push_str(token_text);
self.try_parse()
}
pub fn reset(&mut self) {
self.state = ToolDetectionState::Idle;
self.buffer.clear();
}
fn try_parse(&mut self) -> Option<ToolCall> {
let open = self.grammar.open_delimiter().to_string();
let close = self.grammar.close_delimiter().to_string();
loop {
match self.state {
ToolDetectionState::Idle => {
if let Some(start) = self.buffer.find(open.as_str()) {
let after_open = start + open.len();
self.buffer = self.buffer[after_open..].to_string();
self.state = ToolDetectionState::Capturing;
} else {
self.trim_idle_buffer(&open);
return None;
}
}
ToolDetectionState::Capturing => {
if let Some(end) = self.buffer.find(close.as_str()) {
let payload = self.buffer[..end].trim().to_string();
let after_close = end + close.len();
let remainder = self.buffer[after_close..].to_string();
self.buffer = remainder;
self.state = ToolDetectionState::Idle;
if let Some(call) = parse_tool_call_json(&payload) {
return Some(call);
}
} else {
return None;
}
}
}
}
}
fn trim_idle_buffer(&mut self, open: &str) {
let max_keep = open.len().saturating_sub(1);
if self.buffer.len() > max_keep {
let trim_to = self.buffer.len() - max_keep;
self.buffer = self.buffer[trim_to..].to_string();
}
}
}
fn parse_tool_call_json(payload: &str) -> Option<ToolCall> {
let v: Value = serde_json::from_str(payload).ok()?;
let obj = v.as_object()?;
let name = obj.get("name")?.as_str()?.to_string();
let args = obj
.get("args")
.or_else(|| obj.get("arguments"))
.cloned()
.unwrap_or(Value::Object(serde_json::Map::new()));
Some(ToolCall { name, args })
}
pub struct NoOpDispatcher;
impl ToolDispatcher for NoOpDispatcher {
fn invoke(&self, _name: &str, _args: &Value) -> ToolResult {
ToolResult::Ok(Value::Null)
}
}
pub fn no_op_dispatcher() -> Arc<dyn ToolDispatcher> {
Arc::new(NoOpDispatcher)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tool_call_detection_llama3() {
let mut det = ToolCallDetector::new(ToolCallGrammar::Llama3);
let result = det
.feed(r#"<|tool_call|>{"name":"get_weather","args":{"city":"Tokyo"}}<|/tool_call|>"#);
assert!(result.is_some(), "must detect a complete Llama3 tool call");
let call = result.expect("detection should succeed");
assert_eq!(call.name, "get_weather");
assert_eq!(call.args["city"], Value::String("Tokyo".to_string()));
}
#[test]
fn tool_call_streamed_across_chunks() {
let mut det = ToolCallDetector::new(ToolCallGrammar::Llama3);
let r1 = det.feed("<|tool_call|>");
assert!(r1.is_none(), "open delimiter alone must not fire");
let r2 = det.feed(r#"{"name":"add","args":{"a":1,"b":2}}"#);
assert!(r2.is_none(), "body without close must not fire");
let r3 = det.feed("<|/tool_call|>");
assert!(
r3.is_some(),
"close delimiter should complete the detection"
);
let call = r3.expect("detection should succeed");
assert_eq!(call.name, "add");
assert_eq!(call.args["a"], 1);
assert_eq!(call.args["b"], 2);
}
#[test]
fn malformed_json_does_not_return_call() {
let mut det = ToolCallDetector::new(ToolCallGrammar::Llama3);
let r1 = det.feed("<|tool_call|>{\"name\":\"broken\"");
assert!(r1.is_none(), "partial JSON must not fire");
for _ in 0..5 {
let r = det.feed("more garbage");
assert!(r.is_none(), "unfinished tool call must not fire");
}
}
#[test]
fn multiple_calls_sequentially() {
let mut det = ToolCallDetector::new(ToolCallGrammar::Qwen);
let r1 = det.feed(
r#"<tool_call>{"name":"tool1","args":{"x":1}}</tool_call><tool_call>{"name":"tool2","args":{"y":2}}</tool_call>"#,
);
assert!(r1.is_some(), "first call must be detected");
let c1 = r1.expect("first call");
assert_eq!(c1.name, "tool1");
let r2 = det.feed("");
assert!(r2.is_some(), "second call must be detected from remainder");
let c2 = r2.expect("second call");
assert_eq!(c2.name, "tool2");
}
#[test]
fn tool_call_detection_qwen() {
let mut det = ToolCallDetector::new(ToolCallGrammar::Qwen);
let result = det.feed(r#"<tool_call>{"name":"calc","args":{"expr":"1+1"}}</tool_call>"#);
assert!(result.is_some());
let call = result.expect("qwen call");
assert_eq!(call.name, "calc");
}
#[test]
fn tool_call_detection_mistral() {
let mut det = ToolCallDetector::new(ToolCallGrammar::Mistral);
let result = det.feed(r#"[TOOL_CALLS][{"name":"search","args":{"q":"rust"}}]"#);
assert!(result.is_some());
let call = result.expect("mistral call");
assert_eq!(call.name, "search");
assert_eq!(call.args["q"], "rust");
}
#[test]
fn tool_call_detection_custom() {
let mut det = ToolCallDetector::new(ToolCallGrammar::Custom {
open: "<<TOOL>>".to_string(),
close: "<</TOOL>>".to_string(),
});
let result = det.feed(r#"<<TOOL>>{"name":"echo","args":{"msg":"hi"}}<</TOOL>>"#);
assert!(result.is_some());
let call = result.expect("custom call");
assert_eq!(call.name, "echo");
}
#[test]
fn grammar_delimiters_llama3() {
let g = ToolCallGrammar::Llama3;
assert_eq!(g.open_delimiter(), "<|tool_call|>");
assert_eq!(g.close_delimiter(), "<|/tool_call|>");
}
#[test]
fn grammar_delimiters_qwen() {
let g = ToolCallGrammar::Qwen;
assert_eq!(g.open_delimiter(), "<tool_call>");
assert_eq!(g.close_delimiter(), "</tool_call>");
}
#[test]
fn grammar_delimiters_mistral() {
let g = ToolCallGrammar::Mistral;
assert_eq!(g.open_delimiter(), "[TOOL_CALLS][");
assert_eq!(g.close_delimiter(), "]");
}
#[test]
fn grammar_delimiters_custom() {
let g = ToolCallGrammar::Custom {
open: "START".to_string(),
close: "END".to_string(),
};
assert_eq!(g.open_delimiter(), "START");
assert_eq!(g.close_delimiter(), "END");
}
#[test]
fn reset_clears_state() {
let mut det = ToolCallDetector::new(ToolCallGrammar::Llama3);
det.feed("<|tool_call|>{\"name\":\"half");
assert_eq!(det.state, ToolDetectionState::Capturing);
det.reset();
assert_eq!(det.state, ToolDetectionState::Idle);
assert!(det.buffer.is_empty());
let r = det.feed(r#"<|tool_call|>{"name":"fresh","args":{}}<|/tool_call|>"#);
assert!(r.is_some(), "should detect call after reset");
}
#[test]
fn tool_result_ok_injection_string() {
let result = ToolResult::Ok(Value::String("42°C".to_string()));
let s = result.as_injection_string();
assert!(s.contains("<tool_result>"), "must contain opening tag");
assert!(s.contains("</tool_result>"), "must contain closing tag");
assert!(s.contains("42°C"), "must contain result value");
}
#[test]
fn tool_result_err_injection_string() {
let result = ToolResult::Err("not found".to_string());
let s = result.as_injection_string();
assert!(s.contains("<tool_result>"), "must contain opening tag");
assert!(s.contains("error"), "must contain error key");
}
#[test]
fn tool_call_arguments_alias() {
let mut det = ToolCallDetector::new(ToolCallGrammar::Llama3);
let r = det.feed(r#"<|tool_call|>{"name":"fn","arguments":{"k":"v"}}<|/tool_call|>"#);
assert!(r.is_some(), "arguments alias should be accepted");
let call = r.expect("call with arguments");
assert_eq!(call.args["k"], "v");
}
#[test]
fn no_op_dispatcher_returns_ok_null() {
let d = no_op_dispatcher();
let result = d.invoke("anything", &Value::Null);
assert!(matches!(result, ToolResult::Ok(Value::Null)));
}
}