jsonrpc_utils/
pub_sub.rs

1//! Pub/Sub support.
2
3use std::{
4    collections::HashMap,
5    marker::PhantomData,
6    sync::{Arc, Mutex},
7};
8
9use futures_core::Stream;
10use futures_util::StreamExt;
11use jsonrpc_core::{serde::Serialize, MetaIoHandler, Metadata, Params, Value};
12use tokio::sync::mpsc::Sender;
13
14/// Transports intend to support pub/sub should provide `Session`s as metadata.
15///
16/// See websocket implementation for an example.
17#[derive(Clone)]
18pub struct Session {
19    pub raw_tx: Sender<String>,
20    pub id: u64,
21}
22
23impl Metadata for Session {}
24
25fn generate_id() -> String {
26    let id: [u8; 16] = rand::random();
27    let mut id_hex_bytes = vec![0u8; 34];
28    id_hex_bytes[..2].copy_from_slice(b"0x");
29    hex::encode_to_slice(id, &mut id_hex_bytes[2..]).unwrap();
30    unsafe { String::from_utf8_unchecked(id_hex_bytes) }
31}
32
33/// Inner message published to subscribers.
34#[derive(Clone)]
35pub struct PublishMsg<T> {
36    is_err: bool,
37    // Make clone cheap.
38    value: Arc<str>,
39    phantom: PhantomData<T>,
40}
41
42impl<T: Serialize> PublishMsg<T> {
43    /// Create a new “result” message by serializing the value into JSON.
44    ///
45    /// If serialization fails, an “error” message is created returned instead.
46    pub fn result(value: &T) -> Self {
47        match jsonrpc_core::serde_json::to_string(value) {
48            Ok(value) => Self {
49                is_err: false,
50                value: value.into(),
51                phantom: PhantomData,
52            },
53            Err(_) => Self::error(&jsonrpc_core::Error {
54                code: jsonrpc_core::ErrorCode::InternalError,
55                message: "".into(),
56                data: None,
57            }),
58        }
59    }
60}
61
62impl<T> PublishMsg<T> {
63    /// Create a new “error” message by serializing the JSONRPC error object.
64    ///
65    /// # Panics
66    ///
67    /// If serializing the error fails.
68    pub fn error(err: &jsonrpc_core::Error) -> Self {
69        Self {
70            is_err: true,
71            value: jsonrpc_core::serde_json::to_string(err).unwrap().into(),
72            phantom: PhantomData,
73        }
74    }
75
76    /// Create a new “result” message.
77    ///
78    /// `value` must be valid JSON.
79    pub fn result_raw_json(value: impl Into<Arc<str>>) -> Self {
80        Self {
81            is_err: false,
82            value: value.into(),
83            phantom: PhantomData,
84        }
85    }
86
87    /// Create a new “error” message.
88    ///
89    /// `value` must be valid JSON.
90    pub fn error_raw_json(value: impl Into<Arc<str>>) -> Self {
91        Self {
92            is_err: true,
93            value: value.into(),
94            phantom: PhantomData,
95        }
96    }
97}
98
99/// Implement this trait to define actual pub/sub logic.
100///
101/// # Streams
102///
103/// Stream wrappers from tokio-stream can be used, e.g. `BroadcastStream`.
104///
105/// Or use the async-stream crate to implement streams with async-await. See the example server.
106pub trait PubSub<T> {
107    type Stream: Stream<Item = PublishMsg<T>> + Send;
108
109    fn subscribe(&self, params: Params) -> Result<Self::Stream, jsonrpc_core::Error>;
110}
111
112impl<T, F, S> PubSub<T> for F
113where
114    F: Fn(Params) -> Result<S, jsonrpc_core::Error>,
115    S: Stream<Item = PublishMsg<T>> + Send,
116{
117    type Stream = S;
118
119    fn subscribe(&self, params: Params) -> Result<Self::Stream, jsonrpc_core::Error> {
120        (self)(params)
121    }
122}
123
124impl<T, P: PubSub<T>> PubSub<T> for Arc<P> {
125    type Stream = P::Stream;
126
127    fn subscribe(&self, params: Params) -> Result<Self::Stream, jsonrpc_core::Error> {
128        <P as PubSub<T>>::subscribe(self, params)
129    }
130}
131
132/// Add subscribe and unsubscribe methods to the jsonrpc handler.
133///
134/// Respond to subscription calls with a stream or an error. If a stream is
135/// returned, a subscription id is automatically generated. Any results produced
136/// by the stream will be sent to the client along with the subscription id. The
137/// stream is dropped if the client calls the unsubscribe method with the
138/// subscription id or if it is disconnected.
139pub fn add_pub_sub<T: Send + 'static>(
140    io: &mut MetaIoHandler<Option<Session>>,
141    subscribe_method: &str,
142    notify_method: &str,
143    unsubscribe_method: &str,
144    pubsub: impl PubSub<T> + Clone + Send + Sync + 'static,
145) {
146    let subscriptions0 = Arc::new(Mutex::new(HashMap::new()));
147    let subscriptions = subscriptions0.clone();
148    let notify_method: Arc<str> = serde_json::to_string(notify_method).unwrap().into();
149    io.add_method_with_meta(
150        subscribe_method,
151        move |params: Params, session: Option<Session>| {
152            let subscriptions = subscriptions.clone();
153            let pubsub = pubsub.clone();
154            let notify_method = notify_method.clone();
155            async move {
156                let session = session.ok_or_else(jsonrpc_core::Error::method_not_found)?;
157                let session_id = session.id;
158                let id = generate_id();
159                let stream = pubsub.subscribe(params)?;
160                let stream = terminate_after_one_error(stream);
161                let handle = tokio::spawn({
162                    let id = id.clone();
163                    let subscriptions = subscriptions.clone();
164                    async move {
165                        tokio::pin!(stream);
166                        loop {
167                            tokio::select! {
168                                biased;
169                                msg = stream.next() => {
170                                    match msg {
171                                        Some(msg) => {
172                                            let msg = format_msg(&id, &notify_method, msg);
173                                            if session.raw_tx.send(msg).await.is_err() {
174                                                break;
175                                            }
176                                        }
177                                        None => break,
178                                    }
179                                }
180                                _ = session.raw_tx.closed() => {
181                                    break;
182                                }
183                            }
184                        }
185                        subscriptions.lock().unwrap().remove(&(session_id, id));
186                    }
187                });
188                subscriptions
189                    .lock()
190                    .unwrap()
191                    .insert((session_id, id.clone()), handle);
192                Ok(Value::String(id))
193            }
194        },
195    );
196    io.add_method_with_meta(
197        unsubscribe_method,
198        move |params: Params, session: Option<Session>| {
199            let subscriptions = subscriptions0.clone();
200            async move {
201                let (id,): (String,) = params.parse()?;
202                let session_id = if let Some(session) = session {
203                    session.id
204                } else {
205                    return Ok(Value::Bool(false));
206                };
207                let result =
208                    if let Some(handle) = subscriptions.lock().unwrap().remove(&(session_id, id)) {
209                        handle.abort();
210                        true
211                    } else {
212                        false
213                    };
214                Ok(Value::Bool(result))
215            }
216        },
217    );
218}
219
220fn format_msg<T>(id: &str, method: &str, msg: PublishMsg<T>) -> String {
221    match msg.is_err {
222        false => format!(
223            r#"{{"jsonrpc":"2.0","method":{},"params":{{"subscription":"{}","result":{}}}}}"#,
224            method, id, msg.value,
225        ),
226        true => format!(
227            r#"{{"jsonrpc":"2.0","method":{},"params":{{"subscription":"{}","error":{}}}}}"#,
228            method, id, msg.value,
229        ),
230    }
231}
232
233pin_project_lite::pin_project! {
234    struct TerminateAfterOneError<S> {
235        #[pin]
236        inner: S,
237        has_error: bool,
238    }
239}
240
241impl<S, T> Stream for TerminateAfterOneError<S>
242where
243    S: Stream<Item = PublishMsg<T>>,
244{
245    type Item = PublishMsg<T>;
246
247    fn poll_next(
248        self: std::pin::Pin<&mut Self>,
249        cx: &mut std::task::Context<'_>,
250    ) -> std::task::Poll<Option<Self::Item>> {
251        if self.has_error {
252            return None.into();
253        }
254        let proj = self.project();
255        match futures_core::ready!(proj.inner.poll_next(cx)) {
256            None => None.into(),
257            Some(msg) => {
258                if msg.is_err {
259                    *proj.has_error = true;
260                }
261                Some(msg).into()
262            }
263        }
264    }
265}
266
267fn terminate_after_one_error<S>(s: S) -> TerminateAfterOneError<S> {
268    TerminateAfterOneError {
269        inner: s,
270        has_error: false,
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use async_stream::stream;
277    use jsonrpc_core::{Call, Id, MethodCall, Output, Version};
278    use tokio::sync::mpsc::channel;
279
280    use super::*;
281
282    #[test]
283    fn test_id() {
284        let id = generate_id();
285        assert!(std::str::from_utf8(id.as_bytes()).is_ok());
286    }
287
288    #[tokio::test]
289    async fn test_pubsub() {
290        let mut rpc = MetaIoHandler::with_compatibility(jsonrpc_core::Compatibility::V2);
291        add_pub_sub(&mut rpc, "sub", "notify", "unsub", |_params| {
292            Ok(stream! {
293                yield PublishMsg::result(&1);
294                yield PublishMsg::result(&1);
295            })
296        });
297        let (raw_tx, mut rx) = channel(1);
298        let response = rpc
299            .handle_call(
300                Call::MethodCall(MethodCall {
301                    jsonrpc: Some(Version::V2),
302                    method: "sub".into(),
303                    params: Params::None,
304                    id: Id::Num(1),
305                }),
306                Some(Session {
307                    raw_tx: raw_tx.clone(),
308                    id: 1,
309                }),
310            )
311            .await
312            .unwrap();
313        let sub_id = match response {
314            Output::Success(s) => s.result,
315            _ => unreachable!(),
316        };
317
318        assert!(rx.recv().await.is_some());
319
320        // Unsubscribe with a different id should fail.
321        let response = rpc
322            .handle_call(
323                Call::MethodCall(MethodCall {
324                    jsonrpc: Some(Version::V2),
325                    method: "unsub".into(),
326                    params: Params::Array(vec![sub_id.clone()]),
327                    id: Id::Num(2),
328                }),
329                Some(Session {
330                    raw_tx: raw_tx.clone(),
331                    id: 2,
332                }),
333            )
334            .await
335            .unwrap();
336        let result = match response {
337            Output::Success(s) => s.result,
338            _ => unreachable!(),
339        };
340        assert!(!result.as_bool().unwrap());
341
342        // Unsubscribe with correct id should succeed.
343        let response = rpc
344            .handle_call(
345                Call::MethodCall(MethodCall {
346                    jsonrpc: Some(Version::V2),
347                    method: "unsub".into(),
348                    params: Params::Array(vec![sub_id.clone()]),
349                    id: Id::Num(3),
350                }),
351                Some(Session { raw_tx, id: 1 }),
352            )
353            .await
354            .unwrap();
355        let result = match response {
356            Output::Success(s) => s.result,
357            _ => unreachable!(),
358        };
359        assert!(result.as_bool().unwrap());
360    }
361
362    #[tokio::test]
363    async fn test_terminate_after_one_error() {
364        let s = terminate_after_one_error(futures_util::stream::iter([
365            PublishMsg::<u64>::result_raw_json(""),
366            PublishMsg::error_raw_json(""),
367            PublishMsg::result_raw_json(""),
368        ]));
369        assert_eq!(s.count().await, 2);
370    }
371
372    #[test]
373    fn test_format_message() {
374        let msg = format_msg(
375            "id",
376            &serde_json::to_string("notification").unwrap(),
377            PublishMsg::result(&3u64),
378        );
379        let msg: serde_json::Value = serde_json::from_str(&msg).unwrap();
380        assert_eq!(msg["method"].as_str(), Some("notification"));
381        assert_eq!(msg["params"]["subscription"].as_str(), Some("id"));
382        assert_eq!(msg["params"]["result"].as_u64(), Some(3));
383    }
384}