1use std::collections::HashMap;
28
29use fastmcp_core::{McpContext, McpOutcome, McpResult, Outcome};
30use fastmcp_protocol::{Content, Tool};
31
32use crate::handler::{BoxFuture, BoxedToolHandler, ToolHandler};
33
34#[derive(Debug, Clone, Copy, Default)]
36pub struct NotSet;
37
38#[derive(Debug, Clone, Default)]
43pub struct ArgTransform {
44 pub name: Option<String>,
46 pub description: Option<String>,
48 pub default: Option<serde_json::Value>,
50 pub hide: bool,
53 pub required: Option<bool>,
56 pub type_schema: Option<serde_json::Value>,
58}
59
60impl ArgTransform {
61 #[must_use]
63 pub fn new() -> Self {
64 <Self as Default>::default()
65 }
66
67 #[must_use]
69 pub fn name(mut self, name: impl Into<String>) -> Self {
70 self.name = Some(name.into());
71 self
72 }
73
74 #[must_use]
76 pub fn description(mut self, desc: impl Into<String>) -> Self {
77 self.description = Some(desc.into());
78 self
79 }
80
81 #[must_use]
83 pub fn default(mut self, value: impl Into<serde_json::Value>) -> Self {
84 self.default = Some(value.into());
85 self
86 }
87
88 #[must_use]
90 pub fn default_str(self, value: impl Into<String>) -> Self {
91 self.default(serde_json::Value::String(value.into()))
92 }
93
94 #[must_use]
96 pub fn default_int(self, value: i64) -> Self {
97 self.default(serde_json::Value::Number(value.into()))
98 }
99
100 #[must_use]
102 pub fn default_bool(self, value: bool) -> Self {
103 self.default(serde_json::Value::Bool(value))
104 }
105
106 #[must_use]
111 pub fn hide(mut self) -> Self {
112 self.hide = true;
113 self
114 }
115
116 #[must_use]
118 pub fn required(mut self) -> Self {
119 self.required = Some(true);
120 self
121 }
122
123 #[must_use]
125 pub fn type_schema(mut self, schema: serde_json::Value) -> Self {
126 self.type_schema = Some(schema);
127 self
128 }
129
130 #[must_use]
132 pub fn drop_with_default(value: impl Into<serde_json::Value>) -> Self {
133 Self::new().default(value).hide()
134 }
135}
136
137pub struct TransformedTool {
145 parent: BoxedToolHandler,
147 definition: Tool,
149 arg_transforms: HashMap<String, ArgTransform>,
151 name_mapping: HashMap<String, String>,
153}
154
155impl TransformedTool {
156 pub fn from_tool<H: ToolHandler + 'static>(tool: H) -> TransformedToolBuilder {
158 TransformedToolBuilder::new(Box::new(tool))
159 }
160
161 pub fn from_boxed(tool: BoxedToolHandler) -> TransformedToolBuilder {
163 TransformedToolBuilder::new(tool)
164 }
165
166 #[must_use]
168 pub fn parent_definition(&self) -> Tool {
169 self.parent.definition()
170 }
171
172 #[must_use]
174 pub fn arg_transforms(&self) -> &HashMap<String, ArgTransform> {
175 &self.arg_transforms
176 }
177
178 fn transform_arguments(&self, arguments: serde_json::Value) -> McpResult<serde_json::Value> {
180 let mut args = match arguments {
181 serde_json::Value::Object(map) => map,
182 serde_json::Value::Null => serde_json::Map::new(),
183 _ => {
184 return Err(fastmcp_core::McpError::invalid_params(
185 "Arguments must be an object",
186 ));
187 }
188 };
189
190 let mut result = serde_json::Map::new();
191
192 for (original_name, transform) in &self.arg_transforms {
194 let new_name = transform.name.as_ref().unwrap_or(original_name);
195
196 if let Some(value) = args.remove(new_name) {
198 result.insert(original_name.clone(), value);
200 } else if let Some(default) = &transform.default {
201 result.insert(original_name.clone(), default.clone());
203 } else if transform.hide {
204 return Err(fastmcp_core::McpError::invalid_params(format!(
206 "Hidden argument '{}' requires a default value",
207 original_name
208 )));
209 }
210 }
212
213 for (key, value) in args {
215 if let Some(original) = self.name_mapping.get(&key) {
217 result.insert(original.clone(), value);
218 } else {
219 result.insert(key, value);
220 }
221 }
222
223 Ok(serde_json::Value::Object(result))
224 }
225}
226
227impl std::fmt::Debug for TransformedTool {
228 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229 f.debug_struct("TransformedTool")
230 .field("definition", &self.definition)
231 .field("arg_transforms", &self.arg_transforms)
232 .finish_non_exhaustive()
233 }
234}
235
236impl ToolHandler for TransformedTool {
237 fn definition(&self) -> Tool {
238 self.definition.clone()
239 }
240
241 fn call(&self, ctx: &McpContext, arguments: serde_json::Value) -> McpResult<Vec<Content>> {
242 let transformed_args = self.transform_arguments(arguments)?;
243 self.parent.call(ctx, transformed_args)
244 }
245
246 fn call_async<'a>(
247 &'a self,
248 ctx: &'a McpContext,
249 arguments: serde_json::Value,
250 ) -> BoxFuture<'a, McpOutcome<Vec<Content>>> {
251 Box::pin(async move {
252 let transformed_args = match self.transform_arguments(arguments) {
253 Ok(args) => args,
254 Err(e) => return Outcome::Err(e),
255 };
256 self.parent.call_async(ctx, transformed_args).await
257 })
258 }
259}
260
261pub struct TransformedToolBuilder {
263 parent: BoxedToolHandler,
264 name: Option<String>,
265 description: Option<String>,
266 arg_transforms: HashMap<String, ArgTransform>,
267}
268
269impl TransformedToolBuilder {
270 pub fn new(parent: BoxedToolHandler) -> Self {
272 Self {
273 parent,
274 name: None,
275 description: None,
276 arg_transforms: HashMap::new(),
277 }
278 }
279
280 #[must_use]
282 pub fn name(mut self, name: impl Into<String>) -> Self {
283 self.name = Some(name.into());
284 self
285 }
286
287 #[must_use]
289 pub fn description(mut self, desc: impl Into<String>) -> Self {
290 self.description = Some(desc.into());
291 self
292 }
293
294 #[must_use]
298 pub fn transform_arg(
299 mut self,
300 original_name: impl Into<String>,
301 transform: ArgTransform,
302 ) -> Self {
303 self.arg_transforms.insert(original_name.into(), transform);
304 self
305 }
306
307 #[must_use]
309 pub fn rename_arg(self, original_name: impl Into<String>, new_name: impl Into<String>) -> Self {
310 self.transform_arg(original_name, ArgTransform::new().name(new_name))
311 }
312
313 #[must_use]
315 pub fn hide_arg(
316 self,
317 original_name: impl Into<String>,
318 default: impl Into<serde_json::Value>,
319 ) -> Self {
320 self.transform_arg(original_name, ArgTransform::drop_with_default(default))
321 }
322
323 #[must_use]
325 pub fn build(self) -> TransformedTool {
326 let parent_def = self.parent.definition();
327
328 let mut name_mapping = HashMap::new();
330 for (original, transform) in &self.arg_transforms {
331 if let Some(new_name) = &transform.name {
332 name_mapping.insert(new_name.clone(), original.clone());
333 }
334 }
335
336 let definition = self.build_definition(&parent_def);
338
339 TransformedTool {
340 parent: self.parent,
341 definition,
342 arg_transforms: self.arg_transforms,
343 name_mapping,
344 }
345 }
346
347 fn build_definition(&self, parent: &Tool) -> Tool {
349 let name = self.name.clone().unwrap_or_else(|| parent.name.clone());
350 let description = self
351 .description
352 .clone()
353 .or_else(|| parent.description.clone());
354
355 let input_schema = self.transform_schema(&parent.input_schema);
357
358 Tool {
359 name,
360 description,
361 input_schema,
362 output_schema: parent.output_schema.clone(),
363 icon: parent.icon.clone(),
364 version: parent.version.clone(),
365 tags: parent.tags.clone(),
366 annotations: parent.annotations.clone(),
367 }
368 }
369
370 fn transform_schema(&self, original: &serde_json::Value) -> serde_json::Value {
372 let mut schema = original.clone();
373
374 let Some(obj) = schema.as_object_mut() else {
375 return schema;
376 };
377
378 if !obj.contains_key("properties") {
382 obj.insert(String::from("properties"), serde_json::json!({}));
383 }
384 if !obj.contains_key("required") {
385 obj.insert(String::from("required"), serde_json::json!([]));
386 }
387
388 let capacity = self.arg_transforms.len();
391 let mut props_to_remove: Vec<String> = Vec::with_capacity(capacity);
392 let mut props_to_add: Vec<(String, serde_json::Value)> = Vec::with_capacity(capacity);
393 let mut required_renames: Vec<(String, String)> = Vec::with_capacity(capacity);
394 let mut required_removes: Vec<String> = Vec::with_capacity(capacity);
395
396 {
398 let props = obj["properties"].as_object().unwrap();
399
400 for (original_name, transform) in &self.arg_transforms {
401 if transform.hide {
402 props_to_remove.push(original_name.clone());
403 required_removes.push(original_name.clone());
404 continue;
405 }
406
407 if let Some(prop_schema) = props.get(original_name).cloned() {
408 let new_name = transform.name.as_ref().unwrap_or(original_name);
409 let mut new_schema = prop_schema;
410
411 if let (Some(desc), Some(schema_obj)) =
413 (&transform.description, new_schema.as_object_mut())
414 {
415 schema_obj.insert(String::from("description"), serde_json::json!(desc));
416 }
417
418 if let Some(type_schema) = &transform.type_schema {
420 new_schema = type_schema.clone();
421 }
422
423 if let (Some(default), Some(schema_obj)) =
425 (&transform.default, new_schema.as_object_mut())
426 {
427 schema_obj.insert(String::from("default"), default.clone());
428 }
429
430 if new_name != original_name {
431 props_to_remove.push(original_name.clone());
432 props_to_add.push((new_name.clone(), new_schema));
433 required_renames.push((original_name.clone(), new_name.clone()));
434 } else {
435 props_to_add.push((original_name.clone(), new_schema));
437 }
438 }
439 }
440 }
441
442 if let Some(props) = obj.get_mut("properties").and_then(|p| p.as_object_mut()) {
444 for name in &props_to_remove {
445 props.remove(name);
446 }
447 for (name, prop_schema) in props_to_add {
448 props.insert(name, prop_schema);
449 }
450 }
451
452 if let Some(required) = obj.get_mut("required").and_then(|r| r.as_array_mut()) {
454 for (old_name, new_name) in required_renames {
456 if let Some(idx) = required.iter().position(|v| v.as_str() == Some(&old_name)) {
457 required[idx] = serde_json::json!(new_name);
458 }
459 }
460 required.retain(|v| {
462 v.as_str()
463 .is_none_or(|s| !required_removes.iter().any(|r| r == s))
464 });
465 }
466
467 schema
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474 use fastmcp_protocol::Content;
475
476 struct MockTool {
477 name: String,
478 description: Option<String>,
479 schema: serde_json::Value,
480 }
481
482 impl MockTool {
483 fn new(name: &str) -> Self {
484 Self {
485 name: name.to_string(),
486 description: Some("Mock tool".to_string()),
487 schema: serde_json::json!({
488 "type": "object",
489 "properties": {
490 "q": {
491 "type": "string",
492 "description": "Query"
493 },
494 "n": {
495 "type": "integer",
496 "description": "Limit"
497 }
498 },
499 "required": ["q"]
500 }),
501 }
502 }
503 }
504
505 impl ToolHandler for MockTool {
506 fn definition(&self) -> Tool {
507 Tool {
508 name: self.name.clone(),
509 description: self.description.clone(),
510 input_schema: self.schema.clone(),
511 output_schema: None,
512 icon: None,
513 version: None,
514 tags: vec![],
515 annotations: None,
516 }
517 }
518
519 fn call(&self, _ctx: &McpContext, arguments: serde_json::Value) -> McpResult<Vec<Content>> {
520 Ok(vec![Content::Text {
521 text: format!("Called with: {}", arguments),
522 }])
523 }
524 }
525
526 #[test]
527 fn test_rename_tool() {
528 let tool = MockTool::new("search");
529 let transformed = TransformedTool::from_tool(tool)
530 .name("semantic_search")
531 .description("Search semantically")
532 .build();
533
534 let def = transformed.definition();
535 assert_eq!(def.name, "semantic_search");
536 assert_eq!(def.description, Some("Search semantically".to_string()));
537 }
538
539 #[test]
540 fn test_rename_arg() {
541 let tool = MockTool::new("search");
542 let transformed = TransformedTool::from_tool(tool)
543 .rename_arg("q", "query")
544 .build();
545
546 let def = transformed.definition();
547 let props = def.input_schema["properties"].as_object().unwrap();
548
549 assert!(!props.contains_key("q"));
551 assert!(props.contains_key("query"));
553 }
554
555 #[test]
556 fn test_hide_arg() {
557 let tool = MockTool::new("search");
558 let transformed = TransformedTool::from_tool(tool).hide_arg("n", 10).build();
559
560 let def = transformed.definition();
561 let props = def.input_schema["properties"].as_object().unwrap();
562
563 assert!(!props.contains_key("n"));
565 assert!(props.contains_key("q"));
567 }
568
569 #[test]
570 fn test_transform_arguments() {
571 let tool = MockTool::new("search");
572 let transformed = TransformedTool::from_tool(tool)
573 .rename_arg("q", "query")
574 .hide_arg("n", 10)
575 .build();
576
577 let input = serde_json::json!({
579 "query": "hello world"
580 });
581
582 let result = transformed.transform_arguments(input).unwrap();
584 let obj = result.as_object().unwrap();
585
586 assert_eq!(obj.get("q").unwrap(), "hello world");
587 assert_eq!(obj.get("n").unwrap(), 10);
588 }
589
590 #[test]
591 fn test_arg_transform_builder() {
592 let transform = ArgTransform::new()
593 .name("search_query")
594 .description("The search query string")
595 .default_str("*")
596 .required();
597
598 assert_eq!(transform.name, Some("search_query".to_string()));
599 assert_eq!(
600 transform.description,
601 Some("The search query string".to_string())
602 );
603 assert_eq!(transform.default, Some(serde_json::json!("*")));
604 assert_eq!(transform.required, Some(true));
605 assert!(!transform.hide);
606 }
607}