lash-plugin-tool-discovery 0.1.0-alpha.39

Tool-discovery plugin for the lash agent runtime.
Documentation
use std::sync::{Arc, RwLock};

use lash_core::{ToolCall, ToolContext, ToolResult};
use lash_tool_support::{StaticToolExecute, StaticToolProvider};
use serde_json::{Value, json};

use crate::common::{LLM_CANDIDATE_LIMIT, args_with_limit, catalog_key, limit_from_args};
use crate::definitions::search_tools_definition;
use crate::ranking::ToolDiscoveryIndex;
use crate::rerank::{
    llm_rerank_request, merge_llm_selection, parse_llm_tool_names, rerank_payment_action_intent,
};

#[derive(Clone, Default)]
struct IndexCache {
    index: Option<Arc<ToolDiscoveryIndex>>,
}

#[derive(Clone)]
pub struct ToolDiscoveryToolsProvider {
    cache: Arc<RwLock<IndexCache>>,
}

impl Default for ToolDiscoveryToolsProvider {
    fn default() -> Self {
        Self::new()
    }
}

impl ToolDiscoveryToolsProvider {
    pub fn new() -> Self {
        Self {
            cache: Arc::default(),
        }
    }

    fn index_for_catalog(&self, catalog: Arc<Vec<Value>>) -> Arc<ToolDiscoveryIndex> {
        let key = catalog_key(catalog.as_ref());
        if let Some(index) = self
            .cache
            .read()
            .expect("tool discovery cache lock poisoned")
            .index
            .as_ref()
            .filter(|index| index.key == key)
            .cloned()
        {
            return index;
        }

        let index = Arc::new(ToolDiscoveryIndex::build(key, catalog.as_ref().clone()));
        self.cache
            .write()
            .expect("tool discovery cache lock poisoned")
            .index = Some(Arc::clone(&index));
        index
    }

    async fn search_tools(
        &self,
        args: &Value,
        catalog: Arc<Vec<Value>>,
        context: &ToolContext<'_>,
    ) -> ToolResult {
        let index = self.index_for_catalog(catalog);
        let limit = limit_from_args(args);
        let candidate_args = args_with_limit(args, LLM_CANDIDATE_LIMIT);
        let candidates = index.search(&candidate_args);
        if candidates.is_empty() {
            return ToolResult::ok(json!([]));
        }
        let query = args
            .get("query")
            .and_then(Value::as_str)
            .unwrap_or_default();
        let candidates = rerank_payment_action_intent(query, candidates);

        let model = match context.sessions().model().await {
            Ok(model) => model,
            Err(err) => {
                return ToolResult::err_fmt(format_args!(
                    "search_tools could not resolve parent model: {err}"
                ));
            }
        };
        let request =
            llm_rerank_request(args, &candidates, limit, model.model, model.model_variant);
        let completion = match context
            .direct_completions()
            .complete(request, "search_tools")
            .await
        {
            Ok(completion) => completion,
            Err(err) => return ToolResult::err_fmt(format_args!("search_tools failed: {err}")),
        };

        let selected_names = match parse_llm_tool_names(&completion.text) {
            Ok(names) => names,
            Err(err) => {
                return ToolResult::err_fmt(format_args!(
                    "search_tools returned invalid JSON: {err}"
                ));
            }
        };

        ToolResult::ok(json!(merge_llm_selection(
            candidates,
            selected_names,
            limit
        )))
    }
}

/// Build the `search_tools` provider backed by a fresh discovery cache.
pub fn tool_discovery_provider() -> StaticToolProvider<ToolDiscoveryToolsProvider> {
    StaticToolProvider::new(
        vec![search_tools_definition()],
        ToolDiscoveryToolsProvider::new(),
    )
}

#[async_trait::async_trait]
impl StaticToolExecute for ToolDiscoveryToolsProvider {
    async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
        match call.name {
            "search_tools" => match call.context.sessions().shared_tool_catalog().await {
                Ok(catalog) => self.search_tools(call.args, catalog, call.context).await,
                Err(err) => ToolResult::err_fmt(err.to_string()),
            },
            _ => ToolResult::err_fmt(format_args!("Unknown tool: {}", call.name)),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use lash_core::plugin::runtime_host::{
        SessionGraphService, SessionLifecycleService, SessionStateService,
    };
    use lash_core::plugin::{PluginError, SessionHandle, SessionSnapshot};
    use lash_core::{
        DirectCompletion, TokenUsage, ToolAgentSurface, ToolCall, ToolContract, ToolDefinition,
        ToolProvider,
    };
    use serde_json::json;
    use std::sync::Mutex;

    #[derive(Default)]
    struct FakeSessionManager {
        snapshot: SessionSnapshot,
        catalog: Vec<Value>,
        direct_response: Mutex<Option<String>>,
        direct_requests: Mutex<Vec<lash_core::DirectRequest>>,
    }

    #[async_trait::async_trait]
    impl SessionStateService for FakeSessionManager {
        async fn snapshot_current(&self) -> Result<SessionSnapshot, PluginError> {
            Ok(self.snapshot.clone())
        }

        async fn snapshot_session(
            &self,
            _session_id: &str,
        ) -> Result<SessionSnapshot, PluginError> {
            Ok(self.snapshot.clone())
        }

        async fn tool_catalog(&self, _session_id: &str) -> Result<Vec<Value>, PluginError> {
            Ok(self.catalog.clone())
        }

        async fn set_tools_availability(
            &self,
            _session_id: &str,
            _tool_names: &[String],
            _availability: Option<lash_core::ToolAvailability>,
        ) -> Result<u64, PluginError> {
            Ok(0)
        }
    }

    #[async_trait::async_trait]
    impl SessionLifecycleService for FakeSessionManager {
        async fn create_session(
            &self,
            _request: lash_core::plugin::SessionCreateRequest,
        ) -> Result<SessionHandle, PluginError> {
            Err(PluginError::Session("unused".to_string()))
        }

        async fn close_session(&self, _session_id: &str) -> Result<(), PluginError> {
            Ok(())
        }
    }

    #[async_trait::async_trait]
    impl SessionGraphService for FakeSessionManager {}

    fn snapshot_with_model(model: &str, variant: Option<&str>) -> SessionSnapshot {
        let mut snapshot = SessionSnapshot::default();
        snapshot.policy.model.id = model.to_string();
        snapshot.policy.model.variant = variant.map(str::to_string);
        snapshot
    }

    fn discovery_context(host: Arc<FakeSessionManager>) -> lash_core::ToolContext<'static> {
        let direct_host = Arc::clone(&host);
        lash_core::testing::mock_tool_context_with_host_and_direct_completions(
            host,
            lash_core::DirectCompletionClient::from_fn(move |request, _usage_source| {
                direct_host
                    .direct_requests
                    .lock()
                    .expect("direct requests lock poisoned")
                    .push(request);
                let text = direct_host
                    .direct_response
                    .lock()
                    .expect("direct response lock poisoned")
                    .clone()
                    .unwrap_or_else(|| "{\"tool_names\":[]}".to_string());
                Ok(DirectCompletion {
                    text,
                    usage: TokenUsage::default(),
                })
            }),
        )
    }

    fn catalog_tool_with_metadata(
        name: &str,
        description: &str,
        module: Option<&str>,
        aliases: Vec<&str>,
    ) -> Value {
        let tool = ToolDefinition::raw_named(
            name,
            description,
            ToolContract::default_input_schema(),
            json!({}),
        )
        .with_agent_surface(
            ToolAgentSurface::new(
                [module.unwrap_or(match name {
                    "read_file" => "files",
                    "search_web" => "web",
                    _ => "tools",
                })],
                match name {
                    "read_file" => "read",
                    "search_web" => "search",
                    _ => name,
                },
            )
            .with_aliases(aliases),
        );
        let manifest = tool.manifest();
        let agent_surface = manifest.agent_surface.executable_for(&manifest.name);
        let call = agent_surface.call_path();
        json!({
            "id": manifest.id,
            "name": manifest.name,
            "module_path": agent_surface.module_path.clone(),
            "operation": agent_surface.operation.clone(),
            "call": call,
            "description": manifest.description,
            "aliases": agent_surface.aliases.clone(),
            "availability": "searchable",
            "callable": false,
            "showcased": false,
            "searchable": true,
            "activation": manifest.activation,
            "contract": manifest.compact_contract.expect("compact contract"),
        })
    }

    fn ranked_names(results: &[Value]) -> Vec<String> {
        results
            .iter()
            .map(|result| {
                result
                    .get("name")
                    .and_then(Value::as_str)
                    .expect("ranked result name")
                    .to_string()
            })
            .collect()
    }

    #[test]
    fn provider_exposes_search_tools_only() {
        let names = tool_discovery_provider()
            .tool_manifests()
            .into_iter()
            .map(|definition| definition.name)
            .collect::<Vec<_>>();

        assert_eq!(names, vec!["search_tools"]);
    }

    #[tokio::test]
    async fn search_tools_uses_host_catalog_and_projects_compact_contract() {
        let host = Arc::new(FakeSessionManager {
            catalog: vec![
                catalog_tool_with_metadata(
                    "read_file",
                    "Read file contents",
                    Some("filesystem"),
                    vec!["cat"],
                ),
                catalog_tool_with_metadata(
                    "search_web",
                    "Search the web",
                    Some("web"),
                    vec!["web_search"],
                ),
            ],
            ..Default::default()
        });
        let provider = tool_discovery_provider();
        let context = discovery_context(host);

        let args = json!({
            "query": "cat",
            "module": "filesystem",
            "limit": 1,
        });
        let result = provider
            .execute(ToolCall {
                name: "search_tools",
                args: &args,
                context: &context,
                progress: None,
            })
            .await;

        assert!(result.is_success(), "{result:?}");
        let value = result.value_for_projection();
        let results = value.as_array().expect("search result list");
        assert_eq!(results.len(), 1);
        assert_eq!(results[0]["name"], json!("filesystem.read"));
        assert_eq!(results[0]["call"], json!("filesystem.read"));
        assert!(
            results[0]["signature"]
                .as_str()
                .expect("signature")
                .starts_with("await filesystem.read({})? -> ")
        );
        assert_eq!(results[0]["description"], json!("Read file contents"));
        assert_eq!(results[0]["module_path"], json!(["filesystem"]));
        assert!(results[0].get("score").is_none());
    }

    #[tokio::test]
    async fn search_tools_reranks_candidates_with_direct_completion() {
        let host = Arc::new(FakeSessionManager {
            snapshot: snapshot_with_model("gpt-5.5", Some("medium")),
            catalog: vec![
                catalog_tool_with_metadata("read_file", "Read file contents", None, vec!["cat"]),
                catalog_tool_with_metadata("search_web", "Search the web", None, vec!["web"]),
            ],
            direct_response: Mutex::new(Some(
                "{\"tool_names\":[\"search_web\",\"search_web\",\"unknown\"]}".to_string(),
            )),
            ..Default::default()
        });
        let provider = tool_discovery_provider();
        let context = discovery_context(host.clone());

        let args = json!({
            "query": "",
            "exclude": ["read_file"],
            "limit": 2,
        });
        let result = provider
            .execute(ToolCall {
                name: "search_tools",
                args: &args,
                context: &context,
                progress: None,
            })
            .await;

        assert!(result.is_success(), "{result:?}");
        let value = result.value_for_projection();
        let results = value.as_array().expect("search result list");
        assert_eq!(ranked_names(results), vec!["web.search"]);
        let requests = host
            .direct_requests
            .lock()
            .expect("direct requests lock poisoned");
        assert_eq!(requests.len(), 1);
        assert_eq!(requests[0].model, "gpt-5.5");
        assert_eq!(requests[0].model_variant.as_deref(), Some("medium"));
    }
}