use crate::repository::ToolRepository;
use crate::tools::{Tool, ToolSearchStrategy};
use anyhow::Result;
use async_trait::async_trait;
use regex::Regex;
use std::cmp::Ordering;
use std::collections::HashSet;
use std::sync::Arc;
pub struct TagSearchStrategy {
tool_repository: Arc<dyn ToolRepository>,
description_weight: f64,
word_regex: Regex,
}
impl TagSearchStrategy {
pub fn new(repo: Arc<dyn ToolRepository>, description_weight: f64) -> Self {
Self {
tool_repository: repo,
description_weight,
word_regex: Regex::new(r"\w+").unwrap(),
}
}
}
#[derive(Clone)]
struct ScoredTool {
tool: Tool,
score: f64,
}
#[async_trait]
impl ToolSearchStrategy for TagSearchStrategy {
async fn search_tools(&self, query: &str, limit: usize) -> Result<Vec<Tool>> {
let query_lower = query.trim().to_lowercase();
let query_word_set: HashSet<String> = self
.word_regex
.find_iter(&query_lower)
.map(|m| m.as_str().to_string())
.collect();
let tools = self.tool_repository.get_tools().await?;
if tools.is_empty() {
return Ok(Vec::new());
}
let mut positives = Vec::new();
let mut nonpositives = Vec::new();
for tool in tools {
let score = self.score_tool(&tool, &query_lower, &query_word_set);
let entry = ScoredTool { tool, score };
if score > 0.0 {
positives.push(entry);
} else {
nonpositives.push(entry);
}
}
if limit == 0 {
if !positives.is_empty() {
positives.sort_unstable_by(compare_scored);
return Ok(positives.into_iter().map(|st| st.tool).collect());
}
nonpositives.sort_unstable_by(compare_scored);
return Ok(nonpositives.into_iter().map(|st| st.tool).collect());
}
if !positives.is_empty() {
take_top_n(&mut positives, limit);
return Ok(positives.into_iter().map(|st| st.tool).collect());
}
if nonpositives.is_empty() {
return Ok(Vec::new());
}
take_top_n(&mut nonpositives, limit);
Ok(nonpositives.into_iter().map(|st| st.tool).collect())
}
}
impl TagSearchStrategy {
fn score_tool(&self, tool: &Tool, query_lower: &str, query_word_set: &HashSet<String>) -> f64 {
let mut score = 0.0;
for tag in &tool.tags {
let tag_lower = tag.to_ascii_lowercase();
if query_lower.contains(&tag_lower) {
score += 1.0;
}
for m in self.word_regex.find_iter(&tag_lower) {
if query_word_set.contains(m.as_str()) {
score += self.description_weight;
}
}
}
for m in self.word_regex.find_iter(&tool.description) {
let word = m.as_str().to_ascii_lowercase();
if word.len() > 2 && query_word_set.contains(&word) {
score += self.description_weight;
}
}
score
}
}
fn compare_scored(a: &ScoredTool, b: &ScoredTool) -> Ordering {
b.score
.total_cmp(&a.score)
.then_with(|| a.tool.name.cmp(&b.tool.name))
}
fn take_top_n(scored: &mut Vec<ScoredTool>, limit: usize) {
if limit == 0 {
scored.sort_unstable_by(compare_scored);
return;
}
if scored.len() > limit {
let pivot = limit - 1;
scored.select_nth_unstable_by(pivot, compare_scored);
scored.truncate(limit);
}
scored.sort_unstable_by(compare_scored);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::base::{BaseProvider, ProviderType};
use crate::repository::in_memory::InMemoryToolRepository;
use crate::tools::ToolInputOutputSchema;
use std::sync::Arc;
fn schema() -> ToolInputOutputSchema {
ToolInputOutputSchema {
type_: "object".to_string(),
properties: None,
required: None,
description: None,
title: None,
items: None,
enum_: None,
minimum: None,
maximum: None,
format: None,
}
}
fn make_tool(name: &str, description: &str, tags: &[&str]) -> Tool {
Tool {
name: name.to_string(),
description: description.to_string(),
inputs: schema(),
outputs: schema(),
tags: tags.iter().map(|t| t.to_string()).collect(),
average_response_size: None,
provider: None,
}
}
async fn setup_repo(tools: Vec<Tool>) -> Arc<InMemoryToolRepository> {
let repo = Arc::new(InMemoryToolRepository::new());
let provider = Arc::new(BaseProvider {
name: "test".to_string(),
provider_type: ProviderType::Http,
auth: None,
allowed_communication_protocols: None,
});
repo.save_provider_with_tools(provider, tools)
.await
.unwrap();
repo
}
#[tokio::test]
async fn returns_top_scoring_tools_with_limit() {
let repo = setup_repo(vec![
make_tool(
"p1.weather_primary",
"Weather forecast endpoint",
&["weather"],
),
make_tool("p1.weather_backup", "Weather data service", &["climate"]),
make_tool("p1.finance", "Stock price lookup", &["stocks"]),
])
.await;
let strategy = TagSearchStrategy::new(repo, 0.5);
let results = strategy.search_tools("weather forecast", 2).await.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].name, "p1.weather_primary");
assert_eq!(results[1].name, "p1.weather_backup");
}
#[tokio::test]
async fn falls_back_when_no_positive_scores() {
let repo = setup_repo(vec![
make_tool("p1.alpha", "No overlap here", &["alpha"]),
make_tool("p1.beta", "Still nothing useful", &["beta"]),
make_tool("p1.gamma", "More unrelated content", &["gamma"]),
])
.await;
let strategy = TagSearchStrategy::new(repo, 1.0);
let results = strategy.search_tools("nonsense", 2).await.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].name, "p1.alpha");
assert_eq!(results[1].name, "p1.beta");
}
#[tokio::test]
async fn ties_are_sorted_by_name_within_limit() {
let repo = setup_repo(vec![
make_tool("p1.alpha", "Math helper", &["math"]),
make_tool("p1.beta", "Math helper", &["math"]),
make_tool("p1.gamma", "Math helper", &["math"]),
])
.await;
let strategy = TagSearchStrategy::new(repo, 1.0);
let results = strategy.search_tools("math", 2).await.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].name, "p1.alpha");
assert_eq!(results[1].name, "p1.beta");
}
}