1use std::borrow::Cow;
2
3use thiserror::Error;
4
5use super::*;
6use crate::model::{
7 CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage,
8 ClientNotification, ClientRequest, ClientResult, CreateMessageRequest,
9 CreateMessageRequestParam, CreateMessageResult, ErrorData, ListRootsRequest, ListRootsResult,
10 LoggingMessageNotification, LoggingMessageNotificationParam, ProgressNotification,
11 ProgressNotificationParam, PromptListChangedNotification, ProtocolVersion,
12 ResourceListChangedNotification, ResourceUpdatedNotification, ResourceUpdatedNotificationParam,
13 ServerInfo, ServerNotification, ServerRequest, ServerResult, ToolListChangedNotification,
14};
15
16#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
17pub struct RoleServer;
18
19impl ServiceRole for RoleServer {
20 type Req = ServerRequest;
21 type Resp = ServerResult;
22 type Not = ServerNotification;
23 type PeerReq = ClientRequest;
24 type PeerResp = ClientResult;
25 type PeerNot = ClientNotification;
26 type Info = ServerInfo;
27 type PeerInfo = ClientInfo;
28
29 type InitializeError<E> = ServerInitializeError<E>;
30 const IS_CLIENT: bool = false;
31}
32
33#[derive(Error, Debug)]
37pub enum ServerInitializeError<E> {
38 #[error("expect initialized request, but received: {0:?}")]
39 ExpectedInitializeRequest(Option<ClientJsonRpcMessage>),
40
41 #[error("expect initialized notification, but received: {0:?}")]
42 ExpectedInitializedNotification(Option<ClientJsonRpcMessage>),
43
44 #[error("connection closed: {0}")]
45 ConnectionClosed(String),
46
47 #[error("unexpected initialize result: {0:?}")]
48 UnexpectedInitializeResponse(ServerResult),
49
50 #[error("initialize failed: {0}")]
51 InitializeFailed(ErrorData),
52
53 #[error("unsupported protocol version: {0}")]
54 UnsupportedProtocolVersion(ProtocolVersion),
55
56 #[error("Send message error {error}, when {context}")]
57 TransportError {
58 error: E,
59 context: Cow<'static, str>,
60 },
61
62 #[error("Cancelled")]
63 Cancelled,
64}
65
66pub type ClientSink = Peer<RoleServer>;
67
68impl<S: Service<RoleServer>> ServiceExt<RoleServer> for S {
69 fn serve_with_ct<T, E, A>(
70 self,
71 transport: T,
72 ct: CancellationToken,
73 ) -> impl Future<Output = Result<RunningService<RoleServer, Self>, ServerInitializeError<E>>> + Send
74 where
75 T: IntoTransport<RoleServer, E, A>,
76 E: std::error::Error + Send + Sync + 'static,
77 Self: Sized,
78 {
79 serve_server_with_ct(self, transport, ct)
80 }
81}
82
83pub async fn serve_server<S, T, E, A>(
84 service: S,
85 transport: T,
86) -> Result<RunningService<RoleServer, S>, ServerInitializeError<E>>
87where
88 S: Service<RoleServer>,
89 T: IntoTransport<RoleServer, E, A>,
90 E: std::error::Error + Send + Sync + 'static,
91{
92 serve_server_with_ct(service, transport, CancellationToken::new()).await
93}
94
95async fn expect_next_message<T, E>(
97 transport: &mut T,
98 context: &str,
99) -> Result<ClientJsonRpcMessage, ServerInitializeError<E>>
100where
101 T: Transport<RoleServer>,
102{
103 transport
104 .receive()
105 .await
106 .ok_or_else(|| ServerInitializeError::ConnectionClosed(context.to_string()))
107}
108
109async fn expect_request<T, E>(
111 transport: &mut T,
112 context: &str,
113) -> Result<(ClientRequest, RequestId), ServerInitializeError<E>>
114where
115 T: Transport<RoleServer>,
116{
117 let msg = expect_next_message(transport, context).await?;
118 let msg_clone = msg.clone();
119 msg.into_request()
120 .ok_or(ServerInitializeError::ExpectedInitializeRequest(Some(
121 msg_clone,
122 )))
123}
124
125async fn expect_notification<T, E>(
127 transport: &mut T,
128 context: &str,
129) -> Result<ClientNotification, ServerInitializeError<E>>
130where
131 T: Transport<RoleServer>,
132{
133 let msg = expect_next_message(transport, context).await?;
134 let msg_clone = msg.clone();
135 msg.into_notification()
136 .ok_or(ServerInitializeError::ExpectedInitializedNotification(
137 Some(msg_clone),
138 ))
139}
140
141pub async fn serve_server_with_ct<S, T, E, A>(
142 service: S,
143 transport: T,
144 ct: CancellationToken,
145) -> Result<RunningService<RoleServer, S>, ServerInitializeError<E>>
146where
147 S: Service<RoleServer>,
148 T: IntoTransport<RoleServer, E, A>,
149 E: std::error::Error + Send + Sync + 'static,
150{
151 tokio::select! {
152 result = serve_server_with_ct_inner(service, transport, ct.clone()) => { result }
153 _ = ct.cancelled() => {
154 Err(ServerInitializeError::Cancelled)
155 }
156 }
157}
158
159async fn serve_server_with_ct_inner<S, T, E, A>(
160 service: S,
161 transport: T,
162 ct: CancellationToken,
163) -> Result<RunningService<RoleServer, S>, ServerInitializeError<E>>
164where
165 S: Service<RoleServer>,
166 T: IntoTransport<RoleServer, E, A>,
167 E: std::error::Error + Send + Sync + 'static,
168{
169 let mut transport = transport.into_transport();
170 let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
171
172 let (request, id) = expect_request(&mut transport, "initialized request").await?;
174
175 let ClientRequest::InitializeRequest(peer_info) = &request else {
176 return Err(ServerInitializeError::ExpectedInitializeRequest(Some(
177 ClientJsonRpcMessage::request(request, id),
178 )));
179 };
180 let (peer, peer_rx) = Peer::new(id_provider, Some(peer_info.params.clone()));
181 let context = RequestContext {
182 ct: ct.child_token(),
183 id: id.clone(),
184 meta: request.get_meta().clone(),
185 extensions: request.extensions().clone(),
186 peer: peer.clone(),
187 };
188 let init_response = service.handle_request(request.clone(), context).await;
190 let mut init_response = match init_response {
191 Ok(ServerResult::InitializeResult(init_response)) => init_response,
192 Ok(result) => {
193 return Err(ServerInitializeError::UnexpectedInitializeResponse(result));
194 }
195 Err(e) => {
196 transport
197 .send(ServerJsonRpcMessage::error(e.clone(), id))
198 .await
199 .map_err(|error| ServerInitializeError::TransportError {
200 error,
201 context: "sending error response".into(),
202 })?;
203 return Err(ServerInitializeError::InitializeFailed(e));
204 }
205 };
206 let peer_protocol_version = peer_info.params.protocol_version.clone();
207 let protocol_version = match peer_protocol_version
208 .partial_cmp(&init_response.protocol_version)
209 .ok_or(ServerInitializeError::UnsupportedProtocolVersion(
210 peer_protocol_version,
211 ))? {
212 std::cmp::Ordering::Less => peer_info.params.protocol_version.clone(),
213 _ => init_response.protocol_version,
214 };
215 init_response.protocol_version = protocol_version;
216 transport
217 .send(ServerJsonRpcMessage::response(
218 ServerResult::InitializeResult(init_response),
219 id,
220 ))
221 .await
222 .map_err(|error| ServerInitializeError::TransportError {
223 error,
224 context: "sending initialize response".into(),
225 })?;
226
227 let notification = expect_notification(&mut transport, "initialize notification").await?;
229 let ClientNotification::InitializedNotification(_) = notification else {
230 return Err(ServerInitializeError::ExpectedInitializedNotification(
231 Some(ClientJsonRpcMessage::notification(notification)),
232 ));
233 };
234 let context = NotificationContext {
235 meta: notification.get_meta().clone(),
236 extensions: notification.extensions().clone(),
237 peer: peer.clone(),
238 };
239 let _ = service.handle_notification(notification, context).await;
240 Ok(serve_inner(service, transport, peer, peer_rx, ct))
242}
243
244macro_rules! method {
245 (peer_req $method:ident $Req:ident() => $Resp: ident ) => {
246 pub async fn $method(&self) -> Result<$Resp, ServiceError> {
247 let result = self
248 .send_request(ServerRequest::$Req($Req {
249 method: Default::default(),
250 extensions: Default::default(),
251 }))
252 .await?;
253 match result {
254 ClientResult::$Resp(result) => Ok(result),
255 _ => Err(ServiceError::UnexpectedResponse),
256 }
257 }
258 };
259 (peer_req $method:ident $Req:ident($Param: ident) => $Resp: ident ) => {
260 pub async fn $method(&self, params: $Param) -> Result<$Resp, ServiceError> {
261 let result = self
262 .send_request(ServerRequest::$Req($Req {
263 method: Default::default(),
264 params,
265 extensions: Default::default(),
266 }))
267 .await?;
268 match result {
269 ClientResult::$Resp(result) => Ok(result),
270 _ => Err(ServiceError::UnexpectedResponse),
271 }
272 }
273 };
274 (peer_req $method:ident $Req:ident($Param: ident)) => {
275 pub fn $method(
276 &self,
277 params: $Param,
278 ) -> impl Future<Output = Result<(), ServiceError>> + Send + '_ {
279 async move {
280 let result = self
281 .send_request(ServerRequest::$Req($Req {
282 method: Default::default(),
283 params,
284 }))
285 .await?;
286 match result {
287 ClientResult::EmptyResult(_) => Ok(()),
288 _ => Err(ServiceError::UnexpectedResponse),
289 }
290 }
291 }
292 };
293
294 (peer_not $method:ident $Not:ident($Param: ident)) => {
295 pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
296 self.send_notification(ServerNotification::$Not($Not {
297 method: Default::default(),
298 params,
299 extensions: Default::default(),
300 }))
301 .await?;
302 Ok(())
303 }
304 };
305 (peer_not $method:ident $Not:ident) => {
306 pub async fn $method(&self) -> Result<(), ServiceError> {
307 self.send_notification(ServerNotification::$Not($Not {
308 method: Default::default(),
309 extensions: Default::default(),
310 }))
311 .await?;
312 Ok(())
313 }
314 };
315}
316
317impl Peer<RoleServer> {
318 method!(peer_req create_message CreateMessageRequest(CreateMessageRequestParam) => CreateMessageResult);
319 method!(peer_req list_roots ListRootsRequest() => ListRootsResult);
320
321 method!(peer_not notify_cancelled CancelledNotification(CancelledNotificationParam));
322 method!(peer_not notify_progress ProgressNotification(ProgressNotificationParam));
323 method!(peer_not notify_logging_message LoggingMessageNotification(LoggingMessageNotificationParam));
324 method!(peer_not notify_resource_updated ResourceUpdatedNotification(ResourceUpdatedNotificationParam));
325 method!(peer_not notify_resource_list_changed ResourceListChangedNotification);
326 method!(peer_not notify_tool_list_changed ToolListChangedNotification);
327 method!(peer_not notify_prompt_list_changed PromptListChangedNotification);
328}