Skip to main content

oxios_kernel/tools/
retrieval.rs

1//! Semantic search engine for OS capabilities.
2//!
3//! [`ToolRetriever`] maintains an in-memory index of all available tools
4//! (built-in OS tools, installed programs, OS services, and MCP bridges)
5//! and retrieves the most relevant ones for a given query using the
6//! embedding module's cosine similarity.
7//!
8//! # Usage
9//!
10//! ```no_run
11//! use std::sync::Arc;
12//! use oxios_kernel::embedding::TfIdfEmbeddingProvider;
13//! use oxios_kernel::tools::retrieval::{ToolRetriever, ToolEntry};
14//!
15//! # async fn example() {
16//! let embedder = Arc::new(TfIdfEmbeddingProvider);
17//! let mut retriever = ToolRetriever::new(embedder);
18//!
19//! let tool = ToolEntry {
20//!     name: "exec".into(),
21//!     category: "os-tool".into(),
22//!     description: "Execute a shell command in a workspace".into(),
23//!     skill_path: None,
24//!     command: None,
25//! };
26//! retriever.index_tool(tool).await;
27//! # }
28
29use std::sync::Arc;
30
31use serde::{Deserialize, Serialize};
32
33use crate::embedding::{EmbeddingProvider, EmbeddingVector};
34
35// ---------------------------------------------------------------------------
36// Types
37// ---------------------------------------------------------------------------
38
39/// A searchable entry in the tool index.
40///
41/// Each entry describes a single capability that the agent OS exposes,
42/// such as a built-in execution tool, an installed program, an OS service,
43/// or an MCP bridge.
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ToolEntry {
46    /// Unique capability name (e.g. `"exec"`, `"git-helper"`, `"mcp:github"`).
47    pub name: String,
48    /// Category of the capability.
49    ///
50    /// One of: `"os-tool"`, `"program"`, `"os-service"`, `"mcp"`.
51    pub category: String,
52    /// Human-readable description used both for indexing and for the
53    /// capability index presented to agents.
54    pub description: String,
55    /// Path to the SKILL.md instruction file, if this is a program.
56    pub skill_path: Option<String>,
57    /// Invocation command, if this is a program that can be called directly.
58    pub command: Option<String>,
59}
60
61impl ToolEntry {
62    /// Produce the text that will be embedded for semantic search.
63    ///
64    /// Combines name, category, and description into a single string so that
65    /// the embedding captures all relevant semantics.
66    fn embedding_text(&self) -> String {
67        let mut parts = format!("[{}] {}: {}", self.category, self.name, self.description);
68        if let Some(ref cmd) = self.command {
69            parts.push_str(&format!(" command: {cmd}"));
70        }
71        parts
72    }
73}
74
75/// A tool entry together with its pre-computed embedding vector.
76#[derive(Debug, Clone)]
77struct IndexedTool {
78    entry: ToolEntry,
79    vector: EmbeddingVector,
80}
81
82/// A tool ranked by relevance to a query.
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ScoredTool {
85    /// The tool entry that matched.
86    pub entry: ToolEntry,
87    /// Cosine similarity score in `[0.0, 1.0]`.  Higher is more relevant.
88    pub score: f64,
89}
90
91// ---------------------------------------------------------------------------
92// ToolRetriever
93// ---------------------------------------------------------------------------
94
95/// Semantic search engine for OS capabilities.
96///
97/// Maintains an in-memory vector index of all registered tools and supports
98/// top-K retrieval via cosine similarity against a query embedding.
99pub struct ToolRetriever {
100    /// The indexed tools with their pre-computed embeddings.
101    index: Vec<IndexedTool>,
102    /// The embedding provider used to vectorize tool descriptions.
103    embedder: Arc<dyn EmbeddingProvider>,
104}
105
106impl ToolRetriever {
107    /// Create a new, empty retriever backed by the given embedder.
108    pub fn new(embedder: Arc<dyn EmbeddingProvider>) -> Self {
109        Self {
110            index: Vec::new(),
111            embedder,
112        }
113    }
114
115    /// Return a reference to the underlying embedder.
116    ///
117    /// Useful when the caller needs to compute a query embedding before
118    /// calling [`retrieve`](Self::retrieve).
119    pub fn embedder(&self) -> &Arc<dyn EmbeddingProvider> {
120        &self.embedder
121    }
122
123    /// Add a tool to the index.
124    ///
125    /// The tool's description is embedded immediately using the configured
126    /// provider.  If the embedding fails the tool is silently skipped
127    /// (logged at warn level in future telemetry).
128    pub async fn index_tool(&mut self, entry: ToolEntry) {
129        let text = entry.embedding_text();
130        match self.embedder.embed(&text).await {
131            Ok(vector) => {
132                self.index.push(IndexedTool { entry, vector });
133            }
134            Err(e) => {
135                tracing::warn!(name = %entry.name, error = %e, "failed to embed tool, skipping");
136            }
137        }
138    }
139
140    /// Retrieve the top-K tools most relevant to the given query embedding.
141    ///
142    /// The `query_embedding` is compared against every indexed tool using
143    /// [`EmbeddingVector::cosine_similarity`].  Results are sorted by score
144    /// descending.
145    ///
146    /// If `top_k` exceeds the number of indexed tools, all tools are returned.
147    pub fn retrieve(&self, query_embedding: &EmbeddingVector, top_k: usize) -> Vec<ScoredTool> {
148        let mut scored: Vec<ScoredTool> = self
149            .index
150            .iter()
151            .map(|indexed| {
152                let score = query_embedding.cosine_similarity(&indexed.vector);
153                ScoredTool {
154                    entry: indexed.entry.clone(),
155                    score,
156                }
157            })
158            .collect();
159
160        // Sort descending by score.
161        scored.sort_by(|a, b| {
162            b.score
163                .partial_cmp(&a.score)
164                .unwrap_or(std::cmp::Ordering::Equal)
165        });
166
167        scored.truncate(top_k);
168        scored
169    }
170
171    /// Number of indexed tools.
172    pub fn len(&self) -> usize {
173        self.index.len()
174    }
175
176    /// Returns `true` if no tools have been indexed.
177    pub fn is_empty(&self) -> bool {
178        self.index.is_empty()
179    }
180
181    /// Get all indexed entries (for capability index generation or debugging).
182    pub fn entries(&self) -> Vec<&ToolEntry> {
183        self.index.iter().map(|i| &i.entry).collect()
184    }
185
186    /// Remove all indexed tools.
187    pub fn clear(&mut self) {
188        self.index.clear();
189    }
190}
191
192// ---------------------------------------------------------------------------
193// Capability index formatting
194// ---------------------------------------------------------------------------
195
196/// Format retrieved tools as an XML capability index suitable for injection
197/// into an agent's system prompt.
198///
199/// Example output:
200///
201/// ```xml
202/// <available_capabilities>
203///   <capability>
204///     <name>exec</name>
205///     <category>os-tool</category>
206///     <description>Execute a shell command in a workspace</description>
207///   </capability>
208///   <capability>
209///     <name>git-helper</name>
210///     <category>program</category>
211///     <description>Git workflow automation</description>
212///     <command>git-helper</command>
213///     <skill>programs/git-helper/SKILL.md</skill>
214///   </capability>
215/// </available_capabilities>
216/// ```
217pub fn format_capability_index(tools: &[ScoredTool]) -> String {
218    let mut xml = String::from("<available_capabilities>\n");
219
220    for tool in tools {
221        xml.push_str("  <capability>\n");
222        xml.push_str(&format!(
223            "    <name>{}</name>\n",
224            escape_xml(&tool.entry.name)
225        ));
226        xml.push_str(&format!(
227            "    <category>{}</category>\n",
228            escape_xml(&tool.entry.category)
229        ));
230        xml.push_str(&format!(
231            "    <description>{}</description>\n",
232            escape_xml(&tool.entry.description)
233        ));
234        if let Some(ref cmd) = tool.entry.command {
235            xml.push_str(&format!("    <command>{}</command>\n", escape_xml(cmd)));
236        }
237        if let Some(ref skill) = tool.entry.skill_path {
238            xml.push_str(&format!("    <skill>{}</skill>\n", escape_xml(skill)));
239        }
240        xml.push_str("  </capability>\n");
241    }
242
243    xml.push_str("</available_capabilities>");
244    xml
245}
246
247/// Escape special XML characters in a string.
248fn escape_xml(s: &str) -> String {
249    let mut out = String::with_capacity(s.len());
250    for c in s.chars() {
251        match c {
252            '&' => out.push_str("&amp;"),
253            '<' => out.push_str("&lt;"),
254            '>' => out.push_str("&gt;"),
255            '"' => out.push_str("&quot;"),
256            '\'' => out.push_str("&apos;"),
257            _ => out.push(c),
258        }
259    }
260    out
261}
262
263// ---------------------------------------------------------------------------
264// Kernel manifest
265// ---------------------------------------------------------------------------
266
267/// Well-known domain names that can appear in a kernel manifest.
268const KNOWN_DOMAINS: &[&str] = &[
269    "space",
270    "agent",
271    "a2a",
272    "memory",
273    "knowledge",
274    "security",
275    "budget",
276    "resource",
277    "program",
278];
279
280/// Build a markdown kernel manifest from the set of active domains.
281///
282/// The manifest lists which subsystems are currently enabled so that an
283/// agent can discover the OS's capabilities at a glance.
284///
285/// # Example
286///
287/// ```text
288/// ## Kernel Manifest
289///
290/// Active domains: space, agent, memory, program
291///
292/// ### space
293/// Filesystem workspace management and conversation buffers.
294///
295/// ### agent
296/// ...
297/// ```
298pub fn build_kernel_manifest(active_domains: &[&str]) -> String {
299    let mut md = String::from("## Kernel Manifest\n\n");
300
301    let domain_list: Vec<&str> = active_domains
302        .iter()
303        .filter(|d| KNOWN_DOMAINS.contains(d))
304        .copied()
305        .collect();
306
307    md.push_str(&format!("Active domains: {}\n\n", domain_list.join(", ")));
308
309    for domain in &domain_list {
310        let description = domain_description(domain);
311        md.push_str(&format!("### {domain}\n{description}\n\n"));
312    }
313
314    md
315}
316
317/// Return a short human-readable description for a known domain.
318fn domain_description(domain: &str) -> &'static str {
319    match domain {
320        "space" => "Filesystem workspace management and conversation buffers.",
321        "agent" => "Agent lifecycle, runtime, and supervisor.",
322        "a2a" => "Agent-to-agent communication and delegation.",
323        "memory" => {
324            "Internal agent recall — facts, preferences, behavioral patterns. Not user-visible."
325        }
326        "knowledge" => {
327            "Personal markdown vault — documents, articles, notes, journal. File-based with backlinks and full-text search."
328        }
329        "security" => "RBAC access control and audit trail.",
330        "budget" => "Token and cost budget enforcement.",
331        "resource" => "System resource monitoring and overload protection.",
332        "program" => "Installable OS-level programs and tools.",
333        _ => "Unknown domain.",
334    }
335}
336
337// ---------------------------------------------------------------------------
338// Tests
339// ---------------------------------------------------------------------------
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    /// A trivial embedder that maps every non-empty text to a fixed dense
346    /// vector.  Useful for unit-testing retrieval logic without depending
347    /// on TF-IDF or external models.
348    struct MockEmbedder;
349
350    #[async_trait::async_trait]
351    impl EmbeddingProvider for MockEmbedder {
352        async fn embed(&self, text: &str) -> anyhow::Result<EmbeddingVector> {
353            if text.is_empty() {
354                return Ok(EmbeddingVector::DenseF32(vec![]));
355            }
356            // Produce a deterministic vector based on text length so different
357            // texts get different vectors.
358            let len = text.len() as f32;
359            Ok(EmbeddingVector::DenseF32(vec![1.0, len / 100.0, 0.5]))
360        }
361
362        fn name(&self) -> &str {
363            "mock"
364        }
365    }
366
367    fn mock_entry(name: &str, category: &str, desc: &str) -> ToolEntry {
368        ToolEntry {
369            name: name.to_string(),
370            category: category.to_string(),
371            description: desc.to_string(),
372            skill_path: None,
373            command: None,
374        }
375    }
376
377    #[tokio::test]
378    async fn test_index_and_len() {
379        let embedder = Arc::new(MockEmbedder);
380        let mut retriever = ToolRetriever::new(embedder);
381
382        assert!(retriever.is_empty());
383        assert_eq!(retriever.len(), 0);
384
385        retriever
386            .index_tool(mock_entry("exec", "os-tool", "Run commands"))
387            .await;
388        retriever
389            .index_tool(mock_entry("git", "program", "Git operations"))
390            .await;
391
392        assert_eq!(retriever.len(), 2);
393        assert!(!retriever.is_empty());
394    }
395
396    #[tokio::test]
397    async fn test_retrieve_top_k() {
398        let embedder = Arc::new(MockEmbedder);
399        let mut retriever = ToolRetriever::new(embedder);
400
401        retriever
402            .index_tool(mock_entry("exec", "os-tool", "Run shell commands"))
403            .await;
404        retriever
405            .index_tool(mock_entry("git", "program", "Git version control"))
406            .await;
407        retriever
408            .index_tool(mock_entry("mcp-github", "mcp", "GitHub API bridge"))
409            .await;
410
411        let query = EmbeddingVector::DenseF32(vec![1.0, 0.5, 0.5]);
412        let results = retriever.retrieve(&query, 2);
413
414        assert_eq!(results.len(), 2);
415        // Results should be sorted by score descending.
416        assert!(results[0].score >= results[1].score);
417    }
418
419    #[tokio::test]
420    async fn test_retrieve_exceeds_index() {
421        let embedder = Arc::new(MockEmbedder);
422        let mut retriever = ToolRetriever::new(embedder);
423
424        retriever
425            .index_tool(mock_entry("exec", "os-tool", "Run commands"))
426            .await;
427
428        let query = EmbeddingVector::DenseF32(vec![1.0, 0.5, 0.5]);
429        let results = retriever.retrieve(&query, 10);
430
431        // Should return all available tools, not panic.
432        assert_eq!(results.len(), 1);
433    }
434
435    #[tokio::test]
436    async fn test_retrieve_empty_index() {
437        let embedder = Arc::new(MockEmbedder);
438        let retriever = ToolRetriever::new(embedder);
439
440        let query = EmbeddingVector::DenseF32(vec![1.0, 0.5, 0.5]);
441        let results = retriever.retrieve(&query, 5);
442
443        assert!(results.is_empty());
444    }
445
446    #[tokio::test]
447    async fn test_entries() {
448        let embedder = Arc::new(MockEmbedder);
449        let mut retriever = ToolRetriever::new(embedder);
450
451        retriever
452            .index_tool(mock_entry("exec", "os-tool", "Run commands"))
453            .await;
454        retriever
455            .index_tool(mock_entry("git", "program", "Git ops"))
456            .await;
457
458        let entries = retriever.entries();
459        assert_eq!(entries.len(), 2);
460        assert_eq!(entries[0].name, "exec");
461        assert_eq!(entries[1].name, "git");
462    }
463
464    #[tokio::test]
465    async fn test_clear() {
466        let embedder = Arc::new(MockEmbedder);
467        let mut retriever = ToolRetriever::new(embedder);
468
469        retriever
470            .index_tool(mock_entry("exec", "os-tool", "Run commands"))
471            .await;
472        assert_eq!(retriever.len(), 1);
473
474        retriever.clear();
475        assert!(retriever.is_empty());
476    }
477
478    #[test]
479    fn test_format_capability_index_basic() {
480        let tool = ScoredTool {
481            entry: ToolEntry {
482                name: "exec".into(),
483                category: "os-tool".into(),
484                description: "Execute shell commands".into(),
485                skill_path: None,
486                command: None,
487            },
488            score: 0.95,
489        };
490
491        let xml = format_capability_index(&[tool]);
492        assert!(xml.contains("<available_capabilities>"));
493        assert!(xml.contains("<name>exec</name>"));
494        assert!(xml.contains("<category>os-tool</category>"));
495        assert!(xml.contains("<description>Execute shell commands</description>"));
496        assert!(xml.contains("</available_capabilities>"));
497        // No command/skill tags for os-tool.
498        assert!(!xml.contains("<command>"));
499        assert!(!xml.contains("<skill>"));
500    }
501
502    #[test]
503    fn test_format_capability_index_program() {
504        let tool = ScoredTool {
505            entry: ToolEntry {
506                name: "git-helper".into(),
507                category: "program".into(),
508                description: "Git workflow automation".into(),
509                skill_path: Some("programs/git-helper/SKILL.md".into()),
510                command: Some("git-helper".into()),
511            },
512            score: 0.88,
513        };
514
515        let xml = format_capability_index(&[tool]);
516        assert!(xml.contains("<command>git-helper</command>"));
517        assert!(xml.contains("<skill>programs/git-helper/SKILL.md</skill>"));
518    }
519
520    #[test]
521    fn test_format_capability_index_xml_escaping() {
522        let tool = ScoredTool {
523            entry: ToolEntry {
524                name: "test<>&".into(),
525                category: "os-tool".into(),
526                description: "A & B < C > D".into(),
527                skill_path: None,
528                command: None,
529            },
530            score: 1.0,
531        };
532
533        let xml = format_capability_index(&[tool]);
534        assert!(xml.contains("<name>test&lt;&gt;&amp;</name>"));
535        assert!(xml.contains("<description>A &amp; B &lt; C &gt; D</description>"));
536    }
537
538    #[test]
539    fn test_escape_xml() {
540        assert_eq!(escape_xml("hello"), "hello");
541        assert_eq!(
542            escape_xml("a&b<c>d\"e'f"),
543            "a&amp;b&lt;c&gt;d&quot;e&apos;f"
544        );
545    }
546
547    #[test]
548    fn test_build_kernel_manifest() {
549        let md = build_kernel_manifest(&["space", "agent", "memory", "knowledge", "program"]);
550        assert!(md.contains("## Kernel Manifest"));
551        assert!(md.contains("Active domains: space, agent, memory, knowledge, program"));
552        assert!(md.contains("### space"));
553        assert!(md.contains("### agent"));
554        assert!(md.contains("### memory"));
555        assert!(md.contains("### knowledge"));
556        assert!(md.contains("### program"));
557        assert!(!md.contains("### security"));
558    }
559
560    #[test]
561    fn test_build_kernel_manifest_filters_unknown() {
562        let md = build_kernel_manifest(&["space", "unknown-domain"]);
563        assert!(md.contains("### space"));
564        assert!(!md.contains("unknown-domain"));
565    }
566
567    #[test]
568    fn test_build_kernel_manifest_empty() {
569        let md = build_kernel_manifest(&[]);
570        assert!(md.contains("## Kernel Manifest"));
571        assert!(md.contains("Active domains:"));
572    }
573
574    #[test]
575    fn test_tool_entry_embedding_text() {
576        let entry = mock_entry("exec", "os-tool", "Run commands");
577        let text = entry.embedding_text();
578        assert!(text.contains("[os-tool] exec: Run commands"));
579    }
580
581    #[test]
582    fn test_tool_entry_embedding_text_with_command() {
583        let entry = ToolEntry {
584            name: "git".into(),
585            category: "program".into(),
586            description: "Git ops".into(),
587            skill_path: None,
588            command: Some("git binary".into()),
589        };
590        let text = entry.embedding_text();
591        assert!(text.contains("command: git binary"));
592    }
593
594    #[tokio::test]
595    async fn test_embedder_accessor() {
596        let embedder = Arc::new(MockEmbedder);
597        let retriever = ToolRetriever::new(embedder);
598        assert_eq!(retriever.embedder().name(), "mock");
599    }
600
601    // --- Integration-style test with TfIdfEmbeddingProvider ---
602
603    #[tokio::test]
604    async fn test_with_tfidf_embedder() {
605        use crate::embedding::TfIdfEmbeddingProvider;
606
607        let embedder = Arc::new(TfIdfEmbeddingProvider);
608        let mut retriever = ToolRetriever::new(embedder);
609
610        retriever
611            .index_tool(ToolEntry {
612                name: "exec".into(),
613                category: "os-tool".into(),
614                description: "Execute shell commands in workspace".into(),
615                skill_path: None,
616                command: None,
617            })
618            .await;
619        retriever
620            .index_tool(ToolEntry {
621                name: "memory-search".into(),
622                category: "os-tool".into(),
623                description: "Search persistent vector memory".into(),
624                skill_path: None,
625                command: None,
626            })
627            .await;
628
629        let query_embedding = retriever
630            .embedder()
631            .embed("run a bash command")
632            .await
633            .unwrap();
634        let results = retriever.retrieve(&query_embedding, 2);
635
636        assert_eq!(results.len(), 2);
637        // "exec" should score higher for "run a bash command" query.
638        assert_eq!(results[0].entry.name, "exec");
639    }
640}