Skip to main content

oxios_kernel/tools/
retrieval.rs

1//! Semantic search engine for OS capabilities.
2//!
3//! [`ToolRetriever`] maintains an in-memory index of all available tools
4//! (built-in OS tools, installed programs, OS services, and MCP bridges)
5//! and retrieves the most relevant ones for a given query using the
6//! embedding module's cosine similarity.
7//!
8//! # Usage
9//!
10//! ```ignore
11//! use std::sync::Arc;
12//! use oxios_kernel::embedding::TfIdfEmbeddingProvider;
13//! use oxios_kernel::tools::retrieval::{ToolRetriever, ToolEntry};
14//!
15//! let embedder = Arc::new(TfIdfEmbeddingProvider);
16//! let mut retriever = ToolRetriever::new(embedder);
17//!
18//! let tool = ToolEntry {
19//!     name: "exec".into(),
20//!     category: "os-tool".into(),
21//!     description: "Execute a shell command in a workspace".into(),
22//!     skill_path: None,
23//!     command: None,
24//! };
25//! retriever.index_tool(tool).await;
26//!
27//! let query_embedding = retriever.embedder().embed("run a bash command").await?;
28//! let results = retriever.retrieve(&query_embedding, 5);
29//! ```
30
31use std::sync::Arc;
32
33use serde::{Deserialize, Serialize};
34
35use crate::embedding::{EmbeddingProvider, EmbeddingVector};
36
37// ---------------------------------------------------------------------------
38// Types
39// ---------------------------------------------------------------------------
40
41/// A searchable entry in the tool index.
42///
43/// Each entry describes a single capability that the agent OS exposes,
44/// such as a built-in execution tool, an installed program, an OS service,
45/// or an MCP bridge.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ToolEntry {
48    /// Unique capability name (e.g. `"exec"`, `"git-helper"`, `"mcp:github"`).
49    pub name: String,
50    /// Category of the capability.
51    ///
52    /// One of: `"os-tool"`, `"program"`, `"os-service"`, `"mcp"`.
53    pub category: String,
54    /// Human-readable description used both for indexing and for the
55    /// capability index presented to agents.
56    pub description: String,
57    /// Path to the SKILL.md instruction file, if this is a program.
58    pub skill_path: Option<String>,
59    /// Invocation command, if this is a program that can be called directly.
60    pub command: Option<String>,
61}
62
63impl ToolEntry {
64    /// Produce the text that will be embedded for semantic search.
65    ///
66    /// Combines name, category, and description into a single string so that
67    /// the embedding captures all relevant semantics.
68    fn embedding_text(&self) -> String {
69        let mut parts = format!("[{}] {}: {}", self.category, self.name, self.description);
70        if let Some(ref cmd) = self.command {
71            parts.push_str(&format!(" command: {}", cmd));
72        }
73        parts
74    }
75}
76
77/// A tool entry together with its pre-computed embedding vector.
78#[derive(Debug, Clone)]
79struct IndexedTool {
80    entry: ToolEntry,
81    vector: EmbeddingVector,
82}
83
84/// A tool ranked by relevance to a query.
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct ScoredTool {
87    /// The tool entry that matched.
88    pub entry: ToolEntry,
89    /// Cosine similarity score in `[0.0, 1.0]`.  Higher is more relevant.
90    pub score: f64,
91}
92
93// ---------------------------------------------------------------------------
94// ToolRetriever
95// ---------------------------------------------------------------------------
96
97/// Semantic search engine for OS capabilities.
98///
99/// Maintains an in-memory vector index of all registered tools and supports
100/// top-K retrieval via cosine similarity against a query embedding.
101pub struct ToolRetriever {
102    /// The indexed tools with their pre-computed embeddings.
103    index: Vec<IndexedTool>,
104    /// The embedding provider used to vectorize tool descriptions.
105    embedder: Arc<dyn EmbeddingProvider>,
106}
107
108impl ToolRetriever {
109    /// Create a new, empty retriever backed by the given embedder.
110    pub fn new(embedder: Arc<dyn EmbeddingProvider>) -> Self {
111        Self {
112            index: Vec::new(),
113            embedder,
114        }
115    }
116
117    /// Return a reference to the underlying embedder.
118    ///
119    /// Useful when the caller needs to compute a query embedding before
120    /// calling [`retrieve`](Self::retrieve).
121    pub fn embedder(&self) -> &Arc<dyn EmbeddingProvider> {
122        &self.embedder
123    }
124
125    /// Add a tool to the index.
126    ///
127    /// The tool's description is embedded immediately using the configured
128    /// provider.  If the embedding fails the tool is silently skipped
129    /// (logged at warn level in future telemetry).
130    pub async fn index_tool(&mut self, entry: ToolEntry) {
131        let text = entry.embedding_text();
132        match self.embedder.embed(&text).await {
133            Ok(vector) => {
134                self.index.push(IndexedTool { entry, vector });
135            }
136            Err(e) => {
137                tracing::warn!(name = %entry.name, error = %e, "failed to embed tool, skipping");
138            }
139        }
140    }
141
142    /// Retrieve the top-K tools most relevant to the given query embedding.
143    ///
144    /// The `query_embedding` is compared against every indexed tool using
145    /// [`EmbeddingVector::cosine_similarity`].  Results are sorted by score
146    /// descending.
147    ///
148    /// If `top_k` exceeds the number of indexed tools, all tools are returned.
149    pub fn retrieve(&self, query_embedding: &EmbeddingVector, top_k: usize) -> Vec<ScoredTool> {
150        let mut scored: Vec<ScoredTool> = self
151            .index
152            .iter()
153            .map(|indexed| {
154                let score = query_embedding.cosine_similarity(&indexed.vector);
155                ScoredTool {
156                    entry: indexed.entry.clone(),
157                    score,
158                }
159            })
160            .collect();
161
162        // Sort descending by score.
163        scored.sort_by(|a, b| {
164            b.score
165                .partial_cmp(&a.score)
166                .unwrap_or(std::cmp::Ordering::Equal)
167        });
168
169        scored.truncate(top_k);
170        scored
171    }
172
173    /// Number of indexed tools.
174    pub fn len(&self) -> usize {
175        self.index.len()
176    }
177
178    /// Returns `true` if no tools have been indexed.
179    pub fn is_empty(&self) -> bool {
180        self.index.is_empty()
181    }
182
183    /// Get all indexed entries (for capability index generation or debugging).
184    pub fn entries(&self) -> Vec<&ToolEntry> {
185        self.index.iter().map(|i| &i.entry).collect()
186    }
187
188    /// Remove all indexed tools.
189    pub fn clear(&mut self) {
190        self.index.clear();
191    }
192}
193
194// ---------------------------------------------------------------------------
195// Capability index formatting
196// ---------------------------------------------------------------------------
197
198/// Format retrieved tools as an XML capability index suitable for injection
199/// into an agent's system prompt.
200///
201/// Example output:
202///
203/// ```xml
204/// <available_capabilities>
205///   <capability>
206///     <name>exec</name>
207///     <category>os-tool</category>
208///     <description>Execute a shell command in a workspace</description>
209///   </capability>
210///   <capability>
211///     <name>git-helper</name>
212///     <category>program</category>
213///     <description>Git workflow automation</description>
214///     <command>git-helper</command>
215///     <skill>programs/git-helper/SKILL.md</skill>
216///   </capability>
217/// </available_capabilities>
218/// ```
219pub fn format_capability_index(tools: &[ScoredTool]) -> String {
220    let mut xml = String::from("<available_capabilities>\n");
221
222    for tool in tools {
223        xml.push_str("  <capability>\n");
224        xml.push_str(&format!(
225            "    <name>{}</name>\n",
226            escape_xml(&tool.entry.name)
227        ));
228        xml.push_str(&format!(
229            "    <category>{}</category>\n",
230            escape_xml(&tool.entry.category)
231        ));
232        xml.push_str(&format!(
233            "    <description>{}</description>\n",
234            escape_xml(&tool.entry.description)
235        ));
236        if let Some(ref cmd) = tool.entry.command {
237            xml.push_str(&format!("    <command>{}</command>\n", escape_xml(cmd)));
238        }
239        if let Some(ref skill) = tool.entry.skill_path {
240            xml.push_str(&format!("    <skill>{}</skill>\n", escape_xml(skill)));
241        }
242        xml.push_str("  </capability>\n");
243    }
244
245    xml.push_str("</available_capabilities>");
246    xml
247}
248
249/// Escape special XML characters in a string.
250fn escape_xml(s: &str) -> String {
251    let mut out = String::with_capacity(s.len());
252    for c in s.chars() {
253        match c {
254            '&' => out.push_str("&amp;"),
255            '<' => out.push_str("&lt;"),
256            '>' => out.push_str("&gt;"),
257            '"' => out.push_str("&quot;"),
258            '\'' => out.push_str("&apos;"),
259            _ => out.push(c),
260        }
261    }
262    out
263}
264
265// ---------------------------------------------------------------------------
266// Kernel manifest
267// ---------------------------------------------------------------------------
268
269/// Well-known domain names that can appear in a kernel manifest.
270const KNOWN_DOMAINS: &[&str] = &[
271    "space", "agent", "a2a", "memory", "security", "budget", "resource", "program",
272];
273
274/// Build a markdown kernel manifest from the set of active domains.
275///
276/// The manifest lists which subsystems are currently enabled so that an
277/// agent can discover the OS's capabilities at a glance.
278///
279/// # Example
280///
281/// ```text
282/// ## Kernel Manifest
283///
284/// Active domains: space, agent, memory, program
285///
286/// ### space
287/// Filesystem workspace management and conversation buffers.
288///
289/// ### agent
290/// ...
291/// ```
292pub fn build_kernel_manifest(active_domains: &[&str]) -> String {
293    let mut md = String::from("## Kernel Manifest\n\n");
294
295    let domain_list: Vec<&str> = active_domains
296        .iter()
297        .filter(|d| KNOWN_DOMAINS.contains(d))
298        .copied()
299        .collect();
300
301    md.push_str(&format!("Active domains: {}\n\n", domain_list.join(", ")));
302
303    for domain in &domain_list {
304        let description = domain_description(domain);
305        md.push_str(&format!("### {}\n{}\n\n", domain, description));
306    }
307
308    md
309}
310
311/// Return a short human-readable description for a known domain.
312fn domain_description(domain: &str) -> &'static str {
313    match domain {
314        "space" => "Filesystem workspace management and conversation buffers.",
315        "agent" => "Agent lifecycle, runtime, and supervisor.",
316        "a2a" => "Agent-to-agent communication and delegation.",
317        "memory" => "Persistent vector memory and semantic search.",
318        "security" => "RBAC access control and audit trail.",
319        "budget" => "Token and cost budget enforcement.",
320        "resource" => "System resource monitoring and overload protection.",
321        "program" => "Installable OS-level programs and tools.",
322        _ => "Unknown domain.",
323    }
324}
325
326// ---------------------------------------------------------------------------
327// Tests
328// ---------------------------------------------------------------------------
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    /// A trivial embedder that maps every non-empty text to a fixed dense
335    /// vector.  Useful for unit-testing retrieval logic without depending
336    /// on TF-IDF or external models.
337    struct MockEmbedder;
338
339    #[async_trait::async_trait]
340    impl EmbeddingProvider for MockEmbedder {
341        async fn embed(&self, text: &str) -> anyhow::Result<EmbeddingVector> {
342            if text.is_empty() {
343                return Ok(EmbeddingVector::DenseF32(vec![]));
344            }
345            // Produce a deterministic vector based on text length so different
346            // texts get different vectors.
347            let len = text.len() as f32;
348            Ok(EmbeddingVector::DenseF32(vec![1.0, len / 100.0, 0.5]))
349        }
350
351        fn name(&self) -> &str {
352            "mock"
353        }
354    }
355
356    fn mock_entry(name: &str, category: &str, desc: &str) -> ToolEntry {
357        ToolEntry {
358            name: name.to_string(),
359            category: category.to_string(),
360            description: desc.to_string(),
361            skill_path: None,
362            command: None,
363        }
364    }
365
366    #[tokio::test]
367    async fn test_index_and_len() {
368        let embedder = Arc::new(MockEmbedder);
369        let mut retriever = ToolRetriever::new(embedder);
370
371        assert!(retriever.is_empty());
372        assert_eq!(retriever.len(), 0);
373
374        retriever
375            .index_tool(mock_entry("exec", "os-tool", "Run commands"))
376            .await;
377        retriever
378            .index_tool(mock_entry("git", "program", "Git operations"))
379            .await;
380
381        assert_eq!(retriever.len(), 2);
382        assert!(!retriever.is_empty());
383    }
384
385    #[tokio::test]
386    async fn test_retrieve_top_k() {
387        let embedder = Arc::new(MockEmbedder);
388        let mut retriever = ToolRetriever::new(embedder);
389
390        retriever
391            .index_tool(mock_entry("exec", "os-tool", "Run shell commands"))
392            .await;
393        retriever
394            .index_tool(mock_entry("git", "program", "Git version control"))
395            .await;
396        retriever
397            .index_tool(mock_entry("mcp-github", "mcp", "GitHub API bridge"))
398            .await;
399
400        let query = EmbeddingVector::DenseF32(vec![1.0, 0.5, 0.5]);
401        let results = retriever.retrieve(&query, 2);
402
403        assert_eq!(results.len(), 2);
404        // Results should be sorted by score descending.
405        assert!(results[0].score >= results[1].score);
406    }
407
408    #[tokio::test]
409    async fn test_retrieve_exceeds_index() {
410        let embedder = Arc::new(MockEmbedder);
411        let mut retriever = ToolRetriever::new(embedder);
412
413        retriever
414            .index_tool(mock_entry("exec", "os-tool", "Run commands"))
415            .await;
416
417        let query = EmbeddingVector::DenseF32(vec![1.0, 0.5, 0.5]);
418        let results = retriever.retrieve(&query, 10);
419
420        // Should return all available tools, not panic.
421        assert_eq!(results.len(), 1);
422    }
423
424    #[tokio::test]
425    async fn test_retrieve_empty_index() {
426        let embedder = Arc::new(MockEmbedder);
427        let retriever = ToolRetriever::new(embedder);
428
429        let query = EmbeddingVector::DenseF32(vec![1.0, 0.5, 0.5]);
430        let results = retriever.retrieve(&query, 5);
431
432        assert!(results.is_empty());
433    }
434
435    #[tokio::test]
436    async fn test_entries() {
437        let embedder = Arc::new(MockEmbedder);
438        let mut retriever = ToolRetriever::new(embedder);
439
440        retriever
441            .index_tool(mock_entry("exec", "os-tool", "Run commands"))
442            .await;
443        retriever
444            .index_tool(mock_entry("git", "program", "Git ops"))
445            .await;
446
447        let entries = retriever.entries();
448        assert_eq!(entries.len(), 2);
449        assert_eq!(entries[0].name, "exec");
450        assert_eq!(entries[1].name, "git");
451    }
452
453    #[tokio::test]
454    async fn test_clear() {
455        let embedder = Arc::new(MockEmbedder);
456        let mut retriever = ToolRetriever::new(embedder);
457
458        retriever
459            .index_tool(mock_entry("exec", "os-tool", "Run commands"))
460            .await;
461        assert_eq!(retriever.len(), 1);
462
463        retriever.clear();
464        assert!(retriever.is_empty());
465    }
466
467    #[test]
468    fn test_format_capability_index_basic() {
469        let tool = ScoredTool {
470            entry: ToolEntry {
471                name: "exec".into(),
472                category: "os-tool".into(),
473                description: "Execute shell commands".into(),
474                skill_path: None,
475                command: None,
476            },
477            score: 0.95,
478        };
479
480        let xml = format_capability_index(&[tool]);
481        assert!(xml.contains("<available_capabilities>"));
482        assert!(xml.contains("<name>exec</name>"));
483        assert!(xml.contains("<category>os-tool</category>"));
484        assert!(xml.contains("<description>Execute shell commands</description>"));
485        assert!(xml.contains("</available_capabilities>"));
486        // No command/skill tags for os-tool.
487        assert!(!xml.contains("<command>"));
488        assert!(!xml.contains("<skill>"));
489    }
490
491    #[test]
492    fn test_format_capability_index_program() {
493        let tool = ScoredTool {
494            entry: ToolEntry {
495                name: "git-helper".into(),
496                category: "program".into(),
497                description: "Git workflow automation".into(),
498                skill_path: Some("programs/git-helper/SKILL.md".into()),
499                command: Some("git-helper".into()),
500            },
501            score: 0.88,
502        };
503
504        let xml = format_capability_index(&[tool]);
505        assert!(xml.contains("<command>git-helper</command>"));
506        assert!(xml.contains("<skill>programs/git-helper/SKILL.md</skill>"));
507    }
508
509    #[test]
510    fn test_format_capability_index_xml_escaping() {
511        let tool = ScoredTool {
512            entry: ToolEntry {
513                name: "test<>&".into(),
514                category: "os-tool".into(),
515                description: "A & B < C > D".into(),
516                skill_path: None,
517                command: None,
518            },
519            score: 1.0,
520        };
521
522        let xml = format_capability_index(&[tool]);
523        assert!(xml.contains("<name>test&lt;&gt;&amp;</name>"));
524        assert!(xml.contains("<description>A &amp; B &lt; C &gt; D</description>"));
525    }
526
527    #[test]
528    fn test_escape_xml() {
529        assert_eq!(escape_xml("hello"), "hello");
530        assert_eq!(
531            escape_xml("a&b<c>d\"e'f"),
532            "a&amp;b&lt;c&gt;d&quot;e&apos;f"
533        );
534    }
535
536    #[test]
537    fn test_build_kernel_manifest() {
538        let md = build_kernel_manifest(&["space", "agent", "memory", "program"]);
539        assert!(md.contains("## Kernel Manifest"));
540        assert!(md.contains("Active domains: space, agent, memory, program"));
541        assert!(md.contains("### space"));
542        assert!(md.contains("### agent"));
543        assert!(md.contains("### memory"));
544        assert!(md.contains("### program"));
545        assert!(!md.contains("### security"));
546    }
547
548    #[test]
549    fn test_build_kernel_manifest_filters_unknown() {
550        let md = build_kernel_manifest(&["space", "unknown-domain"]);
551        assert!(md.contains("### space"));
552        assert!(!md.contains("unknown-domain"));
553    }
554
555    #[test]
556    fn test_build_kernel_manifest_empty() {
557        let md = build_kernel_manifest(&[]);
558        assert!(md.contains("## Kernel Manifest"));
559        assert!(md.contains("Active domains:"));
560    }
561
562    #[test]
563    fn test_tool_entry_embedding_text() {
564        let entry = mock_entry("exec", "os-tool", "Run commands");
565        let text = entry.embedding_text();
566        assert!(text.contains("[os-tool] exec: Run commands"));
567    }
568
569    #[test]
570    fn test_tool_entry_embedding_text_with_command() {
571        let entry = ToolEntry {
572            name: "git".into(),
573            category: "program".into(),
574            description: "Git ops".into(),
575            skill_path: None,
576            command: Some("git binary".into()),
577        };
578        let text = entry.embedding_text();
579        assert!(text.contains("command: git binary"));
580    }
581
582    #[tokio::test]
583    async fn test_embedder_accessor() {
584        let embedder = Arc::new(MockEmbedder);
585        let retriever = ToolRetriever::new(embedder);
586        assert_eq!(retriever.embedder().name(), "mock");
587    }
588
589    // --- Integration-style test with TfIdfEmbeddingProvider ---
590
591    #[tokio::test]
592    async fn test_with_tfidf_embedder() {
593        use crate::embedding::TfIdfEmbeddingProvider;
594
595        let embedder = Arc::new(TfIdfEmbeddingProvider);
596        let mut retriever = ToolRetriever::new(embedder);
597
598        retriever
599            .index_tool(ToolEntry {
600                name: "exec".into(),
601                category: "os-tool".into(),
602                description: "Execute shell commands in workspace".into(),
603                skill_path: None,
604                command: None,
605            })
606            .await;
607        retriever
608            .index_tool(ToolEntry {
609                name: "memory-search".into(),
610                category: "os-tool".into(),
611                description: "Search persistent vector memory".into(),
612                skill_path: None,
613                command: None,
614            })
615            .await;
616
617        let query_embedding = retriever
618            .embedder()
619            .embed("run a bash command")
620            .await
621            .unwrap();
622        let results = retriever.retrieve(&query_embedding, 2);
623
624        assert_eq!(results.len(), 2);
625        // "exec" should score higher for "run a bash command" query.
626        assert_eq!(results[0].entry.name, "exec");
627    }
628}