roboticus-agent 0.11.3

Agent core with ReAct loop, policy engine, injection defense, memory system, and skill loader
Documentation
//! Semantic tool search — rank and prune tools before presenting to the LLM.
//!
//! Embeds tool descriptions at registration time, ranks against user query
//! at inference time, and prunes to top-K within a token budget. Uses the
//! `RankedCandidate` type from `ranking.rs` for convergence with memory retrieval.

use crate::ranking::{CandidateKind, RankedCandidate, top_k_with_pinned};
use serde::{Deserialize, Serialize};

/// Where a tool came from (simplified for search ranking).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ToolSourceInfo {
    BuiltIn,
    Plugin(String),
    Mcp { server: String },
}

/// A tool descriptor with its cached embedding.
#[derive(Debug, Clone)]
pub struct ToolDescriptor {
    pub name: String,
    pub description: String,
    pub token_cost: usize,
    pub source: ToolSourceInfo,
    pub embedding: Option<Vec<f32>>,
}

/// Configuration for tool search ranking.
#[derive(Debug, Clone)]
pub struct SearchConfig {
    pub top_k: usize,
    pub token_budget: usize,
    pub mcp_latency_penalty: f64,
    pub always_include: Vec<String>,
}

impl Default for SearchConfig {
    fn default() -> Self {
        Self {
            top_k: 15,
            token_budget: 4000,
            mcp_latency_penalty: 0.05,
            always_include: vec!["memory_store".into(), "delegate".into()],
        }
    }
}

/// Rank tools by cosine similarity to the query embedding.
pub fn rank_tools(
    tools: &[ToolDescriptor],
    query_embedding: &[f32],
    config: &SearchConfig,
) -> Vec<RankedCandidate> {
    let mut candidates: Vec<RankedCandidate> = tools
        .iter()
        .map(|tool| {
            let raw_score = tool
                .embedding
                .as_ref()
                .map(|emb| cosine_similarity(emb, query_embedding))
                .unwrap_or(0.0);

            let penalty = match &tool.source {
                ToolSourceInfo::Mcp { .. } => config.mcp_latency_penalty,
                _ => 0.0,
            };

            RankedCandidate {
                source_id: tool.name.clone(),
                source_kind: CandidateKind::Tool,
                raw_score,
                adjusted_score: (raw_score - penalty).max(0.0),
                token_cost: tool.token_cost,
            }
        })
        .collect();

    candidates.sort_by(|a, b| {
        b.adjusted_score
            .partial_cmp(&a.adjusted_score)
            .unwrap_or(std::cmp::Ordering::Equal)
    });
    candidates
}

/// Rank and prune to top-K within budget, preserving pinned tools.
pub fn search_and_prune(
    tools: &[ToolDescriptor],
    query_embedding: &[f32],
    config: &SearchConfig,
) -> (Vec<RankedCandidate>, ToolSearchStats) {
    let ranked = rank_tools(tools, query_embedding, config);
    let total_before = tools.len();
    let always_refs: Vec<&str> = config.always_include.iter().map(|s| s.as_str()).collect();
    let pruned = top_k_with_pinned(&ranked, config.top_k, config.token_budget, &always_refs);
    let total_after = pruned.len();

    let top_scores: Vec<(String, f64)> = pruned
        .iter()
        .take(10)
        .map(|c| (c.source_id.clone(), c.adjusted_score))
        .collect();

    let stats = ToolSearchStats {
        candidates_considered: total_before,
        candidates_selected: total_after,
        candidates_pruned: total_before - total_after,
        token_savings: ranked.iter().map(|c| c.token_cost).sum::<usize>()
            - pruned.iter().map(|c| c.token_cost).sum::<usize>(),
        top_scores,
        embedding_status: "ok".to_string(),
    };

    (pruned, stats)
}

/// Stats for trace annotation.
#[derive(Debug, Clone, Serialize)]
pub struct ToolSearchStats {
    pub candidates_considered: usize,
    pub candidates_selected: usize,
    pub candidates_pruned: usize,
    pub token_savings: usize,
    /// Top-10 selected tools with their adjusted ranking scores.
    /// Empty when embedding failed (graceful degradation).
    #[serde(default, skip_serializing_if = "Vec::is_empty")]
    pub top_scores: Vec<(String, f64)>,
    /// Whether the embedding provider was used successfully.
    /// `"ok"` = normal ranking, `"failed"` = fallback to static ordering.
    #[serde(default = "default_embedding_status")]
    pub embedding_status: String,
}

#[allow(dead_code)] // Used by serde(default) attribute
fn default_embedding_status() -> String {
    "ok".to_string()
}

fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
    if norm_a == 0.0 || norm_b == 0.0 {
        return 0.0;
    }
    (dot / (norm_a * norm_b)) as f64
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn rank_tools_returns_sorted_by_adjusted_score() {
        let tools = vec![
            ToolDescriptor {
                name: "web_search".into(),
                description: "Search the web".into(),
                token_cost: 50,
                source: ToolSourceInfo::BuiltIn,
                embedding: Some(vec![0.9, 0.1, 0.0]),
            },
            ToolDescriptor {
                name: "memory_store".into(),
                description: "Store a memory".into(),
                token_cost: 30,
                source: ToolSourceInfo::BuiltIn,
                embedding: Some(vec![0.1, 0.9, 0.0]),
            },
        ];
        let query_embedding = vec![0.85, 0.15, 0.0];

        let ranked = rank_tools(&tools, &query_embedding, &SearchConfig::default());
        assert_eq!(ranked[0].source_id, "web_search");
    }

    #[test]
    fn mcp_tools_receive_latency_penalty() {
        let tools = vec![
            ToolDescriptor {
                name: "local_tool".into(),
                description: "A local tool".into(),
                token_cost: 50,
                source: ToolSourceInfo::BuiltIn,
                embedding: Some(vec![0.9, 0.1]),
            },
            ToolDescriptor {
                name: "server::remote_tool".into(),
                description: "A remote tool".into(),
                token_cost: 50,
                source: ToolSourceInfo::Mcp {
                    server: "server".into(),
                },
                embedding: Some(vec![0.9, 0.1]),
            },
        ];
        let query_embedding = vec![0.9, 0.1];

        let config = SearchConfig {
            mcp_latency_penalty: 0.1,
            ..Default::default()
        };
        let ranked = rank_tools(&tools, &query_embedding, &config);
        assert_eq!(ranked[0].source_id, "local_tool");
    }

    #[test]
    fn tools_without_embeddings_are_included_unranked() {
        let tools = vec![ToolDescriptor {
            name: "no_embedding".into(),
            description: "No embedding yet".into(),
            token_cost: 50,
            source: ToolSourceInfo::BuiltIn,
            embedding: None,
        }];
        let query_embedding = vec![0.9, 0.1];

        let ranked = rank_tools(&tools, &query_embedding, &SearchConfig::default());
        assert_eq!(ranked.len(), 1);
        assert_eq!(ranked[0].adjusted_score, 0.0);
    }
}