rmcp_openapi/
tool_generator.rs

1use serde_json::{Value, json};
2use std::collections::HashMap;
3
4use crate::error::OpenApiError;
5use crate::server::ToolMetadata;
6use openapiv3::{Operation, Parameter, ParameterData, ReferenceOr, Schema, SchemaKind};
7
8/// Tool generator for creating MCP tools from OpenAPI operations
9pub struct ToolGenerator;
10
11impl ToolGenerator {
12    /// Generate tool metadata from an OpenAPI operation
13    pub fn generate_tool_metadata(
14        operation: &Operation,
15        method: String,
16        path: String,
17    ) -> Result<ToolMetadata, OpenApiError> {
18        let name = operation.operation_id.clone().unwrap_or_else(|| {
19            format!(
20                "{}_{}",
21                method,
22                path.replace('/', "_").replace(['{', '}'], "")
23            )
24        });
25
26        // Build description from summary and description
27        let description = Self::build_description(operation, &method, &path);
28
29        // Generate parameter schema
30        let parameters = Self::generate_parameter_schema(&operation.parameters, &method)?;
31
32        Ok(ToolMetadata {
33            name,
34            description,
35            parameters,
36            method,
37            path,
38        })
39    }
40
41    /// Build a comprehensive description for the tool
42    fn build_description(operation: &Operation, method: &str, path: &str) -> String {
43        match (&operation.summary, &operation.description) {
44            (Some(summary), Some(desc)) => {
45                format!(
46                    "{}\n\n{}\n\nEndpoint: {} {}",
47                    summary,
48                    desc,
49                    method.to_uppercase(),
50                    path
51                )
52            }
53            (Some(summary), None) => {
54                format!(
55                    "{}\n\nEndpoint: {} {}",
56                    summary,
57                    method.to_uppercase(),
58                    path
59                )
60            }
61            (None, Some(desc)) => {
62                format!("{}\n\nEndpoint: {} {}", desc, method.to_uppercase(), path)
63            }
64            (None, None) => {
65                format!("API endpoint: {} {}", method.to_uppercase(), path)
66            }
67        }
68    }
69
70    /// Generate JSON Schema for tool parameters
71    fn generate_parameter_schema(
72        parameters: &[ReferenceOr<Parameter>],
73        method: &str,
74    ) -> Result<Value, OpenApiError> {
75        let mut properties = serde_json::Map::new();
76        let mut required = Vec::new();
77
78        // Group parameters by location
79        let mut path_params = Vec::new();
80        let mut query_params = Vec::new();
81        let mut header_params = Vec::new();
82        let mut cookie_params = Vec::new();
83
84        for param_ref in parameters {
85            match param_ref {
86                ReferenceOr::Item(param) => match param {
87                    Parameter::Query { parameter_data, .. } => query_params.push(parameter_data),
88                    Parameter::Header { parameter_data, .. } => header_params.push(parameter_data),
89                    Parameter::Path { parameter_data, .. } => path_params.push(parameter_data),
90                    Parameter::Cookie { parameter_data, .. } => cookie_params.push(parameter_data),
91                },
92                ReferenceOr::Reference { .. } => {
93                    // For now, skip reference parameters - could be implemented later
94                    continue;
95                }
96            }
97        }
98
99        // Process path parameters (always required)
100        for param_data in path_params {
101            let param_schema = Self::convert_parameter_schema(param_data, "path")?;
102            properties.insert(param_data.name.clone(), param_schema);
103            required.push(param_data.name.clone());
104        }
105
106        // Process query parameters
107        for param_data in &query_params {
108            let param_schema = Self::convert_parameter_schema(param_data, "query")?;
109            properties.insert(param_data.name.clone(), param_schema);
110            if param_data.required {
111                required.push(param_data.name.clone());
112            }
113        }
114
115        // Process header parameters (optional by default unless explicitly required)
116        for param_data in &header_params {
117            let mut param_schema = Self::convert_parameter_schema(param_data, "header")?;
118
119            // Add location metadata for headers
120            if let Value::Object(ref mut obj) = param_schema {
121                obj.insert("x-location".to_string(), json!("header"));
122            }
123
124            properties.insert(format!("header_{}", param_data.name), param_schema);
125            if param_data.required {
126                required.push(format!("header_{}", param_data.name));
127            }
128        }
129
130        // Process cookie parameters (rare, but supported)
131        for param_data in &cookie_params {
132            let mut param_schema = Self::convert_parameter_schema(param_data, "cookie")?;
133
134            // Add location metadata for cookies
135            if let Value::Object(ref mut obj) = param_schema {
136                obj.insert("x-location".to_string(), json!("cookie"));
137            }
138
139            properties.insert(format!("cookie_{}", param_data.name), param_schema);
140            if param_data.required {
141                required.push(format!("cookie_{}", param_data.name));
142            }
143        }
144
145        // Add request body parameter for operations that typically need it
146        if ["post", "put", "patch"].contains(&method.to_lowercase().as_str()) {
147            properties.insert(
148                "request_body".to_string(),
149                json!({
150                    "type": "object",
151                    "description": "Request body data (JSON)",
152                    "additionalProperties": true,
153                    "x-location": "body",
154                    "x-content-type": mime::APPLICATION_JSON.as_ref()
155                }),
156            );
157        }
158
159        // Add special parameters for request configuration
160        if !query_params.is_empty() || !header_params.is_empty() || !cookie_params.is_empty() {
161            // Add optional timeout parameter
162            properties.insert(
163                "timeout_seconds".to_string(),
164                json!({
165                    "type": "integer",
166                    "description": "Request timeout in seconds",
167                    "minimum": 1,
168                    "maximum": 300,
169                    "default": 30
170                }),
171            );
172        }
173
174        Ok(json!({
175            "type": "object",
176            "properties": properties,
177            "required": required,
178            "additionalProperties": false
179        }))
180    }
181
182    /// Convert OpenAPI parameter schema to JSON Schema for MCP tools
183    fn convert_parameter_schema(
184        param_data: &ParameterData,
185        location: &str,
186    ) -> Result<Value, OpenApiError> {
187        let mut result = serde_json::Map::new();
188
189        // Handle the parameter schema
190        match &param_data.format {
191            openapiv3::ParameterSchemaOrContent::Schema(schema_ref) => {
192                match schema_ref {
193                    ReferenceOr::Item(schema) => {
194                        Self::convert_schema_to_json_schema(schema, &mut result)?;
195                    }
196                    ReferenceOr::Reference { .. } => {
197                        // For now, default to string for references
198                        result.insert("type".to_string(), json!("string"));
199                    }
200                }
201            }
202            openapiv3::ParameterSchemaOrContent::Content(_) => {
203                // For content parameters, default to object
204                result.insert("type".to_string(), json!("object"));
205            }
206        }
207
208        // Add description
209        if let Some(desc) = &param_data.description {
210            result.insert("description".to_string(), json!(desc));
211        } else {
212            result.insert(
213                "description".to_string(),
214                json!(format!("{} parameter", param_data.name)),
215            );
216        }
217
218        // Add parameter location metadata
219        result.insert("x-parameter-location".to_string(), json!(location));
220        result.insert(
221            "x-parameter-required".to_string(),
222            json!(param_data.required),
223        );
224
225        Ok(Value::Object(result))
226    }
227
228    /// Convert openapiv3::Schema to JSON Schema properties
229    fn convert_schema_to_json_schema(
230        schema: &Schema,
231        result: &mut serde_json::Map<String, Value>,
232    ) -> Result<(), OpenApiError> {
233        match &schema.schema_kind {
234            SchemaKind::Type(type_) => match type_ {
235                openapiv3::Type::String(string_type) => {
236                    result.insert("type".to_string(), json!("string"));
237                    if let Some(min_length) = string_type.min_length {
238                        result.insert("minLength".to_string(), json!(min_length));
239                    }
240                    if let Some(max_length) = string_type.max_length {
241                        result.insert("maxLength".to_string(), json!(max_length));
242                    }
243                    if let Some(pattern) = &string_type.pattern {
244                        result.insert("pattern".to_string(), json!(pattern));
245                    }
246                    if let openapiv3::VariantOrUnknownOrEmpty::Item(format) = &string_type.format {
247                        result.insert("format".to_string(), json!(format!("{:?}", format)));
248                    }
249                }
250                openapiv3::Type::Number(number_type) => {
251                    result.insert("type".to_string(), json!("number"));
252                    if let Some(minimum) = number_type.minimum {
253                        result.insert("minimum".to_string(), json!(minimum));
254                    }
255                    if let Some(maximum) = number_type.maximum {
256                        result.insert("maximum".to_string(), json!(maximum));
257                    }
258                    if let openapiv3::VariantOrUnknownOrEmpty::Item(format) = &number_type.format {
259                        result.insert("format".to_string(), json!(format!("{:?}", format)));
260                    }
261                }
262                openapiv3::Type::Integer(integer_type) => {
263                    result.insert("type".to_string(), json!("integer"));
264                    if let Some(minimum) = integer_type.minimum {
265                        result.insert("minimum".to_string(), json!(minimum));
266                    }
267                    if let Some(maximum) = integer_type.maximum {
268                        result.insert("maximum".to_string(), json!(maximum));
269                    }
270                    if let openapiv3::VariantOrUnknownOrEmpty::Item(format) = &integer_type.format {
271                        result.insert("format".to_string(), json!(format!("{:?}", format)));
272                    }
273                }
274                openapiv3::Type::Boolean(_) => {
275                    result.insert("type".to_string(), json!("boolean"));
276                }
277                openapiv3::Type::Array(array_type) => {
278                    result.insert("type".to_string(), json!("array"));
279                    if let Some(items) = &array_type.items {
280                        match items {
281                            ReferenceOr::Item(item_schema) => {
282                                let mut items_result = serde_json::Map::new();
283                                Self::convert_schema_to_json_schema(
284                                    item_schema,
285                                    &mut items_result,
286                                )?;
287                                result.insert("items".to_string(), Value::Object(items_result));
288                            }
289                            ReferenceOr::Reference { .. } => {
290                                result.insert("items".to_string(), json!({"type": "string"}));
291                            }
292                        }
293                    } else {
294                        result.insert("items".to_string(), json!({"type": "string"}));
295                    }
296                }
297                openapiv3::Type::Object(_) => {
298                    result.insert("type".to_string(), json!("object"));
299                    result.insert("additionalProperties".to_string(), json!(true));
300                }
301            },
302            SchemaKind::OneOf { .. } | SchemaKind::AllOf { .. } | SchemaKind::AnyOf { .. } => {
303                // For complex schema types, default to object
304                result.insert("type".to_string(), json!("object"));
305            }
306            SchemaKind::Not { .. } => {
307                // For not schema, default to string
308                result.insert("type".to_string(), json!("string"));
309            }
310            SchemaKind::Any(_) => {
311                // For any schema, allow any type
312                result.insert("type".to_string(), json!("object"));
313            }
314        }
315
316        // Handle enum values - in openapiv3 this is typically handled in the schema_kind
317        // For now, we'll skip enum handling as it's more complex in openapiv3
318
319        Ok(())
320    }
321
322    /// Extract parameter values from MCP tool call arguments
323    pub fn extract_parameters(
324        tool_metadata: &ToolMetadata,
325        arguments: &Value,
326    ) -> Result<ExtractedParameters, OpenApiError> {
327        let args = arguments
328            .as_object()
329            .ok_or_else(|| OpenApiError::Validation("Arguments must be an object".to_string()))?;
330
331        let mut path_params = HashMap::new();
332        let mut query_params = HashMap::new();
333        let mut header_params = HashMap::new();
334        let mut cookie_params = HashMap::new();
335        let mut body_params = HashMap::new();
336        let mut config = RequestConfig::default();
337
338        // Extract timeout if provided
339        if let Some(timeout) = args.get("timeout_seconds").and_then(|v| v.as_u64()) {
340            config.timeout_seconds = timeout as u32;
341        }
342
343        // Process each argument
344        for (key, value) in args {
345            if key == "timeout_seconds" {
346                continue; // Already processed
347            }
348
349            // Handle special request_body parameter
350            if key == "request_body" {
351                body_params.insert("request_body".to_string(), value.clone());
352                continue;
353            }
354
355            // Determine parameter location from the tool metadata
356            let location = Self::get_parameter_location(tool_metadata, key)?;
357
358            match location.as_str() {
359                "path" => {
360                    path_params.insert(key.clone(), value.clone());
361                }
362                "query" => {
363                    query_params.insert(key.clone(), value.clone());
364                }
365                "header" => {
366                    // Remove "header_" prefix if present
367                    let header_name = if key.starts_with("header_") {
368                        key.strip_prefix("header_").unwrap_or(key).to_string()
369                    } else {
370                        key.clone()
371                    };
372                    header_params.insert(header_name, value.clone());
373                }
374                "cookie" => {
375                    // Remove "cookie_" prefix if present
376                    let cookie_name = if key.starts_with("cookie_") {
377                        key.strip_prefix("cookie_").unwrap_or(key).to_string()
378                    } else {
379                        key.clone()
380                    };
381                    cookie_params.insert(cookie_name, value.clone());
382                }
383                "body" => {
384                    // Remove "body_" prefix if present
385                    let body_name = if key.starts_with("body_") {
386                        key.strip_prefix("body_").unwrap_or(key).to_string()
387                    } else {
388                        key.clone()
389                    };
390                    body_params.insert(body_name, value.clone());
391                }
392                _ => {
393                    return Err(OpenApiError::ToolGeneration(format!(
394                        "Unknown parameter location for parameter: {key}"
395                    )));
396                }
397            }
398        }
399
400        let extracted = ExtractedParameters {
401            path: path_params,
402            query: query_params,
403            headers: header_params,
404            cookies: cookie_params,
405            body: body_params,
406            config,
407        };
408
409        // Validate parameters against tool metadata
410        Self::validate_parameters(tool_metadata, &extracted)?;
411
412        Ok(extracted)
413    }
414
415    /// Get parameter location from tool metadata
416    fn get_parameter_location(
417        tool_metadata: &ToolMetadata,
418        param_name: &str,
419    ) -> Result<String, OpenApiError> {
420        let properties = tool_metadata
421            .parameters
422            .get("properties")
423            .and_then(|p| p.as_object())
424            .ok_or_else(|| {
425                OpenApiError::ToolGeneration("Invalid tool parameters schema".to_string())
426            })?;
427
428        if let Some(param_schema) = properties.get(param_name) {
429            if let Some(location) = param_schema
430                .get("x-parameter-location")
431                .and_then(|v| v.as_str())
432            {
433                return Ok(location.to_string());
434            }
435        }
436
437        // Fallback: infer from parameter name prefix
438        if param_name.starts_with("header_") {
439            Ok("header".to_string())
440        } else if param_name.starts_with("cookie_") {
441            Ok("cookie".to_string())
442        } else if param_name.starts_with("body_") {
443            Ok("body".to_string())
444        } else {
445            // Default to query for unknown parameters
446            Ok("query".to_string())
447        }
448    }
449
450    /// Validate extracted parameters against tool metadata
451    fn validate_parameters(
452        tool_metadata: &ToolMetadata,
453        extracted: &ExtractedParameters,
454    ) -> Result<(), OpenApiError> {
455        let schema = &tool_metadata.parameters;
456
457        // Get required parameters from schema
458        let required_params = schema
459            .get("required")
460            .and_then(|r| r.as_array())
461            .map(|arr| {
462                arr.iter()
463                    .filter_map(|v| v.as_str())
464                    .collect::<std::collections::HashSet<_>>()
465            })
466            .unwrap_or_default();
467
468        let _properties = schema
469            .get("properties")
470            .and_then(|p| p.as_object())
471            .ok_or_else(|| {
472                OpenApiError::Validation("Tool schema missing properties".to_string())
473            })?;
474
475        // Check all required parameters are provided
476        for required_param in &required_params {
477            let param_found = extracted.path.contains_key(*required_param)
478                || extracted.query.contains_key(*required_param)
479                || extracted
480                    .headers
481                    .contains_key(&required_param.replace("header_", ""))
482                || extracted
483                    .cookies
484                    .contains_key(&required_param.replace("cookie_", ""))
485                || extracted
486                    .body
487                    .contains_key(&required_param.replace("body_", ""))
488                || (*required_param == "request_body"
489                    && extracted.body.contains_key("request_body"));
490
491            if !param_found {
492                return Err(OpenApiError::InvalidParameter {
493                    parameter: required_param.to_string(),
494                    reason: "Required parameter is missing".to_string(),
495                });
496            }
497        }
498
499        Ok(())
500    }
501}
502
503/// Extracted parameters from MCP tool call
504#[derive(Debug, Clone)]
505pub struct ExtractedParameters {
506    pub path: HashMap<String, Value>,
507    pub query: HashMap<String, Value>,
508    pub headers: HashMap<String, Value>,
509    pub cookies: HashMap<String, Value>,
510    pub body: HashMap<String, Value>,
511    pub config: RequestConfig,
512}
513
514/// Request configuration options
515#[derive(Debug, Clone)]
516pub struct RequestConfig {
517    pub timeout_seconds: u32,
518    pub content_type: String,
519}
520
521impl Default for RequestConfig {
522    fn default() -> Self {
523        Self {
524            timeout_seconds: 30,
525            content_type: mime::APPLICATION_JSON.to_string(),
526        }
527    }
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533    use openapiv3::*;
534    use serde_json::{Value, json};
535
536    fn validate_tool_against_mcp_schema(metadata: &ToolMetadata) {
537        let schema_content = std::fs::read_to_string("schema/2025-03-26/schema.json")
538            .expect("Failed to read MCP schema file");
539        let full_schema: Value =
540            serde_json::from_str(&schema_content).expect("Failed to parse MCP schema JSON");
541
542        // Create a schema that references the Tool definition from the full schema
543        let tool_schema = json!({
544            "$schema": "http://json-schema.org/draft-07/schema#",
545            "definitions": full_schema.get("definitions"),
546            "$ref": "#/definitions/Tool"
547        });
548
549        let validator =
550            jsonschema::validator_for(&tool_schema).expect("Failed to compile MCP Tool schema");
551
552        // Convert ToolMetadata to MCP Tool format
553        let mcp_tool = json!({
554            "name": metadata.name,
555            "description": metadata.description,
556            "inputSchema": metadata.parameters
557        });
558
559        // Validate the generated tool against MCP schema
560        let errors: Vec<String> = validator
561            .iter_errors(&mcp_tool)
562            .map(|e| e.to_string())
563            .collect();
564
565        if !errors.is_empty() {
566            panic!("Generated tool failed MCP schema validation: {errors:?}");
567        }
568    }
569
570    #[test]
571    fn test_petstore_get_pet_by_id() {
572        let mut operation = Operation {
573            operation_id: Some("getPetById".to_string()),
574            summary: Some("Find pet by ID".to_string()),
575            description: Some("Returns a single pet".to_string()),
576            ..Default::default()
577        };
578
579        // Create a path parameter
580        let param_data = ParameterData {
581            name: "petId".to_string(),
582            description: Some("ID of pet to return".to_string()),
583            required: true,
584            deprecated: None,
585            format: ParameterSchemaOrContent::Schema(ReferenceOr::Item(Schema {
586                schema_data: SchemaData::default(),
587                schema_kind: SchemaKind::Type(Type::Integer(IntegerType {
588                    format: openapiv3::VariantOrUnknownOrEmpty::Item(IntegerFormat::Int64),
589                    minimum: Some(1),
590                    maximum: None,
591                    exclusive_minimum: false,
592                    exclusive_maximum: false,
593                    multiple_of: None,
594                    enumeration: Vec::new(),
595                })),
596            })),
597            example: None,
598            examples: indexmap::IndexMap::new(),
599            extensions: indexmap::IndexMap::new(),
600            explode: None,
601        };
602
603        operation
604            .parameters
605            .push(ReferenceOr::Item(Parameter::Path {
606                parameter_data: param_data,
607                style: Default::default(),
608            }));
609
610        let metadata = ToolGenerator::generate_tool_metadata(
611            &operation,
612            "get".to_string(),
613            "/pet/{petId}".to_string(),
614        )
615        .unwrap();
616
617        assert_eq!(metadata.name, "getPetById");
618        assert_eq!(metadata.method, "get");
619        assert_eq!(metadata.path, "/pet/{petId}");
620        assert!(metadata.description.contains("Find pet by ID"));
621
622        // Validate against MCP Tool schema
623        validate_tool_against_mcp_schema(&metadata);
624    }
625}