use super::{RegisteredTool, ScriptedExecutionTrace, ScriptedTool, ToolArgs, ToolDef};
use crate::ExecutionLimits;
use crate::tool::{Tool, ToolRequest, ToolResponse, ToolStatus, VERSION};
use async_trait::async_trait;
use schemars::schema_for;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum DiscoveryMode {
#[default]
Exclusive,
WithDiscovery,
}
pub struct DiscoverTool {
name: String,
locale: String,
display_name: String,
short_desc: String,
inner: ScriptedTool,
}
#[async_trait]
impl Tool for DiscoverTool {
fn name(&self) -> &str {
&self.name
}
fn display_name(&self) -> &str {
&self.display_name
}
fn short_description(&self) -> &str {
&self.short_desc
}
fn description(&self) -> &str {
&self.short_desc
}
fn help(&self) -> String {
format!(
"# {}\n\nDiscover available tool commands.\n\n## Commands\n\n- `discover --categories` — list categories\n- `discover --category <name>` — list tools in category\n- `discover --tag <tag>` — filter by tag\n- `discover --search <keyword>` — search by name/description\n- `help --list` — list all tools\n- `help <tool>` — human-readable usage\n- `help <tool> --json` — machine-readable schema\n\nAll commands support `--json` for structured output.\n",
self.display_name
)
}
fn system_prompt(&self) -> String {
format!(
"{}: discover available tool commands. Use `discover --categories`, `discover --search <keyword>`, `help <tool>`, `help <tool> --json`.",
self.name
)
}
fn locale(&self) -> &str {
&self.locale
}
fn input_schema(&self) -> serde_json::Value {
let schema = schema_for!(ToolRequest);
serde_json::to_value(schema).unwrap_or_default()
}
fn output_schema(&self) -> serde_json::Value {
let schema = schema_for!(ToolResponse);
serde_json::to_value(schema).unwrap_or_default()
}
fn version(&self) -> &str {
VERSION
}
fn execution(
&self,
args: serde_json::Value,
) -> Result<crate::tool::ToolExecution, crate::tool::ToolError> {
self.inner.execution(args)
}
async fn execute(&self, req: ToolRequest) -> ToolResponse {
self.inner.execute(req).await
}
async fn execute_with_status(
&self,
req: ToolRequest,
status_callback: Box<dyn FnMut(ToolStatus) + Send>,
) -> ToolResponse {
self.inner.execute_with_status(req, status_callback).await
}
}
pub struct ScriptingToolSetBuilder {
name: String,
locale: String,
short_desc: Option<String>,
tools: Vec<RegisteredTool>,
limits: Option<ExecutionLimits>,
env_vars: Vec<(String, String)>,
mode: DiscoveryMode,
}
impl ScriptingToolSetBuilder {
fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
locale: "en-US".to_string(),
short_desc: None,
tools: Vec::new(),
limits: None,
env_vars: Vec::new(),
mode: DiscoveryMode::Exclusive,
}
}
pub fn locale(mut self, locale: &str) -> Self {
self.locale = locale.to_string();
self
}
pub fn short_description(mut self, desc: impl Into<String>) -> Self {
self.short_desc = Some(desc.into());
self
}
pub fn tool(
mut self,
def: ToolDef,
callback: impl Fn(&ToolArgs) -> Result<String, String> + Send + Sync + 'static,
) -> Self {
self.tools.push(RegisteredTool {
def,
callback: Arc::new(callback),
});
self
}
pub fn limits(mut self, limits: ExecutionLimits) -> Self {
self.limits = Some(limits);
self
}
pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.env_vars.push((key.into(), value.into()));
self
}
pub fn with_discovery(mut self) -> Self {
self.mode = DiscoveryMode::WithDiscovery;
self
}
pub fn build(&self) -> ScriptingToolSet {
let short_desc = self
.short_desc
.clone()
.unwrap_or_else(|| format!("ScriptingToolSet: {}", self.name));
let compact = self.mode == DiscoveryMode::WithDiscovery;
let mut builder = ScriptedTool::builder(&self.name).locale(&self.locale);
builder = builder.short_description(&short_desc);
builder = builder.compact_prompt(compact);
if let Some(limits) = &self.limits {
builder = builder.limits(limits.clone());
}
for (key, value) in &self.env_vars {
builder = builder.env(key, value);
}
for reg in &self.tools {
let cb = Arc::clone(®.callback);
builder = builder.tool(reg.def.clone(), move |args: &ToolArgs| (cb)(args));
}
ScriptingToolSet {
name: self.name.clone(),
locale: self.locale.clone(),
inner: builder.build(),
mode: self.mode,
}
}
}
pub struct ScriptingToolSet {
name: String,
locale: String,
inner: ScriptedTool,
mode: DiscoveryMode,
}
impl ScriptingToolSet {
pub fn builder(name: impl Into<String>) -> ScriptingToolSetBuilder {
ScriptingToolSetBuilder::new(name)
}
pub fn discovery_mode(&self) -> DiscoveryMode {
self.mode
}
pub fn take_last_execution_trace(&self) -> Option<ScriptedExecutionTrace> {
self.inner.take_last_execution_trace()
}
pub fn tools(&self) -> Vec<Box<dyn Tool>> {
match self.mode {
DiscoveryMode::Exclusive => {
vec![Box::new(self.inner.clone())]
}
DiscoveryMode::WithDiscovery => {
let discover = DiscoverTool {
name: format!("{}_discover", self.name),
locale: self.locale.clone(),
display_name: format!("{} Discover", self.name),
short_desc: format!("Discover available {} commands", self.name),
inner: self.inner.clone(),
};
vec![Box::new(self.inner.clone()), Box::new(discover)]
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tool::ToolRequest;
fn make_tools() -> ScriptingToolSetBuilder {
ScriptingToolSet::builder("test_api")
.short_description("Test API")
.tool(
ToolDef::new("get_user", "Fetch user by ID")
.with_schema(serde_json::json!({
"type": "object",
"properties": {
"id": {"type": "integer"}
},
"required": ["id"]
}))
.with_category("users"),
|args: &ToolArgs| {
let id = args.param_i64("id").ok_or("missing --id")?;
Ok(format!("{{\"id\":{id},\"name\":\"Alice\"}}\n"))
},
)
.tool(
ToolDef::new("list_orders", "List orders for a user")
.with_schema(serde_json::json!({
"type": "object",
"properties": {
"user_id": {"type": "integer"}
}
}))
.with_category("orders"),
|args: &ToolArgs| {
let uid = args.param_i64("user_id").ok_or("missing --user_id")?;
Ok(format!("[{{\"order_id\":1,\"user_id\":{uid}}}]\n"))
},
)
}
#[test]
fn test_builder_defaults_to_exclusive() {
let toolset = make_tools().build();
assert_eq!(toolset.discovery_mode(), DiscoveryMode::Exclusive);
}
#[test]
fn test_with_discovery_switches_mode() {
let toolset = make_tools().with_discovery().build();
assert_eq!(toolset.discovery_mode(), DiscoveryMode::WithDiscovery);
}
#[test]
fn test_exclusive_returns_one_tool() {
let toolset = make_tools().build();
let tools = toolset.tools();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].name(), "test_api");
}
#[test]
fn test_exclusive_tool_has_full_schemas() {
let toolset = make_tools().build();
let tools = toolset.tools();
let sp = tools[0].system_prompt();
assert!(sp.contains("get_user [--id <integer>]"), "prompt: {sp}");
assert!(
sp.contains("list_orders [--user_id <integer>]"),
"prompt: {sp}"
);
}
#[test]
fn test_exclusive_tool_no_discover_instructions() {
let toolset = make_tools().build();
let tools = toolset.tools();
let sp = tools[0].system_prompt();
assert!(!sp.contains("discover --categories"), "prompt: {sp}");
}
#[test]
fn test_discovery_returns_two_tools() {
let toolset = make_tools().with_discovery().build();
let tools = toolset.tools();
assert_eq!(tools.len(), 2);
assert_eq!(tools[0].name(), "test_api");
assert_eq!(tools[1].name(), "test_api_discover");
}
#[test]
fn test_discovery_script_tool_compact_prompt() {
let toolset = make_tools().with_discovery().build();
let tools = toolset.tools();
let sp = tools[0].system_prompt();
assert!(!sp.contains("--id <integer>"), "prompt: {sp}");
assert!(!sp.contains("--user_id <integer>"), "prompt: {sp}");
}
#[test]
fn test_discovery_discover_tool_prompt() {
let toolset = make_tools().with_discovery().build();
let tools = toolset.tools();
let sp = tools[1].system_prompt();
assert!(sp.contains("discover"), "prompt: {sp}");
assert!(sp.contains("help"), "prompt: {sp}");
}
#[test]
fn test_name_and_short_description() {
let toolset = make_tools().build();
let tools = toolset.tools();
assert_eq!(tools[0].name(), "test_api");
assert_eq!(tools[0].short_description(), "Test API");
}
#[test]
fn test_default_short_description() {
let toolset = ScriptingToolSet::builder("mytools")
.tool(ToolDef::new("noop", "No-op"), |_: &ToolArgs| {
Ok("ok\n".into())
})
.build();
let tools = toolset.tools();
assert_eq!(tools[0].short_description(), "ScriptingToolSet: mytools");
}
#[tokio::test]
async fn test_execute_via_exclusive_tool() {
let toolset = make_tools().build();
let tools = toolset.tools();
let resp = tools[0]
.execute(ToolRequest {
commands: "get_user --id 42 | jq -r '.name'".into(),
timeout_ms: None,
})
.await;
assert_eq!(resp.exit_code, 0);
assert_eq!(resp.stdout.trim(), "Alice");
}
#[tokio::test]
async fn test_execute_via_discovery_script_tool() {
let toolset = make_tools().with_discovery().build();
let tools = toolset.tools();
let resp = tools[0]
.execute(ToolRequest {
commands: "get_user --id 1".into(),
timeout_ms: None,
})
.await;
assert_eq!(resp.exit_code, 0);
assert!(resp.stdout.contains("Alice"));
}
#[tokio::test]
async fn test_execute_via_discover_tool() {
let toolset = make_tools().with_discovery().build();
let tools = toolset.tools();
let resp = tools[1]
.execute(ToolRequest {
commands: "discover --categories".into(),
timeout_ms: None,
})
.await;
assert_eq!(resp.exit_code, 0);
assert!(resp.stdout.contains("users"));
assert!(resp.stdout.contains("orders"));
}
#[tokio::test]
async fn test_discover_tool_help_builtin() {
let toolset = make_tools().with_discovery().build();
let tools = toolset.tools();
let resp = tools[1]
.execute(ToolRequest {
commands: "help get_user".into(),
timeout_ms: None,
})
.await;
assert_eq!(resp.exit_code, 0);
assert!(resp.stdout.contains("get_user"));
assert!(resp.stdout.contains("--id"));
}
#[tokio::test]
async fn test_execute_with_status_via_tool() {
use std::sync::{Arc, Mutex};
let toolset = make_tools().build();
let tools = toolset.tools();
let phases = Arc::new(Mutex::new(Vec::new()));
let phases_clone = phases.clone();
let resp = tools[0]
.execute_with_status(
ToolRequest {
commands: "get_user --id 1".into(),
timeout_ms: None,
},
Box::new(move |status| {
phases_clone
.lock()
.expect("lock poisoned")
.push(status.phase.clone());
}),
)
.await;
assert_eq!(resp.exit_code, 0);
let phases = phases.lock().expect("lock poisoned");
assert!(phases.contains(&"validate".to_string()));
assert!(phases.contains(&"complete".to_string()));
}
#[tokio::test]
async fn test_help_builtin_works_in_exclusive() {
let toolset = make_tools().build();
let tools = toolset.tools();
let resp = tools[0]
.execute(ToolRequest {
commands: "help get_user".into(),
timeout_ms: None,
})
.await;
assert_eq!(resp.exit_code, 0);
assert!(resp.stdout.contains("get_user"));
assert!(resp.stdout.contains("--id"));
}
#[tokio::test]
async fn test_discover_builtin_works_in_exclusive() {
let toolset = make_tools().build();
let tools = toolset.tools();
let resp = tools[0]
.execute(ToolRequest {
commands: "discover --categories".into(),
timeout_ms: None,
})
.await;
assert_eq!(resp.exit_code, 0);
assert!(resp.stdout.contains("users"));
assert!(resp.stdout.contains("orders"));
}
#[tokio::test]
async fn test_env_vars_passed_through() {
let toolset = ScriptingToolSet::builder("env_test")
.env("MY_VAR", "hello")
.tool(ToolDef::new("noop", "No-op"), |_: &ToolArgs| {
Ok("ok\n".into())
})
.build();
let tools = toolset.tools();
let resp = tools[0]
.execute(ToolRequest {
commands: "echo $MY_VAR".into(),
timeout_ms: None,
})
.await;
assert_eq!(resp.exit_code, 0);
assert_eq!(resp.stdout.trim(), "hello");
}
#[test]
fn test_version() {
let toolset = make_tools().build();
let tools = toolset.tools();
assert_eq!(tools[0].version(), VERSION);
}
#[test]
fn test_schemas() {
let toolset = make_tools().build();
let tools = toolset.tools();
let input = tools[0].input_schema();
assert!(input["properties"]["commands"].is_object());
let output = tools[0].output_schema();
assert!(output["properties"]["stdout"].is_object());
}
#[test]
fn test_discover_tool_schemas() {
let toolset = make_tools().with_discovery().build();
let tools = toolset.tools();
let input = tools[1].input_schema();
assert!(input["properties"]["commands"].is_object());
}
}