1use std::sync::Arc;
32
33use serde::{Deserialize, Serialize};
34
35use crate::embedding::{EmbeddingProvider, EmbeddingVector};
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ToolEntry {
48 pub name: String,
50 pub category: String,
54 pub description: String,
57 pub skill_path: Option<String>,
59 pub command: Option<String>,
61}
62
63impl ToolEntry {
64 fn embedding_text(&self) -> String {
69 let mut parts = format!("[{}] {}: {}", self.category, self.name, self.description);
70 if let Some(ref cmd) = self.command {
71 parts.push_str(&format!(" command: {}", cmd));
72 }
73 parts
74 }
75}
76
77#[derive(Debug, Clone)]
79struct IndexedTool {
80 entry: ToolEntry,
81 vector: EmbeddingVector,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct ScoredTool {
87 pub entry: ToolEntry,
89 pub score: f64,
91}
92
93pub struct ToolRetriever {
102 index: Vec<IndexedTool>,
104 embedder: Arc<dyn EmbeddingProvider>,
106}
107
108impl ToolRetriever {
109 pub fn new(embedder: Arc<dyn EmbeddingProvider>) -> Self {
111 Self {
112 index: Vec::new(),
113 embedder,
114 }
115 }
116
117 pub fn embedder(&self) -> &Arc<dyn EmbeddingProvider> {
122 &self.embedder
123 }
124
125 pub async fn index_tool(&mut self, entry: ToolEntry) {
131 let text = entry.embedding_text();
132 match self.embedder.embed(&text).await {
133 Ok(vector) => {
134 self.index.push(IndexedTool { entry, vector });
135 }
136 Err(e) => {
137 tracing::warn!(name = %entry.name, error = %e, "failed to embed tool, skipping");
138 }
139 }
140 }
141
142 pub fn retrieve(&self, query_embedding: &EmbeddingVector, top_k: usize) -> Vec<ScoredTool> {
150 let mut scored: Vec<ScoredTool> = self
151 .index
152 .iter()
153 .map(|indexed| {
154 let score = query_embedding.cosine_similarity(&indexed.vector);
155 ScoredTool {
156 entry: indexed.entry.clone(),
157 score,
158 }
159 })
160 .collect();
161
162 scored.sort_by(|a, b| {
164 b.score
165 .partial_cmp(&a.score)
166 .unwrap_or(std::cmp::Ordering::Equal)
167 });
168
169 scored.truncate(top_k);
170 scored
171 }
172
173 pub fn len(&self) -> usize {
175 self.index.len()
176 }
177
178 pub fn is_empty(&self) -> bool {
180 self.index.is_empty()
181 }
182
183 pub fn entries(&self) -> Vec<&ToolEntry> {
185 self.index.iter().map(|i| &i.entry).collect()
186 }
187
188 pub fn clear(&mut self) {
190 self.index.clear();
191 }
192}
193
194pub fn format_capability_index(tools: &[ScoredTool]) -> String {
220 let mut xml = String::from("<available_capabilities>\n");
221
222 for tool in tools {
223 xml.push_str(" <capability>\n");
224 xml.push_str(&format!(
225 " <name>{}</name>\n",
226 escape_xml(&tool.entry.name)
227 ));
228 xml.push_str(&format!(
229 " <category>{}</category>\n",
230 escape_xml(&tool.entry.category)
231 ));
232 xml.push_str(&format!(
233 " <description>{}</description>\n",
234 escape_xml(&tool.entry.description)
235 ));
236 if let Some(ref cmd) = tool.entry.command {
237 xml.push_str(&format!(" <command>{}</command>\n", escape_xml(cmd)));
238 }
239 if let Some(ref skill) = tool.entry.skill_path {
240 xml.push_str(&format!(" <skill>{}</skill>\n", escape_xml(skill)));
241 }
242 xml.push_str(" </capability>\n");
243 }
244
245 xml.push_str("</available_capabilities>");
246 xml
247}
248
249fn escape_xml(s: &str) -> String {
251 let mut out = String::with_capacity(s.len());
252 for c in s.chars() {
253 match c {
254 '&' => out.push_str("&"),
255 '<' => out.push_str("<"),
256 '>' => out.push_str(">"),
257 '"' => out.push_str("""),
258 '\'' => out.push_str("'"),
259 _ => out.push(c),
260 }
261 }
262 out
263}
264
265const KNOWN_DOMAINS: &[&str] = &[
271 "space", "agent", "a2a", "memory", "security", "budget", "resource", "program",
272];
273
274pub fn build_kernel_manifest(active_domains: &[&str]) -> String {
293 let mut md = String::from("## Kernel Manifest\n\n");
294
295 let domain_list: Vec<&str> = active_domains
296 .iter()
297 .filter(|d| KNOWN_DOMAINS.contains(d))
298 .copied()
299 .collect();
300
301 md.push_str(&format!("Active domains: {}\n\n", domain_list.join(", ")));
302
303 for domain in &domain_list {
304 let description = domain_description(domain);
305 md.push_str(&format!("### {}\n{}\n\n", domain, description));
306 }
307
308 md
309}
310
311fn domain_description(domain: &str) -> &'static str {
313 match domain {
314 "space" => "Filesystem workspace management and conversation buffers.",
315 "agent" => "Agent lifecycle, runtime, and supervisor.",
316 "a2a" => "Agent-to-agent communication and delegation.",
317 "memory" => "Persistent vector memory and semantic search.",
318 "security" => "RBAC access control and audit trail.",
319 "budget" => "Token and cost budget enforcement.",
320 "resource" => "System resource monitoring and overload protection.",
321 "program" => "Installable OS-level programs and tools.",
322 _ => "Unknown domain.",
323 }
324}
325
326#[cfg(test)]
331mod tests {
332 use super::*;
333
334 struct MockEmbedder;
338
339 #[async_trait::async_trait]
340 impl EmbeddingProvider for MockEmbedder {
341 async fn embed(&self, text: &str) -> anyhow::Result<EmbeddingVector> {
342 if text.is_empty() {
343 return Ok(EmbeddingVector::DenseF32(vec![]));
344 }
345 let len = text.len() as f32;
348 Ok(EmbeddingVector::DenseF32(vec![1.0, len / 100.0, 0.5]))
349 }
350
351 fn name(&self) -> &str {
352 "mock"
353 }
354 }
355
356 fn mock_entry(name: &str, category: &str, desc: &str) -> ToolEntry {
357 ToolEntry {
358 name: name.to_string(),
359 category: category.to_string(),
360 description: desc.to_string(),
361 skill_path: None,
362 command: None,
363 }
364 }
365
366 #[tokio::test]
367 async fn test_index_and_len() {
368 let embedder = Arc::new(MockEmbedder);
369 let mut retriever = ToolRetriever::new(embedder);
370
371 assert!(retriever.is_empty());
372 assert_eq!(retriever.len(), 0);
373
374 retriever
375 .index_tool(mock_entry("exec", "os-tool", "Run commands"))
376 .await;
377 retriever
378 .index_tool(mock_entry("git", "program", "Git operations"))
379 .await;
380
381 assert_eq!(retriever.len(), 2);
382 assert!(!retriever.is_empty());
383 }
384
385 #[tokio::test]
386 async fn test_retrieve_top_k() {
387 let embedder = Arc::new(MockEmbedder);
388 let mut retriever = ToolRetriever::new(embedder);
389
390 retriever
391 .index_tool(mock_entry("exec", "os-tool", "Run shell commands"))
392 .await;
393 retriever
394 .index_tool(mock_entry("git", "program", "Git version control"))
395 .await;
396 retriever
397 .index_tool(mock_entry("mcp-github", "mcp", "GitHub API bridge"))
398 .await;
399
400 let query = EmbeddingVector::DenseF32(vec![1.0, 0.5, 0.5]);
401 let results = retriever.retrieve(&query, 2);
402
403 assert_eq!(results.len(), 2);
404 assert!(results[0].score >= results[1].score);
406 }
407
408 #[tokio::test]
409 async fn test_retrieve_exceeds_index() {
410 let embedder = Arc::new(MockEmbedder);
411 let mut retriever = ToolRetriever::new(embedder);
412
413 retriever
414 .index_tool(mock_entry("exec", "os-tool", "Run commands"))
415 .await;
416
417 let query = EmbeddingVector::DenseF32(vec![1.0, 0.5, 0.5]);
418 let results = retriever.retrieve(&query, 10);
419
420 assert_eq!(results.len(), 1);
422 }
423
424 #[tokio::test]
425 async fn test_retrieve_empty_index() {
426 let embedder = Arc::new(MockEmbedder);
427 let retriever = ToolRetriever::new(embedder);
428
429 let query = EmbeddingVector::DenseF32(vec![1.0, 0.5, 0.5]);
430 let results = retriever.retrieve(&query, 5);
431
432 assert!(results.is_empty());
433 }
434
435 #[tokio::test]
436 async fn test_entries() {
437 let embedder = Arc::new(MockEmbedder);
438 let mut retriever = ToolRetriever::new(embedder);
439
440 retriever
441 .index_tool(mock_entry("exec", "os-tool", "Run commands"))
442 .await;
443 retriever
444 .index_tool(mock_entry("git", "program", "Git ops"))
445 .await;
446
447 let entries = retriever.entries();
448 assert_eq!(entries.len(), 2);
449 assert_eq!(entries[0].name, "exec");
450 assert_eq!(entries[1].name, "git");
451 }
452
453 #[tokio::test]
454 async fn test_clear() {
455 let embedder = Arc::new(MockEmbedder);
456 let mut retriever = ToolRetriever::new(embedder);
457
458 retriever
459 .index_tool(mock_entry("exec", "os-tool", "Run commands"))
460 .await;
461 assert_eq!(retriever.len(), 1);
462
463 retriever.clear();
464 assert!(retriever.is_empty());
465 }
466
467 #[test]
468 fn test_format_capability_index_basic() {
469 let tool = ScoredTool {
470 entry: ToolEntry {
471 name: "exec".into(),
472 category: "os-tool".into(),
473 description: "Execute shell commands".into(),
474 skill_path: None,
475 command: None,
476 },
477 score: 0.95,
478 };
479
480 let xml = format_capability_index(&[tool]);
481 assert!(xml.contains("<available_capabilities>"));
482 assert!(xml.contains("<name>exec</name>"));
483 assert!(xml.contains("<category>os-tool</category>"));
484 assert!(xml.contains("<description>Execute shell commands</description>"));
485 assert!(xml.contains("</available_capabilities>"));
486 assert!(!xml.contains("<command>"));
488 assert!(!xml.contains("<skill>"));
489 }
490
491 #[test]
492 fn test_format_capability_index_program() {
493 let tool = ScoredTool {
494 entry: ToolEntry {
495 name: "git-helper".into(),
496 category: "program".into(),
497 description: "Git workflow automation".into(),
498 skill_path: Some("programs/git-helper/SKILL.md".into()),
499 command: Some("git-helper".into()),
500 },
501 score: 0.88,
502 };
503
504 let xml = format_capability_index(&[tool]);
505 assert!(xml.contains("<command>git-helper</command>"));
506 assert!(xml.contains("<skill>programs/git-helper/SKILL.md</skill>"));
507 }
508
509 #[test]
510 fn test_format_capability_index_xml_escaping() {
511 let tool = ScoredTool {
512 entry: ToolEntry {
513 name: "test<>&".into(),
514 category: "os-tool".into(),
515 description: "A & B < C > D".into(),
516 skill_path: None,
517 command: None,
518 },
519 score: 1.0,
520 };
521
522 let xml = format_capability_index(&[tool]);
523 assert!(xml.contains("<name>test<>&</name>"));
524 assert!(xml.contains("<description>A & B < C > D</description>"));
525 }
526
527 #[test]
528 fn test_escape_xml() {
529 assert_eq!(escape_xml("hello"), "hello");
530 assert_eq!(
531 escape_xml("a&b<c>d\"e'f"),
532 "a&b<c>d"e'f"
533 );
534 }
535
536 #[test]
537 fn test_build_kernel_manifest() {
538 let md = build_kernel_manifest(&["space", "agent", "memory", "program"]);
539 assert!(md.contains("## Kernel Manifest"));
540 assert!(md.contains("Active domains: space, agent, memory, program"));
541 assert!(md.contains("### space"));
542 assert!(md.contains("### agent"));
543 assert!(md.contains("### memory"));
544 assert!(md.contains("### program"));
545 assert!(!md.contains("### security"));
546 }
547
548 #[test]
549 fn test_build_kernel_manifest_filters_unknown() {
550 let md = build_kernel_manifest(&["space", "unknown-domain"]);
551 assert!(md.contains("### space"));
552 assert!(!md.contains("unknown-domain"));
553 }
554
555 #[test]
556 fn test_build_kernel_manifest_empty() {
557 let md = build_kernel_manifest(&[]);
558 assert!(md.contains("## Kernel Manifest"));
559 assert!(md.contains("Active domains:"));
560 }
561
562 #[test]
563 fn test_tool_entry_embedding_text() {
564 let entry = mock_entry("exec", "os-tool", "Run commands");
565 let text = entry.embedding_text();
566 assert!(text.contains("[os-tool] exec: Run commands"));
567 }
568
569 #[test]
570 fn test_tool_entry_embedding_text_with_command() {
571 let entry = ToolEntry {
572 name: "git".into(),
573 category: "program".into(),
574 description: "Git ops".into(),
575 skill_path: None,
576 command: Some("git binary".into()),
577 };
578 let text = entry.embedding_text();
579 assert!(text.contains("command: git binary"));
580 }
581
582 #[tokio::test]
583 async fn test_embedder_accessor() {
584 let embedder = Arc::new(MockEmbedder);
585 let retriever = ToolRetriever::new(embedder);
586 assert_eq!(retriever.embedder().name(), "mock");
587 }
588
589 #[tokio::test]
592 async fn test_with_tfidf_embedder() {
593 use crate::embedding::TfIdfEmbeddingProvider;
594
595 let embedder = Arc::new(TfIdfEmbeddingProvider);
596 let mut retriever = ToolRetriever::new(embedder);
597
598 retriever
599 .index_tool(ToolEntry {
600 name: "exec".into(),
601 category: "os-tool".into(),
602 description: "Execute shell commands in workspace".into(),
603 skill_path: None,
604 command: None,
605 })
606 .await;
607 retriever
608 .index_tool(ToolEntry {
609 name: "memory-search".into(),
610 category: "os-tool".into(),
611 description: "Search persistent vector memory".into(),
612 skill_path: None,
613 command: None,
614 })
615 .await;
616
617 let query_embedding = retriever
618 .embedder()
619 .embed("run a bash command")
620 .await
621 .unwrap();
622 let results = retriever.retrieve(&query_embedding, 2);
623
624 assert_eq!(results.len(), 2);
625 assert_eq!(results[0].entry.name, "exec");
627 }
628}