1pub mod executor;
4pub mod helpers;
5pub mod loader;
6pub mod mcp;
7pub mod mcp_instrumentation;
8pub mod registry;
9pub mod structured_output;
10pub mod validator;
11pub mod watcher;
12
13use std::pin::Pin;
14
15use futures::Stream;
16
17use async_trait::async_trait;
18
19use crate::types::tools::{ToolResult, ToolResultContent, ToolResultStatus, ToolSpec, ToolUse};
20
21pub type ToolEventStream = Pin<Box<dyn Stream<Item = ToolEvent> + Send>>;
23
24pub type ToolGenerator = ToolEventStream;
26
27#[derive(Debug, Clone)]
29pub enum ToolEvent {
30 Progress { message: String },
32 Stream(serde_json::Value),
34 Result(ToolResult),
36 Interrupt { id: String, data: serde_json::Value },
38}
39
40impl ToolEvent {
41 pub fn progress(message: impl Into<String>) -> Self {
42 Self::Progress { message: message.into() }
43 }
44
45 pub fn stream(data: serde_json::Value) -> Self { Self::Stream(data) }
46 pub fn result(result: ToolResult) -> Self { Self::Result(result) }
47 pub fn is_result(&self) -> bool { matches!(self, Self::Result(_)) }
48
49 pub fn as_result(&self) -> Option<&ToolResult> {
50 match self {
51 Self::Result(r) => Some(r),
52 _ => None,
53 }
54 }
55}
56
57#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
59pub struct InvocationState {
60 pub data: std::collections::HashMap<String, serde_json::Value>,
61 #[serde(default)]
62 pub stop_event_loop: bool,
63}
64
65impl InvocationState {
66 pub fn new() -> Self { Self::default() }
67
68 pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
69 self.data.get(key).and_then(|v| T::deserialize(v).ok())
70 }
71
72 pub fn set(&mut self, key: impl Into<String>, value: impl serde::Serialize) {
73 if let Ok(v) = serde_json::to_value(value) {
74 self.data.insert(key.into(), v);
75 }
76 }
77}
78
79#[derive(Debug, Clone, Default)]
81pub struct ToolContext {
82 pub invocation_state: InvocationState,
83 pub interrupt_id: Option<uuid::Uuid>,
84}
85
86impl ToolContext {
87 pub fn new() -> Self { Self::default() }
88
89 pub fn with_state(state: InvocationState) -> Self {
90 Self { invocation_state: state, interrupt_id: None }
91 }
92}
93
94#[derive(Debug, Clone)]
96pub struct ToolResult2 {
97 pub status: ToolResultStatus,
98 pub content: Vec<ToolResultContent>,
99}
100
101impl ToolResult2 {
102 pub fn success(content: impl Into<String>) -> Self {
103 Self {
104 status: ToolResultStatus::Success,
105 content: vec![ToolResultContent::text(content.into())],
106 }
107 }
108
109 pub fn success_json(value: serde_json::Value) -> Self {
110 Self {
111 status: ToolResultStatus::Success,
112 content: vec![ToolResultContent::json(value)],
113 }
114 }
115
116 pub fn error(message: impl Into<String>) -> Self {
117 Self {
118 status: ToolResultStatus::Error,
119 content: vec![ToolResultContent::text(message.into())],
120 }
121 }
122}
123
124#[async_trait]
126pub trait AgentTool: Send + Sync {
127 fn name(&self) -> &str;
129
130 fn description(&self) -> &str;
132
133 fn tool_spec(&self) -> ToolSpec;
135
136 async fn invoke(
138 &self,
139 input: serde_json::Value,
140 context: &ToolContext,
141 ) -> std::result::Result<ToolResult2, String>;
142
143 fn tool_name(&self) -> &str { self.name() }
145
146 fn tool_type(&self) -> &str { "function" }
148
149 fn supports_hot_reload(&self) -> bool { false }
151
152 fn is_dynamic(&self) -> bool { false }
154
155 fn get_display_properties(&self) -> std::collections::HashMap<String, String> {
157 let mut props = std::collections::HashMap::new();
158 props.insert("Name".to_string(), self.name().to_string());
159 props.insert("Type".to_string(), self.tool_type().to_string());
160 props
161 }
162}
163
164pub fn tool_to_stream(
166 tool: std::sync::Arc<dyn AgentTool>,
167 tool_use: ToolUse,
168 invocation_state: InvocationState,
169) -> ToolGenerator {
170 let input = tool_use.input.clone();
171 let tool_use_id = tool_use.tool_use_id.clone();
172 let context = ToolContext::with_state(invocation_state);
173
174 Box::pin(async_stream::stream! {
175 let result = match tool.invoke(input, &context).await {
176 Ok(r) => ToolResult {
177 tool_use_id,
178 status: r.status,
179 content: r.content,
180 },
181 Err(e) => ToolResult {
182 tool_use_id,
183 status: ToolResultStatus::Error,
184 content: vec![ToolResultContent::text(e)],
185 },
186 };
187 yield ToolEvent::Result(result);
188 })
189}
190
191pub trait DynamicAgentTool: AgentTool {
193 fn mark_dynamic(&mut self);
195}
196
197pub fn execute_tool_stream(
199 tool: std::sync::Arc<dyn AgentTool>,
200 tool_use: ToolUse,
201 invocation_state: InvocationState,
202) -> ToolGenerator {
203 tool_to_stream(tool, tool_use, invocation_state)
204}
205
206pub use loader::{ReloadCallback, ToolLoader, ToolLoaderConfig, ToolWatcher};
207pub use mcp::{
208 ConnectionState, MCPClient, MCPImageContent, MCPImageSource, MCPResultContent,
209 MCPServerConfig, MCPToolResult, MCPToolSpec, MCPTransport, ToolFilters, ToolProvider,
210};
211pub use registry::{ToolInput, ToolRegistry};
212pub use structured_output::{
213 flatten_schema, get_required_fields, process_schema_for_optional_fields, schema_to_tool_spec,
214 structured_output_spec, validate_against_schema, StructuredOutputContext, StructuredOutputResult,
215 StructuredOutputTool,
216};
217pub use helpers::{
218 generate_cancelled_tool_result, generate_missing_tool_result,
219 generate_missing_tool_result_content, generate_missing_tool_results_for_message,
220 generate_timeout_tool_result, noop_tool, noop_tool_with, NoopTool,
221};
222pub use validator::{
223 is_valid_tool_name, sanitize_tool_name, validate_and_prepare_tools, validate_tool_spec,
224 validate_tool_specs, validate_tool_use, validate_tool_uses, ToolUseValidationResult,
225 MAX_TOOL_NAME_LENGTH, MIN_TOOL_NAME_LENGTH,
226};
227pub use mcp_instrumentation::{
228 create_mcp_tool_span, extract_trace_context, init_mcp_instrumentation, inject_trace_context,
229 is_instrumentation_applied, ExtractableContext, InjectableContext, ItemWithContext,
230 MCPInstrumentationConfig, InstrumentationGuard,
231};
232
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use std::sync::Arc;
238
239 struct TestTool;
240
241 #[async_trait]
242 impl AgentTool for TestTool {
243 fn name(&self) -> &str { "test_tool" }
244 fn description(&self) -> &str { "A test tool" }
245 fn tool_spec(&self) -> ToolSpec { ToolSpec::new("test_tool", "A test tool") }
246
247 async fn invoke(
248 &self,
249 _input: serde_json::Value,
250 _context: &ToolContext,
251 ) -> std::result::Result<ToolResult2, String> {
252 Ok(ToolResult2::success("Test result"))
253 }
254 }
255
256 #[tokio::test]
257 async fn test_tool_execution() {
258 use futures::StreamExt;
259
260 let tool: Arc<dyn AgentTool> = Arc::new(TestTool);
261 let tool_use = ToolUse::new("test_tool", "123", serde_json::json!({}));
262 let state = InvocationState::new();
263 let mut stream = tool_to_stream(tool, tool_use, state);
264
265 if let Some(event) = stream.next().await {
266 assert!(event.is_result());
267 assert!(event.as_result().unwrap().is_success());
268 }
269 }
270}