1use std::{ops::Deref, sync::atomic::AtomicU64, time::Duration};
7
8use futures_core::{future::BoxFuture, Future, Stream};
9use futures_util::{
10 future::{self, Shared},
11 FutureExt, Sink, SinkExt, StreamExt,
12};
13use jsonrpc_core::{MetaIoHandler, Metadata};
14use tokio::{sync::mpsc::channel, time::Instant};
15
16use crate::pub_sub::Session;
17
18#[derive(Clone)]
19pub struct StreamServerConfig {
20 pub(crate) channel_size: usize,
21 pub(crate) pipeline_size: usize,
22 pub(crate) keep_alive: bool,
23 pub(crate) keep_alive_duration: Duration,
24 pub(crate) ping_interval: Duration,
25 pub(crate) shutdown_signal: Shared<BoxFuture<'static, ()>>,
26}
27
28impl Default for StreamServerConfig {
29 fn default() -> Self {
30 Self {
31 channel_size: 8,
32 pipeline_size: 1,
33 keep_alive: false,
34 keep_alive_duration: Duration::from_secs(60),
35 ping_interval: Duration::from_secs(19),
36 shutdown_signal: future::pending().boxed().shared(),
37 }
38 }
39}
40
41impl StreamServerConfig {
42 pub fn with_channel_size(mut self, channel_size: usize) -> Self {
50 assert!(channel_size > 0);
51 self.channel_size = channel_size;
52 self
53 }
54
55 pub fn with_pipeline_size(mut self, pipeline_size: usize) -> Self {
65 assert!(pipeline_size > 0);
66 self.pipeline_size = pipeline_size;
67 self
68 }
69
70 pub fn with_keep_alive(mut self, keep_alive: bool) -> Self {
74 self.keep_alive = keep_alive;
75 self
76 }
77
78 pub fn with_keep_alive_duration(mut self, keep_alive_duration: Duration) -> Self {
83 self.keep_alive_duration = keep_alive_duration;
84 self
85 }
86
87 pub fn with_ping_interval(mut self, ping_interval: Duration) -> Self {
91 self.ping_interval = ping_interval;
92 self
93 }
94
95 pub fn with_shutdown<S>(mut self, shutdown: S) -> StreamServerConfig
96 where
97 S: Future<Output = ()> + Send + 'static,
98 {
99 self.shutdown_signal = shutdown.boxed().shared();
100 self
101 }
102}
103
104#[derive(Debug, PartialEq, Eq)]
109pub enum StreamMsg<S> {
110 Str(S),
111 Ping,
112 Pong,
113}
114
115pub async fn serve_stream_sink<E, T, S>(
124 rpc: &MetaIoHandler<T>,
125 mut sink: impl Sink<StreamMsg<S>, Error = E> + Unpin,
126 stream: impl Stream<Item = Result<StreamMsg<S>, E>> + Unpin,
127 config: StreamServerConfig,
128) -> Result<(), E>
129where
130 T: Metadata + From<Session>,
131 S: From<String> + Deref<Target = str>,
132{
133 static SESSION_ID: AtomicU64 = AtomicU64::new(0);
134
135 let (tx, mut rx) = channel(config.channel_size);
136 let session = Session {
137 id: SESSION_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
138 raw_tx: tx,
139 };
140
141 let dead_timer = tokio::time::sleep(config.keep_alive_duration);
142 tokio::pin!(dead_timer);
143 let mut ping_interval = tokio::time::interval(config.ping_interval);
144 ping_interval.reset();
145 ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
146
147 let mut result_stream = stream
148 .map(|message_or_err| async {
149 let msg = message_or_err?;
150 match msg {
151 StreamMsg::Str(msg) => Ok(rpc
152 .handle_request(&msg, session.clone().into())
153 .await
154 .map(|res| StreamMsg::Str(res.into()))),
155 StreamMsg::Ping => Ok(Some(StreamMsg::Pong)),
156 StreamMsg::Pong => Ok(None),
157 }
158 })
159 .buffer_unordered(config.pipeline_size);
160 let mut shutdown = config.shutdown_signal;
161 loop {
162 tokio::select! {
163 biased;
164 result = result_stream.next() => {
166 match result {
167 Some(result) => {
168 if let Some(s) = result? {
170 sink.send(s).await?;
171 }
172 if config.keep_alive {
175 dead_timer
176 .as_mut()
177 .reset(Instant::now() + config.keep_alive_duration);
178 }
179 }
180 None => {
182 break;
183 }
184 }
185 }
186 Some(msg) = rx.recv() => {
188 sink.send(StreamMsg::Str(msg.into())).await?;
189 }
190 _ = ping_interval.tick(), if config.keep_alive => {
191 sink.send(StreamMsg::Ping).await?;
192 }
193 _ = &mut dead_timer, if config.keep_alive => {
194 break;
195 }
196 _ = &mut shutdown => {
197 break;
198 }
199 }
200 }
201 Ok(())
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use futures_util::stream;
208
209 #[tokio::test]
210 async fn test_ping_pong() {
211 let rpc = MetaIoHandler::<Session>::default();
212
213 let stream = stream::iter([Ok(StreamMsg::Ping)]);
214 let mut sink: Vec<StreamMsg<String>> = Vec::new();
215
216 let result = serve_stream_sink(&rpc, &mut sink, stream, Default::default()).await;
217
218 assert!(result.is_ok());
219 assert_eq!(sink, [StreamMsg::Pong]);
220 }
221
222 #[tokio::test]
223 async fn test_subscription() {
224 let mut rpc = MetaIoHandler::default();
225 rpc.add_method_with_meta("subscribe", |_params, session: Session| async move {
228 session
230 .raw_tx
231 .send("subscription_data".to_string())
232 .await
233 .unwrap();
234 Ok(serde_json::Value::String("ok".to_string()))
235 });
236
237 let stream = async_stream::stream! {
238 yield Ok(StreamMsg::Str(
239 r#"{"jsonrpc":"2.0","method":"subscribe","params":[],"id":1}"#.to_string(),
240 ));
241 tokio::time::sleep(Duration::from_secs(1)).await;
242 };
243 tokio::pin!(stream);
244 let mut sink: Vec<StreamMsg<String>> = Vec::new();
245
246 let result = serve_stream_sink(&rpc, &mut sink, stream, Default::default()).await;
247
248 assert!(result.is_ok());
249 assert_eq!(sink.len(), 2);
250 assert_eq!(
253 serde_json::from_str::<serde_json::Value>(match &sink[0] {
254 StreamMsg::Str(s) => s,
255 _ => panic!("Expected StreamMsg::Str, got: {:?}", sink[0]),
256 })
257 .unwrap(),
258 serde_json::json!({ "jsonrpc": "2.0", "result": "ok", "id": 1 })
259 );
260 assert_eq!(sink[1], StreamMsg::Str("subscription_data".to_string()));
261 }
262}