rspc_legacy/internal/
jsonrpc_exec.rs

1use std::{collections::HashMap, sync::Arc};
2
3use futures::StreamExt;
4use serde_json::Value;
5use tokio::sync::{broadcast, mpsc, oneshot, Mutex};
6
7use crate::{internal::jsonrpc, ExecError, Router};
8
9use super::{
10    jsonrpc::{RequestId, RequestInner, ResponseInner},
11    ProcedureKind, RequestContext, ValueOrStream,
12};
13
14// TODO: Deduplicate this function with the httpz integration
15
16pub enum SubscriptionMap<'a> {
17    Ref(&'a mut HashMap<RequestId, oneshot::Sender<()>>),
18    Mutex(&'a Mutex<HashMap<RequestId, oneshot::Sender<()>>>),
19    None,
20}
21
22impl<'a> SubscriptionMap<'a> {
23    pub async fn has_subscription(&self, id: &RequestId) -> bool {
24        match self {
25            SubscriptionMap::Ref(map) => map.contains_key(id),
26            SubscriptionMap::Mutex(map) => {
27                let map = map.lock().await;
28                map.contains_key(id)
29            }
30            SubscriptionMap::None => unreachable!(),
31        }
32    }
33
34    pub async fn insert(&mut self, id: RequestId, tx: oneshot::Sender<()>) {
35        match self {
36            SubscriptionMap::Ref(map) => {
37                map.insert(id, tx);
38            }
39            SubscriptionMap::Mutex(map) => {
40                let mut map = map.lock().await;
41                map.insert(id, tx);
42            }
43            SubscriptionMap::None => unreachable!(),
44        }
45    }
46
47    pub async fn remove(&mut self, id: &RequestId) {
48        match self {
49            SubscriptionMap::Ref(map) => {
50                map.remove(id);
51            }
52            SubscriptionMap::Mutex(map) => {
53                let mut map = map.lock().await;
54                map.remove(id);
55            }
56            SubscriptionMap::None => unreachable!(),
57        }
58    }
59}
60pub enum Sender<'a> {
61    Channel(&'a mut mpsc::Sender<jsonrpc::Response>),
62    ResponseChannel(&'a mut mpsc::UnboundedSender<jsonrpc::Response>),
63    Broadcast(&'a broadcast::Sender<jsonrpc::Response>),
64    Response(Option<jsonrpc::Response>),
65}
66
67pub enum Sender2 {
68    Channel(mpsc::Sender<jsonrpc::Response>),
69    ResponseChannel(mpsc::UnboundedSender<jsonrpc::Response>),
70    Broadcast(broadcast::Sender<jsonrpc::Response>),
71}
72
73impl Sender2 {
74    pub async fn send(
75        &mut self,
76        resp: jsonrpc::Response,
77    ) -> Result<(), mpsc::error::SendError<jsonrpc::Response>> {
78        match self {
79            Self::Channel(tx) => tx.send(resp).await?,
80            Self::ResponseChannel(tx) => tx.send(resp)?,
81            Self::Broadcast(tx) => {
82                let _ = tx.send(resp).map_err(|_err| {
83                    // #[cfg(feature = "tracing")]
84                    // tracing::error!("Failed to send response: {}", _err);
85                });
86            }
87        }
88
89        Ok(())
90    }
91}
92
93impl<'a> Sender<'a> {
94    pub async fn send(
95        &mut self,
96        resp: jsonrpc::Response,
97    ) -> Result<(), mpsc::error::SendError<jsonrpc::Response>> {
98        match self {
99            Self::Channel(tx) => tx.send(resp).await?,
100            Self::ResponseChannel(tx) => tx.send(resp)?,
101            Self::Broadcast(tx) => {
102                let _ = tx.send(resp).map_err(|_err| {
103                    // #[cfg(feature = "tracing")]
104                    // tracing::error!("Failed to send response: {}", _err);
105                });
106            }
107            Self::Response(r) => {
108                *r = Some(resp);
109            }
110        }
111
112        Ok(())
113    }
114
115    pub fn sender2(&mut self) -> Sender2 {
116        match self {
117            Self::Channel(tx) => Sender2::Channel(tx.clone()),
118            Self::ResponseChannel(tx) => Sender2::ResponseChannel(tx.clone()),
119            Self::Broadcast(tx) => Sender2::Broadcast(tx.clone()),
120            Self::Response(_) => unreachable!(),
121        }
122    }
123}
124
125pub async fn handle_json_rpc<TCtx, TMeta>(
126    ctx: TCtx,
127    req: jsonrpc::Request,
128    router: &Arc<Router<TCtx, TMeta>>,
129    sender: &mut Sender<'_>,
130    subscriptions: &mut SubscriptionMap<'_>,
131) where
132    TCtx: 'static,
133{
134    if req.jsonrpc.is_some() && req.jsonrpc.as_deref() != Some("2.0") {
135        let _ = sender
136            .send(jsonrpc::Response {
137                jsonrpc: "2.0",
138                id: req.id.clone(),
139                result: ResponseInner::Error(ExecError::InvalidJsonRpcVersion.into()),
140            })
141            .await
142            .map_err(|_err| {
143                // #[cfg(feature = "tracing")]
144                // tracing::error!("Failed to send response: {}", _err);
145            });
146    }
147
148    let (path, input, procedures, sub_id) = match req.inner {
149        RequestInner::Query { path, input } => (path, input, router.queries(), None),
150        RequestInner::Mutation { path, input } => (path, input, router.mutations(), None),
151        RequestInner::Subscription { path, input } => {
152            (path, input.1, router.subscriptions(), Some(input.0))
153        }
154        RequestInner::SubscriptionStop { input } => {
155            subscriptions.remove(&input).await;
156            return;
157        }
158    };
159
160    let result = match procedures
161        .get(&path)
162        .ok_or_else(|| ExecError::OperationNotFound(path.clone()))
163        .and_then(|v| {
164            v.exec.call(
165                ctx,
166                input.unwrap_or(Value::Null),
167                RequestContext {
168                    kind: ProcedureKind::Query,
169                    path,
170                },
171            )
172        }) {
173        Ok(op) => match op.into_value_or_stream().await {
174            Ok(ValueOrStream::Value(v)) => ResponseInner::Response(v),
175            Ok(ValueOrStream::Stream(mut stream)) => {
176                if matches!(sender, Sender::Response(_))
177                    || matches!(subscriptions, SubscriptionMap::None)
178                {
179                    let _ = sender
180                        .send(jsonrpc::Response {
181                            jsonrpc: "2.0",
182                            id: req.id.clone(),
183                            result: ResponseInner::Error(
184                                ExecError::UnsupportedMethod("Subscription".to_string()).into(),
185                            ),
186                        })
187                        .await
188                        .map_err(|_err| {
189                            // #[cfg(feature = "tracing")]
190                            // tracing::error!("Failed to send response: {}", _err);
191                        });
192                }
193
194                if let Some(id) = sub_id {
195                    if matches!(id, RequestId::Null) {
196                        let _ = sender
197                            .send(jsonrpc::Response {
198                                jsonrpc: "2.0",
199                                id: req.id.clone(),
200                                result: ResponseInner::Error(
201                                    ExecError::ErrSubscriptionWithNullId.into(),
202                                ),
203                            })
204                            .await
205                            .map_err(|_err| {
206                                // #[cfg(feature = "tracing")]
207                                // tracing::error!("Failed to send response: {}", _err);
208                            });
209                    } else if subscriptions.has_subscription(&id).await {
210                        let _ = sender
211                            .send(jsonrpc::Response {
212                                jsonrpc: "2.0",
213                                id: req.id.clone(),
214                                result: ResponseInner::Error(
215                                    ExecError::ErrSubscriptionDuplicateId.into(),
216                                ),
217                            })
218                            .await
219                            .map_err(|_err| {
220                                // #[cfg(feature = "tracing")]
221                                // tracing::error!("Failed to send response: {}", _err);
222                            });
223                    }
224
225                    let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
226                    subscriptions.insert(id.clone(), shutdown_tx).await;
227                    let mut sender2 = sender.sender2();
228                    tokio::spawn(async move {
229                        loop {
230                            tokio::select! {
231                                biased; // Note: Order matters
232                                _ = &mut shutdown_rx => {
233                                    // #[cfg(feature = "tracing")]
234                                    // tracing::debug!("Removing subscription with id '{:?}'", id);
235                                    break;
236                                }
237                                v = stream.next() => {
238                                    match v {
239                                        Some(Ok(v)) => {
240                                            let _ = sender2.send(jsonrpc::Response {
241                                                jsonrpc: "2.0",
242                                                id: id.clone(),
243                                                result: ResponseInner::Event(v),
244                                            })
245                                            .await
246                                            .map_err(|_err| {
247                                                // #[cfg(feature = "tracing")]
248                                                // tracing::error!("Failed to send response: {:?}", _err);
249                                            });
250                                        }
251                                        Some(Err(_err)) => {
252                                           // #[cfg(feature = "tracing")]
253                                           //  tracing::error!("Subscription error: {:?}", _err);
254                                        }
255                                        None => {
256                                            break;
257                                        }
258                                    }
259                                }
260                            }
261                        }
262                    });
263                }
264
265                return;
266            }
267            Err(err) => {
268                // #[cfg(feature = "tracing")]
269                // tracing::error!("Error executing operation: {:?}", err);
270
271                ResponseInner::Error(err.into())
272            }
273        },
274        Err(err) => {
275            // #[cfg(feature = "tracing")]
276            // tracing::error!("Error executing operation: {:?}", err);
277            ResponseInner::Error(err.into())
278        }
279    };
280
281    let _ = sender
282        .send(jsonrpc::Response {
283            jsonrpc: "2.0",
284            id: req.id,
285            result,
286        })
287        .await
288        .map_err(|_err| {
289            // #[cfg(feature = "tracing")]
290            // tracing::error!("Failed to send response: {:?}", _err);
291        });
292}