1use std::borrow::Cow;
2#[cfg(feature = "elicitation")]
3use std::collections::HashSet;
4
5use thiserror::Error;
6#[cfg(feature = "elicitation")]
7use url::Url;
8
9use super::*;
10#[cfg(feature = "elicitation")]
11use crate::model::{
12 CreateElicitationRequest, CreateElicitationRequestParams, CreateElicitationResult,
13 ElicitationAction, ElicitationCompletionNotification, ElicitationResponseNotificationParam,
14};
15use crate::{
16 model::{
17 CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage,
18 ClientNotification, ClientRequest, ClientResult, CreateMessageRequest,
19 CreateMessageRequestParams, CreateMessageResult, EmptyResult, ErrorData, ListRootsRequest,
20 ListRootsResult, LoggingMessageNotification, LoggingMessageNotificationParam,
21 ProgressNotification, ProgressNotificationParam, PromptListChangedNotification,
22 ProtocolVersion, ResourceListChangedNotification, ResourceUpdatedNotification,
23 ResourceUpdatedNotificationParam, ServerInfo, ServerNotification, ServerRequest,
24 ServerResult, ToolListChangedNotification,
25 },
26 transport::DynamicTransportError,
27};
28
29#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
30pub struct RoleServer;
31
32impl ServiceRole for RoleServer {
33 type Req = ServerRequest;
34 type Resp = ServerResult;
35 type Not = ServerNotification;
36 type PeerReq = ClientRequest;
37 type PeerResp = ClientResult;
38 type PeerNot = ClientNotification;
39 type Info = ServerInfo;
40 type PeerInfo = ClientInfo;
41
42 type InitializeError = ServerInitializeError;
43 const IS_CLIENT: bool = false;
44}
45
46#[derive(Error, Debug)]
50#[non_exhaustive]
51pub enum ServerInitializeError {
52 #[error("expect initialized request, but received: {0:?}")]
53 ExpectedInitializeRequest(Option<ClientJsonRpcMessage>),
54
55 #[error("expect initialized notification, but received: {0:?}")]
56 ExpectedInitializedNotification(Option<ClientJsonRpcMessage>),
57
58 #[error("connection closed: {0}")]
59 ConnectionClosed(String),
60
61 #[error("unexpected initialize result: {0:?}")]
62 UnexpectedInitializeResponse(ServerResult),
63
64 #[error("initialize failed: {0}")]
65 InitializeFailed(ErrorData),
66
67 #[error("unsupported protocol version: {0}")]
68 UnsupportedProtocolVersion(ProtocolVersion),
69
70 #[error("Send message error {error}, when {context}")]
71 TransportError {
72 error: DynamicTransportError,
73 context: Cow<'static, str>,
74 },
75
76 #[error("Cancelled")]
77 Cancelled,
78}
79
80impl ServerInitializeError {
81 pub fn transport<T: Transport<RoleServer> + 'static>(
82 error: T::Error,
83 context: impl Into<Cow<'static, str>>,
84 ) -> Self {
85 Self::TransportError {
86 error: DynamicTransportError::new::<T, _>(error),
87 context: context.into(),
88 }
89 }
90}
91pub type ClientSink = Peer<RoleServer>;
92
93impl<S: Service<RoleServer>> ServiceExt<RoleServer> for S {
94 fn serve_with_ct<T, E, A>(
95 self,
96 transport: T,
97 ct: CancellationToken,
98 ) -> impl Future<Output = Result<RunningService<RoleServer, Self>, ServerInitializeError>> + Send
99 where
100 T: IntoTransport<RoleServer, E, A>,
101 E: std::error::Error + Send + Sync + 'static,
102 Self: Sized,
103 {
104 serve_server_with_ct(self, transport, ct)
105 }
106}
107
108pub async fn serve_server<S, T, E, A>(
109 service: S,
110 transport: T,
111) -> Result<RunningService<RoleServer, S>, ServerInitializeError>
112where
113 S: Service<RoleServer>,
114 T: IntoTransport<RoleServer, E, A>,
115 E: std::error::Error + Send + Sync + 'static,
116{
117 serve_server_with_ct(service, transport, CancellationToken::new()).await
118}
119
120async fn expect_next_message<T>(
122 transport: &mut T,
123 context: &str,
124) -> Result<ClientJsonRpcMessage, ServerInitializeError>
125where
126 T: Transport<RoleServer>,
127{
128 transport
129 .receive()
130 .await
131 .ok_or_else(|| ServerInitializeError::ConnectionClosed(context.to_string()))
132}
133
134async fn expect_request<T>(
136 transport: &mut T,
137 context: &str,
138) -> Result<(ClientRequest, RequestId), ServerInitializeError>
139where
140 T: Transport<RoleServer>,
141{
142 let msg = expect_next_message(transport, context).await?;
143 let msg_clone = msg.clone();
144 msg.into_request()
145 .ok_or(ServerInitializeError::ExpectedInitializeRequest(Some(
146 msg_clone,
147 )))
148}
149
150pub async fn serve_server_with_ct<S, T, E, A>(
151 service: S,
152 transport: T,
153 ct: CancellationToken,
154) -> Result<RunningService<RoleServer, S>, ServerInitializeError>
155where
156 S: Service<RoleServer>,
157 T: IntoTransport<RoleServer, E, A>,
158 E: std::error::Error + Send + Sync + 'static,
159{
160 tokio::select! {
161 result = serve_server_with_ct_inner(service, transport.into_transport(), ct.clone()) => { result }
162 _ = ct.cancelled() => {
163 Err(ServerInitializeError::Cancelled)
164 }
165 }
166}
167
168async fn serve_server_with_ct_inner<S, T>(
169 service: S,
170 transport: T,
171 ct: CancellationToken,
172) -> Result<RunningService<RoleServer, S>, ServerInitializeError>
173where
174 S: Service<RoleServer>,
175 T: Transport<RoleServer> + 'static,
176{
177 let mut transport = transport.into_transport();
178 let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
179
180 let (request, id) = expect_request(&mut transport, "initialized request").await?;
182
183 let ClientRequest::InitializeRequest(peer_info) = &request else {
184 return Err(ServerInitializeError::ExpectedInitializeRequest(Some(
185 ClientJsonRpcMessage::request(request, id),
186 )));
187 };
188 let (peer, peer_rx) = Peer::new(id_provider, Some(peer_info.params.clone()));
189 let context = RequestContext {
190 ct: ct.child_token(),
191 id: id.clone(),
192 meta: request.get_meta().clone(),
193 extensions: request.extensions().clone(),
194 peer: peer.clone(),
195 };
196 let init_response = service.handle_request(request.clone(), context).await;
198 let mut init_response = match init_response {
199 Ok(ServerResult::InitializeResult(init_response)) => init_response,
200 Ok(result) => {
201 return Err(ServerInitializeError::UnexpectedInitializeResponse(result));
202 }
203 Err(e) => {
204 transport
205 .send(ServerJsonRpcMessage::error(e.clone(), id))
206 .await
207 .map_err(|error| {
208 ServerInitializeError::transport::<T>(error, "sending error response")
209 })?;
210 return Err(ServerInitializeError::InitializeFailed(e));
211 }
212 };
213 let peer_protocol_version = peer_info.params.protocol_version.clone();
214 let protocol_version = match peer_protocol_version
215 .partial_cmp(&init_response.protocol_version)
216 .ok_or(ServerInitializeError::UnsupportedProtocolVersion(
217 peer_protocol_version,
218 ))? {
219 std::cmp::Ordering::Less => peer_info.params.protocol_version.clone(),
220 _ => init_response.protocol_version,
221 };
222 init_response.protocol_version = protocol_version;
223 transport
224 .send(ServerJsonRpcMessage::response(
225 ServerResult::InitializeResult(init_response),
226 id,
227 ))
228 .await
229 .map_err(|error| {
230 ServerInitializeError::transport::<T>(error, "sending initialize response")
231 })?;
232
233 let notification = loop {
236 let msg = expect_next_message(&mut transport, "initialize notification").await?;
237 match msg {
238 ClientJsonRpcMessage::Notification(n)
239 if matches!(
240 n.notification,
241 ClientNotification::InitializedNotification(_)
242 ) =>
243 {
244 break n.notification;
245 }
246 ClientJsonRpcMessage::Request(req)
247 if matches!(
248 req.request,
249 ClientRequest::SetLevelRequest(_) | ClientRequest::PingRequest(_)
250 ) =>
251 {
252 transport
253 .send(ServerJsonRpcMessage::response(
254 ServerResult::EmptyResult(EmptyResult {}),
255 req.id,
256 ))
257 .await
258 .map_err(|error| {
259 ServerInitializeError::transport::<T>(error, "sending pre-init response")
260 })?;
261 }
262 other => {
263 return Err(ServerInitializeError::ExpectedInitializedNotification(
264 Some(other),
265 ));
266 }
267 }
268 };
269 let context = NotificationContext {
270 meta: notification.get_meta().clone(),
271 extensions: notification.extensions().clone(),
272 peer: peer.clone(),
273 };
274 let _ = service.handle_notification(notification, context).await;
275 Ok(serve_inner(service, transport, peer, peer_rx, ct))
277}
278
279macro_rules! method {
280 (peer_req $method:ident $Req:ident() => $Resp: ident ) => {
281 pub async fn $method(&self) -> Result<$Resp, ServiceError> {
282 let result = self
283 .send_request(ServerRequest::$Req($Req {
284 method: Default::default(),
285 extensions: Default::default(),
286 }))
287 .await?;
288 match result {
289 ClientResult::$Resp(result) => Ok(result),
290 _ => Err(ServiceError::UnexpectedResponse),
291 }
292 }
293 };
294 (peer_req $method:ident $Req:ident($Param: ident) => $Resp: ident ) => {
295 pub async fn $method(&self, params: $Param) -> Result<$Resp, ServiceError> {
296 let result = self
297 .send_request(ServerRequest::$Req($Req {
298 method: Default::default(),
299 params,
300 extensions: Default::default(),
301 }))
302 .await?;
303 match result {
304 ClientResult::$Resp(result) => Ok(result),
305 _ => Err(ServiceError::UnexpectedResponse),
306 }
307 }
308 };
309 (peer_req $method:ident $Req:ident($Param: ident)) => {
310 pub fn $method(
311 &self,
312 params: $Param,
313 ) -> impl Future<Output = Result<(), ServiceError>> + Send + '_ {
314 async move {
315 let result = self
316 .send_request(ServerRequest::$Req($Req {
317 method: Default::default(),
318 params,
319 }))
320 .await?;
321 match result {
322 ClientResult::EmptyResult(_) => Ok(()),
323 _ => Err(ServiceError::UnexpectedResponse),
324 }
325 }
326 }
327 };
328
329 (peer_not $method:ident $Not:ident($Param: ident)) => {
330 pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
331 self.send_notification(ServerNotification::$Not($Not {
332 method: Default::default(),
333 params,
334 extensions: Default::default(),
335 }))
336 .await?;
337 Ok(())
338 }
339 };
340 (peer_not $method:ident $Not:ident) => {
341 pub async fn $method(&self) -> Result<(), ServiceError> {
342 self.send_notification(ServerNotification::$Not($Not {
343 method: Default::default(),
344 extensions: Default::default(),
345 }))
346 .await?;
347 Ok(())
348 }
349 };
350
351 (peer_req_with_timeout $method_with_timeout:ident $Req:ident() => $Resp: ident) => {
353 pub async fn $method_with_timeout(
354 &self,
355 timeout: Option<std::time::Duration>,
356 ) -> Result<$Resp, ServiceError> {
357 let request = ServerRequest::$Req($Req {
358 method: Default::default(),
359 extensions: Default::default(),
360 });
361 let options = crate::service::PeerRequestOptions {
362 timeout,
363 meta: None,
364 };
365 let result = self
366 .send_request_with_option(request, options)
367 .await?
368 .await_response()
369 .await?;
370 match result {
371 ClientResult::$Resp(result) => Ok(result),
372 _ => Err(ServiceError::UnexpectedResponse),
373 }
374 }
375 };
376
377 (peer_req_with_timeout $method_with_timeout:ident $Req:ident($Param: ident) => $Resp: ident) => {
378 pub async fn $method_with_timeout(
379 &self,
380 params: $Param,
381 timeout: Option<std::time::Duration>,
382 ) -> Result<$Resp, ServiceError> {
383 let request = ServerRequest::$Req($Req {
384 method: Default::default(),
385 params,
386 extensions: Default::default(),
387 });
388 let options = crate::service::PeerRequestOptions {
389 timeout,
390 meta: None,
391 };
392 let result = self
393 .send_request_with_option(request, options)
394 .await?
395 .await_response()
396 .await?;
397 match result {
398 ClientResult::$Resp(result) => Ok(result),
399 _ => Err(ServiceError::UnexpectedResponse),
400 }
401 }
402 };
403}
404
405impl Peer<RoleServer> {
406 pub fn supports_sampling_tools(&self) -> bool {
408 if let Some(client_info) = self.peer_info() {
409 client_info
410 .capabilities
411 .sampling
412 .as_ref()
413 .and_then(|s| s.tools.as_ref())
414 .is_some()
415 } else {
416 false
417 }
418 }
419
420 pub async fn create_message(
421 &self,
422 params: CreateMessageRequestParams,
423 ) -> Result<CreateMessageResult, ServiceError> {
424 if (params.tools.is_some() || params.tool_choice.is_some())
426 && !self.supports_sampling_tools()
427 {
428 return Err(ServiceError::McpError(ErrorData::invalid_params(
429 "tools or toolChoice provided but client does not support sampling tools capability",
430 None,
431 )));
432 }
433 params
435 .validate()
436 .map_err(|e| ServiceError::McpError(ErrorData::invalid_params(e, None)))?;
437 let result = self
438 .send_request(ServerRequest::CreateMessageRequest(CreateMessageRequest {
439 method: Default::default(),
440 params,
441 extensions: Default::default(),
442 }))
443 .await?;
444 match result {
445 ClientResult::CreateMessageResult(result) => Ok(*result),
446 _ => Err(ServiceError::UnexpectedResponse),
447 }
448 }
449 method!(peer_req list_roots ListRootsRequest() => ListRootsResult);
450 #[cfg(feature = "elicitation")]
451 method!(peer_req create_elicitation CreateElicitationRequest(CreateElicitationRequestParams) => CreateElicitationResult);
452 #[cfg(feature = "elicitation")]
453 method!(peer_req_with_timeout create_elicitation_with_timeout CreateElicitationRequest(CreateElicitationRequestParams) => CreateElicitationResult);
454 #[cfg(feature = "elicitation")]
455 method!(peer_not notify_url_elicitation_completed ElicitationCompletionNotification(ElicitationResponseNotificationParam));
456
457 method!(peer_not notify_cancelled CancelledNotification(CancelledNotificationParam));
458 method!(peer_not notify_progress ProgressNotification(ProgressNotificationParam));
459 method!(peer_not notify_logging_message LoggingMessageNotification(LoggingMessageNotificationParam));
460 method!(peer_not notify_resource_updated ResourceUpdatedNotification(ResourceUpdatedNotificationParam));
461 method!(peer_not notify_resource_list_changed ResourceListChangedNotification);
462 method!(peer_not notify_tool_list_changed ToolListChangedNotification);
463 method!(peer_not notify_prompt_list_changed PromptListChangedNotification);
464}
465
466#[cfg(feature = "elicitation")]
473#[derive(Error, Debug)]
474#[non_exhaustive]
475pub enum ElicitationError {
476 #[error("Service error: {0}")]
478 Service(#[from] ServiceError),
479
480 #[error("User explicitly declined the request")]
484 UserDeclined,
485
486 #[error("User cancelled/dismissed the request")]
490 UserCancelled,
491
492 #[error("Failed to parse response data: {error}\nReceived data: {data}")]
494 ParseError {
495 error: serde_json::Error,
496 data: serde_json::Value,
497 },
498
499 #[error("No response content provided")]
501 NoContent,
502
503 #[error("Client does not support elicitation - capability not declared during initialization")]
505 CapabilityNotSupported,
506}
507
508#[cfg(feature = "elicitation")]
526pub trait ElicitationSafe: schemars::JsonSchema {}
527
528#[cfg(feature = "elicitation")]
552#[macro_export]
553macro_rules! elicit_safe {
554 ($($t:ty),* $(,)?) => {
555 $(
556 impl $crate::service::ElicitationSafe for $t {}
557 )*
558 };
559}
560
561#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
562pub enum ElicitationMode {
563 Form,
564 Url,
565}
566
567#[cfg(feature = "elicitation")]
568impl Peer<RoleServer> {
569 pub fn supported_elicitation_modes(&self) -> HashSet<ElicitationMode> {
575 if let Some(client_info) = self.peer_info() {
576 if let Some(elicit_capability) = &client_info.capabilities.elicitation {
577 let mut modes = HashSet::new();
578 if elicit_capability.form.is_none() && elicit_capability.url.is_none() {
580 modes.insert(ElicitationMode::Form);
581 } else {
582 if elicit_capability.form.is_some() {
583 modes.insert(ElicitationMode::Form);
584 }
585 if elicit_capability.url.is_some() {
586 modes.insert(ElicitationMode::Url);
587 }
588 }
589 modes
590 } else {
591 HashSet::new()
592 }
593 } else {
594 HashSet::new()
595 }
596 }
597
598 #[cfg(all(feature = "schemars", feature = "elicitation"))]
676 pub async fn elicit<T>(&self, message: impl Into<String>) -> Result<Option<T>, ElicitationError>
677 where
678 T: ElicitationSafe + for<'de> serde::Deserialize<'de>,
679 {
680 self.elicit_with_timeout(message, None).await
681 }
682
683 #[cfg(all(feature = "schemars", feature = "elicitation"))]
737 pub async fn elicit_with_timeout<T>(
738 &self,
739 message: impl Into<String>,
740 timeout: Option<std::time::Duration>,
741 ) -> Result<Option<T>, ElicitationError>
742 where
743 T: ElicitationSafe + for<'de> serde::Deserialize<'de>,
744 {
745 if !self
747 .supported_elicitation_modes()
748 .contains(&ElicitationMode::Form)
749 {
750 return Err(ElicitationError::CapabilityNotSupported);
751 }
752
753 let schema = crate::model::ElicitationSchema::from_type::<T>().map_err(|e| {
755 ElicitationError::Service(ServiceError::McpError(crate::ErrorData::invalid_params(
756 format!(
757 "Invalid schema for type {}: {}",
758 std::any::type_name::<T>(),
759 e
760 ),
761 None,
762 )))
763 })?;
764
765 let response = self
766 .create_elicitation_with_timeout(
767 CreateElicitationRequestParams::FormElicitationParams {
768 meta: None,
769 message: message.into(),
770 requested_schema: schema,
771 },
772 timeout,
773 )
774 .await?;
775
776 match response.action {
777 crate::model::ElicitationAction::Accept => {
778 if let Some(value) = response.content {
779 match serde_json::from_value::<T>(value.clone()) {
780 Ok(parsed) => Ok(Some(parsed)),
781 Err(error) => Err(ElicitationError::ParseError { error, data: value }),
782 }
783 } else {
784 Err(ElicitationError::NoContent)
785 }
786 }
787 crate::model::ElicitationAction::Decline => Err(ElicitationError::UserDeclined),
788 crate::model::ElicitationAction::Cancel => Err(ElicitationError::UserCancelled),
789 }
790 }
791
792 #[cfg(feature = "elicitation")]
832 pub async fn elicit_url(
833 &self,
834 message: impl Into<String>,
835 url: impl Into<Url>,
836 elicitation_id: impl Into<String>,
837 ) -> Result<ElicitationAction, ElicitationError> {
838 self.elicit_url_with_timeout(message, url, elicitation_id, None)
839 .await
840 }
841
842 #[cfg(feature = "elicitation")]
883 pub async fn elicit_url_with_timeout(
884 &self,
885 message: impl Into<String>,
886 url: impl Into<Url>,
887 elicitation_id: impl Into<String>,
888 timeout: Option<std::time::Duration>,
889 ) -> Result<ElicitationAction, ElicitationError> {
890 if !self
892 .supported_elicitation_modes()
893 .contains(&ElicitationMode::Url)
894 {
895 return Err(ElicitationError::CapabilityNotSupported);
896 }
897
898 let action = self
899 .create_elicitation_with_timeout(
900 CreateElicitationRequestParams::UrlElicitationParams {
901 meta: None,
902 message: message.into(),
903 url: url.into().to_string(),
904 elicitation_id: elicitation_id.into(),
905 },
906 timeout,
907 )
908 .await?
909 .action;
910 Ok(action)
911 }
912}