jsonrpsee_server/transport/
ws.rs

1use std::sync::Arc;
2use std::time::Instant;
3
4use crate::future::{IntervalStream, SessionClose};
5use crate::middleware::rpc::{RpcService, RpcServiceCfg};
6use crate::server::{ConnectionState, ServerConfig, handle_rpc_call};
7use crate::{HttpBody, HttpRequest, HttpResponse, LOG_TARGET, PingConfig};
8
9use futures_util::future::{self, Either};
10use futures_util::io::{BufReader, BufWriter};
11use futures_util::{Future, StreamExt, TryStreamExt};
12use hyper::upgrade::Upgraded;
13use hyper_util::rt::TokioIo;
14use jsonrpsee_core::middleware::{RpcServiceBuilder, RpcServiceT};
15use jsonrpsee_core::server::{BoundedSubscriptions, MethodResponse, MethodSink, Methods};
16use jsonrpsee_types::Id;
17use jsonrpsee_types::error::{ErrorCode, reject_too_big_request};
18use serde_json::value::RawValue;
19use soketto::connection::Error as SokettoError;
20use soketto::data::ByteSlice125;
21use tokio::sync::{mpsc, oneshot};
22use tokio::time::{interval, interval_at};
23use tokio_stream::wrappers::ReceiverStream;
24use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
25
26pub(crate) type Sender = soketto::Sender<BufReader<BufWriter<Compat<TokioIo<Upgraded>>>>>;
27pub(crate) type Receiver = soketto::Receiver<BufReader<BufWriter<Compat<TokioIo<Upgraded>>>>>;
28
29pub use soketto::handshake::http::is_upgrade_request;
30
31enum Incoming {
32	Data(Vec<u8>),
33	Pong,
34}
35
36pub(crate) async fn send_message(sender: &mut Sender, response: Box<RawValue>) -> Result<(), SokettoError> {
37	sender.send_text_owned(String::from(Box::<str>::from(response))).await?;
38	sender.flush().await
39}
40
41pub(crate) async fn send_ping(sender: &mut Sender) -> Result<(), SokettoError> {
42	tracing::debug!(target: LOG_TARGET, "Send ping");
43	// Submit empty slice as "optional" parameter.
44	let slice: &[u8] = &[];
45	// Byte slice fails if the provided slice is larger than 125 bytes.
46	let byte_slice = ByteSlice125::try_from(slice).expect("Empty slice should fit into ByteSlice125");
47	sender.send_ping(byte_slice).await?;
48	sender.flush().await
49}
50
51pub(crate) struct BackgroundTaskParams<S> {
52	pub(crate) server_cfg: ServerConfig,
53	pub(crate) conn: ConnectionState,
54	pub(crate) ws_sender: Sender,
55	pub(crate) ws_receiver: Receiver,
56	pub(crate) rpc_service: S,
57	pub(crate) sink: MethodSink,
58	pub(crate) rx: mpsc::Receiver<Box<RawValue>>,
59	pub(crate) pending_calls_completed: mpsc::Receiver<()>,
60	pub(crate) on_session_close: Option<SessionClose>,
61	pub(crate) extensions: http::Extensions,
62}
63
64pub(crate) async fn background_task<S>(params: BackgroundTaskParams<S>)
65where
66	S: RpcServiceT<
67			MethodResponse = MethodResponse,
68			BatchResponse = MethodResponse,
69			NotificationResponse = MethodResponse,
70		> + Send
71		+ Sync
72		+ 'static,
73{
74	let BackgroundTaskParams {
75		server_cfg,
76		conn,
77		ws_sender,
78		ws_receiver,
79		rpc_service,
80		sink,
81		rx,
82		pending_calls_completed,
83		mut on_session_close,
84		extensions,
85	} = params;
86	let ServerConfig { ping_config, batch_requests_config, max_request_body_size, .. } = server_cfg;
87
88	let (conn_tx, conn_rx) = oneshot::channel();
89
90	// Spawn another task that sends out the responses on the Websocket.
91	let send_task_handle = tokio::spawn(send_task(rx, ws_sender, ping_config, conn_rx));
92
93	let stopped = conn.stop_handle.clone().shutdown();
94	let rpc_service = Arc::new(rpc_service);
95	let mut missed_pings = 0;
96
97	tokio::pin!(stopped);
98
99	let ws_stream = futures_util::stream::unfold(ws_receiver, |mut receiver| async {
100		let mut data = Vec::new();
101		match receiver.receive(&mut data).await {
102			Ok(soketto::Incoming::Data(_)) => Some((Ok(Incoming::Data(data)), receiver)),
103			Ok(soketto::Incoming::Pong(_)) => Some((Ok(Incoming::Pong), receiver)),
104			Ok(soketto::Incoming::Closed(_)) | Err(SokettoError::Closed) => None,
105			// The closing reason is already logged by `soketto` trace log level.
106			// Return the `Closed` error to avoid logging unnecessary warnings on clean shutdown.
107			Err(e) => Some((Err(e), receiver)),
108		}
109	})
110	.fuse();
111
112	tokio::pin!(ws_stream);
113
114	let result = loop {
115		let data = match try_recv(&mut ws_stream, stopped, ping_config, &mut missed_pings).await {
116			Receive::ConnectionClosed => break Ok(Shutdown::ConnectionClosed),
117			Receive::Stopped => break Ok(Shutdown::Stopped),
118			Receive::Ok(data, stop) => {
119				stopped = stop;
120				data
121			}
122			Receive::Err(err, stop) => {
123				stopped = stop;
124
125				match err {
126					SokettoError::Closed => {
127						break Ok(Shutdown::ConnectionClosed);
128					}
129					SokettoError::MessageTooLarge { current, maximum } => {
130						tracing::debug!(
131							target: LOG_TARGET,
132							"WS recv error: message too large current={}/max={}",
133							current,
134							maximum
135						);
136						if sink.send_error(Id::Null, reject_too_big_request(max_request_body_size)).await.is_err() {
137							break Ok(Shutdown::ConnectionClosed);
138						}
139
140						continue;
141					}
142					err => {
143						tracing::debug!(target: LOG_TARGET, "WS error: {}; terminate connection: {}", err, conn.conn_id);
144						break Err(err);
145					}
146				};
147			}
148		};
149
150		let rpc_service = rpc_service.clone();
151		let sink = sink.clone();
152		let extensions = extensions.clone();
153
154		tokio::spawn(async move {
155			let first_non_whitespace = data.iter().enumerate().take(128).find(|(_, byte)| !byte.is_ascii_whitespace());
156
157			let (idx, is_single) = match first_non_whitespace {
158				Some((start, b'{')) => (start, true),
159				Some((start, b'[')) => (start, false),
160				_ => {
161					_ = sink.send_error(Id::Null, ErrorCode::ParseError.into()).await;
162					return;
163				}
164			};
165
166			let rp = handle_rpc_call(&data[idx..], is_single, batch_requests_config, &*rpc_service, extensions).await;
167
168			// Subscriptions are handled by the subscription callback and
169			// "ordinary notifications" should not be sent back to the client.
170			if rp.is_method_call() || rp.is_batch() {
171				let is_success = rp.is_success();
172				let (json, mut on_close, _) = rp.into_parts();
173
174				// The connection is closed, just quit.
175				if sink.send(json).await.is_err() {
176					return;
177				}
178
179				// Notify that the message has been sent out to the internal
180				// WebSocket buffer.
181				if let Some(n) = on_close.take() {
182					n.notify(is_success);
183				}
184			}
185		});
186	};
187
188	// Drive all running methods to completion.
189	// **NOTE** Do not return early in this function. This `await` needs to run to guarantee
190	// proper drop behaviour.
191	drop(rpc_service);
192	graceful_shutdown(result, pending_calls_completed, ws_stream, conn_tx, send_task_handle).await;
193
194	drop(conn);
195
196	if let Some(c) = on_session_close.take() {
197		c.close();
198	}
199}
200
201/// A task that waits for new messages via the `rx channel` and sends them out on the `WebSocket`.
202async fn send_task(
203	rx: mpsc::Receiver<Box<RawValue>>,
204	mut ws_sender: Sender,
205	ping_config: Option<PingConfig>,
206	stop: oneshot::Receiver<()>,
207) {
208	let ping_interval = match ping_config {
209		None => IntervalStream::pending(),
210		// NOTE: we are emitted a tick here immediately to sync
211		// with how the receive task work because it starts measuring the pong
212		// when it starts up.
213		Some(p) => IntervalStream::new(interval(p.ping_interval)),
214	};
215	let rx = ReceiverStream::new(rx);
216
217	tokio::pin!(ping_interval, rx, stop);
218
219	// Received messages from the WebSocket.
220	let mut rx_item = rx.next();
221	let next_ping = ping_interval.next();
222	let mut futs = future::select(next_ping, stop);
223
224	loop {
225		// Ensure select is cancel-safe by fetching and storing the `rx_item` that did not finish yet.
226		// Note: Although, this is cancel-safe already, avoid using `select!` macro for future proofing.
227		match future::select(rx_item, futs).await {
228			// Received message.
229			Either::Left((Some(response), not_ready)) => {
230				// If websocket message send fail then terminate the connection.
231				if let Err(err) = send_message(&mut ws_sender, response).await {
232					tracing::debug!(target: LOG_TARGET, "WS send error: {}", err);
233					break;
234				}
235
236				rx_item = rx.next();
237				futs = not_ready;
238			}
239
240			// Nothing else to receive.
241			Either::Left((None, _)) => {
242				break;
243			}
244
245			// Handle timer intervals.
246			Either::Right((Either::Left((_instant, _stopped)), next_rx)) => {
247				stop = _stopped;
248				if let Err(err) = send_ping(&mut ws_sender).await {
249					tracing::debug!(target: LOG_TARGET, "WS send ping error: {}", err);
250					break;
251				}
252
253				rx_item = next_rx;
254				futs = future::select(ping_interval.next(), stop);
255			}
256			Either::Right((Either::Right((_stopped, _)), _)) => {
257				// server has stopped
258				break;
259			}
260		}
261	}
262
263	// Terminate connection and send close message.
264	let _ = ws_sender.close().await;
265	rx.close();
266}
267
268enum Receive<S> {
269	ConnectionClosed,
270	Stopped,
271	Err(SokettoError, S),
272	Ok(Vec<u8>, S),
273}
274
275/// Attempts to read data from WebSocket fails if the server was stopped.
276async fn try_recv<T, S>(
277	ws_stream: &mut T,
278	mut stopped: S,
279	ping_config: Option<PingConfig>,
280	missed_pings: &mut usize,
281) -> Receive<S>
282where
283	S: Future<Output = ()> + Unpin,
284	T: StreamExt<Item = Result<Incoming, SokettoError>> + Unpin,
285{
286	let mut last_active = Instant::now();
287	let inactivity_check = match ping_config {
288		Some(p) => IntervalStream::new(interval_at(tokio::time::Instant::now() + p.ping_interval, p.ping_interval)),
289		None => IntervalStream::pending(),
290	};
291
292	tokio::pin!(inactivity_check);
293
294	let mut futs = futures_util::future::select(ws_stream.next(), inactivity_check.next());
295
296	loop {
297		match futures_util::future::select(futs, stopped).await {
298			// The connection is closed.
299			Either::Left((Either::Left((None, _)), _)) => break Receive::ConnectionClosed,
300			// The message has been received, we are done
301			Either::Left((Either::Left((Some(Ok(Incoming::Data(d))), _)), s)) => break Receive::Ok(d, s),
302			// Got a pong response, update our "last seen" timestamp.
303			Either::Left((Either::Left((Some(Ok(Incoming::Pong)), inactive)), s)) => {
304				last_active = Instant::now();
305				stopped = s;
306				futs = futures_util::future::select(ws_stream.next(), inactive);
307			}
308			// Received an error, terminate the connection.
309			Either::Left((Either::Left((Some(Err(e)), _)), s)) => break Receive::Err(e, s),
310			// Max inactivity timeout fired, check if the connection has been idle too long.
311			Either::Left((Either::Right((_instant, rcv)), s)) => {
312				if let Some(p) = ping_config {
313					if last_active.elapsed() > p.inactive_limit {
314						*missed_pings += 1;
315
316						if *missed_pings >= p.max_failures {
317							tracing::debug!(
318								target: LOG_TARGET,
319								"WS ping/pong inactivity limit `{}` exceeded; closing connection",
320								p.max_failures,
321							);
322							break Receive::ConnectionClosed;
323						}
324					}
325				}
326
327				stopped = s;
328				futs = futures_util::future::select(rcv, inactivity_check.next());
329			}
330			// Server has been stopped.
331			Either::Right(_) => break Receive::Stopped,
332		}
333	}
334}
335
336#[derive(Debug, Copy, Clone)]
337pub(crate) enum Shutdown {
338	Stopped,
339	ConnectionClosed,
340}
341
342/// Enforce a graceful shutdown.
343///
344/// This will return once the connection has been terminated or all pending calls have been executed.
345async fn graceful_shutdown<S>(
346	result: Result<Shutdown, SokettoError>,
347	pending_calls: mpsc::Receiver<()>,
348	ws_stream: S,
349	mut conn_tx: oneshot::Sender<()>,
350	send_task_handle: tokio::task::JoinHandle<()>,
351) where
352	S: StreamExt<Item = Result<Incoming, SokettoError>> + Unpin,
353{
354	let pending_calls = ReceiverStream::new(pending_calls);
355
356	if let Ok(Shutdown::Stopped) = result {
357		let graceful_shutdown = pending_calls.for_each(|_| async {});
358		let disconnect = ws_stream.try_for_each(|_| async { Ok(()) });
359
360		tokio::select! {
361			_ = graceful_shutdown => {}
362			res = disconnect => {
363				if let Err(err) = res {
364					tracing::warn!(target: LOG_TARGET, "Graceful shutdown terminated because of error: `{err}`");
365				}
366			}
367			_ = conn_tx.closed() => {}
368		}
369	}
370
371	// Send a message to close down the "send task".
372	_ = conn_tx.send(());
373	// Ensure that send task has been closed.
374	_ = send_task_handle.await;
375}
376
377/// Low-level API that tries to upgrade the HTTP connection to a WebSocket connection.
378///
379/// Returns `Ok((http_response, conn_fut))` if the WebSocket connection was successfully established
380/// otherwise `Err(http_response)`.
381///
382/// `conn_fut` is a future that drives the WebSocket connection
383/// and if it's dropped the connection will be closed.
384///
385/// Because this API depends on [`hyper`] the response needs to be sent
386/// to complete the HTTP request.
387///
388/// ```no_run
389/// use jsonrpsee_server::{ws, ServerConfig, Methods, ConnectionState, HttpRequest, HttpResponse};
390/// use jsonrpsee_server::middleware::rpc::{RpcServiceBuilder, RpcServiceT, RpcService, MethodResponse};
391/// use std::convert::Infallible;
392///
393/// async fn handle_websocket_conn<L>(
394///     req: HttpRequest,
395///     server_cfg: ServerConfig,
396///     methods: impl Into<Methods> + 'static,
397///     conn: ConnectionState,
398///     rpc_middleware: RpcServiceBuilder<L>,
399///     mut disconnect: tokio::sync::mpsc::Receiver<()>
400/// ) -> HttpResponse
401/// where
402///     L: tower::Layer<RpcService> + 'static,
403///     <L as tower::Layer<RpcService>>::Service: RpcServiceT<MethodResponse = MethodResponse, BatchResponse = MethodResponse, NotificationResponse = MethodResponse> + Send + Sync + 'static,
404/// {
405///   match ws::connect(req, server_cfg, methods, conn, rpc_middleware).await {
406///     Ok((rp, conn_fut)) => {
407///         tokio::spawn(async move {
408///             // Keep the connection alive until
409///             // a close signal is sent.
410///             tokio::select! {
411///                 _ = conn_fut => (),
412///                 _ = disconnect.recv() => (),
413///             }
414///         });
415///         rp
416///     }
417///     Err(rp) => rp,
418///   }
419/// }
420/// ```
421pub async fn connect<L, B>(
422	req: HttpRequest<B>,
423	server_cfg: ServerConfig,
424	methods: impl Into<Methods>,
425	conn: ConnectionState,
426	rpc_middleware: RpcServiceBuilder<L>,
427) -> Result<(HttpResponse, impl Future<Output = ()>), HttpResponse>
428where
429	L: tower::Layer<RpcService>,
430	<L as tower::Layer<RpcService>>::Service: RpcServiceT<
431			MethodResponse = MethodResponse,
432			BatchResponse = MethodResponse,
433			NotificationResponse = MethodResponse,
434		> + Send
435		+ Sync
436		+ 'static,
437{
438	let mut server = soketto::handshake::http::Server::new();
439
440	match server.receive_request(&req) {
441		Ok(response) => {
442			let (tx, rx) = mpsc::channel(server_cfg.message_buffer_capacity as usize);
443			let sink = MethodSink::new(tx);
444
445			// On each method call the `pending_calls` is cloned
446			// then when all pending_calls are dropped
447			// a graceful shutdown can has occur.
448			let (pending_calls, pending_calls_completed) = mpsc::channel::<()>(1);
449
450			let rpc_service_cfg = RpcServiceCfg::CallsAndSubscriptions {
451				bounded_subscriptions: BoundedSubscriptions::new(server_cfg.max_subscriptions_per_connection),
452				id_provider: server_cfg.id_provider.clone(),
453				sink: sink.clone(),
454				_pending_calls: pending_calls,
455			};
456
457			let rpc_service = RpcService::new(
458				methods.into(),
459				server_cfg.max_response_body_size as usize,
460				conn.conn_id.into(),
461				rpc_service_cfg,
462			);
463
464			let rpc_service = rpc_middleware.service(rpc_service);
465
466			// Note: This can't possibly be fulfilled until the HTTP response
467			// is returned below, so that's why it's a separate async block
468			let fut = async move {
469				let extensions = req.extensions().clone();
470
471				let upgraded = match hyper::upgrade::on(req).await {
472					Ok(upgraded) => upgraded,
473					Err(e) => {
474						tracing::debug!(target: LOG_TARGET, "WS upgrade handshake failed: {}", e);
475						return;
476					}
477				};
478
479				let io = TokioIo::new(upgraded);
480
481				let stream = BufReader::new(BufWriter::new(io.compat()));
482				let mut ws_builder = server.into_builder(stream);
483				ws_builder.set_max_message_size(server_cfg.max_response_body_size as usize);
484				let (sender, receiver) = ws_builder.finish();
485
486				let params = BackgroundTaskParams {
487					server_cfg,
488					conn,
489					ws_sender: sender,
490					ws_receiver: receiver,
491					rpc_service,
492					sink,
493					rx,
494					pending_calls_completed,
495					on_session_close: None,
496					extensions,
497				};
498
499				background_task(params).await;
500			};
501
502			Ok((response.map(|()| HttpBody::default()), fut))
503		}
504		Err(e) => {
505			tracing::debug!(target: LOG_TARGET, "WS upgrade handshake failed: {}", e);
506			Err(HttpResponse::new(HttpBody::from(format!("WS upgrade handshake failed: {e}"))))
507		}
508	}
509}