roboticus-api 0.11.3

HTTP routes, WebSocket, auth, rate limiting, and dashboard for the Roboticus agent runtime
Documentation
//! Tool-pruning helpers: embed, score, and select the most relevant tool definitions
//! for a given query before passing them to the LLM.

use roboticus_agent::capability::CapabilitySource;
use roboticus_agent::tool_search::{SearchConfig, ToolDescriptor, ToolSearchStats, ToolSourceInfo};
use sha2::{Digest, Sha256};

use super::super::AppState;

pub(crate) fn estimate_tool_token_cost(def: &roboticus_llm::format::ToolDefinition) -> usize {
    let schema_len = serde_json::to_string(&def.parameters)
        .map(|s| s.len())
        .unwrap_or_default();
    ((def.name.len() + def.description.len() + schema_len) / 4).max(1)
}

pub(crate) fn description_hash(description: &str) -> String {
    let mut hasher = Sha256::new();
    hasher.update(description.as_bytes());
    hex::encode(hasher.finalize())
}

pub(crate) fn tool_source_info(source: CapabilitySource) -> ToolSourceInfo {
    match source {
        CapabilitySource::BuiltIn => ToolSourceInfo::BuiltIn,
        CapabilitySource::Plugin(name) => ToolSourceInfo::Plugin(name),
        CapabilitySource::Mcp { server, .. } => ToolSourceInfo::Mcp { server },
    }
}

pub(crate) async fn prune_tool_definitions(
    state: &AppState,
    defs: Vec<roboticus_llm::format::ToolDefinition>,
    query_embedding: Option<&[f32]>,
    embedding_client: &roboticus_llm::EmbeddingClient,
) -> (
    Vec<roboticus_llm::format::ToolDefinition>,
    Option<ToolSearchStats>,
) {
    let Some(query_embedding) = query_embedding else {
        return (defs, None);
    };

    let mut source_by_name = std::collections::HashMap::new();
    source_by_name.insert("orchestrate-subagents".to_string(), ToolSourceInfo::BuiltIn);
    source_by_name.insert("assign-tasks".to_string(), ToolSourceInfo::BuiltIn);
    source_by_name.insert("delegate-subagent".to_string(), ToolSourceInfo::BuiltIn);
    source_by_name.insert("task-status".to_string(), ToolSourceInfo::BuiltIn);
    source_by_name.insert("retry-task".to_string(), ToolSourceInfo::BuiltIn);
    source_by_name.insert("list-open-tasks".to_string(), ToolSourceInfo::BuiltIn);
    source_by_name.insert("compose-subagent".to_string(), ToolSourceInfo::BuiltIn);
    source_by_name.insert(
        "update-subagent-skills".to_string(),
        ToolSourceInfo::BuiltIn,
    );
    source_by_name.insert("list-subagent-roster".to_string(), ToolSourceInfo::BuiltIn);
    source_by_name.insert("list-available-skills".to_string(), ToolSourceInfo::BuiltIn);
    source_by_name.insert("remove-subagent".to_string(), ToolSourceInfo::BuiltIn);
    source_by_name.insert(
        "retire-unused-subagents".to_string(),
        ToolSourceInfo::BuiltIn,
    );
    source_by_name.insert("compose-skill".to_string(), ToolSourceInfo::BuiltIn);
    source_by_name.insert(
        "validate-subagent-roster".to_string(),
        ToolSourceInfo::BuiltIn,
    );

    if state.capabilities.is_empty().await {
        for tool in state.tools.list() {
            source_by_name.insert(tool.name().to_string(), ToolSourceInfo::BuiltIn);
        }
    } else {
        for capability in state.capabilities.catalog().await {
            source_by_name.insert(capability.name, tool_source_info(capability.source));
        }
    }

    let mut descriptors = Vec::with_capacity(defs.len());
    let mut missing_names = Vec::new();
    let mut missing_descs = Vec::new();

    for def in &defs {
        let desc_hash = description_hash(&def.description);
        let embedding = match roboticus_db::tool_embeddings::get_tool_embedding(
            &state.db, &def.name, &desc_hash,
        ) {
            Ok(embedding) => embedding,
            Err(e) => {
                tracing::warn!(tool = %def.name, error = %e, "failed to read tool embedding cache");
                None
            }
        };
        if embedding.is_none() {
            missing_names.push((def.name.clone(), desc_hash));
            missing_descs.push(def.description.as_str());
        }
        descriptors.push(ToolDescriptor {
            name: def.name.clone(),
            description: def.description.clone(),
            token_cost: estimate_tool_token_cost(def),
            source: source_by_name
                .get(&def.name)
                .cloned()
                .unwrap_or(ToolSourceInfo::BuiltIn),
            embedding,
        });
    }

    if !missing_descs.is_empty() {
        match embedding_client.embed(&missing_descs).await {
            Ok(generated) => {
                for ((name, desc_hash), embedding) in
                    missing_names.into_iter().zip(generated.into_iter())
                {
                    if let Some(desc) = descriptors.iter_mut().find(|d| d.name == name) {
                        desc.embedding = Some(embedding.clone());
                    }
                    if let Err(e) = roboticus_db::tool_embeddings::save_tool_embedding(
                        &state.db, &name, &desc_hash, &embedding,
                    ) {
                        tracing::warn!(tool = %name, error = %e, "failed to cache tool embedding");
                    }
                }
            }
            Err(e) => {
                tracing::warn!(error = %e, "tool search embedding failed; falling back to static top-k");
                // Graceful degradation: keep always-include tools + alphabetical top-k
                // instead of drowning the prompt with all tools.
                let always = always_include_operational_tools();
                let top_k = SearchConfig::default().top_k;
                let mut kept: Vec<_> = defs
                    .iter()
                    .filter(|d| always.iter().any(|a| d.name.contains(a)))
                    .cloned()
                    .collect();
                for def in &defs {
                    if kept.len() >= top_k {
                        break;
                    }
                    if !kept.iter().any(|k| k.name == def.name) {
                        kept.push(def.clone());
                    }
                }
                let fallback_stats = roboticus_agent::tool_search::ToolSearchStats {
                    candidates_considered: defs.len(),
                    candidates_selected: kept.len(),
                    candidates_pruned: defs.len() - kept.len(),
                    token_savings: 0,
                    top_scores: vec![],
                    embedding_status: "failed".to_string(),
                };
                return (kept, Some(fallback_stats));
            }
        }
    }

    let search_config = SearchConfig {
        always_include: always_include_operational_tools(),
        ..Default::default()
    };
    let (ranked, stats) = roboticus_agent::tool_search::search_and_prune(
        &descriptors,
        query_embedding,
        &search_config,
    );
    let defs_by_name: std::collections::HashMap<_, _> = defs
        .into_iter()
        .map(|def| (def.name.clone(), def))
        .collect();
    let selected = ranked
        .into_iter()
        .filter_map(|candidate| defs_by_name.get(&candidate.source_id).cloned())
        .collect();
    (selected, Some(stats))
}

pub(crate) fn always_include_operational_tools() -> Vec<String> {
    vec![
        "get_memory_stats".into(),
        "get_runtime_context".into(),
        "list-subagent-roster".into(),
        "list-available-skills".into(),
        "compose-subagent".into(),
        "compose-skill".into(),
        "memory_store".into(),
        "delegate-subagent".into(),
        "orchestrate-subagents".into(),
        "task-status".into(),
        "retry-task".into(),
        "list-open-tasks".into(),
    ]
}