1use std::future::Future;
28use std::net::{SocketAddr, TcpListener as StdTcpListener};
29use std::pin::Pin;
30use std::sync::Arc;
31use std::sync::atomic::AtomicU32;
32use std::task::Poll;
33use std::time::Duration;
34
35use crate::future::{ConnectionGuard, ServerHandle, SessionClose, SessionClosedFuture, StopHandle, session_close};
36use crate::middleware::rpc::{RpcService, RpcServiceCfg};
37use crate::transport::ws::BackgroundTaskParams;
38use crate::transport::{http, ws};
39use crate::utils::deserialize_with_ext;
40use crate::{Extensions, HttpBody, HttpRequest, HttpResponse, LOG_TARGET};
41
42use futures_util::future::{self, Either, FutureExt};
43use futures_util::io::{BufReader, BufWriter};
44use hyper::body::Bytes;
45use hyper_util::rt::{TokioExecutor, TokioIo};
46use jsonrpsee_core::id_providers::RandomIntegerIdProvider;
47use jsonrpsee_core::middleware::{Batch, BatchEntry, BatchEntryErr, RpcServiceBuilder, RpcServiceT};
48use jsonrpsee_core::server::helpers::prepare_error;
49use jsonrpsee_core::server::{BoundedSubscriptions, ConnectionId, MethodResponse, MethodSink, Methods};
50use jsonrpsee_core::traits::IdProvider;
51use jsonrpsee_core::{BoxError, JsonRawValue, TEN_MB_SIZE_BYTES};
52use jsonrpsee_types::error::{
53 BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG, ErrorCode, reject_too_big_batch_request,
54};
55use jsonrpsee_types::{ErrorObject, Id};
56use soketto::handshake::http::is_upgrade_request;
57use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
58use tokio::sync::{OwnedSemaphorePermit, mpsc, watch};
59use tokio_util::compat::TokioAsyncReadCompatExt;
60use tower::layer::util::Identity;
61use tower::{Layer, Service};
62use tracing::{Instrument, instrument};
63
64const MAX_CONNECTIONS: u32 = 100;
66
67type Notif<'a> = Option<std::borrow::Cow<'a, JsonRawValue>>;
68
69pub struct Server<HttpMiddleware = Identity, RpcMiddleware = Identity> {
71 listener: TcpListener,
72 server_cfg: ServerConfig,
73 rpc_middleware: RpcServiceBuilder<RpcMiddleware>,
74 http_middleware: tower::ServiceBuilder<HttpMiddleware>,
75}
76
77impl Server<Identity, Identity> {
78 pub fn builder() -> Builder<Identity, Identity> {
80 Builder::new()
81 }
82}
83
84impl<RpcMiddleware, HttpMiddleware> std::fmt::Debug for Server<RpcMiddleware, HttpMiddleware> {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 f.debug_struct("Server").field("listener", &self.listener).field("server_cfg", &self.server_cfg).finish()
87 }
88}
89
90impl<RpcMiddleware, HttpMiddleware> Server<RpcMiddleware, HttpMiddleware> {
91 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
93 self.listener.local_addr()
94 }
95}
96
97impl<HttpMiddleware, RpcMiddleware, Body> Server<HttpMiddleware, RpcMiddleware>
98where
99 RpcMiddleware: Layer<RpcService> + Clone + Send + 'static,
100 <RpcMiddleware as Layer<RpcService>>::Service: RpcServiceT,
101 HttpMiddleware: Layer<TowerServiceNoHttp<RpcMiddleware>> + Send + 'static,
102 <HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service:
103 Service<HttpRequest, Response = HttpResponse<Body>, Error = BoxError> + Clone + Send,
104 <<HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service as Service<HttpRequest>>::Future: Send,
105 Body: http_body::Body<Data = Bytes> + Send + 'static,
106 <Body as http_body::Body>::Error: Into<BoxError>,
107{
108 pub fn start(mut self, methods: impl Into<Methods>) -> ServerHandle {
112 let methods = methods.into();
113 let (stop_tx, stop_rx) = watch::channel(());
114
115 let stop_handle = StopHandle::new(stop_rx);
116
117 match self.server_cfg.tokio_runtime.take() {
118 Some(rt) => rt.spawn(self.start_inner(methods, stop_handle)),
119 None => tokio::spawn(self.start_inner(methods, stop_handle)),
120 };
121
122 ServerHandle::new(stop_tx)
123 }
124
125 async fn start_inner(self, methods: Methods, stop_handle: StopHandle) {
126 let mut id: u32 = 0;
127 let connection_guard = ConnectionGuard::new(self.server_cfg.max_connections as usize);
128 let listener = self.listener;
129
130 let stopped = stop_handle.clone().shutdown();
131 tokio::pin!(stopped);
132
133 let (drop_on_completion, mut process_connection_awaiter) = mpsc::channel::<()>(1);
134
135 loop {
136 match try_accept_conn(&listener, stopped).await {
137 AcceptConnection::Established { socket, remote_addr, stop } => {
138 process_connection(ProcessConnection {
139 http_middleware: &self.http_middleware,
140 rpc_middleware: self.rpc_middleware.clone(),
141 remote_addr,
142 methods: methods.clone(),
143 stop_handle: stop_handle.clone(),
144 conn_id: id,
145 server_cfg: self.server_cfg.clone(),
146 conn_guard: &connection_guard,
147 socket,
148 drop_on_completion: drop_on_completion.clone(),
149 });
150 id = id.wrapping_add(1);
151 stopped = stop;
152 }
153 AcceptConnection::Err((e, stop)) => {
154 tracing::debug!(target: LOG_TARGET, "Error while awaiting a new connection: {:?}", e);
155 stopped = stop;
156 }
157 AcceptConnection::Shutdown => break,
158 }
159 }
160
161 drop(drop_on_completion);
163
164 while process_connection_awaiter.recv().await.is_some() {
166 }
169 }
170}
171
172#[derive(Debug, Clone)]
174pub struct ServerConfig {
175 pub(crate) max_request_body_size: u32,
177 pub(crate) max_response_body_size: u32,
179 pub(crate) max_connections: u32,
181 pub(crate) max_subscriptions_per_connection: u32,
183 pub(crate) batch_requests_config: BatchRequestConfig,
185 pub(crate) tokio_runtime: Option<tokio::runtime::Handle>,
187 pub(crate) enable_http: bool,
189 pub(crate) enable_ws: bool,
191 pub(crate) message_buffer_capacity: u32,
193 pub(crate) ping_config: Option<PingConfig>,
195 pub(crate) id_provider: Arc<dyn IdProvider>,
197 pub(crate) tcp_no_delay: bool,
199 pub(crate) keep_alive: Option<std::time::Duration>,
201 pub(crate) keep_alive_timeout: Duration,
203}
204
205#[derive(Debug, Clone)]
207pub struct ServerConfigBuilder {
208 max_request_body_size: u32,
210 max_response_body_size: u32,
212 max_connections: u32,
214 max_subscriptions_per_connection: u32,
216 batch_requests_config: BatchRequestConfig,
218 tokio_runtime: Option<tokio::runtime::Handle>,
220 enable_http: bool,
222 enable_ws: bool,
224 message_buffer_capacity: u32,
226 ping_config: Option<PingConfig>,
228 id_provider: Arc<dyn IdProvider>,
230 tcp_no_delay: bool,
232 keep_alive: Option<std::time::Duration>,
234 keep_alive_timeout: std::time::Duration,
236}
237
238#[derive(Debug, Clone)]
240pub struct TowerServiceBuilder<RpcMiddleware, HttpMiddleware> {
241 pub(crate) server_cfg: ServerConfig,
243 pub(crate) rpc_middleware: RpcServiceBuilder<RpcMiddleware>,
245 pub(crate) http_middleware: tower::ServiceBuilder<HttpMiddleware>,
247 pub(crate) conn_id: Arc<AtomicU32>,
249 pub(crate) conn_guard: ConnectionGuard,
251}
252
253#[derive(Debug, Copy, Clone)]
255pub enum BatchRequestConfig {
256 Disabled,
258 Limit(u32),
260 Unlimited,
262}
263
264#[derive(Debug, Clone)]
267pub struct ConnectionState {
268 pub(crate) stop_handle: StopHandle,
270 pub(crate) conn_id: u32,
272 pub(crate) _conn_permit: Arc<OwnedSemaphorePermit>,
274}
275
276impl ConnectionState {
277 pub fn new(stop_handle: StopHandle, conn_id: u32, conn_permit: OwnedSemaphorePermit) -> ConnectionState {
279 Self { stop_handle, conn_id, _conn_permit: Arc::new(conn_permit) }
280 }
281}
282
283#[derive(Debug, Copy, Clone)]
296pub struct PingConfig {
297 pub(crate) ping_interval: Duration,
299 pub(crate) inactive_limit: Duration,
301 pub(crate) max_failures: usize,
303}
304
305impl Default for PingConfig {
306 fn default() -> Self {
307 Self { ping_interval: Duration::from_secs(30), max_failures: 1, inactive_limit: Duration::from_secs(40) }
308 }
309}
310
311impl PingConfig {
312 pub fn new() -> Self {
314 Self::default()
315 }
316
317 pub fn ping_interval(mut self, ping_interval: Duration) -> Self {
319 self.ping_interval = ping_interval;
320 self
321 }
322
323 pub fn inactive_limit(mut self, inactivity_limit: Duration) -> Self {
329 self.inactive_limit = inactivity_limit;
330 self
331 }
332
333 pub fn max_failures(mut self, max: usize) -> Self {
340 assert!(max > 0);
341 self.max_failures = max;
342 self
343 }
344}
345
346impl Default for ServerConfig {
347 fn default() -> Self {
348 ServerConfig::builder().build()
349 }
350}
351
352impl ServerConfig {
353 pub fn builder() -> ServerConfigBuilder {
355 ServerConfigBuilder::default()
356 }
357}
358
359impl Default for ServerConfigBuilder {
360 fn default() -> Self {
361 ServerConfigBuilder {
362 max_request_body_size: TEN_MB_SIZE_BYTES,
363 max_response_body_size: TEN_MB_SIZE_BYTES,
364 max_connections: MAX_CONNECTIONS,
365 max_subscriptions_per_connection: 1024,
366 batch_requests_config: BatchRequestConfig::Unlimited,
367 tokio_runtime: None,
368 enable_http: true,
369 enable_ws: true,
370 message_buffer_capacity: 1024,
371 ping_config: None,
372 id_provider: Arc::new(RandomIntegerIdProvider),
373 tcp_no_delay: true,
374 keep_alive: None,
375 keep_alive_timeout: Duration::from_secs(20),
377 }
378 }
379}
380
381impl ServerConfigBuilder {
382 pub fn new() -> Self {
384 Self::default()
385 }
386
387 pub fn max_request_body_size(mut self, size: u32) -> Self {
389 self.max_request_body_size = size;
390 self
391 }
392
393 pub fn max_response_body_size(mut self, size: u32) -> Self {
395 self.max_response_body_size = size;
396 self
397 }
398
399 pub fn max_connections(mut self, max: u32) -> Self {
401 self.max_connections = max;
402 self
403 }
404
405 pub fn max_subscriptions_per_connection(mut self, max: u32) -> Self {
407 self.max_subscriptions_per_connection = max;
408 self
409 }
410
411 pub fn set_batch_request_config(mut self, cfg: BatchRequestConfig) -> Self {
416 self.batch_requests_config = cfg;
417 self
418 }
419
420 pub fn custom_tokio_runtime(mut self, rt: tokio::runtime::Handle) -> Self {
424 self.tokio_runtime = Some(rt);
425 self
426 }
427
428 pub fn http_only(mut self) -> Self {
432 self.enable_http = true;
433 self.enable_ws = false;
434 self
435 }
436
437 pub fn ws_only(mut self) -> Self {
443 self.enable_http = false;
444 self.enable_ws = true;
445 self
446 }
447
448 pub fn set_message_buffer_capacity(mut self, c: u32) -> Self {
467 assert!(c > 0, "buffer capacity must be set to > 0");
468 self.message_buffer_capacity = c;
469 self
470 }
471
472 pub fn enable_ws_ping(mut self, config: PingConfig) -> Self {
487 self.ping_config = Some(config);
488 self
489 }
490
491 pub fn disable_ws_ping(mut self) -> Self {
495 self.ping_config = None;
496 self
497 }
498
499 pub fn set_id_provider<I: IdProvider + 'static>(mut self, id_provider: I) -> Self {
520 self.id_provider = Arc::new(id_provider);
521 self
522 }
523
524 pub fn set_tcp_no_delay(mut self, no_delay: bool) -> Self {
528 self.tcp_no_delay = no_delay;
529 self
530 }
531
532 pub fn set_keep_alive(mut self, keep_alive: Option<std::time::Duration>) -> Self {
534 self.keep_alive = keep_alive;
535 self
536 }
537
538 pub fn set_keep_alive_timeout(mut self, keep_alive_timeout: Duration) -> Self {
540 self.keep_alive_timeout = keep_alive_timeout;
541 self
542 }
543
544 pub fn build(self) -> ServerConfig {
546 ServerConfig {
547 max_request_body_size: self.max_request_body_size,
548 max_response_body_size: self.max_response_body_size,
549 max_connections: self.max_connections,
550 max_subscriptions_per_connection: self.max_subscriptions_per_connection,
551 batch_requests_config: self.batch_requests_config,
552 tokio_runtime: self.tokio_runtime,
553 enable_http: self.enable_http,
554 enable_ws: self.enable_ws,
555 message_buffer_capacity: self.message_buffer_capacity,
556 ping_config: self.ping_config,
557 id_provider: self.id_provider,
558 tcp_no_delay: self.tcp_no_delay,
559 keep_alive: self.keep_alive,
560 keep_alive_timeout: self.keep_alive_timeout,
561 }
562 }
563}
564
565#[derive(Debug)]
567pub struct Builder<HttpMiddleware, RpcMiddleware> {
568 server_cfg: ServerConfig,
569 rpc_middleware: RpcServiceBuilder<RpcMiddleware>,
570 http_middleware: tower::ServiceBuilder<HttpMiddleware>,
571}
572
573impl Default for Builder<Identity, Identity> {
574 fn default() -> Self {
575 Builder {
576 server_cfg: ServerConfig::default(),
577 rpc_middleware: RpcServiceBuilder::new(),
578 http_middleware: tower::ServiceBuilder::new(),
579 }
580 }
581}
582
583impl Builder<Identity, Identity> {
584 pub fn new() -> Self {
586 Self::default()
587 }
588
589 pub fn with_config(config: ServerConfig) -> Self {
591 Self { server_cfg: config, ..Default::default() }
592 }
593}
594
595impl<RpcMiddleware, HttpMiddleware> TowerServiceBuilder<RpcMiddleware, HttpMiddleware> {
596 pub fn build(
598 self,
599 methods: impl Into<Methods>,
600 stop_handle: StopHandle,
601 ) -> TowerService<RpcMiddleware, HttpMiddleware> {
602 let conn_id = self.conn_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
603
604 let rpc_middleware = TowerServiceNoHttp {
605 rpc_middleware: self.rpc_middleware,
606 inner: ServiceData {
607 methods: methods.into(),
608 stop_handle,
609 conn_id,
610 conn_guard: self.conn_guard,
611 server_cfg: self.server_cfg,
612 },
613 on_session_close: None,
614 };
615
616 TowerService { rpc_middleware, http_middleware: self.http_middleware }
617 }
618
619 pub fn connection_id(mut self, id: u32) -> Self {
623 self.conn_id = Arc::new(AtomicU32::new(id));
624 self
625 }
626
627 pub fn max_connections(mut self, limit: u32) -> Self {
629 self.conn_guard = ConnectionGuard::new(limit as usize);
630 self
631 }
632
633 pub fn set_rpc_middleware<T>(self, rpc_middleware: RpcServiceBuilder<T>) -> TowerServiceBuilder<T, HttpMiddleware> {
635 TowerServiceBuilder {
636 server_cfg: self.server_cfg,
637 rpc_middleware,
638 http_middleware: self.http_middleware,
639 conn_id: self.conn_id,
640 conn_guard: self.conn_guard,
641 }
642 }
643
644 pub fn set_http_middleware<T>(
646 self,
647 http_middleware: tower::ServiceBuilder<T>,
648 ) -> TowerServiceBuilder<RpcMiddleware, T> {
649 TowerServiceBuilder {
650 server_cfg: self.server_cfg,
651 rpc_middleware: self.rpc_middleware,
652 http_middleware,
653 conn_id: self.conn_id,
654 conn_guard: self.conn_guard,
655 }
656 }
657}
658
659impl<HttpMiddleware, RpcMiddleware> Builder<HttpMiddleware, RpcMiddleware> {
660 pub fn set_config(mut self, cfg: ServerConfig) -> Self {
662 self.server_cfg = cfg;
663 self
664 }
665
666 pub fn set_rpc_middleware<T>(self, rpc_middleware: RpcServiceBuilder<T>) -> Builder<HttpMiddleware, T> {
733 Builder { server_cfg: self.server_cfg, rpc_middleware, http_middleware: self.http_middleware }
734 }
735
736 pub fn set_http_middleware<T>(self, http_middleware: tower::ServiceBuilder<T>) -> Builder<T, RpcMiddleware> {
759 Builder { server_cfg: self.server_cfg, http_middleware, rpc_middleware: self.rpc_middleware }
760 }
761
762 pub fn to_service_builder(self) -> TowerServiceBuilder<RpcMiddleware, HttpMiddleware> {
838 let max_conns = self.server_cfg.max_connections as usize;
839
840 TowerServiceBuilder {
841 server_cfg: self.server_cfg,
842 rpc_middleware: self.rpc_middleware,
843 http_middleware: self.http_middleware,
844 conn_id: Arc::new(AtomicU32::new(0)),
845 conn_guard: ConnectionGuard::new(max_conns),
846 }
847 }
848
849 pub async fn build(self, addrs: impl ToSocketAddrs) -> std::io::Result<Server<HttpMiddleware, RpcMiddleware>> {
866 let listener = TcpListener::bind(addrs).await?;
867
868 Ok(Server {
869 listener,
870 server_cfg: self.server_cfg,
871 rpc_middleware: self.rpc_middleware,
872 http_middleware: self.http_middleware,
873 })
874 }
875
876 pub fn build_from_tcp(
900 self,
901 listener: impl Into<StdTcpListener>,
902 ) -> std::io::Result<Server<HttpMiddleware, RpcMiddleware>> {
903 let listener = TcpListener::from_std(listener.into())?;
904
905 Ok(Server {
906 listener,
907 server_cfg: self.server_cfg,
908 rpc_middleware: self.rpc_middleware,
909 http_middleware: self.http_middleware,
910 })
911 }
912}
913
914#[derive(Debug, Clone)]
916struct ServiceData {
917 methods: Methods,
919 stop_handle: StopHandle,
921 conn_id: u32,
923 conn_guard: ConnectionGuard,
925 server_cfg: ServerConfig,
927}
928
929#[derive(Debug, Clone)]
934pub struct TowerService<RpcMiddleware, HttpMiddleware> {
935 rpc_middleware: TowerServiceNoHttp<RpcMiddleware>,
936 http_middleware: tower::ServiceBuilder<HttpMiddleware>,
937}
938
939impl<RpcMiddleware, HttpMiddleware> TowerService<RpcMiddleware, HttpMiddleware> {
940 pub fn on_session_closed(&mut self) -> SessionClosedFuture {
946 if let Some(n) = self.rpc_middleware.on_session_close.as_mut() {
947 n.closed()
949 } else {
950 let (session_close, fut) = session_close();
951 self.rpc_middleware.on_session_close = Some(session_close);
952 fut
953 }
954 }
955}
956
957impl<RequestBody, ResponseBody, RpcMiddleware, HttpMiddleware> Service<HttpRequest<RequestBody>> for TowerService<RpcMiddleware, HttpMiddleware>
958where
959 RpcMiddleware: Layer<RpcService> + Clone,
960 <RpcMiddleware as Layer<RpcService>>::Service: RpcServiceT + 'static,
961 HttpMiddleware: Layer<TowerServiceNoHttp<RpcMiddleware>> + Send + 'static,
962 <HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service:
963 Service<HttpRequest<RequestBody>, Response = HttpResponse<ResponseBody>, Error = BoxError> + Send,
964 <<HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service as Service<HttpRequest<RequestBody>>>::Future:
965 Send + 'static,
966 RequestBody: http_body::Body<Data = Bytes> + Send + 'static,
967 <RequestBody as http_body::Body>::Error: Into<BoxError>,
968{
969 type Response = HttpResponse<ResponseBody>;
970 type Error = BoxError;
971 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
972
973 fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
974 Poll::Ready(Ok(()))
975 }
976
977 fn call(&mut self, request: HttpRequest<RequestBody>) -> Self::Future {
978 Box::pin(self.http_middleware.service(self.rpc_middleware.clone()).call(request))
979 }
980}
981
982#[derive(Debug, Clone)]
987pub struct TowerServiceNoHttp<L> {
988 inner: ServiceData,
989 rpc_middleware: RpcServiceBuilder<L>,
990 on_session_close: Option<SessionClose>,
991}
992
993impl<Body, RpcMiddleware> Service<HttpRequest<Body>> for TowerServiceNoHttp<RpcMiddleware>
994where
995 RpcMiddleware: Layer<RpcService>,
996 <RpcMiddleware as Layer<RpcService>>::Service: RpcServiceT<
997 MethodResponse = MethodResponse,
998 BatchResponse = MethodResponse,
999 NotificationResponse = MethodResponse,
1000 > + Send
1001 + Sync
1002 + 'static,
1003 Body: http_body::Body<Data = Bytes> + Send + 'static,
1004 <Body as http_body::Body>::Error: Into<BoxError>,
1005{
1006 type Response = HttpResponse;
1007
1008 type Error = BoxError;
1011
1012 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1013
1014 fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
1015 Poll::Ready(Ok(()))
1016 }
1017
1018 fn call(&mut self, request: HttpRequest<Body>) -> Self::Future {
1019 let mut request = request.map(HttpBody::new);
1020
1021 let conn_guard = &self.inner.conn_guard;
1022 let stop_handle = self.inner.stop_handle.clone();
1023 let conn_id = self.inner.conn_id;
1024 let on_session_close = self.on_session_close.take();
1025
1026 tracing::trace!(target: LOG_TARGET, "{:?}", request);
1027
1028 let Some(conn_permit) = conn_guard.try_acquire() else {
1029 return async move { Ok(http::response::too_many_requests()) }.boxed();
1030 };
1031
1032 let conn = ConnectionState::new(stop_handle.clone(), conn_id, conn_permit);
1033
1034 let max_conns = conn_guard.max_connections();
1035 let curr_conns = max_conns - conn_guard.available_connections();
1036 tracing::debug!(target: LOG_TARGET, "Accepting new connection {}/{}", curr_conns, max_conns);
1037
1038 let req_ext = request.extensions_mut();
1039 req_ext.insert::<ConnectionGuard>(conn_guard.clone());
1040 req_ext.insert::<ConnectionId>(conn.conn_id.into());
1041
1042 let is_upgrade_request = is_upgrade_request(&request);
1043
1044 if self.inner.server_cfg.enable_ws && is_upgrade_request {
1045 let this = self.inner.clone();
1046
1047 let mut server = soketto::handshake::http::Server::new();
1048
1049 let response = match server.receive_request(&request) {
1050 Ok(response) => {
1051 let (tx, rx) = mpsc::channel(this.server_cfg.message_buffer_capacity as usize);
1052 let sink = MethodSink::new(tx);
1053
1054 let (pending_calls, pending_calls_completed) = mpsc::channel::<()>(1);
1058
1059 let cfg = RpcServiceCfg::CallsAndSubscriptions {
1060 bounded_subscriptions: BoundedSubscriptions::new(
1061 this.server_cfg.max_subscriptions_per_connection,
1062 ),
1063 id_provider: this.server_cfg.id_provider.clone(),
1064 sink: sink.clone(),
1065 _pending_calls: pending_calls,
1066 };
1067
1068 let rpc_service = RpcService::new(
1069 this.methods.clone(),
1070 this.server_cfg.max_response_body_size as usize,
1071 this.conn_id.into(),
1072 cfg,
1073 );
1074
1075 let rpc_service = self.rpc_middleware.service(rpc_service);
1076
1077 tokio::spawn(
1078 async move {
1079 let extensions = request.extensions().clone();
1080
1081 let upgraded = match hyper::upgrade::on(request).await {
1082 Ok(u) => u,
1083 Err(e) => {
1084 tracing::debug!(target: LOG_TARGET, "Could not upgrade connection: {}", e);
1085 return;
1086 }
1087 };
1088
1089 let io = TokioIo::new(upgraded);
1090
1091 let stream = BufReader::new(BufWriter::new(io.compat()));
1092 let mut ws_builder = server.into_builder(stream);
1093 ws_builder.set_max_message_size(this.server_cfg.max_request_body_size as usize);
1094 let (sender, receiver) = ws_builder.finish();
1095
1096 let params = BackgroundTaskParams {
1097 server_cfg: this.server_cfg,
1098 conn,
1099 ws_sender: sender,
1100 ws_receiver: receiver,
1101 rpc_service,
1102 sink,
1103 rx,
1104 pending_calls_completed,
1105 on_session_close,
1106 extensions,
1107 };
1108
1109 ws::background_task(params).await;
1110 }
1111 .in_current_span(),
1112 );
1113
1114 response.map(|()| HttpBody::empty())
1115 }
1116 Err(e) => {
1117 tracing::debug!(target: LOG_TARGET, "Could not upgrade connection: {}", e);
1118 HttpResponse::new(HttpBody::from(format!("Could not upgrade connection: {e}")))
1119 }
1120 };
1121
1122 async { Ok(response) }.boxed()
1123 } else if self.inner.server_cfg.enable_http && !is_upgrade_request {
1124 let this = &self.inner;
1125 let max_response_size = this.server_cfg.max_response_body_size;
1126 let max_request_size = this.server_cfg.max_request_body_size;
1127 let methods = this.methods.clone();
1128 let batch_config = this.server_cfg.batch_requests_config;
1129
1130 let rpc_service = self.rpc_middleware.service(RpcService::new(
1131 methods,
1132 max_response_size as usize,
1133 this.conn_id.into(),
1134 RpcServiceCfg::OnlyCalls,
1135 ));
1136
1137 Box::pin(async move {
1138 let rp = http::call_with_service(request, batch_config, max_request_size, rpc_service).await;
1139 drop(conn);
1142 Ok(rp)
1143 })
1144 } else {
1145 Box::pin(async { Ok(http::response::denied()) })
1148 }
1149 }
1150}
1151
1152struct ProcessConnection<'a, HttpMiddleware, RpcMiddleware> {
1153 http_middleware: &'a tower::ServiceBuilder<HttpMiddleware>,
1154 rpc_middleware: RpcServiceBuilder<RpcMiddleware>,
1155 conn_guard: &'a ConnectionGuard,
1156 conn_id: u32,
1157 server_cfg: ServerConfig,
1158 stop_handle: StopHandle,
1159 socket: TcpStream,
1160 drop_on_completion: mpsc::Sender<()>,
1161 remote_addr: SocketAddr,
1162 methods: Methods,
1163}
1164
1165#[instrument(name = "connection", skip_all, fields(remote_addr = %params.remote_addr, conn_id = %params.conn_id), level = "INFO")]
1166fn process_connection<'a, RpcMiddleware, HttpMiddleware, Body>(params: ProcessConnection<HttpMiddleware, RpcMiddleware>)
1167where
1168 HttpMiddleware: Layer<TowerServiceNoHttp<RpcMiddleware>> + Send + 'static,
1169 <HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service:
1170 Service<HttpRequest, Response = HttpResponse<Body>, Error = BoxError> + Clone + Send + 'static,
1171 <<HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service as Service<HttpRequest>>::Future:
1172 Send + 'static,
1173 Body: http_body::Body<Data = Bytes> + Send + 'static,
1174 <Body as http_body::Body>::Error: Into<BoxError>,
1175{
1176 let ProcessConnection {
1177 http_middleware,
1178 rpc_middleware,
1179 conn_guard,
1180 conn_id,
1181 server_cfg,
1182 socket,
1183 stop_handle,
1184 drop_on_completion,
1185 methods,
1186 ..
1187 } = params;
1188
1189 if let Err(e) = socket.set_nodelay(server_cfg.tcp_no_delay) {
1190 tracing::warn!(target: LOG_TARGET, "Could not set NODELAY on socket: {:?}", e);
1191 return;
1192 }
1193
1194 let keep_alive = server_cfg.keep_alive;
1195 let keep_alive_timeout = server_cfg.keep_alive_timeout;
1196
1197 let tower_service = TowerServiceNoHttp {
1198 inner: ServiceData {
1199 server_cfg,
1200 methods,
1201 stop_handle: stop_handle.clone(),
1202 conn_id,
1203 conn_guard: conn_guard.clone(),
1204 },
1205 rpc_middleware,
1206 on_session_close: None,
1207 };
1208
1209 let service = http_middleware.service(tower_service);
1210
1211 tokio::spawn(async move {
1212 let service = crate::utils::TowerToHyperService::new(service);
1214 let io = TokioIo::new(socket);
1215 let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
1216
1217 builder.http2().keep_alive_interval(keep_alive).keep_alive_timeout(keep_alive_timeout);
1219
1220 let conn = builder.serve_connection_with_upgrades(io, service);
1221 let stopped = stop_handle.shutdown();
1222
1223 tokio::pin!(stopped, conn);
1224
1225 let res = match future::select(conn, stopped).await {
1226 Either::Left((conn, _)) => conn,
1227 Either::Right((_, mut conn)) => {
1228 conn.as_mut().graceful_shutdown();
1231 conn.await
1232 }
1233 };
1234
1235 if let Err(e) = res {
1236 tracing::debug!(target: LOG_TARGET, "HTTP serve connection failed {:?}", e);
1237 }
1238 drop(drop_on_completion)
1239 });
1240}
1241
1242enum AcceptConnection<S> {
1243 Shutdown,
1244 Established { socket: TcpStream, remote_addr: SocketAddr, stop: S },
1245 Err((std::io::Error, S)),
1246}
1247
1248async fn try_accept_conn<S>(listener: &TcpListener, stopped: S) -> AcceptConnection<S>
1249where
1250 S: Future + Unpin,
1251{
1252 let accept = listener.accept();
1253 tokio::pin!(accept);
1254
1255 match futures_util::future::select(accept, stopped).await {
1256 Either::Left((res, stop)) => match res {
1257 Ok((socket, remote_addr)) => AcceptConnection::Established { socket, remote_addr, stop },
1258 Err(e) => AcceptConnection::Err((e, stop)),
1259 },
1260 Either::Right(_) => AcceptConnection::Shutdown,
1261 }
1262}
1263
1264pub(crate) async fn handle_rpc_call<S>(
1265 body: &[u8],
1266 is_single: bool,
1267 batch_config: BatchRequestConfig,
1268 rpc_service: &S,
1269 extensions: Extensions,
1270) -> MethodResponse
1271where
1272 S: RpcServiceT<
1273 MethodResponse = MethodResponse,
1274 BatchResponse = MethodResponse,
1275 NotificationResponse = MethodResponse,
1276 > + Send,
1277{
1278 if is_single {
1280 if let Ok(req) = deserialize_with_ext::call::from_slice(body, &extensions) {
1281 rpc_service.call(req).await
1282 } else if let Ok(notif) = deserialize_with_ext::notif::from_slice::<Notif>(body, &extensions) {
1283 rpc_service.notification(notif).await
1284 } else {
1285 let (id, code) = prepare_error(body);
1286 MethodResponse::error(id, ErrorObject::from(code))
1287 }
1288 }
1289 else {
1291 let max_len = match batch_config {
1292 BatchRequestConfig::Disabled => {
1293 let rp = MethodResponse::error(
1294 Id::Null,
1295 ErrorObject::borrowed(BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG, None),
1296 );
1297 return rp;
1298 }
1299 BatchRequestConfig::Limit(limit) => limit as usize,
1300 BatchRequestConfig::Unlimited => usize::MAX,
1301 };
1302
1303 if let Ok(unchecked_batch) = serde_json::from_slice::<Vec<&JsonRawValue>>(body) {
1304 if unchecked_batch.len() > max_len {
1305 return MethodResponse::error(Id::Null, reject_too_big_batch_request(max_len));
1306 }
1307
1308 let mut batch = Vec::with_capacity(unchecked_batch.len());
1309
1310 for call in unchecked_batch {
1311 if let Ok(req) = deserialize_with_ext::call::from_str(call.get(), &extensions) {
1312 batch.push(Ok(BatchEntry::Call(req)));
1313 } else if let Ok(notif) = deserialize_with_ext::notif::from_str::<Notif>(call.get(), &extensions) {
1314 batch.push(Ok(BatchEntry::Notification(notif)));
1315 } else {
1316 let id = match serde_json::from_str::<jsonrpsee_types::InvalidRequest>(call.get()) {
1317 Ok(err) => err.id,
1318 Err(_) => Id::Null,
1319 };
1320
1321 batch.push(Err(BatchEntryErr::new(id, ErrorCode::InvalidRequest.into())));
1322 }
1323 }
1324
1325 rpc_service.batch(Batch::from(batch)).await
1326 } else {
1327 MethodResponse::error(Id::Null, ErrorObject::from(ErrorCode::ParseError))
1328 }
1329 }
1330}