1use serde_json::Value;
7use std::collections::HashMap;
8
9use crate::core::error::{McpError, McpResult};
10use crate::protocol::{messages::*, types::*};
11
12pub struct InitializeHandler;
14
15impl InitializeHandler {
16 pub async fn handle(
18 server_info: &ServerInfo,
19 capabilities: &ServerCapabilities,
20 params: Option<Value>,
21 ) -> McpResult<InitializeResult> {
22 let params: InitializeParams = match params {
23 Some(p) => serde_json::from_value(p)
24 .map_err(|e| McpError::Validation(format!("Invalid initialize params: {}", e)))?,
25 None => {
26 return Err(McpError::Validation(
27 "Missing initialize parameters".to_string(),
28 ))
29 }
30 };
31
32 if params.protocol_version != MCP_PROTOCOL_VERSION {
34 return Err(McpError::Protocol(format!(
35 "Unsupported protocol version: {}. Expected: {}",
36 params.protocol_version, MCP_PROTOCOL_VERSION
37 )));
38 }
39
40 if params.client_info.name.is_empty() {
42 return Err(McpError::Validation(
43 "Client name cannot be empty".to_string(),
44 ));
45 }
46
47 if params.client_info.version.is_empty() {
48 return Err(McpError::Validation(
49 "Client version cannot be empty".to_string(),
50 ));
51 }
52
53 Ok(InitializeResult::new(
54 server_info.clone(),
55 capabilities.clone(),
56 MCP_PROTOCOL_VERSION.to_string(),
57 ))
58 }
59}
60
61pub struct ToolHandler;
63
64impl ToolHandler {
65 pub async fn handle_list(
67 tools: &HashMap<String, crate::core::tool::Tool>,
68 params: Option<Value>,
69 ) -> McpResult<ListToolsResult> {
70 let _params: ListToolsParams = match params {
71 Some(p) => serde_json::from_value(p)
72 .map_err(|e| McpError::Validation(format!("Invalid list tools params: {}", e)))?,
73 None => ListToolsParams::default(),
74 };
75
76 let tools: Vec<ToolInfo> = tools
78 .values()
79 .filter(|tool| tool.enabled)
80 .map(|tool| {
81 ToolInfo {
83 name: tool.info.name.clone(),
84 description: tool.info.description.clone(),
85 input_schema: tool.info.input_schema.clone(),
86 }
87 })
88 .collect();
89
90 Ok(ListToolsResult {
91 tools,
92 next_cursor: None,
93 })
94 }
95
96 pub async fn handle_call(
98 tools: &HashMap<String, crate::core::tool::Tool>,
99 params: Option<Value>,
100 ) -> McpResult<CallToolResult> {
101 let params: CallToolParams = match params {
102 Some(p) => serde_json::from_value(p)
103 .map_err(|e| McpError::Validation(format!("Invalid call tool params: {}", e)))?,
104 None => {
105 return Err(McpError::Validation(
106 "Missing tool call parameters".to_string(),
107 ))
108 }
109 };
110
111 if params.name.is_empty() {
112 return Err(McpError::Validation(
113 "Tool name cannot be empty".to_string(),
114 ));
115 }
116
117 let tool = tools
118 .get(¶ms.name)
119 .ok_or_else(|| McpError::ToolNotFound(params.name.clone()))?;
120
121 if !tool.enabled {
122 return Err(McpError::ToolNotFound(format!(
123 "Tool '{}' is disabled",
124 params.name
125 )));
126 }
127
128 let arguments = params.arguments.unwrap_or_default();
129 let result = tool.handler.call(arguments).await?;
130
131 Ok(CallToolResult {
132 content: result.content,
133 is_error: result.is_error,
134 })
135 }
136}
137
138pub struct ResourceHandler;
140
141impl ResourceHandler {
142 pub async fn handle_list(
144 resources: &HashMap<String, crate::core::resource::Resource>,
145 params: Option<Value>,
146 ) -> McpResult<ListResourcesResult> {
147 let _params: ListResourcesParams = match params {
148 Some(p) => serde_json::from_value(p).map_err(|e| {
149 McpError::Validation(format!("Invalid list resources params: {}", e))
150 })?,
151 None => ListResourcesParams::default(),
152 };
153
154 let resources: Vec<ResourceInfo> = resources
156 .values()
157 .map(|resource| {
158 ResourceInfo {
160 uri: resource.info.uri.clone(),
161 name: resource.info.name.clone(),
162 description: resource.info.description.clone(),
163 mime_type: resource.info.mime_type.clone(),
164 }
165 })
166 .collect();
167
168 Ok(ListResourcesResult {
169 resources,
170 next_cursor: None,
171 })
172 }
173
174 pub async fn handle_read(
176 resources: &HashMap<String, crate::core::resource::Resource>,
177 params: Option<Value>,
178 ) -> McpResult<ReadResourceResult> {
179 let params: ReadResourceParams = match params {
180 Some(p) => serde_json::from_value(p).map_err(|e| {
181 McpError::Validation(format!("Invalid read resource params: {}", e))
182 })?,
183 None => {
184 return Err(McpError::Validation(
185 "Missing resource read parameters".to_string(),
186 ))
187 }
188 };
189
190 if params.uri.is_empty() {
191 return Err(McpError::Validation(
192 "Resource URI cannot be empty".to_string(),
193 ));
194 }
195
196 let resource = resources
197 .get(¶ms.uri)
198 .ok_or_else(|| McpError::ResourceNotFound(params.uri.clone()))?;
199
200 let query_params = HashMap::new();
202 let contents = resource.handler.read(¶ms.uri, &query_params).await?;
203
204 Ok(ReadResourceResult { contents })
205 }
206
207 pub async fn handle_subscribe(
209 resources: &HashMap<String, crate::core::resource::Resource>,
210 params: Option<Value>,
211 ) -> McpResult<SubscribeResourceResult> {
212 let params: SubscribeResourceParams = match params {
213 Some(p) => serde_json::from_value(p).map_err(|e| {
214 McpError::Validation(format!("Invalid subscribe resource params: {}", e))
215 })?,
216 None => {
217 return Err(McpError::Validation(
218 "Missing resource subscribe parameters".to_string(),
219 ))
220 }
221 };
222
223 if params.uri.is_empty() {
224 return Err(McpError::Validation(
225 "Resource URI cannot be empty".to_string(),
226 ));
227 }
228
229 let resource = resources
230 .get(¶ms.uri)
231 .ok_or_else(|| McpError::ResourceNotFound(params.uri.clone()))?;
232
233 resource.handler.subscribe(¶ms.uri).await?;
234
235 Ok(SubscribeResourceResult {})
236 }
237
238 pub async fn handle_unsubscribe(
240 resources: &HashMap<String, crate::core::resource::Resource>,
241 params: Option<Value>,
242 ) -> McpResult<UnsubscribeResourceResult> {
243 let params: UnsubscribeResourceParams = match params {
244 Some(p) => serde_json::from_value(p).map_err(|e| {
245 McpError::Validation(format!("Invalid unsubscribe resource params: {}", e))
246 })?,
247 None => {
248 return Err(McpError::Validation(
249 "Missing resource unsubscribe parameters".to_string(),
250 ))
251 }
252 };
253
254 if params.uri.is_empty() {
255 return Err(McpError::Validation(
256 "Resource URI cannot be empty".to_string(),
257 ));
258 }
259
260 let resource = resources
261 .get(¶ms.uri)
262 .ok_or_else(|| McpError::ResourceNotFound(params.uri.clone()))?;
263
264 resource.handler.unsubscribe(¶ms.uri).await?;
265
266 Ok(UnsubscribeResourceResult {})
267 }
268}
269
270pub struct PromptHandler;
272
273impl PromptHandler {
274 pub async fn handle_list(
276 prompts: &HashMap<String, crate::core::prompt::Prompt>,
277 params: Option<Value>,
278 ) -> McpResult<ListPromptsResult> {
279 let _params: ListPromptsParams = match params {
280 Some(p) => serde_json::from_value(p)
281 .map_err(|e| McpError::Validation(format!("Invalid list prompts params: {}", e)))?,
282 None => ListPromptsParams::default(),
283 };
284
285 let prompts: Vec<PromptInfo> = prompts
287 .values()
288 .map(|prompt| {
289 PromptInfo {
291 name: prompt.info.name.clone(),
292 description: prompt.info.description.clone(),
293 arguments: prompt.info.arguments.as_ref().map(|args| {
294 args.iter()
295 .map(|arg| PromptArgument {
296 name: arg.name.clone(),
297 description: arg.description.clone(),
298 required: arg.required,
299 })
300 .collect()
301 }),
302 }
303 })
304 .collect();
305
306 Ok(ListPromptsResult {
307 prompts,
308 next_cursor: None,
309 })
310 }
311
312 pub async fn handle_get(
314 prompts: &HashMap<String, crate::core::prompt::Prompt>,
315 params: Option<Value>,
316 ) -> McpResult<GetPromptResult> {
317 let params: GetPromptParams = match params {
318 Some(p) => serde_json::from_value(p)
319 .map_err(|e| McpError::Validation(format!("Invalid get prompt params: {}", e)))?,
320 None => {
321 return Err(McpError::Validation(
322 "Missing prompt get parameters".to_string(),
323 ))
324 }
325 };
326
327 if params.name.is_empty() {
328 return Err(McpError::Validation(
329 "Prompt name cannot be empty".to_string(),
330 ));
331 }
332
333 let prompt = prompts
334 .get(¶ms.name)
335 .ok_or_else(|| McpError::PromptNotFound(params.name.clone()))?;
336
337 let arguments = params.arguments.unwrap_or_default();
338 let result = prompt.handler.get(arguments).await?;
339
340 Ok(GetPromptResult {
341 description: result.description,
342 messages: result
343 .messages
344 .into_iter()
345 .map(|msg| {
346 PromptMessage {
348 role: msg.role,
349 content: match msg.content {
350 crate::protocol::types::PromptContent::Text { content_type, text } => {
351 PromptContent::Text { content_type, text }
352 }
353 crate::protocol::types::PromptContent::Image {
354 content_type,
355 data,
356 mime_type,
357 } => PromptContent::Image {
358 content_type,
359 data,
360 mime_type,
361 },
362 },
363 }
364 })
365 .collect(),
366 })
367 }
368}
369
370pub struct SamplingHandler;
372
373impl SamplingHandler {
374 pub async fn handle_create_message(_params: Option<Value>) -> McpResult<CreateMessageResult> {
376 Err(McpError::Protocol(
379 "Sampling not implemented on server side".to_string(),
380 ))
381 }
382}
383
384pub struct LoggingHandler;
386
387impl LoggingHandler {
388 pub async fn handle_set_level(params: Option<Value>) -> McpResult<SetLoggingLevelResult> {
390 let _params: SetLoggingLevelParams = match params {
391 Some(p) => serde_json::from_value(p).map_err(|e| {
392 McpError::Validation(format!("Invalid set logging level params: {}", e))
393 })?,
394 None => {
395 return Err(McpError::Validation(
396 "Missing logging level parameters".to_string(),
397 ))
398 }
399 };
400
401 Ok(SetLoggingLevelResult {})
405 }
406}
407
408pub struct PingHandler;
410
411impl PingHandler {
412 pub async fn handle(_params: Option<Value>) -> McpResult<PingResult> {
414 Ok(PingResult {})
415 }
416}
417
418pub mod validation {
420 use super::*;
421
422 pub fn require_params<T>(params: Option<Value>, error_msg: &str) -> McpResult<T>
424 where
425 T: serde::de::DeserializeOwned,
426 {
427 match params {
428 Some(p) => serde_json::from_value(p)
429 .map_err(|e| McpError::Validation(format!("{}: {}", error_msg, e))),
430 None => Err(McpError::Validation(error_msg.to_string())),
431 }
432 }
433
434 pub fn require_non_empty_string(value: &str, field_name: &str) -> McpResult<()> {
436 if value.is_empty() {
437 Err(McpError::Validation(format!(
438 "{} cannot be empty",
439 field_name
440 )))
441 } else {
442 Ok(())
443 }
444 }
445
446 pub fn validate_uri_format(uri: &str) -> McpResult<()> {
448 if uri.is_empty() {
449 return Err(McpError::Validation("URI cannot be empty".to_string()));
450 }
451
452 if !uri.contains("://") && !uri.starts_with('/') && !uri.starts_with("file:") {
454 return Err(McpError::Validation(
455 "URI must have a scheme or be an absolute path".to_string(),
456 ));
457 }
458
459 Ok(())
460 }
461}
462
463pub mod notifications {
465 use super::*;
466
467 pub fn tools_list_changed() -> McpResult<JsonRpcNotification> {
469 Ok(JsonRpcNotification::new(
470 methods::TOOLS_LIST_CHANGED.to_string(),
471 Some(ToolListChangedParams {}),
472 )?)
473 }
474
475 pub fn resources_list_changed() -> McpResult<JsonRpcNotification> {
477 Ok(JsonRpcNotification::new(
478 methods::RESOURCES_LIST_CHANGED.to_string(),
479 Some(ResourceListChangedParams {}),
480 )?)
481 }
482
483 pub fn prompts_list_changed() -> McpResult<JsonRpcNotification> {
485 Ok(JsonRpcNotification::new(
486 methods::PROMPTS_LIST_CHANGED.to_string(),
487 Some(PromptListChangedParams {}),
488 )?)
489 }
490
491 pub fn resource_updated(uri: String) -> McpResult<JsonRpcNotification> {
493 Ok(JsonRpcNotification::new(
494 methods::RESOURCES_UPDATED.to_string(),
495 Some(ResourceUpdatedParams { uri }),
496 )?)
497 }
498
499 pub fn progress(
501 progress_token: String,
502 progress: f32,
503 total: Option<u32>,
504 ) -> McpResult<JsonRpcNotification> {
505 Ok(JsonRpcNotification::new(
506 methods::PROGRESS.to_string(),
507 Some(ProgressParams {
508 progress_token,
509 progress,
510 total,
511 }),
512 )?)
513 }
514
515 pub fn log_message(
517 level: LoggingLevel,
518 logger: Option<String>,
519 data: Value,
520 ) -> McpResult<JsonRpcNotification> {
521 Ok(JsonRpcNotification::new(
522 methods::LOGGING_MESSAGE.to_string(),
523 Some(LoggingMessageParams {
524 level,
525 logger,
526 data,
527 }),
528 )?)
529 }
530}
531
532#[cfg(test)]
533mod tests {
534 use super::*;
535 use serde_json::json;
536
537 #[tokio::test]
538 async fn test_initialize_handler() {
539 let server_info = ServerInfo {
540 name: "test-server".to_string(),
541 version: "1.0.0".to_string(),
542 };
543 let capabilities = ServerCapabilities::default();
544
545 let params = json!({
546 "clientInfo": {
547 "name": "test-client",
548 "version": "1.0.0"
549 },
550 "capabilities": {},
551 "protocolVersion": MCP_PROTOCOL_VERSION
552 });
553
554 let result = InitializeHandler::handle(&server_info, &capabilities, Some(params)).await;
555 assert!(result.is_ok());
556
557 let init_result = result.unwrap();
558 assert_eq!(init_result.server_info.name, "test-server");
559 assert_eq!(init_result.protocol_version, MCP_PROTOCOL_VERSION);
560 }
561
562 #[tokio::test]
563 async fn test_ping_handler() {
564 let result = PingHandler::handle(None).await;
565 assert!(result.is_ok());
566 }
567
568 #[test]
569 fn test_validation_helpers() {
570 assert!(validation::require_non_empty_string("test", "field").is_ok());
572 assert!(validation::require_non_empty_string("", "field").is_err());
573
574 assert!(validation::validate_uri_format("https://example.com").is_ok());
576 assert!(validation::validate_uri_format("file:///path").is_ok());
577 assert!(validation::validate_uri_format("/absolute/path").is_ok());
578 assert!(validation::validate_uri_format("").is_err());
579 assert!(validation::validate_uri_format("invalid").is_err());
580 }
581
582 #[test]
583 fn test_notification_builders() {
584 assert!(notifications::tools_list_changed().is_ok());
585 assert!(notifications::resources_list_changed().is_ok());
586 assert!(notifications::prompts_list_changed().is_ok());
587 assert!(notifications::resource_updated("file:///test".to_string()).is_ok());
588 assert!(notifications::progress("token".to_string(), 0.5, Some(100)).is_ok());
589 assert!(notifications::log_message(
590 LoggingLevel::Info,
591 Some("test".to_string()),
592 json!({"message": "test log"})
593 )
594 .is_ok());
595 }
596}