Skip to main content

agenterra_rmcp/service/
server.rs

1use std::borrow::Cow;
2
3use thiserror::Error;
4
5use super::*;
6use crate::model::{
7    CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage,
8    ClientNotification, ClientRequest, ClientResult, CreateMessageRequest,
9    CreateMessageRequestParam, CreateMessageResult, ErrorData, ListRootsRequest, ListRootsResult,
10    LoggingMessageNotification, LoggingMessageNotificationParam, ProgressNotification,
11    ProgressNotificationParam, PromptListChangedNotification, ProtocolVersion,
12    ResourceListChangedNotification, ResourceUpdatedNotification, ResourceUpdatedNotificationParam,
13    ServerInfo, ServerNotification, ServerRequest, ServerResult, ToolListChangedNotification,
14};
15
16#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
17pub struct RoleServer;
18
19impl ServiceRole for RoleServer {
20    type Req = ServerRequest;
21    type Resp = ServerResult;
22    type Not = ServerNotification;
23    type PeerReq = ClientRequest;
24    type PeerResp = ClientResult;
25    type PeerNot = ClientNotification;
26    type Info = ServerInfo;
27    type PeerInfo = ClientInfo;
28
29    type InitializeError<E> = ServerInitializeError<E>;
30    const IS_CLIENT: bool = false;
31}
32
33/// It represents the error that may occur when serving the server.
34///
35/// if you want to handle the error, you can use `serve_server_with_ct` or `serve_server` with `Result<RunningService<RoleServer, S>, ServerError>`
36#[derive(Error, Debug)]
37pub enum ServerInitializeError<E> {
38    #[error("expect initialized request, but received: {0:?}")]
39    ExpectedInitializeRequest(Option<ClientJsonRpcMessage>),
40
41    #[error("expect initialized notification, but received: {0:?}")]
42    ExpectedInitializedNotification(Option<ClientJsonRpcMessage>),
43
44    #[error("connection closed: {0}")]
45    ConnectionClosed(String),
46
47    #[error("unexpected initialize result: {0:?}")]
48    UnexpectedInitializeResponse(ServerResult),
49
50    #[error("initialize failed: {0}")]
51    InitializeFailed(ErrorData),
52
53    #[error("unsupported protocol version: {0}")]
54    UnsupportedProtocolVersion(ProtocolVersion),
55
56    #[error("Send message error {error}, when {context}")]
57    TransportError {
58        error: E,
59        context: Cow<'static, str>,
60    },
61
62    #[error("Cancelled")]
63    Cancelled,
64}
65
66pub type ClientSink = Peer<RoleServer>;
67
68impl<S: Service<RoleServer>> ServiceExt<RoleServer> for S {
69    fn serve_with_ct<T, E, A>(
70        self,
71        transport: T,
72        ct: CancellationToken,
73    ) -> impl Future<Output = Result<RunningService<RoleServer, Self>, ServerInitializeError<E>>> + Send
74    where
75        T: IntoTransport<RoleServer, E, A>,
76        E: std::error::Error + Send + Sync + 'static,
77        Self: Sized,
78    {
79        serve_server_with_ct(self, transport, ct)
80    }
81}
82
83pub async fn serve_server<S, T, E, A>(
84    service: S,
85    transport: T,
86) -> Result<RunningService<RoleServer, S>, ServerInitializeError<E>>
87where
88    S: Service<RoleServer>,
89    T: IntoTransport<RoleServer, E, A>,
90    E: std::error::Error + Send + Sync + 'static,
91{
92    serve_server_with_ct(service, transport, CancellationToken::new()).await
93}
94
95/// Helper function to get the next message from the stream
96async fn expect_next_message<T, E>(
97    transport: &mut T,
98    context: &str,
99) -> Result<ClientJsonRpcMessage, ServerInitializeError<E>>
100where
101    T: Transport<RoleServer>,
102{
103    transport
104        .receive()
105        .await
106        .ok_or_else(|| ServerInitializeError::ConnectionClosed(context.to_string()))
107}
108
109/// Helper function to expect a request from the stream
110async fn expect_request<T, E>(
111    transport: &mut T,
112    context: &str,
113) -> Result<(ClientRequest, RequestId), ServerInitializeError<E>>
114where
115    T: Transport<RoleServer>,
116{
117    let msg = expect_next_message(transport, context).await?;
118    let msg_clone = msg.clone();
119    msg.into_request()
120        .ok_or(ServerInitializeError::ExpectedInitializeRequest(Some(
121            msg_clone,
122        )))
123}
124
125/// Helper function to expect a notification from the stream
126async fn expect_notification<T, E>(
127    transport: &mut T,
128    context: &str,
129) -> Result<ClientNotification, ServerInitializeError<E>>
130where
131    T: Transport<RoleServer>,
132{
133    let msg = expect_next_message(transport, context).await?;
134    let msg_clone = msg.clone();
135    msg.into_notification()
136        .ok_or(ServerInitializeError::ExpectedInitializedNotification(
137            Some(msg_clone),
138        ))
139}
140
141pub async fn serve_server_with_ct<S, T, E, A>(
142    service: S,
143    transport: T,
144    ct: CancellationToken,
145) -> Result<RunningService<RoleServer, S>, ServerInitializeError<E>>
146where
147    S: Service<RoleServer>,
148    T: IntoTransport<RoleServer, E, A>,
149    E: std::error::Error + Send + Sync + 'static,
150{
151    tokio::select! {
152        result = serve_server_with_ct_inner(service, transport, ct.clone()) => { result }
153        _ = ct.cancelled() => {
154            Err(ServerInitializeError::Cancelled)
155        }
156    }
157}
158
159async fn serve_server_with_ct_inner<S, T, E, A>(
160    service: S,
161    transport: T,
162    ct: CancellationToken,
163) -> Result<RunningService<RoleServer, S>, ServerInitializeError<E>>
164where
165    S: Service<RoleServer>,
166    T: IntoTransport<RoleServer, E, A>,
167    E: std::error::Error + Send + Sync + 'static,
168{
169    let mut transport = transport.into_transport();
170    let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
171
172    // Get initialize request
173    let (request, id) = expect_request(&mut transport, "initialized request").await?;
174
175    let ClientRequest::InitializeRequest(peer_info) = &request else {
176        return Err(ServerInitializeError::ExpectedInitializeRequest(Some(
177            ClientJsonRpcMessage::request(request, id),
178        )));
179    };
180    let (peer, peer_rx) = Peer::new(id_provider, Some(peer_info.params.clone()));
181    let context = RequestContext {
182        ct: ct.child_token(),
183        id: id.clone(),
184        meta: request.get_meta().clone(),
185        extensions: request.extensions().clone(),
186        peer: peer.clone(),
187    };
188    // Send initialize response
189    let init_response = service.handle_request(request.clone(), context).await;
190    let mut init_response = match init_response {
191        Ok(ServerResult::InitializeResult(init_response)) => init_response,
192        Ok(result) => {
193            return Err(ServerInitializeError::UnexpectedInitializeResponse(result));
194        }
195        Err(e) => {
196            transport
197                .send(ServerJsonRpcMessage::error(e.clone(), id))
198                .await
199                .map_err(|error| ServerInitializeError::TransportError {
200                    error,
201                    context: "sending error response".into(),
202                })?;
203            return Err(ServerInitializeError::InitializeFailed(e));
204        }
205    };
206    let peer_protocol_version = peer_info.params.protocol_version.clone();
207    let protocol_version = match peer_protocol_version
208        .partial_cmp(&init_response.protocol_version)
209        .ok_or(ServerInitializeError::UnsupportedProtocolVersion(
210            peer_protocol_version,
211        ))? {
212        std::cmp::Ordering::Less => peer_info.params.protocol_version.clone(),
213        _ => init_response.protocol_version,
214    };
215    init_response.protocol_version = protocol_version;
216    transport
217        .send(ServerJsonRpcMessage::response(
218            ServerResult::InitializeResult(init_response),
219            id,
220        ))
221        .await
222        .map_err(|error| ServerInitializeError::TransportError {
223            error,
224            context: "sending initialize response".into(),
225        })?;
226
227    // Wait for initialize notification
228    let notification = expect_notification(&mut transport, "initialize notification").await?;
229    let ClientNotification::InitializedNotification(_) = notification else {
230        return Err(ServerInitializeError::ExpectedInitializedNotification(
231            Some(ClientJsonRpcMessage::notification(notification)),
232        ));
233    };
234    let context = NotificationContext {
235        meta: notification.get_meta().clone(),
236        extensions: notification.extensions().clone(),
237        peer: peer.clone(),
238    };
239    let _ = service.handle_notification(notification, context).await;
240    // Continue processing service
241    Ok(serve_inner(service, transport, peer, peer_rx, ct))
242}
243
244macro_rules! method {
245    (peer_req $method:ident $Req:ident() => $Resp: ident ) => {
246        pub async fn $method(&self) -> Result<$Resp, ServiceError> {
247            let result = self
248                .send_request(ServerRequest::$Req($Req {
249                    method: Default::default(),
250                    extensions: Default::default(),
251                }))
252                .await?;
253            match result {
254                ClientResult::$Resp(result) => Ok(result),
255                _ => Err(ServiceError::UnexpectedResponse),
256            }
257        }
258    };
259    (peer_req $method:ident $Req:ident($Param: ident) => $Resp: ident ) => {
260        pub async fn $method(&self, params: $Param) -> Result<$Resp, ServiceError> {
261            let result = self
262                .send_request(ServerRequest::$Req($Req {
263                    method: Default::default(),
264                    params,
265                    extensions: Default::default(),
266                }))
267                .await?;
268            match result {
269                ClientResult::$Resp(result) => Ok(result),
270                _ => Err(ServiceError::UnexpectedResponse),
271            }
272        }
273    };
274    (peer_req $method:ident $Req:ident($Param: ident)) => {
275        pub fn $method(
276            &self,
277            params: $Param,
278        ) -> impl Future<Output = Result<(), ServiceError>> + Send + '_ {
279            async move {
280                let result = self
281                    .send_request(ServerRequest::$Req($Req {
282                        method: Default::default(),
283                        params,
284                    }))
285                    .await?;
286                match result {
287                    ClientResult::EmptyResult(_) => Ok(()),
288                    _ => Err(ServiceError::UnexpectedResponse),
289                }
290            }
291        }
292    };
293
294    (peer_not $method:ident $Not:ident($Param: ident)) => {
295        pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
296            self.send_notification(ServerNotification::$Not($Not {
297                method: Default::default(),
298                params,
299                extensions: Default::default(),
300            }))
301            .await?;
302            Ok(())
303        }
304    };
305    (peer_not $method:ident $Not:ident) => {
306        pub async fn $method(&self) -> Result<(), ServiceError> {
307            self.send_notification(ServerNotification::$Not($Not {
308                method: Default::default(),
309                extensions: Default::default(),
310            }))
311            .await?;
312            Ok(())
313        }
314    };
315}
316
317impl Peer<RoleServer> {
318    method!(peer_req create_message CreateMessageRequest(CreateMessageRequestParam) => CreateMessageResult);
319    method!(peer_req list_roots ListRootsRequest() => ListRootsResult);
320
321    method!(peer_not notify_cancelled CancelledNotification(CancelledNotificationParam));
322    method!(peer_not notify_progress ProgressNotification(ProgressNotificationParam));
323    method!(peer_not notify_logging_message LoggingMessageNotification(LoggingMessageNotificationParam));
324    method!(peer_not notify_resource_updated ResourceUpdatedNotification(ResourceUpdatedNotificationParam));
325    method!(peer_not notify_resource_list_changed ResourceListChangedNotification);
326    method!(peer_not notify_tool_list_changed ToolListChangedNotification);
327    method!(peer_not notify_prompt_list_changed PromptListChangedNotification);
328}