jsonrpc_utils/
pub_sub.rs

1//! Pub/Sub support.
2
3use std::{
4    collections::HashMap,
5    marker::PhantomData,
6    sync::{Arc, Mutex},
7};
8
9use futures_core::Stream;
10use futures_util::StreamExt;
11use jsonrpc_core::{serde::Serialize, MetaIoHandler, Metadata, Params, Value};
12use rand::{thread_rng, Rng};
13use tokio::sync::mpsc::Sender;
14
15/// Transports intend to support pub/sub should provide `Session`s as metadata.
16///
17/// See websocket implementation for an example.
18#[derive(Clone)]
19pub struct Session {
20    pub raw_tx: Sender<String>,
21    pub id: u64,
22}
23
24impl Metadata for Session {}
25
26fn generate_id() -> String {
27    let id: [u8; 16] = thread_rng().gen();
28    let mut id_hex_bytes = vec![0u8; 34];
29    id_hex_bytes[..2].copy_from_slice(b"0x");
30    hex::encode_to_slice(id, &mut id_hex_bytes[2..]).unwrap();
31    unsafe { String::from_utf8_unchecked(id_hex_bytes) }
32}
33
34/// Inner message published to subscribers.
35#[derive(Clone)]
36pub struct PublishMsg<T> {
37    is_err: bool,
38    // Make clone cheap.
39    value: Arc<str>,
40    phantom: PhantomData<T>,
41}
42
43impl<T: Serialize> PublishMsg<T> {
44    /// Create a new “result” message by serializing the value into JSON.
45    ///
46    /// If serialization fails, an “error” message is created returned instead.
47    pub fn result(value: &T) -> Self {
48        match jsonrpc_core::serde_json::to_string(value) {
49            Ok(value) => Self {
50                is_err: false,
51                value: value.into(),
52                phantom: PhantomData,
53            },
54            Err(_) => Self::error(&jsonrpc_core::Error {
55                code: jsonrpc_core::ErrorCode::InternalError,
56                message: "".into(),
57                data: None,
58            }),
59        }
60    }
61}
62
63impl<T> PublishMsg<T> {
64    /// Create a new “error” message by serializing the JSONRPC error object.
65    ///
66    /// # Panics
67    ///
68    /// If serializing the error fails.
69    pub fn error(err: &jsonrpc_core::Error) -> Self {
70        Self {
71            is_err: true,
72            value: jsonrpc_core::serde_json::to_string(err).unwrap().into(),
73            phantom: PhantomData,
74        }
75    }
76
77    /// Create a new “result” message.
78    ///
79    /// `value` must be valid JSON.
80    pub fn result_raw_json(value: impl Into<Arc<str>>) -> Self {
81        Self {
82            is_err: false,
83            value: value.into(),
84            phantom: PhantomData,
85        }
86    }
87
88    /// Create a new “error” message.
89    ///
90    /// `value` must be valid JSON.
91    pub fn error_raw_json(value: impl Into<Arc<str>>) -> Self {
92        Self {
93            is_err: true,
94            value: value.into(),
95            phantom: PhantomData,
96        }
97    }
98}
99
100/// Implement this trait to define actual pub/sub logic.
101///
102/// # Streams
103///
104/// Stream wrappers from tokio-stream can be used, e.g. `BroadcastStream`.
105///
106/// Or use the async-stream crate to implement streams with async-await. See the example server.
107pub trait PubSub<T> {
108    type Stream: Stream<Item = PublishMsg<T>> + Send;
109
110    fn subscribe(&self, params: Params) -> Result<Self::Stream, jsonrpc_core::Error>;
111}
112
113impl<T, F, S> PubSub<T> for F
114where
115    F: Fn(Params) -> Result<S, jsonrpc_core::Error>,
116    S: Stream<Item = PublishMsg<T>> + Send,
117{
118    type Stream = S;
119
120    fn subscribe(&self, params: Params) -> Result<Self::Stream, jsonrpc_core::Error> {
121        (self)(params)
122    }
123}
124
125impl<T, P: PubSub<T>> PubSub<T> for Arc<P> {
126    type Stream = P::Stream;
127
128    fn subscribe(&self, params: Params) -> Result<Self::Stream, jsonrpc_core::Error> {
129        <P as PubSub<T>>::subscribe(self, params)
130    }
131}
132
133/// Add subscribe and unsubscribe methods to the jsonrpc handler.
134///
135/// Respond to subscription calls with a stream or an error. If a stream is
136/// returned, a subscription id is automatically generated. Any results produced
137/// by the stream will be sent to the client along with the subscription id. The
138/// stream is dropped if the client calls the unsubscribe method with the
139/// subscription id or if it is disconnected.
140pub fn add_pub_sub<T: Send + 'static>(
141    io: &mut MetaIoHandler<Option<Session>>,
142    subscribe_method: &str,
143    notify_method: &str,
144    unsubscribe_method: &str,
145    pubsub: impl PubSub<T> + Clone + Send + Sync + 'static,
146) {
147    let subscriptions0 = Arc::new(Mutex::new(HashMap::new()));
148    let subscriptions = subscriptions0.clone();
149    let notify_method: Arc<str> = serde_json::to_string(notify_method).unwrap().into();
150    io.add_method_with_meta(
151        subscribe_method,
152        move |params: Params, session: Option<Session>| {
153            let subscriptions = subscriptions.clone();
154            let pubsub = pubsub.clone();
155            let notify_method = notify_method.clone();
156            async move {
157                let session = session.ok_or_else(jsonrpc_core::Error::method_not_found)?;
158                let session_id = session.id;
159                let id = generate_id();
160                let stream = pubsub.subscribe(params)?;
161                let stream = terminate_after_one_error(stream);
162                let handle = tokio::spawn({
163                    let id = id.clone();
164                    let subscriptions = subscriptions.clone();
165                    async move {
166                        tokio::pin!(stream);
167                        loop {
168                            tokio::select! {
169                                biased;
170                                msg = stream.next() => {
171                                    match msg {
172                                        Some(msg) => {
173                                            let msg = format_msg(&id, &notify_method, msg);
174                                            if session.raw_tx.send(msg).await.is_err() {
175                                                break;
176                                            }
177                                        }
178                                        None => break,
179                                    }
180                                }
181                                _ = session.raw_tx.closed() => {
182                                    break;
183                                }
184                            }
185                        }
186                        subscriptions.lock().unwrap().remove(&(session_id, id));
187                    }
188                });
189                subscriptions
190                    .lock()
191                    .unwrap()
192                    .insert((session_id, id.clone()), handle);
193                Ok(Value::String(id))
194            }
195        },
196    );
197    io.add_method_with_meta(
198        unsubscribe_method,
199        move |params: Params, session: Option<Session>| {
200            let subscriptions = subscriptions0.clone();
201            async move {
202                let (id,): (String,) = params.parse()?;
203                let session_id = if let Some(session) = session {
204                    session.id
205                } else {
206                    return Ok(Value::Bool(false));
207                };
208                let result =
209                    if let Some(handle) = subscriptions.lock().unwrap().remove(&(session_id, id)) {
210                        handle.abort();
211                        true
212                    } else {
213                        false
214                    };
215                Ok(Value::Bool(result))
216            }
217        },
218    );
219}
220
221fn format_msg<T>(id: &str, method: &str, msg: PublishMsg<T>) -> String {
222    match msg.is_err {
223        false => format!(
224            r#"{{"jsonrpc":"2.0","method":{},"params":{{"subscription":"{}","result":{}}}}}"#,
225            method, id, msg.value,
226        ),
227        true => format!(
228            r#"{{"jsonrpc":"2.0","method":{},"params":{{"subscription":"{}","error":{}}}}}"#,
229            method, id, msg.value,
230        ),
231    }
232}
233
234pin_project_lite::pin_project! {
235    struct TerminateAfterOneError<S> {
236        #[pin]
237        inner: S,
238        has_error: bool,
239    }
240}
241
242impl<S, T> Stream for TerminateAfterOneError<S>
243where
244    S: Stream<Item = PublishMsg<T>>,
245{
246    type Item = PublishMsg<T>;
247
248    fn poll_next(
249        self: std::pin::Pin<&mut Self>,
250        cx: &mut std::task::Context<'_>,
251    ) -> std::task::Poll<Option<Self::Item>> {
252        if self.has_error {
253            return None.into();
254        }
255        let proj = self.project();
256        match futures_core::ready!(proj.inner.poll_next(cx)) {
257            None => None.into(),
258            Some(msg) => {
259                if msg.is_err {
260                    *proj.has_error = true;
261                }
262                Some(msg).into()
263            }
264        }
265    }
266}
267
268fn terminate_after_one_error<S>(s: S) -> TerminateAfterOneError<S> {
269    TerminateAfterOneError {
270        inner: s,
271        has_error: false,
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use async_stream::stream;
278    use jsonrpc_core::{Call, Id, MethodCall, Output, Version};
279    use tokio::sync::mpsc::channel;
280
281    use super::*;
282
283    #[test]
284    fn test_id() {
285        let id = generate_id();
286        assert!(std::str::from_utf8(id.as_bytes()).is_ok());
287    }
288
289    #[tokio::test]
290    async fn test_pubsub() {
291        let mut rpc = MetaIoHandler::with_compatibility(jsonrpc_core::Compatibility::V2);
292        add_pub_sub(&mut rpc, "sub", "notify", "unsub", |_params| {
293            Ok(stream! {
294                yield PublishMsg::result(&1);
295                yield PublishMsg::result(&1);
296            })
297        });
298        let (raw_tx, mut rx) = channel(1);
299        let response = rpc
300            .handle_call(
301                Call::MethodCall(MethodCall {
302                    jsonrpc: Some(Version::V2),
303                    method: "sub".into(),
304                    params: Params::None,
305                    id: Id::Num(1),
306                }),
307                Some(Session {
308                    raw_tx: raw_tx.clone(),
309                    id: 1,
310                }),
311            )
312            .await
313            .unwrap();
314        let sub_id = match response {
315            Output::Success(s) => s.result,
316            _ => unreachable!(),
317        };
318
319        assert!(rx.recv().await.is_some());
320
321        // Unsubscribe with a different id should fail.
322        let response = rpc
323            .handle_call(
324                Call::MethodCall(MethodCall {
325                    jsonrpc: Some(Version::V2),
326                    method: "unsub".into(),
327                    params: Params::Array(vec![sub_id.clone()]),
328                    id: Id::Num(2),
329                }),
330                Some(Session {
331                    raw_tx: raw_tx.clone(),
332                    id: 2,
333                }),
334            )
335            .await
336            .unwrap();
337        let result = match response {
338            Output::Success(s) => s.result,
339            _ => unreachable!(),
340        };
341        assert!(!result.as_bool().unwrap());
342
343        // Unsubscribe with correct id should succeed.
344        let response = rpc
345            .handle_call(
346                Call::MethodCall(MethodCall {
347                    jsonrpc: Some(Version::V2),
348                    method: "unsub".into(),
349                    params: Params::Array(vec![sub_id.clone()]),
350                    id: Id::Num(3),
351                }),
352                Some(Session { raw_tx, id: 1 }),
353            )
354            .await
355            .unwrap();
356        let result = match response {
357            Output::Success(s) => s.result,
358            _ => unreachable!(),
359        };
360        assert!(result.as_bool().unwrap());
361    }
362
363    #[tokio::test]
364    async fn test_terminate_after_one_error() {
365        let s = terminate_after_one_error(futures_util::stream::iter([
366            PublishMsg::<u64>::result_raw_json(""),
367            PublishMsg::error_raw_json(""),
368            PublishMsg::result_raw_json(""),
369        ]));
370        assert_eq!(s.count().await, 2);
371    }
372
373    #[test]
374    fn test_format_message() {
375        let msg = format_msg(
376            "id",
377            &serde_json::to_string("notification").unwrap(),
378            PublishMsg::result(&3u64),
379        );
380        let msg: serde_json::Value = serde_json::from_str(&msg).unwrap();
381        assert_eq!(msg["method"].as_str(), Some("notification"));
382        assert_eq!(msg["params"]["subscription"].as_str(), Some("id"));
383        assert_eq!(msg["params"]["result"].as_u64(), Some(3));
384    }
385}