use serde_json::{json, Value};
use super::message::ChatMessage;
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: Value,
}
impl ToolDefinition {
#[must_use]
pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters: json!({
"type": "object",
"properties": {},
"required": []
}),
}
}
#[must_use]
pub fn with_parameters(mut self, params: Value) -> Self {
self.parameters = params;
self
}
#[must_use]
pub fn to_openai_function(&self) -> Value {
json!({
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": self.parameters,
}
})
}
pub(crate) fn to_json(&self) -> Value {
json!({
"name": self.name,
"description": self.description,
"parameters": self.parameters,
})
}
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: Value,
}
impl ToolCall {
#[must_use]
pub fn new(id: impl Into<String>, name: impl Into<String>, arguments: Value) -> Self {
Self {
id: id.into(),
name: name.into(),
arguments,
}
}
pub(crate) fn to_json(&self) -> Value {
json!({
"id": self.id,
"name": self.name,
"arguments": self.arguments,
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ToolParseError {
InvalidJson(String),
MissingName,
MissingArguments,
}
impl std::fmt::Display for ToolParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidJson(s) => write!(f, "invalid JSON: {s}"),
Self::MissingName => write!(f, "missing `name` field"),
Self::MissingArguments => write!(f, "missing `arguments` field"),
}
}
}
impl std::error::Error for ToolParseError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum ToolFormat {
#[default]
ChatMl,
Mistral,
Llama3,
Plain,
Functionary,
}
impl ToolFormat {
#[must_use]
pub fn from_chat_format(name: &str) -> Self {
let n = name.to_ascii_lowercase();
if n.contains("qwen") || n.contains("hermes") || n.contains("chatml") {
Self::ChatMl
} else if n.contains("mistral") || n.contains("mixtral") {
Self::Mistral
} else if n.contains("llama-3") || n.contains("llama3") || n.contains("firefunction") {
Self::Llama3
} else if n.contains("functionary") {
Self::Functionary
} else {
Self::Plain
}
}
}
#[derive(Debug)]
pub struct ToolParser {
format: ToolFormat,
buffer: String,
brace_depth: i32,
in_call: bool,
next_id: u32,
}
impl ToolParser {
#[must_use]
pub fn new(format: ToolFormat) -> Self {
Self {
format,
buffer: String::new(),
brace_depth: 0,
in_call: false,
next_id: 0,
}
}
#[must_use]
pub fn for_chat_format(name: &str) -> Self {
Self::new(ToolFormat::from_chat_format(name))
}
pub fn feed(&mut self, chunk: &str) -> Vec<Result<ToolCall, ToolParseError>> {
let mut out = Vec::new();
for c in chunk.chars() {
self.feed_char(c, &mut out);
}
out
}
#[must_use]
pub fn in_call(&self) -> bool {
self.in_call
}
#[must_use]
pub fn current_partial(&self) -> Option<&str> {
if self.in_call {
Some(self.buffer.as_str())
} else {
None
}
}
fn feed_char(&mut self, c: char, out: &mut Vec<Result<ToolCall, ToolParseError>>) {
let s = c.to_string();
match self.format {
ToolFormat::ChatMl => {
if !self.in_call {
if self.buffer.len() < 12 {
self.buffer.push(c);
if self.buffer.ends_with("<tool_call>") {
self.buffer.clear();
self.in_call = true;
}
} else {
self.buffer.clear();
self.buffer.push(c);
}
} else {
self.buffer.push(c);
match c {
'{' => self.brace_depth += 1,
'}' => {
self.brace_depth -= 1;
if self.brace_depth == 0 && self.buffer.trim().ends_with('}') {
let raw = std::mem::take(&mut self.buffer);
if let Some(call) = self.parse_json_call(&raw) {
out.push(Ok(call));
}
self.in_call = false;
}
}
_ => {}
}
}
}
ToolFormat::Mistral => {
if !self.in_call {
self.buffer.push(c);
if self.buffer.contains("[TOOL_CALLS]") {
self.buffer.clear();
self.in_call = true;
self.brace_depth = 0;
}
if self.buffer.len() > 64 {
self.buffer.drain(..self.buffer.len() - 32);
}
} else {
self.buffer.push(c);
match c {
'[' => self.brace_depth += 1,
']' => {
self.brace_depth -= 1;
if self.brace_depth <= 0 {
let raw = std::mem::take(&mut self.buffer);
if let Ok(Value::Array(items)) = serde_json::from_str::<Value>(&raw)
{
for item in items {
if let Some(call) = self.parse_call_obj(&item) {
out.push(Ok(call));
}
}
}
self.in_call = false;
self.brace_depth = 0;
}
}
_ => {}
}
}
}
ToolFormat::Llama3 => {
if !self.in_call {
self.buffer.push(c);
if self.buffer.ends_with("<|python_tag|>") {
self.buffer.clear();
self.in_call = true;
self.brace_depth = 0;
}
} else {
self.buffer.push(c);
match c {
'{' => self.brace_depth += 1,
'}' => {
self.brace_depth -= 1;
if self.brace_depth == 0 {
let raw = std::mem::take(&mut self.buffer);
if let Some(call) = self.parse_json_call(&raw) {
out.push(Ok(call));
}
self.in_call = false;
}
}
_ => {}
}
}
}
ToolFormat::Plain => {
self.buffer.push(c);
match c {
'{' => self.brace_depth += 1,
'}' => {
self.brace_depth -= 1;
if self.brace_depth == 0 && self.buffer.trim().starts_with('{') {
let raw = std::mem::take(&mut self.buffer);
if let Some(call) = self.parse_json_call(&raw) {
out.push(Ok(call));
}
}
}
_ => {}
}
}
ToolFormat::Functionary => {
if !self.in_call {
self.buffer.push(c);
if self.buffer.contains("<|call|>") {
let raw = self
.buffer
.replace("<|call|>", "")
.replace("<|start|>function<|message|>", "")
.trim()
.to_string();
self.buffer.clear();
if let Some(call) = self.parse_json_call(&raw) {
out.push(Ok(call));
}
}
if self.buffer.len() > 1024 {
self.buffer.clear();
}
}
let _ = s;
}
}
}
pub fn finish(&mut self) -> Vec<Result<ToolCall, ToolParseError>> {
let mut out = Vec::new();
let buf = std::mem::take(&mut self.buffer);
if !buf.is_empty() && buf.trim().starts_with('{') && buf.trim().ends_with('}') {
if let Some(call) = self.parse_json_call(&buf) {
out.push(Ok(call));
}
}
out
}
fn parse_json_call(&mut self, raw: &str) -> Option<ToolCall> {
let v: Value = serde_json::from_str(raw).ok()?;
self.parse_call_obj(&v)
}
fn parse_call_obj(&mut self, v: &Value) -> Option<ToolCall> {
let name = v.get("name")?.as_str()?.to_string();
let arguments = v.get("arguments")?.clone();
self.next_id += 1;
let id = format!("call_{}", self.next_id);
Some(ToolCall::new(id, name, arguments))
}
}
pub fn extract_tool_calls(format: ToolFormat, text: &str) -> Vec<Result<ToolCall, ToolParseError>> {
let mut p = ToolParser::new(format);
p.feed(text)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ToolCallDelta {
pub index: usize,
pub id: Option<String>,
pub name: Option<String>,
pub arguments: Option<String>,
pub completed: Option<ToolCall>,
}
#[derive(Debug)]
pub struct ToolCallStream {
parser: ToolParser,
next_index: usize,
active: Option<ActiveStreamCall>,
just_started: bool,
}
#[derive(Debug)]
struct ActiveStreamCall {
index: usize,
id: String,
name: Option<String>,
name_emitted: bool,
args_emitted: String,
args_value_start: Option<usize>,
}
impl ToolCallStream {
#[must_use]
pub fn new(format: ToolFormat) -> Self {
Self {
parser: ToolParser::new(format),
next_index: 0,
active: None,
just_started: false,
}
}
#[must_use]
pub fn for_chat_format(name: &str) -> Self {
Self::new(ToolFormat::from_chat_format(name))
}
#[must_use]
pub fn completed_count(&self) -> usize {
self.next_index
.saturating_sub(usize::from(self.active.is_some()))
}
#[must_use]
pub fn in_call(&self) -> bool {
self.parser.in_call()
}
pub fn feed(&mut self, chunk: &str) -> Vec<ToolCallDelta> {
let was_in_call = self.parser.in_call();
let completed = self.parser.feed(chunk);
let now_in_call = self.parser.in_call();
let mut out = Vec::new();
for call in completed {
let call = match call {
Ok(c) => c,
Err(err) => {
let index = if let Some(a) = self.active.take() {
a.index
} else {
let idx = self.next_index;
self.next_index += 1;
idx
};
out.push(ToolCallDelta {
index,
id: None,
name: None,
arguments: None,
completed: Some(ToolCall::new(
format!("call_err_{index}"),
String::new(),
serde_json::Value::String(err.to_string()),
)),
});
continue;
}
};
let index = if let Some(a) = self.active.take() {
a.index
} else {
let idx = self.next_index;
self.next_index += 1;
idx
};
out.push(ToolCallDelta {
index,
id: None,
name: None,
arguments: None,
completed: Some(call),
});
}
if !was_in_call && now_in_call {
self.just_started = true;
}
if self.just_started {
self.just_started = false;
let index = self.next_index;
self.next_index += 1;
let id = format!("call_{index}");
self.active = Some(ActiveStreamCall {
index,
id: id.clone(),
name: None,
name_emitted: false,
args_emitted: String::new(),
args_value_start: None,
});
out.push(ToolCallDelta {
index,
id: Some(id),
name: None,
arguments: None,
completed: None,
});
}
if let Some(active) = self.active.as_mut() {
if let Some(partial) = self.parser.current_partial() {
if !active.name_emitted {
if let Some(name) = extract_top_level_name(partial) {
active.name = Some(name.clone());
active.name_emitted = true;
out.push(ToolCallDelta {
index: active.index,
id: None,
name: Some(name),
arguments: None,
completed: None,
});
}
}
if active.args_value_start.is_none() {
if let Some(start) = extract_top_level_value_start(partial, "arguments") {
active.args_value_start = Some(start);
}
}
if let Some(start) = active.args_value_start {
let value_end = value_end_offset(partial, start);
if value_end > active.args_emitted.len() {
let diff =
partial[start + active.args_emitted.len()..value_end].to_string();
active.args_emitted.push_str(&diff);
out.push(ToolCallDelta {
index: active.index,
id: None,
name: None,
arguments: Some(diff),
completed: None,
});
}
}
}
}
out
}
pub fn finish(&mut self) -> Vec<ToolCallDelta> {
let mut out = Vec::new();
let final_completed = self.parser.finish();
for call in final_completed {
let call = match call {
Ok(c) => c,
Err(_) => continue,
};
let index = if let Some(a) = self.active.take() {
a.index
} else {
let idx = self.next_index;
self.next_index += 1;
idx
};
out.push(ToolCallDelta {
index,
id: None,
name: None,
arguments: None,
completed: Some(call),
});
}
out
}
}
fn extract_top_level_name(partial: &str) -> Option<String> {
let bytes = partial.as_bytes();
let mut i = 0;
while i < bytes.len() && (bytes[i] as char).is_whitespace() {
i += 1;
}
if i >= bytes.len() || bytes[i] != b'{' {
return None;
}
i += 1;
while i < bytes.len() {
while i < bytes.len() && (bytes[i] as char).is_whitespace() {
i += 1;
}
if i >= bytes.len() {
return None;
}
if bytes[i] == b'}' {
return None;
}
if bytes[i] != b'"' {
return None;
}
i += 1;
let key_start = i;
while i < bytes.len() && bytes[i] != b'"' {
if bytes[i] == b'\\' && i + 1 < bytes.len() {
i += 2;
} else {
i += 1;
}
}
if i >= bytes.len() {
return None;
}
let key = &partial[key_start..i];
i += 1; while i < bytes.len() && (bytes[i] as char).is_whitespace() {
i += 1;
}
if i >= bytes.len() || bytes[i] != b':' {
return None;
}
i += 1;
while i < bytes.len() && (bytes[i] as char).is_whitespace() {
i += 1;
}
if key == "name" {
if i >= bytes.len() || bytes[i] != b'"' {
return None;
}
i += 1;
let mut val = String::new();
while i < bytes.len() {
let c = bytes[i];
if c == b'"' {
return Some(val);
}
if c == b'\\' && i + 1 < bytes.len() {
val.push(bytes[i + 1] as char);
i += 2;
} else {
val.push(c as char);
i += 1;
}
}
return None;
} else {
if i >= bytes.len() {
return None;
}
match bytes[i] {
b'{' | b'[' => {
let open = bytes[i];
let close = if open == b'{' { b'}' } else { b']' };
let mut depth = 1_i32;
i += 1;
while i < bytes.len() && depth > 0 {
if bytes[i] == b'"' {
i += 1;
while i < bytes.len() && bytes[i] != b'"' {
if bytes[i] == b'\\' && i + 1 < bytes.len() {
i += 2;
} else {
i += 1;
}
}
if i < bytes.len() {
i += 1;
}
} else if bytes[i] == open {
depth += 1;
i += 1;
} else if bytes[i] == close {
depth -= 1;
i += 1;
} else {
i += 1;
}
}
if depth > 0 {
return None;
}
}
b'"' => {
i += 1;
while i < bytes.len() && bytes[i] != b'"' {
if bytes[i] == b'\\' && i + 1 < bytes.len() {
i += 2;
} else {
i += 1;
}
}
if i < bytes.len() {
i += 1;
}
}
_ => {
while i < bytes.len() && bytes[i] != b',' && bytes[i] != b'}' {
i += 1;
}
}
}
if i < bytes.len() && bytes[i] == b',' {
i += 1;
}
}
}
None
}
fn extract_top_level_value_start(partial: &str, target_key: &str) -> Option<usize> {
let bytes = partial.as_bytes();
let mut i = 0;
while i < bytes.len() && (bytes[i] as char).is_whitespace() {
i += 1;
}
if i >= bytes.len() || bytes[i] != b'{' {
return None;
}
i += 1;
while i < bytes.len() {
while i < bytes.len() && (bytes[i] as char).is_whitespace() {
i += 1;
}
if i >= bytes.len() || bytes[i] == b'}' {
return None;
}
if bytes[i] != b'"' {
return None;
}
i += 1;
let key_start = i;
while i < bytes.len() && bytes[i] != b'"' {
if bytes[i] == b'\\' && i + 1 < bytes.len() {
i += 2;
} else {
i += 1;
}
}
if i >= bytes.len() {
return None;
}
let key = &partial[key_start..i];
i += 1;
while i < bytes.len() && (bytes[i] as char).is_whitespace() {
i += 1;
}
if i >= bytes.len() || bytes[i] != b':' {
return None;
}
i += 1;
while i < bytes.len() && (bytes[i] as char).is_whitespace() {
i += 1;
}
if key == target_key {
return Some(i);
}
if i >= bytes.len() {
return None;
}
match bytes[i] {
b'{' | b'[' => {
let open = bytes[i];
let close = if open == b'{' { b'}' } else { b']' };
let mut depth = 1_i32;
i += 1;
while i < bytes.len() && depth > 0 {
if bytes[i] == b'"' {
i += 1;
while i < bytes.len() && bytes[i] != b'"' {
if bytes[i] == b'\\' && i + 1 < bytes.len() {
i += 2;
} else {
i += 1;
}
}
if i < bytes.len() {
i += 1;
}
} else if bytes[i] == open {
depth += 1;
i += 1;
} else if bytes[i] == close {
depth -= 1;
i += 1;
} else {
i += 1;
}
}
if depth > 0 {
return None;
}
}
b'"' => {
i += 1;
while i < bytes.len() && bytes[i] != b'"' {
if bytes[i] == b'\\' && i + 1 < bytes.len() {
i += 2;
} else {
i += 1;
}
}
if i < bytes.len() {
i += 1;
}
}
_ => {
while i < bytes.len() && bytes[i] != b',' && bytes[i] != b'}' {
i += 1;
}
}
}
if i < bytes.len() && bytes[i] == b',' {
i += 1;
}
}
None
}
fn value_end_offset(partial: &str, start: usize) -> usize {
let bytes = partial.as_bytes();
if start >= bytes.len() {
return start;
}
match bytes[start] {
b'{' | b'[' => {
let open = bytes[start];
let close = if open == b'{' { b'}' } else { b']' };
let mut depth = 1_i32;
let mut i = start + 1;
while i < bytes.len() && depth > 0 {
if bytes[i] == b'"' {
i += 1;
while i < bytes.len() && bytes[i] != b'"' {
if bytes[i] == b'\\' && i + 1 < bytes.len() {
i += 2;
} else {
i += 1;
}
}
if i < bytes.len() {
i += 1;
}
} else if bytes[i] == open {
depth += 1;
i += 1;
} else if bytes[i] == close {
depth -= 1;
i += 1;
} else {
i += 1;
}
}
if depth == 0 {
i
} else {
bytes.len()
}
}
b'"' => {
let mut i = start + 1;
while i < bytes.len() && bytes[i] != b'"' {
if bytes[i] == b'\\' && i + 1 < bytes.len() {
i += 2;
} else {
i += 1;
}
}
if i < bytes.len() {
i + 1
} else {
bytes.len()
}
}
_ => {
let mut i = start;
while i < bytes.len() && bytes[i] != b',' && bytes[i] != b'}' {
i += 1;
}
i
}
}
}
pub fn tool_calls_to_message(calls: &[ToolCall]) -> ChatMessage {
use super::message::Role;
let mut m = ChatMessage::new(Role::Assistant, String::new());
for c in calls {
m = m.with_tool_call(c.clone());
}
m
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_chatml() {
let s = r#"<tool_call>{"name": "get_weather", "arguments": {"city": "Tokyo"}}</tool_call>"#;
let mut p = ToolParser::new(ToolFormat::ChatMl);
let calls: Vec<_> = p.feed(s).into_iter().filter_map(|r| r.ok()).collect();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "get_weather");
}
#[test]
fn parse_mistral() {
let s = r#"[TOOL_CALLS][{"name": "x", "arguments": {}}]"#;
let mut p = ToolParser::new(ToolFormat::Mistral);
let calls: Vec<_> = p.feed(s).into_iter().filter_map(|r| r.ok()).collect();
assert_eq!(calls.len(), 1);
}
#[test]
fn parse_llama3() {
let s = r#"<|python_tag|>{"name": "x", "arguments": {}}"#;
let mut p = ToolParser::new(ToolFormat::Llama3);
let calls: Vec<_> = p.feed(s).into_iter().filter_map(|r| r.ok()).collect();
assert_eq!(calls.len(), 1);
}
#[test]
fn parse_plain() {
let s = r#"{"name": "x", "arguments": {}}"#;
let mut p = ToolParser::new(ToolFormat::Plain);
let calls: Vec<_> = p.feed(s).into_iter().filter_map(|r| r.ok()).collect();
assert_eq!(calls.len(), 1);
}
#[test]
fn auto_detect_format() {
assert_eq!(ToolFormat::from_chat_format("qwen"), ToolFormat::ChatMl);
assert_eq!(ToolFormat::from_chat_format("llama-3"), ToolFormat::Llama3);
assert_eq!(ToolFormat::from_chat_format("mistral"), ToolFormat::Mistral);
assert_eq!(ToolFormat::from_chat_format("plain"), ToolFormat::Plain);
}
#[test]
fn stream_emits_id_then_name_then_arguments_for_chatml() {
let mut s = ToolCallStream::new(ToolFormat::ChatMl);
let mut all = Vec::new();
all.extend(s.feed("<tool_call>"));
all.extend(s.feed(r#"{"name":"#));
all.extend(s.feed(r#""get_weather""#));
all.extend(s.feed(r#","arguments":{"city":"Tokyo"}}"#));
all.extend(s.feed("</tool_call>"));
assert_eq!(all[0].index, 0);
assert_eq!(all[0].id.as_deref(), Some("call_0"));
assert!(all[0].name.is_none());
assert!(all[0].arguments.is_none());
let name_delta = all
.iter()
.find(|d| d.name.is_some())
.expect("expected a name delta");
assert_eq!(name_delta.index, 0);
assert_eq!(name_delta.name.as_deref(), Some("get_weather"));
let completed = all
.iter()
.rev()
.find(|d| d.completed.is_some())
.expect("expected a completed delta");
let call = completed.completed.as_ref().unwrap();
assert_eq!(call.name, "get_weather");
assert_eq!(call.arguments["city"], "Tokyo");
}
#[test]
fn stream_emits_arguments_growth() {
let mut s = ToolCallStream::new(ToolFormat::ChatMl);
let mut all = Vec::new();
for chunk in [
"<tool_call>",
r#"{"name":"f","arguments":{"#,
r#""a":1}"#,
"}",
"</tool_call>",
] {
all.extend(s.feed(chunk));
}
let arg_diffs: Vec<String> = all.iter().filter_map(|d| d.arguments.clone()).collect();
let joined: String = arg_diffs.iter().map(String::as_str).collect();
assert_eq!(joined, r#"{"a":1}"#);
}
#[test]
fn stream_mistral_array_emits_two_calls() {
let mut s = ToolCallStream::new(ToolFormat::Mistral);
let mut all = Vec::new();
all.extend(s.feed("[TOOL_CALLS]"));
all.extend(s.feed(r#"[{"name":"a","arguments":{}},{"name":"b","arguments":{"x":1}}]"#));
let completed: Vec<_> = all.iter().filter_map(|d| d.completed.as_ref()).collect();
assert_eq!(completed.len(), 2);
assert_eq!(completed[0].name, "a");
assert_eq!(completed[1].name, "b");
assert_eq!(completed[1].arguments["x"], 1);
}
}