1use std::{
4 collections::HashMap,
5 marker::PhantomData,
6 sync::{Arc, Mutex},
7};
8
9use futures_core::Stream;
10use futures_util::StreamExt;
11use jsonrpc_core::{serde::Serialize, MetaIoHandler, Metadata, Params, Value};
12use tokio::sync::mpsc::Sender;
13
14#[derive(Clone)]
18pub struct Session {
19 pub raw_tx: Sender<String>,
20 pub id: u64,
21}
22
23impl Metadata for Session {}
24
25fn generate_id() -> String {
26 let id: [u8; 16] = rand::random();
27 let mut id_hex_bytes = vec![0u8; 34];
28 id_hex_bytes[..2].copy_from_slice(b"0x");
29 hex::encode_to_slice(id, &mut id_hex_bytes[2..]).unwrap();
30 unsafe { String::from_utf8_unchecked(id_hex_bytes) }
31}
32
33#[derive(Clone)]
35pub struct PublishMsg<T> {
36 is_err: bool,
37 value: Arc<str>,
39 phantom: PhantomData<T>,
40}
41
42impl<T: Serialize> PublishMsg<T> {
43 pub fn result(value: &T) -> Self {
47 match jsonrpc_core::serde_json::to_string(value) {
48 Ok(value) => Self {
49 is_err: false,
50 value: value.into(),
51 phantom: PhantomData,
52 },
53 Err(_) => Self::error(&jsonrpc_core::Error {
54 code: jsonrpc_core::ErrorCode::InternalError,
55 message: "".into(),
56 data: None,
57 }),
58 }
59 }
60}
61
62impl<T> PublishMsg<T> {
63 pub fn error(err: &jsonrpc_core::Error) -> Self {
69 Self {
70 is_err: true,
71 value: jsonrpc_core::serde_json::to_string(err).unwrap().into(),
72 phantom: PhantomData,
73 }
74 }
75
76 pub fn result_raw_json(value: impl Into<Arc<str>>) -> Self {
80 Self {
81 is_err: false,
82 value: value.into(),
83 phantom: PhantomData,
84 }
85 }
86
87 pub fn error_raw_json(value: impl Into<Arc<str>>) -> Self {
91 Self {
92 is_err: true,
93 value: value.into(),
94 phantom: PhantomData,
95 }
96 }
97}
98
99pub trait PubSub<T> {
107 type Stream: Stream<Item = PublishMsg<T>> + Send;
108
109 fn subscribe(&self, params: Params) -> Result<Self::Stream, jsonrpc_core::Error>;
110}
111
112impl<T, F, S> PubSub<T> for F
113where
114 F: Fn(Params) -> Result<S, jsonrpc_core::Error>,
115 S: Stream<Item = PublishMsg<T>> + Send,
116{
117 type Stream = S;
118
119 fn subscribe(&self, params: Params) -> Result<Self::Stream, jsonrpc_core::Error> {
120 (self)(params)
121 }
122}
123
124impl<T, P: PubSub<T>> PubSub<T> for Arc<P> {
125 type Stream = P::Stream;
126
127 fn subscribe(&self, params: Params) -> Result<Self::Stream, jsonrpc_core::Error> {
128 <P as PubSub<T>>::subscribe(self, params)
129 }
130}
131
132pub fn add_pub_sub<T: Send + 'static>(
140 io: &mut MetaIoHandler<Option<Session>>,
141 subscribe_method: &str,
142 notify_method: &str,
143 unsubscribe_method: &str,
144 pubsub: impl PubSub<T> + Clone + Send + Sync + 'static,
145) {
146 let subscriptions0 = Arc::new(Mutex::new(HashMap::new()));
147 let subscriptions = subscriptions0.clone();
148 let notify_method: Arc<str> = serde_json::to_string(notify_method).unwrap().into();
149 io.add_method_with_meta(
150 subscribe_method,
151 move |params: Params, session: Option<Session>| {
152 let subscriptions = subscriptions.clone();
153 let pubsub = pubsub.clone();
154 let notify_method = notify_method.clone();
155 async move {
156 let session = session.ok_or_else(jsonrpc_core::Error::method_not_found)?;
157 let session_id = session.id;
158 let id = generate_id();
159 let stream = pubsub.subscribe(params)?;
160 let stream = terminate_after_one_error(stream);
161 let handle = tokio::spawn({
162 let id = id.clone();
163 let subscriptions = subscriptions.clone();
164 async move {
165 tokio::pin!(stream);
166 loop {
167 tokio::select! {
168 biased;
169 msg = stream.next() => {
170 match msg {
171 Some(msg) => {
172 let msg = format_msg(&id, ¬ify_method, msg);
173 if session.raw_tx.send(msg).await.is_err() {
174 break;
175 }
176 }
177 None => break,
178 }
179 }
180 _ = session.raw_tx.closed() => {
181 break;
182 }
183 }
184 }
185 subscriptions.lock().unwrap().remove(&(session_id, id));
186 }
187 });
188 subscriptions
189 .lock()
190 .unwrap()
191 .insert((session_id, id.clone()), handle);
192 Ok(Value::String(id))
193 }
194 },
195 );
196 io.add_method_with_meta(
197 unsubscribe_method,
198 move |params: Params, session: Option<Session>| {
199 let subscriptions = subscriptions0.clone();
200 async move {
201 let (id,): (String,) = params.parse()?;
202 let session_id = if let Some(session) = session {
203 session.id
204 } else {
205 return Ok(Value::Bool(false));
206 };
207 let result =
208 if let Some(handle) = subscriptions.lock().unwrap().remove(&(session_id, id)) {
209 handle.abort();
210 true
211 } else {
212 false
213 };
214 Ok(Value::Bool(result))
215 }
216 },
217 );
218}
219
220fn format_msg<T>(id: &str, method: &str, msg: PublishMsg<T>) -> String {
221 match msg.is_err {
222 false => format!(
223 r#"{{"jsonrpc":"2.0","method":{},"params":{{"subscription":"{}","result":{}}}}}"#,
224 method, id, msg.value,
225 ),
226 true => format!(
227 r#"{{"jsonrpc":"2.0","method":{},"params":{{"subscription":"{}","error":{}}}}}"#,
228 method, id, msg.value,
229 ),
230 }
231}
232
233pin_project_lite::pin_project! {
234 struct TerminateAfterOneError<S> {
235 #[pin]
236 inner: S,
237 has_error: bool,
238 }
239}
240
241impl<S, T> Stream for TerminateAfterOneError<S>
242where
243 S: Stream<Item = PublishMsg<T>>,
244{
245 type Item = PublishMsg<T>;
246
247 fn poll_next(
248 self: std::pin::Pin<&mut Self>,
249 cx: &mut std::task::Context<'_>,
250 ) -> std::task::Poll<Option<Self::Item>> {
251 if self.has_error {
252 return None.into();
253 }
254 let proj = self.project();
255 match futures_core::ready!(proj.inner.poll_next(cx)) {
256 None => None.into(),
257 Some(msg) => {
258 if msg.is_err {
259 *proj.has_error = true;
260 }
261 Some(msg).into()
262 }
263 }
264 }
265}
266
267fn terminate_after_one_error<S>(s: S) -> TerminateAfterOneError<S> {
268 TerminateAfterOneError {
269 inner: s,
270 has_error: false,
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use async_stream::stream;
277 use jsonrpc_core::{Call, Id, MethodCall, Output, Version};
278 use tokio::sync::mpsc::channel;
279
280 use super::*;
281
282 #[test]
283 fn test_id() {
284 let id = generate_id();
285 assert!(std::str::from_utf8(id.as_bytes()).is_ok());
286 }
287
288 #[tokio::test]
289 async fn test_pubsub() {
290 let mut rpc = MetaIoHandler::with_compatibility(jsonrpc_core::Compatibility::V2);
291 add_pub_sub(&mut rpc, "sub", "notify", "unsub", |_params| {
292 Ok(stream! {
293 yield PublishMsg::result(&1);
294 yield PublishMsg::result(&1);
295 })
296 });
297 let (raw_tx, mut rx) = channel(1);
298 let response = rpc
299 .handle_call(
300 Call::MethodCall(MethodCall {
301 jsonrpc: Some(Version::V2),
302 method: "sub".into(),
303 params: Params::None,
304 id: Id::Num(1),
305 }),
306 Some(Session {
307 raw_tx: raw_tx.clone(),
308 id: 1,
309 }),
310 )
311 .await
312 .unwrap();
313 let sub_id = match response {
314 Output::Success(s) => s.result,
315 _ => unreachable!(),
316 };
317
318 assert!(rx.recv().await.is_some());
319
320 let response = rpc
322 .handle_call(
323 Call::MethodCall(MethodCall {
324 jsonrpc: Some(Version::V2),
325 method: "unsub".into(),
326 params: Params::Array(vec![sub_id.clone()]),
327 id: Id::Num(2),
328 }),
329 Some(Session {
330 raw_tx: raw_tx.clone(),
331 id: 2,
332 }),
333 )
334 .await
335 .unwrap();
336 let result = match response {
337 Output::Success(s) => s.result,
338 _ => unreachable!(),
339 };
340 assert!(!result.as_bool().unwrap());
341
342 let response = rpc
344 .handle_call(
345 Call::MethodCall(MethodCall {
346 jsonrpc: Some(Version::V2),
347 method: "unsub".into(),
348 params: Params::Array(vec![sub_id.clone()]),
349 id: Id::Num(3),
350 }),
351 Some(Session { raw_tx, id: 1 }),
352 )
353 .await
354 .unwrap();
355 let result = match response {
356 Output::Success(s) => s.result,
357 _ => unreachable!(),
358 };
359 assert!(result.as_bool().unwrap());
360 }
361
362 #[tokio::test]
363 async fn test_terminate_after_one_error() {
364 let s = terminate_after_one_error(futures_util::stream::iter([
365 PublishMsg::<u64>::result_raw_json(""),
366 PublishMsg::error_raw_json(""),
367 PublishMsg::result_raw_json(""),
368 ]));
369 assert_eq!(s.count().await, 2);
370 }
371
372 #[test]
373 fn test_format_message() {
374 let msg = format_msg(
375 "id",
376 &serde_json::to_string("notification").unwrap(),
377 PublishMsg::result(&3u64),
378 );
379 let msg: serde_json::Value = serde_json::from_str(&msg).unwrap();
380 assert_eq!(msg["method"].as_str(), Some("notification"));
381 assert_eq!(msg["params"]["subscription"].as_str(), Some("id"));
382 assert_eq!(msg["params"]["result"].as_u64(), Some(3));
383 }
384}