1use std::future::Future;
28use std::net::SocketAddr;
29use std::pin::Pin;
30use std::sync::Arc;
31use std::task::{Context, Poll};
32use std::time::Duration;
33
34use crate::future::{FutureDriver, ServerHandle, StopMonitor};
35use crate::types::error::{ErrorCode, ErrorObject, BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG};
36use crate::types::{Id, Request};
37use futures_channel::mpsc;
38use futures_util::future::{Either, FutureExt};
39use futures_util::io::{BufReader, BufWriter};
40use futures_util::stream::StreamExt;
41use futures_util::TryStreamExt;
42use http::header::{HOST, ORIGIN};
43use http::{HeaderMap, HeaderValue};
44use jsonrpsee_core::id_providers::RandomIntegerIdProvider;
45use jsonrpsee_core::middleware::{self, WsMiddleware as Middleware};
46use jsonrpsee_core::server::access_control::AccessControl;
47use jsonrpsee_core::server::helpers::{
48 prepare_error, BatchResponse, BatchResponseBuilder, BoundedSubscriptions, MethodResponse, MethodSink,
49};
50use jsonrpsee_core::server::resource_limiting::Resources;
51use jsonrpsee_core::server::rpc_module::{ConnState, ConnectionId, MethodKind, Methods};
52use jsonrpsee_core::tracing::{rx_log_from_json, rx_log_from_str, tx_log_from_str, RpcTracing};
53use jsonrpsee_core::traits::IdProvider;
54use jsonrpsee_core::{Error, TEN_MB_SIZE_BYTES};
55use jsonrpsee_types::error::{reject_too_big_request, reject_too_many_subscriptions};
56use jsonrpsee_types::Params;
57use soketto::connection::Error as SokettoError;
58use soketto::data::ByteSlice125;
59use soketto::handshake::WebSocketKey;
60use soketto::handshake::{server::Response, Server as SokettoServer};
61use soketto::Sender;
62use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
63use tokio_stream::wrappers::IntervalStream;
64use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
65use tracing_futures::Instrument;
66
67const MAX_CONNECTIONS: u64 = 100;
69
70pub struct Server<M> {
72 listener: TcpListener,
73 cfg: Settings,
74 stop_monitor: StopMonitor,
75 resources: Resources,
76 middleware: M,
77 id_provider: Arc<dyn IdProvider>,
78}
79
80impl<M> std::fmt::Debug for Server<M> {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 f.debug_struct("Server")
83 .field("listener", &self.listener)
84 .field("cfg", &self.cfg)
85 .field("stop_monitor", &self.stop_monitor)
86 .field("id_provider", &self.id_provider)
87 .field("resources", &self.resources)
88 .finish()
89 }
90}
91
92impl<M: Middleware> Server<M> {
93 pub fn local_addr(&self) -> Result<SocketAddr, Error> {
95 self.listener.local_addr().map_err(Into::into)
96 }
97
98 pub fn server_handle(&self) -> ServerHandle {
100 self.stop_monitor.handle()
101 }
102
103 pub fn start(mut self, methods: impl Into<Methods>) -> Result<ServerHandle, Error> {
105 let methods = methods.into().initialize_resources(&self.resources)?;
106 let handle = self.server_handle();
107
108 match self.cfg.tokio_runtime.take() {
109 Some(rt) => rt.spawn(self.start_inner(methods)),
110 None => tokio::spawn(self.start_inner(methods)),
111 };
112
113 Ok(handle)
114 }
115
116 async fn start_inner(self, methods: Methods) {
117 let stop_monitor = self.stop_monitor;
118 let resources = self.resources;
119 let middleware = self.middleware;
120
121 let mut id = 0;
122 let mut connections = FutureDriver::default();
123 let mut incoming = Monitored::new(Incoming(self.listener), &stop_monitor);
124
125 loop {
126 match connections.select_with(&mut incoming).await {
127 Ok((socket, _addr)) => {
128 if let Err(e) = socket.set_nodelay(true) {
129 tracing::error!("Could not set NODELAY on socket: {:?}", e);
130 continue;
131 }
132
133 if connections.count() >= self.cfg.max_connections as usize {
134 tracing::warn!("Too many connections. Try again in a while.");
135 connections.add(Box::pin(handshake(socket, HandshakeResponse::Reject { status_code: 429 })));
136 continue;
137 }
138
139 let methods = &methods;
140 let cfg = &self.cfg;
141 let id_provider = self.id_provider.clone();
142
143 connections.add(Box::pin(handshake(
144 socket,
145 HandshakeResponse::Accept {
146 conn_id: id,
147 methods,
148 resources: &resources,
149 cfg,
150 stop_monitor: &stop_monitor,
151 middleware: middleware.clone(),
152 id_provider,
153 },
154 )));
155
156 tracing::info!("Accepting new connection {}/{}", connections.count(), self.cfg.max_connections);
157
158 id = id.wrapping_add(1);
159 }
160 Err(MonitoredError::Selector(err)) => {
161 tracing::error!("Error while awaiting a new connection: {:?}", err);
162 }
163 Err(MonitoredError::Shutdown) => break,
164 }
165 }
166
167 connections.await
168 }
169}
170
171struct Monitored<'a, F> {
173 future: F,
174 stop_monitor: &'a StopMonitor,
175}
176
177impl<'a, F> Monitored<'a, F> {
178 fn new(future: F, stop_monitor: &'a StopMonitor) -> Self {
179 Monitored { future, stop_monitor }
180 }
181}
182
183enum MonitoredError<E> {
184 Shutdown,
185 Selector(E),
186}
187
188struct Incoming(TcpListener);
189
190impl<'a> Future for Monitored<'a, Incoming> {
191 type Output = Result<(TcpStream, SocketAddr), MonitoredError<std::io::Error>>;
192
193 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
194 let this = Pin::into_inner(self);
195
196 if this.stop_monitor.shutdown_requested() {
197 return Poll::Ready(Err(MonitoredError::Shutdown));
198 }
199
200 this.future.0.poll_accept(cx).map_err(MonitoredError::Selector)
201 }
202}
203
204impl<'a, 'f, F, T, E> Future for Monitored<'a, Pin<&'f mut F>>
205where
206 F: Future<Output = Result<T, E>>,
207{
208 type Output = Result<T, MonitoredError<E>>;
209
210 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
211 let this = Pin::into_inner(self);
212
213 if this.stop_monitor.shutdown_requested() {
214 return Poll::Ready(Err(MonitoredError::Shutdown));
215 }
216
217 this.future.poll_unpin(cx).map_err(MonitoredError::Selector)
218 }
219}
220
221enum HandshakeResponse<'a, M> {
222 Reject {
223 status_code: u16,
224 },
225 Accept {
226 conn_id: ConnectionId,
227 methods: &'a Methods,
228 resources: &'a Resources,
229 cfg: &'a Settings,
230 stop_monitor: &'a StopMonitor,
231 middleware: M,
232 id_provider: Arc<dyn IdProvider>,
233 },
234}
235
236async fn handshake<M: Middleware>(socket: tokio::net::TcpStream, mode: HandshakeResponse<'_, M>) -> Result<(), Error> {
237 let remote_addr = socket.peer_addr()?;
238
239 let mut server = SokettoServer::new(BufReader::new(BufWriter::new(socket.compat())));
241
242 match mode {
243 HandshakeResponse::Reject { status_code } => {
244 let reject = Response::Reject { status_code };
246 server.send_response(&reject).await?;
247
248 let (mut sender, _) = server.into_builder().finish();
249
250 sender.close().await?;
252
253 Ok(())
254 }
255 HandshakeResponse::Accept { conn_id, methods, resources, cfg, stop_monitor, middleware, id_provider } => {
256 tracing::debug!("Accepting new connection: {}", conn_id);
257
258 let key_and_headers = get_key_and_headers(&mut server, cfg).await;
259
260 match key_and_headers {
261 Ok((key, headers)) => {
262 middleware.on_connect(remote_addr, &headers);
263 let accept = Response::Accept { key, protocol: None };
264 server.send_response(&accept).await?;
265 }
266 Err(err) => {
267 tracing::warn!("Rejected connection: {:?}", err);
268 let reject = Response::Reject { status_code: 403 };
269 server.send_response(&reject).await?;
270
271 return Err(err);
272 }
273 };
274
275 let join_result = tokio::spawn(background_task(BackgroundTask {
276 server,
277 conn_id,
278 methods: methods.clone(),
279 resources: resources.clone(),
280 max_request_body_size: cfg.max_request_body_size,
281 max_response_body_size: cfg.max_response_body_size,
282 max_log_length: cfg.max_log_length,
283 batch_requests_supported: cfg.batch_requests_supported,
284 bounded_subscriptions: BoundedSubscriptions::new(cfg.max_subscriptions_per_connection),
285 stop_server: stop_monitor.clone(),
286 middleware,
287 id_provider,
288 ping_interval: cfg.ping_interval,
289 remote_addr,
290 }))
291 .await;
292
293 match join_result {
294 Err(_) => Err(Error::Custom("Background task was aborted".into())),
295 Ok(result) => result,
296 }
297 }
298 }
299}
300
301struct BackgroundTask<'a, M> {
302 server: SokettoServer<'a, BufReader<BufWriter<Compat<tokio::net::TcpStream>>>>,
303 conn_id: ConnectionId,
304 methods: Methods,
305 resources: Resources,
306 max_request_body_size: u32,
307 max_response_body_size: u32,
308 max_log_length: u32,
309 batch_requests_supported: bool,
310 bounded_subscriptions: BoundedSubscriptions,
311 stop_server: StopMonitor,
312 middleware: M,
313 id_provider: Arc<dyn IdProvider>,
314 ping_interval: Duration,
315 remote_addr: SocketAddr,
316}
317
318async fn background_task<M: Middleware>(input: BackgroundTask<'_, M>) -> Result<(), Error> {
319 let BackgroundTask {
320 server,
321 conn_id,
322 methods,
323 resources,
324 max_request_body_size,
325 max_response_body_size,
326 max_log_length,
327 batch_requests_supported,
328 bounded_subscriptions,
329 stop_server,
330 middleware,
331 id_provider,
332 ping_interval,
333 remote_addr,
334 } = input;
335
336 let mut builder = server.into_builder();
338 builder.set_max_message_size(max_request_body_size as usize);
339 let (mut sender, mut receiver) = builder.finish();
340 let (tx, mut rx) = mpsc::unbounded::<String>();
341 let bounded_subscriptions2 = bounded_subscriptions.clone();
342
343 let stop_server2 = stop_server.clone();
344 let sink = MethodSink::new_with_limit(tx, max_response_body_size, max_log_length);
345
346 tokio::spawn(async move {
348 let mut rx_item = rx.next();
350
351 let ping_interval = IntervalStream::new(tokio::time::interval(ping_interval));
353 tokio::pin!(ping_interval);
354 let mut next_ping = ping_interval.next();
355
356 while !stop_server2.shutdown_requested() {
357 match futures_util::future::select(rx_item, next_ping).await {
360 Either::Left((Some(response), ping)) => {
361 if let Err(err) = send_ws_message(&mut sender, response).await {
363 tracing::warn!("WS send error: {}; terminate connection", err);
364 break;
365 }
366 rx_item = rx.next();
367 next_ping = ping;
368 }
369 Either::Left((None, _)) => break,
371
372 Either::Right((_, next_rx)) => {
374 if let Err(err) = send_ws_ping(&mut sender).await {
375 tracing::warn!("WS send ping error: {}; terminate connection", err);
376 break;
377 }
378 rx_item = next_rx;
379 next_ping = ping_interval.next();
380 }
381 }
382 }
383
384 let _ = sender.close().await;
386
387 bounded_subscriptions2.close();
389 });
390
391 let mut data = Vec::with_capacity(100);
393 let mut method_executors = FutureDriver::default();
394 let middleware = &middleware;
395
396 let result = loop {
397 data.clear();
398
399 {
400 let receive = async {
402 loop {
404 match receiver.receive(&mut data).await? {
405 soketto::Incoming::Data(d) => break Ok(d),
406 soketto::Incoming::Pong(_) => tracing::debug!("recv pong"),
407 soketto::Incoming::Closed(_) => {
408 break Err(SokettoError::Closed);
411 }
412 }
413 }
414 };
415
416 tokio::pin!(receive);
417
418 if let Err(err) = method_executors.select_with(Monitored::new(receive, &stop_server)).await {
419 match err {
420 MonitoredError::Selector(SokettoError::Closed) => {
421 tracing::debug!("WS transport: remote peer terminated the connection: {}", conn_id);
422 sink.close();
423 break Ok(());
424 }
425 MonitoredError::Selector(SokettoError::MessageTooLarge { current, maximum }) => {
426 tracing::warn!(
427 "WS transport error: outgoing message is too big error ({} bytes, max is {})",
428 current,
429 maximum
430 );
431 sink.send_error(Id::Null, reject_too_big_request(max_request_body_size));
432 continue;
433 }
434 MonitoredError::Selector(err) => {
436 tracing::debug!("WS error: {}; terminate connection {}", err, conn_id);
437 sink.close();
438 break Err(err.into());
439 }
440 MonitoredError::Shutdown => break Ok(()),
441 };
442 };
443 };
444
445 let request_start = middleware.on_request();
446
447 let first_non_whitespace = data.iter().find(|byte| !byte.is_ascii_whitespace());
448 match first_non_whitespace {
449 Some(b'{') => {
450 let data = std::mem::take(&mut data);
451 let sink = sink.clone();
452 let resources = &resources;
453 let methods = &methods;
454 let bounded_subscriptions = bounded_subscriptions.clone();
455 let id_provider = &*id_provider;
456
457 let fut = async move {
458 let call = CallData {
459 conn_id,
460 resources,
461 max_response_body_size,
462 max_log_length,
463 methods,
464 bounded_subscriptions,
465 sink: &sink,
466 id_provider: &*id_provider,
467 middleware,
468 request_start,
469 };
470
471 match process_single_request(data, call).await {
472 MethodResult::JustMiddleware(r) => {
473 middleware.on_response(&r.result, request_start);
474 }
475 MethodResult::SendAndMiddleware(r) => {
476 middleware.on_response(&r.result, request_start);
477 let _ = sink.send_raw(r.result);
478 }
479 };
480 }
481 .boxed();
482
483 method_executors.add(fut);
484 }
485 Some(b'[') if !batch_requests_supported => {
486 let response = MethodResponse::error(
487 Id::Null,
488 ErrorObject::borrowed(BATCHES_NOT_SUPPORTED_CODE, &BATCHES_NOT_SUPPORTED_MSG, None),
489 );
490 middleware.on_response(&response.result, request_start);
491 let _ = sink.send_raw(response.result);
492 }
493 Some(b'[') => {
494 let resources = &resources;
496 let methods = &methods;
497 let bounded_subscriptions = bounded_subscriptions.clone();
498 let sink = sink.clone();
499 let id_provider = id_provider.clone();
500 let data = std::mem::take(&mut data);
501
502 let fut = async move {
503 let response = process_batch_request(Batch {
504 data,
505 call: CallData {
506 conn_id,
507 resources,
508 max_response_body_size,
509 max_log_length,
510 methods,
511 bounded_subscriptions,
512 sink: &sink,
513 id_provider: &*id_provider,
514 middleware,
515 request_start,
516 },
517 })
518 .await;
519
520 tx_log_from_str(&response.result, max_log_length);
521 middleware.on_response(&response.result, request_start);
522 let _ = sink.send_raw(response.result);
523 };
524
525 method_executors.add(Box::pin(fut));
526 }
527 _ => {
528 sink.send_error(Id::Null, ErrorCode::ParseError.into());
529 }
530 }
531 };
532
533 middleware.on_disconnect(remote_addr);
534
535 method_executors.await;
539
540 result
541}
542
543#[derive(Debug, Clone)]
545struct Settings {
546 max_request_body_size: u32,
548 max_response_body_size: u32,
550 max_connections: u64,
552 max_subscriptions_per_connection: u32,
554 max_log_length: u32,
558 access_control: AccessControl,
560 batch_requests_supported: bool,
562 tokio_runtime: Option<tokio::runtime::Handle>,
564 ping_interval: Duration,
566}
567
568impl Default for Settings {
569 fn default() -> Self {
570 Self {
571 max_request_body_size: TEN_MB_SIZE_BYTES,
572 max_response_body_size: TEN_MB_SIZE_BYTES,
573 max_log_length: 4096,
574 max_subscriptions_per_connection: 1024,
575 max_connections: MAX_CONNECTIONS,
576 batch_requests_supported: true,
577 access_control: AccessControl::default(),
578 tokio_runtime: None,
579 ping_interval: Duration::from_secs(60),
580 }
581 }
582}
583
584#[derive(Debug)]
586pub struct Builder<M = ()> {
587 settings: Settings,
588 resources: Resources,
589 middleware: M,
590 id_provider: Arc<dyn IdProvider>,
591}
592
593impl Default for Builder {
594 fn default() -> Self {
595 Builder {
596 settings: Settings::default(),
597 resources: Resources::default(),
598 middleware: (),
599 id_provider: Arc::new(RandomIntegerIdProvider),
600 }
601 }
602}
603
604impl Builder {
605 pub fn new() -> Self {
607 Self::default()
608 }
609}
610
611impl<M> Builder<M> {
612 pub fn max_request_body_size(mut self, size: u32) -> Self {
614 self.settings.max_request_body_size = size;
615 self
616 }
617
618 pub fn max_response_body_size(mut self, size: u32) -> Self {
620 self.settings.max_response_body_size = size;
621 self
622 }
623
624 pub fn max_connections(mut self, max: u64) -> Self {
626 self.settings.max_connections = max;
627 self
628 }
629
630 pub fn batch_requests_supported(mut self, supported: bool) -> Self {
633 self.settings.batch_requests_supported = supported;
634 self
635 }
636
637 pub fn max_subscriptions_per_connection(mut self, max: u32) -> Self {
639 self.settings.max_subscriptions_per_connection = max;
640 self
641 }
642
643 pub fn register_resource(mut self, label: &'static str, capacity: u16, default: u16) -> Result<Self, Error> {
649 self.resources.register(label, capacity, default)?;
650 Ok(self)
651 }
652
653 pub fn set_middleware<T: Middleware>(self, middleware: T) -> Builder<T> {
695 Builder { settings: self.settings, resources: self.resources, middleware, id_provider: self.id_provider }
696 }
697
698 pub fn custom_tokio_runtime(mut self, rt: tokio::runtime::Handle) -> Self {
702 self.settings.tokio_runtime = Some(rt);
703 self
704 }
705
706 pub fn ping_interval(mut self, interval: Duration) -> Self {
723 self.settings.ping_interval = interval;
724 self
725 }
726
727 pub fn set_id_provider<I: IdProvider + 'static>(mut self, id_provider: I) -> Self {
748 self.id_provider = Arc::new(id_provider);
749 self
750 }
751
752 pub fn set_access_control(mut self, acl: AccessControl) -> Self {
754 self.settings.access_control = acl;
755 self
756 }
757
758 pub async fn build(self, addrs: impl ToSocketAddrs) -> Result<Server<M>, Error> {
775 let listener = TcpListener::bind(addrs).await?;
776 let stop_monitor = StopMonitor::new();
777 let resources = self.resources;
778 Ok(Server {
779 listener,
780 cfg: self.settings,
781 stop_monitor,
782 resources,
783 middleware: self.middleware,
784 id_provider: self.id_provider,
785 })
786 }
787}
788
789async fn send_ws_message(
790 sender: &mut Sender<BufReader<BufWriter<Compat<TcpStream>>>>,
791 response: String,
792) -> Result<(), Error> {
793 sender.send_text_owned(response).await?;
794 sender.flush().await.map_err(Into::into)
795}
796
797async fn send_ws_ping(sender: &mut Sender<BufReader<BufWriter<Compat<TcpStream>>>>) -> Result<(), Error> {
798 tracing::debug!("send ping");
799 let slice: &[u8] = &[];
801 let byte_slice = ByteSlice125::try_from(slice).expect("Empty slice should fit into ByteSlice125");
803 sender.send_ping(byte_slice).await?;
804 sender.flush().await.map_err(Into::into)
805}
806
807#[derive(Debug, Clone)]
808struct Batch<'a, M: Middleware> {
809 data: Vec<u8>,
810 call: CallData<'a, M>,
811}
812
813#[derive(Debug, Clone)]
814struct CallData<'a, M: Middleware> {
815 conn_id: usize,
816 bounded_subscriptions: BoundedSubscriptions,
817 id_provider: &'a dyn IdProvider,
818 middleware: &'a M,
819 methods: &'a Methods,
820 max_response_body_size: u32,
821 max_log_length: u32,
822 resources: &'a Resources,
823 sink: &'a MethodSink,
824 request_start: M::Instant,
825}
826
827#[derive(Debug, Clone)]
828struct Call<'a, M: Middleware> {
829 params: Params<'a>,
830 name: &'a str,
831 call: CallData<'a, M>,
832 id: Id<'a>,
833}
834
835enum MethodResult {
836 JustMiddleware(MethodResponse),
837 SendAndMiddleware(MethodResponse),
838}
839
840impl MethodResult {
841 fn as_inner(&self) -> &MethodResponse {
842 match &self {
843 Self::JustMiddleware(r) => r,
844 Self::SendAndMiddleware(r) => r,
845 }
846 }
847}
848
849async fn process_batch_request<M>(b: Batch<'_, M>) -> BatchResponse
853where
854 M: Middleware,
855{
856 let Batch { data, call } = b;
857
858 if let Ok(batch) = serde_json::from_slice::<Vec<Request>>(&data) {
859 return if !batch.is_empty() {
860 let batch = batch.into_iter().map(|req| Ok((req, call.clone())));
861 let batch_stream = futures_util::stream::iter(batch);
862
863 let trace = RpcTracing::batch();
864
865 return async {
866 let max_response_size = call.max_response_body_size;
867
868 let batch_response = batch_stream
869 .try_fold(
870 BatchResponseBuilder::new_with_limit(max_response_size as usize),
871 |batch_response, (req, call)| async move {
872 let params = Params::new(req.params.map(|params| params.get()));
873 let response = execute_call(Call { name: &req.method, params, id: req.id, call }).await;
874 batch_response.append(response.as_inner())
875 },
876 )
877 .await;
878
879 match batch_response {
880 Ok(batch) => batch.finish(),
881 Err(batch_err) => batch_err,
882 }
883 }
884 .instrument(trace.into_span())
885 .await;
886 } else {
887 BatchResponse::error(Id::Null, ErrorObject::from(ErrorCode::InvalidRequest))
888 };
889 }
890
891 let (id, code) = prepare_error(&data);
892 BatchResponse::error(id, ErrorObject::from(code))
893}
894
895async fn process_single_request<M: Middleware>(data: Vec<u8>, call: CallData<'_, M>) -> MethodResult {
896 if let Ok(req) = serde_json::from_slice::<Request>(&data) {
897 let trace = RpcTracing::method_call(&req.method);
898
899 async {
900 rx_log_from_json(&req, call.max_log_length);
901
902 let params = Params::new(req.params.map(|params| params.get()));
903 let name = &req.method;
904 let id = req.id;
905
906 execute_call(Call { name, params, id, call }).await
907 }
908 .instrument(trace.into_span())
909 .await
910 } else {
911 let (id, code) = prepare_error(&data);
912 MethodResult::SendAndMiddleware(MethodResponse::error(id, ErrorObject::from(code)))
913 }
914}
915
916async fn execute_call<M: Middleware>(c: Call<'_, M>) -> MethodResult {
922 let Call { name, id, params, call } = c;
923 let CallData {
924 resources,
925 methods,
926 middleware,
927 max_response_body_size,
928 max_log_length,
929 conn_id,
930 bounded_subscriptions,
931 id_provider,
932 sink,
933 request_start,
934 } = call;
935
936 let response = match methods.method_with_name(name) {
937 None => {
938 middleware.on_call(name, params.clone(), middleware::MethodKind::Unknown);
939 let response = MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound));
940 MethodResult::SendAndMiddleware(response)
941 }
942 Some((name, method)) => match &method.inner() {
943 MethodKind::Sync(callback) => {
944 middleware.on_call(name, params.clone(), middleware::MethodKind::MethodCall);
945
946 match method.claim(name, resources) {
947 Ok(guard) => {
948 let r = (callback)(id, params, max_response_body_size as usize);
949 drop(guard);
950 MethodResult::SendAndMiddleware(r)
951 }
952 Err(err) => {
953 tracing::error!("[Methods::execute_with_resources] failed to lock resources: {:?}", err);
954 let response = MethodResponse::error(id, ErrorObject::from(ErrorCode::ServerIsBusy));
955 MethodResult::SendAndMiddleware(response)
956 }
957 }
958 }
959 MethodKind::Async(callback) => {
960 middleware.on_call(name, params.clone(), middleware::MethodKind::MethodCall);
961
962 match method.claim(name, resources) {
963 Ok(guard) => {
964 let id = id.into_owned();
965 let params = params.into_owned();
966
967 let response =
968 (callback)(id, params, conn_id, max_response_body_size as usize, Some(guard)).await;
969 MethodResult::SendAndMiddleware(response)
970 }
971 Err(err) => {
972 tracing::error!("[Methods::execute_with_resources] failed to lock resources: {:?}", err);
973 let response = MethodResponse::error(id, ErrorObject::from(ErrorCode::ServerIsBusy));
974 MethodResult::SendAndMiddleware(response)
975 }
976 }
977 }
978 MethodKind::Subscription(callback) => {
979 middleware.on_call(name, params.clone(), middleware::MethodKind::Subscription);
980
981 match method.claim(name, resources) {
982 Ok(guard) => {
983 if let Some(cn) = bounded_subscriptions.acquire() {
984 let conn_state = ConnState { conn_id, close_notify: cn, id_provider };
985 let response = callback(id.clone(), params, sink.clone(), conn_state, Some(guard)).await;
986 MethodResult::JustMiddleware(response)
987 } else {
988 let response =
989 MethodResponse::error(id, reject_too_many_subscriptions(bounded_subscriptions.max()));
990 MethodResult::SendAndMiddleware(response)
991 }
992 }
993 Err(err) => {
994 tracing::error!("[Methods::execute_with_resources] failed to lock resources: {:?}", err);
995 let response = MethodResponse::error(id, ErrorObject::from(ErrorCode::ServerIsBusy));
996 MethodResult::SendAndMiddleware(response)
997 }
998 }
999 }
1000 MethodKind::Unsubscription(callback) => {
1001 middleware.on_call(name, params.clone(), middleware::MethodKind::Unsubscription);
1002
1003 let result = callback(id, params, conn_id, max_response_body_size as usize);
1005 MethodResult::SendAndMiddleware(result)
1006 }
1007 },
1008 };
1009
1010 let r = response.as_inner();
1011
1012 rx_log_from_str(&r.result, max_log_length);
1013 middleware.on_result(name, r.success, request_start);
1014 response
1015}
1016
1017async fn get_key_and_headers(
1019 server: &mut SokettoServer<'_, BufReader<BufWriter<Compat<TcpStream>>>>,
1020 cfg: &Settings,
1021) -> Result<(WebSocketKey, HeaderMap), Error> {
1022 let req = server.receive_request().await?;
1023
1024 tracing::trace!("Connection request: {:?}", req);
1025
1026 let host = std::str::from_utf8(req.headers().host).map_err(|e| Error::HttpHeaderRejected("Host", e.to_string()))?;
1027
1028 let origin = req.headers().origin.and_then(|h| {
1029 let res = std::str::from_utf8(h).ok();
1030 if res.is_none() {
1031 tracing::warn!("Origin header invalid UTF-8; treated as no Origin header");
1032 }
1033 res
1034 });
1035
1036 let host_check = cfg.access_control.verify_host(host);
1037 let origin_check = cfg.access_control.verify_origin(origin, host);
1038
1039 let mut headers = HeaderMap::new();
1040
1041 host_check.and(origin_check).map(|()| {
1042 let key = req.key();
1043
1044 if let Ok(val) = HeaderValue::from_str(host) {
1045 headers.insert(HOST, val);
1046 }
1047
1048 if let Some(Ok(val)) = origin.map(HeaderValue::from_str) {
1049 headers.insert(ORIGIN, val);
1050 }
1051
1052 (key, headers)
1053 })
1054}