1use std::{sync::atomic::AtomicU64, time::Duration};
7
8use futures_core::{future::BoxFuture, Future, Stream};
9use futures_util::{
10 future::{self, Shared},
11 FutureExt, Sink, SinkExt, StreamExt,
12};
13use jsonrpc_core::{MetaIoHandler, Metadata};
14use tokio::{sync::mpsc::channel, time::Instant};
15
16use crate::pub_sub::Session;
17
18#[derive(Clone)]
19pub struct StreamServerConfig {
20 pub(crate) channel_size: usize,
21 pub(crate) pipeline_size: usize,
22 pub(crate) keep_alive: bool,
23 pub(crate) keep_alive_duration: Duration,
24 pub(crate) ping_interval: Duration,
25 pub(crate) shutdown_signal: Shared<BoxFuture<'static, ()>>,
26}
27
28impl Default for StreamServerConfig {
29 fn default() -> Self {
30 Self {
31 channel_size: 8,
32 pipeline_size: 1,
33 keep_alive: false,
34 keep_alive_duration: Duration::from_secs(60),
35 ping_interval: Duration::from_secs(19),
36 shutdown_signal: future::pending().boxed().shared(),
37 }
38 }
39}
40
41impl StreamServerConfig {
42 pub fn with_channel_size(mut self, channel_size: usize) -> Self {
50 assert!(channel_size > 0);
51 self.channel_size = channel_size;
52 self
53 }
54
55 pub fn with_pipeline_size(mut self, pipeline_size: usize) -> Self {
65 assert!(pipeline_size > 0);
66 self.pipeline_size = pipeline_size;
67 self
68 }
69
70 pub fn with_keep_alive(mut self, keep_alive: bool) -> Self {
74 self.keep_alive = keep_alive;
75 self
76 }
77
78 pub fn with_keep_alive_duration(mut self, keep_alive_duration: Duration) -> Self {
83 self.keep_alive_duration = keep_alive_duration;
84 self
85 }
86
87 pub fn with_ping_interval(mut self, ping_interval: Duration) -> Self {
91 self.ping_interval = ping_interval;
92 self
93 }
94
95 pub fn with_shutdown<S>(mut self, shutdown: S) -> StreamServerConfig
96 where
97 S: Future<Output = ()> + Send + 'static,
98 {
99 self.shutdown_signal = shutdown.boxed().shared();
100 self
101 }
102}
103
104pub enum StreamMsg {
105 Str(String),
106 Ping,
107 Pong,
108}
109
110pub async fn serve_stream_sink<E, T: Metadata + From<Session>>(
116 rpc: &MetaIoHandler<T>,
117 mut sink: impl Sink<StreamMsg, Error = E> + Unpin,
118 stream: impl Stream<Item = Result<StreamMsg, E>> + Unpin,
119 config: StreamServerConfig,
120) -> Result<(), E> {
121 static SESSION_ID: AtomicU64 = AtomicU64::new(0);
122
123 let (tx, mut rx) = channel(config.channel_size);
124 let session = Session {
125 id: SESSION_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
126 raw_tx: tx,
127 };
128
129 let dead_timer = tokio::time::sleep(config.keep_alive_duration);
130 tokio::pin!(dead_timer);
131 let mut ping_interval = tokio::time::interval(config.ping_interval);
132 ping_interval.reset();
133 ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
134
135 let mut result_stream = stream
136 .map(|message_or_err| async {
137 let msg = message_or_err?;
138 let msg = match msg {
139 StreamMsg::Str(msg) => msg,
140 _ => return Ok(None),
141 };
142 Ok::<_, E>(rpc.handle_request(&msg, session.clone().into()).await)
143 })
144 .buffer_unordered(config.pipeline_size);
145 let mut shutdown = config.shutdown_signal;
146 loop {
147 tokio::select! {
148 biased;
152 result = result_stream.next() => {
153 match result {
154 Some(result) => {
155 if let Some(result) = result? {
156 sink.send(StreamMsg::Str(result)).await?;
157 }
158 if config.keep_alive {
159 dead_timer
160 .as_mut()
161 .reset(Instant::now() + config.keep_alive_duration);
162 }
163 }
164 _ => break,
165 }
166 }
167 Some(msg) = rx.recv() => {
169 sink.send(StreamMsg::Str(msg)).await?;
170 }
171 _ = ping_interval.tick(), if config.keep_alive => {
172 sink.send(StreamMsg::Ping).await?;
173 }
174 _ = &mut dead_timer, if config.keep_alive => {
175 break;
176 }
177 _ = &mut shutdown => {
178 break;
179 }
180 }
181 }
182 Ok(())
183}