1use async_trait::async_trait;
2use futures::StreamExt;
3use std::sync::Arc;
4use tokio::sync::RwLock;
5
6use crate::{
7 error::{Error, ErrorCode},
8 protocol::{Request, Response, ResponseError},
9 transport::{Message, Transport},
10 types::{ClientCapabilities, Implementation, ServerCapabilities},
11};
12
13#[async_trait]
15pub trait ServerHandler: Send + Sync {
16 async fn initialize(
18 &self,
19 implementation: Implementation,
20 capabilities: ClientCapabilities,
21 ) -> Result<ServerCapabilities, Error>;
22
23 async fn shutdown(&self) -> Result<(), Error>;
25
26 async fn handle_method(
28 &self,
29 method: &str,
30 params: Option<serde_json::Value>,
31 ) -> Result<serde_json::Value, Error>;
32}
33
34pub struct Server {
36 transport: Arc<dyn Transport>,
37 handler: Arc<dyn ServerHandler>,
38 initialized: Arc<RwLock<bool>>,
39}
40
41impl Server {
42 pub fn new(transport: Arc<dyn Transport>, handler: Arc<dyn ServerHandler>) -> Self {
44 Self {
45 transport,
46 handler,
47 initialized: Arc::new(RwLock::new(false)),
48 }
49 }
50
51 pub async fn start(&self) -> Result<(), Error> {
53 let mut stream = self.transport.receive();
54
55 while let Some(message) = stream.next().await {
56 match message? {
57 Message::Request(request) => {
58 let response = match self.handle_request(request.clone()).await {
59 Ok(response) => response,
60 Err(err) => Response::error(request.id, ResponseError::from(err)),
61 };
62 self.transport.send(Message::Response(response)).await?;
63 }
64 Message::Notification(notification) => {
65 match notification.method.as_str() {
66 "exit" => break,
67 "initialized" => {
68 *self.initialized.write().await = true;
69 }
70 _ => {
71 }
73 }
74 }
75 Message::Response(_) => {
76 return Err(Error::protocol(
78 ErrorCode::InvalidRequest,
79 "Server received unexpected response",
80 ));
81 }
82 }
83 }
84
85 Ok(())
86 }
87
88 async fn handle_request(&self, request: Request) -> Result<Response, Error> {
89 let initialized = *self.initialized.read().await;
90
91 match request.method.as_str() {
92 "initialize" => {
93 if initialized {
94 return Err(Error::protocol(
95 ErrorCode::InvalidRequest,
96 "Server already initialized",
97 ));
98 }
99
100 let params: serde_json::Value = request.params.unwrap_or(serde_json::json!({}));
101 let implementation: Implementation = serde_json::from_value(
102 params.get("implementation").cloned().unwrap_or_default(),
103 )?;
104 let capabilities: ClientCapabilities = serde_json::from_value(
105 params.get("capabilities").cloned().unwrap_or_default(),
106 )?;
107
108 let result = self
109 .handler
110 .initialize(implementation, capabilities)
111 .await?;
112 Ok(Response::success(
113 request.id,
114 Some(serde_json::to_value(result)?),
115 ))
116 }
117 "shutdown" => {
118 if !initialized {
119 return Err(Error::protocol(
120 ErrorCode::ServerNotInitialized,
121 "Server not initialized",
122 ));
123 }
124
125 self.handler.shutdown().await?;
126 Ok(Response::success(request.id, None))
127 }
128 _ => {
129 if !initialized {
130 return Err(Error::protocol(
131 ErrorCode::ServerNotInitialized,
132 "Server not initialized",
133 ));
134 }
135
136 let result = self
137 .handler
138 .handle_method(&request.method, request.params)
139 .await?;
140 Ok(Response::success(request.id, Some(result)))
141 }
142 }
143 }
144}