mmcp_rpc/
lib.rs

1use std::collections::HashMap;
2
3use futures::{
4    Sink, SinkExt, Stream, StreamExt,
5    channel::{mpsc, oneshot},
6};
7use mmcp_protocol::{
8    mcp::{
9        JSONRPCError, JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse,
10        JsonrpcBatchResponseItem, JsonrpcNotificationParams, JsonrpcRequestParams, RequestId,
11        Result as JsonRpcResult,
12    },
13    port::{RPCPort, RPCPortError, RPCSink},
14};
15use serde_json::Value;
16
17type ResponseSubscriber = oneshot::Sender<Result<JSONRPCResponse, JSONRPCError>>;
18
19enum Command {
20    WaitResponse {
21        request_id: RequestId,
22        response: ResponseSubscriber,
23    },
24}
25
26#[derive(Clone)]
27pub struct RPCSender<T> {
28    rpc_tx: T,
29    command_tx: mpsc::Sender<Command>,
30}
31
32pub struct RPCRuntime<T, R> {
33    rpc_tx: T,
34    rpc_rx: R,
35    command_tx: mpsc::Sender<Command>,
36    command_rx: mpsc::Receiver<Command>,
37    response_subscriptions: HashMap<RequestId, ResponseSubscriber>,
38}
39
40impl<T, R> RPCRuntime<T, R> {
41    pub fn new(rpc_tx: T, rpc_rx: R) -> Self {
42        let (command_tx, command_rx) = mpsc::channel(100);
43        Self {
44            rpc_tx,
45            rpc_rx,
46            command_tx,
47            command_rx,
48            response_subscriptions: Default::default(),
49        }
50    }
51}
52
53impl<S, R> RPCPort for RPCRuntime<S, R>
54where
55    S: Sink<JSONRPCMessage> + Unpin + Clone + Send + Sync + 'static,
56    R: Stream<Item = anyhow::Result<JSONRPCMessage>> + Unpin + Send + Sync + 'static,
57{
58    fn sink(&self) -> impl RPCSink + Clone + Send + 'static {
59        RPCSender {
60            rpc_tx: self.rpc_tx.clone(),
61            command_tx: self.command_tx.clone(),
62        }
63    }
64
65    async fn progress(&mut self) -> anyhow::Result<Option<JSONRPCMessage>> {
66        // 1. Process all pending commands first with priority
67        while let Ok(Some(command)) = self.command_rx.try_next() {
68            match command {
69                Command::WaitResponse {
70                    request_id,
71                    response,
72                } => {
73                    self.response_subscriptions.insert(request_id, response);
74                }
75            }
76        }
77
78        // 2. Try to get a message from the stream, returning None if the stream is closed
79        match self.rpc_rx.next().await {
80            Some(Ok(message)) => {
81                match &message {
82                    JSONRPCMessage::JSONRPCResponse(response) => {
83                        self.handle_response(response);
84                    }
85                    JSONRPCMessage::JSONRPCError(error) => {
86                        self.handle_error(error);
87                    }
88                    JSONRPCMessage::JSONRPCBatchResponse(batch) => {
89                        for item in batch.0.iter() {
90                            match item {
91                                JsonrpcBatchResponseItem::JSONRPCResponse(response) => {
92                                    self.handle_response(response);
93                                }
94                                JsonrpcBatchResponseItem::JSONRPCError(error) => {
95                                    self.handle_error(error);
96                                }
97                            }
98                        }
99                    }
100                    _ => {}
101                }
102                Ok(Some(message))
103            }
104            Some(Err(e)) => Err(e),
105            None => {
106                // Return None only when the stream is closed
107                Ok(None)
108            }
109        }
110    }
111}
112
113impl<S, R> RPCRuntime<S, R>
114where
115    S: Sink<JSONRPCMessage> + Unpin + 'static,
116{
117    fn handle_response(&mut self, response: &JSONRPCResponse) {
118        if let Some(subscriber) = self.response_subscriptions.remove(&response.id) {
119            // Ignore errors if the subscriber dropped their receiver
120            let _ = subscriber.send(Ok(response.clone()));
121            // Return the message anyway so callers can process it if needed
122        }
123    }
124
125    fn handle_error(&mut self, error: &JSONRPCError) {
126        if let Some(subscriber) = self.response_subscriptions.remove(&error.id) {
127            // Ignore errors if the subscriber dropped their receiver
128            let _ = subscriber.send(Err(error.clone()));
129            // Return the message anyway so callers can process it if needed
130        }
131    }
132}
133
134impl<S> RPCSink for RPCSender<S>
135where
136    S: Sink<JSONRPCMessage> + Unpin + Send + Sync,
137{
138    async fn send_message(&mut self, message: JSONRPCMessage) -> anyhow::Result<()> {
139        self.rpc_tx
140            .send(message)
141            .await
142            .map_err(|_| anyhow::anyhow!("failed to send message to rpc"))?;
143        Ok(())
144    }
145
146    async fn send_notification<T: serde::Serialize + Send>(
147        &mut self,
148        method: &str,
149        notification: T,
150    ) -> anyhow::Result<()> {
151        // Serialize notification to JSON
152        let notification_value = serde_json::to_value(notification)
153            .map_err(|e| anyhow::anyhow!("failed to serialize notification: {}", e))?;
154
155        // If the notification is already a JSON object or null for optional params, extract it
156        let params = match notification_value {
157            Value::Object(obj) => Some(JsonrpcNotificationParams {
158                meta: None,
159                extra: obj,
160            }),
161            Value::Null => None, // Allow null for optional params
162            _ => return Err(RPCPortError::SerializeNotObject(notification_value).into()),
163        };
164
165        // Create notification message
166        let rpc_notification = JSONRPCNotification {
167            jsonrpc: Default::default(),
168            method: method.to_string(),
169            params,
170            extra: Default::default(),
171        };
172
173        // Send notification
174        self.send_message(JSONRPCMessage::JSONRPCNotification(rpc_notification))
175            .await
176    }
177
178    async fn send_response<T: serde::Serialize + Send>(
179        &mut self,
180        request_id: RequestId,
181        response: T,
182    ) -> anyhow::Result<()> {
183        // Serialize response
184        let response_value = serde_json::to_value(response)
185            .map_err(|e| anyhow::anyhow!("failed to serialize response: {}", e))?;
186
187        // Create JSON-RPC response with Result
188        let result = JsonRpcResult {
189            meta: None,
190            extra: match response_value {
191                Value::Object(obj) => obj,
192                _ => return Err(RPCPortError::SerializeNotObject(response_value).into()),
193            },
194        };
195
196        let rpc_response = JSONRPCResponse {
197            id: request_id,
198            jsonrpc: Default::default(),
199            result,
200            extra: Default::default(),
201        };
202
203        // Send response
204        self.send_message(JSONRPCMessage::JSONRPCResponse(rpc_response))
205            .await
206    }
207
208    async fn request<T: serde::Serialize + Send, R: serde::de::DeserializeOwned + Send>(
209        &mut self,
210        request_id: RequestId,
211        method: &str,
212        request: T,
213    ) -> anyhow::Result<Result<R, JSONRPCError>> {
214        // Create oneshot channel for receiving the response
215        let (response_tx, response_rx) = oneshot::channel();
216
217        // Create command to register response subscriber
218        self.command_tx
219            .send(Command::WaitResponse {
220                request_id: request_id.clone(),
221                response: response_tx,
222            })
223            .await
224            .map_err(|_| anyhow::anyhow!("failed to register response subscriber"))?;
225
226        // Serialize request to JSON
227        let params_value = serde_json::to_value(request)
228            .map_err(|e| anyhow::anyhow!("failed to serialize request params: {}", e))?;
229
230        // If the params is already a JSON object or null for optional params, extract it
231        let params = match params_value {
232            Value::Object(obj) => Some(JsonrpcRequestParams {
233                meta: None,
234                extra: obj,
235            }),
236            Value::Null => None, // Allow null for optional params
237            _ => return Err(RPCPortError::SerializeNotObject(params_value).into()),
238        };
239
240        // Create JSON-RPC request
241        let rpc_request = JSONRPCRequest {
242            id: request_id,
243            jsonrpc: Default::default(),
244            method: method.to_string(),
245            params,
246            extra: Default::default(),
247        };
248
249        // Send request
250        self.send_message(JSONRPCMessage::JSONRPCRequest(rpc_request))
251            .await?;
252
253        // Wait for response
254        let response = response_rx
255            .await
256            .map_err(|_| anyhow::anyhow!("response channel closed"))?;
257
258        // Process result
259        match response {
260            Ok(response) => {
261                // Directly use the extra field as the Value object
262                let result_value = Value::Object(response.result.extra);
263
264                // Deserialize into the expected type
265                let result = serde_json::from_value(result_value)
266                    .map_err(|e| anyhow::anyhow!("failed to deserialize response: {}", e))?;
267
268                Ok(Ok(result))
269            }
270            Err(error) => Ok(Err(error)),
271        }
272    }
273}