use std::sync::OnceLock;
use regex::Regex;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::api_types::{FunctionCall, Tool, ToolCall, ToolChoice};
pub use infernum_core::ModelFamily;
static TOOL_CALL_REGEX: OnceLock<Regex> = OnceLock::new();
static TOOL_CALL_EXTRACT_REGEX: OnceLock<Regex> = OnceLock::new();
fn get_tool_call_regex() -> &'static Regex {
TOOL_CALL_REGEX.get_or_init(|| {
Regex::new(r"(?s)<tool_call>\s*(.*?)\s*</tool_call>").expect("invalid tool_call regex")
})
}
fn get_tool_call_extract_regex() -> &'static Regex {
TOOL_CALL_EXTRACT_REGEX.get_or_init(|| {
Regex::new(r"(?s)<tool_call>.*?</tool_call>").expect("invalid tool_call_extract regex")
})
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct DetectedToolCall {
pub id: String,
pub name: String,
pub arguments: String,
}
impl DetectedToolCall {
#[must_use]
pub fn to_tool_call(&self) -> ToolCall {
ToolCall {
id: self.id.clone(),
call_type: "function".to_string(),
function: FunctionCall {
name: self.name.clone(),
arguments: self.arguments.clone(),
},
}
}
}
#[derive(Debug, Clone)]
pub struct ToolProcessingResult {
pub content: Option<String>,
pub tool_calls: Vec<ToolCall>,
pub finish_reason: String,
}
#[must_use]
pub fn format_tools_for_prompt(tools: &[Tool], model_family: ModelFamily) -> String {
if tools.is_empty() {
return String::new();
}
match model_family {
ModelFamily::Qwen | ModelFamily::Unknown => format_tools_qwen(tools),
ModelFamily::Llama => format_tools_llama(tools),
ModelFamily::Mistral => format_tools_mistral(tools),
}
}
fn format_tools_qwen(tools: &[Tool]) -> String {
let mut result = String::from(
"\n\n# Tools\n\n\
You may call one or more functions to assist with the user query.\n\n\
You are provided with function signatures within <tools></tools> XML tags:\n\
<tools>",
);
for tool in tools {
let func_json = serde_json::json!({
"type": "function",
"function": {
"name": tool.function.name,
"description": tool.function.description,
"parameters": tool.function.parameters
}
});
result.push('\n');
result.push_str(&serde_json::to_string(&func_json).unwrap_or_default());
}
result.push_str(
"\n</tools>\n\n\
For each function call, return a json object with function name and arguments \
within <tool_call></tool_call> XML tags:\n\
<tool_call>\n\
{\"name\": <function-name>, \"arguments\": <args-json-object>}\n\
</tool_call>",
);
result
}
fn format_tools_llama(tools: &[Tool]) -> String {
let mut result = String::from("\n\nYou have access to the following functions:\n\n");
for tool in tools {
let func_json = serde_json::json!({
"name": tool.function.name,
"description": tool.function.description,
"parameters": tool.function.parameters
});
result.push_str(&serde_json::to_string_pretty(&func_json).unwrap_or_default());
result.push_str("\n\n");
}
result.push_str("To call a function, respond with a JSON object in the following format:\n");
result.push_str(
"<|python_tag|>{\"name\": \"function_name\", \"arguments\": {\"arg1\": \"value1\"}}\n",
);
result.push_str("\nOnly call functions when necessary to answer the user's request.\n");
result
}
fn format_tools_mistral(tools: &[Tool]) -> String {
let mut result = String::from("\n\n[AVAILABLE_TOOLS]\n");
let tools_json: Vec<serde_json::Value> = tools
.iter()
.map(|t| {
serde_json::json!({
"type": "function",
"function": {
"name": t.function.name,
"description": t.function.description,
"parameters": t.function.parameters
}
})
})
.collect();
result.push_str(&serde_json::to_string(&tools_json).unwrap_or_default());
result.push_str("\n[/AVAILABLE_TOOLS]\n\n");
result.push_str("When you need to call a tool, respond with:\n");
result.push_str(
"[TOOL_CALLS] [{\"name\": \"function_name\", \"arguments\": {\"arg1\": \"value1\"}}]\n",
);
result
}
#[must_use]
pub fn detect_tool_calls(output: &str, model_family: ModelFamily) -> Vec<DetectedToolCall> {
match model_family {
ModelFamily::Qwen | ModelFamily::Unknown => detect_tool_calls_qwen(output),
ModelFamily::Llama => detect_tool_calls_llama(output),
ModelFamily::Mistral => detect_tool_calls_mistral(output),
}
}
fn detect_tool_calls_qwen(output: &str) -> Vec<DetectedToolCall> {
let re = get_tool_call_regex();
let mut calls = Vec::new();
for cap in re.captures_iter(output) {
if let Some(json_match) = cap.get(1) {
let json_str = json_match.as_str();
if let Ok(parsed) = serde_json::from_str::<ToolCallJson>(json_str) {
let id = format!("call_{}", Uuid::new_v4().simple());
calls.push(DetectedToolCall {
id,
name: parsed.name,
arguments: serde_json::to_string(&parsed.arguments).unwrap_or_default(),
});
}
}
}
calls
}
#[derive(Debug, Deserialize)]
struct ToolCallJson {
name: String,
arguments: serde_json::Value,
}
fn detect_tool_calls_llama(output: &str) -> Vec<DetectedToolCall> {
let marker = "<|python_tag|>";
let mut calls = Vec::new();
let mut search_start = 0;
while let Some(marker_pos) = output[search_start..].find(marker) {
let abs_marker_pos = search_start + marker_pos;
let json_start = abs_marker_pos + marker.len();
let remaining = &output[json_start..];
if let Some(json_str) = extract_json_object(remaining, 0) {
if let Ok(parsed) = serde_json::from_str::<ToolCallJson>(&json_str) {
let id = format!("call_{}", Uuid::new_v4().simple());
calls.push(DetectedToolCall {
id,
name: parsed.name,
arguments: serde_json::to_string(&parsed.arguments).unwrap_or_default(),
});
search_start = json_start + json_str.len();
continue;
}
}
search_start = json_start;
}
if calls.is_empty() {
calls = detect_tool_calls_qwen(output);
}
calls
}
static MISTRAL_TOOL_CALL_REGEX: OnceLock<Regex> = OnceLock::new();
fn get_mistral_tool_call_regex() -> &'static Regex {
MISTRAL_TOOL_CALL_REGEX.get_or_init(|| {
Regex::new(r#"\[TOOL_CALLS\]\s*\[([^\]]+)\]"#).expect("invalid mistral tool_call regex")
})
}
fn detect_tool_calls_mistral(output: &str) -> Vec<DetectedToolCall> {
let re = get_mistral_tool_call_regex();
let mut calls = Vec::new();
for cap in re.captures_iter(output) {
if let Some(json_match) = cap.get(1) {
let json_str = format!("[{}]", json_match.as_str());
if let Ok(parsed) = serde_json::from_str::<Vec<ToolCallJson>>(&json_str) {
for tool_call in parsed {
let id = format!("call_{}", Uuid::new_v4().simple());
calls.push(DetectedToolCall {
id,
name: tool_call.name,
arguments: serde_json::to_string(&tool_call.arguments).unwrap_or_default(),
});
}
}
}
}
if calls.is_empty() {
calls = detect_tool_calls_qwen(output);
}
calls
}
#[must_use]
pub fn extract_text_content(output: &str, model_family: ModelFamily) -> Option<String> {
match model_family {
ModelFamily::Qwen | ModelFamily::Unknown => extract_text_content_qwen(output),
ModelFamily::Llama => extract_text_content_llama(output),
ModelFamily::Mistral => extract_text_content_mistral(output),
}
}
fn extract_text_content_qwen(output: &str) -> Option<String> {
let re = get_tool_call_extract_regex();
let cleaned = re.replace_all(output, "");
let trimmed = cleaned.trim();
if trimmed.is_empty() {
None
} else {
Some(trimmed.to_string())
}
}
static LLAMA_EXTRACT_REGEX: OnceLock<Regex> = OnceLock::new();
fn get_llama_extract_regex() -> &'static Regex {
LLAMA_EXTRACT_REGEX.get_or_init(|| {
Regex::new(r#"<\|python_tag\|>\s*\{(?:[^{}]|\{[^{}]*\})*\}"#)
.expect("invalid llama extract regex")
})
}
fn extract_text_content_llama(output: &str) -> Option<String> {
let re = get_llama_extract_regex();
let cleaned = re.replace_all(output, "");
let qwen_re = get_tool_call_extract_regex();
let cleaned = qwen_re.replace_all(&cleaned, "");
let trimmed = cleaned.trim();
if trimmed.is_empty() {
None
} else {
Some(trimmed.to_string())
}
}
static MISTRAL_EXTRACT_REGEX: OnceLock<Regex> = OnceLock::new();
fn get_mistral_extract_regex() -> &'static Regex {
MISTRAL_EXTRACT_REGEX.get_or_init(|| {
Regex::new(r#"\[TOOL_CALLS\]\s*\[[^\]]+\]"#).expect("invalid mistral extract regex")
})
}
fn extract_text_content_mistral(output: &str) -> Option<String> {
let re = get_mistral_extract_regex();
let cleaned = re.replace_all(output, "");
let qwen_re = get_tool_call_extract_regex();
let cleaned = qwen_re.replace_all(&cleaned, "");
let trimmed = cleaned.trim();
if trimmed.is_empty() {
None
} else {
Some(trimmed.to_string())
}
}
#[must_use]
pub fn process_model_output(output: &str, model_family: ModelFamily) -> ToolProcessingResult {
let detected = detect_tool_calls(output, model_family);
let content = extract_text_content(output, model_family);
let tool_calls: Vec<ToolCall> = detected
.iter()
.map(DetectedToolCall::to_tool_call)
.collect();
let finish_reason = if tool_calls.is_empty() {
"stop".to_string()
} else {
"tool_calls".to_string()
};
ToolProcessingResult {
content,
tool_calls,
finish_reason,
}
}
#[must_use]
pub fn validate_tool_exists(tool_name: &str, tools: &[Tool]) -> bool {
tools.iter().any(|t| t.function.name == tool_name)
}
#[must_use]
pub fn should_include_tools(tool_choice: Option<&ToolChoice>) -> bool {
match tool_choice {
None => true, Some(ToolChoice::String(s)) => s != "none",
Some(ToolChoice::Tool(_)) => true, }
}
#[must_use]
pub fn get_forced_tool(tool_choice: Option<&ToolChoice>) -> Option<&str> {
match tool_choice {
Some(ToolChoice::Tool(tc)) => Some(&tc.function.name),
_ => None,
}
}
const QWEN_START_MARKERS: &[&str] = &["<tool_call>", "<tool_call", "<tool_", "<tool"];
const LLAMA_START_MARKERS: &[&str] = &[
"<|python_tag|>",
"<|python_tag",
"<|python_",
"<|python",
"<|",
];
const MISTRAL_START_MARKERS: &[&str] = &["[TOOL_CALLS]", "[TOOL_CALLS", "[TOOL_", "[TOOL"];
#[must_use]
pub fn buffer_might_contain_tool_start(buffer: &str, model_family: ModelFamily) -> bool {
let markers = match model_family {
ModelFamily::Qwen | ModelFamily::Unknown => QWEN_START_MARKERS,
ModelFamily::Llama => LLAMA_START_MARKERS,
ModelFamily::Mistral => MISTRAL_START_MARKERS,
};
for marker in markers {
for prefix_len in 1..=marker.len() {
let prefix = &marker[..prefix_len];
if buffer.ends_with(prefix) {
return true;
}
}
}
false
}
#[must_use]
pub fn definitely_not_tool_call(buffer: &str, model_family: ModelFamily) -> bool {
let full_marker = match model_family {
ModelFamily::Qwen | ModelFamily::Unknown => "<tool_call>",
ModelFamily::Llama => "<|python_tag|>",
ModelFamily::Mistral => "[TOOL_CALLS]",
};
if buffer.contains(full_marker) {
return false;
}
if buffer.len() > full_marker.len() {
let tail_start = buffer.len().saturating_sub(full_marker.len());
let tail = &buffer[tail_start..];
!buffer_might_contain_tool_start(tail, model_family)
} else {
!buffer_might_contain_tool_start(buffer, model_family)
}
}
#[derive(Debug, Clone)]
pub struct StreamingExtractResult {
pub found: bool,
pub text_before: Option<String>,
pub call: Option<DetectedToolCall>,
pub remaining: String,
}
#[must_use]
pub fn try_extract_complete_tool_call(
buffer: &str,
model_family: ModelFamily,
) -> StreamingExtractResult {
match model_family {
ModelFamily::Qwen | ModelFamily::Unknown => try_extract_qwen(buffer),
ModelFamily::Llama => try_extract_llama(buffer),
ModelFamily::Mistral => try_extract_mistral(buffer),
}
}
fn try_extract_qwen(buffer: &str) -> StreamingExtractResult {
let start_tag = "<tool_call>";
let end_tag = "</tool_call>";
if let Some(start_idx) = buffer.find(start_tag) {
let json_start = start_idx + start_tag.len();
if let Some(end_offset) = buffer[json_start..].find(end_tag) {
let end_idx = json_start + end_offset;
let json_content = buffer[json_start..end_idx].trim();
if let Ok(parsed) = serde_json::from_str::<ToolCallJson>(json_content) {
let id = format!("call_{}", Uuid::new_v4().simple());
let call = DetectedToolCall {
id,
name: parsed.name,
arguments: serde_json::to_string(&parsed.arguments).unwrap_or_default(),
};
let text_before = if start_idx > 0 {
Some(buffer[..start_idx].to_string())
} else {
None
};
let remaining = buffer[end_idx + end_tag.len()..].to_string();
return StreamingExtractResult {
found: true,
text_before,
call: Some(call),
remaining,
};
}
}
}
StreamingExtractResult {
found: false,
text_before: None,
call: None,
remaining: buffer.to_string(),
}
}
fn try_extract_llama(buffer: &str) -> StreamingExtractResult {
let marker = "<|python_tag|>";
if let Some(start_idx) = buffer.find(marker) {
let json_start = start_idx + marker.len();
let json_part = &buffer[json_start..];
if let Some(json_str) = extract_json_object(json_part, 0) {
if let Ok(parsed) = serde_json::from_str::<ToolCallJson>(&json_str) {
let id = format!("call_{}", Uuid::new_v4().simple());
let call = DetectedToolCall {
id,
name: parsed.name,
arguments: serde_json::to_string(&parsed.arguments).unwrap_or_default(),
};
let text_before = if start_idx > 0 {
Some(buffer[..start_idx].to_string())
} else {
None
};
let remaining = buffer[json_start + json_str.len()..].to_string();
return StreamingExtractResult {
found: true,
text_before,
call: Some(call),
remaining,
};
}
}
}
try_extract_qwen(buffer)
}
fn try_extract_mistral(buffer: &str) -> StreamingExtractResult {
let marker = "[TOOL_CALLS]";
if let Some(start_idx) = buffer.find(marker) {
let after_marker = &buffer[start_idx + marker.len()..];
if let Some(arr_start) = after_marker.find('[') {
if let Some(arr_end) = find_matching_bracket(after_marker, arr_start) {
let json_arr = &after_marker[arr_start..=arr_end];
if let Ok(parsed) = serde_json::from_str::<Vec<ToolCallJson>>(json_arr) {
if let Some(first) = parsed.into_iter().next() {
let id = format!("call_{}", Uuid::new_v4().simple());
let call = DetectedToolCall {
id,
name: first.name,
arguments: serde_json::to_string(&first.arguments).unwrap_or_default(),
};
let text_before = if start_idx > 0 {
Some(buffer[..start_idx].to_string())
} else {
None
};
let remaining = after_marker[arr_end + 1..].to_string();
return StreamingExtractResult {
found: true,
text_before,
call: Some(call),
remaining,
};
}
}
}
}
}
try_extract_qwen(buffer)
}
fn find_matching_bracket(s: &str, start: usize) -> Option<usize> {
let bytes = s.as_bytes();
let open_char = bytes.get(start)?;
let close_char = match open_char {
b'[' => b']',
b'{' => b'}',
_ => return None,
};
let mut depth = 0;
let mut in_string = false;
let mut escape_next = false;
for (i, &b) in bytes.iter().enumerate().skip(start) {
if escape_next {
escape_next = false;
continue;
}
if b == b'\\' && in_string {
escape_next = true;
continue;
}
if b == b'"' {
in_string = !in_string;
continue;
}
if in_string {
continue;
}
if b == *open_char {
depth += 1;
} else if b == close_char {
depth -= 1;
if depth == 0 {
return Some(i);
}
}
}
None
}
#[must_use]
pub fn extract_json_object(s: &str, start: usize) -> Option<String> {
let bytes = s.as_bytes();
let obj_start = bytes.iter().skip(start).position(|&b| b == b'{')? + start;
let obj_end = find_matching_bracket(s, obj_start)?;
Some(s[obj_start..=obj_end].to_string())
}
#[derive(Debug, Clone, Default)]
pub struct ProcessingOptions {
pub parallel_tool_calls: bool,
}
#[must_use]
pub fn enforce_parallel_tool_calls(
calls: Vec<DetectedToolCall>,
parallel: bool,
) -> Vec<DetectedToolCall> {
if parallel {
calls
} else {
calls.into_iter().take(1).collect()
}
}
#[must_use]
pub fn process_model_output_with_options(
output: &str,
model_family: ModelFamily,
options: ProcessingOptions,
) -> ToolProcessingResult {
let detected = detect_tool_calls(output, model_family);
let detected = enforce_parallel_tool_calls(detected, options.parallel_tool_calls);
let content = extract_text_content(output, model_family);
let tool_calls: Vec<ToolCall> = detected
.iter()
.map(DetectedToolCall::to_tool_call)
.collect();
let finish_reason = if tool_calls.is_empty() {
"stop".to_string()
} else {
"tool_calls".to_string()
};
ToolProcessingResult {
content,
tool_calls,
finish_reason,
}
}
pub fn validate_tool_arguments(arguments: &str, schema: &serde_json::Value) -> Result<(), String> {
let args: serde_json::Value =
serde_json::from_str(arguments).map_err(|e| format!("Invalid JSON: {e}"))?;
validate_value_against_schema(&args, schema, "")
}
fn validate_value_against_schema(
value: &serde_json::Value,
schema: &serde_json::Value,
path: &str,
) -> Result<(), String> {
if let Some(expected_type) = schema.get("type").and_then(|t| t.as_str()) {
let actual_type = match value {
serde_json::Value::Null => "null",
serde_json::Value::Bool(_) => "boolean",
serde_json::Value::Number(n) => {
if n.is_i64() || n.is_u64() {
"integer"
} else {
"number"
}
},
serde_json::Value::String(_) => "string",
serde_json::Value::Array(_) => "array",
serde_json::Value::Object(_) => "object",
};
let type_matches =
actual_type == expected_type || (expected_type == "number" && actual_type == "integer");
if !type_matches {
return Err(format!(
"Type mismatch at {}: expected {expected_type}, got {actual_type}",
if path.is_empty() { "root" } else { path }
));
}
}
if let Some(enum_values) = schema.get("enum").and_then(|e| e.as_array()) {
if !enum_values.contains(value) {
return Err(format!(
"Invalid enum value at {}: {:?} not in {:?}",
if path.is_empty() { "root" } else { path },
value,
enum_values
));
}
}
if let serde_json::Value::Object(obj) = value {
if let Some(required) = schema.get("required").and_then(|r| r.as_array()) {
for req in required {
if let Some(field_name) = req.as_str() {
if !obj.contains_key(field_name) {
return Err(format!("Missing required field: {field_name}"));
}
}
}
}
if let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) {
for (key, prop_value) in obj {
if let Some(prop_schema) = properties.get(key) {
let prop_path = if path.is_empty() {
key.clone()
} else {
format!("{path}.{key}")
};
validate_value_against_schema(prop_value, prop_schema, &prop_path)?;
}
}
}
}
if let serde_json::Value::Array(arr) = value {
if let Some(items_schema) = schema.get("items") {
for (i, item) in arr.iter().enumerate() {
let item_path = format!("{path}[{i}]");
validate_value_against_schema(item, items_schema, &item_path)?;
}
}
}
Ok(())
}
#[derive(Debug, Clone)]
pub struct ToolValidationResult {
pub result: ToolProcessingResult,
pub validation_errors: Vec<(String, String)>,
}
#[must_use]
pub fn process_model_output_with_validation(
output: &str,
model_family: ModelFamily,
tools: &[Tool],
) -> ToolValidationResult {
let detected = detect_tool_calls(output, model_family);
let content = extract_text_content(output, model_family);
let mut validation_errors = Vec::new();
for call in &detected {
if let Some(tool) = tools.iter().find(|t| t.function.name == call.name) {
if tool.function.strict == Some(true) {
if let Some(schema) = &tool.function.parameters {
if let Err(e) = validate_tool_arguments(&call.arguments, schema) {
validation_errors.push((call.name.clone(), e));
}
}
}
}
}
let tool_calls: Vec<ToolCall> = detected
.iter()
.map(DetectedToolCall::to_tool_call)
.collect();
let finish_reason = if tool_calls.is_empty() {
"stop".to_string()
} else {
"tool_calls".to_string()
};
ToolValidationResult {
result: ToolProcessingResult {
content,
tool_calls,
finish_reason,
},
validation_errors,
}
}
#[derive(Debug, Clone)]
pub struct DetectedCallsValidation {
pub valid_calls: Vec<DetectedToolCall>,
pub unknown_tools: Vec<String>,
}
#[must_use]
pub fn validate_detected_calls(
detected: &[DetectedToolCall],
tools: &[Tool],
) -> DetectedCallsValidation {
let mut valid_calls = Vec::new();
let mut unknown_tools = Vec::new();
for call in detected {
if validate_tool_exists(&call.name, tools) {
valid_calls.push(call.clone());
} else {
unknown_tools.push(call.name.clone());
}
}
DetectedCallsValidation {
valid_calls,
unknown_tools,
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ToolDetectionEvent {
Text(String),
ToolCall(DetectedToolCall),
Buffering,
}
#[derive(Debug)]
pub struct StreamingToolDetector {
buffer: String,
model_family: ModelFamily,
in_potential_tool_call: bool,
}
impl StreamingToolDetector {
#[must_use]
pub fn new(model_family: ModelFamily) -> Self {
Self {
buffer: String::new(),
model_family,
in_potential_tool_call: false,
}
}
pub fn process_chunk(&mut self, chunk: &str) -> Vec<ToolDetectionEvent> {
self.buffer.push_str(chunk);
self.evaluate_buffer()
}
pub fn finish(&mut self) -> Vec<ToolDetectionEvent> {
if self.buffer.is_empty() {
return vec![];
}
let remaining = std::mem::take(&mut self.buffer);
self.in_potential_tool_call = false;
vec![ToolDetectionEvent::Text(remaining)]
}
fn evaluate_buffer(&mut self) -> Vec<ToolDetectionEvent> {
let mut events = Vec::new();
loop {
let result = try_extract_complete_tool_call(&self.buffer, self.model_family);
if result.found {
if let Some(text) = result.text_before {
if !text.is_empty() {
events.push(ToolDetectionEvent::Text(text));
}
}
if let Some(call) = result.call {
events.push(ToolDetectionEvent::ToolCall(call));
}
self.buffer = result.remaining;
self.in_potential_tool_call = false;
continue;
}
if buffer_might_contain_tool_start(&self.buffer, self.model_family) {
self.in_potential_tool_call = true;
if events.is_empty() {
events.push(ToolDetectionEvent::Buffering);
}
break;
}
if definitely_not_tool_call(&self.buffer, self.model_family)
|| !self.in_potential_tool_call
{
if !self.buffer.is_empty() {
let text = std::mem::take(&mut self.buffer);
events.push(ToolDetectionEvent::Text(text));
}
self.in_potential_tool_call = false;
break;
}
break;
}
events
}
#[must_use]
pub fn is_buffering(&self) -> bool {
self.in_potential_tool_call
}
#[must_use]
pub fn buffer(&self) -> &str {
&self.buffer
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum SseEvent {
Text {
content: String,
},
ToolCall {
id: String,
name: String,
arguments: serde_json::Value,
},
Error {
code: String,
message: String,
#[serde(skip_serializing_if = "Option::is_none")]
details: Option<serde_json::Value>,
},
Done {
finish_reason: String,
#[serde(skip_serializing_if = "Option::is_none")]
usage: Option<SseUsage>,
},
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SseUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
impl SseEvent {
#[must_use]
pub fn text(content: impl Into<String>) -> Self {
Self::Text {
content: content.into(),
}
}
#[must_use]
pub fn tool_call(call: &DetectedToolCall) -> Self {
let arguments = serde_json::from_str(&call.arguments)
.unwrap_or_else(|_| serde_json::Value::String(call.arguments.clone()));
Self::ToolCall {
id: call.id.clone(),
name: call.name.clone(),
arguments,
}
}
#[must_use]
pub fn error(code: impl Into<String>, message: impl Into<String>) -> Self {
Self::Error {
code: code.into(),
message: message.into(),
details: None,
}
}
#[must_use]
pub fn error_with_details(
code: impl Into<String>,
message: impl Into<String>,
details: serde_json::Value,
) -> Self {
Self::Error {
code: code.into(),
message: message.into(),
details: Some(details),
}
}
#[must_use]
pub fn done(finish_reason: impl Into<String>) -> Self {
Self::Done {
finish_reason: finish_reason.into(),
usage: None,
}
}
#[must_use]
pub fn done_with_usage(
finish_reason: impl Into<String>,
prompt_tokens: u32,
completion_tokens: u32,
) -> Self {
Self::Done {
finish_reason: finish_reason.into(),
usage: Some(SseUsage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
}),
}
}
#[must_use]
pub fn to_json(&self) -> String {
serde_json::to_string(self).unwrap_or_else(|_| {
r#"{"type":"error","code":"serialization_error","message":"Failed to serialize event"}"#.to_string()
})
}
}
impl From<ToolDetectionEvent> for Option<SseEvent> {
fn from(event: ToolDetectionEvent) -> Self {
match event {
ToolDetectionEvent::Text(content) => Some(SseEvent::text(content)),
ToolDetectionEvent::ToolCall(call) => Some(SseEvent::tool_call(&call)),
ToolDetectionEvent::Buffering => None, }
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn make_tool(name: &str, description: &str, params: serde_json::Value) -> Tool {
Tool {
tool_type: "function".to_string(),
function: crate::api_types::FunctionDefinition {
name: name.to_string(),
description: Some(description.to_string()),
parameters: Some(params),
strict: None,
},
}
}
#[test]
fn test_model_family_qwen() {
assert_eq!(
ModelFamily::from_model_name("Qwen/Qwen2.5-7B-Instruct"),
ModelFamily::Qwen
);
assert_eq!(
ModelFamily::from_model_name("qwen2.5-coder"),
ModelFamily::Qwen
);
}
#[test]
fn test_model_family_llama() {
assert_eq!(
ModelFamily::from_model_name("meta-llama/Llama-3.2-3B-Instruct"),
ModelFamily::Llama
);
assert_eq!(
ModelFamily::from_model_name("llama-3.1"),
ModelFamily::Llama
);
}
#[test]
fn test_model_family_mistral() {
assert_eq!(
ModelFamily::from_model_name("mistralai/Mistral-7B-Instruct"),
ModelFamily::Mistral
);
assert_eq!(
ModelFamily::from_model_name("Mixtral-8x7B"),
ModelFamily::Mistral
);
}
#[test]
fn test_model_family_unknown() {
assert_eq!(
ModelFamily::from_model_name("unknown-model"),
ModelFamily::Unknown
);
}
#[test]
fn test_format_empty_tools() {
let result = format_tools_for_prompt(&[], ModelFamily::Qwen);
assert!(result.is_empty());
}
#[test]
fn test_format_single_tool_qwen_native() {
let tool = make_tool(
"get_weather",
"Get current weather for a location",
json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "City name"
}
},
"required": ["location"]
}),
);
let result = format_tools_for_prompt(&[tool], ModelFamily::Qwen);
assert!(result.contains("# Tools"), "Missing '# Tools' header");
assert!(
result.contains("You may call one or more functions to assist with the user query"),
"Missing native preamble text"
);
assert!(result.contains("<tools>"), "Missing <tools> opening tag");
assert!(result.contains("</tools>"), "Missing </tools> closing tag");
assert!(
result.contains("\"type\":\"function\""),
"Tool not serialized as JSON function definition"
);
assert!(
result.contains("\"name\":\"get_weather\""),
"Tool name not in JSON format"
);
assert!(
!result.contains("## get_weather"),
"Still using old markdown-style format"
);
assert!(
!result.contains("Parameters:\n-"),
"Still using old markdown parameter list"
);
assert!(
result.contains("<tool_call>"),
"Missing <tool_call> instruction"
);
}
#[test]
fn test_format_multiple_tools_qwen_native() {
let tools = vec![
make_tool(
"tool_a",
"First tool",
json!({"type": "object", "properties": {}}),
),
make_tool(
"tool_b",
"Second tool",
json!({"type": "object", "properties": {}}),
),
];
let result = format_tools_for_prompt(&tools, ModelFamily::Qwen);
assert!(result.contains("<tools>"));
assert!(result.contains("</tools>"));
assert!(result.contains("\"name\":\"tool_a\""));
assert!(result.contains("\"name\":\"tool_b\""));
assert!(result.contains("\"description\":\"First tool\""));
assert!(result.contains("\"description\":\"Second tool\""));
assert!(!result.contains("## tool_a"));
assert!(!result.contains("## tool_b"));
}
#[test]
fn test_detect_qwen_tool_call() {
let output = r#"I'll get the weather for you.
<tool_call>
{"name": "get_weather", "arguments": {"location": "Seattle"}}
</tool_call>"#;
let calls = detect_tool_calls(output, ModelFamily::Qwen);
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "get_weather");
assert!(calls[0].arguments.contains("Seattle"));
assert!(calls[0].id.starts_with("call_"));
}
#[test]
fn test_detect_multiple_tool_calls() {
let output = r#"<tool_call>
{"name": "tool_a", "arguments": {}}
</tool_call>
Some text
<tool_call>
{"name": "tool_b", "arguments": {"x": 1}}
</tool_call>"#;
let calls = detect_tool_calls(output, ModelFamily::Qwen);
assert_eq!(calls.len(), 2);
assert_eq!(calls[0].name, "tool_a");
assert_eq!(calls[1].name, "tool_b");
}
#[test]
fn test_detect_no_tool_calls() {
let output = "Just a regular response without any tool calls.";
let calls = detect_tool_calls(output, ModelFamily::Qwen);
assert!(calls.is_empty());
}
#[test]
fn test_detect_malformed_tool_call() {
let output = r#"<tool_call>
not valid json
</tool_call>"#;
let calls = detect_tool_calls(output, ModelFamily::Qwen);
assert!(calls.is_empty()); }
#[test]
fn test_extract_text_with_tool_call() {
let output = r#"Here is some text.
<tool_call>
{"name": "test", "arguments": {}}
</tool_call>
More text after."#;
let content = extract_text_content(output, ModelFamily::Qwen);
assert!(content.is_some());
let text = content.unwrap();
assert!(text.contains("Here is some text."));
assert!(text.contains("More text after."));
assert!(!text.contains("<tool_call>"));
}
#[test]
fn test_extract_text_only_tool_call() {
let output = r#"<tool_call>
{"name": "test", "arguments": {}}
</tool_call>"#;
let content = extract_text_content(output, ModelFamily::Qwen);
assert!(content.is_none());
}
#[test]
fn test_extract_text_no_tool_call() {
let output = "Just plain text.";
let content = extract_text_content(output, ModelFamily::Qwen);
assert!(content.is_some());
assert_eq!(content.unwrap(), "Just plain text.");
}
#[test]
fn test_process_model_output_with_tool_call() {
let output = r#"<tool_call>
{"name": "get_weather", "arguments": {"location": "Seattle"}}
</tool_call>"#;
let result = process_model_output(output, ModelFamily::Qwen);
assert_eq!(result.finish_reason, "tool_calls");
assert!(result.content.is_none());
assert_eq!(result.tool_calls.len(), 1);
assert_eq!(result.tool_calls[0].function.name, "get_weather");
}
#[test]
fn test_process_model_output_no_tool_call() {
let output = "This is a regular response.";
let result = process_model_output(output, ModelFamily::Qwen);
assert_eq!(result.finish_reason, "stop");
assert!(result.content.is_some());
assert!(result.tool_calls.is_empty());
}
#[test]
fn test_process_model_output_mixed() {
let output = r#"Let me help you with that.
<tool_call>
{"name": "search", "arguments": {"query": "test"}}
</tool_call>"#;
let result = process_model_output(output, ModelFamily::Qwen);
assert_eq!(result.finish_reason, "tool_calls");
assert!(result.content.is_some());
assert!(result.content.unwrap().contains("Let me help you"));
assert_eq!(result.tool_calls.len(), 1);
}
#[test]
fn test_validate_tool_exists() {
let tools = vec![
make_tool("tool_a", "A", json!({})),
make_tool("tool_b", "B", json!({})),
];
assert!(validate_tool_exists("tool_a", &tools));
assert!(validate_tool_exists("tool_b", &tools));
assert!(!validate_tool_exists("tool_c", &tools));
}
#[test]
fn test_detected_to_tool_call() {
let detected = DetectedToolCall {
id: "call_abc123".to_string(),
name: "test_function".to_string(),
arguments: r#"{"key": "value"}"#.to_string(),
};
let tool_call = detected.to_tool_call();
assert_eq!(tool_call.id, "call_abc123");
assert_eq!(tool_call.call_type, "function");
assert_eq!(tool_call.function.name, "test_function");
assert_eq!(tool_call.function.arguments, r#"{"key": "value"}"#);
}
#[test]
fn test_should_include_tools_none_default() {
assert!(should_include_tools(None));
}
#[test]
fn test_should_include_tools_auto() {
let choice = ToolChoice::String("auto".to_string());
assert!(should_include_tools(Some(&choice)));
}
#[test]
fn test_should_include_tools_none_string() {
let choice = ToolChoice::String("none".to_string());
assert!(!should_include_tools(Some(&choice)));
}
#[test]
fn test_should_include_tools_required() {
let choice = ToolChoice::String("required".to_string());
assert!(should_include_tools(Some(&choice)));
}
#[test]
fn test_should_include_tools_specific_tool() {
use crate::api_types::{ToolChoiceFunction, ToolChoiceFunctionName};
let choice = ToolChoice::Tool(ToolChoiceFunction {
choice_type: "function".to_string(),
function: ToolChoiceFunctionName {
name: "get_weather".to_string(),
},
});
assert!(should_include_tools(Some(&choice)));
}
#[test]
fn test_get_forced_tool_none() {
assert!(get_forced_tool(None).is_none());
}
#[test]
fn test_get_forced_tool_auto() {
let choice = ToolChoice::String("auto".to_string());
assert!(get_forced_tool(Some(&choice)).is_none());
}
#[test]
fn test_get_forced_tool_specific() {
use crate::api_types::{ToolChoiceFunction, ToolChoiceFunctionName};
let choice = ToolChoice::Tool(ToolChoiceFunction {
choice_type: "function".to_string(),
function: ToolChoiceFunctionName {
name: "get_weather".to_string(),
},
});
assert_eq!(get_forced_tool(Some(&choice)), Some("get_weather"));
}
#[test]
fn test_format_single_tool_llama() {
let tool = make_tool(
"get_weather",
"Get current weather for a location",
json!({
"type": "object",
"properties": {
"location": {"type": "string", "description": "City name"}
},
"required": ["location"]
}),
);
let result = format_tools_for_prompt(&[tool], ModelFamily::Llama);
assert!(result.contains("get_weather"));
assert!(result.contains("Get current weather"));
assert!(result.contains("<|python_tag|>"));
}
#[test]
fn test_detect_llama_tool_call() {
let output = r#"I'll check the weather.
<|python_tag|>{"name": "get_weather", "arguments": {"location": "Seattle"}}"#;
let calls = detect_tool_calls(output, ModelFamily::Llama);
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "get_weather");
assert!(calls[0].arguments.contains("Seattle"));
}
#[test]
fn test_extract_text_llama() {
let output = r#"Here is some text.
<|python_tag|>{"name": "test", "arguments": {}}
More text after."#;
let content = extract_text_content(output, ModelFamily::Llama);
assert!(content.is_some());
let text = content.unwrap();
assert!(text.contains("Here is some text."));
assert!(text.contains("More text after."));
assert!(!text.contains("<|python_tag|>"));
}
#[test]
fn test_format_single_tool_mistral() {
let tool = make_tool(
"get_weather",
"Get current weather for a location",
json!({
"type": "object",
"properties": {
"location": {"type": "string", "description": "City name"}
},
"required": ["location"]
}),
);
let result = format_tools_for_prompt(&[tool], ModelFamily::Mistral);
assert!(result.contains("[AVAILABLE_TOOLS]"));
assert!(result.contains("get_weather"));
assert!(result.contains("[TOOL_CALLS]"));
}
#[test]
fn test_detect_mistral_tool_call() {
let output = r#"I'll check the weather.
[TOOL_CALLS] [{"name": "get_weather", "arguments": {"location": "Seattle"}}]"#;
let calls = detect_tool_calls(output, ModelFamily::Mistral);
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "get_weather");
assert!(calls[0].arguments.contains("Seattle"));
}
#[test]
fn test_extract_text_mistral() {
let output = r#"Here is some text.
[TOOL_CALLS] [{"name": "test", "arguments": {}}]
More text after."#;
let content = extract_text_content(output, ModelFamily::Mistral);
assert!(content.is_some());
let text = content.unwrap();
assert!(text.contains("Here is some text."));
assert!(text.contains("More text after."));
assert!(!text.contains("[TOOL_CALLS]"));
}
#[test]
fn test_llama_falls_back_to_qwen_format() {
let output = r#"<tool_call>
{"name": "get_weather", "arguments": {"location": "Seattle"}}
</tool_call>"#;
let calls = detect_tool_calls(output, ModelFamily::Llama);
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "get_weather");
}
#[test]
fn test_mistral_falls_back_to_qwen_format() {
let output = r#"<tool_call>
{"name": "get_weather", "arguments": {"location": "Seattle"}}
</tool_call>"#;
let calls = detect_tool_calls(output, ModelFamily::Mistral);
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "get_weather");
}
mod phase3_streaming {
use super::*;
#[test]
fn test_buffer_detects_potential_tool_start_qwen() {
assert!(buffer_might_contain_tool_start("<tool", ModelFamily::Qwen));
assert!(buffer_might_contain_tool_start("<tool_", ModelFamily::Qwen));
assert!(buffer_might_contain_tool_start(
"<tool_call",
ModelFamily::Qwen
));
assert!(!buffer_might_contain_tool_start(
"hello world",
ModelFamily::Qwen
));
assert!(!buffer_might_contain_tool_start(
"the tool is ready",
ModelFamily::Qwen
));
}
#[test]
fn test_buffer_detects_potential_tool_start_llama() {
assert!(buffer_might_contain_tool_start(
"<|python",
ModelFamily::Llama
));
assert!(buffer_might_contain_tool_start(
"<|python_tag",
ModelFamily::Llama
));
assert!(!buffer_might_contain_tool_start(
"hello",
ModelFamily::Llama
));
}
#[test]
fn test_buffer_detects_potential_tool_start_mistral() {
assert!(buffer_might_contain_tool_start(
"[TOOL",
ModelFamily::Mistral
));
assert!(buffer_might_contain_tool_start(
"[TOOL_CALLS",
ModelFamily::Mistral
));
assert!(!buffer_might_contain_tool_start(
"hello",
ModelFamily::Mistral
));
}
#[test]
fn test_try_extract_complete_tool_call_qwen() {
let buffer = r#"Some text<tool_call>
{"name": "test", "arguments": {"key": "value"}}
</tool_call>More text"#;
let result = try_extract_complete_tool_call(buffer, ModelFamily::Qwen);
assert!(result.found);
assert_eq!(result.text_before, Some("Some text".to_string()));
assert_eq!(result.call.as_ref().unwrap().name, "test");
assert_eq!(result.remaining, "More text");
}
#[test]
fn test_try_extract_incomplete_tool_call() {
let buffer = r#"<tool_call>
{"name": "test", "arguments": {"key": "value"}}"#;
let result = try_extract_complete_tool_call(buffer, ModelFamily::Qwen);
assert!(!result.found);
assert!(result.call.is_none());
}
#[test]
fn test_definitely_not_tool_call() {
assert!(definitely_not_tool_call(
"<tooltip>hover</tooltip>",
ModelFamily::Qwen
));
assert!(!definitely_not_tool_call("<tool_call>", ModelFamily::Qwen));
}
}
mod phase3_parallel_tool_calls {
use super::*;
#[test]
fn test_enforce_single_tool_call() {
let calls = vec![
DetectedToolCall {
id: "call_1".to_string(),
name: "tool_a".to_string(),
arguments: "{}".to_string(),
},
DetectedToolCall {
id: "call_2".to_string(),
name: "tool_b".to_string(),
arguments: "{}".to_string(),
},
];
let enforced = enforce_parallel_tool_calls(calls.clone(), false);
assert_eq!(enforced.len(), 1);
assert_eq!(enforced[0].name, "tool_a");
let all = enforce_parallel_tool_calls(calls, true);
assert_eq!(all.len(), 2);
}
#[test]
fn test_process_model_output_respects_parallel() {
let output = r#"<tool_call>
{"name": "tool_a", "arguments": {}}
</tool_call>
<tool_call>
{"name": "tool_b", "arguments": {}}
</tool_call>"#;
let result = process_model_output_with_options(
output,
ModelFamily::Qwen,
ProcessingOptions {
parallel_tool_calls: false,
..Default::default()
},
);
assert_eq!(result.tool_calls.len(), 1);
assert_eq!(result.tool_calls[0].function.name, "tool_a");
}
}
mod phase3_strict_mode {
use super::*;
#[test]
fn test_validate_tool_arguments_valid() {
let schema = json!({
"type": "object",
"properties": {
"location": {"type": "string"},
"units": {"type": "string", "enum": ["celsius", "fahrenheit"]}
},
"required": ["location"]
});
let arguments = r#"{"location": "Seattle", "units": "celsius"}"#;
let result = validate_tool_arguments(arguments, &schema);
assert!(result.is_ok());
}
#[test]
fn test_validate_tool_arguments_missing_required() {
let schema = json!({
"type": "object",
"properties": {
"location": {"type": "string"}
},
"required": ["location"]
});
let arguments = r#"{}"#; let result = validate_tool_arguments(arguments, &schema);
assert!(result.is_err());
assert!(result.unwrap_err().contains("location"));
}
#[test]
fn test_validate_tool_arguments_wrong_type() {
let schema = json!({
"type": "object",
"properties": {
"count": {"type": "integer"}
}
});
let arguments = r#"{"count": "not a number"}"#;
let result = validate_tool_arguments(arguments, &schema);
assert!(result.is_err());
}
#[test]
fn test_validate_tool_arguments_invalid_enum() {
let schema = json!({
"type": "object",
"properties": {
"status": {"type": "string", "enum": ["active", "inactive"]}
}
});
let arguments = r#"{"status": "unknown"}"#;
let result = validate_tool_arguments(arguments, &schema);
assert!(result.is_err());
}
#[test]
fn test_process_validates_strict_tools() {
let tool = Tool {
tool_type: "function".to_string(),
function: crate::api_types::FunctionDefinition {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: Some(json!({
"type": "object",
"properties": {
"location": {"type": "string"}
},
"required": ["location"]
})),
strict: Some(true), },
};
let output = r#"<tool_call>
{"name": "get_weather", "arguments": {}}
</tool_call>"#;
let result = process_model_output_with_validation(output, ModelFamily::Qwen, &[tool]);
assert!(!result.validation_errors.is_empty());
}
}
mod phase3_deep_json_parsing {
use super::*;
#[test]
fn test_extract_deeply_nested_json() {
let json_str =
r#"{"name": "test", "arguments": {"outer": {"middle": {"inner": "value"}}}}"#;
let result = extract_json_object(json_str, 0);
assert!(result.is_some());
let extracted = result.unwrap();
assert!(extracted.contains("inner"));
assert!(extracted.contains("value"));
}
#[test]
fn test_extract_json_with_arrays() {
let json_str = r#"{"items": [{"a": 1}, {"b": [1, 2, {"c": 3}]}]}"#;
let result = extract_json_object(json_str, 0);
assert!(result.is_some());
let parsed: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
assert!(parsed.get("items").is_some());
}
#[test]
fn test_extract_json_with_escaped_quotes() {
let json_str = r#"{"message": "He said \"hello\""}"#;
let result = extract_json_object(json_str, 0);
assert!(result.is_some());
}
#[test]
fn test_detect_deeply_nested_tool_call() {
let output = r#"<tool_call>
{"name": "complex_tool", "arguments": {"data": {"level1": {"level2": {"level3": "deep"}}}}}
</tool_call>"#;
let calls = detect_tool_calls(output, ModelFamily::Qwen);
assert_eq!(calls.len(), 1);
let args: serde_json::Value = serde_json::from_str(&calls[0].arguments).unwrap();
assert!(args["data"]["level1"]["level2"]["level3"].as_str() == Some("deep"));
}
#[test]
fn test_llama_deeply_nested() {
let output = r#"<|python_tag|>{"name": "test", "arguments": {"a": {"b": {"c": {"d": "deep"}}}}}"#;
let calls = detect_tool_calls(output, ModelFamily::Llama);
assert_eq!(calls.len(), 1);
let args: serde_json::Value = serde_json::from_str(&calls[0].arguments).unwrap();
assert_eq!(args["a"]["b"]["c"]["d"], "deep");
}
}
mod phase3_edge_cases {
use super::*;
#[test]
fn test_qwen_end_tag_before_start_tag() {
let buffer = r#"</tool_call>garbage<tool_call>
{"name": "real_tool", "arguments": {"key": "value"}}
</tool_call>remaining"#;
let result = try_extract_complete_tool_call(buffer, ModelFamily::Qwen);
assert!(result.found, "Should find the valid tool call");
assert_eq!(result.call.as_ref().unwrap().name, "real_tool");
assert_eq!(result.remaining, "remaining");
}
#[test]
fn test_qwen_multiple_stray_end_tags() {
let buffer = r#"text</tool_call></tool_call><tool_call>
{"name": "test", "arguments": {}}
</tool_call>"#;
let result = try_extract_complete_tool_call(buffer, ModelFamily::Qwen);
assert!(result.found);
assert_eq!(result.call.as_ref().unwrap().name, "test");
}
#[test]
fn test_use_validation_function_for_tool_checking() {
let tool = Tool {
tool_type: "function".to_string(),
function: crate::api_types::FunctionDefinition {
name: "known_tool".to_string(),
description: Some("A known tool".to_string()),
parameters: Some(json!({
"type": "object",
"properties": {},
})),
strict: Some(true),
},
};
let output = r#"<tool_call>
{"name": "unknown_tool", "arguments": {}}
</tool_call>"#;
let result = process_model_output_with_validation(output, ModelFamily::Qwen, &[tool]);
assert_eq!(result.result.tool_calls.len(), 1);
}
#[test]
fn test_valid_calls_only_contains_valid() {
let tools = vec![make_tool("known_tool", "A known tool", json!({}))];
let detected = vec![
DetectedToolCall {
id: "call_1".to_string(),
name: "known_tool".to_string(),
arguments: "{}".to_string(),
},
DetectedToolCall {
id: "call_2".to_string(),
name: "unknown_tool".to_string(),
arguments: "{}".to_string(),
},
];
let result = validate_detected_calls(&detected, &tools);
assert_eq!(result.unknown_tools.len(), 1);
assert_eq!(result.unknown_tools[0], "unknown_tool");
assert_eq!(
result.valid_calls.len(),
1,
"valid_calls should only contain calls to known tools"
);
assert_eq!(result.valid_calls[0].name, "known_tool");
}
#[test]
fn test_mismatched_brackets_no_panic() {
let json_str = r#"{"key": {"unclosed": "value"}"#;
let result = extract_json_object(json_str, 0);
assert!(result.is_none());
}
#[test]
fn test_extra_closing_bracket() {
let json_str = r#"{"key": "value"}}"#;
let result = extract_json_object(json_str, 0);
assert!(result.is_some());
let extracted = result.unwrap();
assert_eq!(extracted, r#"{"key": "value"}"#);
}
#[test]
fn test_empty_tool_name() {
let output = r#"<tool_call>
{"name": "", "arguments": {"key": "value"}}
</tool_call>"#;
let calls = detect_tool_calls(output, ModelFamily::Qwen);
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "");
}
#[test]
fn test_whitespace_only_buffer() {
let buffer = " \n\t ";
let result = try_extract_complete_tool_call(buffer, ModelFamily::Qwen);
assert!(!result.found);
assert!(result.call.is_none());
}
#[test]
fn test_null_arguments() {
let output = r#"<tool_call>
{"name": "test", "arguments": null}
</tool_call>"#;
let calls = detect_tool_calls(output, ModelFamily::Qwen);
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].arguments, "null");
}
#[test]
fn test_unicode_in_arguments() {
let output = r#"<tool_call>
{"name": "test", "arguments": {"message": "Hello 世界 🌍"}}
</tool_call>"#;
let calls = detect_tool_calls(output, ModelFamily::Qwen);
assert_eq!(calls.len(), 1);
let args: serde_json::Value = serde_json::from_str(&calls[0].arguments).unwrap();
assert_eq!(args["message"], "Hello 世界 🌍");
}
#[test]
fn test_brackets_in_string_values() {
let output = r#"<tool_call>
{"name": "test", "arguments": {"code": "arr[0] = {a: 1}"}}
</tool_call>"#;
let calls = detect_tool_calls(output, ModelFamily::Qwen);
assert_eq!(calls.len(), 1);
let args: serde_json::Value = serde_json::from_str(&calls[0].arguments).unwrap();
assert_eq!(args["code"], "arr[0] = {a: 1}");
}
#[test]
fn test_ten_levels_deep_nesting() {
let output = r#"<tool_call>
{"name": "deep", "arguments": {"l1": {"l2": {"l3": {"l4": {"l5": {"l6": {"l7": {"l8": {"l9": {"l10": "bottom"}}}}}}}}}}}
</tool_call>"#;
let calls = detect_tool_calls(output, ModelFamily::Qwen);
assert_eq!(calls.len(), 1);
let args: serde_json::Value = serde_json::from_str(&calls[0].arguments).unwrap();
assert_eq!(
args["l1"]["l2"]["l3"]["l4"]["l5"]["l6"]["l7"]["l8"]["l9"]["l10"],
"bottom"
);
}
}
mod phase3_unknown_tool_logging {
use super::*;
#[test]
fn test_validate_detected_calls_known_tools() {
let tools = vec![
make_tool("tool_a", "A", json!({})),
make_tool("tool_b", "B", json!({})),
];
let detected = vec![DetectedToolCall {
id: "call_1".to_string(),
name: "tool_a".to_string(),
arguments: "{}".to_string(),
}];
let result = validate_detected_calls(&detected, &tools);
assert!(result.unknown_tools.is_empty());
assert_eq!(result.valid_calls.len(), 1);
}
#[test]
fn test_validate_detected_calls_unknown_tools() {
let tools = vec![make_tool("tool_a", "A", json!({}))];
let detected = vec![DetectedToolCall {
id: "call_1".to_string(),
name: "unknown_tool".to_string(),
arguments: "{}".to_string(),
}];
let result = validate_detected_calls(&detected, &tools);
assert_eq!(result.unknown_tools.len(), 1);
assert_eq!(result.unknown_tools[0], "unknown_tool");
assert_eq!(result.valid_calls.len(), 0);
}
}
mod streaming_tool_detector {
use super::*;
#[test]
fn test_new_detector() {
let detector = StreamingToolDetector::new(ModelFamily::Qwen);
assert!(!detector.is_buffering());
assert!(detector.buffer().is_empty());
}
#[test]
fn test_plain_text_streams_immediately() {
let mut detector = StreamingToolDetector::new(ModelFamily::Qwen);
let events = detector.process_chunk("Hello, world!");
assert_eq!(events.len(), 1);
assert!(matches!(&events[0], ToolDetectionEvent::Text(t) if t == "Hello, world!"));
assert!(!detector.is_buffering());
}
#[test]
fn test_multiple_plain_text_chunks() {
let mut detector = StreamingToolDetector::new(ModelFamily::Qwen);
let events1 = detector.process_chunk("Hello ");
let events2 = detector.process_chunk("world!");
assert_eq!(events1.len(), 1);
assert!(matches!(&events1[0], ToolDetectionEvent::Text(t) if t == "Hello "));
assert_eq!(events2.len(), 1);
assert!(matches!(&events2[0], ToolDetectionEvent::Text(t) if t == "world!"));
}
#[test]
fn test_complete_tool_call_in_one_chunk() {
let mut detector = StreamingToolDetector::new(ModelFamily::Qwen);
let chunk = r#"<tool_call>
{"name": "get_weather", "arguments": {"location": "NYC"}}
</tool_call>"#;
let events = detector.process_chunk(chunk);
assert_eq!(events.len(), 1);
match &events[0] {
ToolDetectionEvent::ToolCall(call) => {
assert_eq!(call.name, "get_weather");
assert!(call.arguments.contains("NYC"));
},
other => panic!("Expected ToolCall, got {:?}", other),
}
}
#[test]
fn test_tool_call_split_across_chunks() {
let mut detector = StreamingToolDetector::new(ModelFamily::Qwen);
let events1 = detector.process_chunk("<tool");
assert!(detector.is_buffering());
assert_eq!(events1.len(), 1);
assert!(matches!(events1[0], ToolDetectionEvent::Buffering));
let events2 = detector.process_chunk("_call>");
assert!(detector.is_buffering());
let events3 = detector.process_chunk(
r#"
{"name": "test", "arguments": {}}"#,
);
assert!(detector.is_buffering());
let events4 = detector.process_chunk("\n</tool_call>");
assert!(!detector.is_buffering());
assert_eq!(events4.len(), 1);
match &events4[0] {
ToolDetectionEvent::ToolCall(call) => {
assert_eq!(call.name, "test");
},
other => panic!("Expected ToolCall, got {:?}", other),
}
}
#[test]
fn test_text_before_tool_call() {
let mut detector = StreamingToolDetector::new(ModelFamily::Qwen);
let chunk = r#"Some intro text<tool_call>
{"name": "test", "arguments": {}}
</tool_call>"#;
let events = detector.process_chunk(chunk);
assert_eq!(events.len(), 2);
assert!(matches!(&events[0], ToolDetectionEvent::Text(t) if t == "Some intro text"));
assert!(matches!(&events[1], ToolDetectionEvent::ToolCall(_)));
}
#[test]
fn test_text_after_tool_call() {
let mut detector = StreamingToolDetector::new(ModelFamily::Qwen);
let chunk = r#"<tool_call>
{"name": "test", "arguments": {}}
</tool_call>Some trailing text"#;
let events = detector.process_chunk(chunk);
assert_eq!(events.len(), 2);
assert!(matches!(&events[0], ToolDetectionEvent::ToolCall(_)));
assert!(matches!(&events[1], ToolDetectionEvent::Text(t) if t == "Some trailing text"));
let final_events = detector.finish();
assert!(final_events.is_empty());
}
#[test]
fn test_multiple_tool_calls() {
let mut detector = StreamingToolDetector::new(ModelFamily::Qwen);
let chunk = r#"<tool_call>
{"name": "tool_a", "arguments": {}}
</tool_call>
<tool_call>
{"name": "tool_b", "arguments": {}}
</tool_call>"#;
let events = detector.process_chunk(chunk);
let tool_calls: Vec<_> = events
.iter()
.filter(|e| matches!(e, ToolDetectionEvent::ToolCall(_)))
.collect();
assert_eq!(tool_calls.len(), 2);
}
#[test]
fn test_finish_flushes_incomplete_buffer() {
let mut detector = StreamingToolDetector::new(ModelFamily::Qwen);
detector.process_chunk("<tool_call>");
detector.process_chunk(r#"{"name": "incomplete""#);
assert!(detector.is_buffering());
let events = detector.finish();
assert_eq!(events.len(), 1);
match &events[0] {
ToolDetectionEvent::Text(text) => {
assert!(text.contains("<tool_call>"));
assert!(text.contains("incomplete"));
},
other => panic!("Expected Text, got {:?}", other),
}
}
#[test]
fn test_empty_finish() {
let mut detector = StreamingToolDetector::new(ModelFamily::Qwen);
detector.process_chunk(r#"<tool_call>{"name": "t", "arguments": {}}</tool_call>"#);
let events = detector.finish();
assert!(events.is_empty());
}
#[test]
fn test_false_positive_tool_start() {
let mut detector = StreamingToolDetector::new(ModelFamily::Qwen);
let events1 = detector.process_chunk("<tool");
assert!(detector.is_buffering());
let events2 = detector.process_chunk("tip>helpful tip</tooltip>");
assert!(!detector.is_buffering());
let has_text = events2
.iter()
.any(|e| matches!(e, ToolDetectionEvent::Text(_)));
assert!(has_text);
}
#[test]
fn test_llama_format() {
let mut detector = StreamingToolDetector::new(ModelFamily::Llama);
let chunk = r#"<|python_tag|>{"name": "test", "arguments": {"x": 1}}"#;
let events = detector.process_chunk(chunk);
assert_eq!(events.len(), 1);
assert!(matches!(&events[0], ToolDetectionEvent::ToolCall(c) if c.name == "test"));
}
#[test]
fn test_mistral_format() {
let mut detector = StreamingToolDetector::new(ModelFamily::Mistral);
let chunk = r#"[TOOL_CALLS][{"name": "test", "arguments": {}}]"#;
let events = detector.process_chunk(chunk);
assert_eq!(events.len(), 1);
assert!(matches!(&events[0], ToolDetectionEvent::ToolCall(c) if c.name == "test"));
}
#[test]
fn test_character_by_character_streaming() {
let mut detector = StreamingToolDetector::new(ModelFamily::Qwen);
let full_text = r#"Hi<tool_call>{"name": "t", "arguments": {}}</tool_call>bye"#;
let mut all_events = Vec::new();
for c in full_text.chars() {
all_events.extend(detector.process_chunk(&c.to_string()));
}
all_events.extend(detector.finish());
let text_events: Vec<_> = all_events
.iter()
.filter(|e| matches!(e, ToolDetectionEvent::Text(_)))
.collect();
let tool_events: Vec<_> = all_events
.iter()
.filter(|e| matches!(e, ToolDetectionEvent::ToolCall(_)))
.collect();
assert_eq!(tool_events.len(), 1);
assert!(!text_events.is_empty());
}
#[test]
fn test_deeply_nested_json() {
let mut detector = StreamingToolDetector::new(ModelFamily::Qwen);
let chunk = r#"<tool_call>
{"name": "nested", "arguments": {"a": {"b": {"c": {"d": "deep"}}}}}
</tool_call>"#;
let events = detector.process_chunk(chunk);
assert_eq!(events.len(), 1);
match &events[0] {
ToolDetectionEvent::ToolCall(call) => {
assert_eq!(call.name, "nested");
let args: serde_json::Value = serde_json::from_str(&call.arguments).unwrap();
assert_eq!(args["a"]["b"]["c"]["d"], "deep");
},
other => panic!("Expected ToolCall, got {:?}", other),
}
}
}
mod sse_events {
use super::*;
#[test]
fn test_text_event_json_format() {
let event = SseEvent::text("Hello world");
let json = event.to_json();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["type"], "text");
assert_eq!(parsed["content"], "Hello world");
assert!(parsed.get("id").is_none());
assert!(parsed.get("name").is_none());
}
#[test]
fn test_tool_call_event_json_format() {
let call = DetectedToolCall {
id: "call_abc123".to_string(),
name: "get_weather".to_string(),
arguments: r#"{"location":"NYC","units":"celsius"}"#.to_string(),
};
let event = SseEvent::tool_call(&call);
let json = event.to_json();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["type"], "tool_call");
assert_eq!(parsed["id"], "call_abc123");
assert_eq!(parsed["name"], "get_weather");
assert_eq!(parsed["arguments"]["location"], "NYC");
assert_eq!(parsed["arguments"]["units"], "celsius");
}
#[test]
fn test_error_event_json_format() {
let event = SseEvent::error("rate_limit", "Too many requests");
let json = event.to_json();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["type"], "error");
assert_eq!(parsed["code"], "rate_limit");
assert_eq!(parsed["message"], "Too many requests");
assert!(parsed.get("details").is_none());
}
#[test]
fn test_error_event_with_details_json_format() {
let event = SseEvent::error_with_details(
"validation_error",
"Invalid arguments",
json!({"field": "location", "reason": "required"}),
);
let json = event.to_json();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["type"], "error");
assert_eq!(parsed["details"]["field"], "location");
}
#[test]
fn test_done_event_json_format() {
let event = SseEvent::done("stop");
let json = event.to_json();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["type"], "done");
assert_eq!(parsed["finish_reason"], "stop");
assert!(parsed.get("usage").is_none());
}
#[test]
fn test_done_event_with_usage_json_format() {
let event = SseEvent::done_with_usage("tool_calls", 42, 18);
let json = event.to_json();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["type"], "done");
assert_eq!(parsed["finish_reason"], "tool_calls");
assert_eq!(parsed["usage"]["prompt_tokens"], 42);
assert_eq!(parsed["usage"]["completion_tokens"], 18);
assert_eq!(parsed["usage"]["total_tokens"], 60);
}
#[test]
fn test_deserialize_text_event() {
let json = r#"{"type":"text","content":"Hello"}"#;
let event: SseEvent = serde_json::from_str(json).unwrap();
match event {
SseEvent::Text { content } => assert_eq!(content, "Hello"),
other => panic!("Expected Text, got {:?}", other),
}
}
#[test]
fn test_deserialize_tool_call_event() {
let json = r#"{"type":"tool_call","id":"call_123","name":"test","arguments":{"x":1}}"#;
let event: SseEvent = serde_json::from_str(json).unwrap();
match event {
SseEvent::ToolCall {
id,
name,
arguments,
} => {
assert_eq!(id, "call_123");
assert_eq!(name, "test");
assert_eq!(arguments["x"], 1);
},
other => panic!("Expected ToolCall, got {:?}", other),
}
}
#[test]
fn test_deserialize_done_event() {
let json = r#"{"type":"done","finish_reason":"stop"}"#;
let event: SseEvent = serde_json::from_str(json).unwrap();
match event {
SseEvent::Done {
finish_reason,
usage,
} => {
assert_eq!(finish_reason, "stop");
assert!(usage.is_none());
},
other => panic!("Expected Done, got {:?}", other),
}
}
#[test]
fn test_convert_text_detection_to_sse() {
let detection = ToolDetectionEvent::Text("Hello".to_string());
let sse: Option<SseEvent> = detection.into();
assert!(sse.is_some());
match sse.unwrap() {
SseEvent::Text { content } => assert_eq!(content, "Hello"),
other => panic!("Expected Text, got {:?}", other),
}
}
#[test]
fn test_convert_tool_call_detection_to_sse() {
let call = DetectedToolCall {
id: "call_xyz".to_string(),
name: "my_tool".to_string(),
arguments: r#"{"a":1}"#.to_string(),
};
let detection = ToolDetectionEvent::ToolCall(call);
let sse: Option<SseEvent> = detection.into();
assert!(sse.is_some());
match sse.unwrap() {
SseEvent::ToolCall {
id,
name,
arguments,
} => {
assert_eq!(id, "call_xyz");
assert_eq!(name, "my_tool");
assert_eq!(arguments["a"], 1);
},
other => panic!("Expected ToolCall, got {:?}", other),
}
}
#[test]
fn test_convert_buffering_detection_to_sse() {
let detection = ToolDetectionEvent::Buffering;
let sse: Option<SseEvent> = detection.into();
assert!(sse.is_none());
}
#[test]
fn test_detector_to_sse_stream_text_only() {
let mut detector = StreamingToolDetector::new(ModelFamily::Qwen);
let events = detector.process_chunk("Hello world");
let sse_events: Vec<SseEvent> = events.into_iter().filter_map(|e| e.into()).collect();
assert_eq!(sse_events.len(), 1);
let json = sse_events[0].to_json();
assert!(json.contains(r#""type":"text""#));
assert!(json.contains(r#""content":"Hello world""#));
}
#[test]
fn test_detector_to_sse_stream_with_tool_call() {
let mut detector = StreamingToolDetector::new(ModelFamily::Qwen);
let chunk = r#"Let me help<tool_call>
{"name": "search", "arguments": {"query": "rust"}}
</tool_call>"#;
let events = detector.process_chunk(chunk);
let sse_events: Vec<SseEvent> = events.into_iter().filter_map(|e| e.into()).collect();
assert_eq!(sse_events.len(), 2);
match &sse_events[0] {
SseEvent::Text { content } => assert_eq!(content, "Let me help"),
other => panic!("Expected Text, got {:?}", other),
}
match &sse_events[1] {
SseEvent::ToolCall {
name, arguments, ..
} => {
assert_eq!(name, "search");
assert_eq!(arguments["query"], "rust");
},
other => panic!("Expected ToolCall, got {:?}", other),
}
}
#[test]
fn test_full_streaming_session_to_sse() {
let mut detector = StreamingToolDetector::new(ModelFamily::Qwen);
let mut all_sse: Vec<SseEvent> = Vec::new();
let chunks = vec![
"I'll check ",
"the weather",
"<tool_call>",
r#"{"name": "get_weather", "arguments": {"location": "NYC"}}"#,
"</tool_call>",
" for you.",
];
for chunk in chunks {
let events = detector.process_chunk(chunk);
all_sse.extend(events.into_iter().filter_map(|e| e.into()));
}
all_sse.extend(detector.finish().into_iter().filter_map(|e| e.into()));
let text_count = all_sse
.iter()
.filter(|e| matches!(e, SseEvent::Text { .. }))
.count();
let tool_count = all_sse
.iter()
.filter(|e| matches!(e, SseEvent::ToolCall { .. }))
.count();
assert!(text_count >= 1, "Should have text events");
assert_eq!(tool_count, 1, "Should have exactly one tool call");
let tool_event = all_sse
.iter()
.find(|e| matches!(e, SseEvent::ToolCall { .. }))
.unwrap();
match tool_event {
SseEvent::ToolCall {
name, arguments, ..
} => {
assert_eq!(name, "get_weather");
assert_eq!(arguments["location"], "NYC");
},
_ => unreachable!(),
}
}
}
}