1use std::collections::HashMap;
2
3use futures::{
4 Sink, SinkExt, Stream, StreamExt,
5 channel::{mpsc, oneshot},
6};
7use mmcp_protocol::{
8 mcp::{
9 JSONRPCError, JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse,
10 JsonrpcBatchResponseItem, JsonrpcNotificationParams, JsonrpcRequestParams, RequestId,
11 Result as JsonRpcResult,
12 },
13 port::{RPCPort, RPCPortError, RPCSink},
14};
15use serde_json::Value;
16
17type ResponseSubscriber = oneshot::Sender<Result<JSONRPCResponse, JSONRPCError>>;
18
19enum Command {
20 WaitResponse {
21 request_id: RequestId,
22 response: ResponseSubscriber,
23 },
24}
25
26#[derive(Clone)]
27pub struct RPCSender<T> {
28 rpc_tx: T,
29 command_tx: mpsc::Sender<Command>,
30}
31
32pub struct RPCRuntime<T, R> {
33 rpc_tx: T,
34 rpc_rx: R,
35 command_tx: mpsc::Sender<Command>,
36 command_rx: mpsc::Receiver<Command>,
37 response_subscriptions: HashMap<RequestId, ResponseSubscriber>,
38}
39
40impl<T, R> RPCRuntime<T, R> {
41 pub fn new(rpc_tx: T, rpc_rx: R) -> Self {
42 let (command_tx, command_rx) = mpsc::channel(100);
43 Self {
44 rpc_tx,
45 rpc_rx,
46 command_tx,
47 command_rx,
48 response_subscriptions: Default::default(),
49 }
50 }
51}
52
53impl<S, R> RPCPort for RPCRuntime<S, R>
54where
55 S: Sink<JSONRPCMessage> + Unpin + Clone + Send + Sync + 'static,
56 R: Stream<Item = anyhow::Result<JSONRPCMessage>> + Unpin + Send + Sync + 'static,
57{
58 fn sink(&self) -> impl RPCSink + Clone + Send + 'static {
59 RPCSender {
60 rpc_tx: self.rpc_tx.clone(),
61 command_tx: self.command_tx.clone(),
62 }
63 }
64
65 async fn progress(&mut self) -> anyhow::Result<Option<JSONRPCMessage>> {
66 while let Ok(Some(command)) = self.command_rx.try_next() {
68 match command {
69 Command::WaitResponse {
70 request_id,
71 response,
72 } => {
73 self.response_subscriptions.insert(request_id, response);
74 }
75 }
76 }
77
78 match self.rpc_rx.next().await {
80 Some(Ok(message)) => {
81 match &message {
82 JSONRPCMessage::JSONRPCResponse(response) => {
83 self.handle_response(response);
84 }
85 JSONRPCMessage::JSONRPCError(error) => {
86 self.handle_error(error);
87 }
88 JSONRPCMessage::JSONRPCBatchResponse(batch) => {
89 for item in batch.0.iter() {
90 match item {
91 JsonrpcBatchResponseItem::JSONRPCResponse(response) => {
92 self.handle_response(response);
93 }
94 JsonrpcBatchResponseItem::JSONRPCError(error) => {
95 self.handle_error(error);
96 }
97 }
98 }
99 }
100 _ => {}
101 }
102 Ok(Some(message))
103 }
104 Some(Err(e)) => Err(e),
105 None => {
106 Ok(None)
108 }
109 }
110 }
111}
112
113impl<S, R> RPCRuntime<S, R>
114where
115 S: Sink<JSONRPCMessage> + Unpin + 'static,
116{
117 fn handle_response(&mut self, response: &JSONRPCResponse) {
118 if let Some(subscriber) = self.response_subscriptions.remove(&response.id) {
119 let _ = subscriber.send(Ok(response.clone()));
121 }
123 }
124
125 fn handle_error(&mut self, error: &JSONRPCError) {
126 if let Some(subscriber) = self.response_subscriptions.remove(&error.id) {
127 let _ = subscriber.send(Err(error.clone()));
129 }
131 }
132}
133
134impl<S> RPCSink for RPCSender<S>
135where
136 S: Sink<JSONRPCMessage> + Unpin + Send + Sync,
137{
138 async fn send_message(&mut self, message: JSONRPCMessage) -> anyhow::Result<()> {
139 self.rpc_tx
140 .send(message)
141 .await
142 .map_err(|_| anyhow::anyhow!("failed to send message to rpc"))?;
143 Ok(())
144 }
145
146 async fn send_notification<T: serde::Serialize + Send>(
147 &mut self,
148 method: &str,
149 notification: T,
150 ) -> anyhow::Result<()> {
151 let notification_value = serde_json::to_value(notification)
153 .map_err(|e| anyhow::anyhow!("failed to serialize notification: {}", e))?;
154
155 let params = match notification_value {
157 Value::Object(obj) => Some(JsonrpcNotificationParams {
158 meta: None,
159 extra: obj,
160 }),
161 Value::Null => None, _ => return Err(RPCPortError::SerializeNotObject(notification_value).into()),
163 };
164
165 let rpc_notification = JSONRPCNotification {
167 jsonrpc: Default::default(),
168 method: method.to_string(),
169 params,
170 extra: Default::default(),
171 };
172
173 self.send_message(JSONRPCMessage::JSONRPCNotification(rpc_notification))
175 .await
176 }
177
178 async fn send_response<T: serde::Serialize + Send>(
179 &mut self,
180 request_id: RequestId,
181 response: T,
182 ) -> anyhow::Result<()> {
183 let response_value = serde_json::to_value(response)
185 .map_err(|e| anyhow::anyhow!("failed to serialize response: {}", e))?;
186
187 let result = JsonRpcResult {
189 meta: None,
190 extra: match response_value {
191 Value::Object(obj) => obj,
192 _ => return Err(RPCPortError::SerializeNotObject(response_value).into()),
193 },
194 };
195
196 let rpc_response = JSONRPCResponse {
197 id: request_id,
198 jsonrpc: Default::default(),
199 result,
200 extra: Default::default(),
201 };
202
203 self.send_message(JSONRPCMessage::JSONRPCResponse(rpc_response))
205 .await
206 }
207
208 async fn request<T: serde::Serialize + Send, R: serde::de::DeserializeOwned + Send>(
209 &mut self,
210 request_id: RequestId,
211 method: &str,
212 request: T,
213 ) -> anyhow::Result<Result<R, JSONRPCError>> {
214 let (response_tx, response_rx) = oneshot::channel();
216
217 self.command_tx
219 .send(Command::WaitResponse {
220 request_id: request_id.clone(),
221 response: response_tx,
222 })
223 .await
224 .map_err(|_| anyhow::anyhow!("failed to register response subscriber"))?;
225
226 let params_value = serde_json::to_value(request)
228 .map_err(|e| anyhow::anyhow!("failed to serialize request params: {}", e))?;
229
230 let params = match params_value {
232 Value::Object(obj) => Some(JsonrpcRequestParams {
233 meta: None,
234 extra: obj,
235 }),
236 Value::Null => None, _ => return Err(RPCPortError::SerializeNotObject(params_value).into()),
238 };
239
240 let rpc_request = JSONRPCRequest {
242 id: request_id,
243 jsonrpc: Default::default(),
244 method: method.to_string(),
245 params,
246 extra: Default::default(),
247 };
248
249 self.send_message(JSONRPCMessage::JSONRPCRequest(rpc_request))
251 .await?;
252
253 let response = response_rx
255 .await
256 .map_err(|_| anyhow::anyhow!("response channel closed"))?;
257
258 match response {
260 Ok(response) => {
261 let result_value = Value::Object(response.result.extra);
263
264 let result = serde_json::from_value(result_value)
266 .map_err(|e| anyhow::anyhow!("failed to deserialize response: {}", e))?;
267
268 Ok(Ok(result))
269 }
270 Err(error) => Ok(Err(error)),
271 }
272 }
273}