1use std::sync::Arc;
30
31use serde::{Deserialize, Serialize};
32
33use crate::embedding::{EmbeddingProvider, EmbeddingVector};
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ToolEntry {
46 pub name: String,
48 pub category: String,
52 pub description: String,
55 pub skill_path: Option<String>,
57 pub command: Option<String>,
59}
60
61impl ToolEntry {
62 fn embedding_text(&self) -> String {
67 let mut parts = format!("[{}] {}: {}", self.category, self.name, self.description);
68 if let Some(ref cmd) = self.command {
69 parts.push_str(&format!(" command: {cmd}"));
70 }
71 parts
72 }
73}
74
75#[derive(Debug, Clone)]
77struct IndexedTool {
78 entry: ToolEntry,
79 vector: EmbeddingVector,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ScoredTool {
85 pub entry: ToolEntry,
87 pub score: f64,
89}
90
91pub struct ToolRetriever {
100 index: Vec<IndexedTool>,
102 embedder: Arc<dyn EmbeddingProvider>,
104}
105
106impl ToolRetriever {
107 pub fn new(embedder: Arc<dyn EmbeddingProvider>) -> Self {
109 Self {
110 index: Vec::new(),
111 embedder,
112 }
113 }
114
115 pub fn embedder(&self) -> &Arc<dyn EmbeddingProvider> {
120 &self.embedder
121 }
122
123 pub async fn index_tool(&mut self, entry: ToolEntry) {
129 let text = entry.embedding_text();
130 match self.embedder.embed(&text).await {
131 Ok(vector) => {
132 self.index.push(IndexedTool { entry, vector });
133 }
134 Err(e) => {
135 tracing::warn!(name = %entry.name, error = %e, "failed to embed tool, skipping");
136 }
137 }
138 }
139
140 pub fn retrieve(&self, query_embedding: &EmbeddingVector, top_k: usize) -> Vec<ScoredTool> {
148 let mut scored: Vec<ScoredTool> = self
149 .index
150 .iter()
151 .map(|indexed| {
152 let score = query_embedding.cosine_similarity(&indexed.vector);
153 ScoredTool {
154 entry: indexed.entry.clone(),
155 score,
156 }
157 })
158 .collect();
159
160 scored.sort_by(|a, b| {
162 b.score
163 .partial_cmp(&a.score)
164 .unwrap_or(std::cmp::Ordering::Equal)
165 });
166
167 scored.truncate(top_k);
168 scored
169 }
170
171 pub fn len(&self) -> usize {
173 self.index.len()
174 }
175
176 pub fn is_empty(&self) -> bool {
178 self.index.is_empty()
179 }
180
181 pub fn entries(&self) -> Vec<&ToolEntry> {
183 self.index.iter().map(|i| &i.entry).collect()
184 }
185
186 pub fn clear(&mut self) {
188 self.index.clear();
189 }
190}
191
192pub fn format_capability_index(tools: &[ScoredTool]) -> String {
218 let mut xml = String::from("<available_capabilities>\n");
219
220 for tool in tools {
221 xml.push_str(" <capability>\n");
222 xml.push_str(&format!(
223 " <name>{}</name>\n",
224 escape_xml(&tool.entry.name)
225 ));
226 xml.push_str(&format!(
227 " <category>{}</category>\n",
228 escape_xml(&tool.entry.category)
229 ));
230 xml.push_str(&format!(
231 " <description>{}</description>\n",
232 escape_xml(&tool.entry.description)
233 ));
234 if let Some(ref cmd) = tool.entry.command {
235 xml.push_str(&format!(" <command>{}</command>\n", escape_xml(cmd)));
236 }
237 if let Some(ref skill) = tool.entry.skill_path {
238 xml.push_str(&format!(" <skill>{}</skill>\n", escape_xml(skill)));
239 }
240 xml.push_str(" </capability>\n");
241 }
242
243 xml.push_str("</available_capabilities>");
244 xml
245}
246
247fn escape_xml(s: &str) -> String {
249 let mut out = String::with_capacity(s.len());
250 for c in s.chars() {
251 match c {
252 '&' => out.push_str("&"),
253 '<' => out.push_str("<"),
254 '>' => out.push_str(">"),
255 '"' => out.push_str("""),
256 '\'' => out.push_str("'"),
257 _ => out.push(c),
258 }
259 }
260 out
261}
262
263const KNOWN_DOMAINS: &[&str] = &[
269 "space",
270 "agent",
271 "a2a",
272 "memory",
273 "knowledge",
274 "security",
275 "budget",
276 "resource",
277 "program",
278];
279
280pub fn build_kernel_manifest(active_domains: &[&str]) -> String {
299 let mut md = String::from("## Kernel Manifest\n\n");
300
301 let domain_list: Vec<&str> = active_domains
302 .iter()
303 .filter(|d| KNOWN_DOMAINS.contains(d))
304 .copied()
305 .collect();
306
307 md.push_str(&format!("Active domains: {}\n\n", domain_list.join(", ")));
308
309 for domain in &domain_list {
310 let description = domain_description(domain);
311 md.push_str(&format!("### {domain}\n{description}\n\n"));
312 }
313
314 md
315}
316
317fn domain_description(domain: &str) -> &'static str {
319 match domain {
320 "space" => "Filesystem workspace management and conversation buffers.",
321 "agent" => "Agent lifecycle, runtime, and supervisor.",
322 "a2a" => "Agent-to-agent communication and delegation.",
323 "memory" => {
324 "Internal agent recall — facts, preferences, behavioral patterns. Not user-visible."
325 }
326 "knowledge" => {
327 "Personal markdown vault — documents, articles, notes, journal. File-based with backlinks and full-text search."
328 }
329 "security" => "RBAC access control and audit trail.",
330 "budget" => "Token and cost budget enforcement.",
331 "resource" => "System resource monitoring and overload protection.",
332 "program" => "Installable OS-level programs and tools.",
333 _ => "Unknown domain.",
334 }
335}
336
337#[cfg(test)]
342mod tests {
343 use super::*;
344
345 struct MockEmbedder;
349
350 #[async_trait::async_trait]
351 impl EmbeddingProvider for MockEmbedder {
352 async fn embed(&self, text: &str) -> anyhow::Result<EmbeddingVector> {
353 if text.is_empty() {
354 return Ok(EmbeddingVector::DenseF32(vec![]));
355 }
356 let len = text.len() as f32;
359 Ok(EmbeddingVector::DenseF32(vec![1.0, len / 100.0, 0.5]))
360 }
361
362 fn name(&self) -> &str {
363 "mock"
364 }
365 }
366
367 fn mock_entry(name: &str, category: &str, desc: &str) -> ToolEntry {
368 ToolEntry {
369 name: name.to_string(),
370 category: category.to_string(),
371 description: desc.to_string(),
372 skill_path: None,
373 command: None,
374 }
375 }
376
377 #[tokio::test]
378 async fn test_index_and_len() {
379 let embedder = Arc::new(MockEmbedder);
380 let mut retriever = ToolRetriever::new(embedder);
381
382 assert!(retriever.is_empty());
383 assert_eq!(retriever.len(), 0);
384
385 retriever
386 .index_tool(mock_entry("exec", "os-tool", "Run commands"))
387 .await;
388 retriever
389 .index_tool(mock_entry("git", "program", "Git operations"))
390 .await;
391
392 assert_eq!(retriever.len(), 2);
393 assert!(!retriever.is_empty());
394 }
395
396 #[tokio::test]
397 async fn test_retrieve_top_k() {
398 let embedder = Arc::new(MockEmbedder);
399 let mut retriever = ToolRetriever::new(embedder);
400
401 retriever
402 .index_tool(mock_entry("exec", "os-tool", "Run shell commands"))
403 .await;
404 retriever
405 .index_tool(mock_entry("git", "program", "Git version control"))
406 .await;
407 retriever
408 .index_tool(mock_entry("mcp-github", "mcp", "GitHub API bridge"))
409 .await;
410
411 let query = EmbeddingVector::DenseF32(vec![1.0, 0.5, 0.5]);
412 let results = retriever.retrieve(&query, 2);
413
414 assert_eq!(results.len(), 2);
415 assert!(results[0].score >= results[1].score);
417 }
418
419 #[tokio::test]
420 async fn test_retrieve_exceeds_index() {
421 let embedder = Arc::new(MockEmbedder);
422 let mut retriever = ToolRetriever::new(embedder);
423
424 retriever
425 .index_tool(mock_entry("exec", "os-tool", "Run commands"))
426 .await;
427
428 let query = EmbeddingVector::DenseF32(vec![1.0, 0.5, 0.5]);
429 let results = retriever.retrieve(&query, 10);
430
431 assert_eq!(results.len(), 1);
433 }
434
435 #[tokio::test]
436 async fn test_retrieve_empty_index() {
437 let embedder = Arc::new(MockEmbedder);
438 let retriever = ToolRetriever::new(embedder);
439
440 let query = EmbeddingVector::DenseF32(vec![1.0, 0.5, 0.5]);
441 let results = retriever.retrieve(&query, 5);
442
443 assert!(results.is_empty());
444 }
445
446 #[tokio::test]
447 async fn test_entries() {
448 let embedder = Arc::new(MockEmbedder);
449 let mut retriever = ToolRetriever::new(embedder);
450
451 retriever
452 .index_tool(mock_entry("exec", "os-tool", "Run commands"))
453 .await;
454 retriever
455 .index_tool(mock_entry("git", "program", "Git ops"))
456 .await;
457
458 let entries = retriever.entries();
459 assert_eq!(entries.len(), 2);
460 assert_eq!(entries[0].name, "exec");
461 assert_eq!(entries[1].name, "git");
462 }
463
464 #[tokio::test]
465 async fn test_clear() {
466 let embedder = Arc::new(MockEmbedder);
467 let mut retriever = ToolRetriever::new(embedder);
468
469 retriever
470 .index_tool(mock_entry("exec", "os-tool", "Run commands"))
471 .await;
472 assert_eq!(retriever.len(), 1);
473
474 retriever.clear();
475 assert!(retriever.is_empty());
476 }
477
478 #[test]
479 fn test_format_capability_index_basic() {
480 let tool = ScoredTool {
481 entry: ToolEntry {
482 name: "exec".into(),
483 category: "os-tool".into(),
484 description: "Execute shell commands".into(),
485 skill_path: None,
486 command: None,
487 },
488 score: 0.95,
489 };
490
491 let xml = format_capability_index(&[tool]);
492 assert!(xml.contains("<available_capabilities>"));
493 assert!(xml.contains("<name>exec</name>"));
494 assert!(xml.contains("<category>os-tool</category>"));
495 assert!(xml.contains("<description>Execute shell commands</description>"));
496 assert!(xml.contains("</available_capabilities>"));
497 assert!(!xml.contains("<command>"));
499 assert!(!xml.contains("<skill>"));
500 }
501
502 #[test]
503 fn test_format_capability_index_program() {
504 let tool = ScoredTool {
505 entry: ToolEntry {
506 name: "git-helper".into(),
507 category: "program".into(),
508 description: "Git workflow automation".into(),
509 skill_path: Some("programs/git-helper/SKILL.md".into()),
510 command: Some("git-helper".into()),
511 },
512 score: 0.88,
513 };
514
515 let xml = format_capability_index(&[tool]);
516 assert!(xml.contains("<command>git-helper</command>"));
517 assert!(xml.contains("<skill>programs/git-helper/SKILL.md</skill>"));
518 }
519
520 #[test]
521 fn test_format_capability_index_xml_escaping() {
522 let tool = ScoredTool {
523 entry: ToolEntry {
524 name: "test<>&".into(),
525 category: "os-tool".into(),
526 description: "A & B < C > D".into(),
527 skill_path: None,
528 command: None,
529 },
530 score: 1.0,
531 };
532
533 let xml = format_capability_index(&[tool]);
534 assert!(xml.contains("<name>test<>&</name>"));
535 assert!(xml.contains("<description>A & B < C > D</description>"));
536 }
537
538 #[test]
539 fn test_escape_xml() {
540 assert_eq!(escape_xml("hello"), "hello");
541 assert_eq!(
542 escape_xml("a&b<c>d\"e'f"),
543 "a&b<c>d"e'f"
544 );
545 }
546
547 #[test]
548 fn test_build_kernel_manifest() {
549 let md = build_kernel_manifest(&["space", "agent", "memory", "knowledge", "program"]);
550 assert!(md.contains("## Kernel Manifest"));
551 assert!(md.contains("Active domains: space, agent, memory, knowledge, program"));
552 assert!(md.contains("### space"));
553 assert!(md.contains("### agent"));
554 assert!(md.contains("### memory"));
555 assert!(md.contains("### knowledge"));
556 assert!(md.contains("### program"));
557 assert!(!md.contains("### security"));
558 }
559
560 #[test]
561 fn test_build_kernel_manifest_filters_unknown() {
562 let md = build_kernel_manifest(&["space", "unknown-domain"]);
563 assert!(md.contains("### space"));
564 assert!(!md.contains("unknown-domain"));
565 }
566
567 #[test]
568 fn test_build_kernel_manifest_empty() {
569 let md = build_kernel_manifest(&[]);
570 assert!(md.contains("## Kernel Manifest"));
571 assert!(md.contains("Active domains:"));
572 }
573
574 #[test]
575 fn test_tool_entry_embedding_text() {
576 let entry = mock_entry("exec", "os-tool", "Run commands");
577 let text = entry.embedding_text();
578 assert!(text.contains("[os-tool] exec: Run commands"));
579 }
580
581 #[test]
582 fn test_tool_entry_embedding_text_with_command() {
583 let entry = ToolEntry {
584 name: "git".into(),
585 category: "program".into(),
586 description: "Git ops".into(),
587 skill_path: None,
588 command: Some("git binary".into()),
589 };
590 let text = entry.embedding_text();
591 assert!(text.contains("command: git binary"));
592 }
593
594 #[tokio::test]
595 async fn test_embedder_accessor() {
596 let embedder = Arc::new(MockEmbedder);
597 let retriever = ToolRetriever::new(embedder);
598 assert_eq!(retriever.embedder().name(), "mock");
599 }
600
601 #[tokio::test]
604 async fn test_with_tfidf_embedder() {
605 use crate::embedding::TfIdfEmbeddingProvider;
606
607 let embedder = Arc::new(TfIdfEmbeddingProvider);
608 let mut retriever = ToolRetriever::new(embedder);
609
610 retriever
611 .index_tool(ToolEntry {
612 name: "exec".into(),
613 category: "os-tool".into(),
614 description: "Execute shell commands in workspace".into(),
615 skill_path: None,
616 command: None,
617 })
618 .await;
619 retriever
620 .index_tool(ToolEntry {
621 name: "memory-search".into(),
622 category: "os-tool".into(),
623 description: "Search persistent vector memory".into(),
624 skill_path: None,
625 command: None,
626 })
627 .await;
628
629 let query_embedding = retriever
630 .embedder()
631 .embed("run a bash command")
632 .await
633 .unwrap();
634 let results = retriever.retrieve(&query_embedding, 2);
635
636 assert_eq!(results.len(), 2);
637 assert_eq!(results[0].entry.name, "exec");
639 }
640}