Skip to main content

oxios_kernel/tools/
retrieval.rs

1//! Semantic search engine for OS capabilities.
2//!
3//! [`ToolRetriever`] maintains an in-memory index of all available tools
4//! (built-in OS tools, installed programs, OS services, and MCP bridges)
5//! and retrieves the most relevant ones for a given query using the
6//! embedding module's cosine similarity.
7//!
8//! # Usage
9//!
10//! ```no_run
11//! use std::sync::Arc;
12//! use oxios_kernel::embedding::TfIdfEmbeddingProvider;
13//! use oxios_kernel::tools::retrieval::{ToolRetriever, ToolEntry};
14//!
15//! # async fn example() {
16//! let embedder = Arc::new(TfIdfEmbeddingProvider);
17//! let mut retriever = ToolRetriever::new(embedder);
18//!
19//! let tool = ToolEntry {
20//!     name: "exec".into(),
21//!     category: "os-tool".into(),
22//!     description: "Execute a shell command in a workspace".into(),
23//!     skill_path: None,
24//!     command: None,
25//! };
26//! retriever.index_tool(tool).await;
27//! # }
28
29use std::sync::Arc;
30
31use serde::{Deserialize, Serialize};
32
33use crate::embedding::{EmbeddingProvider, EmbeddingVector};
34
35// ---------------------------------------------------------------------------
36// Types
37// ---------------------------------------------------------------------------
38
39/// A searchable entry in the tool index.
40///
41/// Each entry describes a single capability that the agent OS exposes,
42/// such as a built-in execution tool, an installed program, an OS service,
43/// or an MCP bridge.
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ToolEntry {
46    /// Unique capability name (e.g. `"exec"`, `"git-helper"`, `"mcp:github"`).
47    pub name: String,
48    /// Category of the capability.
49    ///
50    /// One of: `"os-tool"`, `"program"`, `"os-service"`, `"mcp"`.
51    pub category: String,
52    /// Human-readable description used both for indexing and for the
53    /// capability index presented to agents.
54    pub description: String,
55    /// Path to the SKILL.md instruction file, if this is a program.
56    pub skill_path: Option<String>,
57    /// Invocation command, if this is a program that can be called directly.
58    pub command: Option<String>,
59}
60
61impl ToolEntry {
62    /// Produce the text that will be embedded for semantic search.
63    ///
64    /// Combines name, category, and description into a single string so that
65    /// the embedding captures all relevant semantics.
66    fn embedding_text(&self) -> String {
67        let mut parts = format!("[{}] {}: {}", self.category, self.name, self.description);
68        if let Some(ref cmd) = self.command {
69            parts.push_str(&format!(" command: {cmd}"));
70        }
71        parts
72    }
73}
74
75/// A tool entry together with its pre-computed embedding vector.
76#[derive(Debug, Clone)]
77struct IndexedTool {
78    entry: ToolEntry,
79    vector: EmbeddingVector,
80}
81
82/// A tool ranked by relevance to a query.
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ScoredTool {
85    /// The tool entry that matched.
86    pub entry: ToolEntry,
87    /// Cosine similarity score in `[0.0, 1.0]`.  Higher is more relevant.
88    pub score: f64,
89}
90
91// ---------------------------------------------------------------------------
92// ToolRetriever
93// ---------------------------------------------------------------------------
94
95/// Semantic search engine for OS capabilities.
96///
97/// Maintains an in-memory vector index of all registered tools and supports
98/// top-K retrieval via cosine similarity against a query embedding.
99pub struct ToolRetriever {
100    /// The indexed tools with their pre-computed embeddings.
101    index: Vec<IndexedTool>,
102    /// The embedding provider used to vectorize tool descriptions.
103    embedder: Arc<dyn EmbeddingProvider>,
104}
105
106impl ToolRetriever {
107    /// Create a new, empty retriever backed by the given embedder.
108    pub fn new(embedder: Arc<dyn EmbeddingProvider>) -> Self {
109        Self {
110            index: Vec::new(),
111            embedder,
112        }
113    }
114
115    /// Return a reference to the underlying embedder.
116    ///
117    /// Useful when the caller needs to compute a query embedding before
118    /// calling [`retrieve`](Self::retrieve).
119    pub fn embedder(&self) -> &Arc<dyn EmbeddingProvider> {
120        &self.embedder
121    }
122
123    /// Add a tool to the index.
124    ///
125    /// The tool's description is embedded immediately using the configured
126    /// provider.  If the embedding fails the tool is silently skipped
127    /// (logged at warn level in future telemetry).
128    pub async fn index_tool(&mut self, entry: ToolEntry) {
129        let text = entry.embedding_text();
130        match self.embedder.embed(&text).await {
131            Ok(vector) => {
132                self.index.push(IndexedTool { entry, vector });
133            }
134            Err(e) => {
135                tracing::warn!(name = %entry.name, error = %e, "failed to embed tool, skipping");
136            }
137        }
138    }
139
140    /// Retrieve the top-K tools most relevant to the given query embedding.
141    ///
142    /// The `query_embedding` is compared against every indexed tool using
143    /// [`EmbeddingVector::cosine_similarity`].  Results are sorted by score
144    /// descending.
145    ///
146    /// If `top_k` exceeds the number of indexed tools, all tools are returned.
147    pub fn retrieve(&self, query_embedding: &EmbeddingVector, top_k: usize) -> Vec<ScoredTool> {
148        let mut scored: Vec<ScoredTool> = self
149            .index
150            .iter()
151            .map(|indexed| {
152                let score = query_embedding.cosine_similarity(&indexed.vector);
153                ScoredTool {
154                    entry: indexed.entry.clone(),
155                    score,
156                }
157            })
158            .collect();
159
160        // Sort descending by score.
161        scored.sort_by(|a, b| {
162            b.score
163                .partial_cmp(&a.score)
164                .unwrap_or(std::cmp::Ordering::Equal)
165        });
166
167        scored.truncate(top_k);
168        scored
169    }
170
171    /// Number of indexed tools.
172    pub fn len(&self) -> usize {
173        self.index.len()
174    }
175
176    /// Returns `true` if no tools have been indexed.
177    pub fn is_empty(&self) -> bool {
178        self.index.is_empty()
179    }
180
181    /// Get all indexed entries (for capability index generation or debugging).
182    pub fn entries(&self) -> Vec<&ToolEntry> {
183        self.index.iter().map(|i| &i.entry).collect()
184    }
185
186    /// Remove all indexed tools.
187    pub fn clear(&mut self) {
188        self.index.clear();
189    }
190}
191
192// ---------------------------------------------------------------------------
193// Capability index formatting
194// ---------------------------------------------------------------------------
195
196/// Format retrieved tools as an XML capability index suitable for injection
197/// into an agent's system prompt.
198///
199/// Example output:
200///
201/// ```xml
202/// <available_capabilities>
203///   <capability>
204///     <name>exec</name>
205///     <category>os-tool</category>
206///     <description>Execute a shell command in a workspace</description>
207///   </capability>
208///   <capability>
209///     <name>git-helper</name>
210///     <category>program</category>
211///     <description>Git workflow automation</description>
212///     <command>git-helper</command>
213///     <skill>programs/git-helper/SKILL.md</skill>
214///   </capability>
215/// </available_capabilities>
216/// ```
217pub fn format_capability_index(tools: &[ScoredTool]) -> String {
218    let mut xml = String::from("<available_capabilities>\n");
219
220    for tool in tools {
221        xml.push_str("  <capability>\n");
222        xml.push_str(&format!(
223            "    <name>{}</name>\n",
224            escape_xml(&tool.entry.name)
225        ));
226        xml.push_str(&format!(
227            "    <category>{}</category>\n",
228            escape_xml(&tool.entry.category)
229        ));
230        xml.push_str(&format!(
231            "    <description>{}</description>\n",
232            escape_xml(&tool.entry.description)
233        ));
234        if let Some(ref cmd) = tool.entry.command {
235            xml.push_str(&format!("    <command>{}</command>\n", escape_xml(cmd)));
236        }
237        if let Some(ref skill) = tool.entry.skill_path {
238            xml.push_str(&format!("    <skill>{}</skill>\n", escape_xml(skill)));
239        }
240        xml.push_str("  </capability>\n");
241    }
242
243    xml.push_str("</available_capabilities>");
244    xml
245}
246
247/// Escape special XML characters in a string.
248fn escape_xml(s: &str) -> String {
249    let mut out = String::with_capacity(s.len());
250    for c in s.chars() {
251        match c {
252            '&' => out.push_str("&amp;"),
253            '<' => out.push_str("&lt;"),
254            '>' => out.push_str("&gt;"),
255            '"' => out.push_str("&quot;"),
256            '\'' => out.push_str("&apos;"),
257            _ => out.push(c),
258        }
259    }
260    out
261}
262
263// ---------------------------------------------------------------------------
264// Kernel manifest
265// ---------------------------------------------------------------------------
266
267/// Well-known domain names that can appear in a kernel manifest.
268const KNOWN_DOMAINS: &[&str] = &[
269    "space", "agent", "a2a", "memory", "security", "budget", "resource", "program",
270];
271
272/// Build a markdown kernel manifest from the set of active domains.
273///
274/// The manifest lists which subsystems are currently enabled so that an
275/// agent can discover the OS's capabilities at a glance.
276///
277/// # Example
278///
279/// ```text
280/// ## Kernel Manifest
281///
282/// Active domains: space, agent, memory, program
283///
284/// ### space
285/// Filesystem workspace management and conversation buffers.
286///
287/// ### agent
288/// ...
289/// ```
290pub fn build_kernel_manifest(active_domains: &[&str]) -> String {
291    let mut md = String::from("## Kernel Manifest\n\n");
292
293    let domain_list: Vec<&str> = active_domains
294        .iter()
295        .filter(|d| KNOWN_DOMAINS.contains(d))
296        .copied()
297        .collect();
298
299    md.push_str(&format!("Active domains: {}\n\n", domain_list.join(", ")));
300
301    for domain in &domain_list {
302        let description = domain_description(domain);
303        md.push_str(&format!("### {domain}\n{description}\n\n"));
304    }
305
306    md
307}
308
309/// Return a short human-readable description for a known domain.
310fn domain_description(domain: &str) -> &'static str {
311    match domain {
312        "space" => "Filesystem workspace management and conversation buffers.",
313        "agent" => "Agent lifecycle, runtime, and supervisor.",
314        "a2a" => "Agent-to-agent communication and delegation.",
315        "memory" => "Persistent vector memory and semantic search.",
316        "security" => "RBAC access control and audit trail.",
317        "budget" => "Token and cost budget enforcement.",
318        "resource" => "System resource monitoring and overload protection.",
319        "program" => "Installable OS-level programs and tools.",
320        _ => "Unknown domain.",
321    }
322}
323
324// ---------------------------------------------------------------------------
325// Tests
326// ---------------------------------------------------------------------------
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331
332    /// A trivial embedder that maps every non-empty text to a fixed dense
333    /// vector.  Useful for unit-testing retrieval logic without depending
334    /// on TF-IDF or external models.
335    struct MockEmbedder;
336
337    #[async_trait::async_trait]
338    impl EmbeddingProvider for MockEmbedder {
339        async fn embed(&self, text: &str) -> anyhow::Result<EmbeddingVector> {
340            if text.is_empty() {
341                return Ok(EmbeddingVector::DenseF32(vec![]));
342            }
343            // Produce a deterministic vector based on text length so different
344            // texts get different vectors.
345            let len = text.len() as f32;
346            Ok(EmbeddingVector::DenseF32(vec![1.0, len / 100.0, 0.5]))
347        }
348
349        fn name(&self) -> &str {
350            "mock"
351        }
352    }
353
354    fn mock_entry(name: &str, category: &str, desc: &str) -> ToolEntry {
355        ToolEntry {
356            name: name.to_string(),
357            category: category.to_string(),
358            description: desc.to_string(),
359            skill_path: None,
360            command: None,
361        }
362    }
363
364    #[tokio::test]
365    async fn test_index_and_len() {
366        let embedder = Arc::new(MockEmbedder);
367        let mut retriever = ToolRetriever::new(embedder);
368
369        assert!(retriever.is_empty());
370        assert_eq!(retriever.len(), 0);
371
372        retriever
373            .index_tool(mock_entry("exec", "os-tool", "Run commands"))
374            .await;
375        retriever
376            .index_tool(mock_entry("git", "program", "Git operations"))
377            .await;
378
379        assert_eq!(retriever.len(), 2);
380        assert!(!retriever.is_empty());
381    }
382
383    #[tokio::test]
384    async fn test_retrieve_top_k() {
385        let embedder = Arc::new(MockEmbedder);
386        let mut retriever = ToolRetriever::new(embedder);
387
388        retriever
389            .index_tool(mock_entry("exec", "os-tool", "Run shell commands"))
390            .await;
391        retriever
392            .index_tool(mock_entry("git", "program", "Git version control"))
393            .await;
394        retriever
395            .index_tool(mock_entry("mcp-github", "mcp", "GitHub API bridge"))
396            .await;
397
398        let query = EmbeddingVector::DenseF32(vec![1.0, 0.5, 0.5]);
399        let results = retriever.retrieve(&query, 2);
400
401        assert_eq!(results.len(), 2);
402        // Results should be sorted by score descending.
403        assert!(results[0].score >= results[1].score);
404    }
405
406    #[tokio::test]
407    async fn test_retrieve_exceeds_index() {
408        let embedder = Arc::new(MockEmbedder);
409        let mut retriever = ToolRetriever::new(embedder);
410
411        retriever
412            .index_tool(mock_entry("exec", "os-tool", "Run commands"))
413            .await;
414
415        let query = EmbeddingVector::DenseF32(vec![1.0, 0.5, 0.5]);
416        let results = retriever.retrieve(&query, 10);
417
418        // Should return all available tools, not panic.
419        assert_eq!(results.len(), 1);
420    }
421
422    #[tokio::test]
423    async fn test_retrieve_empty_index() {
424        let embedder = Arc::new(MockEmbedder);
425        let retriever = ToolRetriever::new(embedder);
426
427        let query = EmbeddingVector::DenseF32(vec![1.0, 0.5, 0.5]);
428        let results = retriever.retrieve(&query, 5);
429
430        assert!(results.is_empty());
431    }
432
433    #[tokio::test]
434    async fn test_entries() {
435        let embedder = Arc::new(MockEmbedder);
436        let mut retriever = ToolRetriever::new(embedder);
437
438        retriever
439            .index_tool(mock_entry("exec", "os-tool", "Run commands"))
440            .await;
441        retriever
442            .index_tool(mock_entry("git", "program", "Git ops"))
443            .await;
444
445        let entries = retriever.entries();
446        assert_eq!(entries.len(), 2);
447        assert_eq!(entries[0].name, "exec");
448        assert_eq!(entries[1].name, "git");
449    }
450
451    #[tokio::test]
452    async fn test_clear() {
453        let embedder = Arc::new(MockEmbedder);
454        let mut retriever = ToolRetriever::new(embedder);
455
456        retriever
457            .index_tool(mock_entry("exec", "os-tool", "Run commands"))
458            .await;
459        assert_eq!(retriever.len(), 1);
460
461        retriever.clear();
462        assert!(retriever.is_empty());
463    }
464
465    #[test]
466    fn test_format_capability_index_basic() {
467        let tool = ScoredTool {
468            entry: ToolEntry {
469                name: "exec".into(),
470                category: "os-tool".into(),
471                description: "Execute shell commands".into(),
472                skill_path: None,
473                command: None,
474            },
475            score: 0.95,
476        };
477
478        let xml = format_capability_index(&[tool]);
479        assert!(xml.contains("<available_capabilities>"));
480        assert!(xml.contains("<name>exec</name>"));
481        assert!(xml.contains("<category>os-tool</category>"));
482        assert!(xml.contains("<description>Execute shell commands</description>"));
483        assert!(xml.contains("</available_capabilities>"));
484        // No command/skill tags for os-tool.
485        assert!(!xml.contains("<command>"));
486        assert!(!xml.contains("<skill>"));
487    }
488
489    #[test]
490    fn test_format_capability_index_program() {
491        let tool = ScoredTool {
492            entry: ToolEntry {
493                name: "git-helper".into(),
494                category: "program".into(),
495                description: "Git workflow automation".into(),
496                skill_path: Some("programs/git-helper/SKILL.md".into()),
497                command: Some("git-helper".into()),
498            },
499            score: 0.88,
500        };
501
502        let xml = format_capability_index(&[tool]);
503        assert!(xml.contains("<command>git-helper</command>"));
504        assert!(xml.contains("<skill>programs/git-helper/SKILL.md</skill>"));
505    }
506
507    #[test]
508    fn test_format_capability_index_xml_escaping() {
509        let tool = ScoredTool {
510            entry: ToolEntry {
511                name: "test<>&".into(),
512                category: "os-tool".into(),
513                description: "A & B < C > D".into(),
514                skill_path: None,
515                command: None,
516            },
517            score: 1.0,
518        };
519
520        let xml = format_capability_index(&[tool]);
521        assert!(xml.contains("<name>test&lt;&gt;&amp;</name>"));
522        assert!(xml.contains("<description>A &amp; B &lt; C &gt; D</description>"));
523    }
524
525    #[test]
526    fn test_escape_xml() {
527        assert_eq!(escape_xml("hello"), "hello");
528        assert_eq!(
529            escape_xml("a&b<c>d\"e'f"),
530            "a&amp;b&lt;c&gt;d&quot;e&apos;f"
531        );
532    }
533
534    #[test]
535    fn test_build_kernel_manifest() {
536        let md = build_kernel_manifest(&["space", "agent", "memory", "program"]);
537        assert!(md.contains("## Kernel Manifest"));
538        assert!(md.contains("Active domains: space, agent, memory, program"));
539        assert!(md.contains("### space"));
540        assert!(md.contains("### agent"));
541        assert!(md.contains("### memory"));
542        assert!(md.contains("### program"));
543        assert!(!md.contains("### security"));
544    }
545
546    #[test]
547    fn test_build_kernel_manifest_filters_unknown() {
548        let md = build_kernel_manifest(&["space", "unknown-domain"]);
549        assert!(md.contains("### space"));
550        assert!(!md.contains("unknown-domain"));
551    }
552
553    #[test]
554    fn test_build_kernel_manifest_empty() {
555        let md = build_kernel_manifest(&[]);
556        assert!(md.contains("## Kernel Manifest"));
557        assert!(md.contains("Active domains:"));
558    }
559
560    #[test]
561    fn test_tool_entry_embedding_text() {
562        let entry = mock_entry("exec", "os-tool", "Run commands");
563        let text = entry.embedding_text();
564        assert!(text.contains("[os-tool] exec: Run commands"));
565    }
566
567    #[test]
568    fn test_tool_entry_embedding_text_with_command() {
569        let entry = ToolEntry {
570            name: "git".into(),
571            category: "program".into(),
572            description: "Git ops".into(),
573            skill_path: None,
574            command: Some("git binary".into()),
575        };
576        let text = entry.embedding_text();
577        assert!(text.contains("command: git binary"));
578    }
579
580    #[tokio::test]
581    async fn test_embedder_accessor() {
582        let embedder = Arc::new(MockEmbedder);
583        let retriever = ToolRetriever::new(embedder);
584        assert_eq!(retriever.embedder().name(), "mock");
585    }
586
587    // --- Integration-style test with TfIdfEmbeddingProvider ---
588
589    #[tokio::test]
590    async fn test_with_tfidf_embedder() {
591        use crate::embedding::TfIdfEmbeddingProvider;
592
593        let embedder = Arc::new(TfIdfEmbeddingProvider);
594        let mut retriever = ToolRetriever::new(embedder);
595
596        retriever
597            .index_tool(ToolEntry {
598                name: "exec".into(),
599                category: "os-tool".into(),
600                description: "Execute shell commands in workspace".into(),
601                skill_path: None,
602                command: None,
603            })
604            .await;
605        retriever
606            .index_tool(ToolEntry {
607                name: "memory-search".into(),
608                category: "os-tool".into(),
609                description: "Search persistent vector memory".into(),
610                skill_path: None,
611                command: None,
612            })
613            .await;
614
615        let query_embedding = retriever
616            .embedder()
617            .embed("run a bash command")
618            .await
619            .unwrap();
620        let results = retriever.retrieve(&query_embedding, 2);
621
622        assert_eq!(results.len(), 2);
623        // "exec" should score higher for "run a bash command" query.
624        assert_eq!(results[0].entry.name, "exec");
625    }
626}