pub mod executor;
pub mod helpers;
pub mod loader;
pub mod mcp;
pub mod mcp_instrumentation;
pub mod registry;
pub mod structured_output;
pub mod validator;
pub mod watcher;
use std::pin::Pin;
use futures::Stream;
use async_trait::async_trait;
use crate::types::tools::{ToolResult, ToolResultContent, ToolResultStatus, ToolSpec, ToolUse};
pub type ToolEventStream = Pin<Box<dyn Stream<Item = ToolEvent> + Send>>;
pub type ToolGenerator = ToolEventStream;
#[derive(Debug, Clone)]
pub enum ToolEvent {
Progress { message: String },
Stream(serde_json::Value),
Result(ToolResult),
Interrupt { id: String, data: serde_json::Value },
}
impl ToolEvent {
pub fn progress(message: impl Into<String>) -> Self {
Self::Progress { message: message.into() }
}
pub fn stream(data: serde_json::Value) -> Self { Self::Stream(data) }
pub fn result(result: ToolResult) -> Self { Self::Result(result) }
pub fn is_result(&self) -> bool { matches!(self, Self::Result(_)) }
pub fn as_result(&self) -> Option<&ToolResult> {
match self {
Self::Result(r) => Some(r),
_ => None,
}
}
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct InvocationState {
pub data: std::collections::HashMap<String, serde_json::Value>,
#[serde(default)]
pub stop_event_loop: bool,
}
impl InvocationState {
pub fn new() -> Self { Self::default() }
pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
self.data.get(key).and_then(|v| T::deserialize(v).ok())
}
pub fn set(&mut self, key: impl Into<String>, value: impl serde::Serialize) {
if let Ok(v) = serde_json::to_value(value) {
self.data.insert(key.into(), v);
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ToolContext {
pub invocation_state: InvocationState,
pub interrupt_id: Option<uuid::Uuid>,
}
impl ToolContext {
pub fn new() -> Self { Self::default() }
pub fn with_state(state: InvocationState) -> Self {
Self { invocation_state: state, interrupt_id: None }
}
}
#[derive(Debug, Clone)]
pub struct ToolResult2 {
pub status: ToolResultStatus,
pub content: Vec<ToolResultContent>,
}
impl ToolResult2 {
pub fn success(content: impl Into<String>) -> Self {
Self {
status: ToolResultStatus::Success,
content: vec![ToolResultContent::text(content.into())],
}
}
pub fn success_json(value: serde_json::Value) -> Self {
Self {
status: ToolResultStatus::Success,
content: vec![ToolResultContent::json(value)],
}
}
pub fn error(message: impl Into<String>) -> Self {
Self {
status: ToolResultStatus::Error,
content: vec![ToolResultContent::text(message.into())],
}
}
}
#[async_trait]
pub trait AgentTool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn tool_spec(&self) -> ToolSpec;
async fn invoke(
&self,
input: serde_json::Value,
context: &ToolContext,
) -> std::result::Result<ToolResult2, String>;
fn tool_name(&self) -> &str { self.name() }
fn tool_type(&self) -> &str { "function" }
fn supports_hot_reload(&self) -> bool { false }
fn is_dynamic(&self) -> bool { false }
fn get_display_properties(&self) -> std::collections::HashMap<String, String> {
let mut props = std::collections::HashMap::new();
props.insert("Name".to_string(), self.name().to_string());
props.insert("Type".to_string(), self.tool_type().to_string());
props
}
}
pub fn tool_to_stream(
tool: std::sync::Arc<dyn AgentTool>,
tool_use: ToolUse,
invocation_state: InvocationState,
) -> ToolGenerator {
let input = tool_use.input.clone();
let tool_use_id = tool_use.tool_use_id.clone();
let context = ToolContext::with_state(invocation_state);
Box::pin(async_stream::stream! {
let result = match tool.invoke(input, &context).await {
Ok(r) => ToolResult {
tool_use_id,
status: r.status,
content: r.content,
},
Err(e) => ToolResult {
tool_use_id,
status: ToolResultStatus::Error,
content: vec![ToolResultContent::text(e)],
},
};
yield ToolEvent::Result(result);
})
}
pub trait DynamicAgentTool: AgentTool {
fn mark_dynamic(&mut self);
}
pub fn execute_tool_stream(
tool: std::sync::Arc<dyn AgentTool>,
tool_use: ToolUse,
invocation_state: InvocationState,
) -> ToolGenerator {
tool_to_stream(tool, tool_use, invocation_state)
}
pub use loader::{ReloadCallback, ToolLoader, ToolLoaderConfig, ToolWatcher};
pub use mcp::{
ConnectionState, MCPClient, MCPImageContent, MCPImageSource, MCPResultContent,
MCPServerConfig, MCPToolResult, MCPToolSpec, MCPTransport, ToolFilters, ToolProvider,
};
pub use registry::{ToolInput, ToolRegistry};
pub use structured_output::{
flatten_schema, get_required_fields, process_schema_for_optional_fields, schema_to_tool_spec,
structured_output_spec, validate_against_schema, StructuredOutputContext, StructuredOutputResult,
StructuredOutputTool,
};
pub use helpers::{
generate_cancelled_tool_result, generate_missing_tool_result,
generate_missing_tool_result_content, generate_missing_tool_results_for_message,
generate_timeout_tool_result, noop_tool, noop_tool_with, NoopTool,
};
pub use validator::{
is_valid_tool_name, sanitize_tool_name, validate_and_prepare_tools, validate_tool_spec,
validate_tool_specs, validate_tool_use, validate_tool_uses, ToolUseValidationResult,
MAX_TOOL_NAME_LENGTH, MIN_TOOL_NAME_LENGTH,
};
pub use mcp_instrumentation::{
create_mcp_tool_span, extract_trace_context, init_mcp_instrumentation, inject_trace_context,
is_instrumentation_applied, ExtractableContext, InjectableContext, ItemWithContext,
MCPInstrumentationConfig, InstrumentationGuard,
};
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
struct TestTool;
#[async_trait]
impl AgentTool for TestTool {
fn name(&self) -> &str { "test_tool" }
fn description(&self) -> &str { "A test tool" }
fn tool_spec(&self) -> ToolSpec { ToolSpec::new("test_tool", "A test tool") }
async fn invoke(
&self,
_input: serde_json::Value,
_context: &ToolContext,
) -> std::result::Result<ToolResult2, String> {
Ok(ToolResult2::success("Test result"))
}
}
#[tokio::test]
async fn test_tool_execution() {
use futures::StreamExt;
let tool: Arc<dyn AgentTool> = Arc::new(TestTool);
let tool_use = ToolUse::new("test_tool", "123", serde_json::json!({}));
let state = InvocationState::new();
let mut stream = tool_to_stream(tool, tool_use, state);
if let Some(event) = stream.next().await {
assert!(event.is_result());
assert!(event.as_result().unwrap().is_success());
}
}
}