1use std::sync::Arc;
30
31use serde::{Deserialize, Serialize};
32
33use crate::embedding::{EmbeddingProvider, EmbeddingVector};
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ToolEntry {
46 pub name: String,
48 pub category: String,
52 pub description: String,
55 pub skill_path: Option<String>,
57 pub command: Option<String>,
59}
60
61impl ToolEntry {
62 fn embedding_text(&self) -> String {
67 let mut parts = format!("[{}] {}: {}", self.category, self.name, self.description);
68 if let Some(ref cmd) = self.command {
69 parts.push_str(&format!(" command: {cmd}"));
70 }
71 parts
72 }
73}
74
75#[derive(Debug, Clone)]
77struct IndexedTool {
78 entry: ToolEntry,
79 vector: EmbeddingVector,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ScoredTool {
85 pub entry: ToolEntry,
87 pub score: f64,
89}
90
91pub struct ToolRetriever {
100 index: Vec<IndexedTool>,
102 embedder: Arc<dyn EmbeddingProvider>,
104}
105
106impl ToolRetriever {
107 pub fn new(embedder: Arc<dyn EmbeddingProvider>) -> Self {
109 Self {
110 index: Vec::new(),
111 embedder,
112 }
113 }
114
115 pub fn embedder(&self) -> &Arc<dyn EmbeddingProvider> {
120 &self.embedder
121 }
122
123 pub async fn index_tool(&mut self, entry: ToolEntry) {
129 let text = entry.embedding_text();
130 match self.embedder.embed(&text).await {
131 Ok(vector) => {
132 self.index.push(IndexedTool { entry, vector });
133 }
134 Err(e) => {
135 tracing::warn!(name = %entry.name, error = %e, "failed to embed tool, skipping");
136 }
137 }
138 }
139
140 pub fn retrieve(&self, query_embedding: &EmbeddingVector, top_k: usize) -> Vec<ScoredTool> {
148 let mut scored: Vec<ScoredTool> = self
149 .index
150 .iter()
151 .map(|indexed| {
152 let score = query_embedding.cosine_similarity(&indexed.vector);
153 ScoredTool {
154 entry: indexed.entry.clone(),
155 score,
156 }
157 })
158 .collect();
159
160 scored.sort_by(|a, b| {
162 b.score
163 .partial_cmp(&a.score)
164 .unwrap_or(std::cmp::Ordering::Equal)
165 });
166
167 scored.truncate(top_k);
168 scored
169 }
170
171 pub fn len(&self) -> usize {
173 self.index.len()
174 }
175
176 pub fn is_empty(&self) -> bool {
178 self.index.is_empty()
179 }
180
181 pub fn entries(&self) -> Vec<&ToolEntry> {
183 self.index.iter().map(|i| &i.entry).collect()
184 }
185
186 pub fn clear(&mut self) {
188 self.index.clear();
189 }
190}
191
192pub fn format_capability_index(tools: &[ScoredTool]) -> String {
218 let mut xml = String::from("<available_capabilities>\n");
219
220 for tool in tools {
221 xml.push_str(" <capability>\n");
222 xml.push_str(&format!(
223 " <name>{}</name>\n",
224 escape_xml(&tool.entry.name)
225 ));
226 xml.push_str(&format!(
227 " <category>{}</category>\n",
228 escape_xml(&tool.entry.category)
229 ));
230 xml.push_str(&format!(
231 " <description>{}</description>\n",
232 escape_xml(&tool.entry.description)
233 ));
234 if let Some(ref cmd) = tool.entry.command {
235 xml.push_str(&format!(" <command>{}</command>\n", escape_xml(cmd)));
236 }
237 if let Some(ref skill) = tool.entry.skill_path {
238 xml.push_str(&format!(" <skill>{}</skill>\n", escape_xml(skill)));
239 }
240 xml.push_str(" </capability>\n");
241 }
242
243 xml.push_str("</available_capabilities>");
244 xml
245}
246
247fn escape_xml(s: &str) -> String {
249 let mut out = String::with_capacity(s.len());
250 for c in s.chars() {
251 match c {
252 '&' => out.push_str("&"),
253 '<' => out.push_str("<"),
254 '>' => out.push_str(">"),
255 '"' => out.push_str("""),
256 '\'' => out.push_str("'"),
257 _ => out.push(c),
258 }
259 }
260 out
261}
262
263const KNOWN_DOMAINS: &[&str] = &[
269 "space", "agent", "a2a", "memory", "security", "budget", "resource", "program",
270];
271
272pub fn build_kernel_manifest(active_domains: &[&str]) -> String {
291 let mut md = String::from("## Kernel Manifest\n\n");
292
293 let domain_list: Vec<&str> = active_domains
294 .iter()
295 .filter(|d| KNOWN_DOMAINS.contains(d))
296 .copied()
297 .collect();
298
299 md.push_str(&format!("Active domains: {}\n\n", domain_list.join(", ")));
300
301 for domain in &domain_list {
302 let description = domain_description(domain);
303 md.push_str(&format!("### {domain}\n{description}\n\n"));
304 }
305
306 md
307}
308
309fn domain_description(domain: &str) -> &'static str {
311 match domain {
312 "space" => "Filesystem workspace management and conversation buffers.",
313 "agent" => "Agent lifecycle, runtime, and supervisor.",
314 "a2a" => "Agent-to-agent communication and delegation.",
315 "memory" => "Persistent vector memory and semantic search.",
316 "security" => "RBAC access control and audit trail.",
317 "budget" => "Token and cost budget enforcement.",
318 "resource" => "System resource monitoring and overload protection.",
319 "program" => "Installable OS-level programs and tools.",
320 _ => "Unknown domain.",
321 }
322}
323
324#[cfg(test)]
329mod tests {
330 use super::*;
331
332 struct MockEmbedder;
336
337 #[async_trait::async_trait]
338 impl EmbeddingProvider for MockEmbedder {
339 async fn embed(&self, text: &str) -> anyhow::Result<EmbeddingVector> {
340 if text.is_empty() {
341 return Ok(EmbeddingVector::DenseF32(vec![]));
342 }
343 let len = text.len() as f32;
346 Ok(EmbeddingVector::DenseF32(vec![1.0, len / 100.0, 0.5]))
347 }
348
349 fn name(&self) -> &str {
350 "mock"
351 }
352 }
353
354 fn mock_entry(name: &str, category: &str, desc: &str) -> ToolEntry {
355 ToolEntry {
356 name: name.to_string(),
357 category: category.to_string(),
358 description: desc.to_string(),
359 skill_path: None,
360 command: None,
361 }
362 }
363
364 #[tokio::test]
365 async fn test_index_and_len() {
366 let embedder = Arc::new(MockEmbedder);
367 let mut retriever = ToolRetriever::new(embedder);
368
369 assert!(retriever.is_empty());
370 assert_eq!(retriever.len(), 0);
371
372 retriever
373 .index_tool(mock_entry("exec", "os-tool", "Run commands"))
374 .await;
375 retriever
376 .index_tool(mock_entry("git", "program", "Git operations"))
377 .await;
378
379 assert_eq!(retriever.len(), 2);
380 assert!(!retriever.is_empty());
381 }
382
383 #[tokio::test]
384 async fn test_retrieve_top_k() {
385 let embedder = Arc::new(MockEmbedder);
386 let mut retriever = ToolRetriever::new(embedder);
387
388 retriever
389 .index_tool(mock_entry("exec", "os-tool", "Run shell commands"))
390 .await;
391 retriever
392 .index_tool(mock_entry("git", "program", "Git version control"))
393 .await;
394 retriever
395 .index_tool(mock_entry("mcp-github", "mcp", "GitHub API bridge"))
396 .await;
397
398 let query = EmbeddingVector::DenseF32(vec![1.0, 0.5, 0.5]);
399 let results = retriever.retrieve(&query, 2);
400
401 assert_eq!(results.len(), 2);
402 assert!(results[0].score >= results[1].score);
404 }
405
406 #[tokio::test]
407 async fn test_retrieve_exceeds_index() {
408 let embedder = Arc::new(MockEmbedder);
409 let mut retriever = ToolRetriever::new(embedder);
410
411 retriever
412 .index_tool(mock_entry("exec", "os-tool", "Run commands"))
413 .await;
414
415 let query = EmbeddingVector::DenseF32(vec![1.0, 0.5, 0.5]);
416 let results = retriever.retrieve(&query, 10);
417
418 assert_eq!(results.len(), 1);
420 }
421
422 #[tokio::test]
423 async fn test_retrieve_empty_index() {
424 let embedder = Arc::new(MockEmbedder);
425 let retriever = ToolRetriever::new(embedder);
426
427 let query = EmbeddingVector::DenseF32(vec![1.0, 0.5, 0.5]);
428 let results = retriever.retrieve(&query, 5);
429
430 assert!(results.is_empty());
431 }
432
433 #[tokio::test]
434 async fn test_entries() {
435 let embedder = Arc::new(MockEmbedder);
436 let mut retriever = ToolRetriever::new(embedder);
437
438 retriever
439 .index_tool(mock_entry("exec", "os-tool", "Run commands"))
440 .await;
441 retriever
442 .index_tool(mock_entry("git", "program", "Git ops"))
443 .await;
444
445 let entries = retriever.entries();
446 assert_eq!(entries.len(), 2);
447 assert_eq!(entries[0].name, "exec");
448 assert_eq!(entries[1].name, "git");
449 }
450
451 #[tokio::test]
452 async fn test_clear() {
453 let embedder = Arc::new(MockEmbedder);
454 let mut retriever = ToolRetriever::new(embedder);
455
456 retriever
457 .index_tool(mock_entry("exec", "os-tool", "Run commands"))
458 .await;
459 assert_eq!(retriever.len(), 1);
460
461 retriever.clear();
462 assert!(retriever.is_empty());
463 }
464
465 #[test]
466 fn test_format_capability_index_basic() {
467 let tool = ScoredTool {
468 entry: ToolEntry {
469 name: "exec".into(),
470 category: "os-tool".into(),
471 description: "Execute shell commands".into(),
472 skill_path: None,
473 command: None,
474 },
475 score: 0.95,
476 };
477
478 let xml = format_capability_index(&[tool]);
479 assert!(xml.contains("<available_capabilities>"));
480 assert!(xml.contains("<name>exec</name>"));
481 assert!(xml.contains("<category>os-tool</category>"));
482 assert!(xml.contains("<description>Execute shell commands</description>"));
483 assert!(xml.contains("</available_capabilities>"));
484 assert!(!xml.contains("<command>"));
486 assert!(!xml.contains("<skill>"));
487 }
488
489 #[test]
490 fn test_format_capability_index_program() {
491 let tool = ScoredTool {
492 entry: ToolEntry {
493 name: "git-helper".into(),
494 category: "program".into(),
495 description: "Git workflow automation".into(),
496 skill_path: Some("programs/git-helper/SKILL.md".into()),
497 command: Some("git-helper".into()),
498 },
499 score: 0.88,
500 };
501
502 let xml = format_capability_index(&[tool]);
503 assert!(xml.contains("<command>git-helper</command>"));
504 assert!(xml.contains("<skill>programs/git-helper/SKILL.md</skill>"));
505 }
506
507 #[test]
508 fn test_format_capability_index_xml_escaping() {
509 let tool = ScoredTool {
510 entry: ToolEntry {
511 name: "test<>&".into(),
512 category: "os-tool".into(),
513 description: "A & B < C > D".into(),
514 skill_path: None,
515 command: None,
516 },
517 score: 1.0,
518 };
519
520 let xml = format_capability_index(&[tool]);
521 assert!(xml.contains("<name>test<>&</name>"));
522 assert!(xml.contains("<description>A & B < C > D</description>"));
523 }
524
525 #[test]
526 fn test_escape_xml() {
527 assert_eq!(escape_xml("hello"), "hello");
528 assert_eq!(
529 escape_xml("a&b<c>d\"e'f"),
530 "a&b<c>d"e'f"
531 );
532 }
533
534 #[test]
535 fn test_build_kernel_manifest() {
536 let md = build_kernel_manifest(&["space", "agent", "memory", "program"]);
537 assert!(md.contains("## Kernel Manifest"));
538 assert!(md.contains("Active domains: space, agent, memory, program"));
539 assert!(md.contains("### space"));
540 assert!(md.contains("### agent"));
541 assert!(md.contains("### memory"));
542 assert!(md.contains("### program"));
543 assert!(!md.contains("### security"));
544 }
545
546 #[test]
547 fn test_build_kernel_manifest_filters_unknown() {
548 let md = build_kernel_manifest(&["space", "unknown-domain"]);
549 assert!(md.contains("### space"));
550 assert!(!md.contains("unknown-domain"));
551 }
552
553 #[test]
554 fn test_build_kernel_manifest_empty() {
555 let md = build_kernel_manifest(&[]);
556 assert!(md.contains("## Kernel Manifest"));
557 assert!(md.contains("Active domains:"));
558 }
559
560 #[test]
561 fn test_tool_entry_embedding_text() {
562 let entry = mock_entry("exec", "os-tool", "Run commands");
563 let text = entry.embedding_text();
564 assert!(text.contains("[os-tool] exec: Run commands"));
565 }
566
567 #[test]
568 fn test_tool_entry_embedding_text_with_command() {
569 let entry = ToolEntry {
570 name: "git".into(),
571 category: "program".into(),
572 description: "Git ops".into(),
573 skill_path: None,
574 command: Some("git binary".into()),
575 };
576 let text = entry.embedding_text();
577 assert!(text.contains("command: git binary"));
578 }
579
580 #[tokio::test]
581 async fn test_embedder_accessor() {
582 let embedder = Arc::new(MockEmbedder);
583 let retriever = ToolRetriever::new(embedder);
584 assert_eq!(retriever.embedder().name(), "mock");
585 }
586
587 #[tokio::test]
590 async fn test_with_tfidf_embedder() {
591 use crate::embedding::TfIdfEmbeddingProvider;
592
593 let embedder = Arc::new(TfIdfEmbeddingProvider);
594 let mut retriever = ToolRetriever::new(embedder);
595
596 retriever
597 .index_tool(ToolEntry {
598 name: "exec".into(),
599 category: "os-tool".into(),
600 description: "Execute shell commands in workspace".into(),
601 skill_path: None,
602 command: None,
603 })
604 .await;
605 retriever
606 .index_tool(ToolEntry {
607 name: "memory-search".into(),
608 category: "os-tool".into(),
609 description: "Search persistent vector memory".into(),
610 skill_path: None,
611 command: None,
612 })
613 .await;
614
615 let query_embedding = retriever
616 .embedder()
617 .embed("run a bash command")
618 .await
619 .unwrap();
620 let results = retriever.retrieve(&query_embedding, 2);
621
622 assert_eq!(results.len(), 2);
623 assert_eq!(results[0].entry.name, "exec");
625 }
626}