use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::error::Result;
use crate::types::ToolCallId;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
impl ToolDefinition {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
parameters: serde_json::Value,
) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters,
}
}
pub fn no_params(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters: serde_json::json!({
"type": "object",
"properties": {}
}),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolOutput {
#[serde(default)]
pub title: String,
pub content: String,
#[serde(default = "default_metadata")]
pub metadata: serde_json::Value,
#[serde(default)]
pub is_error: bool,
}
fn default_metadata() -> serde_json::Value {
serde_json::Value::Object(serde_json::Map::new())
}
impl ToolOutput {
pub fn success(content: impl Into<String>) -> Self {
Self {
title: String::new(),
content: content.into(),
metadata: default_metadata(),
is_error: false,
}
}
pub fn success_with_title(title: impl Into<String>, content: impl Into<String>) -> Self {
Self {
title: title.into(),
content: content.into(),
metadata: default_metadata(),
is_error: false,
}
}
pub fn error(content: impl Into<String>) -> Self {
Self {
title: String::new(),
content: content.into(),
metadata: default_metadata(),
is_error: true,
}
}
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = metadata;
self
}
pub fn with_title(mut self, title: impl Into<String>) -> Self {
self.title = title.into();
self
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolChoice {
#[default]
Auto,
None,
Required,
Specific {
name: String,
},
}
impl ToolChoice {
pub fn specific(name: impl Into<String>) -> Self {
Self::Specific { name: name.into() }
}
pub fn allows_tools(&self) -> bool {
!matches!(self, Self::None)
}
pub fn required_tool(&self) -> Option<&str> {
match self {
Self::Specific { name } => Some(name.as_str()),
_ => Option::None,
}
}
pub fn is_forced(&self) -> bool {
matches!(self, Self::Required | Self::Specific { .. })
}
}
#[derive(Debug, Clone)]
pub struct CancellationToken(Arc<AtomicBool>);
impl CancellationToken {
pub fn new() -> Self {
Self(Arc::new(AtomicBool::new(false)))
}
pub fn cancel(&self) {
self.0.store(true, Ordering::Release);
}
pub fn is_cancelled(&self) -> bool {
self.0.load(Ordering::Acquire)
}
}
impl Default for CancellationToken {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ConcurrencyMode {
#[default]
Shared,
Exclusive,
}
pub struct ToolCallContext {
pub call_id: ToolCallId,
pub cancellation: CancellationToken,
pub extra: serde_json::Value,
}
impl ToolCallContext {
pub fn new(call_id: ToolCallId) -> Self {
Self {
call_id,
cancellation: CancellationToken::new(),
extra: serde_json::Value::Object(serde_json::Map::new()),
}
}
pub fn with_cancellation(mut self, token: CancellationToken) -> Self {
self.cancellation = token;
self
}
pub fn with_extra(mut self, extra: serde_json::Value) -> Self {
self.extra = extra;
self
}
}
#[async_trait]
pub trait Tool: Send + Sync {
fn definition(&self) -> &ToolDefinition;
async fn validate(&self, _args: &serde_json::Value, _ctx: &ToolCallContext) -> Result<()> {
Ok(())
}
async fn execute(&self, args: serde_json::Value, ctx: &ToolCallContext) -> Result<ToolOutput>;
fn concurrency_mode(&self) -> ConcurrencyMode {
ConcurrencyMode::Shared
}
fn permission_key(&self) -> &str {
&self.definition().name
}
fn check_permissions(
&self,
_args: &serde_json::Value,
_ctx: &ToolCallContext,
) -> crate::permission::PermissionResult {
crate::permission::PermissionResult::Passthrough
}
fn permission_request(
&self,
_args: &serde_json::Value,
_ctx: &ToolCallContext,
) -> Option<crate::permission::PermissionRequest> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_tool_definition_new() {
let def = ToolDefinition::new(
"read_file",
"Read file contents",
json!({
"type": "object",
"properties": {
"path": { "type": "string" }
},
"required": ["path"]
}),
);
assert_eq!(def.name, "read_file");
assert_eq!(def.description, "Read file contents");
assert!(def.parameters["properties"]["path"]["type"]
.as_str()
.unwrap()
== "string");
}
#[test]
fn test_tool_definition_no_params() {
let def = ToolDefinition::no_params("get_time", "Get current time");
assert_eq!(def.name, "get_time");
assert_eq!(def.parameters["type"], "object");
assert!(def.parameters["properties"]
.as_object()
.unwrap()
.is_empty());
}
#[test]
fn test_tool_definition_serde_roundtrip() {
let def = ToolDefinition::new(
"bash",
"Run a shell command",
json!({
"type": "object",
"properties": {
"command": { "type": "string" }
},
"required": ["command"]
}),
);
let json_str = serde_json::to_string(&def).unwrap();
let restored: ToolDefinition = serde_json::from_str(&json_str).unwrap();
assert_eq!(def, restored);
}
#[test]
fn test_tool_output_success() {
let out = ToolOutput::success("hello world");
assert_eq!(out.content, "hello world");
assert!(!out.is_error);
assert!(out.title.is_empty());
}
#[test]
fn test_tool_output_success_with_title() {
let out = ToolOutput::success_with_title("Read file", "file contents");
assert_eq!(out.title, "Read file");
assert_eq!(out.content, "file contents");
assert!(!out.is_error);
}
#[test]
fn test_tool_output_error() {
let out = ToolOutput::error("not found");
assert_eq!(out.content, "not found");
assert!(out.is_error);
}
#[test]
fn test_tool_output_builder() {
let out = ToolOutput::success("ok")
.with_title("Done")
.with_metadata(json!({"elapsed_ms": 42}));
assert_eq!(out.title, "Done");
assert_eq!(out.metadata["elapsed_ms"], 42);
assert!(!out.is_error);
}
#[test]
fn test_tool_output_serde_roundtrip() {
let out = ToolOutput::success_with_title("Read", "contents")
.with_metadata(json!({"lines": 100}));
let json_str = serde_json::to_string(&out).unwrap();
let restored: ToolOutput = serde_json::from_str(&json_str).unwrap();
assert_eq!(out, restored);
}
#[test]
fn test_tool_output_serde_defaults() {
let json_str = r#"{"content":"hello"}"#;
let out: ToolOutput = serde_json::from_str(json_str).unwrap();
assert_eq!(out.content, "hello");
assert!(!out.is_error);
assert!(out.title.is_empty());
assert!(out.metadata.is_object());
}
#[test]
fn test_tool_choice_auto_default() {
let choice = ToolChoice::default();
assert_eq!(choice, ToolChoice::Auto);
}
#[test]
fn test_tool_choice_allows_tools() {
assert!(ToolChoice::Auto.allows_tools());
assert!(!ToolChoice::None.allows_tools());
assert!(ToolChoice::Required.allows_tools());
assert!(ToolChoice::specific("bash").allows_tools());
}
#[test]
fn test_tool_choice_required_tool() {
assert_eq!(ToolChoice::Auto.required_tool(), Option::None);
assert_eq!(ToolChoice::None.required_tool(), Option::None);
assert_eq!(ToolChoice::Required.required_tool(), Option::None);
assert_eq!(ToolChoice::specific("bash").required_tool(), Some("bash"));
}
#[test]
fn test_tool_choice_is_forced() {
assert!(!ToolChoice::Auto.is_forced());
assert!(!ToolChoice::None.is_forced());
assert!(ToolChoice::Required.is_forced());
assert!(ToolChoice::specific("bash").is_forced());
}
#[test]
fn test_tool_choice_serde_roundtrip() {
for choice in [
ToolChoice::Auto,
ToolChoice::None,
ToolChoice::Required,
ToolChoice::specific("read_file"),
] {
let json_str = serde_json::to_string(&choice).unwrap();
let restored: ToolChoice = serde_json::from_str(&json_str).unwrap();
assert_eq!(choice, restored);
}
}
#[test]
fn test_tool_choice_serde_format() {
let json_str = serde_json::to_string(&ToolChoice::Auto).unwrap();
assert!(json_str.contains(r#""type":"auto""#));
let json_str = serde_json::to_string(&ToolChoice::specific("bash")).unwrap();
assert!(json_str.contains(r#""type":"specific""#));
assert!(json_str.contains(r#""name":"bash""#));
}
#[test]
fn test_cancellation_token_new() {
let token = CancellationToken::new();
assert!(!token.is_cancelled());
}
#[test]
fn test_cancellation_token_cancel() {
let token = CancellationToken::new();
token.cancel();
assert!(token.is_cancelled());
}
#[test]
fn test_cancellation_token_clone_shares_state() {
let token = CancellationToken::new();
let token2 = token.clone();
assert!(!token.is_cancelled());
assert!(!token2.is_cancelled());
token.cancel();
assert!(token.is_cancelled());
assert!(token2.is_cancelled());
}
#[test]
fn test_cancellation_token_default() {
let token = CancellationToken::default();
assert!(!token.is_cancelled());
}
#[test]
fn test_concurrency_mode_default() {
assert_eq!(ConcurrencyMode::default(), ConcurrencyMode::Shared);
}
#[test]
fn test_concurrency_mode_serde_roundtrip() {
for mode in [ConcurrencyMode::Shared, ConcurrencyMode::Exclusive] {
let json_str = serde_json::to_string(&mode).unwrap();
let restored: ConcurrencyMode = serde_json::from_str(&json_str).unwrap();
assert_eq!(mode, restored);
}
}
#[test]
fn test_concurrency_mode_serde_format() {
assert_eq!(
serde_json::to_string(&ConcurrencyMode::Shared).unwrap(),
r#""shared""#
);
assert_eq!(
serde_json::to_string(&ConcurrencyMode::Exclusive).unwrap(),
r#""exclusive""#
);
}
#[test]
fn test_tool_call_context_new() {
let ctx = ToolCallContext::new(ToolCallId::new("call_1"));
assert_eq!(ctx.call_id.as_str(), "call_1");
assert!(!ctx.cancellation.is_cancelled());
assert!(ctx.extra.is_object());
}
#[test]
fn test_tool_call_context_builder() {
let token = CancellationToken::new();
let ctx = ToolCallContext::new(ToolCallId::new("call_2"))
.with_cancellation(token.clone())
.with_extra(json!({"cwd": "/tmp", "env": {"DEBUG": "1"}}));
assert_eq!(ctx.call_id.as_str(), "call_2");
assert_eq!(ctx.extra["cwd"], "/tmp");
assert_eq!(ctx.extra["env"]["DEBUG"], "1");
token.cancel();
assert!(ctx.cancellation.is_cancelled());
}
struct EchoTool;
static ECHO_DEF: std::sync::LazyLock<ToolDefinition> = std::sync::LazyLock::new(|| {
ToolDefinition::new(
"echo",
"Echoes the input message",
json!({
"type": "object",
"properties": {
"message": { "type": "string" }
},
"required": ["message"]
}),
)
});
#[async_trait]
impl Tool for EchoTool {
fn definition(&self) -> &ToolDefinition {
&ECHO_DEF
}
async fn execute(
&self,
args: serde_json::Value,
_ctx: &ToolCallContext,
) -> Result<ToolOutput> {
let message = args["message"]
.as_str()
.unwrap_or("(no message)");
Ok(ToolOutput::success(message))
}
}
struct ExclusiveTool;
static EXCLUSIVE_DEF: std::sync::LazyLock<ToolDefinition> = std::sync::LazyLock::new(|| {
ToolDefinition::new(
"write_file",
"Write content to a file",
json!({
"type": "object",
"properties": {
"path": { "type": "string" },
"content": { "type": "string" }
},
"required": ["path", "content"]
}),
)
});
#[async_trait]
impl Tool for ExclusiveTool {
fn definition(&self) -> &ToolDefinition {
&EXCLUSIVE_DEF
}
async fn validate(
&self,
args: &serde_json::Value,
_ctx: &ToolCallContext,
) -> Result<()> {
let path = args["path"].as_str().unwrap_or("");
if path.starts_with("/etc/") {
return Err(crate::Error::tool(
"write_file",
_ctx.call_id.clone(),
"cannot write to /etc/",
));
}
Ok(())
}
async fn execute(
&self,
args: serde_json::Value,
ctx: &ToolCallContext,
) -> Result<ToolOutput> {
if ctx.cancellation.is_cancelled() {
return Err(crate::Error::Cancelled);
}
let path = args["path"].as_str().unwrap_or("?");
Ok(ToolOutput::success(format!("wrote to {path}"))
.with_title(format!("Write: {path}")))
}
fn concurrency_mode(&self) -> ConcurrencyMode {
ConcurrencyMode::Exclusive
}
}
#[tokio::test]
async fn test_tool_echo_execute() {
let tool = EchoTool;
let ctx = ToolCallContext::new(ToolCallId::new("c1"));
let result = tool
.execute(json!({"message": "hello"}), &ctx)
.await
.unwrap();
assert_eq!(result.content, "hello");
assert!(!result.is_error);
}
#[tokio::test]
async fn test_tool_echo_default_concurrency() {
let tool = EchoTool;
assert_eq!(tool.concurrency_mode(), ConcurrencyMode::Shared);
}
#[tokio::test]
async fn test_tool_echo_default_validate() {
let tool = EchoTool;
let ctx = ToolCallContext::new(ToolCallId::new("c1"));
assert!(tool.validate(&json!({}), &ctx).await.is_ok());
}
#[tokio::test]
async fn test_tool_definition_matches() {
let tool = EchoTool;
assert_eq!(tool.definition().name, "echo");
assert!(!tool.definition().description.is_empty());
}
#[tokio::test]
async fn test_tool_exclusive_concurrency() {
let tool = ExclusiveTool;
assert_eq!(tool.concurrency_mode(), ConcurrencyMode::Exclusive);
}
#[tokio::test]
async fn test_tool_validate_rejects_invalid() {
let tool = ExclusiveTool;
let ctx = ToolCallContext::new(ToolCallId::new("c2"));
let result = tool
.validate(&json!({"path": "/etc/shadow", "content": "x"}), &ctx)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_tool_validate_accepts_valid() {
let tool = ExclusiveTool;
let ctx = ToolCallContext::new(ToolCallId::new("c3"));
let result = tool
.validate(&json!({"path": "/tmp/test.txt", "content": "x"}), &ctx)
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_tool_execute_with_cancellation() {
let tool = ExclusiveTool;
let token = CancellationToken::new();
let ctx = ToolCallContext::new(ToolCallId::new("c4"))
.with_cancellation(token.clone());
let result = tool
.execute(json!({"path": "/tmp/a.txt", "content": "hi"}), &ctx)
.await
.unwrap();
assert_eq!(result.content, "wrote to /tmp/a.txt");
token.cancel();
let ctx2 = ToolCallContext::new(ToolCallId::new("c5"))
.with_cancellation(token.clone());
let result = tool
.execute(json!({"path": "/tmp/b.txt", "content": "hi"}), &ctx2)
.await;
assert!(matches!(result, Err(crate::Error::Cancelled)));
}
#[tokio::test]
async fn test_tool_dyn_dispatch() {
let tool: Arc<dyn Tool> = Arc::new(EchoTool);
let ctx = ToolCallContext::new(ToolCallId::new("c6"));
let result = tool
.execute(json!({"message": "dynamic"}), &ctx)
.await
.unwrap();
assert_eq!(result.content, "dynamic");
}
}