use std::sync::Arc;
use serde::{Deserialize, Serialize};
use crate::embedding::{EmbeddingProvider, EmbeddingVector};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolEntry {
pub name: String,
pub category: String,
pub description: String,
pub skill_path: Option<String>,
pub command: Option<String>,
}
impl ToolEntry {
fn embedding_text(&self) -> String {
let mut parts = format!("[{}] {}: {}", self.category, self.name, self.description);
if let Some(ref cmd) = self.command {
parts.push_str(&format!(" command: {}", cmd));
}
parts
}
}
#[derive(Debug, Clone)]
struct IndexedTool {
entry: ToolEntry,
vector: EmbeddingVector,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoredTool {
pub entry: ToolEntry,
pub score: f64,
}
pub struct ToolRetriever {
index: Vec<IndexedTool>,
embedder: Arc<dyn EmbeddingProvider>,
}
impl ToolRetriever {
pub fn new(embedder: Arc<dyn EmbeddingProvider>) -> Self {
Self {
index: Vec::new(),
embedder,
}
}
pub fn embedder(&self) -> &Arc<dyn EmbeddingProvider> {
&self.embedder
}
pub async fn index_tool(&mut self, entry: ToolEntry) {
let text = entry.embedding_text();
match self.embedder.embed(&text).await {
Ok(vector) => {
self.index.push(IndexedTool { entry, vector });
}
Err(e) => {
tracing::warn!(name = %entry.name, error = %e, "failed to embed tool, skipping");
}
}
}
pub fn retrieve(&self, query_embedding: &EmbeddingVector, top_k: usize) -> Vec<ScoredTool> {
let mut scored: Vec<ScoredTool> = self
.index
.iter()
.map(|indexed| {
let score = query_embedding.cosine_similarity(&indexed.vector);
ScoredTool {
entry: indexed.entry.clone(),
score,
}
})
.collect();
scored.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
scored.truncate(top_k);
scored
}
pub fn len(&self) -> usize {
self.index.len()
}
pub fn is_empty(&self) -> bool {
self.index.is_empty()
}
pub fn entries(&self) -> Vec<&ToolEntry> {
self.index.iter().map(|i| &i.entry).collect()
}
pub fn clear(&mut self) {
self.index.clear();
}
}
pub fn format_capability_index(tools: &[ScoredTool]) -> String {
let mut xml = String::from("<available_capabilities>\n");
for tool in tools {
xml.push_str(" <capability>\n");
xml.push_str(&format!(
" <name>{}</name>\n",
escape_xml(&tool.entry.name)
));
xml.push_str(&format!(
" <category>{}</category>\n",
escape_xml(&tool.entry.category)
));
xml.push_str(&format!(
" <description>{}</description>\n",
escape_xml(&tool.entry.description)
));
if let Some(ref cmd) = tool.entry.command {
xml.push_str(&format!(" <command>{}</command>\n", escape_xml(cmd)));
}
if let Some(ref skill) = tool.entry.skill_path {
xml.push_str(&format!(" <skill>{}</skill>\n", escape_xml(skill)));
}
xml.push_str(" </capability>\n");
}
xml.push_str("</available_capabilities>");
xml
}
fn escape_xml(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for c in s.chars() {
match c {
'&' => out.push_str("&"),
'<' => out.push_str("<"),
'>' => out.push_str(">"),
'"' => out.push_str("""),
'\'' => out.push_str("'"),
_ => out.push(c),
}
}
out
}
const KNOWN_DOMAINS: &[&str] = &[
"space", "agent", "a2a", "memory", "security", "budget", "resource", "program",
];
pub fn build_kernel_manifest(active_domains: &[&str]) -> String {
let mut md = String::from("## Kernel Manifest\n\n");
let domain_list: Vec<&str> = active_domains
.iter()
.filter(|d| KNOWN_DOMAINS.contains(d))
.copied()
.collect();
md.push_str(&format!("Active domains: {}\n\n", domain_list.join(", ")));
for domain in &domain_list {
let description = domain_description(domain);
md.push_str(&format!("### {}\n{}\n\n", domain, description));
}
md
}
fn domain_description(domain: &str) -> &'static str {
match domain {
"space" => "Filesystem workspace management and conversation buffers.",
"agent" => "Agent lifecycle, runtime, and supervisor.",
"a2a" => "Agent-to-agent communication and delegation.",
"memory" => "Persistent vector memory and semantic search.",
"security" => "RBAC access control and audit trail.",
"budget" => "Token and cost budget enforcement.",
"resource" => "System resource monitoring and overload protection.",
"program" => "Installable OS-level programs and tools.",
_ => "Unknown domain.",
}
}
#[cfg(test)]
mod tests {
use super::*;
struct MockEmbedder;
#[async_trait::async_trait]
impl EmbeddingProvider for MockEmbedder {
async fn embed(&self, text: &str) -> anyhow::Result<EmbeddingVector> {
if text.is_empty() {
return Ok(EmbeddingVector::DenseF32(vec![]));
}
let len = text.len() as f32;
Ok(EmbeddingVector::DenseF32(vec![1.0, len / 100.0, 0.5]))
}
fn name(&self) -> &str {
"mock"
}
}
fn mock_entry(name: &str, category: &str, desc: &str) -> ToolEntry {
ToolEntry {
name: name.to_string(),
category: category.to_string(),
description: desc.to_string(),
skill_path: None,
command: None,
}
}
#[tokio::test]
async fn test_index_and_len() {
let embedder = Arc::new(MockEmbedder);
let mut retriever = ToolRetriever::new(embedder);
assert!(retriever.is_empty());
assert_eq!(retriever.len(), 0);
retriever
.index_tool(mock_entry("exec", "os-tool", "Run commands"))
.await;
retriever
.index_tool(mock_entry("git", "program", "Git operations"))
.await;
assert_eq!(retriever.len(), 2);
assert!(!retriever.is_empty());
}
#[tokio::test]
async fn test_retrieve_top_k() {
let embedder = Arc::new(MockEmbedder);
let mut retriever = ToolRetriever::new(embedder);
retriever
.index_tool(mock_entry("exec", "os-tool", "Run shell commands"))
.await;
retriever
.index_tool(mock_entry("git", "program", "Git version control"))
.await;
retriever
.index_tool(mock_entry("mcp-github", "mcp", "GitHub API bridge"))
.await;
let query = EmbeddingVector::DenseF32(vec![1.0, 0.5, 0.5]);
let results = retriever.retrieve(&query, 2);
assert_eq!(results.len(), 2);
assert!(results[0].score >= results[1].score);
}
#[tokio::test]
async fn test_retrieve_exceeds_index() {
let embedder = Arc::new(MockEmbedder);
let mut retriever = ToolRetriever::new(embedder);
retriever
.index_tool(mock_entry("exec", "os-tool", "Run commands"))
.await;
let query = EmbeddingVector::DenseF32(vec![1.0, 0.5, 0.5]);
let results = retriever.retrieve(&query, 10);
assert_eq!(results.len(), 1);
}
#[tokio::test]
async fn test_retrieve_empty_index() {
let embedder = Arc::new(MockEmbedder);
let retriever = ToolRetriever::new(embedder);
let query = EmbeddingVector::DenseF32(vec![1.0, 0.5, 0.5]);
let results = retriever.retrieve(&query, 5);
assert!(results.is_empty());
}
#[tokio::test]
async fn test_entries() {
let embedder = Arc::new(MockEmbedder);
let mut retriever = ToolRetriever::new(embedder);
retriever
.index_tool(mock_entry("exec", "os-tool", "Run commands"))
.await;
retriever
.index_tool(mock_entry("git", "program", "Git ops"))
.await;
let entries = retriever.entries();
assert_eq!(entries.len(), 2);
assert_eq!(entries[0].name, "exec");
assert_eq!(entries[1].name, "git");
}
#[tokio::test]
async fn test_clear() {
let embedder = Arc::new(MockEmbedder);
let mut retriever = ToolRetriever::new(embedder);
retriever
.index_tool(mock_entry("exec", "os-tool", "Run commands"))
.await;
assert_eq!(retriever.len(), 1);
retriever.clear();
assert!(retriever.is_empty());
}
#[test]
fn test_format_capability_index_basic() {
let tool = ScoredTool {
entry: ToolEntry {
name: "exec".into(),
category: "os-tool".into(),
description: "Execute shell commands".into(),
skill_path: None,
command: None,
},
score: 0.95,
};
let xml = format_capability_index(&[tool]);
assert!(xml.contains("<available_capabilities>"));
assert!(xml.contains("<name>exec</name>"));
assert!(xml.contains("<category>os-tool</category>"));
assert!(xml.contains("<description>Execute shell commands</description>"));
assert!(xml.contains("</available_capabilities>"));
assert!(!xml.contains("<command>"));
assert!(!xml.contains("<skill>"));
}
#[test]
fn test_format_capability_index_program() {
let tool = ScoredTool {
entry: ToolEntry {
name: "git-helper".into(),
category: "program".into(),
description: "Git workflow automation".into(),
skill_path: Some("programs/git-helper/SKILL.md".into()),
command: Some("git-helper".into()),
},
score: 0.88,
};
let xml = format_capability_index(&[tool]);
assert!(xml.contains("<command>git-helper</command>"));
assert!(xml.contains("<skill>programs/git-helper/SKILL.md</skill>"));
}
#[test]
fn test_format_capability_index_xml_escaping() {
let tool = ScoredTool {
entry: ToolEntry {
name: "test<>&".into(),
category: "os-tool".into(),
description: "A & B < C > D".into(),
skill_path: None,
command: None,
},
score: 1.0,
};
let xml = format_capability_index(&[tool]);
assert!(xml.contains("<name>test<>&</name>"));
assert!(xml.contains("<description>A & B < C > D</description>"));
}
#[test]
fn test_escape_xml() {
assert_eq!(escape_xml("hello"), "hello");
assert_eq!(
escape_xml("a&b<c>d\"e'f"),
"a&b<c>d"e'f"
);
}
#[test]
fn test_build_kernel_manifest() {
let md = build_kernel_manifest(&["space", "agent", "memory", "program"]);
assert!(md.contains("## Kernel Manifest"));
assert!(md.contains("Active domains: space, agent, memory, program"));
assert!(md.contains("### space"));
assert!(md.contains("### agent"));
assert!(md.contains("### memory"));
assert!(md.contains("### program"));
assert!(!md.contains("### security"));
}
#[test]
fn test_build_kernel_manifest_filters_unknown() {
let md = build_kernel_manifest(&["space", "unknown-domain"]);
assert!(md.contains("### space"));
assert!(!md.contains("unknown-domain"));
}
#[test]
fn test_build_kernel_manifest_empty() {
let md = build_kernel_manifest(&[]);
assert!(md.contains("## Kernel Manifest"));
assert!(md.contains("Active domains:"));
}
#[test]
fn test_tool_entry_embedding_text() {
let entry = mock_entry("exec", "os-tool", "Run commands");
let text = entry.embedding_text();
assert!(text.contains("[os-tool] exec: Run commands"));
}
#[test]
fn test_tool_entry_embedding_text_with_command() {
let entry = ToolEntry {
name: "git".into(),
category: "program".into(),
description: "Git ops".into(),
skill_path: None,
command: Some("git binary".into()),
};
let text = entry.embedding_text();
assert!(text.contains("command: git binary"));
}
#[tokio::test]
async fn test_embedder_accessor() {
let embedder = Arc::new(MockEmbedder);
let retriever = ToolRetriever::new(embedder);
assert_eq!(retriever.embedder().name(), "mock");
}
#[tokio::test]
async fn test_with_tfidf_embedder() {
use crate::embedding::TfIdfEmbeddingProvider;
let embedder = Arc::new(TfIdfEmbeddingProvider);
let mut retriever = ToolRetriever::new(embedder);
retriever
.index_tool(ToolEntry {
name: "exec".into(),
category: "os-tool".into(),
description: "Execute shell commands in workspace".into(),
skill_path: None,
command: None,
})
.await;
retriever
.index_tool(ToolEntry {
name: "memory-search".into(),
category: "os-tool".into(),
description: "Search persistent vector memory".into(),
skill_path: None,
command: None,
})
.await;
let query_embedding = retriever
.embedder()
.embed("run a bash command")
.await
.unwrap();
let results = retriever.retrieve(&query_embedding, 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].entry.name, "exec");
}
}