jsonrpc_utils/
stream.rs

1//! JSONRPC server on any streams, e.g. TCP, unix socket.
2//!
3//! Use `tokio_util::codec` to convert `AsyncRead`, `AsyncWrite` to `Stream`
4//! and `Sink`. Use `LinesCodec` or define you own codec.
5
6use std::{ops::Deref, sync::atomic::AtomicU64, time::Duration};
7
8use futures_core::{future::BoxFuture, Future, Stream};
9use futures_util::{
10    future::{self, Shared},
11    FutureExt, Sink, SinkExt, StreamExt,
12};
13use jsonrpc_core::{MetaIoHandler, Metadata};
14use tokio::{sync::mpsc::channel, time::Instant};
15
16use crate::pub_sub::Session;
17
18#[derive(Clone)]
19pub struct StreamServerConfig {
20    pub(crate) channel_size: usize,
21    pub(crate) pipeline_size: usize,
22    pub(crate) keep_alive: bool,
23    pub(crate) keep_alive_duration: Duration,
24    pub(crate) ping_interval: Duration,
25    pub(crate) shutdown_signal: Shared<BoxFuture<'static, ()>>,
26}
27
28impl Default for StreamServerConfig {
29    fn default() -> Self {
30        Self {
31            channel_size: 8,
32            pipeline_size: 1,
33            keep_alive: false,
34            keep_alive_duration: Duration::from_secs(60),
35            ping_interval: Duration::from_secs(19),
36            shutdown_signal: future::pending().boxed().shared(),
37        }
38    }
39}
40
41impl StreamServerConfig {
42    /// Set pub-sub channel buffer size.
43    ///
44    /// Default is 8.
45    ///
46    /// # Panics
47    ///
48    /// If channel_size is 0.
49    pub fn with_channel_size(mut self, channel_size: usize) -> Self {
50        assert!(channel_size > 0);
51        self.channel_size = channel_size;
52        self
53    }
54
55    /// Set maximum request pipelining.
56    ///
57    /// Up to `pipeline_size` number of requests will be handled concurrently.
58    ///
59    /// Default is 1, i.e. no pipelining.
60    ///
61    /// # Panics
62    ///
63    /// if `pipeline_size` is 0.
64    pub fn with_pipeline_size(mut self, pipeline_size: usize) -> Self {
65        assert!(pipeline_size > 0);
66        self.pipeline_size = pipeline_size;
67        self
68    }
69
70    /// Set whether keep alive is enabled.
71    ///
72    /// Default is false.
73    pub fn with_keep_alive(mut self, keep_alive: bool) -> Self {
74        self.keep_alive = keep_alive;
75        self
76    }
77
78    /// Wait for `keep_alive_duration` after the last message is received, then
79    /// close the connection.
80    ///
81    /// Default is 60 seconds.
82    pub fn with_keep_alive_duration(mut self, keep_alive_duration: Duration) -> Self {
83        self.keep_alive_duration = keep_alive_duration;
84        self
85    }
86
87    /// Set interval to send ping messages.
88    ///
89    /// Default is 19 seconds.
90    pub fn with_ping_interval(mut self, ping_interval: Duration) -> Self {
91        self.ping_interval = ping_interval;
92        self
93    }
94
95    pub fn with_shutdown<S>(mut self, shutdown: S) -> StreamServerConfig
96    where
97        S: Future<Output = ()> + Send + 'static,
98    {
99        self.shutdown_signal = shutdown.boxed().shared();
100        self
101    }
102}
103
104/// Request/response message for streaming JSON-RPC servers.
105///
106/// S should be some string-like type. For TCP it's `String`, for WebSocket
107/// it's `Utf8Bytes`.
108#[derive(Debug, PartialEq, Eq)]
109pub enum StreamMsg<S> {
110    Str(S),
111    Ping,
112    Pong,
113}
114
115/// Serve JSON-RPC requests over a bidirectional stream (Stream + Sink).
116///
117/// # Keepalive
118///
119/// We will response to ping messages with pong messages. We will send out ping
120/// messages at the specified interval if keepalive is enabled. If keepalive is
121/// enabled and we don't receive any messages over the stream for
122/// `keep_alive_duration`, we will stop serving (and this function will return).
123pub async fn serve_stream_sink<E, T, S>(
124    rpc: &MetaIoHandler<T>,
125    mut sink: impl Sink<StreamMsg<S>, Error = E> + Unpin,
126    stream: impl Stream<Item = Result<StreamMsg<S>, E>> + Unpin,
127    config: StreamServerConfig,
128) -> Result<(), E>
129where
130    T: Metadata + From<Session>,
131    S: From<String> + Deref<Target = str>,
132{
133    static SESSION_ID: AtomicU64 = AtomicU64::new(0);
134
135    let (tx, mut rx) = channel(config.channel_size);
136    let session = Session {
137        id: SESSION_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
138        raw_tx: tx,
139    };
140
141    let dead_timer = tokio::time::sleep(config.keep_alive_duration);
142    tokio::pin!(dead_timer);
143    let mut ping_interval = tokio::time::interval(config.ping_interval);
144    ping_interval.reset();
145    ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
146
147    let mut result_stream = stream
148        .map(|message_or_err| async {
149            let msg = message_or_err?;
150            match msg {
151                StreamMsg::Str(msg) => Ok(rpc
152                    .handle_request(&msg, session.clone().into())
153                    .await
154                    .map(|res| StreamMsg::Str(res.into()))),
155                StreamMsg::Ping => Ok(Some(StreamMsg::Pong)),
156                StreamMsg::Pong => Ok(None),
157            }
158        })
159        .buffer_unordered(config.pipeline_size);
160    let mut shutdown = config.shutdown_signal;
161    loop {
162        tokio::select! {
163            biased;
164            // Response/pong messages.
165            result = result_stream.next() => {
166                match result {
167                    Some(result) => {
168                        // Stop serving if the stream returns an error.
169                        if let Some(s) = result? {
170                            sink.send(s).await?;
171                        }
172                        // Reset the keepalive timer if we have received anything from the stream.
173                        // Ordinary messages as well as pings and pongs will all reset the timer.
174                        if config.keep_alive {
175                            dead_timer
176                                .as_mut()
177                                .reset(Instant::now() + config.keep_alive_duration);
178                        }
179                    }
180                    // Stop serving if the stream ends.
181                    None => {
182                        break;
183                    }
184                }
185            }
186            // Subscritpion response messages. This will never be None.
187            Some(msg) = rx.recv() => {
188                sink.send(StreamMsg::Str(msg.into())).await?;
189            }
190            _ = ping_interval.tick(), if config.keep_alive => {
191                sink.send(StreamMsg::Ping).await?;
192            }
193            _ = &mut dead_timer, if config.keep_alive => {
194                break;
195            }
196            _ = &mut shutdown => {
197                break;
198            }
199        }
200    }
201    Ok(())
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use futures_util::stream;
208
209    #[tokio::test]
210    async fn test_ping_pong() {
211        let rpc = MetaIoHandler::<Session>::default();
212
213        let stream = stream::iter([Ok(StreamMsg::Ping)]);
214        let mut sink: Vec<StreamMsg<String>> = Vec::new();
215
216        let result = serve_stream_sink(&rpc, &mut sink, stream, Default::default()).await;
217
218        assert!(result.is_ok());
219        assert_eq!(sink, [StreamMsg::Pong]);
220    }
221
222    #[tokio::test]
223    async fn test_subscription() {
224        let mut rpc = MetaIoHandler::default();
225        // Here we use add_method_with_meta instead of our add_pub_sub so that we are sure that
226        // the subscription data is sent before the subscirption ok response.
227        rpc.add_method_with_meta("subscribe", |_params, session: Session| async move {
228            // Send a subscription response through the channel
229            session
230                .raw_tx
231                .send("subscription_data".to_string())
232                .await
233                .unwrap();
234            Ok(serde_json::Value::String("ok".to_string()))
235        });
236
237        let stream = async_stream::stream! {
238            yield Ok(StreamMsg::Str(
239                r#"{"jsonrpc":"2.0","method":"subscribe","params":[],"id":1}"#.to_string(),
240            ));
241            tokio::time::sleep(Duration::from_secs(1)).await;
242        };
243        tokio::pin!(stream);
244        let mut sink: Vec<StreamMsg<String>> = Vec::new();
245
246        let result = serve_stream_sink(&rpc, &mut sink, stream, Default::default()).await;
247
248        assert!(result.is_ok());
249        assert_eq!(sink.len(), 2);
250        // Test that we receive the subscription ok response and the subscription data, and
251        // in the correct order.
252        assert_eq!(
253            serde_json::from_str::<serde_json::Value>(match &sink[0] {
254                StreamMsg::Str(s) => s,
255                _ => panic!("Expected StreamMsg::Str, got: {:?}", sink[0]),
256            })
257            .unwrap(),
258            serde_json::json!({ "jsonrpc": "2.0", "result": "ok", "id": 1 })
259        );
260        assert_eq!(sink[1], StreamMsg::Str("subscription_data".to_string()));
261    }
262}