use crate::types::ToolDefinition;
#[must_use]
pub fn select_tools(
query: &str,
tools: &[ToolDefinition],
max_tools: usize,
) -> Vec<ToolDefinition> {
if tools.len() <= max_tools {
return tools.to_vec();
}
let query_lower = query.to_lowercase();
let query_words: std::collections::HashSet<&str> = query_lower.split_whitespace().collect();
let mut scored: Vec<(usize, f64)> = tools
.iter()
.enumerate()
.map(|(i, tool)| {
let score = score_tool(&query_lower, &query_words, tool);
(i, score)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut selected_indices: Vec<usize> = Vec::with_capacity(max_tools);
for (i, tool) in tools.iter().enumerate() {
if query_lower.contains(&tool.name.to_lowercase()) && !selected_indices.contains(&i) {
selected_indices.push(i);
}
}
for (i, _score) in &scored {
if selected_indices.len() >= max_tools {
break;
}
if !selected_indices.contains(i) {
selected_indices.push(*i);
}
}
selected_indices.sort_unstable();
selected_indices
.into_iter()
.map(|i| tools[i].clone())
.collect()
}
fn score_tool(
query_lower: &str,
query_words: &std::collections::HashSet<&str>,
tool: &ToolDefinition,
) -> f64 {
let mut score = 0.0;
if query_lower.contains(&tool.name.to_lowercase()) {
score += 10.0;
}
let name_words: Vec<&str> = tool.name.split('_').collect();
for name_word in &name_words {
if query_words.contains(name_word.to_lowercase().as_str()) {
score += 3.0;
}
}
let desc_lower = tool.description.to_lowercase();
let desc_words: std::collections::HashSet<&str> = desc_lower.split_whitespace().collect();
let overlap = query_words.intersection(&desc_words).count();
#[allow(clippy::cast_precision_loss)]
{
score += overlap as f64;
}
for param_name in tool.parameters.properties.keys() {
if query_words.contains(param_name.to_lowercase().as_str()) {
score += 2.0;
}
}
score
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::ToolParameters;
use std::collections::HashMap;
fn make_tool(name: &str, desc: &str) -> ToolDefinition {
ToolDefinition {
name: name.to_string(),
description: desc.to_string(),
parameters: ToolParameters {
schema_type: "object".to_string(),
properties: HashMap::new(),
required: vec![],
},
icon: None,
}
}
#[test]
fn returns_all_when_under_limit() {
let tools = vec![make_tool("a", "Tool A"), make_tool("b", "Tool B")];
let result = select_tools("hello", &tools, 5);
assert_eq!(result.len(), 2);
}
#[test]
fn selects_mentioned_tool() {
let tools = vec![
make_tool("get_weather", "Get weather forecast"),
make_tool("search_web", "Search the web"),
make_tool("get_calendar", "Get calendar events"),
];
let result = select_tools("What's the weather today?", &tools, 2);
assert!(result.iter().any(|t| t.name == "get_weather"));
}
#[test]
fn respects_max_tools_limit() {
let tools: Vec<_> = (0..10)
.map(|i| make_tool(&format!("tool_{i}"), "desc"))
.collect();
let result = select_tools("query", &tools, 3);
assert_eq!(result.len(), 3);
}
#[test]
fn explicitly_named_tool_always_included() {
let tools = vec![
make_tool("irrelevant_1", "Not relevant"),
make_tool("irrelevant_2", "Not relevant"),
make_tool("search_web", "Search the web"),
make_tool("irrelevant_3", "Not relevant"),
];
let result = select_tools("use search_web to find something", &tools, 2);
assert!(result.iter().any(|t| t.name == "search_web"));
}
}