use std::collections::{HashMap, HashSet};
use crate::error::{Result, SqzError};
use crate::preset::Preset;
use crate::types::ToolId;
#[derive(Debug, Clone)]
pub struct ToolDefinition {
pub id: ToolId,
pub name: String,
pub description: String,
}
type BagOfWords = HashSet<String>;
fn tokenize(text: &str) -> BagOfWords {
text.split(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty())
.map(|s| s.to_lowercase())
.collect()
}
fn jaccard(a: &BagOfWords, b: &BagOfWords) -> f64 {
if a.is_empty() && b.is_empty() {
return 0.0;
}
let intersection = a.intersection(b).count() as f64;
let union = a.union(b).count() as f64;
if union == 0.0 {
0.0
} else {
intersection / union
}
}
pub struct ToolSelector {
bags: HashMap<ToolId, BagOfWords>,
tool_ids: Vec<ToolId>,
threshold: f64,
default_tools: Vec<ToolId>,
}
impl ToolSelector {
pub fn new(_model_path: &std::path::Path, preset: &Preset) -> Result<Self> {
let threshold = preset.tool_selection.similarity_threshold;
let default_tools = preset.tool_selection.default_tools.clone();
Ok(Self {
bags: HashMap::new(),
tool_ids: Vec::new(),
threshold,
default_tools,
})
}
pub fn register_tools(&mut self, tools: &[ToolDefinition]) -> Result<()> {
for tool in tools {
let bag = tokenize(&tool.description);
if !self.bags.contains_key(&tool.id) {
self.tool_ids.push(tool.id.clone());
}
self.bags.insert(tool.id.clone(), bag);
}
Ok(())
}
pub fn select(&self, intent: &str, max_tools: usize) -> Result<Vec<ToolId>> {
let tool_count = self.tool_ids.len();
if tool_count == 0 {
return Ok(self.default_tools.clone());
}
let intent_bag = tokenize(intent);
let mut scored: Vec<(f64, &ToolId)> = self
.tool_ids
.iter()
.map(|id| {
let bag = self.bags.get(id).expect("bag must exist for registered tool");
let score = jaccard(&intent_bag, bag);
(score, id)
})
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.1.cmp(b.1)));
let best_score = scored.first().map(|(s, _)| *s).unwrap_or(0.0);
if best_score < self.threshold {
return Ok(self.default_tools.clone());
}
let upper = max_tools.min(5).min(tool_count);
let lower = 3_usize.min(tool_count);
let count = upper.max(lower);
let result = scored
.into_iter()
.take(count)
.map(|(_, id)| id.clone())
.collect();
Ok(result)
}
pub fn update_tool(&mut self, tool: &ToolDefinition) -> Result<()> {
if !self.bags.contains_key(&tool.id) {
return Err(SqzError::Other(format!(
"tool '{}' is not registered; use register_tools first",
tool.id
)));
}
let bag = tokenize(&tool.description);
self.bags.insert(tool.id.clone(), bag);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
use std::path::Path;
fn make_preset_with_threshold(threshold: f64, default_tools: Vec<String>) -> Preset {
let mut p = Preset::default();
p.tool_selection.similarity_threshold = threshold;
p.tool_selection.default_tools = default_tools;
p
}
fn make_tools(n: usize) -> Vec<ToolDefinition> {
(0..n)
.map(|i| ToolDefinition {
id: format!("tool_{i}"),
name: format!("Tool {i}"),
description: format!(
"This tool performs operation number {i} for task category alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu nu xi omicron pi rho sigma tau upsilon phi chi psi omega {i}"
),
})
.collect()
}
#[test]
fn test_tokenize_basic() {
let bag = tokenize("hello world foo");
assert!(bag.contains("hello"));
assert!(bag.contains("world"));
assert!(bag.contains("foo"));
}
#[test]
fn test_tokenize_punctuation() {
let bag = tokenize("read_file: reads a file.");
assert!(bag.contains("read"));
assert!(bag.contains("file"));
assert!(bag.contains("reads"));
assert!(bag.contains("a"));
}
#[test]
fn test_jaccard_identical() {
let a = tokenize("read file");
let b = tokenize("read file");
assert!((jaccard(&a, &b) - 1.0).abs() < 1e-9);
}
#[test]
fn test_jaccard_disjoint() {
let a = tokenize("alpha beta");
let b = tokenize("gamma delta");
assert!((jaccard(&a, &b)).abs() < 1e-9);
}
#[test]
fn test_select_returns_between_3_and_5_for_large_set() {
let preset = make_preset_with_threshold(0.0, vec![]);
let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
let tools = make_tools(10);
selector.register_tools(&tools).unwrap();
let result = selector.select("operation task alpha beta", 5).unwrap();
assert!(result.len() >= 3, "expected >= 3, got {}", result.len());
assert!(result.len() <= 5, "expected <= 5, got {}", result.len());
}
#[test]
fn test_select_returns_at_most_tool_count_for_small_set() {
let preset = make_preset_with_threshold(0.0, vec![]);
let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
let tools = make_tools(2);
selector.register_tools(&tools).unwrap();
let result = selector.select("operation task", 5).unwrap();
assert!(result.len() <= 2, "expected <= 2, got {}", result.len());
}
#[test]
fn test_fallback_to_defaults_on_low_confidence() {
let defaults = vec!["default_a".to_string(), "default_b".to_string()];
let preset = make_preset_with_threshold(1.0, defaults.clone());
let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
let tools = make_tools(5);
selector.register_tools(&tools).unwrap();
let result = selector.select("completely unrelated xyz", 5).unwrap();
assert_eq!(result, defaults);
}
#[test]
fn test_update_tool_changes_embedding() {
let preset = make_preset_with_threshold(0.0, vec![]);
let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
let tools = vec![ToolDefinition {
id: "t1".to_string(),
name: "T1".to_string(),
description: "alpha beta gamma".to_string(),
}];
selector.register_tools(&tools).unwrap();
let updated = ToolDefinition {
id: "t1".to_string(),
name: "T1".to_string(),
description: "delta epsilon zeta".to_string(),
};
selector.update_tool(&updated).unwrap();
let bag = selector.bags.get("t1").unwrap();
assert!(bag.contains("delta"));
assert!(!bag.contains("alpha"));
}
#[test]
fn test_update_tool_unregistered_returns_error() {
let preset = make_preset_with_threshold(0.0, vec![]);
let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
let result = selector.update_tool(&ToolDefinition {
id: "nonexistent".to_string(),
name: "X".to_string(),
description: "desc".to_string(),
});
assert!(result.is_err());
}
#[test]
fn test_empty_tool_set_returns_defaults() {
let defaults = vec!["fallback".to_string()];
let preset = make_preset_with_threshold(0.0, defaults.clone());
let selector = ToolSelector::new(Path::new(""), &preset).unwrap();
let result = selector.select("anything", 5).unwrap();
assert_eq!(result, defaults);
}
fn arb_tool_count_and_intent() -> impl Strategy<Value = (usize, String)> {
(5usize..=20usize, "[a-z ]{5,40}".prop_map(|s| s.trim().to_string()))
}
fn arb_small_tool_count_and_intent() -> impl Strategy<Value = (usize, String)> {
(1usize..=4usize, "[a-z ]{5,40}".prop_map(|s| s.trim().to_string()))
}
proptest! {
#[test]
fn prop_tool_selection_cardinality_large(
(tool_count, intent) in arb_tool_count_and_intent()
) {
let preset = make_preset_with_threshold(0.0, vec![]);
let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
let tools = make_tools(tool_count);
selector.register_tools(&tools).unwrap();
let result = selector.select(&intent, 5).unwrap();
prop_assert!(
result.len() >= 3,
"expected >= 3 tools, got {} (tool_count={}, intent='{}')",
result.len(), tool_count, intent
);
prop_assert!(
result.len() <= 5,
"expected <= 5 tools, got {} (tool_count={}, intent='{}')",
result.len(), tool_count, intent
);
}
#[test]
fn prop_tool_selection_cardinality_small(
(tool_count, intent) in arb_small_tool_count_and_intent()
) {
let preset = make_preset_with_threshold(0.0, vec![]);
let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
let tools = make_tools(tool_count);
selector.register_tools(&tools).unwrap();
let result = selector.select(&intent, 5).unwrap();
prop_assert!(
result.len() <= tool_count,
"expected <= {} tools, got {} (intent='{}')",
tool_count, result.len(), intent
);
}
}
}