use std::collections::{HashMap, HashSet};
use zeph_common::ToolName;
use zeph_common::math::cosine_similarity;
use crate::config::ToolDependency;
#[derive(Debug, Clone)]
pub struct ToolEmbedding {
pub tool_id: ToolName,
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InclusionReason {
AlwaysOn,
NameMentioned,
SimilarityRank,
ShortDescription,
NoEmbedding,
DependencyMet,
PreferenceBoost,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DependencyExclusion {
pub tool_id: ToolName,
pub unmet_requires: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct ToolFilterResult {
pub included: HashSet<String>,
pub excluded: Vec<String>,
pub scores: Vec<(String, f32)>,
pub inclusion_reasons: Vec<(String, InclusionReason)>,
pub dependency_exclusions: Vec<DependencyExclusion>,
}
#[derive(Debug, Clone, Default)]
pub struct ToolDependencyGraph {
deps: HashMap<String, ToolDependency>,
}
impl ToolDependencyGraph {
#[must_use]
pub fn new(deps: HashMap<String, ToolDependency>) -> Self {
if deps.is_empty() {
return Self { deps };
}
let cycled = detect_cycles(&deps);
if !cycled.is_empty() {
tracing::warn!(
tools = ?cycled,
"tool dependency graph: cycles detected, removing requires for cycle participants"
);
}
let mut resolved = deps;
for tool_id in &cycled {
if let Some(dep) = resolved.get_mut(tool_id) {
dep.requires.clear();
}
}
Self { deps: resolved }
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.deps.is_empty()
}
#[must_use]
pub fn requirements_met(&self, tool_id: &str, completed: &HashSet<String>) -> bool {
self.deps
.get(tool_id)
.is_none_or(|d| d.requires.iter().all(|r| completed.contains(r)))
}
#[must_use]
pub fn unmet_requires<'a>(
&'a self,
tool_id: &str,
completed: &HashSet<String>,
) -> Vec<&'a str> {
self.deps.get(tool_id).map_or_else(Vec::new, |d| {
d.requires
.iter()
.filter(|r| !completed.contains(r.as_str()))
.map(String::as_str)
.collect()
})
}
#[must_use]
pub fn preference_boost(
&self,
tool_id: &str,
completed: &HashSet<String>,
boost_per_dep: f32,
max_total_boost: f32,
) -> f32 {
self.deps.get(tool_id).map_or(0.0, |d| {
let met = d
.prefers
.iter()
.filter(|p| completed.contains(p.as_str()))
.count();
#[allow(clippy::cast_precision_loss)]
let boost = met as f32 * boost_per_dep;
boost.min(max_total_boost)
})
}
pub fn apply(
&self,
result: &mut ToolFilterResult,
completed: &HashSet<String>,
boost_per_dep: f32,
max_total_boost: f32,
always_on: &HashSet<String>,
) {
if self.deps.is_empty() {
return;
}
let bypassed: HashSet<&str> = result
.inclusion_reasons
.iter()
.filter(|(_, r)| matches!(r, InclusionReason::AlwaysOn))
.map(|(id, _)| id.as_str())
.collect();
let mut to_exclude: Vec<DependencyExclusion> = Vec::new();
for tool_id in &result.included {
if bypassed.contains(tool_id.as_str()) {
continue;
}
let unmet: Vec<String> = self
.unmet_requires(tool_id, completed)
.into_iter()
.map(str::to_owned)
.collect();
if !unmet.is_empty() {
to_exclude.push(DependencyExclusion {
tool_id: tool_id.as_str().into(),
unmet_requires: unmet,
});
}
}
let non_always_on_included: usize = result
.included
.iter()
.filter(|id| !always_on.contains(id.as_str()))
.count();
if !to_exclude.is_empty() && to_exclude.len() >= non_always_on_included {
tracing::warn!(
gated = to_exclude.len(),
non_always_on = non_always_on_included,
"tool dependency graph: all non-always-on tools would be blocked; \
disabling hard gates for this turn"
);
to_exclude.clear();
}
for excl in &to_exclude {
result.included.remove(excl.tool_id.as_str());
result.excluded.push(excl.tool_id.to_string());
tracing::debug!(
tool_id = %excl.tool_id,
unmet = ?excl.unmet_requires,
"tool dependency gate: excluded (requires not met)"
);
}
result.dependency_exclusions = to_exclude;
for (tool_id, score) in &mut result.scores {
if !result.included.contains(tool_id) {
continue;
}
let boost = self.preference_boost(tool_id, completed, boost_per_dep, max_total_boost);
if boost > 0.0 {
*score += boost;
let already_recorded = result.inclusion_reasons.iter().any(|(id, _)| id == tool_id);
if !already_recorded {
result
.inclusion_reasons
.push((tool_id.clone(), InclusionReason::PreferenceBoost));
}
tracing::debug!(
tool_id = %tool_id,
boost,
"tool dependency: preference boost applied"
);
}
}
result
.scores
.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
}
#[must_use]
pub fn filter_tool_names<'a>(
&self,
names: &[&'a str],
completed: &HashSet<String>,
always_on: &HashSet<String>,
) -> Vec<&'a str> {
names
.iter()
.copied()
.filter(|n| always_on.contains(*n) || self.requirements_met(n, completed))
.collect()
}
}
fn detect_cycles(deps: &HashMap<String, ToolDependency>) -> HashSet<String> {
#[derive(Clone, Copy, PartialEq)]
enum State {
Unvisited,
InProgress,
Done,
}
let mut state: HashMap<&str, State> = HashMap::new();
let mut cycled: HashSet<String> = HashSet::new();
for start in deps.keys() {
if state
.get(start.as_str())
.copied()
.unwrap_or(State::Unvisited)
!= State::Unvisited
{
continue;
}
let mut stack: Vec<(&str, usize)> = vec![(start.as_str(), 0)];
state.insert(start.as_str(), State::InProgress);
while let Some((node, child_idx)) = stack.last_mut() {
let node = *node;
let requires = deps
.get(node)
.map_or(&[] as &[String], |d| d.requires.as_slice());
if *child_idx >= requires.len() {
state.insert(node, State::Done);
stack.pop();
continue;
}
let child = requires[*child_idx].as_str();
*child_idx += 1;
match state.get(child).copied().unwrap_or(State::Unvisited) {
State::InProgress => {
let cycle_start = stack.iter().position(|(n, _)| *n == child);
if let Some(start) = cycle_start {
for (path_node, _) in &stack[start..] {
cycled.insert((*path_node).to_owned());
}
}
cycled.insert(child.to_owned());
}
State::Unvisited => {
state.insert(child, State::InProgress);
stack.push((child, 0));
}
State::Done => {}
}
}
}
cycled
}
pub struct ToolSchemaFilter {
always_on: HashSet<String>,
top_k: usize,
min_description_words: usize,
embeddings: Vec<ToolEmbedding>,
version: u64,
}
impl ToolSchemaFilter {
#[must_use]
pub fn new(
always_on: Vec<String>,
top_k: usize,
min_description_words: usize,
embeddings: Vec<ToolEmbedding>,
) -> Self {
Self {
always_on: always_on.into_iter().collect(),
top_k,
min_description_words,
embeddings,
version: 0,
}
}
#[must_use]
pub fn version(&self) -> u64 {
self.version
}
#[must_use]
pub fn embedding_count(&self) -> usize {
self.embeddings.len()
}
#[must_use]
pub fn top_k(&self) -> usize {
self.top_k
}
#[must_use]
pub fn always_on_count(&self) -> usize {
self.always_on.len()
}
pub fn recompute(&mut self, embeddings: Vec<ToolEmbedding>) {
self.embeddings = embeddings;
self.version += 1;
}
#[must_use]
pub fn filter(
&self,
all_tool_ids: &[&str],
tool_descriptions: &[(&str, &str)],
query: &str,
query_embedding: &[f32],
) -> ToolFilterResult {
let mut included = HashSet::new();
let mut inclusion_reasons = Vec::new();
for id in all_tool_ids {
if self.always_on.contains(*id) {
included.insert((*id).to_owned());
inclusion_reasons.push(((*id).to_owned(), InclusionReason::AlwaysOn));
}
}
let mentioned = find_mentioned_tool_ids(query, all_tool_ids);
for id in &mentioned {
if included.insert(id.clone()) {
inclusion_reasons.push((id.clone(), InclusionReason::NameMentioned));
}
}
for &(id, desc) in tool_descriptions {
let word_count = desc.split_whitespace().count();
if word_count < self.min_description_words && included.insert(id.to_owned()) {
inclusion_reasons.push((id.to_owned(), InclusionReason::ShortDescription));
}
}
let mut scores: Vec<(String, f32)> = self
.embeddings
.iter()
.filter(|e| !included.contains(e.tool_id.as_str()))
.map(|e| {
let score = cosine_similarity(query_embedding, &e.embedding);
(e.tool_id.to_string(), score)
})
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let take = if self.top_k == 0 {
scores.len()
} else {
self.top_k.min(scores.len())
};
for (id, _score) in scores.iter().take(take) {
if included.insert(id.clone()) {
inclusion_reasons.push((id.clone(), InclusionReason::SimilarityRank));
}
}
let embedded_ids: HashSet<&str> =
self.embeddings.iter().map(|e| e.tool_id.as_str()).collect();
for id in all_tool_ids {
if !included.contains(*id) && !embedded_ids.contains(*id) {
included.insert((*id).to_owned());
inclusion_reasons.push(((*id).to_owned(), InclusionReason::NoEmbedding));
}
}
let excluded: Vec<String> = all_tool_ids
.iter()
.filter(|id| !included.contains(**id))
.map(|id| (*id).to_owned())
.collect();
ToolFilterResult {
included,
excluded,
scores,
inclusion_reasons,
dependency_exclusions: Vec::new(),
}
}
}
#[must_use]
pub fn find_mentioned_tool_ids(query: &str, all_tool_ids: &[&str]) -> Vec<String> {
let query_lower = query.to_lowercase();
all_tool_ids
.iter()
.filter(|id| {
let id_lower = id.to_lowercase();
let mut start = 0;
while let Some(pos) = query_lower[start..].find(&id_lower) {
let abs_pos = start + pos;
let end_pos = abs_pos + id_lower.len();
let before_ok = abs_pos == 0
|| !query_lower.as_bytes()[abs_pos - 1].is_ascii_alphanumeric()
&& query_lower.as_bytes()[abs_pos - 1] != b'_';
let after_ok = end_pos >= query_lower.len()
|| !query_lower.as_bytes()[end_pos].is_ascii_alphanumeric()
&& query_lower.as_bytes()[end_pos] != b'_';
if before_ok && after_ok {
return true;
}
start = abs_pos + 1;
}
false
})
.map(|id| (*id).to_owned())
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn make_filter(always_on: Vec<&str>, top_k: usize) -> ToolSchemaFilter {
ToolSchemaFilter::new(
always_on.into_iter().map(String::from).collect(),
top_k,
5,
vec![
ToolEmbedding {
tool_id: "grep".into(),
embedding: vec![0.9, 0.1, 0.0],
},
ToolEmbedding {
tool_id: "write".into(),
embedding: vec![0.1, 0.9, 0.0],
},
ToolEmbedding {
tool_id: "find_path".into(),
embedding: vec![0.5, 0.5, 0.0],
},
ToolEmbedding {
tool_id: "web_scrape".into(),
embedding: vec![0.0, 0.0, 1.0],
},
ToolEmbedding {
tool_id: "diagnostics".into(),
embedding: vec![0.0, 0.1, 0.9],
},
],
)
}
#[test]
fn top_k_ranking_selects_most_similar() {
let filter = make_filter(vec!["bash"], 2);
let all_ids: Vec<&str> = vec![
"bash",
"grep",
"write",
"find_path",
"web_scrape",
"diagnostics",
];
let query_emb = vec![0.8, 0.2, 0.0]; let result = filter.filter(&all_ids, &[], "search for pattern", &query_emb);
assert!(result.included.contains("bash")); assert!(result.included.contains("grep")); assert!(result.included.contains("find_path")); assert!(!result.included.contains("web_scrape"));
assert!(!result.included.contains("diagnostics"));
}
#[test]
fn always_on_tools_always_included() {
let filter = make_filter(vec!["bash", "read"], 1);
let all_ids: Vec<&str> = vec!["bash", "read", "grep", "write"];
let query_emb = vec![0.0, 1.0, 0.0]; let result = filter.filter(&all_ids, &[], "test query", &query_emb);
assert!(result.included.contains("bash"));
assert!(result.included.contains("read"));
assert!(result.included.contains("write")); assert!(!result.included.contains("grep"));
}
#[test]
fn name_mention_force_includes() {
let filter = make_filter(vec!["bash"], 1);
let all_ids: Vec<&str> = vec!["bash", "grep", "web_scrape", "write"];
let query_emb = vec![0.0, 1.0, 0.0]; let result = filter.filter(&all_ids, &[], "use web_scrape to fetch", &query_emb);
assert!(result.included.contains("web_scrape")); assert!(result.included.contains("write")); assert!(result.included.contains("bash")); }
#[test]
fn short_mcp_description_auto_included() {
let filter = make_filter(vec!["bash"], 1);
let all_ids: Vec<&str> = vec!["bash", "grep", "mcp_query"];
let descriptions: Vec<(&str, &str)> = vec![
("mcp_query", "Run query"),
("grep", "Search file contents recursively"),
];
let query_emb = vec![0.9, 0.1, 0.0];
let result = filter.filter(&all_ids, &descriptions, "test", &query_emb);
assert!(result.included.contains("mcp_query")); }
#[test]
fn empty_embeddings_includes_all_via_no_embedding_fallback() {
let filter = ToolSchemaFilter::new(vec!["bash".into()], 6, 5, vec![]);
let all_ids: Vec<&str> = vec!["bash", "grep", "write"];
let query_emb = vec![0.5, 0.5, 0.0];
let result = filter.filter(&all_ids, &[], "test", &query_emb);
assert!(result.included.contains("bash"));
assert!(result.included.contains("grep"));
assert!(result.included.contains("write"));
assert!(result.excluded.is_empty());
}
#[test]
fn top_k_zero_includes_all_filterable() {
let filter = make_filter(vec!["bash"], 0);
let all_ids: Vec<&str> = vec![
"bash",
"grep",
"write",
"find_path",
"web_scrape",
"diagnostics",
];
let query_emb = vec![0.1, 0.1, 0.1];
let result = filter.filter(&all_ids, &[], "test", &query_emb);
assert_eq!(result.included.len(), 6); assert!(result.excluded.is_empty());
}
#[test]
fn top_k_exceeds_filterable_count_includes_all() {
let filter = make_filter(vec!["bash"], 100);
let all_ids: Vec<&str> = vec![
"bash",
"grep",
"write",
"find_path",
"web_scrape",
"diagnostics",
];
let query_emb = vec![0.1, 0.1, 0.1];
let result = filter.filter(&all_ids, &[], "test", &query_emb);
assert_eq!(result.included.len(), 6);
}
#[test]
fn accessors_return_configured_values() {
let filter = make_filter(vec!["bash", "read"], 7);
assert_eq!(filter.top_k(), 7);
assert_eq!(filter.always_on_count(), 2);
assert_eq!(filter.embedding_count(), 5);
}
#[test]
fn version_counter_incremented_on_recompute() {
let mut filter = make_filter(vec![], 3);
assert_eq!(filter.version(), 0);
filter.recompute(vec![]);
assert_eq!(filter.version(), 1);
filter.recompute(vec![]);
assert_eq!(filter.version(), 2);
}
#[test]
fn inclusion_reason_correctness() {
let filter = make_filter(vec!["bash"], 1);
let all_ids: Vec<&str> = vec!["bash", "grep", "web_scrape", "write"];
let descriptions: Vec<(&str, &str)> = vec![("web_scrape", "Scrape")]; let query_emb = vec![0.1, 0.9, 0.0]; let result = filter.filter(&all_ids, &descriptions, "test query", &query_emb);
let reasons: std::collections::HashMap<String, InclusionReason> =
result.inclusion_reasons.into_iter().collect();
assert_eq!(reasons.get("bash"), Some(&InclusionReason::AlwaysOn));
assert_eq!(
reasons.get("web_scrape"),
Some(&InclusionReason::ShortDescription)
);
assert_eq!(reasons.get("write"), Some(&InclusionReason::SimilarityRank));
}
#[test]
fn cosine_similarity_identical_vectors() {
let v = vec![1.0, 2.0, 3.0];
let sim = cosine_similarity(&v, &v);
assert!((sim - 1.0).abs() < 1e-5);
}
#[test]
fn cosine_similarity_orthogonal_vectors() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
let sim = cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-5);
}
#[test]
fn cosine_similarity_empty_returns_zero() {
assert!(cosine_similarity(&[], &[]) < f32::EPSILON);
}
#[test]
fn cosine_similarity_mismatched_length_returns_zero() {
assert!(cosine_similarity(&[1.0], &[1.0, 2.0]) < f32::EPSILON);
}
#[test]
fn find_mentioned_tool_ids_case_insensitive() {
let ids = vec!["web_scrape", "grep", "Bash"];
let found = find_mentioned_tool_ids("use WEB_SCRAPE and BASH", &ids);
assert!(found.contains(&"web_scrape".to_owned()));
assert!(found.contains(&"Bash".to_owned()));
assert!(!found.contains(&"grep".to_owned()));
}
#[test]
fn find_mentioned_tool_ids_word_boundary_no_false_positives() {
let ids = vec!["read", "edit", "fetch", "grep"];
let found = find_mentioned_tool_ids("thread breadcrumb", &ids);
assert!(found.is_empty());
}
#[test]
fn find_mentioned_tool_ids_word_boundary_matches_standalone() {
let ids = vec!["read", "edit"];
let found = find_mentioned_tool_ids("please read and edit the file", &ids);
assert!(found.contains(&"read".to_owned()));
assert!(found.contains(&"edit".to_owned()));
}
fn make_dep_graph(rules: &[(&str, Vec<&str>, Vec<&str>)]) -> ToolDependencyGraph {
let deps = rules
.iter()
.map(|(id, requires, prefers)| {
(
(*id).to_owned(),
crate::config::ToolDependency {
requires: requires.iter().map(|s| (*s).to_owned()).collect(),
prefers: prefers.iter().map(|s| (*s).to_owned()).collect(),
},
)
})
.collect();
ToolDependencyGraph::new(deps)
}
fn completed(ids: &[&str]) -> HashSet<String> {
ids.iter().map(|s| (*s).to_owned()).collect()
}
#[test]
fn requirements_met_no_deps() {
let graph = make_dep_graph(&[]);
assert!(graph.requirements_met("any_tool", &completed(&[])));
}
#[test]
fn requirements_met_all_satisfied() {
let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
assert!(graph.requirements_met("apply_patch", &completed(&["read"])));
}
#[test]
fn requirements_met_unmet() {
let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
assert!(!graph.requirements_met("apply_patch", &completed(&[])));
}
#[test]
fn requirements_met_unconfigured_tool() {
let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
assert!(graph.requirements_met("grep", &completed(&[])));
}
#[test]
fn preference_boost_none_met() {
let graph = make_dep_graph(&[("format", vec![], vec!["search", "grep"])]);
let boost = graph.preference_boost("format", &completed(&[]), 0.15, 0.2);
assert!(boost < f32::EPSILON);
}
#[test]
fn preference_boost_partial() {
let graph = make_dep_graph(&[("format", vec![], vec!["search", "grep"])]);
let boost = graph.preference_boost("format", &completed(&["search"]), 0.15, 0.2);
assert!((boost - 0.15).abs() < 1e-5);
}
#[test]
fn preference_boost_capped_at_max() {
let graph = make_dep_graph(&[("format", vec![], vec!["a", "b", "c"])]);
let boost = graph.preference_boost("format", &completed(&["a", "b", "c"]), 0.15, 0.2);
assert!((boost - 0.2).abs() < 1e-5);
}
#[test]
fn cycle_detection_simple_cycle() {
let graph = make_dep_graph(&[
("tool_a", vec!["tool_b"], vec![]),
("tool_b", vec!["tool_a"], vec![]),
]);
assert!(graph.requirements_met("tool_a", &completed(&[])));
assert!(graph.requirements_met("tool_b", &completed(&[])));
}
#[test]
fn cycle_detection_does_not_affect_non_cycle_tools() {
let graph = make_dep_graph(&[
("tool_a", vec!["tool_b"], vec![]),
("tool_b", vec!["tool_c"], vec![]),
("tool_c", vec!["tool_d"], vec![]),
("tool_d", vec!["tool_c"], vec![]), ]);
assert!(graph.requirements_met("tool_c", &completed(&[])));
assert!(graph.requirements_met("tool_d", &completed(&[])));
assert!(!graph.requirements_met("tool_a", &completed(&[])));
assert!(!graph.requirements_met("tool_b", &completed(&[])));
}
#[test]
fn apply_excludes_gated_tool() {
let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
let filter = make_filter(vec!["bash"], 5);
let all_ids = vec!["bash", "read", "apply_patch", "grep"];
let query_emb = vec![0.5, 0.5, 0.0];
let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
result.included.insert("apply_patch".into());
let always_on: HashSet<String> = ["bash".into()].into();
graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
assert!(!result.included.contains("apply_patch"));
assert_eq!(result.dependency_exclusions.len(), 1);
assert_eq!(result.dependency_exclusions[0].tool_id, "apply_patch");
assert_eq!(result.dependency_exclusions[0].unmet_requires, vec!["read"]);
}
#[test]
fn apply_includes_gated_tool_when_dep_met() {
let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
let filter = make_filter(vec!["bash"], 5);
let all_ids = vec!["bash", "read", "apply_patch"];
let query_emb = vec![0.5, 0.5, 0.0];
let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
result.included.insert("apply_patch".into());
let always_on: HashSet<String> = ["bash".into()].into();
graph.apply(&mut result, &completed(&["read"]), 0.15, 0.2, &always_on);
assert!(result.included.contains("apply_patch"));
assert!(result.dependency_exclusions.is_empty());
}
#[test]
fn apply_deadlock_fallback_when_all_gated() {
let filter = ToolSchemaFilter::new(
vec!["bash".into()],
5,
5,
vec![], );
let graph = make_dep_graph(&[("only_tool", vec!["missing"], vec![])]);
let all_ids = vec!["bash", "only_tool"];
let query_emb = vec![0.5, 0.5, 0.0];
let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
assert!(result.included.contains("only_tool"));
assert!(result.included.contains("bash"));
let always_on: HashSet<String> = ["bash".into()].into();
graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
assert!(result.included.contains("only_tool"));
assert!(result.dependency_exclusions.is_empty());
}
#[test]
fn apply_always_on_bypasses_gate() {
let graph = make_dep_graph(&[("bash", vec!["nonexistent"], vec![])]);
let filter = make_filter(vec!["bash"], 5);
let all_ids = vec!["bash", "grep"];
let query_emb = vec![0.5, 0.5, 0.0];
let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
let always_on: HashSet<String> = ["bash".into()].into();
graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
assert!(result.included.contains("bash"));
}
#[test]
fn cycle_detection_does_not_clear_ancestor_requires() {
let graph = make_dep_graph(&[
("tool_a", vec!["tool_b"], vec![]),
("tool_b", vec!["tool_c"], vec![]),
("tool_c", vec!["tool_d"], vec![]),
("tool_d", vec!["tool_c"], vec![]),
]);
assert!(graph.requirements_met("tool_c", &completed(&[])));
assert!(graph.requirements_met("tool_d", &completed(&[])));
assert!(!graph.requirements_met("tool_a", &completed(&[])));
assert!(!graph.requirements_met("tool_b", &completed(&[])));
assert!(graph.requirements_met("tool_b", &completed(&["tool_c"])));
assert!(graph.requirements_met("tool_a", &completed(&["tool_b"])));
}
#[test]
fn name_mentioned_does_not_bypass_hard_gate() {
let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
let filter = make_filter(vec!["bash"], 5);
let all_ids = vec!["bash", "read", "apply_patch"];
let query_emb = vec![0.5, 0.5, 0.0];
let mut result = filter.filter(&all_ids, &[], "use apply_patch to fix the bug", &query_emb);
assert!(result.included.contains("apply_patch"));
let reason = result
.inclusion_reasons
.iter()
.find(|(id, _)| id == "apply_patch")
.map(|(_, r)| r);
assert_eq!(reason, Some(&InclusionReason::NameMentioned));
let always_on: HashSet<String> = ["bash".into()].into();
graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
assert!(!result.included.contains("apply_patch"));
assert_eq!(result.dependency_exclusions.len(), 1);
assert_eq!(result.dependency_exclusions[0].tool_id, "apply_patch");
}
#[test]
fn multi_turn_chain_two_steps() {
let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
let always_on: HashSet<String> = ["bash".into()].into();
let filter = ToolSchemaFilter::new(vec!["bash".into()], 5, 5, vec![]);
let all_ids = vec!["bash", "read", "apply_patch"];
let q = vec![0.5, 0.5, 0.0];
let mut result = filter.filter(&all_ids, &[], "fix bug", &q);
graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
assert!(!result.included.contains("apply_patch"));
assert_eq!(result.dependency_exclusions.len(), 1);
let mut result2 = filter.filter(&all_ids, &[], "fix bug", &q);
graph.apply(&mut result2, &completed(&["read"]), 0.15, 0.2, &always_on);
assert!(result2.included.contains("apply_patch"));
assert!(result2.dependency_exclusions.is_empty());
}
#[test]
fn multi_turn_chain_three_steps() {
let graph = make_dep_graph(&[
("search", vec!["read"], vec![]),
("apply_patch", vec!["search"], vec![]),
]);
let always_on: HashSet<String> = ["bash".into()].into();
let filter = ToolSchemaFilter::new(vec!["bash".into()], 5, 5, vec![]);
let all_ids = vec!["bash", "read", "search", "apply_patch"];
let q = vec![0.5, 0.5, 0.0];
let mut r1 = filter.filter(&all_ids, &[], "q", &q);
graph.apply(&mut r1, &completed(&[]), 0.15, 0.2, &always_on);
assert!(r1.included.contains("read"));
assert!(!r1.included.contains("search"));
assert!(!r1.included.contains("apply_patch"));
let mut r2 = filter.filter(&all_ids, &[], "q", &q);
graph.apply(&mut r2, &completed(&["read"]), 0.15, 0.2, &always_on);
assert!(r2.included.contains("search"));
assert!(!r2.included.contains("apply_patch"));
let mut r3 = filter.filter(&all_ids, &[], "q", &q);
graph.apply(
&mut r3,
&completed(&["read", "search"]),
0.15,
0.2,
&always_on,
);
assert!(r3.included.contains("apply_patch"));
assert!(r3.dependency_exclusions.is_empty());
}
#[test]
fn multi_turn_multi_requires_both_must_complete() {
let graph = make_dep_graph(&[("apply_patch", vec!["read", "search"], vec![])]);
let always_on: HashSet<String> = ["bash".into()].into();
let filter = ToolSchemaFilter::new(vec!["bash".into()], 5, 5, vec![]);
let all_ids = vec!["bash", "read", "search", "apply_patch"];
let q = vec![0.5, 0.5, 0.0];
let mut r1 = filter.filter(&all_ids, &[], "q", &q);
graph.apply(&mut r1, &completed(&["read"]), 0.15, 0.2, &always_on);
assert!(!r1.included.contains("apply_patch"));
let excl = &r1.dependency_exclusions[0];
assert_eq!(excl.unmet_requires, vec!["search"]);
let mut r2 = filter.filter(&all_ids, &[], "q", &q);
graph.apply(
&mut r2,
&completed(&["read", "search"]),
0.15,
0.2,
&always_on,
);
assert!(r2.included.contains("apply_patch"));
assert!(r2.dependency_exclusions.is_empty());
}
#[test]
fn multi_turn_preference_boost_accumulates() {
let graph = make_dep_graph(&[("format", vec![], vec!["search", "grep"])]);
let always_on: HashSet<String> = HashSet::new();
let filter = ToolSchemaFilter::new(
vec![],
5,
5,
vec![
ToolEmbedding {
tool_id: "format".into(),
embedding: vec![0.6, 0.4, 0.0],
},
ToolEmbedding {
tool_id: "search".into(),
embedding: vec![0.7, 0.3, 0.0],
},
ToolEmbedding {
tool_id: "grep".into(),
embedding: vec![0.8, 0.2, 0.0],
},
],
);
let all_ids = vec!["format", "search", "grep"];
let q = vec![0.5, 0.5, 0.0];
let boost_per = 0.15_f32;
let max_boost = 0.3_f32;
let score_of = |result: &ToolFilterResult, id: &str| -> f32 {
result
.scores
.iter()
.find(|(tid, _)| tid == id)
.map_or(0.0, |(_, s)| *s)
};
let mut r1 = filter.filter(&all_ids, &[], "q", &q);
let base_score = score_of(&r1, "format");
graph.apply(&mut r1, &completed(&[]), boost_per, max_boost, &always_on);
assert!((score_of(&r1, "format") - base_score).abs() < 1e-5);
let mut r2 = filter.filter(&all_ids, &[], "q", &q);
graph.apply(
&mut r2,
&completed(&["search"]),
boost_per,
max_boost,
&always_on,
);
let delta2 = score_of(&r2, "format") - base_score;
assert!(
(delta2 - 0.15).abs() < 1e-4,
"expected +0.15 boost, got {delta2}"
);
let mut r3 = filter.filter(&all_ids, &[], "q", &q);
graph.apply(
&mut r3,
&completed(&["search", "grep"]),
boost_per,
max_boost,
&always_on,
);
let delta3 = score_of(&r3, "format") - base_score;
assert!(
(delta3 - 0.30).abs() < 1e-4,
"expected +0.30 boost, got {delta3}"
);
}
#[test]
fn filter_tool_names_multi_turn_unlocks_after_completion() {
let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
let always_on: HashSet<String> = ["bash".into()].into();
let all_names = vec!["bash", "read", "apply_patch"];
let filtered_before = graph.filter_tool_names(&all_names, &completed(&[]), &always_on);
assert!(filtered_before.contains(&"bash")); assert!(filtered_before.contains(&"read")); assert!(!filtered_before.contains(&"apply_patch"));
let filtered_after = graph.filter_tool_names(&all_names, &completed(&["read"]), &always_on);
assert!(filtered_after.contains(&"bash"));
assert!(filtered_after.contains(&"read"));
assert!(filtered_after.contains(&"apply_patch")); }
#[test]
fn filter_tool_names_deadlock_fallback_passes_all() {
let graph = make_dep_graph(&[("only_tool", vec!["missing"], vec![])]);
let always_on: HashSet<String> = ["bash".into()].into();
let all_names = vec!["bash", "only_tool"];
let filtered = graph.filter_tool_names(&all_names, &completed(&[]), &always_on);
assert!(filtered.contains(&"bash"));
assert!(!filtered.contains(&"only_tool"));
}
}