1use std::borrow::Cow;
35use std::convert::Infallible;
36use std::fmt;
37use std::future::Future;
38use std::pin::Pin;
39use std::sync::Arc;
40use std::task::{Context, Poll};
41
42use schemars::{JsonSchema, Schema, SchemaGenerator};
43use serde::Serialize;
44use serde::de::DeserializeOwned;
45use serde_json::Value;
46use tower::util::BoxCloneService;
47use tower_service::Service;
48
49use crate::context::RequestContext;
50use crate::error::{Error, Result};
51use crate::protocol::{CallToolResult, ToolAnnotations, ToolDefinition, ToolIcon};
52
53#[derive(Debug, Clone)]
62pub struct ToolRequest {
63 pub ctx: RequestContext,
65 pub args: Value,
67}
68
69impl ToolRequest {
70 pub fn new(ctx: RequestContext, args: Value) -> Self {
72 Self { ctx, args }
73 }
74}
75
76pub type BoxToolService = BoxCloneService<ToolRequest, CallToolResult, Infallible>;
82
83pub struct ToolCatchError<S> {
89 inner: S,
90}
91
92impl<S> ToolCatchError<S> {
93 pub fn new(inner: S) -> Self {
95 Self { inner }
96 }
97}
98
99impl<S: Clone> Clone for ToolCatchError<S> {
100 fn clone(&self) -> Self {
101 Self {
102 inner: self.inner.clone(),
103 }
104 }
105}
106
107impl<S: fmt::Debug> fmt::Debug for ToolCatchError<S> {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109 f.debug_struct("ToolCatchError")
110 .field("inner", &self.inner)
111 .finish()
112 }
113}
114
115impl<S> Service<ToolRequest> for ToolCatchError<S>
116where
117 S: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
118 S::Error: fmt::Display + Send,
119 S::Future: Send,
120{
121 type Response = CallToolResult;
122 type Error = Infallible;
123 type Future =
124 Pin<Box<dyn Future<Output = std::result::Result<CallToolResult, Infallible>> + Send>>;
125
126 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
127 match self.inner.poll_ready(cx) {
129 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
130 Poll::Ready(Err(_)) => Poll::Ready(Ok(())),
131 Poll::Pending => Poll::Pending,
132 }
133 }
134
135 fn call(&mut self, req: ToolRequest) -> Self::Future {
136 let fut = self.inner.call(req);
137
138 Box::pin(async move {
139 match fut.await {
140 Ok(result) => Ok(result),
141 Err(err) => Ok(CallToolResult::error(err.to_string())),
142 }
143 })
144 }
145}
146
147#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
168pub struct NoParams;
169
170impl<'de> serde::Deserialize<'de> for NoParams {
171 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
172 where
173 D: serde::Deserializer<'de>,
174 {
175 struct NoParamsVisitor;
177
178 impl<'de> serde::de::Visitor<'de> for NoParamsVisitor {
179 type Value = NoParams;
180
181 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
182 formatter.write_str("null or an object")
183 }
184
185 fn visit_unit<E>(self) -> std::result::Result<Self::Value, E>
186 where
187 E: serde::de::Error,
188 {
189 Ok(NoParams)
190 }
191
192 fn visit_none<E>(self) -> std::result::Result<Self::Value, E>
193 where
194 E: serde::de::Error,
195 {
196 Ok(NoParams)
197 }
198
199 fn visit_some<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
200 where
201 D: serde::Deserializer<'de>,
202 {
203 serde::Deserialize::deserialize(deserializer)
204 }
205
206 fn visit_map<A>(self, mut map: A) -> std::result::Result<Self::Value, A::Error>
207 where
208 A: serde::de::MapAccess<'de>,
209 {
210 while map
212 .next_entry::<serde::de::IgnoredAny, serde::de::IgnoredAny>()?
213 .is_some()
214 {}
215 Ok(NoParams)
216 }
217 }
218
219 deserializer.deserialize_any(NoParamsVisitor)
220 }
221}
222
223impl JsonSchema for NoParams {
224 fn schema_name() -> Cow<'static, str> {
225 Cow::Borrowed("NoParams")
226 }
227
228 fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
229 serde_json::json!({
230 "type": "object"
231 })
232 .try_into()
233 .expect("valid schema")
234 }
235}
236
237pub fn validate_tool_name(name: &str) -> Result<()> {
245 if name.is_empty() {
246 return Err(Error::tool("Tool name cannot be empty"));
247 }
248 if name.len() > 128 {
249 return Err(Error::tool(format!(
250 "Tool name '{}' exceeds maximum length of 128 characters (got {})",
251 name,
252 name.len()
253 )));
254 }
255 if let Some(invalid_char) = name
256 .chars()
257 .find(|c| !c.is_ascii_alphanumeric() && *c != '_' && *c != '-' && *c != '.')
258 {
259 return Err(Error::tool(format!(
260 "Tool name '{}' contains invalid character '{}'. Only alphanumeric, underscore, hyphen, and dot are allowed.",
261 name, invalid_char
262 )));
263 }
264 Ok(())
265}
266
267pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
269
270pub trait ToolHandler: Send + Sync {
272 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>>;
274
275 fn call_with_context(
280 &self,
281 _ctx: RequestContext,
282 args: Value,
283 ) -> BoxFuture<'_, Result<CallToolResult>> {
284 self.call(args)
285 }
286
287 fn uses_context(&self) -> bool {
289 false
290 }
291
292 fn input_schema(&self) -> Value;
294}
295
296pub(crate) struct ToolHandlerService<H> {
301 handler: Arc<H>,
302}
303
304impl<H> ToolHandlerService<H> {
305 pub(crate) fn new(handler: H) -> Self {
306 Self {
307 handler: Arc::new(handler),
308 }
309 }
310}
311
312impl<H> Clone for ToolHandlerService<H> {
313 fn clone(&self) -> Self {
314 Self {
315 handler: self.handler.clone(),
316 }
317 }
318}
319
320impl<H> Service<ToolRequest> for ToolHandlerService<H>
321where
322 H: ToolHandler + 'static,
323{
324 type Response = CallToolResult;
325 type Error = Error;
326 type Future = Pin<Box<dyn Future<Output = std::result::Result<CallToolResult, Error>> + Send>>;
327
328 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
329 Poll::Ready(Ok(()))
330 }
331
332 fn call(&mut self, req: ToolRequest) -> Self::Future {
333 let handler = self.handler.clone();
334 Box::pin(async move { handler.call_with_context(req.ctx, req.args).await })
335 }
336}
337
338pub struct Tool {
345 pub name: String,
347 pub title: Option<String>,
349 pub description: Option<String>,
351 pub output_schema: Option<Value>,
353 pub icons: Option<Vec<ToolIcon>>,
355 pub annotations: Option<ToolAnnotations>,
357 pub(crate) service: BoxToolService,
359 pub(crate) input_schema: Value,
361}
362
363impl std::fmt::Debug for Tool {
364 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365 f.debug_struct("Tool")
366 .field("name", &self.name)
367 .field("title", &self.title)
368 .field("description", &self.description)
369 .field("output_schema", &self.output_schema)
370 .field("icons", &self.icons)
371 .field("annotations", &self.annotations)
372 .finish_non_exhaustive()
373 }
374}
375
376unsafe impl Send for Tool {}
379unsafe impl Sync for Tool {}
380
381impl Tool {
382 pub fn builder(name: impl Into<String>) -> ToolBuilder {
384 ToolBuilder::new(name)
385 }
386
387 pub fn definition(&self) -> ToolDefinition {
389 ToolDefinition {
390 name: self.name.clone(),
391 title: self.title.clone(),
392 description: self.description.clone(),
393 input_schema: self.input_schema.clone(),
394 output_schema: self.output_schema.clone(),
395 icons: self.icons.clone(),
396 annotations: self.annotations.clone(),
397 }
398 }
399
400 pub fn call(&self, args: Value) -> BoxFuture<'static, CallToolResult> {
405 let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
406 self.call_with_context(ctx, args)
407 }
408
409 pub fn call_with_context(
420 &self,
421 ctx: RequestContext,
422 args: Value,
423 ) -> BoxFuture<'static, CallToolResult> {
424 use tower::ServiceExt;
425 let service = self.service.clone();
426 Box::pin(async move {
427 service.oneshot(ToolRequest::new(ctx, args)).await.unwrap()
430 })
431 }
432
433 pub fn with_name_prefix(&self, prefix: &str) -> Self {
461 Self {
462 name: format!("{}.{}", prefix, self.name),
463 title: self.title.clone(),
464 description: self.description.clone(),
465 output_schema: self.output_schema.clone(),
466 icons: self.icons.clone(),
467 annotations: self.annotations.clone(),
468 service: self.service.clone(),
469 input_schema: self.input_schema.clone(),
470 }
471 }
472
473 fn from_handler<H: ToolHandler + 'static>(
475 name: String,
476 title: Option<String>,
477 description: Option<String>,
478 output_schema: Option<Value>,
479 icons: Option<Vec<ToolIcon>>,
480 annotations: Option<ToolAnnotations>,
481 handler: H,
482 ) -> Self {
483 let input_schema = handler.input_schema();
484 let handler_service = ToolHandlerService::new(handler);
485 let catch_error = ToolCatchError::new(handler_service);
486 let service = BoxCloneService::new(catch_error);
487
488 Self {
489 name,
490 title,
491 description,
492 output_schema,
493 icons,
494 annotations,
495 service,
496 input_schema,
497 }
498 }
499}
500
501pub struct ToolBuilder {
530 name: String,
531 title: Option<String>,
532 description: Option<String>,
533 output_schema: Option<Value>,
534 icons: Option<Vec<ToolIcon>>,
535 annotations: Option<ToolAnnotations>,
536}
537
538impl ToolBuilder {
539 pub fn new(name: impl Into<String>) -> Self {
540 Self {
541 name: name.into(),
542 title: None,
543 description: None,
544 output_schema: None,
545 icons: None,
546 annotations: None,
547 }
548 }
549
550 pub fn title(mut self, title: impl Into<String>) -> Self {
552 self.title = Some(title.into());
553 self
554 }
555
556 pub fn output_schema(mut self, schema: Value) -> Self {
558 self.output_schema = Some(schema);
559 self
560 }
561
562 pub fn icon(mut self, src: impl Into<String>) -> Self {
564 self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
565 src: src.into(),
566 mime_type: None,
567 sizes: None,
568 });
569 self
570 }
571
572 pub fn icon_with_meta(
574 mut self,
575 src: impl Into<String>,
576 mime_type: Option<String>,
577 sizes: Option<Vec<String>>,
578 ) -> Self {
579 self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
580 src: src.into(),
581 mime_type,
582 sizes,
583 });
584 self
585 }
586
587 pub fn description(mut self, description: impl Into<String>) -> Self {
589 self.description = Some(description.into());
590 self
591 }
592
593 pub fn read_only(mut self) -> Self {
595 self.annotations
596 .get_or_insert_with(ToolAnnotations::default)
597 .read_only_hint = true;
598 self
599 }
600
601 pub fn non_destructive(mut self) -> Self {
603 self.annotations
604 .get_or_insert_with(ToolAnnotations::default)
605 .destructive_hint = false;
606 self
607 }
608
609 pub fn idempotent(mut self) -> Self {
611 self.annotations
612 .get_or_insert_with(ToolAnnotations::default)
613 .idempotent_hint = true;
614 self
615 }
616
617 pub fn annotations(mut self, annotations: ToolAnnotations) -> Self {
619 self.annotations = Some(annotations);
620 self
621 }
622
623 pub fn handler<I, F, Fut>(self, handler: F) -> ToolBuilderWithHandler<I, F>
667 where
668 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
669 F: Fn(I) -> Fut + Send + Sync + 'static,
670 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
671 {
672 ToolBuilderWithHandler {
673 name: self.name,
674 title: self.title,
675 description: self.description,
676 output_schema: self.output_schema,
677 icons: self.icons,
678 annotations: self.annotations,
679 handler,
680 _phantom: std::marker::PhantomData,
681 }
682 }
683
684 pub fn extractor_handler<S, F, T>(
738 self,
739 state: S,
740 handler: F,
741 ) -> crate::extract::ToolBuilderWithExtractor<S, F, T>
742 where
743 S: Clone + Send + Sync + 'static,
744 F: crate::extract::ExtractorHandler<S, T> + Clone,
745 T: Send + Sync + 'static,
746 {
747 crate::extract::ToolBuilderWithExtractor {
748 name: self.name,
749 title: self.title,
750 description: self.description,
751 output_schema: self.output_schema,
752 icons: self.icons,
753 annotations: self.annotations,
754 state,
755 handler,
756 input_schema: F::input_schema(),
757 _phantom: std::marker::PhantomData,
758 }
759 }
760
761 pub fn extractor_handler_typed<S, F, T, I>(
795 self,
796 state: S,
797 handler: F,
798 ) -> crate::extract::ToolBuilderWithTypedExtractor<S, F, T, I>
799 where
800 S: Clone + Send + Sync + 'static,
801 F: crate::extract::TypedExtractorHandler<S, T, I> + Clone,
802 T: Send + Sync + 'static,
803 I: schemars::JsonSchema + Send + Sync + 'static,
804 {
805 crate::extract::ToolBuilderWithTypedExtractor {
806 name: self.name,
807 title: self.title,
808 description: self.description,
809 output_schema: self.output_schema,
810 icons: self.icons,
811 annotations: self.annotations,
812 state,
813 handler,
814 _phantom: std::marker::PhantomData,
815 }
816 }
817}
818
819pub struct ToolBuilderWithHandler<I, F> {
821 name: String,
822 title: Option<String>,
823 description: Option<String>,
824 output_schema: Option<Value>,
825 icons: Option<Vec<ToolIcon>>,
826 annotations: Option<ToolAnnotations>,
827 handler: F,
828 _phantom: std::marker::PhantomData<I>,
829}
830
831impl<I, F, Fut> ToolBuilderWithHandler<I, F>
832where
833 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
834 F: Fn(I) -> Fut + Send + Sync + 'static,
835 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
836{
837 pub fn build(self) -> Result<Tool> {
841 validate_tool_name(&self.name)?;
842 Ok(Tool::from_handler(
843 self.name,
844 self.title,
845 self.description,
846 self.output_schema,
847 self.icons,
848 self.annotations,
849 TypedHandler {
850 handler: self.handler,
851 _phantom: std::marker::PhantomData,
852 },
853 ))
854 }
855
856 pub fn layer<L>(self, layer: L) -> ToolBuilderWithLayer<I, F, L> {
883 ToolBuilderWithLayer {
884 name: self.name,
885 title: self.title,
886 description: self.description,
887 output_schema: self.output_schema,
888 icons: self.icons,
889 annotations: self.annotations,
890 handler: self.handler,
891 layer,
892 _phantom: std::marker::PhantomData,
893 }
894 }
895}
896
897pub struct ToolBuilderWithLayer<I, F, L> {
901 name: String,
902 title: Option<String>,
903 description: Option<String>,
904 output_schema: Option<Value>,
905 icons: Option<Vec<ToolIcon>>,
906 annotations: Option<ToolAnnotations>,
907 handler: F,
908 layer: L,
909 _phantom: std::marker::PhantomData<I>,
910}
911
912#[allow(private_bounds)]
915impl<I, F, Fut, L> ToolBuilderWithLayer<I, F, L>
916where
917 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
918 F: Fn(I) -> Fut + Send + Sync + 'static,
919 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
920 L: tower::Layer<ToolHandlerService<TypedHandler<I, F>>> + Clone + Send + Sync + 'static,
921 L::Service: Service<ToolRequest, Response = CallToolResult> + Clone + Send + 'static,
922 <L::Service as Service<ToolRequest>>::Error: fmt::Display + Send,
923 <L::Service as Service<ToolRequest>>::Future: Send,
924{
925 pub fn build(self) -> Result<Tool> {
929 validate_tool_name(&self.name)?;
930
931 let input_schema = schemars::schema_for!(I);
932 let input_schema = serde_json::to_value(input_schema)
933 .unwrap_or_else(|_| serde_json::json!({ "type": "object" }));
934
935 let handler_service = ToolHandlerService::new(TypedHandler {
936 handler: self.handler,
937 _phantom: std::marker::PhantomData,
938 });
939 let layered = self.layer.layer(handler_service);
940 let catch_error = ToolCatchError::new(layered);
941 let service = BoxCloneService::new(catch_error);
942
943 Ok(Tool {
944 name: self.name,
945 title: self.title,
946 description: self.description,
947 output_schema: self.output_schema,
948 icons: self.icons,
949 annotations: self.annotations,
950 service,
951 input_schema,
952 })
953 }
954
955 pub fn layer<L2>(
960 self,
961 layer: L2,
962 ) -> ToolBuilderWithLayer<I, F, tower::layer::util::Stack<L2, L>> {
963 ToolBuilderWithLayer {
964 name: self.name,
965 title: self.title,
966 description: self.description,
967 output_schema: self.output_schema,
968 icons: self.icons,
969 annotations: self.annotations,
970 handler: self.handler,
971 layer: tower::layer::util::Stack::new(layer, self.layer),
972 _phantom: std::marker::PhantomData,
973 }
974 }
975}
976
977struct TypedHandler<I, F> {
983 handler: F,
984 _phantom: std::marker::PhantomData<I>,
985}
986
987impl<I, F, Fut> ToolHandler for TypedHandler<I, F>
988where
989 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
990 F: Fn(I) -> Fut + Send + Sync + 'static,
991 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
992{
993 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
994 Box::pin(async move {
995 let input: I = serde_json::from_value(args)
996 .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
997 (self.handler)(input).await
998 })
999 }
1000
1001 fn input_schema(&self) -> Value {
1002 let schema = schemars::schema_for!(I);
1003 serde_json::to_value(schema).unwrap_or_else(|_| {
1004 serde_json::json!({
1005 "type": "object"
1006 })
1007 })
1008 }
1009}
1010
1011pub trait McpTool: Send + Sync + 'static {
1052 const NAME: &'static str;
1053 const DESCRIPTION: &'static str;
1054
1055 type Input: JsonSchema + DeserializeOwned + Send;
1056 type Output: Serialize + Send;
1057
1058 fn call(&self, input: Self::Input) -> impl Future<Output = Result<Self::Output>> + Send;
1059
1060 fn annotations(&self) -> Option<ToolAnnotations> {
1062 None
1063 }
1064
1065 fn into_tool(self) -> Result<Tool>
1069 where
1070 Self: Sized,
1071 {
1072 validate_tool_name(Self::NAME)?;
1073 let annotations = self.annotations();
1074 let tool = Arc::new(self);
1075 Ok(Tool::from_handler(
1076 Self::NAME.to_string(),
1077 None,
1078 Some(Self::DESCRIPTION.to_string()),
1079 None,
1080 None,
1081 annotations,
1082 McpToolHandler { tool },
1083 ))
1084 }
1085}
1086
1087struct McpToolHandler<T: McpTool> {
1089 tool: Arc<T>,
1090}
1091
1092impl<T: McpTool> ToolHandler for McpToolHandler<T> {
1093 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1094 let tool = self.tool.clone();
1095 Box::pin(async move {
1096 let input: T::Input = serde_json::from_value(args)
1097 .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
1098 let output = tool.call(input).await?;
1099 let value = serde_json::to_value(output)
1100 .map_err(|e| Error::tool(format!("Failed to serialize output: {}", e)))?;
1101 Ok(CallToolResult::json(value))
1102 })
1103 }
1104
1105 fn input_schema(&self) -> Value {
1106 let schema = schemars::schema_for!(T::Input);
1107 serde_json::to_value(schema).unwrap_or_else(|_| {
1108 serde_json::json!({
1109 "type": "object"
1110 })
1111 })
1112 }
1113}
1114
1115#[cfg(test)]
1116mod tests {
1117 use super::*;
1118 use crate::extract::{Context, Json, RawArgs, State};
1119 use crate::protocol::Content;
1120 use schemars::JsonSchema;
1121 use serde::Deserialize;
1122
1123 #[derive(Debug, Deserialize, JsonSchema)]
1124 struct GreetInput {
1125 name: String,
1126 }
1127
1128 #[tokio::test]
1129 async fn test_builder_tool() {
1130 let tool = ToolBuilder::new("greet")
1131 .description("Greet someone")
1132 .handler(|input: GreetInput| async move {
1133 Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1134 })
1135 .build()
1136 .expect("valid tool name");
1137
1138 assert_eq!(tool.name, "greet");
1139 assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1140
1141 let result = tool.call(serde_json::json!({"name": "World"})).await;
1142
1143 assert!(!result.is_error);
1144 }
1145
1146 #[tokio::test]
1147 async fn test_raw_handler() {
1148 let tool = ToolBuilder::new("echo")
1149 .description("Echo input")
1150 .extractor_handler((), |RawArgs(args): RawArgs| async move {
1151 Ok(CallToolResult::json(args))
1152 })
1153 .build()
1154 .expect("valid tool name");
1155
1156 let result = tool.call(serde_json::json!({"foo": "bar"})).await;
1157
1158 assert!(!result.is_error);
1159 }
1160
1161 #[test]
1162 fn test_invalid_tool_name_empty() {
1163 let result = ToolBuilder::new("")
1164 .description("Empty name")
1165 .extractor_handler((), |RawArgs(args): RawArgs| async move {
1166 Ok(CallToolResult::json(args))
1167 })
1168 .build();
1169
1170 assert!(result.is_err());
1171 assert!(result.unwrap_err().to_string().contains("cannot be empty"));
1172 }
1173
1174 #[test]
1175 fn test_invalid_tool_name_too_long() {
1176 let long_name = "a".repeat(129);
1177 let result = ToolBuilder::new(long_name)
1178 .description("Too long")
1179 .extractor_handler((), |RawArgs(args): RawArgs| async move {
1180 Ok(CallToolResult::json(args))
1181 })
1182 .build();
1183
1184 assert!(result.is_err());
1185 assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
1186 }
1187
1188 #[test]
1189 fn test_invalid_tool_name_bad_chars() {
1190 let result = ToolBuilder::new("my tool!")
1191 .description("Bad chars")
1192 .extractor_handler((), |RawArgs(args): RawArgs| async move {
1193 Ok(CallToolResult::json(args))
1194 })
1195 .build();
1196
1197 assert!(result.is_err());
1198 assert!(
1199 result
1200 .unwrap_err()
1201 .to_string()
1202 .contains("invalid character")
1203 );
1204 }
1205
1206 #[test]
1207 fn test_valid_tool_names() {
1208 let names = [
1210 "my_tool",
1211 "my-tool",
1212 "my.tool",
1213 "MyTool123",
1214 "a",
1215 &"a".repeat(128),
1216 ];
1217 for name in names {
1218 let result = ToolBuilder::new(name)
1219 .description("Valid")
1220 .extractor_handler((), |RawArgs(args): RawArgs| async move {
1221 Ok(CallToolResult::json(args))
1222 })
1223 .build();
1224 assert!(result.is_ok(), "Expected '{}' to be valid", name);
1225 }
1226 }
1227
1228 #[tokio::test]
1229 async fn test_context_aware_handler() {
1230 use crate::context::notification_channel;
1231 use crate::protocol::{ProgressToken, RequestId};
1232
1233 #[derive(Debug, Deserialize, JsonSchema)]
1234 struct ProcessInput {
1235 count: i32,
1236 }
1237
1238 let tool = ToolBuilder::new("process")
1239 .description("Process with context")
1240 .extractor_handler_typed::<_, _, _, ProcessInput>(
1241 (),
1242 |ctx: Context, Json(input): Json<ProcessInput>| async move {
1243 for i in 0..input.count {
1245 if ctx.is_cancelled() {
1246 return Ok(CallToolResult::error("Cancelled"));
1247 }
1248 ctx.report_progress(i as f64, Some(input.count as f64), None)
1249 .await;
1250 }
1251 Ok(CallToolResult::text(format!(
1252 "Processed {} items",
1253 input.count
1254 )))
1255 },
1256 )
1257 .build()
1258 .expect("valid tool name");
1259
1260 assert_eq!(tool.name, "process");
1261
1262 let (tx, mut rx) = notification_channel(10);
1264 let ctx = RequestContext::new(RequestId::Number(1))
1265 .with_progress_token(ProgressToken::Number(42))
1266 .with_notification_sender(tx);
1267
1268 let result = tool
1269 .call_with_context(ctx, serde_json::json!({"count": 3}))
1270 .await;
1271
1272 assert!(!result.is_error);
1273
1274 let mut progress_count = 0;
1276 while rx.try_recv().is_ok() {
1277 progress_count += 1;
1278 }
1279 assert_eq!(progress_count, 3);
1280 }
1281
1282 #[tokio::test]
1283 async fn test_context_aware_handler_cancellation() {
1284 use crate::protocol::RequestId;
1285 use std::sync::atomic::{AtomicI32, Ordering};
1286
1287 #[derive(Debug, Deserialize, JsonSchema)]
1288 struct LongRunningInput {
1289 iterations: i32,
1290 }
1291
1292 let iterations_completed = Arc::new(AtomicI32::new(0));
1293 let iterations_ref = iterations_completed.clone();
1294
1295 let tool = ToolBuilder::new("long_running")
1296 .description("Long running task")
1297 .extractor_handler_typed::<_, _, _, LongRunningInput>(
1298 (),
1299 move |ctx: Context, Json(input): Json<LongRunningInput>| {
1300 let completed = iterations_ref.clone();
1301 async move {
1302 for i in 0..input.iterations {
1303 if ctx.is_cancelled() {
1304 return Ok(CallToolResult::error("Cancelled"));
1305 }
1306 completed.fetch_add(1, Ordering::SeqCst);
1307 tokio::task::yield_now().await;
1309 if i == 2 {
1311 ctx.cancellation_token().cancel();
1312 }
1313 }
1314 Ok(CallToolResult::text("Done"))
1315 }
1316 },
1317 )
1318 .build()
1319 .expect("valid tool name");
1320
1321 let ctx = RequestContext::new(RequestId::Number(1));
1322
1323 let result = tool
1324 .call_with_context(ctx, serde_json::json!({"iterations": 10}))
1325 .await;
1326
1327 assert!(result.is_error);
1330 assert_eq!(iterations_completed.load(Ordering::SeqCst), 3);
1331 }
1332
1333 #[tokio::test]
1334 async fn test_tool_builder_with_enhanced_fields() {
1335 let output_schema = serde_json::json!({
1336 "type": "object",
1337 "properties": {
1338 "greeting": {"type": "string"}
1339 }
1340 });
1341
1342 let tool = ToolBuilder::new("greet")
1343 .title("Greeting Tool")
1344 .description("Greet someone")
1345 .output_schema(output_schema.clone())
1346 .icon("https://example.com/icon.png")
1347 .icon_with_meta(
1348 "https://example.com/icon-large.png",
1349 Some("image/png".to_string()),
1350 Some(vec!["96x96".to_string()]),
1351 )
1352 .handler(|input: GreetInput| async move {
1353 Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1354 })
1355 .build()
1356 .expect("valid tool name");
1357
1358 assert_eq!(tool.name, "greet");
1359 assert_eq!(tool.title.as_deref(), Some("Greeting Tool"));
1360 assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1361 assert_eq!(tool.output_schema, Some(output_schema));
1362 assert!(tool.icons.is_some());
1363 assert_eq!(tool.icons.as_ref().unwrap().len(), 2);
1364
1365 let def = tool.definition();
1367 assert_eq!(def.title.as_deref(), Some("Greeting Tool"));
1368 assert!(def.output_schema.is_some());
1369 assert!(def.icons.is_some());
1370 }
1371
1372 #[tokio::test]
1373 async fn test_handler_with_state() {
1374 let shared = Arc::new("shared-state".to_string());
1375
1376 let tool = ToolBuilder::new("stateful")
1377 .description("Uses shared state")
1378 .extractor_handler_typed::<_, _, _, GreetInput>(
1379 shared,
1380 |State(state): State<Arc<String>>, Json(input): Json<GreetInput>| async move {
1381 Ok(CallToolResult::text(format!(
1382 "{}: Hello, {}!",
1383 state, input.name
1384 )))
1385 },
1386 )
1387 .build()
1388 .expect("valid tool name");
1389
1390 let result = tool.call(serde_json::json!({"name": "World"})).await;
1391 assert!(!result.is_error);
1392 }
1393
1394 #[tokio::test]
1395 async fn test_handler_with_state_and_context() {
1396 use crate::protocol::RequestId;
1397
1398 let shared = Arc::new(42_i32);
1399
1400 let tool =
1401 ToolBuilder::new("stateful_ctx")
1402 .description("Uses state and context")
1403 .extractor_handler_typed::<_, _, _, GreetInput>(
1404 shared,
1405 |State(state): State<Arc<i32>>,
1406 _ctx: Context,
1407 Json(input): Json<GreetInput>| async move {
1408 Ok(CallToolResult::text(format!(
1409 "{}: Hello, {}!",
1410 state, input.name
1411 )))
1412 },
1413 )
1414 .build()
1415 .expect("valid tool name");
1416
1417 let ctx = RequestContext::new(RequestId::Number(1));
1418 let result = tool
1419 .call_with_context(ctx, serde_json::json!({"name": "World"}))
1420 .await;
1421 assert!(!result.is_error);
1422 }
1423
1424 #[tokio::test]
1425 async fn test_handler_no_params() {
1426 let tool = ToolBuilder::new("no_params")
1427 .description("Takes no parameters")
1428 .extractor_handler_typed::<_, _, _, NoParams>((), |Json(_): Json<NoParams>| async {
1429 Ok(CallToolResult::text("no params result"))
1430 })
1431 .build()
1432 .expect("valid tool name");
1433
1434 assert_eq!(tool.name, "no_params");
1435
1436 let result = tool.call(serde_json::json!({})).await;
1438 assert!(!result.is_error);
1439
1440 let result = tool.call(serde_json::json!({"unexpected": "value"})).await;
1442 assert!(!result.is_error);
1443
1444 let schema = tool.definition().input_schema;
1446 assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1447 }
1448
1449 #[tokio::test]
1450 async fn test_handler_with_state_no_params() {
1451 let shared = Arc::new("shared_value".to_string());
1452
1453 let tool = ToolBuilder::new("with_state_no_params")
1454 .description("Takes no parameters but has state")
1455 .extractor_handler_typed::<_, _, _, NoParams>(
1456 shared,
1457 |State(state): State<Arc<String>>, Json(_): Json<NoParams>| async move {
1458 Ok(CallToolResult::text(format!("state: {}", state)))
1459 },
1460 )
1461 .build()
1462 .expect("valid tool name");
1463
1464 assert_eq!(tool.name, "with_state_no_params");
1465
1466 let result = tool.call(serde_json::json!({})).await;
1468 assert!(!result.is_error);
1469 assert_eq!(result.first_text().unwrap(), "state: shared_value");
1470
1471 let schema = tool.definition().input_schema;
1473 assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1474 }
1475
1476 #[tokio::test]
1477 async fn test_handler_no_params_with_context() {
1478 let tool = ToolBuilder::new("no_params_with_context")
1479 .description("Takes no parameters but has context")
1480 .extractor_handler_typed::<_, _, _, NoParams>(
1481 (),
1482 |_ctx: Context, Json(_): Json<NoParams>| async move {
1483 Ok(CallToolResult::text("context available"))
1484 },
1485 )
1486 .build()
1487 .expect("valid tool name");
1488
1489 assert_eq!(tool.name, "no_params_with_context");
1490
1491 let result = tool.call(serde_json::json!({})).await;
1492 assert!(!result.is_error);
1493 assert_eq!(result.first_text().unwrap(), "context available");
1494 }
1495
1496 #[tokio::test]
1497 async fn test_handler_with_state_and_context_no_params() {
1498 let shared = Arc::new("shared".to_string());
1499
1500 let tool = ToolBuilder::new("state_context_no_params")
1501 .description("Has state and context, no params")
1502 .extractor_handler_typed::<_, _, _, NoParams>(
1503 shared,
1504 |State(state): State<Arc<String>>,
1505 _ctx: Context,
1506 Json(_): Json<NoParams>| async move {
1507 Ok(CallToolResult::text(format!("state: {}", state)))
1508 },
1509 )
1510 .build()
1511 .expect("valid tool name");
1512
1513 assert_eq!(tool.name, "state_context_no_params");
1514
1515 let result = tool.call(serde_json::json!({})).await;
1516 assert!(!result.is_error);
1517 assert_eq!(result.first_text().unwrap(), "state: shared");
1518 }
1519
1520 #[tokio::test]
1521 async fn test_raw_handler_with_state() {
1522 let prefix = Arc::new("prefix:".to_string());
1523
1524 let tool = ToolBuilder::new("raw_with_state")
1525 .description("Raw handler with state")
1526 .extractor_handler(
1527 prefix,
1528 |State(state): State<Arc<String>>, RawArgs(args): RawArgs| async move {
1529 Ok(CallToolResult::text(format!("{} {}", state, args)))
1530 },
1531 )
1532 .build()
1533 .expect("valid tool name");
1534
1535 assert_eq!(tool.name, "raw_with_state");
1536
1537 let result = tool.call(serde_json::json!({"key": "value"})).await;
1538 assert!(!result.is_error);
1539 assert!(result.first_text().unwrap().starts_with("prefix:"));
1540 }
1541
1542 #[tokio::test]
1543 async fn test_raw_handler_with_state_and_context() {
1544 let prefix = Arc::new("prefix:".to_string());
1545
1546 let tool = ToolBuilder::new("raw_state_context")
1547 .description("Raw handler with state and context")
1548 .extractor_handler(
1549 prefix,
1550 |State(state): State<Arc<String>>,
1551 _ctx: Context,
1552 RawArgs(args): RawArgs| async move {
1553 Ok(CallToolResult::text(format!("{} {}", state, args)))
1554 },
1555 )
1556 .build()
1557 .expect("valid tool name");
1558
1559 assert_eq!(tool.name, "raw_state_context");
1560
1561 let result = tool.call(serde_json::json!({"key": "value"})).await;
1562 assert!(!result.is_error);
1563 assert!(result.first_text().unwrap().starts_with("prefix:"));
1564 }
1565
1566 #[tokio::test]
1567 async fn test_tool_with_timeout_layer() {
1568 use std::time::Duration;
1569 use tower::timeout::TimeoutLayer;
1570
1571 #[derive(Debug, Deserialize, JsonSchema)]
1572 struct SlowInput {
1573 delay_ms: u64,
1574 }
1575
1576 let tool = ToolBuilder::new("slow_tool")
1578 .description("A slow tool")
1579 .handler(|input: SlowInput| async move {
1580 tokio::time::sleep(Duration::from_millis(input.delay_ms)).await;
1581 Ok(CallToolResult::text("completed"))
1582 })
1583 .layer(TimeoutLayer::new(Duration::from_millis(50)))
1584 .build()
1585 .expect("valid tool name");
1586
1587 let result = tool.call(serde_json::json!({"delay_ms": 10})).await;
1589 assert!(!result.is_error);
1590 assert_eq!(result.first_text().unwrap(), "completed");
1591
1592 let result = tool.call(serde_json::json!({"delay_ms": 200})).await;
1594 assert!(result.is_error);
1595 let msg = result.first_text().unwrap().to_lowercase();
1597 assert!(
1598 msg.contains("timed out") || msg.contains("timeout") || msg.contains("elapsed"),
1599 "Expected timeout error, got: {}",
1600 msg
1601 );
1602 }
1603
1604 #[tokio::test]
1605 async fn test_tool_with_concurrency_limit_layer() {
1606 use std::sync::atomic::{AtomicU32, Ordering};
1607 use std::time::Duration;
1608 use tower::limit::ConcurrencyLimitLayer;
1609
1610 #[derive(Debug, Deserialize, JsonSchema)]
1611 struct WorkInput {
1612 id: u32,
1613 }
1614
1615 let max_concurrent = Arc::new(AtomicU32::new(0));
1616 let current_concurrent = Arc::new(AtomicU32::new(0));
1617 let max_ref = max_concurrent.clone();
1618 let current_ref = current_concurrent.clone();
1619
1620 let tool = ToolBuilder::new("concurrent_tool")
1622 .description("A concurrent tool")
1623 .handler(move |input: WorkInput| {
1624 let max = max_ref.clone();
1625 let current = current_ref.clone();
1626 async move {
1627 let prev = current.fetch_add(1, Ordering::SeqCst);
1629 max.fetch_max(prev + 1, Ordering::SeqCst);
1630
1631 tokio::time::sleep(Duration::from_millis(50)).await;
1633
1634 current.fetch_sub(1, Ordering::SeqCst);
1635 Ok(CallToolResult::text(format!("completed {}", input.id)))
1636 }
1637 })
1638 .layer(ConcurrencyLimitLayer::new(2))
1639 .build()
1640 .expect("valid tool name");
1641
1642 let handles: Vec<_> = (0..4)
1644 .map(|i| {
1645 let t = tool.call(serde_json::json!({"id": i}));
1646 tokio::spawn(t)
1647 })
1648 .collect();
1649
1650 for handle in handles {
1651 let result = handle.await.unwrap();
1652 assert!(!result.is_error);
1653 }
1654
1655 assert!(max_concurrent.load(Ordering::SeqCst) <= 2);
1657 }
1658
1659 #[tokio::test]
1660 async fn test_tool_with_multiple_layers() {
1661 use std::time::Duration;
1662 use tower::limit::ConcurrencyLimitLayer;
1663 use tower::timeout::TimeoutLayer;
1664
1665 #[derive(Debug, Deserialize, JsonSchema)]
1666 struct Input {
1667 value: String,
1668 }
1669
1670 let tool = ToolBuilder::new("multi_layer_tool")
1672 .description("Tool with multiple layers")
1673 .handler(|input: Input| async move {
1674 Ok(CallToolResult::text(format!("processed: {}", input.value)))
1675 })
1676 .layer(TimeoutLayer::new(Duration::from_secs(5)))
1677 .layer(ConcurrencyLimitLayer::new(10))
1678 .build()
1679 .expect("valid tool name");
1680
1681 let result = tool.call(serde_json::json!({"value": "test"})).await;
1682 assert!(!result.is_error);
1683 assert_eq!(result.first_text().unwrap(), "processed: test");
1684 }
1685
1686 #[test]
1687 fn test_tool_catch_error_clone() {
1688 let tool = ToolBuilder::new("test")
1691 .description("test")
1692 .extractor_handler((), |RawArgs(_args): RawArgs| async {
1693 Ok(CallToolResult::text("ok"))
1694 })
1695 .build()
1696 .unwrap();
1697 let _clone = tool.call(serde_json::json!({}));
1699 }
1700
1701 #[test]
1702 fn test_tool_catch_error_debug() {
1703 #[derive(Debug, Clone)]
1707 struct DebugService;
1708
1709 impl Service<ToolRequest> for DebugService {
1710 type Response = CallToolResult;
1711 type Error = crate::error::Error;
1712 type Future = Pin<
1713 Box<
1714 dyn Future<Output = std::result::Result<CallToolResult, crate::error::Error>>
1715 + Send,
1716 >,
1717 >;
1718
1719 fn poll_ready(
1720 &mut self,
1721 _cx: &mut std::task::Context<'_>,
1722 ) -> Poll<std::result::Result<(), Self::Error>> {
1723 Poll::Ready(Ok(()))
1724 }
1725
1726 fn call(&mut self, _req: ToolRequest) -> Self::Future {
1727 Box::pin(async { Ok(CallToolResult::text("ok")) })
1728 }
1729 }
1730
1731 let catch_error = ToolCatchError::new(DebugService);
1732 let debug = format!("{:?}", catch_error);
1733 assert!(debug.contains("ToolCatchError"));
1734 }
1735
1736 #[test]
1737 fn test_tool_request_new() {
1738 use crate::protocol::RequestId;
1739
1740 let ctx = RequestContext::new(RequestId::Number(42));
1741 let args = serde_json::json!({"key": "value"});
1742 let req = ToolRequest::new(ctx.clone(), args.clone());
1743
1744 assert_eq!(req.args, args);
1745 }
1746
1747 #[test]
1748 fn test_no_params_schema() {
1749 let schema = schemars::schema_for!(NoParams);
1751 let schema_value = serde_json::to_value(&schema).unwrap();
1752 assert_eq!(
1753 schema_value.get("type").and_then(|v| v.as_str()),
1754 Some("object"),
1755 "NoParams should generate type: object schema"
1756 );
1757 }
1758
1759 #[test]
1760 fn test_no_params_deserialize() {
1761 let from_empty_object: NoParams = serde_json::from_str("{}").unwrap();
1763 assert_eq!(from_empty_object, NoParams);
1764
1765 let from_null: NoParams = serde_json::from_str("null").unwrap();
1766 assert_eq!(from_null, NoParams);
1767
1768 let from_object_with_fields: NoParams =
1770 serde_json::from_str(r#"{"unexpected": "value"}"#).unwrap();
1771 assert_eq!(from_object_with_fields, NoParams);
1772 }
1773
1774 #[tokio::test]
1775 async fn test_no_params_type_in_handler() {
1776 let tool = ToolBuilder::new("status")
1778 .description("Get status")
1779 .handler(|_input: NoParams| async move { Ok(CallToolResult::text("OK")) })
1780 .build()
1781 .expect("valid tool name");
1782
1783 let schema = tool.definition().input_schema;
1785 assert_eq!(
1786 schema.get("type").and_then(|v| v.as_str()),
1787 Some("object"),
1788 "NoParams handler should produce type: object schema"
1789 );
1790
1791 let result = tool.call(serde_json::json!({})).await;
1793 assert!(!result.is_error);
1794 }
1795
1796 #[tokio::test]
1797 async fn test_tool_with_name_prefix() {
1798 #[derive(Debug, Deserialize, JsonSchema)]
1799 struct Input {
1800 value: String,
1801 }
1802
1803 let tool = ToolBuilder::new("query")
1804 .description("Query something")
1805 .title("Query Tool")
1806 .handler(|input: Input| async move { Ok(CallToolResult::text(&input.value)) })
1807 .build()
1808 .expect("valid tool name");
1809
1810 let prefixed = tool.with_name_prefix("db");
1812
1813 assert_eq!(prefixed.name, "db.query");
1815
1816 assert_eq!(prefixed.description.as_deref(), Some("Query something"));
1818 assert_eq!(prefixed.title.as_deref(), Some("Query Tool"));
1819
1820 let result = prefixed
1822 .call(serde_json::json!({"value": "test input"}))
1823 .await;
1824 assert!(!result.is_error);
1825 match &result.content[0] {
1826 Content::Text { text, .. } => assert_eq!(text, "test input"),
1827 _ => panic!("Expected text content"),
1828 }
1829 }
1830
1831 #[tokio::test]
1832 async fn test_tool_with_name_prefix_multiple_levels() {
1833 let tool = ToolBuilder::new("action")
1834 .description("Do something")
1835 .handler(|_: NoParams| async move { Ok(CallToolResult::text("done")) })
1836 .build()
1837 .expect("valid tool name");
1838
1839 let prefixed = tool.with_name_prefix("level1");
1841 assert_eq!(prefixed.name, "level1.action");
1842
1843 let double_prefixed = prefixed.with_name_prefix("level0");
1844 assert_eq!(double_prefixed.name, "level0.level1.action");
1845 }
1846}