use futures::stream::{self, StreamExt};
use zeph_common::math::cosine_similarity;
use crate::tool::McpTool;
#[derive(Debug, thiserror::Error)]
pub enum SemanticIndexError {
#[error("all {count} tool embeddings failed during index build")]
AllEmbeddingsFailed { count: usize },
}
struct ToolEntry {
tool: McpTool,
embedding: Vec<f32>,
}
#[derive(Debug)]
pub struct SemanticToolIndex {
entries: Vec<ToolEntry>,
all_tools: Vec<McpTool>,
}
impl std::fmt::Debug for ToolEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolEntry")
.field("tool_name", &self.tool.name)
.field("embedding_dim", &self.embedding.len())
.finish()
}
}
impl SemanticToolIndex {
pub async fn build<F>(tools: &[McpTool], embed_fn: &F) -> Result<Self, SemanticIndexError>
where
F: Fn(&str) -> zeph_llm::provider::EmbedFuture + Send + Sync,
{
if tools.is_empty() {
return Ok(Self {
entries: Vec::new(),
all_tools: Vec::new(),
});
}
let results: Vec<(usize, Result<Vec<f32>, _>)> = stream::iter(tools.iter().enumerate())
.map(|(idx, tool)| {
let sanitized_desc: String = tool
.description
.chars()
.filter(|c| !c.is_control())
.take(200)
.collect();
let text = format!("{}: {}", tool.name, sanitized_desc);
let fut = embed_fn(&text);
async move { (idx, fut.await) }
})
.buffer_unordered(8)
.collect()
.await;
let mut entries = Vec::with_capacity(tools.len());
let mut failed = 0usize;
for (idx, result) in results {
match result {
Ok(embedding) => entries.push(ToolEntry {
tool: tools[idx].clone(),
embedding,
}),
Err(e) => {
failed += 1;
tracing::warn!(
tool_name = %tools[idx].name,
server_id = %tools[idx].server_id,
"semantic index: embedding failed for tool, excluded from similarity ranking: {e:#}"
);
}
}
}
if entries.is_empty() {
return Err(SemanticIndexError::AllEmbeddingsFailed { count: failed });
}
if failed > 0 {
tracing::warn!(
total = tools.len(),
failed,
indexed = entries.len(),
"semantic index: some tools failed to embed"
);
}
Ok(Self {
entries,
all_tools: tools.to_vec(),
})
}
pub fn select(
&self,
query_embedding: &[f32],
top_k: usize,
min_similarity: f32,
always_include: &[String],
) -> Vec<McpTool> {
let mut pinned: Vec<McpTool> = self
.all_tools
.iter()
.filter(|t| always_include.iter().any(|a| a == &t.name))
.cloned()
.collect();
if query_embedding.is_empty() || top_k == 0 {
return pinned;
}
let query_dim = query_embedding.len();
let mut scored: Vec<(f32, &McpTool)> = self
.entries
.iter()
.filter(|e| {
if e.embedding.len() == query_dim {
true
} else {
tracing::warn!(
tool_name = %e.tool.name,
entry_dim = e.embedding.len(),
query_dim,
"semantic index: dimension mismatch, skipping tool"
);
false
}
})
.filter(|e| !always_include.iter().any(|a| a == &e.tool.name))
.map(|e| (cosine_similarity(query_embedding, &e.embedding), &e.tool))
.filter(|(score, _)| *score >= min_similarity)
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_k);
for (score, tool) in &scored {
tracing::debug!(tool_name = %tool.name, score, "semantic tool selection score");
}
pinned.extend(scored.into_iter().map(|(_, t)| t.clone()));
pinned
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum ToolDiscoveryStrategy {
Embedding,
Llm,
#[default]
None,
}
#[derive(Debug, Clone)]
pub struct DiscoveryParams {
pub top_k: usize,
pub min_similarity: f32,
pub min_tools_to_filter: usize,
pub always_include: Vec<String>,
pub strict: bool,
}
impl Default for DiscoveryParams {
fn default() -> Self {
Self {
top_k: 10,
min_similarity: 0.2,
min_tools_to_filter: 10,
always_include: Vec::new(),
strict: false,
}
}
}
#[cfg(test)]
mod tests {
use zeph_llm::provider::EmbedFn;
use super::*;
fn make_tool(name: &str, desc: &str) -> McpTool {
McpTool {
server_id: "srv".into(),
name: name.into(),
description: desc.into(),
input_schema: serde_json::Value::Null,
security_meta: crate::tool::ToolSecurityMeta::default(),
}
}
fn fixed_embed() -> EmbedFn {
Box::new(|text: &str| -> zeph_llm::provider::EmbedFuture {
let first = f32::from(text.chars().next().unwrap_or('a') as u8);
let v = vec![first / 100.0, 1.0, 1.0];
Box::pin(async move { Ok(v) })
})
}
fn failing_embed() -> EmbedFn {
Box::new(|_text: &str| -> zeph_llm::provider::EmbedFuture {
Box::pin(async move { Err(zeph_llm::LlmError::Other("forced failure".into())) })
})
}
#[tokio::test]
async fn build_empty_tools_returns_empty_index() {
let embed = fixed_embed();
let idx = SemanticToolIndex::build(&[], &embed).await.unwrap();
assert!(idx.is_empty());
assert_eq!(idx.len(), 0);
}
#[tokio::test]
async fn build_all_fail_returns_error() {
let tools = vec![make_tool("a", "desc")];
let embed = failing_embed();
let err = SemanticToolIndex::build(&tools, &embed).await.unwrap_err();
assert!(matches!(
err,
SemanticIndexError::AllEmbeddingsFailed { count: 1 }
));
}
#[tokio::test]
async fn build_partial_failure_returns_partial_index() {
let tools = vec![make_tool("aaa", "desc a"), make_tool("bbb", "desc b")];
let call_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
let cc = call_count.clone();
let embed: EmbedFn = Box::new(move |_text: &str| -> zeph_llm::provider::EmbedFuture {
let n = cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Box::pin(async move {
if n == 0 {
Ok(vec![1.0, 0.0, 0.0])
} else {
Err(zeph_llm::LlmError::Other("fail".into()))
}
})
});
let idx = SemanticToolIndex::build(&tools, &embed).await.unwrap();
assert_eq!(
idx.len(),
1,
"one tool indexed despite second embedding failure"
);
}
#[tokio::test]
async fn select_returns_top_k() {
let tools: Vec<McpTool> = (0..5).map(|i| make_tool(&format!("t{i}"), "d")).collect();
let embed: EmbedFn = Box::new(|text: &str| -> zeph_llm::provider::EmbedFuture {
#[allow(clippy::cast_precision_loss)]
let v = vec![text.len() as f32 / 10.0, 1.0];
Box::pin(async move { Ok(v) })
});
let idx = SemanticToolIndex::build(&tools, &embed).await.unwrap();
let query = vec![1.0, 1.0];
let result = idx.select(&query, 3, 0.0, &[]);
assert_eq!(result.len(), 3);
}
#[tokio::test]
async fn select_always_include_from_failed_tools() {
let tools = vec![
make_tool("pinned", "always here"),
make_tool("normal", "normal desc"),
];
let call_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
let cc = call_count.clone();
let embed: EmbedFn = Box::new(move |_text: &str| -> zeph_llm::provider::EmbedFuture {
let n = cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Box::pin(async move {
if n == 0 {
Err(zeph_llm::LlmError::Other("fail".into()))
} else {
Ok(vec![1.0, 0.0])
}
})
});
let idx = SemanticToolIndex::build(&tools, &embed).await.unwrap();
let query = vec![1.0, 0.0];
let result = idx.select(&query, 10, 0.0, &["pinned".to_string()]);
assert!(
result.iter().any(|t| t.name == "pinned"),
"always_include must include failed-to-embed tools"
);
}
#[tokio::test]
async fn select_min_similarity_filters_low_scores() {
let tools = vec![make_tool("t0", "x"), make_tool("t1", "y")];
let embed: EmbedFn = Box::new(|text: &str| -> zeph_llm::provider::EmbedFuture {
let v = if text.starts_with("t0") {
vec![1.0_f32, 0.0]
} else {
vec![0.0_f32, 1.0]
};
Box::pin(async move { Ok(v) })
});
let idx = SemanticToolIndex::build(&tools, &embed).await.unwrap();
let query = vec![1.0_f32, 0.0];
let result = idx.select(&query, 10, 0.5, &[]);
assert_eq!(result.len(), 1);
assert_eq!(result[0].name, "t0");
}
#[tokio::test]
async fn select_dimension_mismatch_skips_entry() {
let tools = vec![make_tool("t0", "d")];
let embed: EmbedFn = Box::new(|_text: &str| -> zeph_llm::provider::EmbedFuture {
Box::pin(async move { Ok(vec![1.0, 0.0, 0.0]) }) });
let idx = SemanticToolIndex::build(&tools, &embed).await.unwrap();
let result = idx.select(&[1.0, 0.0], 10, 0.0, &[]);
assert!(result.is_empty(), "dimension mismatch must skip entry");
}
#[tokio::test]
async fn select_top_k_exceeds_available_tools() {
let tools: Vec<McpTool> = (0..3).map(|i| make_tool(&format!("t{i}"), "d")).collect();
let embed: EmbedFn = Box::new(|_text: &str| -> zeph_llm::provider::EmbedFuture {
Box::pin(async move { Ok(vec![1.0, 0.0]) })
});
let idx = SemanticToolIndex::build(&tools, &embed).await.unwrap();
let result = idx.select(&[1.0, 0.0], 100, 0.0, &[]);
assert_eq!(
result.len(),
3,
"top_k > available tools must return all tools"
);
}
#[tokio::test]
async fn select_top_k_zero_returns_empty() {
let tools: Vec<McpTool> = (0..3).map(|i| make_tool(&format!("t{i}"), "d")).collect();
let embed: EmbedFn = Box::new(|_text: &str| -> zeph_llm::provider::EmbedFuture {
Box::pin(async move { Ok(vec![1.0, 0.0]) })
});
let idx = SemanticToolIndex::build(&tools, &embed).await.unwrap();
let result = idx.select(&[1.0, 0.0], 0, 0.0, &[]);
assert!(
result.is_empty(),
"top_k=0 with no always_include must return empty"
);
}
#[tokio::test]
async fn select_top_k_zero_returns_pinned() {
let tools = vec![make_tool("pinned", "always"), make_tool("other", "other")];
let embed: EmbedFn = Box::new(|_text: &str| -> zeph_llm::provider::EmbedFuture {
Box::pin(async move { Ok(vec![1.0, 0.0]) })
});
let idx = SemanticToolIndex::build(&tools, &embed).await.unwrap();
let result = idx.select(&[1.0, 0.0], 0, 0.0, &["pinned".to_string()]);
assert_eq!(result.len(), 1);
assert_eq!(result[0].name, "pinned");
}
#[tokio::test]
async fn select_empty_query_returns_only_pinned() {
let tools = vec![make_tool("pinned", "always"), make_tool("other", "other")];
let embed: EmbedFn = Box::new(|_text: &str| -> zeph_llm::provider::EmbedFuture {
Box::pin(async move { Ok(vec![1.0, 0.0]) })
});
let idx = SemanticToolIndex::build(&tools, &embed).await.unwrap();
let result = idx.select(&[], 10, 0.0, &["pinned".to_string()]);
assert_eq!(
result.len(),
1,
"empty query must return only always_include tools"
);
assert_eq!(result[0].name, "pinned");
}
#[tokio::test]
async fn select_always_include_no_duplicate() {
let tools = vec![make_tool("pinned", "always"), make_tool("other", "other")];
let embed: EmbedFn = Box::new(|_text: &str| -> zeph_llm::provider::EmbedFuture {
Box::pin(async move { Ok(vec![1.0, 0.0]) })
});
let idx = SemanticToolIndex::build(&tools, &embed).await.unwrap();
let result = idx.select(&[1.0, 0.0], 10, 0.0, &["pinned".to_string()]);
let pinned_count = result.iter().filter(|t| t.name == "pinned").count();
assert_eq!(
pinned_count, 1,
"always_include tool must not be duplicated in result"
);
}
#[test]
fn strategy_none_variant_exists() {
let s = ToolDiscoveryStrategy::None;
assert_ne!(s, ToolDiscoveryStrategy::Embedding);
assert_ne!(s, ToolDiscoveryStrategy::Llm);
}
#[test]
fn strategy_default_is_none() {
assert_eq!(
ToolDiscoveryStrategy::default(),
ToolDiscoveryStrategy::None
);
}
}