1use serde_json::Value;
7use std::collections::HashMap;
8
9use crate::core::error::{McpError, McpResult};
10use crate::protocol::{messages::*, methods, types::*, LATEST_PROTOCOL_VERSION};
11
12pub struct InitializeHandler;
14
15impl InitializeHandler {
16 pub async fn handle(
18 server_info: &ServerInfo,
19 capabilities: &ServerCapabilities,
20 params: Option<Value>,
21 ) -> McpResult<InitializeResult> {
22 let params: InitializeParams = match params {
23 Some(p) => serde_json::from_value(p)
24 .map_err(|e| McpError::Validation(format!("Invalid initialize params: {}", e)))?,
25 None => {
26 return Err(McpError::Validation(
27 "Missing initialize parameters".to_string(),
28 ))
29 }
30 };
31
32 if params.protocol_version != LATEST_PROTOCOL_VERSION {
34 return Err(McpError::Protocol(format!(
35 "Unsupported protocol version: {}. Expected: {}",
36 params.protocol_version, LATEST_PROTOCOL_VERSION
37 )));
38 }
39
40 if params.client_info.name.is_empty() {
42 return Err(McpError::Validation(
43 "Client name cannot be empty".to_string(),
44 ));
45 }
46
47 if params.client_info.version.is_empty() {
48 return Err(McpError::Validation(
49 "Client version cannot be empty".to_string(),
50 ));
51 }
52
53 Ok(InitializeResult::new(
54 LATEST_PROTOCOL_VERSION.to_string(),
55 capabilities.clone(),
56 server_info.clone(),
57 ))
58 }
59}
60
61pub struct ToolHandler;
63
64impl ToolHandler {
65 pub async fn handle_list(
67 tools: &HashMap<String, crate::core::tool::Tool>,
68 params: Option<Value>,
69 ) -> McpResult<ListToolsResult> {
70 let _params: ListToolsParams = match params {
71 Some(p) => serde_json::from_value(p)
72 .map_err(|e| McpError::Validation(format!("Invalid list tools params: {}", e)))?,
73 None => ListToolsParams::default(),
74 };
75
76 let tools: Vec<ToolInfo> = tools
78 .values()
79 .filter(|tool| tool.enabled)
80 .map(|tool| {
81 ToolInfo {
83 name: tool.info.name.clone(),
84 description: tool.info.description.clone(),
85 input_schema: tool.info.input_schema.clone(),
86 annotations: None,
87 }
88 })
89 .collect();
90
91 Ok(ListToolsResult {
92 tools,
93 next_cursor: None,
94 meta: None,
95 })
96 }
97
98 pub async fn handle_call(
100 tools: &HashMap<String, crate::core::tool::Tool>,
101 params: Option<Value>,
102 ) -> McpResult<CallToolResult> {
103 let params: CallToolParams = match params {
104 Some(p) => serde_json::from_value(p)
105 .map_err(|e| McpError::Validation(format!("Invalid call tool params: {}", e)))?,
106 None => {
107 return Err(McpError::Validation(
108 "Missing tool call parameters".to_string(),
109 ))
110 }
111 };
112
113 if params.name.is_empty() {
114 return Err(McpError::Validation(
115 "Tool name cannot be empty".to_string(),
116 ));
117 }
118
119 let tool = tools
120 .get(¶ms.name)
121 .ok_or_else(|| McpError::ToolNotFound(params.name.clone()))?;
122
123 if !tool.enabled {
124 return Err(McpError::ToolNotFound(format!(
125 "Tool '{}' is disabled",
126 params.name
127 )));
128 }
129
130 let arguments = params.arguments.unwrap_or_default();
131 let result = tool.handler.call(arguments).await?;
132
133 Ok(CallToolResult {
134 content: result.content,
135 is_error: result.is_error,
136 meta: None,
137 })
138 }
139}
140
141pub struct ResourceHandler;
143
144impl ResourceHandler {
145 pub async fn handle_list(
147 resources: &HashMap<String, crate::core::resource::Resource>,
148 params: Option<Value>,
149 ) -> McpResult<ListResourcesResult> {
150 let _params: ListResourcesParams = match params {
151 Some(p) => serde_json::from_value(p).map_err(|e| {
152 McpError::Validation(format!("Invalid list resources params: {}", e))
153 })?,
154 None => ListResourcesParams::default(),
155 };
156
157 let resources: Vec<ResourceInfo> = resources
159 .values()
160 .map(|resource| {
161 ResourceInfo {
163 uri: resource.info.uri.clone(),
164 name: resource.info.name.clone(),
165 description: resource.info.description.clone(),
166 mime_type: resource.info.mime_type.clone(),
167 annotations: None,
168 size: None,
169 }
170 })
171 .collect();
172
173 Ok(ListResourcesResult {
174 resources,
175 next_cursor: None,
176 meta: None,
177 })
178 }
179
180 pub async fn handle_read(
182 resources: &HashMap<String, crate::core::resource::Resource>,
183 params: Option<Value>,
184 ) -> McpResult<ReadResourceResult> {
185 let params: ReadResourceParams = match params {
186 Some(p) => serde_json::from_value(p).map_err(|e| {
187 McpError::Validation(format!("Invalid read resource params: {}", e))
188 })?,
189 None => {
190 return Err(McpError::Validation(
191 "Missing resource read parameters".to_string(),
192 ))
193 }
194 };
195
196 if params.uri.is_empty() {
197 return Err(McpError::Validation(
198 "Resource URI cannot be empty".to_string(),
199 ));
200 }
201
202 let resource = resources
203 .get(¶ms.uri)
204 .ok_or_else(|| McpError::ResourceNotFound(params.uri.clone()))?;
205
206 let query_params = HashMap::new();
208 let contents = resource.handler.read(¶ms.uri, &query_params).await?;
209
210 Ok(ReadResourceResult {
211 contents,
212 meta: None,
213 })
214 }
215
216 pub async fn handle_subscribe(
218 resources: &HashMap<String, crate::core::resource::Resource>,
219 params: Option<Value>,
220 ) -> McpResult<SubscribeResourceResult> {
221 let params: SubscribeResourceParams = match params {
222 Some(p) => serde_json::from_value(p).map_err(|e| {
223 McpError::Validation(format!("Invalid subscribe resource params: {}", e))
224 })?,
225 None => {
226 return Err(McpError::Validation(
227 "Missing resource subscribe parameters".to_string(),
228 ))
229 }
230 };
231
232 if params.uri.is_empty() {
233 return Err(McpError::Validation(
234 "Resource URI cannot be empty".to_string(),
235 ));
236 }
237
238 let resource = resources
239 .get(¶ms.uri)
240 .ok_or_else(|| McpError::ResourceNotFound(params.uri.clone()))?;
241
242 resource.handler.subscribe(¶ms.uri).await?;
243
244 Ok(SubscribeResourceResult { meta: None })
245 }
246
247 pub async fn handle_unsubscribe(
249 resources: &HashMap<String, crate::core::resource::Resource>,
250 params: Option<Value>,
251 ) -> McpResult<UnsubscribeResourceResult> {
252 let params: UnsubscribeResourceParams = match params {
253 Some(p) => serde_json::from_value(p).map_err(|e| {
254 McpError::Validation(format!("Invalid unsubscribe resource params: {}", e))
255 })?,
256 None => {
257 return Err(McpError::Validation(
258 "Missing resource unsubscribe parameters".to_string(),
259 ))
260 }
261 };
262
263 if params.uri.is_empty() {
264 return Err(McpError::Validation(
265 "Resource URI cannot be empty".to_string(),
266 ));
267 }
268
269 let resource = resources
270 .get(¶ms.uri)
271 .ok_or_else(|| McpError::ResourceNotFound(params.uri.clone()))?;
272
273 resource.handler.unsubscribe(¶ms.uri).await?;
274
275 Ok(UnsubscribeResourceResult { meta: None })
276 }
277}
278
279pub struct PromptHandler;
281
282impl PromptHandler {
283 pub async fn handle_list(
285 prompts: &HashMap<String, crate::core::prompt::Prompt>,
286 params: Option<Value>,
287 ) -> McpResult<ListPromptsResult> {
288 let _params: ListPromptsParams = match params {
289 Some(p) => serde_json::from_value(p)
290 .map_err(|e| McpError::Validation(format!("Invalid list prompts params: {}", e)))?,
291 None => ListPromptsParams::default(),
292 };
293
294 let prompts: Vec<PromptInfo> = prompts
296 .values()
297 .map(|prompt| {
298 PromptInfo {
300 name: prompt.info.name.clone(),
301 description: prompt.info.description.clone(),
302 arguments: prompt.info.arguments.as_ref().map(|args| {
303 args.iter()
304 .map(|arg| PromptArgument {
305 name: arg.name.clone(),
306 description: arg.description.clone(),
307 required: arg.required,
308 })
309 .collect()
310 }),
311 }
312 })
313 .collect();
314
315 Ok(ListPromptsResult {
316 prompts,
317 next_cursor: None,
318 meta: None,
319 })
320 }
321
322 pub async fn handle_get(
324 prompts: &HashMap<String, crate::core::prompt::Prompt>,
325 params: Option<Value>,
326 ) -> McpResult<GetPromptResult> {
327 let params: GetPromptParams = match params {
328 Some(p) => serde_json::from_value(p)
329 .map_err(|e| McpError::Validation(format!("Invalid get prompt params: {}", e)))?,
330 None => {
331 return Err(McpError::Validation(
332 "Missing prompt get parameters".to_string(),
333 ))
334 }
335 };
336
337 if params.name.is_empty() {
338 return Err(McpError::Validation(
339 "Prompt name cannot be empty".to_string(),
340 ));
341 }
342
343 let prompt = prompts
344 .get(¶ms.name)
345 .ok_or_else(|| McpError::PromptNotFound(params.name.clone()))?;
346
347 let arguments = params
348 .arguments
349 .unwrap_or_default()
350 .into_iter()
351 .map(|(k, v)| (k, serde_json::Value::String(v)))
352 .collect();
353 let result = prompt.handler.get(arguments).await?;
354
355 Ok(GetPromptResult {
356 description: result.description,
357 messages: result
358 .messages
359 .into_iter()
360 .map(|msg| {
361 PromptMessage {
363 role: msg.role,
364 content: match msg.content {
365 Content::Text { text, .. } => Content::Text {
366 text,
367 annotations: None,
368 },
369 Content::Image {
370 data, mime_type, ..
371 } => Content::Image {
372 data,
373 mime_type,
374 annotations: None,
375 },
376 other => other,
377 },
378 }
379 })
380 .collect(),
381 meta: None,
382 })
383 }
384}
385
386pub struct SamplingHandler;
388
389impl SamplingHandler {
390 pub async fn handle_create_message(_params: Option<Value>) -> McpResult<CreateMessageResult> {
392 Err(McpError::Protocol(
395 "Sampling not implemented on server side".to_string(),
396 ))
397 }
398}
399
400pub struct LoggingHandler;
402
403impl LoggingHandler {
404 pub async fn handle_set_level(params: Option<Value>) -> McpResult<SetLoggingLevelResult> {
406 let _params: SetLoggingLevelParams = match params {
407 Some(p) => serde_json::from_value(p).map_err(|e| {
408 McpError::Validation(format!("Invalid set logging level params: {}", e))
409 })?,
410 None => {
411 return Err(McpError::Validation(
412 "Missing logging level parameters".to_string(),
413 ))
414 }
415 };
416
417 Ok(SetLoggingLevelResult { meta: None })
421 }
422}
423
424pub struct PingHandler;
426
427impl PingHandler {
428 pub async fn handle(_params: Option<Value>) -> McpResult<PingResult> {
430 Ok(PingResult { meta: None })
431 }
432}
433
434pub mod validation {
436 use super::*;
437
438 pub fn require_params<T>(params: Option<Value>, error_msg: &str) -> McpResult<T>
440 where
441 T: serde::de::DeserializeOwned,
442 {
443 match params {
444 Some(p) => serde_json::from_value(p)
445 .map_err(|e| McpError::Validation(format!("{}: {}", error_msg, e))),
446 None => Err(McpError::Validation(error_msg.to_string())),
447 }
448 }
449
450 pub fn require_non_empty_string(value: &str, field_name: &str) -> McpResult<()> {
452 if value.is_empty() {
453 Err(McpError::Validation(format!(
454 "{} cannot be empty",
455 field_name
456 )))
457 } else {
458 Ok(())
459 }
460 }
461
462 pub fn validate_uri_format(uri: &str) -> McpResult<()> {
464 if uri.is_empty() {
465 return Err(McpError::Validation("URI cannot be empty".to_string()));
466 }
467
468 if !uri.contains("://") && !uri.starts_with('/') && !uri.starts_with("file:") {
470 return Err(McpError::Validation(
471 "URI must have a scheme or be an absolute path".to_string(),
472 ));
473 }
474
475 Ok(())
476 }
477}
478
479pub mod notifications {
481 use super::*;
482
483 pub fn tools_list_changed() -> McpResult<JsonRpcNotification> {
485 Ok(JsonRpcNotification::new(
486 methods::TOOLS_LIST_CHANGED.to_string(),
487 Some(ToolListChangedParams { meta: None }),
488 )?)
489 }
490
491 pub fn resources_list_changed() -> McpResult<JsonRpcNotification> {
493 Ok(JsonRpcNotification::new(
494 methods::RESOURCES_LIST_CHANGED.to_string(),
495 Some(ResourceListChangedParams { meta: None }),
496 )?)
497 }
498
499 pub fn prompts_list_changed() -> McpResult<JsonRpcNotification> {
501 Ok(JsonRpcNotification::new(
502 methods::PROMPTS_LIST_CHANGED.to_string(),
503 Some(PromptListChangedParams { meta: None }),
504 )?)
505 }
506
507 pub fn resource_updated(uri: String) -> McpResult<JsonRpcNotification> {
509 Ok(JsonRpcNotification::new(
510 methods::RESOURCES_UPDATED.to_string(),
511 Some(ResourceUpdatedParams { uri }),
512 )?)
513 }
514
515 pub fn progress(
517 progress_token: String,
518 progress: f32,
519 total: Option<f32>,
520 ) -> McpResult<JsonRpcNotification> {
521 Ok(JsonRpcNotification::new(
522 methods::PROGRESS.to_string(),
523 Some(ProgressParams {
524 progress_token: serde_json::Value::String(progress_token),
525 progress,
526 total,
527 message: None,
528 }),
529 )?)
530 }
531
532 pub fn log_message(
534 level: LoggingLevel,
535 logger: Option<String>,
536 data: Value,
537 ) -> McpResult<JsonRpcNotification> {
538 Ok(JsonRpcNotification::new(
539 methods::LOGGING_MESSAGE.to_string(),
540 Some(LoggingMessageParams {
541 level,
542 logger,
543 data,
544 }),
545 )?)
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552 use serde_json::json;
553
554 #[tokio::test]
555 async fn test_initialize_handler() {
556 let server_info = ServerInfo {
557 name: "test-server".to_string(),
558 version: "1.0.0".to_string(),
559 };
560 let capabilities = ServerCapabilities::default();
561
562 let params = json!({
563 "clientInfo": {
564 "name": "test-client",
565 "version": "1.0.0"
566 },
567 "capabilities": {},
568 "protocolVersion": LATEST_PROTOCOL_VERSION
569 });
570
571 let result = InitializeHandler::handle(&server_info, &capabilities, Some(params)).await;
572 assert!(result.is_ok());
573
574 let init_result = result.unwrap();
575 assert_eq!(init_result.server_info.name, "test-server");
576 assert_eq!(init_result.protocol_version, LATEST_PROTOCOL_VERSION);
577 }
578
579 #[tokio::test]
580 async fn test_ping_handler() {
581 let result = PingHandler::handle(None).await;
582 assert!(result.is_ok());
583 }
584
585 #[test]
586 fn test_validation_helpers() {
587 assert!(validation::require_non_empty_string("test", "field").is_ok());
589 assert!(validation::require_non_empty_string("", "field").is_err());
590
591 assert!(validation::validate_uri_format("https://example.com").is_ok());
593 assert!(validation::validate_uri_format("file:///path").is_ok());
594 assert!(validation::validate_uri_format("/absolute/path").is_ok());
595 assert!(validation::validate_uri_format("").is_err());
596 assert!(validation::validate_uri_format("invalid").is_err());
597 }
598
599 #[test]
600 fn test_notification_builders() {
601 assert!(notifications::tools_list_changed().is_ok());
602 assert!(notifications::resources_list_changed().is_ok());
603 assert!(notifications::prompts_list_changed().is_ok());
604 assert!(notifications::resource_updated("file:///test".to_string()).is_ok());
605 assert!(notifications::progress("token".to_string(), 0.5, Some(100.0)).is_ok());
606 assert!(notifications::log_message(
607 LoggingLevel::Info,
608 Some("test".to_string()),
609 json!({"message": "test log"})
610 )
611 .is_ok());
612 }
613}