1use std::borrow::Cow;
2
3use thiserror::Error;
4
5use super::*;
6#[cfg(feature = "elicitation")]
7use crate::model::{
8 CreateElicitationRequest, CreateElicitationRequestParams, CreateElicitationResult,
9};
10use crate::{
11 model::{
12 CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage,
13 ClientNotification, ClientRequest, ClientResult, CreateMessageRequest,
14 CreateMessageRequestParams, CreateMessageResult, ErrorData, ListRootsRequest,
15 ListRootsResult, LoggingMessageNotification, LoggingMessageNotificationParam,
16 ProgressNotification, ProgressNotificationParam, PromptListChangedNotification,
17 ProtocolVersion, ResourceListChangedNotification, ResourceUpdatedNotification,
18 ResourceUpdatedNotificationParam, ServerInfo, ServerNotification, ServerRequest,
19 ServerResult, ToolListChangedNotification,
20 },
21 transport::DynamicTransportError,
22};
23
24#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
25pub struct RoleServer;
26
27impl ServiceRole for RoleServer {
28 type Req = ServerRequest;
29 type Resp = ServerResult;
30 type Not = ServerNotification;
31 type PeerReq = ClientRequest;
32 type PeerResp = ClientResult;
33 type PeerNot = ClientNotification;
34 type Info = ServerInfo;
35 type PeerInfo = ClientInfo;
36
37 type InitializeError = ServerInitializeError;
38 const IS_CLIENT: bool = false;
39}
40
41#[derive(Error, Debug)]
45pub enum ServerInitializeError {
46 #[error("expect initialized request, but received: {0:?}")]
47 ExpectedInitializeRequest(Option<ClientJsonRpcMessage>),
48
49 #[error("expect initialized notification, but received: {0:?}")]
50 ExpectedInitializedNotification(Option<ClientJsonRpcMessage>),
51
52 #[error("connection closed: {0}")]
53 ConnectionClosed(String),
54
55 #[error("unexpected initialize result: {0:?}")]
56 UnexpectedInitializeResponse(ServerResult),
57
58 #[error("initialize failed: {0}")]
59 InitializeFailed(ErrorData),
60
61 #[error("unsupported protocol version: {0}")]
62 UnsupportedProtocolVersion(ProtocolVersion),
63
64 #[error("Send message error {error}, when {context}")]
65 TransportError {
66 error: DynamicTransportError,
67 context: Cow<'static, str>,
68 },
69
70 #[error("Cancelled")]
71 Cancelled,
72}
73
74impl ServerInitializeError {
75 pub fn transport<T: Transport<RoleServer> + 'static>(
76 error: T::Error,
77 context: impl Into<Cow<'static, str>>,
78 ) -> Self {
79 Self::TransportError {
80 error: DynamicTransportError::new::<T, _>(error),
81 context: context.into(),
82 }
83 }
84}
85pub type ClientSink = Peer<RoleServer>;
86
87impl<S: Service<RoleServer>> ServiceExt<RoleServer> for S {
88 fn serve_with_ct<T, E, A>(
89 self,
90 transport: T,
91 ct: CancellationToken,
92 ) -> impl Future<Output = Result<RunningService<RoleServer, Self>, ServerInitializeError>> + Send
93 where
94 T: IntoTransport<RoleServer, E, A>,
95 E: std::error::Error + Send + Sync + 'static,
96 Self: Sized,
97 {
98 serve_server_with_ct(self, transport, ct)
99 }
100}
101
102pub async fn serve_server<S, T, E, A>(
103 service: S,
104 transport: T,
105) -> Result<RunningService<RoleServer, S>, ServerInitializeError>
106where
107 S: Service<RoleServer>,
108 T: IntoTransport<RoleServer, E, A>,
109 E: std::error::Error + Send + Sync + 'static,
110{
111 serve_server_with_ct(service, transport, CancellationToken::new()).await
112}
113
114async fn expect_next_message<T>(
116 transport: &mut T,
117 context: &str,
118) -> Result<ClientJsonRpcMessage, ServerInitializeError>
119where
120 T: Transport<RoleServer>,
121{
122 transport
123 .receive()
124 .await
125 .ok_or_else(|| ServerInitializeError::ConnectionClosed(context.to_string()))
126}
127
128async fn expect_request<T>(
130 transport: &mut T,
131 context: &str,
132) -> Result<(ClientRequest, RequestId), ServerInitializeError>
133where
134 T: Transport<RoleServer>,
135{
136 let msg = expect_next_message(transport, context).await?;
137 let msg_clone = msg.clone();
138 msg.into_request()
139 .ok_or(ServerInitializeError::ExpectedInitializeRequest(Some(
140 msg_clone,
141 )))
142}
143
144async fn expect_notification<T>(
146 transport: &mut T,
147 context: &str,
148) -> Result<ClientNotification, ServerInitializeError>
149where
150 T: Transport<RoleServer>,
151{
152 let msg = expect_next_message(transport, context).await?;
153 let msg_clone = msg.clone();
154 msg.into_notification()
155 .ok_or(ServerInitializeError::ExpectedInitializedNotification(
156 Some(msg_clone),
157 ))
158}
159
160pub async fn serve_server_with_ct<S, T, E, A>(
161 service: S,
162 transport: T,
163 ct: CancellationToken,
164) -> Result<RunningService<RoleServer, S>, ServerInitializeError>
165where
166 S: Service<RoleServer>,
167 T: IntoTransport<RoleServer, E, A>,
168 E: std::error::Error + Send + Sync + 'static,
169{
170 tokio::select! {
171 result = serve_server_with_ct_inner(service, transport.into_transport(), ct.clone()) => { result }
172 _ = ct.cancelled() => {
173 Err(ServerInitializeError::Cancelled)
174 }
175 }
176}
177
178async fn serve_server_with_ct_inner<S, T>(
179 service: S,
180 transport: T,
181 ct: CancellationToken,
182) -> Result<RunningService<RoleServer, S>, ServerInitializeError>
183where
184 S: Service<RoleServer>,
185 T: Transport<RoleServer> + 'static,
186{
187 let mut transport = transport.into_transport();
188 let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
189
190 let (request, id) = expect_request(&mut transport, "initialized request").await?;
192
193 let ClientRequest::InitializeRequest(peer_info) = &request else {
194 return Err(ServerInitializeError::ExpectedInitializeRequest(Some(
195 ClientJsonRpcMessage::request(request, id),
196 )));
197 };
198 let (peer, peer_rx) = Peer::new(id_provider, Some(peer_info.params.clone()));
199 let context = RequestContext {
200 ct: ct.child_token(),
201 id: id.clone(),
202 meta: request.get_meta().clone(),
203 extensions: request.extensions().clone(),
204 peer: peer.clone(),
205 };
206 let init_response = service.handle_request(request.clone(), context).await;
208 let mut init_response = match init_response {
209 Ok(ServerResult::InitializeResult(init_response)) => init_response,
210 Ok(result) => {
211 return Err(ServerInitializeError::UnexpectedInitializeResponse(result));
212 }
213 Err(e) => {
214 transport
215 .send(ServerJsonRpcMessage::error(e.clone(), id))
216 .await
217 .map_err(|error| {
218 ServerInitializeError::transport::<T>(error, "sending error response")
219 })?;
220 return Err(ServerInitializeError::InitializeFailed(e));
221 }
222 };
223 let peer_protocol_version = peer_info.params.protocol_version.clone();
224 let protocol_version = match peer_protocol_version
225 .partial_cmp(&init_response.protocol_version)
226 .ok_or(ServerInitializeError::UnsupportedProtocolVersion(
227 peer_protocol_version,
228 ))? {
229 std::cmp::Ordering::Less => peer_info.params.protocol_version.clone(),
230 _ => init_response.protocol_version,
231 };
232 init_response.protocol_version = protocol_version;
233 transport
234 .send(ServerJsonRpcMessage::response(
235 ServerResult::InitializeResult(init_response),
236 id,
237 ))
238 .await
239 .map_err(|error| {
240 ServerInitializeError::transport::<T>(error, "sending initialize response")
241 })?;
242
243 let notification = expect_notification(&mut transport, "initialize notification").await?;
245 let ClientNotification::InitializedNotification(_) = notification else {
246 return Err(ServerInitializeError::ExpectedInitializedNotification(
247 Some(ClientJsonRpcMessage::notification(notification)),
248 ));
249 };
250 let context = NotificationContext {
251 meta: notification.get_meta().clone(),
252 extensions: notification.extensions().clone(),
253 peer: peer.clone(),
254 };
255 let _ = service.handle_notification(notification, context).await;
256 Ok(serve_inner(service, transport, peer, peer_rx, ct))
258}
259
260macro_rules! method {
261 (peer_req $method:ident $Req:ident() => $Resp: ident ) => {
262 pub async fn $method(&self) -> Result<$Resp, ServiceError> {
263 let result = self
264 .send_request(ServerRequest::$Req($Req {
265 method: Default::default(),
266 extensions: Default::default(),
267 }))
268 .await?;
269 match result {
270 ClientResult::$Resp(result) => Ok(result),
271 _ => Err(ServiceError::UnexpectedResponse),
272 }
273 }
274 };
275 (peer_req $method:ident $Req:ident($Param: ident) => $Resp: ident ) => {
276 pub async fn $method(&self, params: $Param) -> Result<$Resp, ServiceError> {
277 let result = self
278 .send_request(ServerRequest::$Req($Req {
279 method: Default::default(),
280 params,
281 extensions: Default::default(),
282 }))
283 .await?;
284 match result {
285 ClientResult::$Resp(result) => Ok(result),
286 _ => Err(ServiceError::UnexpectedResponse),
287 }
288 }
289 };
290 (peer_req $method:ident $Req:ident($Param: ident)) => {
291 pub fn $method(
292 &self,
293 params: $Param,
294 ) -> impl Future<Output = Result<(), ServiceError>> + Send + '_ {
295 async move {
296 let result = self
297 .send_request(ServerRequest::$Req($Req {
298 method: Default::default(),
299 params,
300 }))
301 .await?;
302 match result {
303 ClientResult::EmptyResult(_) => Ok(()),
304 _ => Err(ServiceError::UnexpectedResponse),
305 }
306 }
307 }
308 };
309
310 (peer_not $method:ident $Not:ident($Param: ident)) => {
311 pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
312 self.send_notification(ServerNotification::$Not($Not {
313 method: Default::default(),
314 params,
315 extensions: Default::default(),
316 }))
317 .await?;
318 Ok(())
319 }
320 };
321 (peer_not $method:ident $Not:ident) => {
322 pub async fn $method(&self) -> Result<(), ServiceError> {
323 self.send_notification(ServerNotification::$Not($Not {
324 method: Default::default(),
325 extensions: Default::default(),
326 }))
327 .await?;
328 Ok(())
329 }
330 };
331
332 (peer_req_with_timeout $method_with_timeout:ident $Req:ident() => $Resp: ident) => {
334 pub async fn $method_with_timeout(
335 &self,
336 timeout: Option<std::time::Duration>,
337 ) -> Result<$Resp, ServiceError> {
338 let request = ServerRequest::$Req($Req {
339 method: Default::default(),
340 extensions: Default::default(),
341 });
342 let options = crate::service::PeerRequestOptions {
343 timeout,
344 meta: None,
345 };
346 let result = self
347 .send_request_with_option(request, options)
348 .await?
349 .await_response()
350 .await?;
351 match result {
352 ClientResult::$Resp(result) => Ok(result),
353 _ => Err(ServiceError::UnexpectedResponse),
354 }
355 }
356 };
357
358 (peer_req_with_timeout $method_with_timeout:ident $Req:ident($Param: ident) => $Resp: ident) => {
359 pub async fn $method_with_timeout(
360 &self,
361 params: $Param,
362 timeout: Option<std::time::Duration>,
363 ) -> Result<$Resp, ServiceError> {
364 let request = ServerRequest::$Req($Req {
365 method: Default::default(),
366 params,
367 extensions: Default::default(),
368 });
369 let options = crate::service::PeerRequestOptions {
370 timeout,
371 meta: None,
372 };
373 let result = self
374 .send_request_with_option(request, options)
375 .await?
376 .await_response()
377 .await?;
378 match result {
379 ClientResult::$Resp(result) => Ok(result),
380 _ => Err(ServiceError::UnexpectedResponse),
381 }
382 }
383 };
384}
385
386impl Peer<RoleServer> {
387 pub async fn create_message(
388 &self,
389 params: CreateMessageRequestParams,
390 ) -> Result<CreateMessageResult, ServiceError> {
391 let result = self
392 .send_request(ServerRequest::CreateMessageRequest(CreateMessageRequest {
393 method: Default::default(),
394 params,
395 extensions: Default::default(),
396 }))
397 .await?;
398 match result {
399 ClientResult::CreateMessageResult(result) => Ok(*result),
400 _ => Err(ServiceError::UnexpectedResponse),
401 }
402 }
403 method!(peer_req list_roots ListRootsRequest() => ListRootsResult);
404 #[cfg(feature = "elicitation")]
405 method!(peer_req create_elicitation CreateElicitationRequest(CreateElicitationRequestParams) => CreateElicitationResult);
406 #[cfg(feature = "elicitation")]
407 method!(peer_req_with_timeout create_elicitation_with_timeout CreateElicitationRequest(CreateElicitationRequestParams) => CreateElicitationResult);
408
409 method!(peer_not notify_cancelled CancelledNotification(CancelledNotificationParam));
410 method!(peer_not notify_progress ProgressNotification(ProgressNotificationParam));
411 method!(peer_not notify_logging_message LoggingMessageNotification(LoggingMessageNotificationParam));
412 method!(peer_not notify_resource_updated ResourceUpdatedNotification(ResourceUpdatedNotificationParam));
413 method!(peer_not notify_resource_list_changed ResourceListChangedNotification);
414 method!(peer_not notify_tool_list_changed ToolListChangedNotification);
415 method!(peer_not notify_prompt_list_changed PromptListChangedNotification);
416}
417
418#[cfg(feature = "elicitation")]
425#[derive(Error, Debug)]
426pub enum ElicitationError {
427 #[error("Service error: {0}")]
429 Service(#[from] ServiceError),
430
431 #[error("User explicitly declined the request")]
435 UserDeclined,
436
437 #[error("User cancelled/dismissed the request")]
441 UserCancelled,
442
443 #[error("Failed to parse response data: {error}\nReceived data: {data}")]
445 ParseError {
446 error: serde_json::Error,
447 data: serde_json::Value,
448 },
449
450 #[error("No response content provided")]
452 NoContent,
453
454 #[error("Client does not support elicitation - capability not declared during initialization")]
456 CapabilityNotSupported,
457}
458
459#[cfg(feature = "elicitation")]
477pub trait ElicitationSafe: schemars::JsonSchema {}
478
479#[cfg(feature = "elicitation")]
503#[macro_export]
504macro_rules! elicit_safe {
505 ($($t:ty),* $(,)?) => {
506 $(
507 impl $crate::service::ElicitationSafe for $t {}
508 )*
509 };
510}
511
512#[cfg(feature = "elicitation")]
513impl Peer<RoleServer> {
514 pub fn supports_elicitation(&self) -> bool {
520 if let Some(client_info) = self.peer_info() {
521 client_info.capabilities.elicitation.is_some()
522 } else {
523 false
524 }
525 }
526
527 #[cfg(all(feature = "schemars", feature = "elicitation"))]
605 pub async fn elicit<T>(&self, message: impl Into<String>) -> Result<Option<T>, ElicitationError>
606 where
607 T: ElicitationSafe + for<'de> serde::Deserialize<'de>,
608 {
609 self.elicit_with_timeout(message, None).await
610 }
611
612 #[cfg(all(feature = "schemars", feature = "elicitation"))]
666 pub async fn elicit_with_timeout<T>(
667 &self,
668 message: impl Into<String>,
669 timeout: Option<std::time::Duration>,
670 ) -> Result<Option<T>, ElicitationError>
671 where
672 T: ElicitationSafe + for<'de> serde::Deserialize<'de>,
673 {
674 if !self.supports_elicitation() {
676 return Err(ElicitationError::CapabilityNotSupported);
677 }
678
679 let schema = crate::model::ElicitationSchema::from_type::<T>().map_err(|e| {
681 ElicitationError::Service(ServiceError::McpError(crate::ErrorData::invalid_params(
682 format!(
683 "Invalid schema for type {}: {}",
684 std::any::type_name::<T>(),
685 e
686 ),
687 None,
688 )))
689 })?;
690
691 let response = self
692 .create_elicitation_with_timeout(
693 CreateElicitationRequestParams {
694 meta: None,
695 message: message.into(),
696 requested_schema: schema,
697 },
698 timeout,
699 )
700 .await?;
701
702 match response.action {
703 crate::model::ElicitationAction::Accept => {
704 if let Some(value) = response.content {
705 match serde_json::from_value::<T>(value.clone()) {
706 Ok(parsed) => Ok(Some(parsed)),
707 Err(error) => Err(ElicitationError::ParseError { error, data: value }),
708 }
709 } else {
710 Err(ElicitationError::NoContent)
711 }
712 }
713 crate::model::ElicitationAction::Decline => Err(ElicitationError::UserDeclined),
714 crate::model::ElicitationAction::Cancel => Err(ElicitationError::UserCancelled),
715 }
716 }
717}