1use std::sync::Arc;
2
3use rmcp::{
4 handler::server::{
5 router::tool::{ToolRoute, ToolRouter},
6 tool::ToolCallContext,
7 },
8 model::{
9 CallToolRequestParam, CallToolResult, Content, Implementation, ListToolsResult,
10 PaginatedRequestParam, ServerCapabilities, ServerInfo, Tool, ToolAnnotations,
11 },
12 service::RequestContext,
13 RoleServer, ServerHandler,
14};
15
16use crate::http::HttpClient;
17use crate::types::ApiOperation;
18
19pub struct ApiMcpService {
21 api_name: String,
22 tool_router: ToolRouter<Self>,
23 tool_count: usize,
24}
25
26impl ApiMcpService {
27 pub fn new(
32 api_name: String,
33 operations: Vec<ApiOperation>,
34 http_client: Arc<HttpClient>,
35 ) -> Self {
36 let tool_count = operations.len();
37 let mut router = ToolRouter::new();
38
39 for op in operations {
40 let annotations = annotations_for_method(&op.method);
41 let input_schema = input_schema_to_arc_map(&op.input_schema);
42
43 let description = match &op.hint {
44 Some(hint) => format!("{}\n\nHint: {hint}", op.description),
45 None => op.description.clone(),
46 };
47
48 let tool =
49 Tool::new(op.tool_name.clone(), description, input_schema).annotate(annotations);
50
51 let client = Arc::clone(&http_client);
52 let route = ToolRoute::new_dyn(tool, move |ctx: ToolCallContext<'_, Self>| {
53 let client = Arc::clone(&client);
54 let op = op.clone();
55 Box::pin(async move {
56 let args = ctx.arguments.unwrap_or_default();
57
58 let validation_errors = validate_args(&op.input_schema, &args);
59 if !validation_errors.is_empty() {
60 let msg = format!(
61 "Invalid arguments:\n{}",
62 validation_errors
63 .iter()
64 .map(|e| format!(" - {e}"))
65 .collect::<Vec<_>>()
66 .join("\n")
67 );
68 return Ok(CallToolResult::error(vec![Content::text(msg)]));
69 }
70
71 match client.execute(&op, &args).await {
72 Ok(response) => {
73 let text = serde_json::to_string_pretty(&response)
74 .unwrap_or_else(|_| response.to_string());
75 Ok(CallToolResult::success(vec![Content::text(text)]))
76 }
77 Err(err) => {
78 let msg = match &err {
79 crate::error::Error::ApiError { status, body } => {
80 format!("API returned HTTP {status}:\n{body}")
81 }
82 crate::error::Error::HttpClient(detail) => {
83 format!("Connection error: {detail}")
84 }
85 other => format!("Error: {other}"),
86 };
87 Ok(CallToolResult::error(vec![Content::text(msg)]))
88 }
89 }
90 })
91 });
92
93 router.add_route(route);
94 }
95
96 Self {
97 api_name,
98 tool_router: router,
99 tool_count,
100 }
101 }
102}
103
104impl ServerHandler for ApiMcpService {
105 fn get_info(&self) -> ServerInfo {
106 ServerInfo {
107 protocol_version: Default::default(),
108 capabilities: ServerCapabilities::builder().enable_tools().build(),
109 server_info: Implementation {
110 name: "ferro-api-mcp".to_string(),
111 title: None,
112 version: env!("CARGO_PKG_VERSION").to_string(),
113 icons: None,
114 website_url: None,
115 },
116 instructions: Some(format!(
117 "API tools for {}. {} tools available. Use these tools to interact with the API.",
118 self.api_name, self.tool_count
119 )),
120 }
121 }
122
123 fn list_tools(
124 &self,
125 _request: Option<PaginatedRequestParam>,
126 _context: RequestContext<RoleServer>,
127 ) -> impl Future<Output = Result<ListToolsResult, rmcp::ErrorData>> + Send + '_ {
128 std::future::ready(Ok(ListToolsResult::with_all_items(
129 self.tool_router.list_all(),
130 )))
131 }
132
133 fn call_tool(
134 &self,
135 request: CallToolRequestParam,
136 context: RequestContext<RoleServer>,
137 ) -> impl Future<Output = Result<CallToolResult, rmcp::ErrorData>> + Send + '_ {
138 let tcc = ToolCallContext::new(self, request, context);
139 async move { self.tool_router.call(tcc).await }
140 }
141}
142
143fn annotations_for_method(method: &str) -> ToolAnnotations {
145 match method.to_uppercase().as_str() {
146 "GET" => ToolAnnotations::new()
147 .read_only(true)
148 .idempotent(true)
149 .open_world(true),
150 "POST" => ToolAnnotations::new().read_only(false).open_world(true),
151 "PUT" | "PATCH" => ToolAnnotations::new()
152 .read_only(false)
153 .idempotent(true)
154 .open_world(true),
155 "DELETE" => ToolAnnotations::new()
156 .read_only(false)
157 .destructive(true)
158 .open_world(true),
159 _ => ToolAnnotations::new().open_world(true),
160 }
161}
162
163fn input_schema_to_arc_map(
166 value: &serde_json::Value,
167) -> Arc<serde_json::Map<String, serde_json::Value>> {
168 match value {
169 serde_json::Value::Object(map) => Arc::new(map.clone()),
170 _ => Arc::new(serde_json::Map::new()),
171 }
172}
173
174fn validate_args(
179 input_schema: &serde_json::Value,
180 args: &serde_json::Map<String, serde_json::Value>,
181) -> Vec<String> {
182 let mut errors = Vec::new();
183
184 if let Some(required) = input_schema.get("required").and_then(|r| r.as_array()) {
186 for field in required {
187 if let Some(name) = field.as_str() {
188 if !args.contains_key(name) {
189 errors.push(format!("missing required field: '{name}'"));
190 }
191 }
192 }
193 }
194
195 if let Some(properties) = input_schema.get("properties").and_then(|p| p.as_object()) {
197 for (name, value) in args {
198 if let Some(prop_schema) = properties.get(name) {
199 if let Some(expected_type) = prop_schema.get("type").and_then(|t| t.as_str()) {
200 let type_ok = match expected_type {
201 "string" => value.is_string(),
202 "integer" => value.is_i64() || value.is_u64(),
203 "number" => value.is_number(),
204 "boolean" => value.is_boolean(),
205 "object" => value.is_object(),
206 "array" => value.is_array(),
207 _ => true,
208 };
209 if !type_ok {
210 errors.push(format!(
211 "field '{name}' expects type '{expected_type}', got {}",
212 json_type_name(value)
213 ));
214 }
215 }
216 }
217 }
218 }
219
220 errors
221}
222
223fn json_type_name(value: &serde_json::Value) -> &'static str {
225 match value {
226 serde_json::Value::Null => "null",
227 serde_json::Value::Bool(_) => "boolean",
228 serde_json::Value::Number(_) => "number",
229 serde_json::Value::String(_) => "string",
230 serde_json::Value::Array(_) => "array",
231 serde_json::Value::Object(_) => "object",
232 }
233}
234
235use std::future::Future;
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use serde_json::json;
242
243 #[test]
244 fn validate_args_catches_missing_required_field() {
245 let schema = json!({
246 "type": "object",
247 "properties": {
248 "name": {"type": "string"},
249 "email": {"type": "string"}
250 },
251 "required": ["name", "email"]
252 });
253 let mut args = serde_json::Map::new();
254 args.insert("name".to_string(), json!("Alice"));
255 let errors = validate_args(&schema, &args);
258 assert_eq!(errors.len(), 1);
259 assert!(errors[0].contains("email"));
260 }
261
262 #[test]
263 fn validate_args_catches_wrong_type() {
264 let schema = json!({
265 "type": "object",
266 "properties": {
267 "count": {"type": "integer"}
268 },
269 "required": []
270 });
271 let mut args = serde_json::Map::new();
272 args.insert("count".to_string(), json!("not a number"));
273
274 let errors = validate_args(&schema, &args);
275 assert_eq!(errors.len(), 1);
276 assert!(errors[0].contains("count"));
277 assert!(errors[0].contains("integer"));
278 assert!(errors[0].contains("string"));
279 }
280
281 #[test]
282 fn validate_args_passes_valid_args() {
283 let schema = json!({
284 "type": "object",
285 "properties": {
286 "name": {"type": "string"},
287 "age": {"type": "integer"},
288 "active": {"type": "boolean"}
289 },
290 "required": ["name"]
291 });
292 let mut args = serde_json::Map::new();
293 args.insert("name".to_string(), json!("Alice"));
294 args.insert("age".to_string(), json!(30));
295 args.insert("active".to_string(), json!(true));
296
297 let errors = validate_args(&schema, &args);
298 assert!(errors.is_empty());
299 }
300
301 #[test]
302 fn validate_args_ignores_unknown_fields() {
303 let schema = json!({
304 "type": "object",
305 "properties": {
306 "name": {"type": "string"}
307 },
308 "required": ["name"]
309 });
310 let mut args = serde_json::Map::new();
311 args.insert("name".to_string(), json!("Alice"));
312 args.insert("extra_field".to_string(), json!(42));
313
314 let errors = validate_args(&schema, &args);
315 assert!(errors.is_empty());
316 }
317
318 #[test]
319 fn validate_args_passes_empty_required() {
320 let schema = json!({
321 "type": "object",
322 "properties": {
323 "name": {"type": "string"}
324 },
325 "required": []
326 });
327 let args = serde_json::Map::new();
328
329 let errors = validate_args(&schema, &args);
330 assert!(errors.is_empty());
331 }
332
333 #[test]
334 fn validate_args_checks_all_types() {
335 let schema = json!({
336 "type": "object",
337 "properties": {
338 "s": {"type": "string"},
339 "n": {"type": "number"},
340 "b": {"type": "boolean"},
341 "a": {"type": "array"},
342 "o": {"type": "object"}
343 },
344 "required": []
345 });
346 let mut args = serde_json::Map::new();
347 args.insert("s".to_string(), json!(123)); args.insert("n".to_string(), json!("text")); args.insert("b".to_string(), json!("true")); args.insert("a".to_string(), json!({})); args.insert("o".to_string(), json!([])); let errors = validate_args(&schema, &args);
354 assert_eq!(errors.len(), 5);
355 }
356
357 #[test]
358 fn validate_args_number_accepts_integers() {
359 let schema = json!({
360 "type": "object",
361 "properties": {
362 "value": {"type": "number"}
363 },
364 "required": []
365 });
366 let mut args = serde_json::Map::new();
367 args.insert("value".to_string(), json!(42));
368
369 let errors = validate_args(&schema, &args);
370 assert!(errors.is_empty());
371 }
372
373 #[test]
374 fn json_type_name_returns_correct_names() {
375 assert_eq!(json_type_name(&json!(null)), "null");
376 assert_eq!(json_type_name(&json!(true)), "boolean");
377 assert_eq!(json_type_name(&json!(42)), "number");
378 assert_eq!(json_type_name(&json!("hello")), "string");
379 assert_eq!(json_type_name(&json!([])), "array");
380 assert_eq!(json_type_name(&json!({})), "object");
381 }
382}