1use async_trait::async_trait;
44use parking_lot::RwLock;
45use std::collections::HashMap;
46use std::sync::Arc;
47use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
48use tracing::{debug, error, info, warn};
49
50use crate::error::McpError;
51use crate::protocol::*;
52
53#[derive(Debug, Clone)]
59pub struct ServerConfig {
60 pub name: String,
62 pub version: String,
64 pub instructions: Option<String>,
66 pub enable_tools: bool,
68 pub enable_resources: bool,
70 pub enable_prompts: bool,
72}
73
74impl Default for ServerConfig {
75 fn default() -> Self {
76 Self {
77 name: "cortexai-mcp-server".to_string(),
78 version: env!("CARGO_PKG_VERSION").to_string(),
79 instructions: None,
80 enable_tools: true,
81 enable_resources: false,
82 enable_prompts: false,
83 }
84 }
85}
86
87#[async_trait]
93pub trait ToolHandler: Send + Sync {
94 fn definition(&self) -> McpTool;
96
97 async fn execute(&self, arguments: serde_json::Value) -> Result<CallToolResult, McpError>;
99}
100
101#[async_trait]
103pub trait ResourceHandler: Send + Sync {
104 fn list(&self) -> Vec<McpResource>;
106
107 async fn read(&self, uri: &str) -> Result<ResourceContent, McpError>;
109}
110
111#[async_trait]
113pub trait PromptHandler: Send + Sync {
114 fn list(&self) -> Vec<McpPrompt>;
116
117 async fn get(
119 &self,
120 name: &str,
121 arguments: HashMap<String, String>,
122 ) -> Result<Vec<PromptMessage>, McpError>;
123}
124
125#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
127pub struct PromptMessage {
128 pub role: String,
129 pub content: PromptContent,
130}
131
132#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
134#[serde(tag = "type")]
135pub enum PromptContent {
136 #[serde(rename = "text")]
137 Text { text: String },
138 #[serde(rename = "image")]
139 Image { data: String, mime_type: String },
140 #[serde(rename = "resource")]
141 Resource { resource: ResourceContent },
142}
143
144pub struct FnTool<F>
150where
151 F: Fn(serde_json::Value) -> Result<serde_json::Value, String> + Send + Sync + 'static,
152{
153 definition: McpTool,
154 handler: F,
155}
156
157impl<F> FnTool<F>
158where
159 F: Fn(serde_json::Value) -> Result<serde_json::Value, String> + Send + Sync + 'static,
160{
161 pub fn new(
162 name: impl Into<String>,
163 description: impl Into<String>,
164 schema: serde_json::Value,
165 handler: F,
166 ) -> Self {
167 Self {
168 definition: McpTool {
169 name: name.into(),
170 description: Some(description.into()),
171 input_schema: schema,
172 },
173 handler,
174 }
175 }
176}
177
178#[async_trait]
179impl<F> ToolHandler for FnTool<F>
180where
181 F: Fn(serde_json::Value) -> Result<serde_json::Value, String> + Send + Sync + 'static,
182{
183 fn definition(&self) -> McpTool {
184 self.definition.clone()
185 }
186
187 async fn execute(&self, arguments: serde_json::Value) -> Result<CallToolResult, McpError> {
188 match (self.handler)(arguments) {
189 Ok(result) => Ok(CallToolResult {
190 content: vec![ToolContent::text(result.to_string())],
191 is_error: false,
192 }),
193 Err(e) => Ok(CallToolResult {
194 content: vec![ToolContent::text(e)],
195 is_error: true,
196 }),
197 }
198 }
199}
200
201pub struct AsyncFnTool<F, Fut>
203where
204 F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
205 Fut: std::future::Future<Output = Result<serde_json::Value, String>> + Send + 'static,
206{
207 definition: McpTool,
208 handler: F,
209}
210
211impl<F, Fut> AsyncFnTool<F, Fut>
212where
213 F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
214 Fut: std::future::Future<Output = Result<serde_json::Value, String>> + Send + 'static,
215{
216 pub fn new(
217 name: impl Into<String>,
218 description: impl Into<String>,
219 schema: serde_json::Value,
220 handler: F,
221 ) -> Self {
222 Self {
223 definition: McpTool {
224 name: name.into(),
225 description: Some(description.into()),
226 input_schema: schema,
227 },
228 handler,
229 }
230 }
231}
232
233#[async_trait]
234impl<F, Fut> ToolHandler for AsyncFnTool<F, Fut>
235where
236 F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
237 Fut: std::future::Future<Output = Result<serde_json::Value, String>> + Send + 'static,
238{
239 fn definition(&self) -> McpTool {
240 self.definition.clone()
241 }
242
243 async fn execute(&self, arguments: serde_json::Value) -> Result<CallToolResult, McpError> {
244 match (self.handler)(arguments).await {
245 Ok(result) => Ok(CallToolResult {
246 content: vec![ToolContent::text(result.to_string())],
247 is_error: false,
248 }),
249 Err(e) => Ok(CallToolResult {
250 content: vec![ToolContent::text(e)],
251 is_error: true,
252 }),
253 }
254 }
255}
256
257pub struct McpServer {
263 config: ServerConfig,
264 tools: RwLock<HashMap<String, Arc<dyn ToolHandler>>>,
265 resources: RwLock<Option<Arc<dyn ResourceHandler>>>,
266 prompts: RwLock<Option<Arc<dyn PromptHandler>>>,
267 initialized: RwLock<bool>,
268}
269
270impl McpServer {
271 pub fn new() -> Self {
273 Self::with_config(ServerConfig::default())
274 }
275
276 pub fn with_config(config: ServerConfig) -> Self {
278 Self {
279 config,
280 tools: RwLock::new(HashMap::new()),
281 resources: RwLock::new(None),
282 prompts: RwLock::new(None),
283 initialized: RwLock::new(false),
284 }
285 }
286
287 pub fn builder() -> McpServerBuilder {
289 McpServerBuilder::new()
290 }
291
292 pub fn tool_count(&self) -> usize {
294 self.tools.read().len()
295 }
296
297 pub fn add_tool(&self, handler: impl ToolHandler + 'static) {
299 let def = handler.definition();
300 self.tools
301 .write()
302 .insert(def.name.clone(), Arc::new(handler));
303 }
304
305 pub fn set_resource_handler(&self, handler: impl ResourceHandler + 'static) {
307 *self.resources.write() = Some(Arc::new(handler));
308 }
309
310 pub fn set_prompt_handler(&self, handler: impl PromptHandler + 'static) {
312 *self.prompts.write() = Some(Arc::new(handler));
313 }
314
315 pub async fn run_stdio(self: Arc<Self>) -> Result<(), McpError> {
317 info!(
318 "Starting MCP server '{}' v{} on STDIO",
319 self.config.name, self.config.version
320 );
321
322 let stdin = tokio::io::stdin();
323 let mut stdout = tokio::io::stdout();
324 let reader = BufReader::new(stdin);
325 let mut lines = reader.lines();
326
327 while let Ok(Some(line)) = lines.next_line().await {
328 debug!("Received: {}", line);
329
330 let response = self.handle_message_internal(&line).await;
331
332 if let Some(resp) = response {
333 let json = serde_json::to_string(&resp).unwrap();
334 debug!("Sending: {}", json);
335
336 if let Err(e) = stdout.write_all(format!("{}\n", json).as_bytes()).await {
337 error!("Failed to write response: {}", e);
338 break;
339 }
340 if let Err(e) = stdout.flush().await {
341 error!("Failed to flush response: {}", e);
342 break;
343 }
344 }
345 }
346
347 info!("MCP server shutting down");
348 Ok(())
349 }
350
351 async fn handle_message_internal(&self, message: &str) -> Option<JsonRpcResponse> {
353 if let Ok(request) = serde_json::from_str::<JsonRpcRequest>(message) {
355 let response = self.handle_request(request).await;
356 return Some(response);
357 }
358
359 if let Ok(notification) = serde_json::from_str::<JsonRpcNotification>(message) {
361 self.handle_notification(notification).await;
362 return None;
363 }
364
365 warn!("Failed to parse message: {}", message);
366 None
367 }
368
369 pub async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcResponse {
371 debug!("Handling request: {} (id={:?})", request.method, request.id);
372
373 let result = match request.method.as_str() {
374 "initialize" => self.handle_initialize(request.params).await,
375 "ping" => Ok(serde_json::json!({})),
376 "tools/list" => self.handle_tools_list(request.params).await,
377 "tools/call" => self.handle_tools_call(request.params).await,
378 "resources/list" => self.handle_resources_list(request.params).await,
379 "resources/read" => self.handle_resources_read(request.params).await,
380 "prompts/list" => self.handle_prompts_list(request.params).await,
381 "prompts/get" => self.handle_prompts_get(request.params).await,
382 _ => Err(McpError::MethodNotFound(request.method.clone())),
383 };
384
385 match result {
386 Ok(value) => JsonRpcResponse {
387 jsonrpc: "2.0".to_string(),
388 id: request.id,
389 result: Some(value),
390 error: None,
391 },
392 Err(e) => {
393 let (code, message) = match &e {
394 McpError::MethodNotFound(_) => (-32601, e.to_string()),
395 McpError::InvalidParams(msg) => (-32602, msg.clone()),
396 _ => (-32000, e.to_string()),
397 };
398 JsonRpcResponse {
399 jsonrpc: "2.0".to_string(),
400 id: request.id,
401 result: None,
402 error: Some(JsonRpcError {
403 code,
404 message,
405 data: None,
406 }),
407 }
408 }
409 }
410 }
411
412 async fn handle_notification(&self, notification: JsonRpcNotification) {
414 debug!("Handling notification: {}", notification.method);
415
416 match notification.method.as_str() {
417 "notifications/initialized" => {
418 info!("Client initialized");
419 }
420 "notifications/cancelled" => {
421 debug!("Request cancelled");
422 }
423 _ => {
424 debug!("Unknown notification: {}", notification.method);
425 }
426 }
427 }
428
429 async fn handle_initialize(
434 &self,
435 params: Option<serde_json::Value>,
436 ) -> Result<serde_json::Value, McpError> {
437 let _params: InitializeParams = params
438 .map(serde_json::from_value)
439 .transpose()
440 .map_err(|e| McpError::InvalidParams(e.to_string()))?
441 .unwrap_or_else(|| InitializeParams {
442 protocol_version: PROTOCOL_VERSION.to_string(),
443 capabilities: ClientCapabilities::default(),
444 client_info: Implementation {
445 name: "unknown".to_string(),
446 version: "0.0.0".to_string(),
447 },
448 });
449
450 *self.initialized.write() = true;
451
452 let result = InitializeResult {
453 protocol_version: PROTOCOL_VERSION.to_string(),
454 capabilities: ServerCapabilities {
455 tools: if self.config.enable_tools {
456 Some(ToolsCapability {
457 list_changed: Some(true),
458 })
459 } else {
460 None
461 },
462 resources: if self.config.enable_resources {
463 Some(ResourcesCapability {
464 subscribe: Some(false),
465 list_changed: Some(true),
466 })
467 } else {
468 None
469 },
470 prompts: if self.config.enable_prompts {
471 Some(PromptsCapability {
472 list_changed: Some(true),
473 })
474 } else {
475 None
476 },
477 logging: None,
478 experimental: None,
479 },
480 server_info: Implementation {
481 name: self.config.name.clone(),
482 version: self.config.version.clone(),
483 },
484 instructions: self.config.instructions.clone(),
485 };
486
487 serde_json::to_value(result).map_err(|e| McpError::Internal(e.to_string()))
488 }
489
490 async fn handle_tools_list(
491 &self,
492 _params: Option<serde_json::Value>,
493 ) -> Result<serde_json::Value, McpError> {
494 let tools = self.tools.read();
495 let tool_list: Vec<McpTool> = tools.values().map(|h| h.definition()).collect();
496
497 let result = ListToolsResult {
498 tools: tool_list,
499 next_cursor: None,
500 };
501
502 serde_json::to_value(result).map_err(|e| McpError::Internal(e.to_string()))
503 }
504
505 async fn handle_tools_call(
506 &self,
507 params: Option<serde_json::Value>,
508 ) -> Result<serde_json::Value, McpError> {
509 let params: CallToolParams = params
510 .map(serde_json::from_value)
511 .transpose()
512 .map_err(|e| McpError::InvalidParams(e.to_string()))?
513 .ok_or_else(|| McpError::InvalidParams("Missing params".to_string()))?;
514
515 let handler = {
516 let tools = self.tools.read();
517 tools
518 .get(¶ms.name)
519 .ok_or_else(|| McpError::ToolNotFound(params.name.clone()))?
520 .clone()
521 };
522
523 let arguments = params.arguments.unwrap_or(serde_json::json!({}));
524 let result = handler.execute(arguments).await?;
525
526 serde_json::to_value(result).map_err(|e| McpError::Internal(e.to_string()))
527 }
528
529 async fn handle_resources_list(
530 &self,
531 _params: Option<serde_json::Value>,
532 ) -> Result<serde_json::Value, McpError> {
533 let handler = self.resources.read();
534 let resources = handler.as_ref().map(|h| h.list()).unwrap_or_default();
535
536 let result = ListResourcesResult {
537 resources,
538 next_cursor: None,
539 };
540
541 serde_json::to_value(result).map_err(|e| McpError::Internal(e.to_string()))
542 }
543
544 async fn handle_resources_read(
545 &self,
546 params: Option<serde_json::Value>,
547 ) -> Result<serde_json::Value, McpError> {
548 #[derive(serde::Deserialize)]
549 struct ReadParams {
550 uri: String,
551 }
552
553 let params: ReadParams = params
554 .map(serde_json::from_value)
555 .transpose()
556 .map_err(|e| McpError::InvalidParams(e.to_string()))?
557 .ok_or_else(|| McpError::InvalidParams("Missing uri".to_string()))?;
558
559 let handler = {
560 let guard = self.resources.read();
561 guard
562 .as_ref()
563 .ok_or_else(|| McpError::CapabilityNotSupported("resources".to_string()))?
564 .clone()
565 };
566
567 let content = handler.read(¶ms.uri).await?;
568
569 let result = serde_json::json!({
570 "contents": [content]
571 });
572
573 Ok(result)
574 }
575
576 async fn handle_prompts_list(
577 &self,
578 _params: Option<serde_json::Value>,
579 ) -> Result<serde_json::Value, McpError> {
580 let handler = self.prompts.read();
581 let prompts = handler.as_ref().map(|h| h.list()).unwrap_or_default();
582
583 let result = ListPromptsResult {
584 prompts,
585 next_cursor: None,
586 };
587
588 serde_json::to_value(result).map_err(|e| McpError::Internal(e.to_string()))
589 }
590
591 async fn handle_prompts_get(
592 &self,
593 params: Option<serde_json::Value>,
594 ) -> Result<serde_json::Value, McpError> {
595 #[derive(serde::Deserialize)]
596 struct GetParams {
597 name: String,
598 #[serde(default)]
599 arguments: HashMap<String, String>,
600 }
601
602 let params: GetParams = params
603 .map(serde_json::from_value)
604 .transpose()
605 .map_err(|e| McpError::InvalidParams(e.to_string()))?
606 .ok_or_else(|| McpError::InvalidParams("Missing name".to_string()))?;
607
608 let handler = {
609 let guard = self.prompts.read();
610 guard
611 .as_ref()
612 .ok_or_else(|| McpError::CapabilityNotSupported("prompts".to_string()))?
613 .clone()
614 };
615
616 let messages = handler.get(¶ms.name, params.arguments).await?;
617
618 let result = serde_json::json!({
619 "messages": messages
620 });
621
622 Ok(result)
623 }
624}
625
626impl Default for McpServer {
627 fn default() -> Self {
628 Self::new()
629 }
630}
631
632pub struct McpServerBuilder {
638 config: ServerConfig,
639 tools: Vec<Arc<dyn ToolHandler>>,
640 resource_handler: Option<Arc<dyn ResourceHandler>>,
641 prompt_handler: Option<Arc<dyn PromptHandler>>,
642}
643
644impl McpServerBuilder {
645 pub fn new() -> Self {
646 Self {
647 config: ServerConfig::default(),
648 tools: Vec::new(),
649 resource_handler: None,
650 prompt_handler: None,
651 }
652 }
653
654 pub fn name(mut self, name: impl Into<String>) -> Self {
655 self.config.name = name.into();
656 self
657 }
658
659 pub fn version(mut self, version: impl Into<String>) -> Self {
660 self.config.version = version.into();
661 self
662 }
663
664 pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
665 self.config.instructions = Some(instructions.into());
666 self
667 }
668
669 pub fn enable_resources(mut self, enable: bool) -> Self {
670 self.config.enable_resources = enable;
671 self
672 }
673
674 pub fn enable_prompts(mut self, enable: bool) -> Self {
675 self.config.enable_prompts = enable;
676 self
677 }
678
679 pub fn add_tool(mut self, handler: impl ToolHandler + 'static) -> Self {
680 self.tools.push(Arc::new(handler));
681 self
682 }
683
684 pub fn resource_handler(mut self, handler: impl ResourceHandler + 'static) -> Self {
685 self.config.enable_resources = true;
686 self.resource_handler = Some(Arc::new(handler));
687 self
688 }
689
690 pub fn prompt_handler(mut self, handler: impl PromptHandler + 'static) -> Self {
691 self.config.enable_prompts = true;
692 self.prompt_handler = Some(Arc::new(handler));
693 self
694 }
695
696 pub fn build(self) -> Arc<McpServer> {
697 let server = McpServer::with_config(self.config);
698
699 for tool in self.tools {
700 let def = tool.definition();
701 server.tools.write().insert(def.name.clone(), tool);
702 }
703
704 if let Some(handler) = self.resource_handler {
705 *server.resources.write() = Some(handler);
706 }
707
708 if let Some(handler) = self.prompt_handler {
709 *server.prompts.write() = Some(handler);
710 }
711
712 Arc::new(server)
713 }
714}
715
716impl Default for McpServerBuilder {
717 fn default() -> Self {
718 Self::new()
719 }
720}
721
722#[cfg(test)]
723mod tests {
724 use super::*;
725 use serde_json::json;
726
727 struct EchoTool;
728
729 #[async_trait]
730 impl ToolHandler for EchoTool {
731 fn definition(&self) -> McpTool {
732 McpTool {
733 name: "echo".to_string(),
734 description: Some("Echoes the input".to_string()),
735 input_schema: json!({
736 "type": "object",
737 "properties": {
738 "message": {"type": "string"}
739 },
740 "required": ["message"]
741 }),
742 }
743 }
744
745 async fn execute(&self, arguments: serde_json::Value) -> Result<CallToolResult, McpError> {
746 let message = arguments
747 .get("message")
748 .and_then(|v| v.as_str())
749 .unwrap_or("No message");
750
751 Ok(CallToolResult {
752 content: vec![ToolContent::text(message)],
753 is_error: false,
754 })
755 }
756 }
757
758 #[tokio::test]
759 async fn test_server_builder() {
760 let server = McpServer::builder()
761 .name("test-server")
762 .version("1.0.0")
763 .add_tool(EchoTool)
764 .build();
765
766 assert_eq!(server.config.name, "test-server");
767 assert_eq!(server.config.version, "1.0.0");
768 assert!(server.tools.read().contains_key("echo"));
769 }
770
771 #[tokio::test]
772 async fn test_handle_initialize() {
773 let server = McpServer::builder()
774 .name("test-server")
775 .version("1.0.0")
776 .build();
777
778 let request = JsonRpcRequest::new(1i64, "initialize").with_params(json!({
779 "protocolVersion": PROTOCOL_VERSION,
780 "capabilities": {},
781 "clientInfo": {
782 "name": "test-client",
783 "version": "1.0.0"
784 }
785 }));
786
787 let response = server.handle_request(request).await;
788
789 assert!(response.error.is_none());
790 assert!(response.result.is_some());
791
792 let result = response.result.unwrap();
793 assert_eq!(result["serverInfo"]["name"], "test-server");
794 assert_eq!(result["protocolVersion"], PROTOCOL_VERSION);
795 }
796
797 #[tokio::test]
798 async fn test_handle_tools_list() {
799 let server = McpServer::builder()
800 .name("test-server")
801 .add_tool(EchoTool)
802 .build();
803
804 *server.initialized.write() = true;
806
807 let request = JsonRpcRequest::new(1i64, "tools/list");
808 let response = server.handle_request(request).await;
809
810 assert!(response.error.is_none());
811
812 let result = response.result.unwrap();
813 let tools = result["tools"].as_array().unwrap();
814
815 assert_eq!(tools.len(), 1);
816 assert_eq!(tools[0]["name"], "echo");
817 }
818
819 #[tokio::test]
820 async fn test_handle_tools_call() {
821 let server = McpServer::builder()
822 .name("test-server")
823 .add_tool(EchoTool)
824 .build();
825
826 *server.initialized.write() = true;
827
828 let request = JsonRpcRequest::new(1i64, "tools/call").with_params(json!({
829 "name": "echo",
830 "arguments": {
831 "message": "Hello, MCP!"
832 }
833 }));
834
835 let response = server.handle_request(request).await;
836
837 assert!(response.error.is_none());
838
839 let result = response.result.unwrap();
840 assert_eq!(result["isError"], false);
841
842 let content = &result["content"][0];
843 assert_eq!(content["type"], "text");
844 assert_eq!(content["text"], "Hello, MCP!");
845 }
846
847 #[tokio::test]
848 async fn test_handle_unknown_method() {
849 let server = McpServer::new();
850
851 let request = JsonRpcRequest::new(1i64, "unknown/method");
852 let response = server.handle_request(request).await;
853
854 assert!(response.error.is_some());
855 assert_eq!(response.error.unwrap().code, -32601);
856 }
857
858 #[tokio::test]
859 async fn test_fn_tool() {
860 let tool = FnTool::new(
861 "add",
862 "Adds two numbers",
863 json!({
864 "type": "object",
865 "properties": {
866 "a": {"type": "number"},
867 "b": {"type": "number"}
868 }
869 }),
870 |args| {
871 let a = args["a"].as_f64().unwrap_or(0.0);
872 let b = args["b"].as_f64().unwrap_or(0.0);
873 Ok(json!(a + b))
874 },
875 );
876
877 let result = tool.execute(json!({"a": 2, "b": 3})).await.unwrap();
878 assert!(!result.is_error);
879 let text = result.content[0].as_text().unwrap();
881 assert!(text == "5" || text == "5.0");
882 }
883}