1use std::sync::Arc;
2use std::time::Instant;
3
4use crate::future::{IntervalStream, SessionClose};
5use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT};
6use crate::server::{handle_rpc_call, ConnectionState, ServerConfig};
7use crate::{HttpBody, HttpRequest, HttpResponse, PingConfig, LOG_TARGET};
8
9use futures_util::future::{self, Either};
10use futures_util::io::{BufReader, BufWriter};
11use futures_util::{Future, StreamExt, TryStreamExt};
12use hyper::upgrade::Upgraded;
13use hyper_util::rt::TokioIo;
14use jsonrpsee_core::server::{BoundedSubscriptions, MethodSink, Methods};
15use jsonrpsee_types::error::{reject_too_big_request, ErrorCode};
16use jsonrpsee_types::Id;
17use soketto::connection::Error as SokettoError;
18use soketto::data::ByteSlice125;
19
20use tokio::sync::{mpsc, oneshot};
21use tokio::time::{interval, interval_at};
22use tokio_stream::wrappers::ReceiverStream;
23use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
24
25pub(crate) type Sender = soketto::Sender<BufReader<BufWriter<Compat<TokioIo<Upgraded>>>>>;
26pub(crate) type Receiver = soketto::Receiver<BufReader<BufWriter<Compat<TokioIo<Upgraded>>>>>;
27
28pub use soketto::handshake::http::is_upgrade_request;
29
30enum Incoming {
31 Data(Vec<u8>),
32 Pong,
33}
34
35pub(crate) async fn send_message(sender: &mut Sender, response: String) -> Result<(), SokettoError> {
36 sender.send_text_owned(response).await?;
37 sender.flush().await.map_err(Into::into)
38}
39
40pub(crate) async fn send_ping(sender: &mut Sender) -> Result<(), SokettoError> {
41 tracing::debug!(target: LOG_TARGET, "Send ping");
42 let slice: &[u8] = &[];
44 let byte_slice = ByteSlice125::try_from(slice).expect("Empty slice should fit into ByteSlice125");
46 sender.send_ping(byte_slice).await?;
47 sender.flush().await.map_err(Into::into)
48}
49
50pub(crate) struct BackgroundTaskParams<S> {
51 pub(crate) server_cfg: ServerConfig,
52 pub(crate) conn: ConnectionState,
53 pub(crate) ws_sender: Sender,
54 pub(crate) ws_receiver: Receiver,
55 pub(crate) rpc_service: S,
56 pub(crate) sink: MethodSink,
57 pub(crate) rx: mpsc::Receiver<String>,
58 pub(crate) pending_calls_completed: mpsc::Receiver<()>,
59 pub(crate) on_session_close: Option<SessionClose>,
60 pub(crate) extensions: http::Extensions,
61}
62
63pub(crate) async fn background_task<S>(params: BackgroundTaskParams<S>)
64where
65 for<'a> S: RpcServiceT<'a> + Send + Sync + 'static,
66{
67 let BackgroundTaskParams {
68 server_cfg,
69 conn,
70 ws_sender,
71 ws_receiver,
72 rpc_service,
73 sink,
74 rx,
75 pending_calls_completed,
76 mut on_session_close,
77 extensions,
78 } = params;
79 let ServerConfig { ping_config, batch_requests_config, max_request_body_size, max_response_body_size, .. } =
80 server_cfg;
81
82 let (conn_tx, conn_rx) = oneshot::channel();
83
84 let send_task_handle = tokio::spawn(send_task(rx, ws_sender, ping_config, conn_rx));
86
87 let stopped = conn.stop_handle.clone().shutdown();
88 let rpc_service = Arc::new(rpc_service);
89 let mut missed_pings = 0;
90
91 tokio::pin!(stopped);
92
93 let ws_stream = futures_util::stream::unfold(ws_receiver, |mut receiver| async {
94 let mut data = Vec::new();
95 match receiver.receive(&mut data).await {
96 Ok(soketto::Incoming::Data(_)) => Some((Ok(Incoming::Data(data)), receiver)),
97 Ok(soketto::Incoming::Pong(_)) => Some((Ok(Incoming::Pong), receiver)),
98 Ok(soketto::Incoming::Closed(_)) | Err(SokettoError::Closed) => None,
99 Err(e) => Some((Err(e), receiver)),
102 }
103 })
104 .fuse();
105
106 tokio::pin!(ws_stream);
107
108 let result = loop {
109 let data = match try_recv(&mut ws_stream, stopped, ping_config, &mut missed_pings).await {
110 Receive::ConnectionClosed => break Ok(Shutdown::ConnectionClosed),
111 Receive::Stopped => break Ok(Shutdown::Stopped),
112 Receive::Ok(data, stop) => {
113 stopped = stop;
114 data
115 }
116 Receive::Err(err, stop) => {
117 stopped = stop;
118
119 match err {
120 SokettoError::Closed => {
121 break Ok(Shutdown::ConnectionClosed);
122 }
123 SokettoError::MessageTooLarge { current, maximum } => {
124 tracing::debug!(
125 target: LOG_TARGET,
126 "WS recv error: message too large current={}/max={}",
127 current,
128 maximum
129 );
130 if sink.send_error(Id::Null, reject_too_big_request(max_request_body_size)).await.is_err() {
131 break Ok(Shutdown::ConnectionClosed);
132 }
133
134 continue;
135 }
136 err => {
137 tracing::debug!(target: LOG_TARGET, "WS error: {}; terminate connection: {}", err, conn.conn_id);
138 break Err(err);
139 }
140 };
141 }
142 };
143
144 let rpc_service = rpc_service.clone();
145 let sink = sink.clone();
146 let extensions = extensions.clone();
147
148 tokio::spawn(async move {
149 let first_non_whitespace = data.iter().enumerate().take(128).find(|(_, byte)| !byte.is_ascii_whitespace());
150
151 let (idx, is_single) = match first_non_whitespace {
152 Some((start, b'{')) => (start, true),
153 Some((start, b'[')) => (start, false),
154 _ => {
155 _ = sink.send_error(Id::Null, ErrorCode::ParseError.into()).await;
156 return;
157 }
158 };
159
160 if let Some(rp) = handle_rpc_call(
161 &data[idx..],
162 is_single,
163 batch_requests_config,
164 max_response_body_size,
165 &*rpc_service,
166 extensions,
167 )
168 .await
169 {
170 if !rp.is_subscription() {
171 let is_success = rp.is_success();
172 let (serialized_rp, mut on_close) = rp.into_parts();
173
174 if sink.send(serialized_rp).await.is_err() {
176 return;
177 }
178
179 if let Some(n) = on_close.take() {
182 n.notify(is_success);
183 }
184 }
185 }
186 });
187 };
188
189 drop(rpc_service);
193 graceful_shutdown(result, pending_calls_completed, ws_stream, conn_tx, send_task_handle).await;
194
195 drop(conn);
196
197 if let Some(c) = on_session_close.take() {
198 c.close();
199 }
200}
201
202async fn send_task(
204 rx: mpsc::Receiver<String>,
205 mut ws_sender: Sender,
206 ping_config: Option<PingConfig>,
207 stop: oneshot::Receiver<()>,
208) {
209 let ping_interval = match ping_config {
210 None => IntervalStream::pending(),
211 Some(p) => IntervalStream::new(interval(p.ping_interval)),
215 };
216 let rx = ReceiverStream::new(rx);
217
218 tokio::pin!(ping_interval, rx, stop);
219
220 let mut rx_item = rx.next();
222 let next_ping = ping_interval.next();
223 let mut futs = future::select(next_ping, stop);
224
225 loop {
226 match future::select(rx_item, futs).await {
229 Either::Left((Some(response), not_ready)) => {
231 if let Err(err) = send_message(&mut ws_sender, response).await {
233 tracing::debug!(target: LOG_TARGET, "WS send error: {}", err);
234 break;
235 }
236
237 rx_item = rx.next();
238 futs = not_ready;
239 }
240
241 Either::Left((None, _)) => {
243 break;
244 }
245
246 Either::Right((Either::Left((_instant, _stopped)), next_rx)) => {
248 stop = _stopped;
249 if let Err(err) = send_ping(&mut ws_sender).await {
250 tracing::debug!(target: LOG_TARGET, "WS send ping error: {}", err);
251 break;
252 }
253
254 rx_item = next_rx;
255 futs = future::select(ping_interval.next(), stop);
256 }
257 Either::Right((Either::Right((_stopped, _)), _)) => {
258 break;
260 }
261 }
262 }
263
264 let _ = ws_sender.close().await;
266 rx.close();
267}
268
269enum Receive<S> {
270 ConnectionClosed,
271 Stopped,
272 Err(SokettoError, S),
273 Ok(Vec<u8>, S),
274}
275
276async fn try_recv<T, S>(
278 ws_stream: &mut T,
279 mut stopped: S,
280 ping_config: Option<PingConfig>,
281 missed_pings: &mut usize,
282) -> Receive<S>
283where
284 S: Future<Output = ()> + Unpin,
285 T: StreamExt<Item = Result<Incoming, SokettoError>> + Unpin,
286{
287 let mut last_active = Instant::now();
288 let inactivity_check = match ping_config {
289 Some(p) => IntervalStream::new(interval_at(tokio::time::Instant::now() + p.ping_interval, p.ping_interval)),
290 None => IntervalStream::pending(),
291 };
292
293 tokio::pin!(inactivity_check);
294
295 let mut futs = futures_util::future::select(ws_stream.next(), inactivity_check.next());
296
297 loop {
298 match futures_util::future::select(futs, stopped).await {
299 Either::Left((Either::Left((None, _)), _)) => break Receive::ConnectionClosed,
301 Either::Left((Either::Left((Some(Ok(Incoming::Data(d))), _)), s)) => break Receive::Ok(d, s),
303 Either::Left((Either::Left((Some(Ok(Incoming::Pong)), inactive)), s)) => {
305 last_active = Instant::now();
306 stopped = s;
307 futs = futures_util::future::select(ws_stream.next(), inactive);
308 }
309 Either::Left((Either::Left((Some(Err(e)), _)), s)) => break Receive::Err(e, s),
311 Either::Left((Either::Right((_instant, rcv)), s)) => {
313 if let Some(p) = ping_config {
314 if last_active.elapsed() > p.inactive_limit {
315 *missed_pings += 1;
316
317 if *missed_pings >= p.max_failures {
318 tracing::debug!(
319 target: LOG_TARGET,
320 "WS ping/pong inactivity limit `{}` exceeded; closing connection",
321 p.max_failures,
322 );
323 break Receive::ConnectionClosed;
324 }
325 }
326 }
327
328 stopped = s;
329 futs = futures_util::future::select(rcv, inactivity_check.next());
330 }
331 Either::Right(_) => break Receive::Stopped,
333 }
334 }
335}
336
337#[derive(Debug, Copy, Clone)]
338pub(crate) enum Shutdown {
339 Stopped,
340 ConnectionClosed,
341}
342
343async fn graceful_shutdown<S>(
347 result: Result<Shutdown, SokettoError>,
348 pending_calls: mpsc::Receiver<()>,
349 ws_stream: S,
350 mut conn_tx: oneshot::Sender<()>,
351 send_task_handle: tokio::task::JoinHandle<()>,
352) where
353 S: StreamExt<Item = Result<Incoming, SokettoError>> + Unpin,
354{
355 let pending_calls = ReceiverStream::new(pending_calls);
356
357 if let Ok(Shutdown::Stopped) = result {
358 let graceful_shutdown = pending_calls.for_each(|_| async {});
359 let disconnect = ws_stream.try_for_each(|_| async { Ok(()) });
360
361 tokio::select! {
362 _ = graceful_shutdown => {}
363 res = disconnect => {
364 if let Err(err) = res {
365 tracing::warn!(target: LOG_TARGET, "Graceful shutdown terminated because of error: `{err}`");
366 }
367 }
368 _ = conn_tx.closed() => {}
369 }
370 }
371
372 _ = conn_tx.send(());
374 _ = send_task_handle.await;
376}
377
378pub async fn connect<L, B>(
423 req: HttpRequest<B>,
424 server_cfg: ServerConfig,
425 methods: impl Into<Methods>,
426 conn: ConnectionState,
427 rpc_middleware: RpcServiceBuilder<L>,
428) -> Result<(HttpResponse, impl Future<Output = ()>), HttpResponse>
429where
430 L: for<'a> tower::Layer<RpcService>,
431 <L as tower::Layer<RpcService>>::Service: Send + Sync + 'static,
432 for<'a> <L as tower::Layer<RpcService>>::Service: RpcServiceT<'a>,
433{
434 let mut server = soketto::handshake::http::Server::new();
435
436 match server.receive_request(&req) {
437 Ok(response) => {
438 let (tx, rx) = mpsc::channel::<String>(server_cfg.message_buffer_capacity as usize);
439 let sink = MethodSink::new(tx);
440
441 let (pending_calls, pending_calls_completed) = mpsc::channel::<()>(1);
445
446 let rpc_service_cfg = RpcServiceCfg::CallsAndSubscriptions {
447 bounded_subscriptions: BoundedSubscriptions::new(server_cfg.max_subscriptions_per_connection),
448 id_provider: server_cfg.id_provider.clone(),
449 sink: sink.clone(),
450 _pending_calls: pending_calls,
451 };
452
453 let rpc_service = RpcService::new(
454 methods.into(),
455 server_cfg.max_response_body_size as usize,
456 conn.conn_id.into(),
457 rpc_service_cfg,
458 );
459
460 let rpc_service = rpc_middleware.service(rpc_service);
461
462 let fut = async move {
465 let extensions = req.extensions().clone();
466
467 let upgraded = match hyper::upgrade::on(req).await {
468 Ok(upgraded) => upgraded,
469 Err(e) => {
470 tracing::debug!(target: LOG_TARGET, "WS upgrade handshake failed: {}", e);
471 return;
472 }
473 };
474
475 let io = TokioIo::new(upgraded);
476
477 let stream = BufReader::new(BufWriter::new(io.compat()));
478 let mut ws_builder = server.into_builder(stream);
479 ws_builder.set_max_message_size(server_cfg.max_response_body_size as usize);
480 let (sender, receiver) = ws_builder.finish();
481
482 let params = BackgroundTaskParams {
483 server_cfg,
484 conn,
485 ws_sender: sender,
486 ws_receiver: receiver,
487 rpc_service,
488 sink,
489 rx,
490 pending_calls_completed,
491 on_session_close: None,
492 extensions,
493 };
494
495 background_task(params).await;
496 };
497
498 Ok((response.map(|()| HttpBody::default()), fut))
499 }
500 Err(e) => {
501 tracing::debug!(target: LOG_TARGET, "WS upgrade handshake failed: {}", e);
502 Err(HttpResponse::new(HttpBody::from(format!("WS upgrade handshake failed: {e}"))))
503 }
504 }
505}