use crate::events::{AgentEvent, AgentEventEnvelope, SequenceCounter};
use crate::llm;
use crate::types::{ToolOutcome, ToolResult, ToolTier};
use anyhow::Result;
use async_trait::async_trait;
use futures::Stream;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use serde_json::Value;
use std::collections::HashMap;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use time::OffsetDateTime;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
pub trait ToolName: Send + Sync + Serialize + DeserializeOwned + 'static {}
#[must_use]
pub fn tool_name_to_string<N: ToolName>(name: &N) -> String {
serde_json::to_string(name)
.unwrap_or_else(|_| "\"<unknown_tool>\"".to_string())
.trim_matches('"')
.to_string()
}
pub fn tool_name_from_str<N: ToolName>(s: &str) -> Result<N, serde_json::Error> {
serde_json::from_str(&format!("\"{s}\""))
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PrimitiveToolName {
Read,
Write,
Edit,
MultiEdit,
Bash,
Glob,
Grep,
NotebookRead,
NotebookEdit,
TodoRead,
TodoWrite,
AskUser,
LinkFetch,
WebSearch,
}
impl ToolName for PrimitiveToolName {}
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct DynamicToolName(String);
impl DynamicToolName {
#[must_use]
pub fn new(name: impl Into<String>) -> Self {
Self(name.into())
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
}
impl ToolName for DynamicToolName {}
pub trait ProgressStage: Clone + Send + Sync + Serialize + DeserializeOwned + 'static {}
#[must_use]
pub fn stage_to_string<S: ProgressStage>(stage: &S) -> String {
serde_json::to_string(stage)
.expect("ProgressStage must serialize to string")
.trim_matches('"')
.to_string()
}
#[derive(Clone, Debug, Serialize)]
pub enum ToolStatus<S: ProgressStage> {
Progress {
stage: S,
message: String,
data: Option<serde_json::Value>,
},
Completed(ToolResult),
Failed(ToolResult),
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum ErasedToolStatus {
Progress {
stage: String,
message: String,
data: Option<serde_json::Value>,
},
Completed(ToolResult),
Failed(ToolResult),
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum ListenToolUpdate {
Listening {
operation_id: String,
revision: u64,
message: String,
snapshot: Option<serde_json::Value>,
#[serde(with = "time::serde::rfc3339::option")]
expires_at: Option<OffsetDateTime>,
},
Ready {
operation_id: String,
revision: u64,
message: String,
snapshot: serde_json::Value,
#[serde(with = "time::serde::rfc3339::option")]
expires_at: Option<OffsetDateTime>,
},
Invalidated {
operation_id: String,
message: String,
recoverable: bool,
},
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub enum ListenStopReason {
UserRejected,
Blocked,
StreamDisconnected,
StreamEnded,
}
impl<S: ProgressStage> From<ToolStatus<S>> for ErasedToolStatus {
fn from(status: ToolStatus<S>) -> Self {
match status {
ToolStatus::Progress {
stage,
message,
data,
} => Self::Progress {
stage: stage_to_string(&stage),
message,
data,
},
ToolStatus::Completed(r) => Self::Completed(r),
ToolStatus::Failed(r) => Self::Failed(r),
}
}
}
pub struct ToolContext<Ctx> {
pub app: Ctx,
pub metadata: HashMap<String, Value>,
event_tx: Option<mpsc::Sender<AgentEventEnvelope>>,
event_seq: Option<SequenceCounter>,
cancel_token: Option<CancellationToken>,
}
impl<Ctx> ToolContext<Ctx> {
#[must_use]
pub fn new(app: Ctx) -> Self {
Self {
app,
metadata: HashMap::new(),
event_tx: None,
event_seq: None,
cancel_token: None,
}
}
#[must_use]
pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
#[must_use]
pub fn with_event_tx(
mut self,
tx: mpsc::Sender<AgentEventEnvelope>,
seq: SequenceCounter,
) -> Self {
self.event_tx = Some(tx);
self.event_seq = Some(seq);
self
}
pub fn emit_event(&self, event: AgentEvent) {
if let Some((tx, seq)) = self.event_tx.as_ref().zip(self.event_seq.as_ref()) {
let envelope = AgentEventEnvelope::wrap(event, seq);
let _ = tx.try_send(envelope);
}
}
#[must_use]
pub fn event_tx(&self) -> Option<mpsc::Sender<AgentEventEnvelope>> {
self.event_tx.clone()
}
#[must_use]
pub fn event_seq(&self) -> Option<SequenceCounter> {
self.event_seq.clone()
}
#[must_use]
pub fn with_cancel_token(mut self, token: CancellationToken) -> Self {
self.cancel_token = Some(token);
self
}
#[must_use]
pub fn cancel_token(&self) -> Option<CancellationToken> {
self.cancel_token.clone()
}
}
pub trait Tool<Ctx>: Send + Sync {
type Name: ToolName;
fn name(&self) -> Self::Name;
fn display_name(&self) -> &'static str;
fn description(&self) -> &'static str;
fn input_schema(&self) -> Value;
fn tier(&self) -> ToolTier {
ToolTier::Observe
}
fn execute(
&self,
ctx: &ToolContext<Ctx>,
input: Value,
) -> impl Future<Output = Result<ToolResult>> + Send;
}
pub trait AsyncTool<Ctx>: Send + Sync {
type Name: ToolName;
type Stage: ProgressStage;
fn name(&self) -> Self::Name;
fn display_name(&self) -> &'static str;
fn description(&self) -> &'static str;
fn input_schema(&self) -> Value;
fn tier(&self) -> ToolTier {
ToolTier::Observe
}
fn execute(
&self,
ctx: &ToolContext<Ctx>,
input: Value,
) -> impl Future<Output = Result<ToolOutcome>> + Send;
fn check_status(
&self,
ctx: &ToolContext<Ctx>,
operation_id: &str,
) -> impl Stream<Item = ToolStatus<Self::Stage>> + Send;
}
pub trait ListenExecuteTool<Ctx>: Send + Sync {
type Name: ToolName;
fn name(&self) -> Self::Name;
fn display_name(&self) -> &'static str;
fn description(&self) -> &'static str;
fn input_schema(&self) -> Value;
fn tier(&self) -> ToolTier {
ToolTier::Confirm
}
fn listen(
&self,
ctx: &ToolContext<Ctx>,
input: Value,
) -> impl Stream<Item = ListenToolUpdate> + Send;
fn execute(
&self,
ctx: &ToolContext<Ctx>,
operation_id: &str,
expected_revision: u64,
) -> impl Future<Output = Result<ToolResult>> + Send;
fn cancel(
&self,
_ctx: &ToolContext<Ctx>,
_operation_id: &str,
_reason: ListenStopReason,
) -> impl Future<Output = Result<()>> + Send {
async { Ok(()) }
}
}
#[async_trait]
pub trait ErasedTool<Ctx>: Send + Sync {
fn name_str(&self) -> &str;
fn display_name(&self) -> &'static str;
fn description(&self) -> &'static str;
fn input_schema(&self) -> Value;
fn tier(&self) -> ToolTier;
async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult>;
}
struct ToolWrapper<T, Ctx>
where
T: Tool<Ctx>,
{
inner: T,
name_cache: String,
_marker: PhantomData<Ctx>,
}
impl<T, Ctx> ToolWrapper<T, Ctx>
where
T: Tool<Ctx>,
{
fn new(tool: T) -> Self {
let name_cache = tool_name_to_string(&tool.name());
Self {
inner: tool,
name_cache,
_marker: PhantomData,
}
}
}
#[async_trait]
impl<T, Ctx> ErasedTool<Ctx> for ToolWrapper<T, Ctx>
where
T: Tool<Ctx> + 'static,
Ctx: Send + Sync + 'static,
{
fn name_str(&self) -> &str {
&self.name_cache
}
fn display_name(&self) -> &'static str {
self.inner.display_name()
}
fn description(&self) -> &'static str {
self.inner.description()
}
fn input_schema(&self) -> Value {
self.inner.input_schema()
}
fn tier(&self) -> ToolTier {
self.inner.tier()
}
async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
self.inner.execute(ctx, input).await
}
}
#[async_trait]
pub trait ErasedAsyncTool<Ctx>: Send + Sync {
fn name_str(&self) -> &str;
fn display_name(&self) -> &'static str;
fn description(&self) -> &'static str;
fn input_schema(&self) -> Value;
fn tier(&self) -> ToolTier;
async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolOutcome>;
fn check_status_stream<'a>(
&'a self,
ctx: &'a ToolContext<Ctx>,
operation_id: &'a str,
) -> Pin<Box<dyn Stream<Item = ErasedToolStatus> + Send + 'a>>;
}
struct AsyncToolWrapper<T, Ctx>
where
T: AsyncTool<Ctx>,
{
inner: T,
name_cache: String,
_marker: PhantomData<Ctx>,
}
impl<T, Ctx> AsyncToolWrapper<T, Ctx>
where
T: AsyncTool<Ctx>,
{
fn new(tool: T) -> Self {
let name_cache = tool_name_to_string(&tool.name());
Self {
inner: tool,
name_cache,
_marker: PhantomData,
}
}
}
#[async_trait]
impl<T, Ctx> ErasedAsyncTool<Ctx> for AsyncToolWrapper<T, Ctx>
where
T: AsyncTool<Ctx> + 'static,
Ctx: Send + Sync + 'static,
{
fn name_str(&self) -> &str {
&self.name_cache
}
fn display_name(&self) -> &'static str {
self.inner.display_name()
}
fn description(&self) -> &'static str {
self.inner.description()
}
fn input_schema(&self) -> Value {
self.inner.input_schema()
}
fn tier(&self) -> ToolTier {
self.inner.tier()
}
async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolOutcome> {
self.inner.execute(ctx, input).await
}
fn check_status_stream<'a>(
&'a self,
ctx: &'a ToolContext<Ctx>,
operation_id: &'a str,
) -> Pin<Box<dyn Stream<Item = ErasedToolStatus> + Send + 'a>> {
use futures::StreamExt;
let stream = self.inner.check_status(ctx, operation_id);
Box::pin(stream.map(ErasedToolStatus::from))
}
}
#[async_trait]
pub trait ErasedListenTool<Ctx>: Send + Sync {
fn name_str(&self) -> &str;
fn display_name(&self) -> &'static str;
fn description(&self) -> &'static str;
fn input_schema(&self) -> Value;
fn tier(&self) -> ToolTier;
fn listen_stream<'a>(
&'a self,
ctx: &'a ToolContext<Ctx>,
input: Value,
) -> Pin<Box<dyn Stream<Item = ListenToolUpdate> + Send + 'a>>;
async fn execute(
&self,
ctx: &ToolContext<Ctx>,
operation_id: &str,
expected_revision: u64,
) -> Result<ToolResult>;
async fn cancel(
&self,
ctx: &ToolContext<Ctx>,
operation_id: &str,
reason: ListenStopReason,
) -> Result<()>;
}
struct ListenToolWrapper<T, Ctx>
where
T: ListenExecuteTool<Ctx>,
{
inner: T,
name_cache: String,
_marker: PhantomData<Ctx>,
}
impl<T, Ctx> ListenToolWrapper<T, Ctx>
where
T: ListenExecuteTool<Ctx>,
{
fn new(tool: T) -> Self {
let name_cache = tool_name_to_string(&tool.name());
Self {
inner: tool,
name_cache,
_marker: PhantomData,
}
}
}
#[async_trait]
impl<T, Ctx> ErasedListenTool<Ctx> for ListenToolWrapper<T, Ctx>
where
T: ListenExecuteTool<Ctx> + 'static,
Ctx: Send + Sync + 'static,
{
fn name_str(&self) -> &str {
&self.name_cache
}
fn display_name(&self) -> &'static str {
self.inner.display_name()
}
fn description(&self) -> &'static str {
self.inner.description()
}
fn input_schema(&self) -> Value {
self.inner.input_schema()
}
fn tier(&self) -> ToolTier {
self.inner.tier()
}
fn listen_stream<'a>(
&'a self,
ctx: &'a ToolContext<Ctx>,
input: Value,
) -> Pin<Box<dyn Stream<Item = ListenToolUpdate> + Send + 'a>> {
let stream = self.inner.listen(ctx, input);
Box::pin(stream)
}
async fn execute(
&self,
ctx: &ToolContext<Ctx>,
operation_id: &str,
expected_revision: u64,
) -> Result<ToolResult> {
self.inner
.execute(ctx, operation_id, expected_revision)
.await
}
async fn cancel(
&self,
ctx: &ToolContext<Ctx>,
operation_id: &str,
reason: ListenStopReason,
) -> Result<()> {
self.inner.cancel(ctx, operation_id, reason).await
}
}
pub struct ToolRegistry<Ctx> {
tools: HashMap<String, Arc<dyn ErasedTool<Ctx>>>,
async_tools: HashMap<String, Arc<dyn ErasedAsyncTool<Ctx>>>,
listen_tools: HashMap<String, Arc<dyn ErasedListenTool<Ctx>>>,
}
impl<Ctx> Clone for ToolRegistry<Ctx> {
fn clone(&self) -> Self {
Self {
tools: self.tools.clone(),
async_tools: self.async_tools.clone(),
listen_tools: self.listen_tools.clone(),
}
}
}
impl<Ctx: Send + Sync + 'static> Default for ToolRegistry<Ctx> {
fn default() -> Self {
Self::new()
}
}
impl<Ctx: Send + Sync + 'static> ToolRegistry<Ctx> {
#[must_use]
pub fn new() -> Self {
Self {
tools: HashMap::new(),
async_tools: HashMap::new(),
listen_tools: HashMap::new(),
}
}
pub fn register<T>(&mut self, tool: T) -> &mut Self
where
T: Tool<Ctx> + 'static,
{
let wrapper = ToolWrapper::new(tool);
let name = wrapper.name_str().to_string();
self.tools.insert(name, Arc::new(wrapper));
self
}
pub fn register_async<T>(&mut self, tool: T) -> &mut Self
where
T: AsyncTool<Ctx> + 'static,
{
let wrapper = AsyncToolWrapper::new(tool);
let name = wrapper.name_str().to_string();
self.async_tools.insert(name, Arc::new(wrapper));
self
}
pub fn register_listen<T>(&mut self, tool: T) -> &mut Self
where
T: ListenExecuteTool<Ctx> + 'static,
{
let wrapper = ListenToolWrapper::new(tool);
let name = wrapper.name_str().to_string();
self.listen_tools.insert(name, Arc::new(wrapper));
self
}
#[must_use]
pub fn get(&self, name: &str) -> Option<&Arc<dyn ErasedTool<Ctx>>> {
self.tools.get(name)
}
#[must_use]
pub fn get_async(&self, name: &str) -> Option<&Arc<dyn ErasedAsyncTool<Ctx>>> {
self.async_tools.get(name)
}
#[must_use]
pub fn get_listen(&self, name: &str) -> Option<&Arc<dyn ErasedListenTool<Ctx>>> {
self.listen_tools.get(name)
}
#[must_use]
pub fn is_async(&self, name: &str) -> bool {
self.async_tools.contains_key(name)
}
#[must_use]
pub fn is_listen(&self, name: &str) -> bool {
self.listen_tools.contains_key(name)
}
pub fn all(&self) -> impl Iterator<Item = &Arc<dyn ErasedTool<Ctx>>> {
self.tools.values()
}
pub fn all_async(&self) -> impl Iterator<Item = &Arc<dyn ErasedAsyncTool<Ctx>>> {
self.async_tools.values()
}
pub fn all_listen(&self) -> impl Iterator<Item = &Arc<dyn ErasedListenTool<Ctx>>> {
self.listen_tools.values()
}
#[must_use]
pub fn len(&self) -> usize {
self.tools.len() + self.async_tools.len() + self.listen_tools.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.tools.is_empty() && self.async_tools.is_empty() && self.listen_tools.is_empty()
}
pub fn filter<F>(&mut self, predicate: F)
where
F: Fn(&str) -> bool,
{
self.tools.retain(|name, _| predicate(name));
self.async_tools.retain(|name, _| predicate(name));
self.listen_tools.retain(|name, _| predicate(name));
}
#[must_use]
pub fn to_llm_tools(&self) -> Vec<llm::Tool> {
let mut tools: Vec<_> = self
.tools
.values()
.map(|tool| llm::Tool {
name: tool.name_str().to_string(),
description: tool.description().to_string(),
input_schema: tool.input_schema(),
})
.collect();
tools.extend(self.async_tools.values().map(|tool| llm::Tool {
name: tool.name_str().to_string(),
description: tool.description().to_string(),
input_schema: tool.input_schema(),
}));
tools.extend(self.listen_tools.values().map(|tool| llm::Tool {
name: tool.name_str().to_string(),
description: tool.description().to_string(),
input_schema: tool.input_schema(),
}));
tools
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum TestToolName {
MockTool,
AnotherTool,
}
impl ToolName for TestToolName {}
struct MockTool;
impl Tool<()> for MockTool {
type Name = TestToolName;
fn name(&self) -> TestToolName {
TestToolName::MockTool
}
fn display_name(&self) -> &'static str {
"Mock Tool"
}
fn description(&self) -> &'static str {
"A mock tool for testing"
}
fn input_schema(&self) -> Value {
serde_json::json!({
"type": "object",
"properties": {
"message": { "type": "string" }
}
})
}
async fn execute(&self, _ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
let message = input
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("no message");
Ok(ToolResult::success(format!("Received: {message}")))
}
}
#[test]
fn test_tool_name_serialization() {
let name = TestToolName::MockTool;
assert_eq!(tool_name_to_string(&name), "mock_tool");
let parsed: TestToolName = tool_name_from_str("mock_tool").unwrap();
assert_eq!(parsed, TestToolName::MockTool);
}
#[test]
fn test_dynamic_tool_name() {
let name = DynamicToolName::new("my_mcp_tool");
assert_eq!(tool_name_to_string(&name), "my_mcp_tool");
assert_eq!(name.as_str(), "my_mcp_tool");
}
#[test]
fn test_tool_registry() {
let mut registry = ToolRegistry::new();
registry.register(MockTool);
assert_eq!(registry.len(), 1);
assert!(registry.get("mock_tool").is_some());
assert!(registry.get("nonexistent").is_none());
}
#[test]
fn test_to_llm_tools() {
let mut registry = ToolRegistry::new();
registry.register(MockTool);
let llm_tools = registry.to_llm_tools();
assert_eq!(llm_tools.len(), 1);
assert_eq!(llm_tools[0].name, "mock_tool");
}
struct AnotherTool;
impl Tool<()> for AnotherTool {
type Name = TestToolName;
fn name(&self) -> TestToolName {
TestToolName::AnotherTool
}
fn display_name(&self) -> &'static str {
"Another Tool"
}
fn description(&self) -> &'static str {
"Another tool for testing"
}
fn input_schema(&self) -> Value {
serde_json::json!({ "type": "object" })
}
async fn execute(&self, _ctx: &ToolContext<()>, _input: Value) -> Result<ToolResult> {
Ok(ToolResult::success("Done"))
}
}
#[test]
fn test_filter_tools() {
let mut registry = ToolRegistry::new();
registry.register(MockTool);
registry.register(AnotherTool);
assert_eq!(registry.len(), 2);
registry.filter(|name| name != "mock_tool");
assert_eq!(registry.len(), 1);
assert!(registry.get("mock_tool").is_none());
assert!(registry.get("another_tool").is_some());
}
#[test]
fn test_filter_tools_keep_all() {
let mut registry = ToolRegistry::new();
registry.register(MockTool);
registry.register(AnotherTool);
registry.filter(|_| true);
assert_eq!(registry.len(), 2);
}
#[test]
fn test_filter_tools_remove_all() {
let mut registry = ToolRegistry::new();
registry.register(MockTool);
registry.register(AnotherTool);
registry.filter(|_| false);
assert!(registry.is_empty());
}
#[test]
fn test_display_name() {
let mut registry = ToolRegistry::new();
registry.register(MockTool);
let tool = registry.get("mock_tool").unwrap();
assert_eq!(tool.display_name(), "Mock Tool");
}
struct ListenMockTool;
impl ListenExecuteTool<()> for ListenMockTool {
type Name = TestToolName;
fn name(&self) -> TestToolName {
TestToolName::MockTool
}
fn display_name(&self) -> &'static str {
"Listen Mock Tool"
}
fn description(&self) -> &'static str {
"A listen/execute mock tool for testing"
}
fn input_schema(&self) -> Value {
serde_json::json!({ "type": "object" })
}
fn listen(
&self,
_ctx: &ToolContext<()>,
_input: Value,
) -> impl futures::Stream<Item = ListenToolUpdate> + Send {
futures::stream::iter(vec![ListenToolUpdate::Ready {
operation_id: "op_1".to_string(),
revision: 1,
message: "ready".to_string(),
snapshot: serde_json::json!({"ok": true}),
expires_at: None,
}])
}
async fn execute(
&self,
_ctx: &ToolContext<()>,
_operation_id: &str,
_expected_revision: u64,
) -> Result<ToolResult> {
Ok(ToolResult::success("Executed"))
}
}
#[test]
fn test_listen_tool_registry() {
let mut registry = ToolRegistry::new();
registry.register_listen(ListenMockTool);
assert_eq!(registry.len(), 1);
assert!(registry.get_listen("mock_tool").is_some());
assert!(registry.is_listen("mock_tool"));
}
}