mcp_sdk_rs/server/
mod.rs

1use async_trait::async_trait;
2use futures::StreamExt;
3use std::sync::Arc;
4use tokio::sync::RwLock;
5
6use crate::{
7    error::{Error, ErrorCode},
8    protocol::{Request, Response, ResponseError},
9    transport::{Message, Transport},
10    types::{ClientCapabilities, Implementation, ServerCapabilities},
11};
12
13/// Trait for implementing MCP server handlers
14#[async_trait]
15pub trait ServerHandler: Send + Sync {
16    /// Handle initialization
17    async fn initialize(
18        &self,
19        implementation: Implementation,
20        capabilities: ClientCapabilities,
21    ) -> Result<ServerCapabilities, Error>;
22
23    /// Handle shutdown request
24    async fn shutdown(&self) -> Result<(), Error>;
25
26    /// Handle custom method calls
27    async fn handle_method(
28        &self,
29        method: &str,
30        params: Option<serde_json::Value>,
31    ) -> Result<serde_json::Value, Error>;
32}
33
34/// Server state
35pub struct Server {
36    transport: Arc<dyn Transport>,
37    handler: Arc<dyn ServerHandler>,
38    initialized: Arc<RwLock<bool>>,
39}
40
41impl Server {
42    /// Create a new MCP server
43    pub fn new(transport: Arc<dyn Transport>, handler: Arc<dyn ServerHandler>) -> Self {
44        Self {
45            transport,
46            handler,
47            initialized: Arc::new(RwLock::new(false)),
48        }
49    }
50
51    /// Start the server
52    pub async fn start(&self) -> Result<(), Error> {
53        let mut stream = self.transport.receive();
54
55        while let Some(message) = stream.next().await {
56            match message? {
57                Message::Request(request) => {
58                    let response = match self.handle_request(request.clone()).await {
59                        Ok(response) => response,
60                        Err(err) => Response::error(request.id, ResponseError::from(err)),
61                    };
62                    self.transport.send(Message::Response(response)).await?;
63                }
64                Message::Notification(notification) => {
65                    match notification.method.as_str() {
66                        "exit" => break,
67                        "initialized" => {
68                            *self.initialized.write().await = true;
69                        }
70                        _ => {
71                            // Handle other notifications
72                        }
73                    }
74                }
75                Message::Response(_) => {
76                    // Server shouldn't receive responses
77                    return Err(Error::protocol(
78                        ErrorCode::InvalidRequest,
79                        "Server received unexpected response",
80                    ));
81                }
82            }
83        }
84
85        Ok(())
86    }
87
88    async fn handle_request(&self, request: Request) -> Result<Response, Error> {
89        let initialized = *self.initialized.read().await;
90
91        match request.method.as_str() {
92            "initialize" => {
93                if initialized {
94                    return Err(Error::protocol(
95                        ErrorCode::InvalidRequest,
96                        "Server already initialized",
97                    ));
98                }
99
100                let params: serde_json::Value = request.params.unwrap_or(serde_json::json!({}));
101                let implementation: Implementation = serde_json::from_value(
102                    params.get("implementation").cloned().unwrap_or_default(),
103                )?;
104                let capabilities: ClientCapabilities = serde_json::from_value(
105                    params.get("capabilities").cloned().unwrap_or_default(),
106                )?;
107
108                let result = self
109                    .handler
110                    .initialize(implementation, capabilities)
111                    .await?;
112                Ok(Response::success(
113                    request.id,
114                    Some(serde_json::to_value(result)?),
115                ))
116            }
117            "shutdown" => {
118                if !initialized {
119                    return Err(Error::protocol(
120                        ErrorCode::ServerNotInitialized,
121                        "Server not initialized",
122                    ));
123                }
124
125                self.handler.shutdown().await?;
126                Ok(Response::success(request.id, None))
127            }
128            _ => {
129                if !initialized {
130                    return Err(Error::protocol(
131                        ErrorCode::ServerNotInitialized,
132                        "Server not initialized",
133                    ));
134                }
135
136                let result = self
137                    .handler
138                    .handle_method(&request.method, request.params)
139                    .await?;
140                Ok(Response::success(request.id, Some(result)))
141            }
142        }
143    }
144}