Skip to main content

ferro_api_mcp/
service.rs

1use std::sync::Arc;
2
3use rmcp::{
4    handler::server::{
5        router::tool::{ToolRoute, ToolRouter},
6        tool::ToolCallContext,
7    },
8    model::{
9        CallToolRequestParam, CallToolResult, Content, Implementation, ListToolsResult,
10        PaginatedRequestParam, ServerCapabilities, ServerInfo, Tool, ToolAnnotations,
11    },
12    service::RequestContext,
13    RoleServer, ServerHandler,
14};
15
16use crate::http::HttpClient;
17use crate::types::ApiOperation;
18
19/// MCP service that dynamically registers one tool per OpenAPI operation.
20pub struct ApiMcpService {
21    api_name: String,
22    tool_router: ToolRouter<Self>,
23    tool_count: usize,
24}
25
26impl ApiMcpService {
27    /// Build an MCP service from parsed API operations.
28    ///
29    /// Each `ApiOperation` becomes one MCP tool. The `http_client` is shared
30    /// across all tool handlers for executing API calls.
31    pub fn new(
32        api_name: String,
33        operations: Vec<ApiOperation>,
34        http_client: Arc<HttpClient>,
35    ) -> Self {
36        let tool_count = operations.len();
37        let mut router = ToolRouter::new();
38
39        for op in operations {
40            let annotations = annotations_for_method(&op.method);
41            let input_schema = input_schema_to_arc_map(&op.input_schema);
42
43            let description = match &op.hint {
44                Some(hint) => format!("{}\n\nHint: {hint}", op.description),
45                None => op.description.clone(),
46            };
47
48            let tool =
49                Tool::new(op.tool_name.clone(), description, input_schema).annotate(annotations);
50
51            let client = Arc::clone(&http_client);
52            let route = ToolRoute::new_dyn(tool, move |ctx: ToolCallContext<'_, Self>| {
53                let client = Arc::clone(&client);
54                let op = op.clone();
55                Box::pin(async move {
56                    let args = ctx.arguments.unwrap_or_default();
57
58                    let validation_errors = validate_args(&op.input_schema, &args);
59                    if !validation_errors.is_empty() {
60                        let msg = format!(
61                            "Invalid arguments:\n{}",
62                            validation_errors
63                                .iter()
64                                .map(|e| format!("  - {e}"))
65                                .collect::<Vec<_>>()
66                                .join("\n")
67                        );
68                        return Ok(CallToolResult::error(vec![Content::text(msg)]));
69                    }
70
71                    match client.execute(&op, &args).await {
72                        Ok(response) => {
73                            let text = serde_json::to_string_pretty(&response)
74                                .unwrap_or_else(|_| response.to_string());
75                            Ok(CallToolResult::success(vec![Content::text(text)]))
76                        }
77                        Err(err) => {
78                            let msg = match &err {
79                                crate::error::Error::ApiError { status, body } => {
80                                    format!("API returned HTTP {status}:\n{body}")
81                                }
82                                crate::error::Error::HttpClient(detail) => {
83                                    format!("Connection error: {detail}")
84                                }
85                                other => format!("Error: {other}"),
86                            };
87                            Ok(CallToolResult::error(vec![Content::text(msg)]))
88                        }
89                    }
90                })
91            });
92
93            router.add_route(route);
94        }
95
96        Self {
97            api_name,
98            tool_router: router,
99            tool_count,
100        }
101    }
102}
103
104impl ServerHandler for ApiMcpService {
105    fn get_info(&self) -> ServerInfo {
106        ServerInfo {
107            protocol_version: Default::default(),
108            capabilities: ServerCapabilities::builder().enable_tools().build(),
109            server_info: Implementation {
110                name: "ferro-api-mcp".to_string(),
111                title: None,
112                version: env!("CARGO_PKG_VERSION").to_string(),
113                icons: None,
114                website_url: None,
115            },
116            instructions: Some(format!(
117                "API tools for {}. {} tools available. Use these tools to interact with the API.",
118                self.api_name, self.tool_count
119            )),
120        }
121    }
122
123    fn list_tools(
124        &self,
125        _request: Option<PaginatedRequestParam>,
126        _context: RequestContext<RoleServer>,
127    ) -> impl Future<Output = Result<ListToolsResult, rmcp::ErrorData>> + Send + '_ {
128        std::future::ready(Ok(ListToolsResult::with_all_items(
129            self.tool_router.list_all(),
130        )))
131    }
132
133    fn call_tool(
134        &self,
135        request: CallToolRequestParam,
136        context: RequestContext<RoleServer>,
137    ) -> impl Future<Output = Result<CallToolResult, rmcp::ErrorData>> + Send + '_ {
138        let tcc = ToolCallContext::new(self, request, context);
139        async move { self.tool_router.call(tcc).await }
140    }
141}
142
143/// Map HTTP method to MCP tool annotations.
144fn annotations_for_method(method: &str) -> ToolAnnotations {
145    match method.to_uppercase().as_str() {
146        "GET" => ToolAnnotations::new()
147            .read_only(true)
148            .idempotent(true)
149            .open_world(true),
150        "POST" => ToolAnnotations::new().read_only(false).open_world(true),
151        "PUT" | "PATCH" => ToolAnnotations::new()
152            .read_only(false)
153            .idempotent(true)
154            .open_world(true),
155        "DELETE" => ToolAnnotations::new()
156            .read_only(false)
157            .destructive(true)
158            .open_world(true),
159        _ => ToolAnnotations::new().open_world(true),
160    }
161}
162
163/// Convert a `serde_json::Value` (expected to be an object) into
164/// `Arc<serde_json::Map<String, serde_json::Value>>` for the MCP `Tool` input schema.
165fn input_schema_to_arc_map(
166    value: &serde_json::Value,
167) -> Arc<serde_json::Map<String, serde_json::Value>> {
168    match value {
169        serde_json::Value::Object(map) => Arc::new(map.clone()),
170        _ => Arc::new(serde_json::Map::new()),
171    }
172}
173
174/// Validate tool arguments against the operation's input schema.
175///
176/// Checks required fields are present and basic type correctness.
177/// Returns a list of validation errors, empty if valid.
178fn validate_args(
179    input_schema: &serde_json::Value,
180    args: &serde_json::Map<String, serde_json::Value>,
181) -> Vec<String> {
182    let mut errors = Vec::new();
183
184    // Check required fields
185    if let Some(required) = input_schema.get("required").and_then(|r| r.as_array()) {
186        for field in required {
187            if let Some(name) = field.as_str() {
188                if !args.contains_key(name) {
189                    errors.push(format!("missing required field: '{name}'"));
190                }
191            }
192        }
193    }
194
195    // Check type correctness for provided fields
196    if let Some(properties) = input_schema.get("properties").and_then(|p| p.as_object()) {
197        for (name, value) in args {
198            if let Some(prop_schema) = properties.get(name) {
199                if let Some(expected_type) = prop_schema.get("type").and_then(|t| t.as_str()) {
200                    let type_ok = match expected_type {
201                        "string" => value.is_string(),
202                        "integer" => value.is_i64() || value.is_u64(),
203                        "number" => value.is_number(),
204                        "boolean" => value.is_boolean(),
205                        "object" => value.is_object(),
206                        "array" => value.is_array(),
207                        _ => true,
208                    };
209                    if !type_ok {
210                        errors.push(format!(
211                            "field '{name}' expects type '{expected_type}', got {}",
212                            json_type_name(value)
213                        ));
214                    }
215                }
216            }
217        }
218    }
219
220    errors
221}
222
223/// Returns a human-readable type name for a JSON value.
224fn json_type_name(value: &serde_json::Value) -> &'static str {
225    match value {
226        serde_json::Value::Null => "null",
227        serde_json::Value::Bool(_) => "boolean",
228        serde_json::Value::Number(_) => "number",
229        serde_json::Value::String(_) => "string",
230        serde_json::Value::Array(_) => "array",
231        serde_json::Value::Object(_) => "object",
232    }
233}
234
235// Required for async trait methods in ServerHandler
236use std::future::Future;
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use serde_json::json;
242
243    #[test]
244    fn validate_args_catches_missing_required_field() {
245        let schema = json!({
246            "type": "object",
247            "properties": {
248                "name": {"type": "string"},
249                "email": {"type": "string"}
250            },
251            "required": ["name", "email"]
252        });
253        let mut args = serde_json::Map::new();
254        args.insert("name".to_string(), json!("Alice"));
255        // email missing
256
257        let errors = validate_args(&schema, &args);
258        assert_eq!(errors.len(), 1);
259        assert!(errors[0].contains("email"));
260    }
261
262    #[test]
263    fn validate_args_catches_wrong_type() {
264        let schema = json!({
265            "type": "object",
266            "properties": {
267                "count": {"type": "integer"}
268            },
269            "required": []
270        });
271        let mut args = serde_json::Map::new();
272        args.insert("count".to_string(), json!("not a number"));
273
274        let errors = validate_args(&schema, &args);
275        assert_eq!(errors.len(), 1);
276        assert!(errors[0].contains("count"));
277        assert!(errors[0].contains("integer"));
278        assert!(errors[0].contains("string"));
279    }
280
281    #[test]
282    fn validate_args_passes_valid_args() {
283        let schema = json!({
284            "type": "object",
285            "properties": {
286                "name": {"type": "string"},
287                "age": {"type": "integer"},
288                "active": {"type": "boolean"}
289            },
290            "required": ["name"]
291        });
292        let mut args = serde_json::Map::new();
293        args.insert("name".to_string(), json!("Alice"));
294        args.insert("age".to_string(), json!(30));
295        args.insert("active".to_string(), json!(true));
296
297        let errors = validate_args(&schema, &args);
298        assert!(errors.is_empty());
299    }
300
301    #[test]
302    fn validate_args_ignores_unknown_fields() {
303        let schema = json!({
304            "type": "object",
305            "properties": {
306                "name": {"type": "string"}
307            },
308            "required": ["name"]
309        });
310        let mut args = serde_json::Map::new();
311        args.insert("name".to_string(), json!("Alice"));
312        args.insert("extra_field".to_string(), json!(42));
313
314        let errors = validate_args(&schema, &args);
315        assert!(errors.is_empty());
316    }
317
318    #[test]
319    fn validate_args_passes_empty_required() {
320        let schema = json!({
321            "type": "object",
322            "properties": {
323                "name": {"type": "string"}
324            },
325            "required": []
326        });
327        let args = serde_json::Map::new();
328
329        let errors = validate_args(&schema, &args);
330        assert!(errors.is_empty());
331    }
332
333    #[test]
334    fn validate_args_checks_all_types() {
335        let schema = json!({
336            "type": "object",
337            "properties": {
338                "s": {"type": "string"},
339                "n": {"type": "number"},
340                "b": {"type": "boolean"},
341                "a": {"type": "array"},
342                "o": {"type": "object"}
343            },
344            "required": []
345        });
346        let mut args = serde_json::Map::new();
347        args.insert("s".to_string(), json!(123)); // wrong: number instead of string
348        args.insert("n".to_string(), json!("text")); // wrong: string instead of number
349        args.insert("b".to_string(), json!("true")); // wrong: string instead of boolean
350        args.insert("a".to_string(), json!({})); // wrong: object instead of array
351        args.insert("o".to_string(), json!([])); // wrong: array instead of object
352
353        let errors = validate_args(&schema, &args);
354        assert_eq!(errors.len(), 5);
355    }
356
357    #[test]
358    fn validate_args_number_accepts_integers() {
359        let schema = json!({
360            "type": "object",
361            "properties": {
362                "value": {"type": "number"}
363            },
364            "required": []
365        });
366        let mut args = serde_json::Map::new();
367        args.insert("value".to_string(), json!(42));
368
369        let errors = validate_args(&schema, &args);
370        assert!(errors.is_empty());
371    }
372
373    #[test]
374    fn json_type_name_returns_correct_names() {
375        assert_eq!(json_type_name(&json!(null)), "null");
376        assert_eq!(json_type_name(&json!(true)), "boolean");
377        assert_eq!(json_type_name(&json!(42)), "number");
378        assert_eq!(json_type_name(&json!("hello")), "string");
379        assert_eq!(json_type_name(&json!([])), "array");
380        assert_eq!(json_type_name(&json!({})), "object");
381    }
382}