jsonrpsee_ws_server/
server.rs

1// Copyright 2019-2021 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any
4// person obtaining a copy of this software and associated
5// documentation files (the "Software"), to deal in the
6// Software without restriction, including without
7// limitation the rights to use, copy, modify, merge,
8// publish, distribute, sublicense, and/or sell copies of
9// the Software, and to permit persons to whom the Software
10// is furnished to do so, subject to the following
11// conditions:
12//
13// The above copyright notice and this permission notice
14// shall be included in all copies or substantial portions
15// of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
18// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
19// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
20// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
21// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
24// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
25// DEALINGS IN THE SOFTWARE.
26
27use std::future::Future;
28use std::net::SocketAddr;
29use std::pin::Pin;
30use std::sync::Arc;
31use std::task::{Context, Poll};
32use std::time::Duration;
33
34use crate::future::{FutureDriver, ServerHandle, StopMonitor};
35use crate::types::error::{ErrorCode, ErrorObject, BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG};
36use crate::types::{Id, Request};
37use futures_channel::mpsc;
38use futures_util::future::{Either, FutureExt};
39use futures_util::io::{BufReader, BufWriter};
40use futures_util::stream::StreamExt;
41use futures_util::TryStreamExt;
42use http::header::{HOST, ORIGIN};
43use http::{HeaderMap, HeaderValue};
44use jsonrpsee_core::id_providers::RandomIntegerIdProvider;
45use jsonrpsee_core::middleware::{self, WsMiddleware as Middleware};
46use jsonrpsee_core::server::access_control::AccessControl;
47use jsonrpsee_core::server::helpers::{
48	prepare_error, BatchResponse, BatchResponseBuilder, BoundedSubscriptions, MethodResponse, MethodSink,
49};
50use jsonrpsee_core::server::resource_limiting::Resources;
51use jsonrpsee_core::server::rpc_module::{ConnState, ConnectionId, MethodKind, Methods};
52use jsonrpsee_core::tracing::{rx_log_from_json, rx_log_from_str, tx_log_from_str, RpcTracing};
53use jsonrpsee_core::traits::IdProvider;
54use jsonrpsee_core::{Error, TEN_MB_SIZE_BYTES};
55use jsonrpsee_types::error::{reject_too_big_request, reject_too_many_subscriptions};
56use jsonrpsee_types::Params;
57use soketto::connection::Error as SokettoError;
58use soketto::data::ByteSlice125;
59use soketto::handshake::WebSocketKey;
60use soketto::handshake::{server::Response, Server as SokettoServer};
61use soketto::Sender;
62use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
63use tokio_stream::wrappers::IntervalStream;
64use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
65use tracing_futures::Instrument;
66
67/// Default maximum connections allowed.
68const MAX_CONNECTIONS: u64 = 100;
69
70/// A WebSocket JSON RPC server.
71pub struct Server<M> {
72	listener: TcpListener,
73	cfg: Settings,
74	stop_monitor: StopMonitor,
75	resources: Resources,
76	middleware: M,
77	id_provider: Arc<dyn IdProvider>,
78}
79
80impl<M> std::fmt::Debug for Server<M> {
81	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82		f.debug_struct("Server")
83			.field("listener", &self.listener)
84			.field("cfg", &self.cfg)
85			.field("stop_monitor", &self.stop_monitor)
86			.field("id_provider", &self.id_provider)
87			.field("resources", &self.resources)
88			.finish()
89	}
90}
91
92impl<M: Middleware> Server<M> {
93	/// Returns socket address to which the server is bound.
94	pub fn local_addr(&self) -> Result<SocketAddr, Error> {
95		self.listener.local_addr().map_err(Into::into)
96	}
97
98	/// Returns the handle to stop the running server.
99	pub fn server_handle(&self) -> ServerHandle {
100		self.stop_monitor.handle()
101	}
102
103	/// Start responding to connections requests. This will run on the tokio runtime until the server is stopped.
104	pub fn start(mut self, methods: impl Into<Methods>) -> Result<ServerHandle, Error> {
105		let methods = methods.into().initialize_resources(&self.resources)?;
106		let handle = self.server_handle();
107
108		match self.cfg.tokio_runtime.take() {
109			Some(rt) => rt.spawn(self.start_inner(methods)),
110			None => tokio::spawn(self.start_inner(methods)),
111		};
112
113		Ok(handle)
114	}
115
116	async fn start_inner(self, methods: Methods) {
117		let stop_monitor = self.stop_monitor;
118		let resources = self.resources;
119		let middleware = self.middleware;
120
121		let mut id = 0;
122		let mut connections = FutureDriver::default();
123		let mut incoming = Monitored::new(Incoming(self.listener), &stop_monitor);
124
125		loop {
126			match connections.select_with(&mut incoming).await {
127				Ok((socket, _addr)) => {
128					if let Err(e) = socket.set_nodelay(true) {
129						tracing::error!("Could not set NODELAY on socket: {:?}", e);
130						continue;
131					}
132
133					if connections.count() >= self.cfg.max_connections as usize {
134						tracing::warn!("Too many connections. Try again in a while.");
135						connections.add(Box::pin(handshake(socket, HandshakeResponse::Reject { status_code: 429 })));
136						continue;
137					}
138
139					let methods = &methods;
140					let cfg = &self.cfg;
141					let id_provider = self.id_provider.clone();
142
143					connections.add(Box::pin(handshake(
144						socket,
145						HandshakeResponse::Accept {
146							conn_id: id,
147							methods,
148							resources: &resources,
149							cfg,
150							stop_monitor: &stop_monitor,
151							middleware: middleware.clone(),
152							id_provider,
153						},
154					)));
155
156					tracing::info!("Accepting new connection {}/{}", connections.count(), self.cfg.max_connections);
157
158					id = id.wrapping_add(1);
159				}
160				Err(MonitoredError::Selector(err)) => {
161					tracing::error!("Error while awaiting a new connection: {:?}", err);
162				}
163				Err(MonitoredError::Shutdown) => break,
164			}
165		}
166
167		connections.await
168	}
169}
170
171/// This is a glorified select listening for new messages, while also checking the `stop_receiver` signal.
172struct Monitored<'a, F> {
173	future: F,
174	stop_monitor: &'a StopMonitor,
175}
176
177impl<'a, F> Monitored<'a, F> {
178	fn new(future: F, stop_monitor: &'a StopMonitor) -> Self {
179		Monitored { future, stop_monitor }
180	}
181}
182
183enum MonitoredError<E> {
184	Shutdown,
185	Selector(E),
186}
187
188struct Incoming(TcpListener);
189
190impl<'a> Future for Monitored<'a, Incoming> {
191	type Output = Result<(TcpStream, SocketAddr), MonitoredError<std::io::Error>>;
192
193	fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
194		let this = Pin::into_inner(self);
195
196		if this.stop_monitor.shutdown_requested() {
197			return Poll::Ready(Err(MonitoredError::Shutdown));
198		}
199
200		this.future.0.poll_accept(cx).map_err(MonitoredError::Selector)
201	}
202}
203
204impl<'a, 'f, F, T, E> Future for Monitored<'a, Pin<&'f mut F>>
205where
206	F: Future<Output = Result<T, E>>,
207{
208	type Output = Result<T, MonitoredError<E>>;
209
210	fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
211		let this = Pin::into_inner(self);
212
213		if this.stop_monitor.shutdown_requested() {
214			return Poll::Ready(Err(MonitoredError::Shutdown));
215		}
216
217		this.future.poll_unpin(cx).map_err(MonitoredError::Selector)
218	}
219}
220
221enum HandshakeResponse<'a, M> {
222	Reject {
223		status_code: u16,
224	},
225	Accept {
226		conn_id: ConnectionId,
227		methods: &'a Methods,
228		resources: &'a Resources,
229		cfg: &'a Settings,
230		stop_monitor: &'a StopMonitor,
231		middleware: M,
232		id_provider: Arc<dyn IdProvider>,
233	},
234}
235
236async fn handshake<M: Middleware>(socket: tokio::net::TcpStream, mode: HandshakeResponse<'_, M>) -> Result<(), Error> {
237	let remote_addr = socket.peer_addr()?;
238
239	// For each incoming background_task we perform a handshake.
240	let mut server = SokettoServer::new(BufReader::new(BufWriter::new(socket.compat())));
241
242	match mode {
243		HandshakeResponse::Reject { status_code } => {
244			// Forced rejection, don't need to read anything from the socket
245			let reject = Response::Reject { status_code };
246			server.send_response(&reject).await?;
247
248			let (mut sender, _) = server.into_builder().finish();
249
250			// Gracefully shut down the connection
251			sender.close().await?;
252
253			Ok(())
254		}
255		HandshakeResponse::Accept { conn_id, methods, resources, cfg, stop_monitor, middleware, id_provider } => {
256			tracing::debug!("Accepting new connection: {}", conn_id);
257
258			let key_and_headers = get_key_and_headers(&mut server, cfg).await;
259
260			match key_and_headers {
261				Ok((key, headers)) => {
262					middleware.on_connect(remote_addr, &headers);
263					let accept = Response::Accept { key, protocol: None };
264					server.send_response(&accept).await?;
265				}
266				Err(err) => {
267					tracing::warn!("Rejected connection: {:?}", err);
268					let reject = Response::Reject { status_code: 403 };
269					server.send_response(&reject).await?;
270
271					return Err(err);
272				}
273			};
274
275			let join_result = tokio::spawn(background_task(BackgroundTask {
276				server,
277				conn_id,
278				methods: methods.clone(),
279				resources: resources.clone(),
280				max_request_body_size: cfg.max_request_body_size,
281				max_response_body_size: cfg.max_response_body_size,
282				max_log_length: cfg.max_log_length,
283				batch_requests_supported: cfg.batch_requests_supported,
284				bounded_subscriptions: BoundedSubscriptions::new(cfg.max_subscriptions_per_connection),
285				stop_server: stop_monitor.clone(),
286				middleware,
287				id_provider,
288				ping_interval: cfg.ping_interval,
289				remote_addr,
290			}))
291			.await;
292
293			match join_result {
294				Err(_) => Err(Error::Custom("Background task was aborted".into())),
295				Ok(result) => result,
296			}
297		}
298	}
299}
300
301struct BackgroundTask<'a, M> {
302	server: SokettoServer<'a, BufReader<BufWriter<Compat<tokio::net::TcpStream>>>>,
303	conn_id: ConnectionId,
304	methods: Methods,
305	resources: Resources,
306	max_request_body_size: u32,
307	max_response_body_size: u32,
308	max_log_length: u32,
309	batch_requests_supported: bool,
310	bounded_subscriptions: BoundedSubscriptions,
311	stop_server: StopMonitor,
312	middleware: M,
313	id_provider: Arc<dyn IdProvider>,
314	ping_interval: Duration,
315	remote_addr: SocketAddr,
316}
317
318async fn background_task<M: Middleware>(input: BackgroundTask<'_, M>) -> Result<(), Error> {
319	let BackgroundTask {
320		server,
321		conn_id,
322		methods,
323		resources,
324		max_request_body_size,
325		max_response_body_size,
326		max_log_length,
327		batch_requests_supported,
328		bounded_subscriptions,
329		stop_server,
330		middleware,
331		id_provider,
332		ping_interval,
333		remote_addr,
334	} = input;
335
336	// And we can finally transition to a websocket background_task.
337	let mut builder = server.into_builder();
338	builder.set_max_message_size(max_request_body_size as usize);
339	let (mut sender, mut receiver) = builder.finish();
340	let (tx, mut rx) = mpsc::unbounded::<String>();
341	let bounded_subscriptions2 = bounded_subscriptions.clone();
342
343	let stop_server2 = stop_server.clone();
344	let sink = MethodSink::new_with_limit(tx, max_response_body_size, max_log_length);
345
346	// Send results back to the client.
347	tokio::spawn(async move {
348		// Received messages from the WebSocket.
349		let mut rx_item = rx.next();
350
351		// Interval to send out continuously `pings`.
352		let ping_interval = IntervalStream::new(tokio::time::interval(ping_interval));
353		tokio::pin!(ping_interval);
354		let mut next_ping = ping_interval.next();
355
356		while !stop_server2.shutdown_requested() {
357			// Ensure select is cancel-safe by fetching and storing the `rx_item` that did not finish yet.
358			// Note: Although, this is cancel-safe already, avoid using `select!` macro for future proofing.
359			match futures_util::future::select(rx_item, next_ping).await {
360				Either::Left((Some(response), ping)) => {
361					// If websocket message send fail then terminate the connection.
362					if let Err(err) = send_ws_message(&mut sender, response).await {
363						tracing::warn!("WS send error: {}; terminate connection", err);
364						break;
365					}
366					rx_item = rx.next();
367					next_ping = ping;
368				}
369				// Nothing else to receive.
370				Either::Left((None, _)) => break,
371
372				// Handle timer intervals.
373				Either::Right((_, next_rx)) => {
374					if let Err(err) = send_ws_ping(&mut sender).await {
375						tracing::warn!("WS send ping error: {}; terminate connection", err);
376						break;
377					}
378					rx_item = next_rx;
379					next_ping = ping_interval.next();
380				}
381			}
382		}
383
384		// Terminate connection and send close message.
385		let _ = sender.close().await;
386
387		// Notify all listeners and close down associated tasks.
388		bounded_subscriptions2.close();
389	});
390
391	// Buffer for incoming data.
392	let mut data = Vec::with_capacity(100);
393	let mut method_executors = FutureDriver::default();
394	let middleware = &middleware;
395
396	let result = loop {
397		data.clear();
398
399		{
400			// Need the extra scope to drop this pinned future and reclaim access to `data`
401			let receive = async {
402				// Identical loop to `soketto::receive_data` with debug logs for `Pong` frames.
403				loop {
404					match receiver.receive(&mut data).await? {
405						soketto::Incoming::Data(d) => break Ok(d),
406						soketto::Incoming::Pong(_) => tracing::debug!("recv pong"),
407						soketto::Incoming::Closed(_) => {
408							// The closing reason is already logged by `soketto` trace log level.
409							// Return the `Closed` error to avoid logging unnecessary warnings on clean shutdown.
410							break Err(SokettoError::Closed);
411						}
412					}
413				}
414			};
415
416			tokio::pin!(receive);
417
418			if let Err(err) = method_executors.select_with(Monitored::new(receive, &stop_server)).await {
419				match err {
420					MonitoredError::Selector(SokettoError::Closed) => {
421						tracing::debug!("WS transport: remote peer terminated the connection: {}", conn_id);
422						sink.close();
423						break Ok(());
424					}
425					MonitoredError::Selector(SokettoError::MessageTooLarge { current, maximum }) => {
426						tracing::warn!(
427							"WS transport error: outgoing message is too big error ({} bytes, max is {})",
428							current,
429							maximum
430						);
431						sink.send_error(Id::Null, reject_too_big_request(max_request_body_size));
432						continue;
433					}
434					// These errors can not be gracefully handled, so just log them and terminate the connection.
435					MonitoredError::Selector(err) => {
436						tracing::debug!("WS error: {}; terminate connection {}", err, conn_id);
437						sink.close();
438						break Err(err.into());
439					}
440					MonitoredError::Shutdown => break Ok(()),
441				};
442			};
443		};
444
445		let request_start = middleware.on_request();
446
447		let first_non_whitespace = data.iter().find(|byte| !byte.is_ascii_whitespace());
448		match first_non_whitespace {
449			Some(b'{') => {
450				let data = std::mem::take(&mut data);
451				let sink = sink.clone();
452				let resources = &resources;
453				let methods = &methods;
454				let bounded_subscriptions = bounded_subscriptions.clone();
455				let id_provider = &*id_provider;
456
457				let fut = async move {
458					let call = CallData {
459						conn_id,
460						resources,
461						max_response_body_size,
462						max_log_length,
463						methods,
464						bounded_subscriptions,
465						sink: &sink,
466						id_provider: &*id_provider,
467						middleware,
468						request_start,
469					};
470
471					match process_single_request(data, call).await {
472						MethodResult::JustMiddleware(r) => {
473							middleware.on_response(&r.result, request_start);
474						}
475						MethodResult::SendAndMiddleware(r) => {
476							middleware.on_response(&r.result, request_start);
477							let _ = sink.send_raw(r.result);
478						}
479					};
480				}
481				.boxed();
482
483				method_executors.add(fut);
484			}
485			Some(b'[') if !batch_requests_supported => {
486				let response = MethodResponse::error(
487					Id::Null,
488					ErrorObject::borrowed(BATCHES_NOT_SUPPORTED_CODE, &BATCHES_NOT_SUPPORTED_MSG, None),
489				);
490				middleware.on_response(&response.result, request_start);
491				let _ = sink.send_raw(response.result);
492			}
493			Some(b'[') => {
494				// Make sure the following variables are not moved into async closure below.
495				let resources = &resources;
496				let methods = &methods;
497				let bounded_subscriptions = bounded_subscriptions.clone();
498				let sink = sink.clone();
499				let id_provider = id_provider.clone();
500				let data = std::mem::take(&mut data);
501
502				let fut = async move {
503					let response = process_batch_request(Batch {
504						data,
505						call: CallData {
506							conn_id,
507							resources,
508							max_response_body_size,
509							max_log_length,
510							methods,
511							bounded_subscriptions,
512							sink: &sink,
513							id_provider: &*id_provider,
514							middleware,
515							request_start,
516						},
517					})
518					.await;
519
520					tx_log_from_str(&response.result, max_log_length);
521					middleware.on_response(&response.result, request_start);
522					let _ = sink.send_raw(response.result);
523				};
524
525				method_executors.add(Box::pin(fut));
526			}
527			_ => {
528				sink.send_error(Id::Null, ErrorCode::ParseError.into());
529			}
530		}
531	};
532
533	middleware.on_disconnect(remote_addr);
534
535	// Drive all running methods to completion.
536	// **NOTE** Do not return early in this function. This `await` needs to run to guarantee
537	// proper drop behaviour.
538	method_executors.await;
539
540	result
541}
542
543/// JSON-RPC Websocket server settings.
544#[derive(Debug, Clone)]
545struct Settings {
546	/// Maximum size in bytes of a request.
547	max_request_body_size: u32,
548	/// Maximum size in bytes of a response.
549	max_response_body_size: u32,
550	/// Maximum number of incoming connections allowed.
551	max_connections: u64,
552	/// Maximum number of subscriptions per connection.
553	max_subscriptions_per_connection: u32,
554	/// Max length for logging for requests and responses
555	///
556	/// Logs bigger than this limit will be truncated.
557	max_log_length: u32,
558	/// Access control based on HTTP headers
559	access_control: AccessControl,
560	/// Whether batch requests are supported by this server or not.
561	batch_requests_supported: bool,
562	/// Custom tokio runtime to run the server on.
563	tokio_runtime: Option<tokio::runtime::Handle>,
564	/// The interval at which `Ping` frames are submitted.
565	ping_interval: Duration,
566}
567
568impl Default for Settings {
569	fn default() -> Self {
570		Self {
571			max_request_body_size: TEN_MB_SIZE_BYTES,
572			max_response_body_size: TEN_MB_SIZE_BYTES,
573			max_log_length: 4096,
574			max_subscriptions_per_connection: 1024,
575			max_connections: MAX_CONNECTIONS,
576			batch_requests_supported: true,
577			access_control: AccessControl::default(),
578			tokio_runtime: None,
579			ping_interval: Duration::from_secs(60),
580		}
581	}
582}
583
584/// Builder to configure and create a JSON-RPC Websocket server
585#[derive(Debug)]
586pub struct Builder<M = ()> {
587	settings: Settings,
588	resources: Resources,
589	middleware: M,
590	id_provider: Arc<dyn IdProvider>,
591}
592
593impl Default for Builder {
594	fn default() -> Self {
595		Builder {
596			settings: Settings::default(),
597			resources: Resources::default(),
598			middleware: (),
599			id_provider: Arc::new(RandomIntegerIdProvider),
600		}
601	}
602}
603
604impl Builder {
605	/// Create a default server builder.
606	pub fn new() -> Self {
607		Self::default()
608	}
609}
610
611impl<M> Builder<M> {
612	/// Set the maximum size of a request body in bytes. Default is 10 MiB.
613	pub fn max_request_body_size(mut self, size: u32) -> Self {
614		self.settings.max_request_body_size = size;
615		self
616	}
617
618	/// Set the maximum size of a response body in bytes. Default is 10 MiB.
619	pub fn max_response_body_size(mut self, size: u32) -> Self {
620		self.settings.max_response_body_size = size;
621		self
622	}
623
624	/// Set the maximum number of connections allowed. Default is 100.
625	pub fn max_connections(mut self, max: u64) -> Self {
626		self.settings.max_connections = max;
627		self
628	}
629
630	/// Enables or disables support of [batch requests](https://www.jsonrpc.org/specification#batch).
631	/// By default, support is enabled.
632	pub fn batch_requests_supported(mut self, supported: bool) -> Self {
633		self.settings.batch_requests_supported = supported;
634		self
635	}
636
637	/// Set the maximum number of connections allowed. Default is 1024.
638	pub fn max_subscriptions_per_connection(mut self, max: u32) -> Self {
639		self.settings.max_subscriptions_per_connection = max;
640		self
641	}
642
643	/// Register a new resource kind. Errors if `label` is already registered, or if the number of
644	/// registered resources on this server instance would exceed 8.
645	///
646	/// See the module documentation for [`resurce_limiting`](../jsonrpsee_utils/server/resource_limiting/index.html#resource-limiting)
647	/// for details.
648	pub fn register_resource(mut self, label: &'static str, capacity: u16, default: u16) -> Result<Self, Error> {
649		self.resources.register(label, capacity, default)?;
650		Ok(self)
651	}
652
653	/// Add a middleware to the builder [`Middleware`](../jsonrpsee_core/middleware/trait.Middleware.html).
654	///
655	/// ```
656	/// use std::{time::Instant, net::SocketAddr};
657	///
658	/// use jsonrpsee_core::middleware::{WsMiddleware, Headers, MethodKind, Params};
659	/// use jsonrpsee_ws_server::WsServerBuilder;
660	///
661	/// #[derive(Clone)]
662	/// struct MyMiddleware;
663	///
664	/// impl WsMiddleware for MyMiddleware {
665	///     type Instant = Instant;
666	///
667	///     fn on_connect(&self, remote_addr: SocketAddr, headers: &Headers) {
668	///          println!("[MyMiddleware::on_call] remote_addr: {}, headers: {:?}", remote_addr, headers);
669	///     }
670	///
671	///     fn on_request(&self) -> Self::Instant {
672	///          Instant::now()
673	///     }
674	///
675	///     fn on_call(&self, method_name: &str, params: Params, kind: MethodKind) {
676	///          println!("[MyMiddleware::on_call] method: '{}' params: {:?}, kind: {:?}", method_name, params, kind);
677	///     }
678	///
679	///     fn on_result(&self, method_name: &str, success: bool, started_at: Self::Instant) {
680	///          println!("[MyMiddleware::on_result] '{}', worked? {}, time elapsed {:?}", method_name, success, started_at.elapsed());
681	///     }
682	///
683	///     fn on_response(&self, result: &str, started_at: Self::Instant) {
684	///          println!("[MyMiddleware::on_response] result: {}, time elapsed {:?}", result, started_at.elapsed());
685	///     }
686	///
687	///     fn on_disconnect(&self, remote_addr: SocketAddr) {
688	///          println!("[MyMiddleware::on_disconnect] remote_addr: {}", remote_addr);
689	///     }
690	/// }
691	///
692	/// let builder = WsServerBuilder::new().set_middleware(MyMiddleware);
693	/// ```
694	pub fn set_middleware<T: Middleware>(self, middleware: T) -> Builder<T> {
695		Builder { settings: self.settings, resources: self.resources, middleware, id_provider: self.id_provider }
696	}
697
698	/// Configure a custom [`tokio::runtime::Handle`] to run the server on.
699	///
700	/// Default: [`tokio::spawn`]
701	pub fn custom_tokio_runtime(mut self, rt: tokio::runtime::Handle) -> Self {
702		self.settings.tokio_runtime = Some(rt);
703		self
704	}
705
706	/// Configure the interval at which pings are submitted.
707	///
708	/// This option is used to keep the connection alive, and is just submitting `Ping` frames,
709	/// without making any assumptions about when a `Pong` frame should be received.
710	///
711	/// Default: 60 seconds.
712	///
713	/// # Examples
714	///
715	/// ```rust
716	/// use std::time::Duration;
717	/// use jsonrpsee_ws_server::WsServerBuilder;
718	///
719	/// // Set the ping interval to 10 seconds.
720	/// let builder = WsServerBuilder::default().ping_interval(Duration::from_secs(10));
721	/// ```
722	pub fn ping_interval(mut self, interval: Duration) -> Self {
723		self.settings.ping_interval = interval;
724		self
725	}
726
727	/// Configure custom `subscription ID` provider for the server to use
728	/// to when getting new subscription calls.
729	///
730	/// You may choose static dispatch or dynamic dispatch because
731	/// `IdProvider` is implemented for `Box<T>`.
732	///
733	/// Default: [`RandomIntegerIdProvider`].
734	///
735	/// # Examples
736	///
737	/// ```rust
738	/// use jsonrpsee_ws_server::{WsServerBuilder, RandomStringIdProvider, IdProvider};
739	///
740	/// // static dispatch
741	/// let builder1 = WsServerBuilder::default().set_id_provider(RandomStringIdProvider::new(16));
742	///
743	/// // or dynamic dispatch
744	/// let builder2 = WsServerBuilder::default().set_id_provider(Box::new(RandomStringIdProvider::new(16)));
745	/// ```
746	///
747	pub fn set_id_provider<I: IdProvider + 'static>(mut self, id_provider: I) -> Self {
748		self.id_provider = Arc::new(id_provider);
749		self
750	}
751
752	/// Sets access control settings.
753	pub fn set_access_control(mut self, acl: AccessControl) -> Self {
754		self.settings.access_control = acl;
755		self
756	}
757
758	/// Finalize the configuration of the server. Consumes the [`Builder`].
759	///
760	/// ```rust
761	/// #[tokio::main]
762	/// async fn main() {
763	///   let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
764	///   let occupied_addr = listener.local_addr().unwrap();
765	///   let addrs: &[std::net::SocketAddr] = &[
766	///       occupied_addr,
767	///       "127.0.0.1:0".parse().unwrap(),
768	///   ];
769	///   assert!(jsonrpsee_ws_server::WsServerBuilder::default().build(occupied_addr).await.is_err());
770	///   assert!(jsonrpsee_ws_server::WsServerBuilder::default().build(addrs).await.is_ok());
771	/// }
772	/// ```
773	///
774	pub async fn build(self, addrs: impl ToSocketAddrs) -> Result<Server<M>, Error> {
775		let listener = TcpListener::bind(addrs).await?;
776		let stop_monitor = StopMonitor::new();
777		let resources = self.resources;
778		Ok(Server {
779			listener,
780			cfg: self.settings,
781			stop_monitor,
782			resources,
783			middleware: self.middleware,
784			id_provider: self.id_provider,
785		})
786	}
787}
788
789async fn send_ws_message(
790	sender: &mut Sender<BufReader<BufWriter<Compat<TcpStream>>>>,
791	response: String,
792) -> Result<(), Error> {
793	sender.send_text_owned(response).await?;
794	sender.flush().await.map_err(Into::into)
795}
796
797async fn send_ws_ping(sender: &mut Sender<BufReader<BufWriter<Compat<TcpStream>>>>) -> Result<(), Error> {
798	tracing::debug!("send ping");
799	// Submit empty slice as "optional" parameter.
800	let slice: &[u8] = &[];
801	// Byte slice fails if the provided slice is larger than 125 bytes.
802	let byte_slice = ByteSlice125::try_from(slice).expect("Empty slice should fit into ByteSlice125");
803	sender.send_ping(byte_slice).await?;
804	sender.flush().await.map_err(Into::into)
805}
806
807#[derive(Debug, Clone)]
808struct Batch<'a, M: Middleware> {
809	data: Vec<u8>,
810	call: CallData<'a, M>,
811}
812
813#[derive(Debug, Clone)]
814struct CallData<'a, M: Middleware> {
815	conn_id: usize,
816	bounded_subscriptions: BoundedSubscriptions,
817	id_provider: &'a dyn IdProvider,
818	middleware: &'a M,
819	methods: &'a Methods,
820	max_response_body_size: u32,
821	max_log_length: u32,
822	resources: &'a Resources,
823	sink: &'a MethodSink,
824	request_start: M::Instant,
825}
826
827#[derive(Debug, Clone)]
828struct Call<'a, M: Middleware> {
829	params: Params<'a>,
830	name: &'a str,
831	call: CallData<'a, M>,
832	id: Id<'a>,
833}
834
835enum MethodResult {
836	JustMiddleware(MethodResponse),
837	SendAndMiddleware(MethodResponse),
838}
839
840impl MethodResult {
841	fn as_inner(&self) -> &MethodResponse {
842		match &self {
843			Self::JustMiddleware(r) => r,
844			Self::SendAndMiddleware(r) => r,
845		}
846	}
847}
848
849// Batch responses must be sent back as a single message so we read the results from each
850// request in the batch and read the results off of a new channel, `rx_batch`, and then send the
851// complete batch response back to the client over `tx`.
852async fn process_batch_request<M>(b: Batch<'_, M>) -> BatchResponse
853where
854	M: Middleware,
855{
856	let Batch { data, call } = b;
857
858	if let Ok(batch) = serde_json::from_slice::<Vec<Request>>(&data) {
859		return if !batch.is_empty() {
860			let batch = batch.into_iter().map(|req| Ok((req, call.clone())));
861			let batch_stream = futures_util::stream::iter(batch);
862
863			let trace = RpcTracing::batch();
864
865			return async {
866				let max_response_size = call.max_response_body_size;
867
868				let batch_response = batch_stream
869					.try_fold(
870						BatchResponseBuilder::new_with_limit(max_response_size as usize),
871						|batch_response, (req, call)| async move {
872							let params = Params::new(req.params.map(|params| params.get()));
873							let response = execute_call(Call { name: &req.method, params, id: req.id, call }).await;
874							batch_response.append(response.as_inner())
875						},
876					)
877					.await;
878
879				match batch_response {
880					Ok(batch) => batch.finish(),
881					Err(batch_err) => batch_err,
882				}
883			}
884			.instrument(trace.into_span())
885			.await;
886		} else {
887			BatchResponse::error(Id::Null, ErrorObject::from(ErrorCode::InvalidRequest))
888		};
889	}
890
891	let (id, code) = prepare_error(&data);
892	BatchResponse::error(id, ErrorObject::from(code))
893}
894
895async fn process_single_request<M: Middleware>(data: Vec<u8>, call: CallData<'_, M>) -> MethodResult {
896	if let Ok(req) = serde_json::from_slice::<Request>(&data) {
897		let trace = RpcTracing::method_call(&req.method);
898
899		async {
900			rx_log_from_json(&req, call.max_log_length);
901
902			let params = Params::new(req.params.map(|params| params.get()));
903			let name = &req.method;
904			let id = req.id;
905
906			execute_call(Call { name, params, id, call }).await
907		}
908		.instrument(trace.into_span())
909		.await
910	} else {
911		let (id, code) = prepare_error(&data);
912		MethodResult::SendAndMiddleware(MethodResponse::error(id, ErrorObject::from(code)))
913	}
914}
915
916/// Execute a call which returns result of the call with a additional sink
917/// to fire a signal once the subscription call has been answered.
918///
919/// Returns `(MethodResponse, None)` on every call that isn't a subscription
920/// Otherwise `(MethodResponse, Some(PendingSubscriptionCallTx)`.
921async fn execute_call<M: Middleware>(c: Call<'_, M>) -> MethodResult {
922	let Call { name, id, params, call } = c;
923	let CallData {
924		resources,
925		methods,
926		middleware,
927		max_response_body_size,
928		max_log_length,
929		conn_id,
930		bounded_subscriptions,
931		id_provider,
932		sink,
933		request_start,
934	} = call;
935
936	let response = match methods.method_with_name(name) {
937		None => {
938			middleware.on_call(name, params.clone(), middleware::MethodKind::Unknown);
939			let response = MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound));
940			MethodResult::SendAndMiddleware(response)
941		}
942		Some((name, method)) => match &method.inner() {
943			MethodKind::Sync(callback) => {
944				middleware.on_call(name, params.clone(), middleware::MethodKind::MethodCall);
945
946				match method.claim(name, resources) {
947					Ok(guard) => {
948						let r = (callback)(id, params, max_response_body_size as usize);
949						drop(guard);
950						MethodResult::SendAndMiddleware(r)
951					}
952					Err(err) => {
953						tracing::error!("[Methods::execute_with_resources] failed to lock resources: {:?}", err);
954						let response = MethodResponse::error(id, ErrorObject::from(ErrorCode::ServerIsBusy));
955						MethodResult::SendAndMiddleware(response)
956					}
957				}
958			}
959			MethodKind::Async(callback) => {
960				middleware.on_call(name, params.clone(), middleware::MethodKind::MethodCall);
961
962				match method.claim(name, resources) {
963					Ok(guard) => {
964						let id = id.into_owned();
965						let params = params.into_owned();
966
967						let response =
968							(callback)(id, params, conn_id, max_response_body_size as usize, Some(guard)).await;
969						MethodResult::SendAndMiddleware(response)
970					}
971					Err(err) => {
972						tracing::error!("[Methods::execute_with_resources] failed to lock resources: {:?}", err);
973						let response = MethodResponse::error(id, ErrorObject::from(ErrorCode::ServerIsBusy));
974						MethodResult::SendAndMiddleware(response)
975					}
976				}
977			}
978			MethodKind::Subscription(callback) => {
979				middleware.on_call(name, params.clone(), middleware::MethodKind::Subscription);
980
981				match method.claim(name, resources) {
982					Ok(guard) => {
983						if let Some(cn) = bounded_subscriptions.acquire() {
984							let conn_state = ConnState { conn_id, close_notify: cn, id_provider };
985							let response = callback(id.clone(), params, sink.clone(), conn_state, Some(guard)).await;
986							MethodResult::JustMiddleware(response)
987						} else {
988							let response =
989								MethodResponse::error(id, reject_too_many_subscriptions(bounded_subscriptions.max()));
990							MethodResult::SendAndMiddleware(response)
991						}
992					}
993					Err(err) => {
994						tracing::error!("[Methods::execute_with_resources] failed to lock resources: {:?}", err);
995						let response = MethodResponse::error(id, ErrorObject::from(ErrorCode::ServerIsBusy));
996						MethodResult::SendAndMiddleware(response)
997					}
998				}
999			}
1000			MethodKind::Unsubscription(callback) => {
1001				middleware.on_call(name, params.clone(), middleware::MethodKind::Unsubscription);
1002
1003				// Don't adhere to any resource or subscription limits; always let unsubscribing happen!
1004				let result = callback(id, params, conn_id, max_response_body_size as usize);
1005				MethodResult::SendAndMiddleware(result)
1006			}
1007		},
1008	};
1009
1010	let r = response.as_inner();
1011
1012	rx_log_from_str(&r.result, max_log_length);
1013	middleware.on_result(name, r.success, request_start);
1014	response
1015}
1016
1017/// Helper to fetch the `WebSocketKey` and `Headers` from the WebSocket handshake.
1018async fn get_key_and_headers(
1019	server: &mut SokettoServer<'_, BufReader<BufWriter<Compat<TcpStream>>>>,
1020	cfg: &Settings,
1021) -> Result<(WebSocketKey, HeaderMap), Error> {
1022	let req = server.receive_request().await?;
1023
1024	tracing::trace!("Connection request: {:?}", req);
1025
1026	let host = std::str::from_utf8(req.headers().host).map_err(|e| Error::HttpHeaderRejected("Host", e.to_string()))?;
1027
1028	let origin = req.headers().origin.and_then(|h| {
1029		let res = std::str::from_utf8(h).ok();
1030		if res.is_none() {
1031			tracing::warn!("Origin header invalid UTF-8; treated as no Origin header");
1032		}
1033		res
1034	});
1035
1036	let host_check = cfg.access_control.verify_host(host);
1037	let origin_check = cfg.access_control.verify_origin(origin, host);
1038
1039	let mut headers = HeaderMap::new();
1040
1041	host_check.and(origin_check).map(|()| {
1042		let key = req.key();
1043
1044		if let Ok(val) = HeaderValue::from_str(host) {
1045			headers.insert(HOST, val);
1046		}
1047
1048		if let Some(Ok(val)) = origin.map(HeaderValue::from_str) {
1049			headers.insert(ORIGIN, val);
1050		}
1051
1052		(key, headers)
1053	})
1054}