1use std::{
4 collections::HashMap,
5 marker::PhantomData,
6 sync::{Arc, Mutex},
7};
8
9use futures_core::Stream;
10use futures_util::StreamExt;
11use jsonrpc_core::{serde::Serialize, MetaIoHandler, Metadata, Params, Value};
12use rand::{thread_rng, Rng};
13use tokio::sync::mpsc::Sender;
14
15#[derive(Clone)]
19pub struct Session {
20 pub raw_tx: Sender<String>,
21 pub id: u64,
22}
23
24impl Metadata for Session {}
25
26fn generate_id() -> String {
27 let id: [u8; 16] = thread_rng().gen();
28 let mut id_hex_bytes = vec![0u8; 34];
29 id_hex_bytes[..2].copy_from_slice(b"0x");
30 hex::encode_to_slice(id, &mut id_hex_bytes[2..]).unwrap();
31 unsafe { String::from_utf8_unchecked(id_hex_bytes) }
32}
33
34#[derive(Clone)]
36pub struct PublishMsg<T> {
37 is_err: bool,
38 value: Arc<str>,
40 phantom: PhantomData<T>,
41}
42
43impl<T: Serialize> PublishMsg<T> {
44 pub fn result(value: &T) -> Self {
48 match jsonrpc_core::serde_json::to_string(value) {
49 Ok(value) => Self {
50 is_err: false,
51 value: value.into(),
52 phantom: PhantomData,
53 },
54 Err(_) => Self::error(&jsonrpc_core::Error {
55 code: jsonrpc_core::ErrorCode::InternalError,
56 message: "".into(),
57 data: None,
58 }),
59 }
60 }
61}
62
63impl<T> PublishMsg<T> {
64 pub fn error(err: &jsonrpc_core::Error) -> Self {
70 Self {
71 is_err: true,
72 value: jsonrpc_core::serde_json::to_string(err).unwrap().into(),
73 phantom: PhantomData,
74 }
75 }
76
77 pub fn result_raw_json(value: impl Into<Arc<str>>) -> Self {
81 Self {
82 is_err: false,
83 value: value.into(),
84 phantom: PhantomData,
85 }
86 }
87
88 pub fn error_raw_json(value: impl Into<Arc<str>>) -> Self {
92 Self {
93 is_err: true,
94 value: value.into(),
95 phantom: PhantomData,
96 }
97 }
98}
99
100pub trait PubSub<T> {
108 type Stream: Stream<Item = PublishMsg<T>> + Send;
109
110 fn subscribe(&self, params: Params) -> Result<Self::Stream, jsonrpc_core::Error>;
111}
112
113impl<T, F, S> PubSub<T> for F
114where
115 F: Fn(Params) -> Result<S, jsonrpc_core::Error>,
116 S: Stream<Item = PublishMsg<T>> + Send,
117{
118 type Stream = S;
119
120 fn subscribe(&self, params: Params) -> Result<Self::Stream, jsonrpc_core::Error> {
121 (self)(params)
122 }
123}
124
125impl<T, P: PubSub<T>> PubSub<T> for Arc<P> {
126 type Stream = P::Stream;
127
128 fn subscribe(&self, params: Params) -> Result<Self::Stream, jsonrpc_core::Error> {
129 <P as PubSub<T>>::subscribe(self, params)
130 }
131}
132
133pub fn add_pub_sub<T: Send + 'static>(
141 io: &mut MetaIoHandler<Option<Session>>,
142 subscribe_method: &str,
143 notify_method: &str,
144 unsubscribe_method: &str,
145 pubsub: impl PubSub<T> + Clone + Send + Sync + 'static,
146) {
147 let subscriptions0 = Arc::new(Mutex::new(HashMap::new()));
148 let subscriptions = subscriptions0.clone();
149 let notify_method: Arc<str> = serde_json::to_string(notify_method).unwrap().into();
150 io.add_method_with_meta(
151 subscribe_method,
152 move |params: Params, session: Option<Session>| {
153 let subscriptions = subscriptions.clone();
154 let pubsub = pubsub.clone();
155 let notify_method = notify_method.clone();
156 async move {
157 let session = session.ok_or_else(jsonrpc_core::Error::method_not_found)?;
158 let session_id = session.id;
159 let id = generate_id();
160 let stream = pubsub.subscribe(params)?;
161 let stream = terminate_after_one_error(stream);
162 let handle = tokio::spawn({
163 let id = id.clone();
164 let subscriptions = subscriptions.clone();
165 async move {
166 tokio::pin!(stream);
167 loop {
168 tokio::select! {
169 biased;
170 msg = stream.next() => {
171 match msg {
172 Some(msg) => {
173 let msg = format_msg(&id, ¬ify_method, msg);
174 if session.raw_tx.send(msg).await.is_err() {
175 break;
176 }
177 }
178 None => break,
179 }
180 }
181 _ = session.raw_tx.closed() => {
182 break;
183 }
184 }
185 }
186 subscriptions.lock().unwrap().remove(&(session_id, id));
187 }
188 });
189 subscriptions
190 .lock()
191 .unwrap()
192 .insert((session_id, id.clone()), handle);
193 Ok(Value::String(id))
194 }
195 },
196 );
197 io.add_method_with_meta(
198 unsubscribe_method,
199 move |params: Params, session: Option<Session>| {
200 let subscriptions = subscriptions0.clone();
201 async move {
202 let (id,): (String,) = params.parse()?;
203 let session_id = if let Some(session) = session {
204 session.id
205 } else {
206 return Ok(Value::Bool(false));
207 };
208 let result =
209 if let Some(handle) = subscriptions.lock().unwrap().remove(&(session_id, id)) {
210 handle.abort();
211 true
212 } else {
213 false
214 };
215 Ok(Value::Bool(result))
216 }
217 },
218 );
219}
220
221fn format_msg<T>(id: &str, method: &str, msg: PublishMsg<T>) -> String {
222 match msg.is_err {
223 false => format!(
224 r#"{{"jsonrpc":"2.0","method":{},"params":{{"subscription":"{}","result":{}}}}}"#,
225 method, id, msg.value,
226 ),
227 true => format!(
228 r#"{{"jsonrpc":"2.0","method":{},"params":{{"subscription":"{}","error":{}}}}}"#,
229 method, id, msg.value,
230 ),
231 }
232}
233
234pin_project_lite::pin_project! {
235 struct TerminateAfterOneError<S> {
236 #[pin]
237 inner: S,
238 has_error: bool,
239 }
240}
241
242impl<S, T> Stream for TerminateAfterOneError<S>
243where
244 S: Stream<Item = PublishMsg<T>>,
245{
246 type Item = PublishMsg<T>;
247
248 fn poll_next(
249 self: std::pin::Pin<&mut Self>,
250 cx: &mut std::task::Context<'_>,
251 ) -> std::task::Poll<Option<Self::Item>> {
252 if self.has_error {
253 return None.into();
254 }
255 let proj = self.project();
256 match futures_core::ready!(proj.inner.poll_next(cx)) {
257 None => None.into(),
258 Some(msg) => {
259 if msg.is_err {
260 *proj.has_error = true;
261 }
262 Some(msg).into()
263 }
264 }
265 }
266}
267
268fn terminate_after_one_error<S>(s: S) -> TerminateAfterOneError<S> {
269 TerminateAfterOneError {
270 inner: s,
271 has_error: false,
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use async_stream::stream;
278 use jsonrpc_core::{Call, Id, MethodCall, Output, Version};
279 use tokio::sync::mpsc::channel;
280
281 use super::*;
282
283 #[test]
284 fn test_id() {
285 let id = generate_id();
286 assert!(std::str::from_utf8(id.as_bytes()).is_ok());
287 }
288
289 #[tokio::test]
290 async fn test_pubsub() {
291 let mut rpc = MetaIoHandler::with_compatibility(jsonrpc_core::Compatibility::V2);
292 add_pub_sub(&mut rpc, "sub", "notify", "unsub", |_params| {
293 Ok(stream! {
294 yield PublishMsg::result(&1);
295 yield PublishMsg::result(&1);
296 })
297 });
298 let (raw_tx, mut rx) = channel(1);
299 let response = rpc
300 .handle_call(
301 Call::MethodCall(MethodCall {
302 jsonrpc: Some(Version::V2),
303 method: "sub".into(),
304 params: Params::None,
305 id: Id::Num(1),
306 }),
307 Some(Session {
308 raw_tx: raw_tx.clone(),
309 id: 1,
310 }),
311 )
312 .await
313 .unwrap();
314 let sub_id = match response {
315 Output::Success(s) => s.result,
316 _ => unreachable!(),
317 };
318
319 assert!(rx.recv().await.is_some());
320
321 let response = rpc
323 .handle_call(
324 Call::MethodCall(MethodCall {
325 jsonrpc: Some(Version::V2),
326 method: "unsub".into(),
327 params: Params::Array(vec![sub_id.clone()]),
328 id: Id::Num(2),
329 }),
330 Some(Session {
331 raw_tx: raw_tx.clone(),
332 id: 2,
333 }),
334 )
335 .await
336 .unwrap();
337 let result = match response {
338 Output::Success(s) => s.result,
339 _ => unreachable!(),
340 };
341 assert!(!result.as_bool().unwrap());
342
343 let response = rpc
345 .handle_call(
346 Call::MethodCall(MethodCall {
347 jsonrpc: Some(Version::V2),
348 method: "unsub".into(),
349 params: Params::Array(vec![sub_id.clone()]),
350 id: Id::Num(3),
351 }),
352 Some(Session { raw_tx, id: 1 }),
353 )
354 .await
355 .unwrap();
356 let result = match response {
357 Output::Success(s) => s.result,
358 _ => unreachable!(),
359 };
360 assert!(result.as_bool().unwrap());
361 }
362
363 #[tokio::test]
364 async fn test_terminate_after_one_error() {
365 let s = terminate_after_one_error(futures_util::stream::iter([
366 PublishMsg::<u64>::result_raw_json(""),
367 PublishMsg::error_raw_json(""),
368 PublishMsg::result_raw_json(""),
369 ]));
370 assert_eq!(s.count().await, 2);
371 }
372
373 #[test]
374 fn test_format_message() {
375 let msg = format_msg(
376 "id",
377 &serde_json::to_string("notification").unwrap(),
378 PublishMsg::result(&3u64),
379 );
380 let msg: serde_json::Value = serde_json::from_str(&msg).unwrap();
381 assert_eq!(msg["method"].as_str(), Some("notification"));
382 assert_eq!(msg["params"]["subscription"].as_str(), Some("id"));
383 assert_eq!(msg["params"]["result"].as_u64(), Some(3));
384 }
385}