Skip to main content

tower_mcp/
tool.rs

1//! Tool definition and builder API
2//!
3//! Provides ergonomic ways to define MCP tools:
4//!
5//! 1. **Builder pattern** - Fluent API for defining tools
6//! 2. **Trait-based** - Implement `McpTool` for full control
7//! 3. **Function-based** - Quick tools from async functions
8
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12
13use schemars::JsonSchema;
14use serde::Serialize;
15use serde::de::DeserializeOwned;
16use serde_json::Value;
17
18use crate::context::RequestContext;
19use crate::error::{Error, Result};
20use crate::protocol::{CallToolResult, ToolAnnotations, ToolDefinition, ToolIcon};
21
22/// Validate a tool name according to MCP spec.
23///
24/// Tool names must be:
25/// - 1-128 characters long
26/// - Contain only alphanumeric characters, underscores, hyphens, and dots
27///
28/// Returns `Ok(())` if valid, `Err` with description if invalid.
29pub fn validate_tool_name(name: &str) -> Result<()> {
30    if name.is_empty() {
31        return Err(Error::tool("Tool name cannot be empty"));
32    }
33    if name.len() > 128 {
34        return Err(Error::tool(format!(
35            "Tool name '{}' exceeds maximum length of 128 characters (got {})",
36            name,
37            name.len()
38        )));
39    }
40    if let Some(invalid_char) = name
41        .chars()
42        .find(|c| !c.is_ascii_alphanumeric() && *c != '_' && *c != '-' && *c != '.')
43    {
44        return Err(Error::tool(format!(
45            "Tool name '{}' contains invalid character '{}'. Only alphanumeric, underscore, hyphen, and dot are allowed.",
46            name, invalid_char
47        )));
48    }
49    Ok(())
50}
51
52/// A boxed future for tool handlers
53pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
54
55/// Tool handler trait - the core abstraction for tool execution
56pub trait ToolHandler: Send + Sync {
57    /// Execute the tool with the given arguments
58    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>>;
59
60    /// Execute the tool with request context for progress/cancellation support
61    ///
62    /// The default implementation ignores the context and calls `call`.
63    /// Override this to receive progress/cancellation context.
64    fn call_with_context(
65        &self,
66        _ctx: RequestContext,
67        args: Value,
68    ) -> BoxFuture<'_, Result<CallToolResult>> {
69        self.call(args)
70    }
71
72    /// Returns true if this handler uses context (for optimization)
73    fn uses_context(&self) -> bool {
74        false
75    }
76
77    /// Get the tool's input schema
78    fn input_schema(&self) -> Value;
79}
80
81/// A complete tool definition with handler
82pub struct Tool {
83    pub name: String,
84    pub title: Option<String>,
85    pub description: Option<String>,
86    pub output_schema: Option<Value>,
87    pub icons: Option<Vec<ToolIcon>>,
88    pub annotations: Option<ToolAnnotations>,
89    handler: Arc<dyn ToolHandler>,
90}
91
92impl std::fmt::Debug for Tool {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        f.debug_struct("Tool")
95            .field("name", &self.name)
96            .field("title", &self.title)
97            .field("description", &self.description)
98            .field("output_schema", &self.output_schema)
99            .field("icons", &self.icons)
100            .field("annotations", &self.annotations)
101            .finish_non_exhaustive()
102    }
103}
104
105impl Tool {
106    /// Create a new tool builder
107    pub fn builder(name: impl Into<String>) -> ToolBuilder {
108        ToolBuilder::new(name)
109    }
110
111    /// Get the tool definition for tools/list
112    pub fn definition(&self) -> ToolDefinition {
113        ToolDefinition {
114            name: self.name.clone(),
115            title: self.title.clone(),
116            description: self.description.clone(),
117            input_schema: self.handler.input_schema(),
118            output_schema: self.output_schema.clone(),
119            icons: self.icons.clone(),
120            annotations: self.annotations.clone(),
121        }
122    }
123
124    /// Call the tool without context
125    pub fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
126        self.handler.call(args)
127    }
128
129    /// Call the tool with request context
130    ///
131    /// Use this when you have a RequestContext available for progress/cancellation.
132    pub fn call_with_context(
133        &self,
134        ctx: RequestContext,
135        args: Value,
136    ) -> BoxFuture<'_, Result<CallToolResult>> {
137        self.handler.call_with_context(ctx, args)
138    }
139
140    /// Returns true if this tool uses context
141    pub fn uses_context(&self) -> bool {
142        self.handler.uses_context()
143    }
144}
145
146// =============================================================================
147// Builder API
148// =============================================================================
149
150/// Builder for creating tools with a fluent API
151///
152/// # Example
153///
154/// ```rust
155/// use tower_mcp::{ToolBuilder, CallToolResult};
156/// use schemars::JsonSchema;
157/// use serde::Deserialize;
158///
159/// #[derive(Debug, Deserialize, JsonSchema)]
160/// struct GreetInput {
161///     name: String,
162/// }
163///
164/// let tool = ToolBuilder::new("greet")
165///     .description("Greet someone by name")
166///     .handler(|input: GreetInput| async move {
167///         Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
168///     })
169///     .build()
170///     .expect("valid tool name");
171///
172/// assert_eq!(tool.name, "greet");
173/// ```
174pub struct ToolBuilder {
175    name: String,
176    title: Option<String>,
177    description: Option<String>,
178    output_schema: Option<Value>,
179    icons: Option<Vec<ToolIcon>>,
180    annotations: Option<ToolAnnotations>,
181}
182
183impl ToolBuilder {
184    pub fn new(name: impl Into<String>) -> Self {
185        Self {
186            name: name.into(),
187            title: None,
188            description: None,
189            output_schema: None,
190            icons: None,
191            annotations: None,
192        }
193    }
194
195    /// Set a human-readable title for the tool
196    pub fn title(mut self, title: impl Into<String>) -> Self {
197        self.title = Some(title.into());
198        self
199    }
200
201    /// Set the output schema (JSON Schema for structured output)
202    pub fn output_schema(mut self, schema: Value) -> Self {
203        self.output_schema = Some(schema);
204        self
205    }
206
207    /// Add an icon for the tool
208    pub fn icon(mut self, src: impl Into<String>) -> Self {
209        self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
210            src: src.into(),
211            mime_type: None,
212            sizes: None,
213        });
214        self
215    }
216
217    /// Add an icon with metadata
218    pub fn icon_with_meta(
219        mut self,
220        src: impl Into<String>,
221        mime_type: Option<String>,
222        sizes: Option<Vec<String>>,
223    ) -> Self {
224        self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
225            src: src.into(),
226            mime_type,
227            sizes,
228        });
229        self
230    }
231
232    /// Set the tool description
233    pub fn description(mut self, description: impl Into<String>) -> Self {
234        self.description = Some(description.into());
235        self
236    }
237
238    /// Mark the tool as read-only (does not modify state)
239    pub fn read_only(mut self) -> Self {
240        self.annotations
241            .get_or_insert_with(ToolAnnotations::default)
242            .read_only_hint = true;
243        self
244    }
245
246    /// Mark the tool as non-destructive
247    pub fn non_destructive(mut self) -> Self {
248        self.annotations
249            .get_or_insert_with(ToolAnnotations::default)
250            .destructive_hint = false;
251        self
252    }
253
254    /// Mark the tool as idempotent (same args = same effect)
255    pub fn idempotent(mut self) -> Self {
256        self.annotations
257            .get_or_insert_with(ToolAnnotations::default)
258            .idempotent_hint = true;
259        self
260    }
261
262    /// Set tool annotations directly
263    pub fn annotations(mut self, annotations: ToolAnnotations) -> Self {
264        self.annotations = Some(annotations);
265        self
266    }
267
268    /// Specify input type and handler.
269    ///
270    /// The input type must implement `JsonSchema` and `DeserializeOwned`.
271    /// The handler receives the deserialized input and returns a `CallToolResult`.
272    ///
273    /// # State Sharing
274    ///
275    /// To share state across tool calls (e.g., database connections, API clients),
276    /// wrap your state in an `Arc` and clone it into the async block:
277    ///
278    /// ```rust
279    /// use std::sync::Arc;
280    /// use tower_mcp::{ToolBuilder, CallToolResult};
281    /// use schemars::JsonSchema;
282    /// use serde::Deserialize;
283    ///
284    /// struct AppState {
285    ///     api_key: String,
286    /// }
287    ///
288    /// #[derive(Debug, Deserialize, JsonSchema)]
289    /// struct MyInput {
290    ///     query: String,
291    /// }
292    ///
293    /// let state = Arc::new(AppState { api_key: "secret".to_string() });
294    ///
295    /// let tool = ToolBuilder::new("my_tool")
296    ///     .description("A tool that uses shared state")
297    ///     .handler(move |input: MyInput| {
298    ///         let state = state.clone(); // Clone Arc for the async block
299    ///         async move {
300    ///             // Use state.api_key here...
301    ///             Ok(CallToolResult::text(format!("Query: {}", input.query)))
302    ///         }
303    ///     })
304    ///     .build()
305    ///     .expect("valid tool name");
306    /// ```
307    ///
308    /// The `move` keyword on the closure captures the `Arc<AppState>`, and
309    /// cloning it inside the closure body allows each async invocation to
310    /// have its own reference to the shared state.
311    pub fn handler<I, F, Fut>(self, handler: F) -> ToolBuilderWithHandler<I, F>
312    where
313        I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
314        F: Fn(I) -> Fut + Send + Sync + 'static,
315        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
316    {
317        ToolBuilderWithHandler {
318            name: self.name,
319            title: self.title,
320            description: self.description,
321            output_schema: self.output_schema,
322            icons: self.icons,
323            annotations: self.annotations,
324            handler,
325            _phantom: std::marker::PhantomData,
326        }
327    }
328
329    /// Specify input type and context-aware handler
330    ///
331    /// The handler receives a `RequestContext` for progress reporting and
332    /// cancellation checking, along with the deserialized input.
333    ///
334    /// # Example
335    ///
336    /// ```rust
337    /// use tower_mcp::{ToolBuilder, CallToolResult, RequestContext};
338    /// use schemars::JsonSchema;
339    /// use serde::Deserialize;
340    ///
341    /// #[derive(Debug, Deserialize, JsonSchema)]
342    /// struct ProcessInput {
343    ///     items: Vec<String>,
344    /// }
345    ///
346    /// let tool = ToolBuilder::new("process")
347    ///     .description("Process items with progress")
348    ///     .handler_with_context(|ctx: RequestContext, input: ProcessInput| async move {
349    ///         for (i, item) in input.items.iter().enumerate() {
350    ///             if ctx.is_cancelled() {
351    ///                 return Ok(CallToolResult::error("Cancelled"));
352    ///             }
353    ///             ctx.report_progress(i as f64, Some(input.items.len() as f64), Some("Processing...")).await;
354    ///             // Process item...
355    ///         }
356    ///         Ok(CallToolResult::text("Done"))
357    ///     })
358    ///     .build()
359    ///     .expect("valid tool name");
360    /// ```
361    pub fn handler_with_context<I, F, Fut>(self, handler: F) -> ToolBuilderWithContextHandler<I, F>
362    where
363        I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
364        F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
365        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
366    {
367        ToolBuilderWithContextHandler {
368            name: self.name,
369            title: self.title,
370            description: self.description,
371            output_schema: self.output_schema,
372            icons: self.icons,
373            annotations: self.annotations,
374            handler,
375            _phantom: std::marker::PhantomData,
376        }
377    }
378
379    /// Create a tool with raw JSON handling (no automatic deserialization)
380    ///
381    /// Returns an error if the tool name is invalid.
382    pub fn raw_handler<F, Fut>(self, handler: F) -> Result<Tool>
383    where
384        F: Fn(Value) -> Fut + Send + Sync + 'static,
385        Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
386    {
387        validate_tool_name(&self.name)?;
388        Ok(Tool {
389            name: self.name,
390            title: self.title,
391            description: self.description,
392            output_schema: self.output_schema,
393            icons: self.icons,
394            annotations: self.annotations,
395            handler: Arc::new(RawHandler { handler }),
396        })
397    }
398}
399
400/// Builder state after handler is specified
401pub struct ToolBuilderWithHandler<I, F> {
402    name: String,
403    title: Option<String>,
404    description: Option<String>,
405    output_schema: Option<Value>,
406    icons: Option<Vec<ToolIcon>>,
407    annotations: Option<ToolAnnotations>,
408    handler: F,
409    _phantom: std::marker::PhantomData<I>,
410}
411
412impl<I, F, Fut> ToolBuilderWithHandler<I, F>
413where
414    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
415    F: Fn(I) -> Fut + Send + Sync + 'static,
416    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
417{
418    /// Build the tool
419    ///
420    /// Returns an error if the tool name is invalid.
421    pub fn build(self) -> Result<Tool> {
422        validate_tool_name(&self.name)?;
423        Ok(Tool {
424            name: self.name,
425            title: self.title,
426            description: self.description,
427            output_schema: self.output_schema,
428            icons: self.icons,
429            annotations: self.annotations,
430            handler: Arc::new(TypedHandler {
431                handler: self.handler,
432                _phantom: std::marker::PhantomData,
433            }),
434        })
435    }
436}
437
438/// Builder state after context-aware handler is specified
439pub struct ToolBuilderWithContextHandler<I, F> {
440    name: String,
441    title: Option<String>,
442    description: Option<String>,
443    output_schema: Option<Value>,
444    icons: Option<Vec<ToolIcon>>,
445    annotations: Option<ToolAnnotations>,
446    handler: F,
447    _phantom: std::marker::PhantomData<I>,
448}
449
450impl<I, F, Fut> ToolBuilderWithContextHandler<I, F>
451where
452    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
453    F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
454    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
455{
456    /// Build the tool
457    ///
458    /// Returns an error if the tool name is invalid.
459    pub fn build(self) -> Result<Tool> {
460        validate_tool_name(&self.name)?;
461        Ok(Tool {
462            name: self.name,
463            title: self.title,
464            description: self.description,
465            output_schema: self.output_schema,
466            icons: self.icons,
467            annotations: self.annotations,
468            handler: Arc::new(ContextAwareHandler {
469                handler: self.handler,
470                _phantom: std::marker::PhantomData,
471            }),
472        })
473    }
474}
475
476// =============================================================================
477// Handler implementations
478// =============================================================================
479
480/// Handler that deserializes input to a specific type
481struct TypedHandler<I, F> {
482    handler: F,
483    _phantom: std::marker::PhantomData<I>,
484}
485
486impl<I, F, Fut> ToolHandler for TypedHandler<I, F>
487where
488    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
489    F: Fn(I) -> Fut + Send + Sync + 'static,
490    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
491{
492    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
493        Box::pin(async move {
494            let input: I = serde_json::from_value(args)
495                .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
496            (self.handler)(input).await
497        })
498    }
499
500    fn input_schema(&self) -> Value {
501        let schema = schemars::schema_for!(I);
502        serde_json::to_value(schema).unwrap_or_else(|_| {
503            serde_json::json!({
504                "type": "object"
505            })
506        })
507    }
508}
509
510/// Handler that works with raw JSON
511struct RawHandler<F> {
512    handler: F,
513}
514
515impl<F, Fut> ToolHandler for RawHandler<F>
516where
517    F: Fn(Value) -> Fut + Send + Sync + 'static,
518    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
519{
520    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
521        Box::pin((self.handler)(args))
522    }
523
524    fn input_schema(&self) -> Value {
525        // Raw handlers accept any JSON
526        serde_json::json!({
527            "type": "object",
528            "additionalProperties": true
529        })
530    }
531}
532
533/// Handler that receives request context for progress/cancellation
534struct ContextAwareHandler<I, F> {
535    handler: F,
536    _phantom: std::marker::PhantomData<I>,
537}
538
539impl<I, F, Fut> ToolHandler for ContextAwareHandler<I, F>
540where
541    I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
542    F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
543    Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
544{
545    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
546        // When called without context, create a dummy context
547        let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
548        self.call_with_context(ctx, args)
549    }
550
551    fn call_with_context(
552        &self,
553        ctx: RequestContext,
554        args: Value,
555    ) -> BoxFuture<'_, Result<CallToolResult>> {
556        Box::pin(async move {
557            let input: I = serde_json::from_value(args)
558                .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
559            (self.handler)(ctx, input).await
560        })
561    }
562
563    fn uses_context(&self) -> bool {
564        true
565    }
566
567    fn input_schema(&self) -> Value {
568        let schema = schemars::schema_for!(I);
569        serde_json::to_value(schema).unwrap_or_else(|_| {
570            serde_json::json!({
571                "type": "object"
572            })
573        })
574    }
575}
576
577// =============================================================================
578// Trait-based tool definition
579// =============================================================================
580
581/// Trait for defining tools with full control
582///
583/// Implement this trait when you need more control than the builder provides,
584/// or when you want to define tools as standalone types.
585///
586/// # Example
587///
588/// ```rust
589/// use tower_mcp::tool::McpTool;
590/// use tower_mcp::error::Result;
591/// use schemars::JsonSchema;
592/// use serde::{Deserialize, Serialize};
593///
594/// #[derive(Debug, Deserialize, JsonSchema)]
595/// struct AddInput {
596///     a: i64,
597///     b: i64,
598/// }
599///
600/// struct AddTool;
601///
602/// impl McpTool for AddTool {
603///     const NAME: &'static str = "add";
604///     const DESCRIPTION: &'static str = "Add two numbers";
605///
606///     type Input = AddInput;
607///     type Output = i64;
608///
609///     async fn call(&self, input: Self::Input) -> Result<Self::Output> {
610///         Ok(input.a + input.b)
611///     }
612/// }
613///
614/// let tool = AddTool.into_tool().expect("valid tool name");
615/// assert_eq!(tool.name, "add");
616/// ```
617pub trait McpTool: Send + Sync + 'static {
618    const NAME: &'static str;
619    const DESCRIPTION: &'static str;
620
621    type Input: JsonSchema + DeserializeOwned + Send;
622    type Output: Serialize + Send;
623
624    fn call(&self, input: Self::Input) -> impl Future<Output = Result<Self::Output>> + Send;
625
626    /// Optional annotations for the tool
627    fn annotations(&self) -> Option<ToolAnnotations> {
628        None
629    }
630
631    /// Convert to a Tool instance
632    ///
633    /// Returns an error if the tool name is invalid.
634    fn into_tool(self) -> Result<Tool>
635    where
636        Self: Sized,
637    {
638        validate_tool_name(Self::NAME)?;
639        let annotations = self.annotations();
640        let tool = Arc::new(self);
641        Ok(Tool {
642            name: Self::NAME.to_string(),
643            title: None,
644            description: Some(Self::DESCRIPTION.to_string()),
645            output_schema: None,
646            icons: None,
647            annotations,
648            handler: Arc::new(McpToolHandler { tool }),
649        })
650    }
651}
652
653/// Wrapper to make McpTool implement ToolHandler
654struct McpToolHandler<T: McpTool> {
655    tool: Arc<T>,
656}
657
658impl<T: McpTool> ToolHandler for McpToolHandler<T> {
659    fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
660        let tool = self.tool.clone();
661        Box::pin(async move {
662            let input: T::Input = serde_json::from_value(args)
663                .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
664            let output = tool.call(input).await?;
665            let value = serde_json::to_value(output)
666                .map_err(|e| Error::tool(format!("Failed to serialize output: {}", e)))?;
667            Ok(CallToolResult::json(value))
668        })
669    }
670
671    fn input_schema(&self) -> Value {
672        let schema = schemars::schema_for!(T::Input);
673        serde_json::to_value(schema).unwrap_or_else(|_| {
674            serde_json::json!({
675                "type": "object"
676            })
677        })
678    }
679}
680
681#[cfg(test)]
682mod tests {
683    use super::*;
684    use schemars::JsonSchema;
685    use serde::Deserialize;
686
687    #[derive(Debug, Deserialize, JsonSchema)]
688    struct GreetInput {
689        name: String,
690    }
691
692    #[tokio::test]
693    async fn test_builder_tool() {
694        let tool = ToolBuilder::new("greet")
695            .description("Greet someone")
696            .handler(|input: GreetInput| async move {
697                Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
698            })
699            .build()
700            .expect("valid tool name");
701
702        assert_eq!(tool.name, "greet");
703        assert_eq!(tool.description.as_deref(), Some("Greet someone"));
704
705        let result = tool
706            .call(serde_json::json!({"name": "World"}))
707            .await
708            .unwrap();
709
710        assert!(!result.is_error);
711    }
712
713    #[tokio::test]
714    async fn test_raw_handler() {
715        let tool = ToolBuilder::new("echo")
716            .description("Echo input")
717            .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) })
718            .expect("valid tool name");
719
720        let result = tool.call(serde_json::json!({"foo": "bar"})).await.unwrap();
721
722        assert!(!result.is_error);
723    }
724
725    #[test]
726    fn test_invalid_tool_name_empty() {
727        let result = ToolBuilder::new("")
728            .description("Empty name")
729            .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
730
731        assert!(result.is_err());
732        assert!(result.unwrap_err().to_string().contains("cannot be empty"));
733    }
734
735    #[test]
736    fn test_invalid_tool_name_too_long() {
737        let long_name = "a".repeat(129);
738        let result = ToolBuilder::new(long_name)
739            .description("Too long")
740            .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
741
742        assert!(result.is_err());
743        assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
744    }
745
746    #[test]
747    fn test_invalid_tool_name_bad_chars() {
748        let result = ToolBuilder::new("my tool!")
749            .description("Bad chars")
750            .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
751
752        assert!(result.is_err());
753        assert!(
754            result
755                .unwrap_err()
756                .to_string()
757                .contains("invalid character")
758        );
759    }
760
761    #[test]
762    fn test_valid_tool_names() {
763        // All valid characters
764        let names = [
765            "my_tool",
766            "my-tool",
767            "my.tool",
768            "MyTool123",
769            "a",
770            &"a".repeat(128),
771        ];
772        for name in names {
773            let result = ToolBuilder::new(name)
774                .description("Valid")
775                .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
776            assert!(result.is_ok(), "Expected '{}' to be valid", name);
777        }
778    }
779
780    #[tokio::test]
781    async fn test_context_aware_handler() {
782        use crate::context::{RequestContext, notification_channel};
783        use crate::protocol::{ProgressToken, RequestId};
784
785        #[derive(Debug, Deserialize, JsonSchema)]
786        struct ProcessInput {
787            count: i32,
788        }
789
790        let tool = ToolBuilder::new("process")
791            .description("Process with context")
792            .handler_with_context(|ctx: RequestContext, input: ProcessInput| async move {
793                // Simulate progress reporting
794                for i in 0..input.count {
795                    if ctx.is_cancelled() {
796                        return Ok(CallToolResult::error("Cancelled"));
797                    }
798                    ctx.report_progress(i as f64, Some(input.count as f64), None)
799                        .await;
800                }
801                Ok(CallToolResult::text(format!(
802                    "Processed {} items",
803                    input.count
804                )))
805            })
806            .build()
807            .expect("valid tool name");
808
809        assert_eq!(tool.name, "process");
810        assert!(tool.uses_context());
811
812        // Test with a context that has progress token and notification sender
813        let (tx, mut rx) = notification_channel(10);
814        let ctx = RequestContext::new(RequestId::Number(1))
815            .with_progress_token(ProgressToken::Number(42))
816            .with_notification_sender(tx);
817
818        let result = tool
819            .call_with_context(ctx, serde_json::json!({"count": 3}))
820            .await
821            .unwrap();
822
823        assert!(!result.is_error);
824
825        // Check that progress notifications were sent
826        let mut progress_count = 0;
827        while rx.try_recv().is_ok() {
828            progress_count += 1;
829        }
830        assert_eq!(progress_count, 3);
831    }
832
833    #[tokio::test]
834    async fn test_context_aware_handler_cancellation() {
835        use crate::context::RequestContext;
836        use crate::protocol::RequestId;
837        use std::sync::Arc;
838        use std::sync::atomic::{AtomicI32, Ordering};
839
840        #[derive(Debug, Deserialize, JsonSchema)]
841        struct LongRunningInput {
842            iterations: i32,
843        }
844
845        let iterations_completed = Arc::new(AtomicI32::new(0));
846        let iterations_ref = iterations_completed.clone();
847
848        let tool = ToolBuilder::new("long_running")
849            .description("Long running task")
850            .handler_with_context(move |ctx: RequestContext, input: LongRunningInput| {
851                let completed = iterations_ref.clone();
852                async move {
853                    for i in 0..input.iterations {
854                        if ctx.is_cancelled() {
855                            return Ok(CallToolResult::error("Cancelled"));
856                        }
857                        completed.fetch_add(1, Ordering::SeqCst);
858                        // Simulate work
859                        tokio::task::yield_now().await;
860                        // Cancel after iteration 2
861                        if i == 2 {
862                            ctx.cancellation_token().cancel();
863                        }
864                    }
865                    Ok(CallToolResult::text("Done"))
866                }
867            })
868            .build()
869            .expect("valid tool name");
870
871        let ctx = RequestContext::new(RequestId::Number(1));
872
873        let result = tool
874            .call_with_context(ctx, serde_json::json!({"iterations": 10}))
875            .await
876            .unwrap();
877
878        // Should have been cancelled after 3 iterations (0, 1, 2)
879        // The next iteration (3) checks cancellation and returns
880        assert!(result.is_error);
881        assert_eq!(iterations_completed.load(Ordering::SeqCst), 3);
882    }
883
884    #[tokio::test]
885    async fn test_tool_builder_with_enhanced_fields() {
886        let output_schema = serde_json::json!({
887            "type": "object",
888            "properties": {
889                "greeting": {"type": "string"}
890            }
891        });
892
893        let tool = ToolBuilder::new("greet")
894            .title("Greeting Tool")
895            .description("Greet someone")
896            .output_schema(output_schema.clone())
897            .icon("https://example.com/icon.png")
898            .icon_with_meta(
899                "https://example.com/icon-large.png",
900                Some("image/png".to_string()),
901                Some(vec!["96x96".to_string()]),
902            )
903            .handler(|input: GreetInput| async move {
904                Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
905            })
906            .build()
907            .expect("valid tool name");
908
909        assert_eq!(tool.name, "greet");
910        assert_eq!(tool.title.as_deref(), Some("Greeting Tool"));
911        assert_eq!(tool.description.as_deref(), Some("Greet someone"));
912        assert_eq!(tool.output_schema, Some(output_schema));
913        assert!(tool.icons.is_some());
914        assert_eq!(tool.icons.as_ref().unwrap().len(), 2);
915
916        // Test definition includes new fields
917        let def = tool.definition();
918        assert_eq!(def.title.as_deref(), Some("Greeting Tool"));
919        assert!(def.output_schema.is_some());
920        assert!(def.icons.is_some());
921    }
922}