1use std::{collections::HashMap, sync::Arc};
2
3use futures::StreamExt;
4use serde_json::Value;
5use tokio::sync::{broadcast, mpsc, oneshot, Mutex};
6
7use crate::{internal::jsonrpc, ExecError, Router};
8
9use super::{
10 jsonrpc::{RequestId, RequestInner, ResponseInner},
11 ProcedureKind, RequestContext, ValueOrStream,
12};
13
14pub enum SubscriptionMap<'a> {
17 Ref(&'a mut HashMap<RequestId, oneshot::Sender<()>>),
18 Mutex(&'a Mutex<HashMap<RequestId, oneshot::Sender<()>>>),
19 None,
20}
21
22impl<'a> SubscriptionMap<'a> {
23 pub async fn has_subscription(&self, id: &RequestId) -> bool {
24 match self {
25 SubscriptionMap::Ref(map) => map.contains_key(id),
26 SubscriptionMap::Mutex(map) => {
27 let map = map.lock().await;
28 map.contains_key(id)
29 }
30 SubscriptionMap::None => unreachable!(),
31 }
32 }
33
34 pub async fn insert(&mut self, id: RequestId, tx: oneshot::Sender<()>) {
35 match self {
36 SubscriptionMap::Ref(map) => {
37 map.insert(id, tx);
38 }
39 SubscriptionMap::Mutex(map) => {
40 let mut map = map.lock().await;
41 map.insert(id, tx);
42 }
43 SubscriptionMap::None => unreachable!(),
44 }
45 }
46
47 pub async fn remove(&mut self, id: &RequestId) {
48 match self {
49 SubscriptionMap::Ref(map) => {
50 map.remove(id);
51 }
52 SubscriptionMap::Mutex(map) => {
53 let mut map = map.lock().await;
54 map.remove(id);
55 }
56 SubscriptionMap::None => unreachable!(),
57 }
58 }
59}
60pub enum Sender<'a> {
61 Channel(&'a mut mpsc::Sender<jsonrpc::Response>),
62 ResponseChannel(&'a mut mpsc::UnboundedSender<jsonrpc::Response>),
63 Broadcast(&'a broadcast::Sender<jsonrpc::Response>),
64 Response(Option<jsonrpc::Response>),
65}
66
67pub enum Sender2 {
68 Channel(mpsc::Sender<jsonrpc::Response>),
69 ResponseChannel(mpsc::UnboundedSender<jsonrpc::Response>),
70 Broadcast(broadcast::Sender<jsonrpc::Response>),
71}
72
73impl Sender2 {
74 pub async fn send(
75 &mut self,
76 resp: jsonrpc::Response,
77 ) -> Result<(), mpsc::error::SendError<jsonrpc::Response>> {
78 match self {
79 Self::Channel(tx) => tx.send(resp).await?,
80 Self::ResponseChannel(tx) => tx.send(resp)?,
81 Self::Broadcast(tx) => {
82 let _ = tx.send(resp).map_err(|_err| {
83 });
86 }
87 }
88
89 Ok(())
90 }
91}
92
93impl<'a> Sender<'a> {
94 pub async fn send(
95 &mut self,
96 resp: jsonrpc::Response,
97 ) -> Result<(), mpsc::error::SendError<jsonrpc::Response>> {
98 match self {
99 Self::Channel(tx) => tx.send(resp).await?,
100 Self::ResponseChannel(tx) => tx.send(resp)?,
101 Self::Broadcast(tx) => {
102 let _ = tx.send(resp).map_err(|_err| {
103 });
106 }
107 Self::Response(r) => {
108 *r = Some(resp);
109 }
110 }
111
112 Ok(())
113 }
114
115 pub fn sender2(&mut self) -> Sender2 {
116 match self {
117 Self::Channel(tx) => Sender2::Channel(tx.clone()),
118 Self::ResponseChannel(tx) => Sender2::ResponseChannel(tx.clone()),
119 Self::Broadcast(tx) => Sender2::Broadcast(tx.clone()),
120 Self::Response(_) => unreachable!(),
121 }
122 }
123}
124
125pub async fn handle_json_rpc<TCtx, TMeta>(
126 ctx: TCtx,
127 req: jsonrpc::Request,
128 router: &Arc<Router<TCtx, TMeta>>,
129 sender: &mut Sender<'_>,
130 subscriptions: &mut SubscriptionMap<'_>,
131) where
132 TCtx: 'static,
133{
134 if req.jsonrpc.is_some() && req.jsonrpc.as_deref() != Some("2.0") {
135 let _ = sender
136 .send(jsonrpc::Response {
137 jsonrpc: "2.0",
138 id: req.id.clone(),
139 result: ResponseInner::Error(ExecError::InvalidJsonRpcVersion.into()),
140 })
141 .await
142 .map_err(|_err| {
143 });
146 }
147
148 let (path, input, procedures, sub_id) = match req.inner {
149 RequestInner::Query { path, input } => (path, input, router.queries(), None),
150 RequestInner::Mutation { path, input } => (path, input, router.mutations(), None),
151 RequestInner::Subscription { path, input } => {
152 (path, input.1, router.subscriptions(), Some(input.0))
153 }
154 RequestInner::SubscriptionStop { input } => {
155 subscriptions.remove(&input).await;
156 return;
157 }
158 };
159
160 let result = match procedures
161 .get(&path)
162 .ok_or_else(|| ExecError::OperationNotFound(path.clone()))
163 .and_then(|v| {
164 v.exec.call(
165 ctx,
166 input.unwrap_or(Value::Null),
167 RequestContext {
168 kind: ProcedureKind::Query,
169 path,
170 },
171 )
172 }) {
173 Ok(op) => match op.into_value_or_stream().await {
174 Ok(ValueOrStream::Value(v)) => ResponseInner::Response(v),
175 Ok(ValueOrStream::Stream(mut stream)) => {
176 if matches!(sender, Sender::Response(_))
177 || matches!(subscriptions, SubscriptionMap::None)
178 {
179 let _ = sender
180 .send(jsonrpc::Response {
181 jsonrpc: "2.0",
182 id: req.id.clone(),
183 result: ResponseInner::Error(
184 ExecError::UnsupportedMethod("Subscription".to_string()).into(),
185 ),
186 })
187 .await
188 .map_err(|_err| {
189 });
192 }
193
194 if let Some(id) = sub_id {
195 if matches!(id, RequestId::Null) {
196 let _ = sender
197 .send(jsonrpc::Response {
198 jsonrpc: "2.0",
199 id: req.id.clone(),
200 result: ResponseInner::Error(
201 ExecError::ErrSubscriptionWithNullId.into(),
202 ),
203 })
204 .await
205 .map_err(|_err| {
206 });
209 } else if subscriptions.has_subscription(&id).await {
210 let _ = sender
211 .send(jsonrpc::Response {
212 jsonrpc: "2.0",
213 id: req.id.clone(),
214 result: ResponseInner::Error(
215 ExecError::ErrSubscriptionDuplicateId.into(),
216 ),
217 })
218 .await
219 .map_err(|_err| {
220 });
223 }
224
225 let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
226 subscriptions.insert(id.clone(), shutdown_tx).await;
227 let mut sender2 = sender.sender2();
228 tokio::spawn(async move {
229 loop {
230 tokio::select! {
231 biased; _ = &mut shutdown_rx => {
233 break;
236 }
237 v = stream.next() => {
238 match v {
239 Some(Ok(v)) => {
240 let _ = sender2.send(jsonrpc::Response {
241 jsonrpc: "2.0",
242 id: id.clone(),
243 result: ResponseInner::Event(v),
244 })
245 .await
246 .map_err(|_err| {
247 });
250 }
251 Some(Err(_err)) => {
252 }
255 None => {
256 break;
257 }
258 }
259 }
260 }
261 }
262 });
263 }
264
265 return;
266 }
267 Err(err) => {
268 ResponseInner::Error(err.into())
272 }
273 },
274 Err(err) => {
275 ResponseInner::Error(err.into())
278 }
279 };
280
281 let _ = sender
282 .send(jsonrpc::Response {
283 jsonrpc: "2.0",
284 id: req.id,
285 result,
286 })
287 .await
288 .map_err(|_err| {
289 });
292}