use crate::authority::{EventAuthority, LocalEventAuthority};
use crate::seed::{HostDependencies, ToolContextSeed};
use crate::stores::EventStore;
use agent_sdk_foundation::events::AgentEvent;
use agent_sdk_foundation::llm;
use agent_sdk_foundation::types::{ToolOutcome, ToolResult, ToolTier};
use anyhow::Result;
use async_trait::async_trait;
use futures::Stream;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use serde_json::Value;
use std::collections::HashMap;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use time::OffsetDateTime;
use tokio_util::sync::CancellationToken;
pub trait ToolName: Send + Sync + Serialize + DeserializeOwned + 'static {}
#[must_use]
pub fn tool_name_to_string<N: ToolName>(name: &N) -> String {
serde_json::to_string(name)
.unwrap_or_else(|_| "\"<unknown_tool>\"".to_string())
.trim_matches('"')
.to_string()
}
pub fn tool_name_from_str<N: ToolName>(s: &str) -> Result<N, serde_json::Error> {
serde_json::from_str(&format!("\"{s}\""))
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PrimitiveToolName {
Read,
Write,
Edit,
MultiEdit,
Bash,
Glob,
Grep,
NotebookRead,
NotebookEdit,
TodoRead,
TodoWrite,
AskUser,
LinkFetch,
WebSearch,
}
impl ToolName for PrimitiveToolName {}
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct DynamicToolName(String);
impl DynamicToolName {
#[must_use]
pub fn new(name: impl Into<String>) -> Self {
Self(name.into())
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
}
impl ToolName for DynamicToolName {}
pub trait ProgressStage: Clone + Send + Sync + Serialize + DeserializeOwned + 'static {}
#[must_use]
pub fn stage_to_string<S: ProgressStage>(stage: &S) -> String {
serde_json::to_string(stage)
.expect("ProgressStage must serialize to string")
.trim_matches('"')
.to_string()
}
#[derive(Clone, Debug, Serialize)]
pub enum ToolStatus<S: ProgressStage> {
Progress {
stage: S,
message: String,
data: Option<serde_json::Value>,
},
Completed(ToolResult),
Failed(ToolResult),
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum ErasedToolStatus {
Progress {
stage: String,
message: String,
data: Option<serde_json::Value>,
},
Completed(ToolResult),
Failed(ToolResult),
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum ListenToolUpdate {
Listening {
operation_id: String,
revision: u64,
message: String,
snapshot: Option<serde_json::Value>,
#[serde(with = "time::serde::rfc3339::option")]
expires_at: Option<OffsetDateTime>,
},
Ready {
operation_id: String,
revision: u64,
message: String,
snapshot: serde_json::Value,
#[serde(with = "time::serde::rfc3339::option")]
expires_at: Option<OffsetDateTime>,
},
Invalidated {
operation_id: String,
message: String,
recoverable: bool,
},
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub enum ListenStopReason {
UserRejected,
Blocked,
StreamDisconnected,
StreamEnded,
}
impl<S: ProgressStage> From<ToolStatus<S>> for ErasedToolStatus {
fn from(status: ToolStatus<S>) -> Self {
match status {
ToolStatus::Progress {
stage,
message,
data,
} => Self::Progress {
stage: stage_to_string(&stage),
message,
data,
},
ToolStatus::Completed(r) => Self::Completed(r),
ToolStatus::Failed(r) => Self::Failed(r),
}
}
}
#[derive(Clone)]
pub struct ToolContext<Ctx> {
pub app: Ctx,
pub metadata: HashMap<String, Value>,
event_store: Option<Arc<dyn EventStore>>,
event_thread_id: Option<agent_sdk_foundation::types::ThreadId>,
event_turn: Option<usize>,
event_authority: Option<Arc<dyn EventAuthority>>,
cancel_token: Option<CancellationToken>,
subagent_semaphore: Option<Arc<tokio::sync::Semaphore>>,
tool_timeout: Option<std::time::Duration>,
}
impl<Ctx> ToolContext<Ctx> {
#[must_use]
pub fn new(app: Ctx) -> Self {
Self {
app,
metadata: HashMap::new(),
event_store: None,
event_thread_id: None,
event_turn: None,
event_authority: None,
cancel_token: None,
subagent_semaphore: None,
tool_timeout: None,
}
}
#[must_use]
pub fn from_seed(seed: &ToolContextSeed, app: Ctx, deps: HostDependencies) -> Self {
let authority: Arc<dyn EventAuthority> =
Arc::new(LocalEventAuthority::with_offset(seed.sequence_offset));
Self {
app,
metadata: seed.metadata.clone(),
event_store: Some(deps.event_store),
event_thread_id: Some(seed.thread_id.clone()),
event_turn: Some(seed.turn),
event_authority: Some(authority),
cancel_token: Some(deps.cancel_token),
subagent_semaphore: deps.subagent_semaphore,
tool_timeout: None,
}
}
#[must_use]
pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
#[must_use]
pub fn with_event_store(
mut self,
store: Arc<dyn EventStore>,
thread_id: agent_sdk_foundation::types::ThreadId,
turn: usize,
authority: Arc<dyn EventAuthority>,
) -> Self {
self.event_store = Some(store);
self.event_thread_id = Some(thread_id);
self.event_turn = Some(turn);
self.event_authority = Some(authority);
self
}
pub async fn emit_event(&self, event: AgentEvent) -> Result<()>
where
Ctx: Sync,
{
if let Some((store, authority, thread_id, turn)) = self
.event_store
.as_ref()
.zip(self.event_authority.as_ref())
.zip(self.event_thread_id.as_ref())
.zip(self.event_turn)
.map(|(((store, authority), thread_id), turn)| (store, authority, thread_id, turn))
{
let envelope = authority.wrap(event);
store.append(thread_id, turn, envelope).await?;
}
Ok(())
}
#[must_use]
pub fn event_authority(&self) -> Option<Arc<dyn EventAuthority>> {
self.event_authority.clone()
}
#[must_use]
pub fn with_cancel_token(mut self, token: CancellationToken) -> Self {
self.cancel_token = Some(token);
self
}
#[must_use]
pub fn cancel_token(&self) -> Option<CancellationToken> {
self.cancel_token.clone()
}
#[must_use]
pub const fn with_tool_timeout(mut self, timeout: std::time::Duration) -> Self {
self.tool_timeout = Some(timeout);
self
}
#[must_use]
pub const fn tool_timeout(&self) -> Option<std::time::Duration> {
self.tool_timeout
}
#[must_use]
pub fn with_subagent_semaphore(mut self, semaphore: Arc<tokio::sync::Semaphore>) -> Self {
self.subagent_semaphore = Some(semaphore);
self
}
#[must_use]
pub fn subagent_semaphore(&self) -> Option<Arc<tokio::sync::Semaphore>> {
self.subagent_semaphore.clone()
}
}
pub trait Tool<Ctx>: Send + Sync {
type Name: ToolName;
fn name(&self) -> Self::Name;
fn display_name(&self) -> &'static str;
fn description(&self) -> &'static str;
fn input_schema(&self) -> Value;
fn tier(&self) -> ToolTier {
ToolTier::Observe
}
fn execute(
&self,
ctx: &ToolContext<Ctx>,
input: Value,
) -> impl Future<Output = Result<ToolResult>> + Send;
}
pub trait TypedTool<Ctx>: Send + Sync {
type Input: DeserializeOwned + Serialize + Send + 'static;
fn name(&self) -> &'static str;
fn display_name(&self) -> &'static str {
""
}
fn description(&self) -> &'static str;
fn input_schema(&self) -> Value;
fn tier(&self) -> ToolTier {
ToolTier::Observe
}
fn execute(
&self,
ctx: &ToolContext<Ctx>,
input: Self::Input,
) -> impl Future<Output = Result<ToolResult>> + Send;
}
#[must_use]
pub fn invalid_tool_input_result(tool_name: &str, error: &serde_json::Error) -> ToolResult {
ToolResult::error(format!(
"Invalid arguments for tool `{tool_name}`: {error}. \
The arguments did not match the tool's input schema — \
re-read the schema and call the tool again with corrected arguments."
))
}
pub fn validate_tool_input<Input>(tool_name: &str, raw: Value) -> Result<Input, ToolResult>
where
Input: DeserializeOwned,
{
serde_json::from_value(raw).map_err(|error| invalid_tool_input_result(tool_name, &error))
}
pub struct TypedToolAdapter<T> {
inner: T,
}
impl<T> TypedToolAdapter<T> {
pub const fn new(tool: T) -> Self {
Self { inner: tool }
}
pub fn into_inner(self) -> T {
self.inner
}
}
impl<Ctx, T> Tool<Ctx> for TypedToolAdapter<T>
where
T: TypedTool<Ctx>,
Ctx: Send + Sync,
{
type Name = DynamicToolName;
fn name(&self) -> DynamicToolName {
DynamicToolName::new(TypedTool::name(&self.inner))
}
fn display_name(&self) -> &'static str {
TypedTool::display_name(&self.inner)
}
fn description(&self) -> &'static str {
TypedTool::description(&self.inner)
}
fn input_schema(&self) -> Value {
TypedTool::input_schema(&self.inner)
}
fn tier(&self) -> ToolTier {
TypedTool::tier(&self.inner)
}
async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
match validate_tool_input::<<T as TypedTool<Ctx>>::Input>(
TypedTool::name(&self.inner),
input,
) {
Ok(typed) => TypedTool::execute(&self.inner, ctx, typed).await,
Err(result) => Ok(result),
}
}
}
pub trait ToolLogic<Ctx>: Send + Sync {
type Input;
fn execute(
&self,
ctx: &ToolContext<Ctx>,
input: Self::Input,
) -> impl Future<Output = Result<ToolResult>> + Send;
}
pub trait SimpleTool<Ctx>: Send + Sync {
fn name(&self) -> &'static str;
fn display_name(&self) -> &'static str {
""
}
fn description(&self) -> &'static str;
fn input_schema(&self) -> Value;
fn tier(&self) -> ToolTier {
ToolTier::Observe
}
fn execute(
&self,
ctx: &ToolContext<Ctx>,
input: Value,
) -> impl Future<Output = Result<ToolResult>> + Send;
}
pub struct SimpleToolAdapter<T> {
inner: T,
}
impl<T> SimpleToolAdapter<T> {
pub const fn new(tool: T) -> Self {
Self { inner: tool }
}
pub fn into_inner(self) -> T {
self.inner
}
}
impl<Ctx, T> Tool<Ctx> for SimpleToolAdapter<T>
where
T: SimpleTool<Ctx>,
{
type Name = DynamicToolName;
fn name(&self) -> DynamicToolName {
DynamicToolName::new(SimpleTool::name(&self.inner))
}
fn display_name(&self) -> &'static str {
SimpleTool::display_name(&self.inner)
}
fn description(&self) -> &'static str {
SimpleTool::description(&self.inner)
}
fn input_schema(&self) -> Value {
SimpleTool::input_schema(&self.inner)
}
fn tier(&self) -> ToolTier {
SimpleTool::tier(&self.inner)
}
fn execute(
&self,
ctx: &ToolContext<Ctx>,
input: Value,
) -> impl Future<Output = Result<ToolResult>> + Send {
SimpleTool::execute(&self.inner, ctx, input)
}
}
pub trait AsyncTool<Ctx>: Send + Sync {
type Name: ToolName;
type Stage: ProgressStage;
fn name(&self) -> Self::Name;
fn display_name(&self) -> &'static str;
fn description(&self) -> &'static str;
fn input_schema(&self) -> Value;
fn tier(&self) -> ToolTier {
ToolTier::Observe
}
fn execute(
&self,
ctx: &ToolContext<Ctx>,
input: Value,
) -> impl Future<Output = Result<ToolOutcome>> + Send;
fn check_status(
&self,
ctx: &ToolContext<Ctx>,
operation_id: &str,
) -> impl Stream<Item = ToolStatus<Self::Stage>> + Send;
}
pub trait ListenExecuteTool<Ctx>: Send + Sync {
type Name: ToolName;
fn name(&self) -> Self::Name;
fn display_name(&self) -> &'static str;
fn description(&self) -> &'static str;
fn input_schema(&self) -> Value;
fn tier(&self) -> ToolTier {
ToolTier::Confirm
}
fn listen(
&self,
ctx: &ToolContext<Ctx>,
input: Value,
) -> impl Stream<Item = ListenToolUpdate> + Send;
fn execute(
&self,
ctx: &ToolContext<Ctx>,
operation_id: &str,
expected_revision: u64,
) -> impl Future<Output = Result<ToolResult>> + Send;
fn cancel(
&self,
_ctx: &ToolContext<Ctx>,
_operation_id: &str,
_reason: ListenStopReason,
) -> impl Future<Output = Result<()>> + Send {
async { Ok(()) }
}
}
#[async_trait]
pub trait ErasedTool<Ctx>: Send + Sync {
fn name_str(&self) -> &str;
fn display_name(&self) -> &'static str;
fn description(&self) -> &'static str;
fn input_schema(&self) -> Value;
fn tier(&self) -> ToolTier;
async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult>;
}
struct ToolWrapper<T, Ctx>
where
T: Tool<Ctx>,
{
inner: T,
name_cache: String,
_marker: PhantomData<Ctx>,
}
impl<T, Ctx> ToolWrapper<T, Ctx>
where
T: Tool<Ctx>,
{
fn new(tool: T) -> Self {
let name_cache = tool_name_to_string(&tool.name());
Self {
inner: tool,
name_cache,
_marker: PhantomData,
}
}
}
#[async_trait]
impl<T, Ctx> ErasedTool<Ctx> for ToolWrapper<T, Ctx>
where
T: Tool<Ctx> + 'static,
Ctx: Send + Sync + 'static,
{
fn name_str(&self) -> &str {
&self.name_cache
}
fn display_name(&self) -> &'static str {
self.inner.display_name()
}
fn description(&self) -> &'static str {
self.inner.description()
}
fn input_schema(&self) -> Value {
self.inner.input_schema()
}
fn tier(&self) -> ToolTier {
self.inner.tier()
}
async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
self.inner.execute(ctx, input).await
}
}
#[async_trait]
pub trait ErasedAsyncTool<Ctx>: Send + Sync {
fn name_str(&self) -> &str;
fn display_name(&self) -> &'static str;
fn description(&self) -> &'static str;
fn input_schema(&self) -> Value;
fn tier(&self) -> ToolTier;
async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolOutcome>;
fn check_status_stream<'a>(
&'a self,
ctx: &'a ToolContext<Ctx>,
operation_id: &'a str,
) -> Pin<Box<dyn Stream<Item = ErasedToolStatus> + Send + 'a>>;
}
struct AsyncToolWrapper<T, Ctx>
where
T: AsyncTool<Ctx>,
{
inner: T,
name_cache: String,
_marker: PhantomData<Ctx>,
}
impl<T, Ctx> AsyncToolWrapper<T, Ctx>
where
T: AsyncTool<Ctx>,
{
fn new(tool: T) -> Self {
let name_cache = tool_name_to_string(&tool.name());
Self {
inner: tool,
name_cache,
_marker: PhantomData,
}
}
}
#[async_trait]
impl<T, Ctx> ErasedAsyncTool<Ctx> for AsyncToolWrapper<T, Ctx>
where
T: AsyncTool<Ctx> + 'static,
Ctx: Send + Sync + 'static,
{
fn name_str(&self) -> &str {
&self.name_cache
}
fn display_name(&self) -> &'static str {
self.inner.display_name()
}
fn description(&self) -> &'static str {
self.inner.description()
}
fn input_schema(&self) -> Value {
self.inner.input_schema()
}
fn tier(&self) -> ToolTier {
self.inner.tier()
}
async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolOutcome> {
self.inner.execute(ctx, input).await
}
fn check_status_stream<'a>(
&'a self,
ctx: &'a ToolContext<Ctx>,
operation_id: &'a str,
) -> Pin<Box<dyn Stream<Item = ErasedToolStatus> + Send + 'a>> {
use futures::StreamExt;
let stream = self.inner.check_status(ctx, operation_id);
Box::pin(stream.map(ErasedToolStatus::from))
}
}
#[async_trait]
pub trait ErasedListenTool<Ctx>: Send + Sync {
fn name_str(&self) -> &str;
fn display_name(&self) -> &'static str;
fn description(&self) -> &'static str;
fn input_schema(&self) -> Value;
fn tier(&self) -> ToolTier;
fn listen_stream<'a>(
&'a self,
ctx: &'a ToolContext<Ctx>,
input: Value,
) -> Pin<Box<dyn Stream<Item = ListenToolUpdate> + Send + 'a>>;
async fn execute(
&self,
ctx: &ToolContext<Ctx>,
operation_id: &str,
expected_revision: u64,
) -> Result<ToolResult>;
async fn cancel(
&self,
ctx: &ToolContext<Ctx>,
operation_id: &str,
reason: ListenStopReason,
) -> Result<()>;
}
struct ListenToolWrapper<T, Ctx>
where
T: ListenExecuteTool<Ctx>,
{
inner: T,
name_cache: String,
_marker: PhantomData<Ctx>,
}
impl<T, Ctx> ListenToolWrapper<T, Ctx>
where
T: ListenExecuteTool<Ctx>,
{
fn new(tool: T) -> Self {
let name_cache = tool_name_to_string(&tool.name());
Self {
inner: tool,
name_cache,
_marker: PhantomData,
}
}
}
#[async_trait]
impl<T, Ctx> ErasedListenTool<Ctx> for ListenToolWrapper<T, Ctx>
where
T: ListenExecuteTool<Ctx> + 'static,
Ctx: Send + Sync + 'static,
{
fn name_str(&self) -> &str {
&self.name_cache
}
fn display_name(&self) -> &'static str {
self.inner.display_name()
}
fn description(&self) -> &'static str {
self.inner.description()
}
fn input_schema(&self) -> Value {
self.inner.input_schema()
}
fn tier(&self) -> ToolTier {
self.inner.tier()
}
fn listen_stream<'a>(
&'a self,
ctx: &'a ToolContext<Ctx>,
input: Value,
) -> Pin<Box<dyn Stream<Item = ListenToolUpdate> + Send + 'a>> {
let stream = self.inner.listen(ctx, input);
Box::pin(stream)
}
async fn execute(
&self,
ctx: &ToolContext<Ctx>,
operation_id: &str,
expected_revision: u64,
) -> Result<ToolResult> {
self.inner
.execute(ctx, operation_id, expected_revision)
.await
}
async fn cancel(
&self,
ctx: &ToolContext<Ctx>,
operation_id: &str,
reason: ListenStopReason,
) -> Result<()> {
self.inner.cancel(ctx, operation_id, reason).await
}
}
pub struct ToolRegistry<Ctx> {
tools: HashMap<String, Arc<dyn ErasedTool<Ctx>>>,
async_tools: HashMap<String, Arc<dyn ErasedAsyncTool<Ctx>>>,
listen_tools: HashMap<String, Arc<dyn ErasedListenTool<Ctx>>>,
}
impl<Ctx> Clone for ToolRegistry<Ctx> {
fn clone(&self) -> Self {
Self {
tools: self.tools.clone(),
async_tools: self.async_tools.clone(),
listen_tools: self.listen_tools.clone(),
}
}
}
impl<Ctx: Send + Sync + 'static> Default for ToolRegistry<Ctx> {
fn default() -> Self {
Self::new()
}
}
impl<Ctx: Send + Sync + 'static> ToolRegistry<Ctx> {
#[must_use]
pub fn new() -> Self {
Self {
tools: HashMap::new(),
async_tools: HashMap::new(),
listen_tools: HashMap::new(),
}
}
pub fn register<T>(&mut self, tool: T) -> &mut Self
where
T: Tool<Ctx> + 'static,
{
let wrapper = ToolWrapper::new(tool);
let name = wrapper.name_str().to_string();
self.tools.insert(name, Arc::new(wrapper));
self
}
pub fn register_simple<T>(&mut self, tool: T) -> &mut Self
where
T: SimpleTool<Ctx> + 'static,
{
self.register(SimpleToolAdapter::new(tool))
}
pub fn register_typed<T>(&mut self, tool: T) -> &mut Self
where
T: TypedTool<Ctx> + 'static,
{
self.register(TypedToolAdapter::new(tool))
}
pub fn register_async<T>(&mut self, tool: T) -> &mut Self
where
T: AsyncTool<Ctx> + 'static,
{
let wrapper = AsyncToolWrapper::new(tool);
let name = wrapper.name_str().to_string();
self.async_tools.insert(name, Arc::new(wrapper));
self
}
pub fn register_listen<T>(&mut self, tool: T) -> &mut Self
where
T: ListenExecuteTool<Ctx> + 'static,
{
let wrapper = ListenToolWrapper::new(tool);
let name = wrapper.name_str().to_string();
self.listen_tools.insert(name, Arc::new(wrapper));
self
}
#[must_use]
pub fn get(&self, name: &str) -> Option<&Arc<dyn ErasedTool<Ctx>>> {
self.tools.get(name)
}
#[must_use]
pub fn get_async(&self, name: &str) -> Option<&Arc<dyn ErasedAsyncTool<Ctx>>> {
self.async_tools.get(name)
}
#[must_use]
pub fn get_listen(&self, name: &str) -> Option<&Arc<dyn ErasedListenTool<Ctx>>> {
self.listen_tools.get(name)
}
#[must_use]
pub fn is_async(&self, name: &str) -> bool {
self.async_tools.contains_key(name)
}
#[must_use]
pub fn is_listen(&self, name: &str) -> bool {
self.listen_tools.contains_key(name)
}
pub fn all(&self) -> impl Iterator<Item = &Arc<dyn ErasedTool<Ctx>>> {
self.tools.values()
}
pub fn all_async(&self) -> impl Iterator<Item = &Arc<dyn ErasedAsyncTool<Ctx>>> {
self.async_tools.values()
}
pub fn all_listen(&self) -> impl Iterator<Item = &Arc<dyn ErasedListenTool<Ctx>>> {
self.listen_tools.values()
}
#[must_use]
pub fn len(&self) -> usize {
self.tools.len() + self.async_tools.len() + self.listen_tools.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.tools.is_empty() && self.async_tools.is_empty() && self.listen_tools.is_empty()
}
pub fn filter<F>(&mut self, predicate: F)
where
F: Fn(&str) -> bool,
{
self.tools.retain(|name, _| predicate(name));
self.async_tools.retain(|name, _| predicate(name));
self.listen_tools.retain(|name, _| predicate(name));
}
#[must_use]
pub fn to_llm_tools(&self) -> Vec<llm::Tool> {
let mut tools: Vec<_> = self
.tools
.values()
.map(|tool| llm::Tool {
name: tool.name_str().to_string(),
description: tool.description().to_string(),
input_schema: tool.input_schema(),
display_name: tool.display_name().to_string(),
tier: tool.tier(),
})
.collect();
tools.extend(self.async_tools.values().map(|tool| llm::Tool {
name: tool.name_str().to_string(),
description: tool.description().to_string(),
input_schema: tool.input_schema(),
display_name: tool.display_name().to_string(),
tier: tool.tier(),
}));
tools.extend(self.listen_tools.values().map(|tool| llm::Tool {
name: tool.name_str().to_string(),
description: tool.description().to_string(),
input_schema: tool.input_schema(),
display_name: tool.display_name().to_string(),
tier: tool.tier(),
}));
tools.sort_by(|a, b| a.name.cmp(&b.name));
tools
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Context;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum TestToolName {
MockTool,
AnotherTool,
}
impl ToolName for TestToolName {}
struct MockTool;
impl Tool<()> for MockTool {
type Name = TestToolName;
fn name(&self) -> TestToolName {
TestToolName::MockTool
}
fn display_name(&self) -> &'static str {
"Mock Tool"
}
fn description(&self) -> &'static str {
"A mock tool for testing"
}
fn input_schema(&self) -> Value {
serde_json::json!({
"type": "object",
"properties": {
"message": { "type": "string" }
}
})
}
async fn execute(&self, _ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
let message = input
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("no message");
Ok(ToolResult::success(format!("Received: {message}")))
}
}
#[test]
fn test_tool_name_serialization() {
let name = TestToolName::MockTool;
assert_eq!(tool_name_to_string(&name), "mock_tool");
let parsed: TestToolName = tool_name_from_str("mock_tool").unwrap();
assert_eq!(parsed, TestToolName::MockTool);
}
#[test]
fn test_dynamic_tool_name() {
let name = DynamicToolName::new("my_mcp_tool");
assert_eq!(tool_name_to_string(&name), "my_mcp_tool");
assert_eq!(name.as_str(), "my_mcp_tool");
}
#[test]
fn test_tool_registry() {
let mut registry = ToolRegistry::new();
registry.register(MockTool);
assert_eq!(registry.len(), 1);
assert!(registry.get("mock_tool").is_some());
assert!(registry.get("nonexistent").is_none());
}
#[test]
fn test_to_llm_tools() {
let mut registry = ToolRegistry::new();
registry.register(MockTool);
let llm_tools = registry.to_llm_tools();
assert_eq!(llm_tools.len(), 1);
assert_eq!(llm_tools[0].name, "mock_tool");
}
#[test]
fn to_llm_tools_returns_alphabetical_order() {
let mut registry = ToolRegistry::new();
registry.register(MockTool); registry.register(AnotherTool);
let names: Vec<String> = registry
.to_llm_tools()
.into_iter()
.map(|t| t.name)
.collect();
assert_eq!(names, vec!["another_tool", "mock_tool"]);
}
#[test]
fn to_llm_tools_is_deterministic_across_calls() {
let mut registry = ToolRegistry::new();
registry.register(MockTool);
registry.register(AnotherTool);
let first: Vec<String> = registry
.to_llm_tools()
.into_iter()
.map(|t| t.name)
.collect();
for _ in 0..32 {
let next: Vec<String> = registry
.to_llm_tools()
.into_iter()
.map(|t| t.name)
.collect();
assert_eq!(next, first, "tool ordering must be stable across calls");
}
}
struct AnotherTool;
impl Tool<()> for AnotherTool {
type Name = TestToolName;
fn name(&self) -> TestToolName {
TestToolName::AnotherTool
}
fn display_name(&self) -> &'static str {
"Another Tool"
}
fn description(&self) -> &'static str {
"Another tool for testing"
}
fn input_schema(&self) -> Value {
serde_json::json!({ "type": "object" })
}
async fn execute(&self, _ctx: &ToolContext<()>, _input: Value) -> Result<ToolResult> {
Ok(ToolResult::success("Done"))
}
}
#[test]
fn test_filter_tools() {
let mut registry = ToolRegistry::new();
registry.register(MockTool);
registry.register(AnotherTool);
assert_eq!(registry.len(), 2);
registry.filter(|name| name != "mock_tool");
assert_eq!(registry.len(), 1);
assert!(registry.get("mock_tool").is_none());
assert!(registry.get("another_tool").is_some());
}
#[test]
fn test_filter_tools_keep_all() {
let mut registry = ToolRegistry::new();
registry.register(MockTool);
registry.register(AnotherTool);
registry.filter(|_| true);
assert_eq!(registry.len(), 2);
}
#[test]
fn test_filter_tools_remove_all() {
let mut registry = ToolRegistry::new();
registry.register(MockTool);
registry.register(AnotherTool);
registry.filter(|_| false);
assert!(registry.is_empty());
}
#[test]
fn test_display_name() {
let mut registry = ToolRegistry::new();
registry.register(MockTool);
let tool = registry.get("mock_tool").unwrap();
assert_eq!(tool.display_name(), "Mock Tool");
}
struct ListenMockTool;
impl ListenExecuteTool<()> for ListenMockTool {
type Name = TestToolName;
fn name(&self) -> TestToolName {
TestToolName::MockTool
}
fn display_name(&self) -> &'static str {
"Listen Mock Tool"
}
fn description(&self) -> &'static str {
"A listen/execute mock tool for testing"
}
fn input_schema(&self) -> Value {
serde_json::json!({ "type": "object" })
}
fn listen(
&self,
_ctx: &ToolContext<()>,
_input: Value,
) -> impl futures::Stream<Item = ListenToolUpdate> + Send {
futures::stream::iter(vec![ListenToolUpdate::Ready {
operation_id: "op_1".to_string(),
revision: 1,
message: "ready".to_string(),
snapshot: serde_json::json!({"ok": true}),
expires_at: None,
}])
}
async fn execute(
&self,
_ctx: &ToolContext<()>,
_operation_id: &str,
_expected_revision: u64,
) -> Result<ToolResult> {
Ok(ToolResult::success("Executed"))
}
}
#[test]
fn test_listen_tool_registry() {
let mut registry = ToolRegistry::new();
registry.register_listen(ListenMockTool);
assert_eq!(registry.len(), 1);
assert!(registry.get_listen("mock_tool").is_some());
assert!(registry.is_listen("mock_tool"));
}
use std::sync::atomic::{AtomicBool, Ordering};
#[derive(Debug, Serialize, Deserialize)]
struct GreetArgs {
name: String,
greeting: String,
}
struct GreetTool {
executed: Arc<AtomicBool>,
}
impl TypedTool<()> for GreetTool {
type Input = GreetArgs;
fn name(&self) -> &'static str {
"greet"
}
fn description(&self) -> &'static str {
"Greet someone by name"
}
fn input_schema(&self) -> Value {
serde_json::json!({
"type": "object",
"properties": {
"name": { "type": "string" },
"greeting": { "type": "string" }
},
"required": ["name", "greeting"]
})
}
async fn execute(&self, _ctx: &ToolContext<()>, input: GreetArgs) -> Result<ToolResult> {
self.executed.store(true, Ordering::SeqCst);
Ok(ToolResult::success(format!(
"{}, {}!",
input.greeting, input.name
)))
}
}
#[tokio::test]
async fn typed_tool_happy_path_receives_typed_input() -> Result<()> {
let executed = Arc::new(AtomicBool::new(false));
let adapter = TypedToolAdapter::new(GreetTool {
executed: executed.clone(),
});
let ctx = ToolContext::new(());
let result = Tool::execute(
&adapter,
&ctx,
serde_json::json!({ "name": "Ada", "greeting": "Hello" }),
)
.await?;
assert!(executed.load(Ordering::SeqCst), "execute must be called");
assert!(result.success);
assert_eq!(result.output, "Hello, Ada!");
Ok(())
}
#[tokio::test]
async fn typed_tool_invalid_args_self_correct_without_executing() -> Result<()> {
let executed = Arc::new(AtomicBool::new(false));
let adapter = TypedToolAdapter::new(GreetTool {
executed: executed.clone(),
});
let ctx = ToolContext::new(());
let result = Tool::execute(&adapter, &ctx, serde_json::json!({ "name": "Ada" })).await?;
assert!(
!executed.load(Ordering::SeqCst),
"execute must NOT be called with invalid arguments"
);
assert!(!result.success, "validation failure is an error result");
assert!(
result.output.contains("Invalid arguments for tool `greet`"),
"error must identify the tool: {}",
result.output
);
assert!(
result.output.contains("greeting"),
"error must surface the serde message naming the bad field: {}",
result.output
);
Ok(())
}
#[tokio::test]
async fn typed_tool_wrong_type_self_corrects() -> Result<()> {
let executed = Arc::new(AtomicBool::new(false));
let adapter = TypedToolAdapter::new(GreetTool {
executed: executed.clone(),
});
let ctx = ToolContext::new(());
let result = Tool::execute(
&adapter,
&ctx,
serde_json::json!({ "name": 42, "greeting": "Hi" }),
)
.await?;
assert!(!executed.load(Ordering::SeqCst));
assert!(!result.success);
Ok(())
}
struct ValueTypedTool;
impl TypedTool<()> for ValueTypedTool {
type Input = Value;
fn name(&self) -> &'static str {
"value_typed"
}
fn description(&self) -> &'static str {
"Accepts any JSON, like an untyped tool"
}
fn input_schema(&self) -> Value {
serde_json::json!({ "type": "object" })
}
async fn execute(&self, _ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
Ok(ToolResult::success(input.to_string()))
}
}
#[tokio::test]
async fn typed_tool_value_input_is_identity_passthrough() -> Result<()> {
let adapter = TypedToolAdapter::new(ValueTypedTool);
let ctx = ToolContext::new(());
let result = Tool::execute(
&adapter,
&ctx,
serde_json::json!({ "anything": [1, 2, 3], "nested": { "ok": true } }),
)
.await?;
assert!(result.success);
Ok(())
}
#[test]
fn register_typed_exposes_tool_via_registry() -> Result<()> {
let mut registry = ToolRegistry::new();
registry.register_typed(GreetTool {
executed: Arc::new(AtomicBool::new(false)),
});
assert_eq!(registry.len(), 1);
let tool = registry.get("greet").context("typed tool registered")?;
assert_eq!(tool.input_schema()["required"][0], "name");
Ok(())
}
#[test]
fn invalid_tool_input_result_is_balanced_error() -> Result<()> {
let Err(err) = serde_json::from_str::<GreetArgs>("{}") else {
anyhow::bail!("empty object must fail to deserialize GreetArgs");
};
let result = invalid_tool_input_result("greet", &err);
assert!(!result.success);
assert!(result.output.contains("greet"));
assert!(result.output.contains("call the tool again"));
Ok(())
}
}