jsonrpsee_server/transport/
ws.rs

1use std::sync::Arc;
2use std::time::Instant;
3
4use crate::future::{IntervalStream, SessionClose};
5use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT};
6use crate::server::{handle_rpc_call, ConnectionState, ServerConfig};
7use crate::{HttpBody, HttpRequest, HttpResponse, PingConfig, LOG_TARGET};
8
9use futures_util::future::{self, Either};
10use futures_util::io::{BufReader, BufWriter};
11use futures_util::{Future, StreamExt, TryStreamExt};
12use hyper::upgrade::Upgraded;
13use hyper_util::rt::TokioIo;
14use jsonrpsee_core::server::{BoundedSubscriptions, MethodSink, Methods};
15use jsonrpsee_types::error::{reject_too_big_request, ErrorCode};
16use jsonrpsee_types::Id;
17use soketto::connection::Error as SokettoError;
18use soketto::data::ByteSlice125;
19
20use tokio::sync::{mpsc, oneshot};
21use tokio::time::{interval, interval_at};
22use tokio_stream::wrappers::ReceiverStream;
23use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
24
25pub(crate) type Sender = soketto::Sender<BufReader<BufWriter<Compat<TokioIo<Upgraded>>>>>;
26pub(crate) type Receiver = soketto::Receiver<BufReader<BufWriter<Compat<TokioIo<Upgraded>>>>>;
27
28pub use soketto::handshake::http::is_upgrade_request;
29
30enum Incoming {
31	Data(Vec<u8>),
32	Pong,
33}
34
35pub(crate) async fn send_message(sender: &mut Sender, response: String) -> Result<(), SokettoError> {
36	sender.send_text_owned(response).await?;
37	sender.flush().await.map_err(Into::into)
38}
39
40pub(crate) async fn send_ping(sender: &mut Sender) -> Result<(), SokettoError> {
41	tracing::debug!(target: LOG_TARGET, "Send ping");
42	// Submit empty slice as "optional" parameter.
43	let slice: &[u8] = &[];
44	// Byte slice fails if the provided slice is larger than 125 bytes.
45	let byte_slice = ByteSlice125::try_from(slice).expect("Empty slice should fit into ByteSlice125");
46	sender.send_ping(byte_slice).await?;
47	sender.flush().await.map_err(Into::into)
48}
49
50pub(crate) struct BackgroundTaskParams<S> {
51	pub(crate) server_cfg: ServerConfig,
52	pub(crate) conn: ConnectionState,
53	pub(crate) ws_sender: Sender,
54	pub(crate) ws_receiver: Receiver,
55	pub(crate) rpc_service: S,
56	pub(crate) sink: MethodSink,
57	pub(crate) rx: mpsc::Receiver<String>,
58	pub(crate) pending_calls_completed: mpsc::Receiver<()>,
59	pub(crate) on_session_close: Option<SessionClose>,
60	pub(crate) extensions: http::Extensions,
61}
62
63pub(crate) async fn background_task<S>(params: BackgroundTaskParams<S>)
64where
65	for<'a> S: RpcServiceT<'a> + Send + Sync + 'static,
66{
67	let BackgroundTaskParams {
68		server_cfg,
69		conn,
70		ws_sender,
71		ws_receiver,
72		rpc_service,
73		sink,
74		rx,
75		pending_calls_completed,
76		mut on_session_close,
77		extensions,
78	} = params;
79	let ServerConfig { ping_config, batch_requests_config, max_request_body_size, max_response_body_size, .. } =
80		server_cfg;
81
82	let (conn_tx, conn_rx) = oneshot::channel();
83
84	// Spawn another task that sends out the responses on the Websocket.
85	let send_task_handle = tokio::spawn(send_task(rx, ws_sender, ping_config, conn_rx));
86
87	let stopped = conn.stop_handle.clone().shutdown();
88	let rpc_service = Arc::new(rpc_service);
89	let mut missed_pings = 0;
90
91	tokio::pin!(stopped);
92
93	let ws_stream = futures_util::stream::unfold(ws_receiver, |mut receiver| async {
94		let mut data = Vec::new();
95		match receiver.receive(&mut data).await {
96			Ok(soketto::Incoming::Data(_)) => Some((Ok(Incoming::Data(data)), receiver)),
97			Ok(soketto::Incoming::Pong(_)) => Some((Ok(Incoming::Pong), receiver)),
98			Ok(soketto::Incoming::Closed(_)) | Err(SokettoError::Closed) => None,
99			// The closing reason is already logged by `soketto` trace log level.
100			// Return the `Closed` error to avoid logging unnecessary warnings on clean shutdown.
101			Err(e) => Some((Err(e), receiver)),
102		}
103	})
104	.fuse();
105
106	tokio::pin!(ws_stream);
107
108	let result = loop {
109		let data = match try_recv(&mut ws_stream, stopped, ping_config, &mut missed_pings).await {
110			Receive::ConnectionClosed => break Ok(Shutdown::ConnectionClosed),
111			Receive::Stopped => break Ok(Shutdown::Stopped),
112			Receive::Ok(data, stop) => {
113				stopped = stop;
114				data
115			}
116			Receive::Err(err, stop) => {
117				stopped = stop;
118
119				match err {
120					SokettoError::Closed => {
121						break Ok(Shutdown::ConnectionClosed);
122					}
123					SokettoError::MessageTooLarge { current, maximum } => {
124						tracing::debug!(
125							target: LOG_TARGET,
126							"WS recv error: message too large current={}/max={}",
127							current,
128							maximum
129						);
130						if sink.send_error(Id::Null, reject_too_big_request(max_request_body_size)).await.is_err() {
131							break Ok(Shutdown::ConnectionClosed);
132						}
133
134						continue;
135					}
136					err => {
137						tracing::debug!(target: LOG_TARGET, "WS error: {}; terminate connection: {}", err, conn.conn_id);
138						break Err(err);
139					}
140				};
141			}
142		};
143
144		let rpc_service = rpc_service.clone();
145		let sink = sink.clone();
146		let extensions = extensions.clone();
147
148		tokio::spawn(async move {
149			let first_non_whitespace = data.iter().enumerate().take(128).find(|(_, byte)| !byte.is_ascii_whitespace());
150
151			let (idx, is_single) = match first_non_whitespace {
152				Some((start, b'{')) => (start, true),
153				Some((start, b'[')) => (start, false),
154				_ => {
155					_ = sink.send_error(Id::Null, ErrorCode::ParseError.into()).await;
156					return;
157				}
158			};
159
160			if let Some(rp) = handle_rpc_call(
161				&data[idx..],
162				is_single,
163				batch_requests_config,
164				max_response_body_size,
165				&*rpc_service,
166				extensions,
167			)
168			.await
169			{
170				if !rp.is_subscription() {
171					let is_success = rp.is_success();
172					let (serialized_rp, mut on_close) = rp.into_parts();
173
174					// The connection is closed, just quit.
175					if sink.send(serialized_rp).await.is_err() {
176						return;
177					}
178
179					// Notify that the message has been sent out to the internal
180					// WebSocket buffer.
181					if let Some(n) = on_close.take() {
182						n.notify(is_success);
183					}
184				}
185			}
186		});
187	};
188
189	// Drive all running methods to completion.
190	// **NOTE** Do not return early in this function. This `await` needs to run to guarantee
191	// proper drop behaviour.
192	drop(rpc_service);
193	graceful_shutdown(result, pending_calls_completed, ws_stream, conn_tx, send_task_handle).await;
194
195	drop(conn);
196
197	if let Some(c) = on_session_close.take() {
198		c.close();
199	}
200}
201
202/// A task that waits for new messages via the `rx channel` and sends them out on the `WebSocket`.
203async fn send_task(
204	rx: mpsc::Receiver<String>,
205	mut ws_sender: Sender,
206	ping_config: Option<PingConfig>,
207	stop: oneshot::Receiver<()>,
208) {
209	let ping_interval = match ping_config {
210		None => IntervalStream::pending(),
211		// NOTE: we are emitted a tick here immediately to sync
212		// with how the receive task work because it starts measuring the pong
213		// when it starts up.
214		Some(p) => IntervalStream::new(interval(p.ping_interval)),
215	};
216	let rx = ReceiverStream::new(rx);
217
218	tokio::pin!(ping_interval, rx, stop);
219
220	// Received messages from the WebSocket.
221	let mut rx_item = rx.next();
222	let next_ping = ping_interval.next();
223	let mut futs = future::select(next_ping, stop);
224
225	loop {
226		// Ensure select is cancel-safe by fetching and storing the `rx_item` that did not finish yet.
227		// Note: Although, this is cancel-safe already, avoid using `select!` macro for future proofing.
228		match future::select(rx_item, futs).await {
229			// Received message.
230			Either::Left((Some(response), not_ready)) => {
231				// If websocket message send fail then terminate the connection.
232				if let Err(err) = send_message(&mut ws_sender, response).await {
233					tracing::debug!(target: LOG_TARGET, "WS send error: {}", err);
234					break;
235				}
236
237				rx_item = rx.next();
238				futs = not_ready;
239			}
240
241			// Nothing else to receive.
242			Either::Left((None, _)) => {
243				break;
244			}
245
246			// Handle timer intervals.
247			Either::Right((Either::Left((_instant, _stopped)), next_rx)) => {
248				stop = _stopped;
249				if let Err(err) = send_ping(&mut ws_sender).await {
250					tracing::debug!(target: LOG_TARGET, "WS send ping error: {}", err);
251					break;
252				}
253
254				rx_item = next_rx;
255				futs = future::select(ping_interval.next(), stop);
256			}
257			Either::Right((Either::Right((_stopped, _)), _)) => {
258				// server has stopped
259				break;
260			}
261		}
262	}
263
264	// Terminate connection and send close message.
265	let _ = ws_sender.close().await;
266	rx.close();
267}
268
269enum Receive<S> {
270	ConnectionClosed,
271	Stopped,
272	Err(SokettoError, S),
273	Ok(Vec<u8>, S),
274}
275
276/// Attempts to read data from WebSocket fails if the server was stopped.
277async fn try_recv<T, S>(
278	ws_stream: &mut T,
279	mut stopped: S,
280	ping_config: Option<PingConfig>,
281	missed_pings: &mut usize,
282) -> Receive<S>
283where
284	S: Future<Output = ()> + Unpin,
285	T: StreamExt<Item = Result<Incoming, SokettoError>> + Unpin,
286{
287	let mut last_active = Instant::now();
288	let inactivity_check = match ping_config {
289		Some(p) => IntervalStream::new(interval_at(tokio::time::Instant::now() + p.ping_interval, p.ping_interval)),
290		None => IntervalStream::pending(),
291	};
292
293	tokio::pin!(inactivity_check);
294
295	let mut futs = futures_util::future::select(ws_stream.next(), inactivity_check.next());
296
297	loop {
298		match futures_util::future::select(futs, stopped).await {
299			// The connection is closed.
300			Either::Left((Either::Left((None, _)), _)) => break Receive::ConnectionClosed,
301			// The message has been received, we are done
302			Either::Left((Either::Left((Some(Ok(Incoming::Data(d))), _)), s)) => break Receive::Ok(d, s),
303			// Got a pong response, update our "last seen" timestamp.
304			Either::Left((Either::Left((Some(Ok(Incoming::Pong)), inactive)), s)) => {
305				last_active = Instant::now();
306				stopped = s;
307				futs = futures_util::future::select(ws_stream.next(), inactive);
308			}
309			// Received an error, terminate the connection.
310			Either::Left((Either::Left((Some(Err(e)), _)), s)) => break Receive::Err(e, s),
311			// Max inactivity timeout fired, check if the connection has been idle too long.
312			Either::Left((Either::Right((_instant, rcv)), s)) => {
313				if let Some(p) = ping_config {
314					if last_active.elapsed() > p.inactive_limit {
315						*missed_pings += 1;
316
317						if *missed_pings >= p.max_failures {
318							tracing::debug!(
319								target: LOG_TARGET,
320								"WS ping/pong inactivity limit `{}` exceeded; closing connection",
321								p.max_failures,
322							);
323							break Receive::ConnectionClosed;
324						}
325					}
326				}
327
328				stopped = s;
329				futs = futures_util::future::select(rcv, inactivity_check.next());
330			}
331			// Server has been stopped.
332			Either::Right(_) => break Receive::Stopped,
333		}
334	}
335}
336
337#[derive(Debug, Copy, Clone)]
338pub(crate) enum Shutdown {
339	Stopped,
340	ConnectionClosed,
341}
342
343/// Enforce a graceful shutdown.
344///
345/// This will return once the connection has been terminated or all pending calls have been executed.
346async fn graceful_shutdown<S>(
347	result: Result<Shutdown, SokettoError>,
348	pending_calls: mpsc::Receiver<()>,
349	ws_stream: S,
350	mut conn_tx: oneshot::Sender<()>,
351	send_task_handle: tokio::task::JoinHandle<()>,
352) where
353	S: StreamExt<Item = Result<Incoming, SokettoError>> + Unpin,
354{
355	let pending_calls = ReceiverStream::new(pending_calls);
356
357	if let Ok(Shutdown::Stopped) = result {
358		let graceful_shutdown = pending_calls.for_each(|_| async {});
359		let disconnect = ws_stream.try_for_each(|_| async { Ok(()) });
360
361		tokio::select! {
362			_ = graceful_shutdown => {}
363			res = disconnect => {
364				if let Err(err) = res {
365					tracing::warn!(target: LOG_TARGET, "Graceful shutdown terminated because of error: `{err}`");
366				}
367			}
368			_ = conn_tx.closed() => {}
369		}
370	}
371
372	// Send a message to close down the "send task".
373	_ = conn_tx.send(());
374	// Ensure that send task has been closed.
375	_ = send_task_handle.await;
376}
377
378/// Low-level API that tries to upgrade the HTTP connection to a WebSocket connection.
379///
380/// Returns `Ok((http_response, conn_fut))` if the WebSocket connection was successfully established
381/// otherwise `Err(http_response)`.
382///
383/// `conn_fut` is a future that drives the WebSocket connection
384/// and if it's dropped the connection will be closed.
385///
386/// Because this API depends on [`hyper`] the response needs to be sent
387/// to complete the HTTP request.
388///
389/// ```no_run
390/// use jsonrpsee_server::{ws, ServerConfig, Methods, ConnectionState, HttpRequest, HttpResponse};
391/// use jsonrpsee_server::middleware::rpc::{RpcServiceBuilder, RpcServiceT, RpcService};
392///
393/// async fn handle_websocket_conn<L>(
394///     req: HttpRequest,
395///     server_cfg: ServerConfig,
396///     methods: impl Into<Methods> + 'static,
397///     conn: ConnectionState,
398///     rpc_middleware: RpcServiceBuilder<L>,
399///     mut disconnect: tokio::sync::mpsc::Receiver<()>
400/// ) -> HttpResponse
401/// where
402///     L: for<'a> tower::Layer<RpcService> + 'static,
403///     <L as tower::Layer<RpcService>>::Service: Send + Sync + 'static,
404///     for<'a> <L as tower::Layer<RpcService>>::Service: RpcServiceT<'a> + 'static,
405/// {
406///   match ws::connect(req, server_cfg, methods, conn, rpc_middleware).await {
407///     Ok((rp, conn_fut)) => {
408///         tokio::spawn(async move {
409///             // Keep the connection alive until
410///             // a close signal is sent.
411///             tokio::select! {
412///                 _ = conn_fut => (),
413///                 _ = disconnect.recv() => (),
414///             }
415///         });
416///         rp
417///     }
418///     Err(rp) => rp,
419///   }
420/// }
421/// ```
422pub async fn connect<L, B>(
423	req: HttpRequest<B>,
424	server_cfg: ServerConfig,
425	methods: impl Into<Methods>,
426	conn: ConnectionState,
427	rpc_middleware: RpcServiceBuilder<L>,
428) -> Result<(HttpResponse, impl Future<Output = ()>), HttpResponse>
429where
430	L: for<'a> tower::Layer<RpcService>,
431	<L as tower::Layer<RpcService>>::Service: Send + Sync + 'static,
432	for<'a> <L as tower::Layer<RpcService>>::Service: RpcServiceT<'a>,
433{
434	let mut server = soketto::handshake::http::Server::new();
435
436	match server.receive_request(&req) {
437		Ok(response) => {
438			let (tx, rx) = mpsc::channel::<String>(server_cfg.message_buffer_capacity as usize);
439			let sink = MethodSink::new(tx);
440
441			// On each method call the `pending_calls` is cloned
442			// then when all pending_calls are dropped
443			// a graceful shutdown can has occur.
444			let (pending_calls, pending_calls_completed) = mpsc::channel::<()>(1);
445
446			let rpc_service_cfg = RpcServiceCfg::CallsAndSubscriptions {
447				bounded_subscriptions: BoundedSubscriptions::new(server_cfg.max_subscriptions_per_connection),
448				id_provider: server_cfg.id_provider.clone(),
449				sink: sink.clone(),
450				_pending_calls: pending_calls,
451			};
452
453			let rpc_service = RpcService::new(
454				methods.into(),
455				server_cfg.max_response_body_size as usize,
456				conn.conn_id.into(),
457				rpc_service_cfg,
458			);
459
460			let rpc_service = rpc_middleware.service(rpc_service);
461
462			// Note: This can't possibly be fulfilled until the HTTP response
463			// is returned below, so that's why it's a separate async block
464			let fut = async move {
465				let extensions = req.extensions().clone();
466
467				let upgraded = match hyper::upgrade::on(req).await {
468					Ok(upgraded) => upgraded,
469					Err(e) => {
470						tracing::debug!(target: LOG_TARGET, "WS upgrade handshake failed: {}", e);
471						return;
472					}
473				};
474
475				let io = TokioIo::new(upgraded);
476
477				let stream = BufReader::new(BufWriter::new(io.compat()));
478				let mut ws_builder = server.into_builder(stream);
479				ws_builder.set_max_message_size(server_cfg.max_response_body_size as usize);
480				let (sender, receiver) = ws_builder.finish();
481
482				let params = BackgroundTaskParams {
483					server_cfg,
484					conn,
485					ws_sender: sender,
486					ws_receiver: receiver,
487					rpc_service,
488					sink,
489					rx,
490					pending_calls_completed,
491					on_session_close: None,
492					extensions,
493				};
494
495				background_task(params).await;
496			};
497
498			Ok((response.map(|()| HttpBody::default()), fut))
499		}
500		Err(e) => {
501			tracing::debug!(target: LOG_TARGET, "WS upgrade handshake failed: {}", e);
502			Err(HttpResponse::new(HttpBody::from(format!("WS upgrade handshake failed: {e}"))))
503		}
504	}
505}