use std::collections::HashSet;
use std::str::FromStr;
use crate::types::response_api::Tool;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum ToolPriority {
PreferDefined,
PreferSearched,
#[default]
Merge,
}
impl FromStr for ToolPriority {
type Err = std::convert::Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match s.to_lowercase().as_str() {
"prefer_defined" | "prefer-defined" => ToolPriority::PreferDefined,
"prefer_searched" | "prefer-searched" => ToolPriority::PreferSearched,
"merge" | "combined" => ToolPriority::Merge,
_ => {
tracing::warn!(
"[TOOL_SEARCH] unknown tool_priority '{}', defaulting to 'merge'",
s
);
ToolPriority::Merge
}
})
}
}
#[derive(Debug, Clone)]
pub struct ToolSearchContext {
pending_calls: HashSet<String>,
searched_tools: Vec<Tool>,
predefined_tools: Vec<Tool>,
resolved_tools: Vec<Tool>,
priority: ToolPriority,
finalized: bool,
}
impl ToolSearchContext {
pub fn new(priority: ToolPriority) -> Self {
Self {
pending_calls: HashSet::new(),
searched_tools: Vec::new(),
predefined_tools: Vec::new(),
resolved_tools: Vec::new(),
priority,
finalized: false,
}
}
pub fn register_search_call(&mut self, call_id: &str) {
self.pending_calls.insert(call_id.to_string());
tracing::debug!(
"[TOOL_SEARCH] registered tool_search_call: id={}",
call_id
);
}
pub fn is_registered_search(&self, call_id: &str) -> bool {
self.pending_calls.contains(call_id)
}
pub fn set_predefined_tools(&mut self, tools: Vec<Tool>) {
self.predefined_tools = tools;
tracing::debug!(
"[TOOL_SEARCH] set predefined tools: count={}",
self.predefined_tools.len()
);
}
pub fn add_searched_tools(&mut self, tools: Vec<Tool>, call_id: &str) {
if !self.pending_calls.contains(call_id) {
tracing::warn!(
"[TOOL_SEARCH] tool_search_output has unrecognized call_id '{}', \
it may not have a corresponding tool_search_call",
call_id
);
}
let count = tools.len();
self.searched_tools.extend(tools);
tracing::debug!(
"[TOOL_SEARCH] added {} tools from tool_search_output: call_id={}, total_searched={}",
count,
call_id,
self.searched_tools.len()
);
}
pub fn complete_search(&mut self, call_id: &str) {
self.pending_calls.remove(call_id);
tracing::debug!(
"[TOOL_SEARCH] completed tool_search_call: call_id={}",
call_id
);
}
#[must_use]
pub fn finalize(mut self) -> Vec<Tool> {
if self.finalized {
return std::mem::take(&mut self.resolved_tools);
}
tracing::debug!(
"[TOOL_SEARCH] finalizing: predefined={}, searched={}, priority={:?}",
self.predefined_tools.len(),
self.searched_tools.len(),
self.priority
);
let result = match self.priority {
ToolPriority::PreferDefined => {
if !self.searched_tools.is_empty() {
tracing::info!(
"[TOOL_SEARCH] PreferDefined: ignoring {} searched tools",
self.searched_tools.len()
);
}
std::mem::take(&mut self.predefined_tools)
}
ToolPriority::PreferSearched => {
if !self.predefined_tools.is_empty() {
tracing::info!(
"[TOOL_SEARCH] PreferSearched: ignoring {} predefined tools",
self.predefined_tools.len()
);
}
std::mem::take(&mut self.searched_tools)
}
ToolPriority::Merge => {
merge_tools_map(&self.predefined_tools, &self.searched_tools)
}
};
self.finalized = true;
self.resolved_tools = result.clone();
tracing::debug!(
"[TOOL_SEARCH] resolved tools: count={}",
result.len()
);
result
}
pub fn resolved_tools(&self) -> &[Tool] {
&self.resolved_tools
}
pub fn priority(&self) -> ToolPriority {
self.priority
}
pub fn has_pending_searches(&self) -> bool {
!self.pending_calls.is_empty()
}
pub fn predefined_count(&self) -> usize {
self.predefined_tools.len()
}
pub fn searched_count(&self) -> usize {
self.searched_tools.len()
}
}
pub(crate) fn merge_tools_map(first: &[Tool], second: &[Tool]) -> Vec<Tool> {
use std::collections::HashMap;
let mut name_to_tool: HashMap<String, &Tool> = HashMap::new();
for tool in first {
if let Some(name) = &tool.name {
name_to_tool.insert(name.clone(), tool);
}
}
for tool in second {
if let Some(name) = &tool.name {
name_to_tool.insert(name.clone(), tool);
}
}
let mut result: Vec<Tool> = Vec::new();
let mut seen: HashSet<String> = HashSet::new();
for tool in first {
if let Some(name) = &tool.name
&& !seen.contains(name)
{
result.push(tool.clone());
seen.insert(name.clone());
}
}
for tool in second {
if let Some(name) = &tool.name {
let first_had = first.iter().any(|t| t.name.as_ref() == Some(name));
if !first_had || !seen.contains(name) {
if !seen.contains(name) {
result.push(tool.clone());
seen.insert(name.clone());
}
}
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::response_api::{Tool, ToolType};
fn make_tool(name: &str) -> Tool {
Tool {
tool_type: ToolType::Function,
name: Some(name.to_string()),
description: None,
parameters: None,
strict: None,
extra: std::collections::HashMap::new(),
}
}
#[test]
fn test_prefer_defined_keeps_predefined() {
let mut ctx = ToolSearchContext::new(ToolPriority::PreferDefined);
ctx.set_predefined_tools(vec![make_tool("tool_a"), make_tool("tool_b")]);
ctx.add_searched_tools(vec![make_tool("tool_c")], "call_1");
let resolved = ctx.finalize().clone();
assert_eq!(resolved.len(), 2);
assert!(resolved.iter().any(|t| t.name.as_ref().unwrap() == "tool_a"));
assert!(resolved.iter().any(|t| t.name.as_ref().unwrap() == "tool_b"));
}
#[test]
fn test_prefer_searched_keeps_searched() {
let mut ctx = ToolSearchContext::new(ToolPriority::PreferSearched);
ctx.set_predefined_tools(vec![make_tool("tool_a"), make_tool("tool_b")]);
ctx.add_searched_tools(vec![make_tool("tool_c"), make_tool("tool_d")], "call_1");
let resolved = ctx.finalize().clone();
assert_eq!(resolved.len(), 2);
assert!(resolved.iter().any(|t| t.name.as_ref().unwrap() == "tool_c"));
assert!(resolved.iter().any(|t| t.name.as_ref().unwrap() == "tool_d"));
}
#[test]
fn test_merge_combines_all_unique() {
let mut ctx = ToolSearchContext::new(ToolPriority::Merge);
ctx.set_predefined_tools(vec![make_tool("tool_a"), make_tool("tool_b")]);
ctx.add_searched_tools(vec![make_tool("tool_c"), make_tool("tool_d")], "call_1");
let resolved = ctx.finalize().clone();
assert_eq!(resolved.len(), 4);
}
#[test]
fn test_merge_searched_overrides_on_conflict() {
let mut ctx = ToolSearchContext::new(ToolPriority::Merge);
ctx.set_predefined_tools(vec![make_tool("tool_a"), make_tool("tool_b")]);
ctx.add_searched_tools(vec![make_tool("tool_b"), make_tool("tool_c")], "call_1");
let resolved = ctx.finalize().clone();
assert_eq!(resolved.len(), 3);
let _tool_b = resolved.iter().find(|t| t.name.as_ref().unwrap() == "tool_b");
assert!(_tool_b.is_some());
}
#[test]
fn test_register_search_call() {
let mut ctx = ToolSearchContext::new(ToolPriority::Merge);
ctx.register_search_call("call_123");
assert!(ctx.is_registered_search("call_123"));
assert!(!ctx.is_registered_search("call_456"));
}
#[test]
fn test_priority_from_str() {
assert_eq!("prefer_defined".parse::<ToolPriority>(), Ok(ToolPriority::PreferDefined));
assert_eq!("prefer-searched".parse::<ToolPriority>(), Ok(ToolPriority::PreferSearched));
assert_eq!("merge".parse::<ToolPriority>(), Ok(ToolPriority::Merge));
assert_eq!("unknown".parse::<ToolPriority>(), Ok(ToolPriority::Merge)); }
#[test]
fn test_finalize_idempotent() {
let mut ctx = ToolSearchContext::new(ToolPriority::Merge);
ctx.set_predefined_tools(vec![make_tool("tool_a")]);
ctx.add_searched_tools(vec![make_tool("tool_b")], "call_1");
let mut ctx2 = ToolSearchContext::new(ToolPriority::Merge);
ctx2.set_predefined_tools(vec![make_tool("tool_a")]);
ctx2.add_searched_tools(vec![make_tool("tool_b")], "call_1");
let first = ctx.finalize();
assert_eq!(first.len(), 2);
let second = ctx2.finalize();
assert_eq!(second.len(), 2);
}
}