1use std::sync::Arc;
2use std::time::Instant;
3
4use crate::future::{IntervalStream, SessionClose};
5use crate::middleware::rpc::{RpcService, RpcServiceCfg};
6use crate::server::{ConnectionState, ServerConfig, handle_rpc_call};
7use crate::{HttpBody, HttpRequest, HttpResponse, LOG_TARGET, PingConfig};
8
9use futures_util::future::{self, Either};
10use futures_util::io::{BufReader, BufWriter};
11use futures_util::{Future, StreamExt, TryStreamExt};
12use hyper::upgrade::Upgraded;
13use hyper_util::rt::TokioIo;
14use jsonrpsee_core::middleware::{RpcServiceBuilder, RpcServiceT};
15use jsonrpsee_core::server::{BoundedSubscriptions, MethodResponse, MethodSink, Methods};
16use jsonrpsee_types::Id;
17use jsonrpsee_types::error::{ErrorCode, reject_too_big_request};
18use serde_json::value::RawValue;
19use soketto::connection::Error as SokettoError;
20use soketto::data::ByteSlice125;
21use tokio::sync::{mpsc, oneshot};
22use tokio::time::{interval, interval_at};
23use tokio_stream::wrappers::ReceiverStream;
24use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
25
26pub(crate) type Sender = soketto::Sender<BufReader<BufWriter<Compat<TokioIo<Upgraded>>>>>;
27pub(crate) type Receiver = soketto::Receiver<BufReader<BufWriter<Compat<TokioIo<Upgraded>>>>>;
28
29pub use soketto::handshake::http::is_upgrade_request;
30
31enum Incoming {
32 Data(Vec<u8>),
33 Pong,
34}
35
36pub(crate) async fn send_message(sender: &mut Sender, response: Box<RawValue>) -> Result<(), SokettoError> {
37 sender.send_text_owned(String::from(Box::<str>::from(response))).await?;
38 sender.flush().await
39}
40
41pub(crate) async fn send_ping(sender: &mut Sender) -> Result<(), SokettoError> {
42 tracing::debug!(target: LOG_TARGET, "Send ping");
43 let slice: &[u8] = &[];
45 let byte_slice = ByteSlice125::try_from(slice).expect("Empty slice should fit into ByteSlice125");
47 sender.send_ping(byte_slice).await?;
48 sender.flush().await
49}
50
51pub(crate) struct BackgroundTaskParams<S> {
52 pub(crate) server_cfg: ServerConfig,
53 pub(crate) conn: ConnectionState,
54 pub(crate) ws_sender: Sender,
55 pub(crate) ws_receiver: Receiver,
56 pub(crate) rpc_service: S,
57 pub(crate) sink: MethodSink,
58 pub(crate) rx: mpsc::Receiver<Box<RawValue>>,
59 pub(crate) pending_calls_completed: mpsc::Receiver<()>,
60 pub(crate) on_session_close: Option<SessionClose>,
61 pub(crate) extensions: http::Extensions,
62}
63
64pub(crate) async fn background_task<S>(params: BackgroundTaskParams<S>)
65where
66 S: RpcServiceT<
67 MethodResponse = MethodResponse,
68 BatchResponse = MethodResponse,
69 NotificationResponse = MethodResponse,
70 > + Send
71 + Sync
72 + 'static,
73{
74 let BackgroundTaskParams {
75 server_cfg,
76 conn,
77 ws_sender,
78 ws_receiver,
79 rpc_service,
80 sink,
81 rx,
82 pending_calls_completed,
83 mut on_session_close,
84 extensions,
85 } = params;
86 let ServerConfig { ping_config, batch_requests_config, max_request_body_size, .. } = server_cfg;
87
88 let (conn_tx, conn_rx) = oneshot::channel();
89
90 let send_task_handle = tokio::spawn(send_task(rx, ws_sender, ping_config, conn_rx));
92
93 let stopped = conn.stop_handle.clone().shutdown();
94 let rpc_service = Arc::new(rpc_service);
95 let mut missed_pings = 0;
96
97 tokio::pin!(stopped);
98
99 let ws_stream = futures_util::stream::unfold(ws_receiver, |mut receiver| async {
100 let mut data = Vec::new();
101 match receiver.receive(&mut data).await {
102 Ok(soketto::Incoming::Data(_)) => Some((Ok(Incoming::Data(data)), receiver)),
103 Ok(soketto::Incoming::Pong(_)) => Some((Ok(Incoming::Pong), receiver)),
104 Ok(soketto::Incoming::Closed(_)) | Err(SokettoError::Closed) => None,
105 Err(e) => Some((Err(e), receiver)),
108 }
109 })
110 .fuse();
111
112 tokio::pin!(ws_stream);
113
114 let result = loop {
115 let data = match try_recv(&mut ws_stream, stopped, ping_config, &mut missed_pings).await {
116 Receive::ConnectionClosed => break Ok(Shutdown::ConnectionClosed),
117 Receive::Stopped => break Ok(Shutdown::Stopped),
118 Receive::Ok(data, stop) => {
119 stopped = stop;
120 data
121 }
122 Receive::Err(err, stop) => {
123 stopped = stop;
124
125 match err {
126 SokettoError::Closed => {
127 break Ok(Shutdown::ConnectionClosed);
128 }
129 SokettoError::MessageTooLarge { current, maximum } => {
130 tracing::debug!(
131 target: LOG_TARGET,
132 "WS recv error: message too large current={}/max={}",
133 current,
134 maximum
135 );
136 if sink.send_error(Id::Null, reject_too_big_request(max_request_body_size)).await.is_err() {
137 break Ok(Shutdown::ConnectionClosed);
138 }
139
140 continue;
141 }
142 err => {
143 tracing::debug!(target: LOG_TARGET, "WS error: {}; terminate connection: {}", err, conn.conn_id);
144 break Err(err);
145 }
146 };
147 }
148 };
149
150 let rpc_service = rpc_service.clone();
151 let sink = sink.clone();
152 let extensions = extensions.clone();
153
154 tokio::spawn(async move {
155 let first_non_whitespace = data.iter().enumerate().take(128).find(|(_, byte)| !byte.is_ascii_whitespace());
156
157 let (idx, is_single) = match first_non_whitespace {
158 Some((start, b'{')) => (start, true),
159 Some((start, b'[')) => (start, false),
160 _ => {
161 _ = sink.send_error(Id::Null, ErrorCode::ParseError.into()).await;
162 return;
163 }
164 };
165
166 let rp = handle_rpc_call(&data[idx..], is_single, batch_requests_config, &*rpc_service, extensions).await;
167
168 if rp.is_method_call() || rp.is_batch() {
171 let is_success = rp.is_success();
172 let (json, mut on_close, _) = rp.into_parts();
173
174 if sink.send(json).await.is_err() {
176 return;
177 }
178
179 if let Some(n) = on_close.take() {
182 n.notify(is_success);
183 }
184 }
185 });
186 };
187
188 drop(rpc_service);
192 graceful_shutdown(result, pending_calls_completed, ws_stream, conn_tx, send_task_handle).await;
193
194 drop(conn);
195
196 if let Some(c) = on_session_close.take() {
197 c.close();
198 }
199}
200
201async fn send_task(
203 rx: mpsc::Receiver<Box<RawValue>>,
204 mut ws_sender: Sender,
205 ping_config: Option<PingConfig>,
206 stop: oneshot::Receiver<()>,
207) {
208 let ping_interval = match ping_config {
209 None => IntervalStream::pending(),
210 Some(p) => IntervalStream::new(interval(p.ping_interval)),
214 };
215 let rx = ReceiverStream::new(rx);
216
217 tokio::pin!(ping_interval, rx, stop);
218
219 let mut rx_item = rx.next();
221 let next_ping = ping_interval.next();
222 let mut futs = future::select(next_ping, stop);
223
224 loop {
225 match future::select(rx_item, futs).await {
228 Either::Left((Some(response), not_ready)) => {
230 if let Err(err) = send_message(&mut ws_sender, response).await {
232 tracing::debug!(target: LOG_TARGET, "WS send error: {}", err);
233 break;
234 }
235
236 rx_item = rx.next();
237 futs = not_ready;
238 }
239
240 Either::Left((None, _)) => {
242 break;
243 }
244
245 Either::Right((Either::Left((_instant, _stopped)), next_rx)) => {
247 stop = _stopped;
248 if let Err(err) = send_ping(&mut ws_sender).await {
249 tracing::debug!(target: LOG_TARGET, "WS send ping error: {}", err);
250 break;
251 }
252
253 rx_item = next_rx;
254 futs = future::select(ping_interval.next(), stop);
255 }
256 Either::Right((Either::Right((_stopped, _)), _)) => {
257 break;
259 }
260 }
261 }
262
263 let _ = ws_sender.close().await;
265 rx.close();
266}
267
268enum Receive<S> {
269 ConnectionClosed,
270 Stopped,
271 Err(SokettoError, S),
272 Ok(Vec<u8>, S),
273}
274
275async fn try_recv<T, S>(
277 ws_stream: &mut T,
278 mut stopped: S,
279 ping_config: Option<PingConfig>,
280 missed_pings: &mut usize,
281) -> Receive<S>
282where
283 S: Future<Output = ()> + Unpin,
284 T: StreamExt<Item = Result<Incoming, SokettoError>> + Unpin,
285{
286 let mut last_active = Instant::now();
287 let inactivity_check = match ping_config {
288 Some(p) => IntervalStream::new(interval_at(tokio::time::Instant::now() + p.ping_interval, p.ping_interval)),
289 None => IntervalStream::pending(),
290 };
291
292 tokio::pin!(inactivity_check);
293
294 let mut futs = futures_util::future::select(ws_stream.next(), inactivity_check.next());
295
296 loop {
297 match futures_util::future::select(futs, stopped).await {
298 Either::Left((Either::Left((None, _)), _)) => break Receive::ConnectionClosed,
300 Either::Left((Either::Left((Some(Ok(Incoming::Data(d))), _)), s)) => break Receive::Ok(d, s),
302 Either::Left((Either::Left((Some(Ok(Incoming::Pong)), inactive)), s)) => {
304 last_active = Instant::now();
305 stopped = s;
306 futs = futures_util::future::select(ws_stream.next(), inactive);
307 }
308 Either::Left((Either::Left((Some(Err(e)), _)), s)) => break Receive::Err(e, s),
310 Either::Left((Either::Right((_instant, rcv)), s)) => {
312 if let Some(p) = ping_config {
313 if last_active.elapsed() > p.inactive_limit {
314 *missed_pings += 1;
315
316 if *missed_pings >= p.max_failures {
317 tracing::debug!(
318 target: LOG_TARGET,
319 "WS ping/pong inactivity limit `{}` exceeded; closing connection",
320 p.max_failures,
321 );
322 break Receive::ConnectionClosed;
323 }
324 }
325 }
326
327 stopped = s;
328 futs = futures_util::future::select(rcv, inactivity_check.next());
329 }
330 Either::Right(_) => break Receive::Stopped,
332 }
333 }
334}
335
336#[derive(Debug, Copy, Clone)]
337pub(crate) enum Shutdown {
338 Stopped,
339 ConnectionClosed,
340}
341
342async fn graceful_shutdown<S>(
346 result: Result<Shutdown, SokettoError>,
347 pending_calls: mpsc::Receiver<()>,
348 ws_stream: S,
349 mut conn_tx: oneshot::Sender<()>,
350 send_task_handle: tokio::task::JoinHandle<()>,
351) where
352 S: StreamExt<Item = Result<Incoming, SokettoError>> + Unpin,
353{
354 let pending_calls = ReceiverStream::new(pending_calls);
355
356 if let Ok(Shutdown::Stopped) = result {
357 let graceful_shutdown = pending_calls.for_each(|_| async {});
358 let disconnect = ws_stream.try_for_each(|_| async { Ok(()) });
359
360 tokio::select! {
361 _ = graceful_shutdown => {}
362 res = disconnect => {
363 if let Err(err) = res {
364 tracing::warn!(target: LOG_TARGET, "Graceful shutdown terminated because of error: `{err}`");
365 }
366 }
367 _ = conn_tx.closed() => {}
368 }
369 }
370
371 _ = conn_tx.send(());
373 _ = send_task_handle.await;
375}
376
377pub async fn connect<L, B>(
422 req: HttpRequest<B>,
423 server_cfg: ServerConfig,
424 methods: impl Into<Methods>,
425 conn: ConnectionState,
426 rpc_middleware: RpcServiceBuilder<L>,
427) -> Result<(HttpResponse, impl Future<Output = ()>), HttpResponse>
428where
429 L: tower::Layer<RpcService>,
430 <L as tower::Layer<RpcService>>::Service: RpcServiceT<
431 MethodResponse = MethodResponse,
432 BatchResponse = MethodResponse,
433 NotificationResponse = MethodResponse,
434 > + Send
435 + Sync
436 + 'static,
437{
438 let mut server = soketto::handshake::http::Server::new();
439
440 match server.receive_request(&req) {
441 Ok(response) => {
442 let (tx, rx) = mpsc::channel(server_cfg.message_buffer_capacity as usize);
443 let sink = MethodSink::new(tx);
444
445 let (pending_calls, pending_calls_completed) = mpsc::channel::<()>(1);
449
450 let rpc_service_cfg = RpcServiceCfg::CallsAndSubscriptions {
451 bounded_subscriptions: BoundedSubscriptions::new(server_cfg.max_subscriptions_per_connection),
452 id_provider: server_cfg.id_provider.clone(),
453 sink: sink.clone(),
454 _pending_calls: pending_calls,
455 };
456
457 let rpc_service = RpcService::new(
458 methods.into(),
459 server_cfg.max_response_body_size as usize,
460 conn.conn_id.into(),
461 rpc_service_cfg,
462 );
463
464 let rpc_service = rpc_middleware.service(rpc_service);
465
466 let fut = async move {
469 let extensions = req.extensions().clone();
470
471 let upgraded = match hyper::upgrade::on(req).await {
472 Ok(upgraded) => upgraded,
473 Err(e) => {
474 tracing::debug!(target: LOG_TARGET, "WS upgrade handshake failed: {}", e);
475 return;
476 }
477 };
478
479 let io = TokioIo::new(upgraded);
480
481 let stream = BufReader::new(BufWriter::new(io.compat()));
482 let mut ws_builder = server.into_builder(stream);
483 ws_builder.set_max_message_size(server_cfg.max_response_body_size as usize);
484 let (sender, receiver) = ws_builder.finish();
485
486 let params = BackgroundTaskParams {
487 server_cfg,
488 conn,
489 ws_sender: sender,
490 ws_receiver: receiver,
491 rpc_service,
492 sink,
493 rx,
494 pending_calls_completed,
495 on_session_close: None,
496 extensions,
497 };
498
499 background_task(params).await;
500 };
501
502 Ok((response.map(|()| HttpBody::default()), fut))
503 }
504 Err(e) => {
505 tracing::debug!(target: LOG_TARGET, "WS upgrade handshake failed: {}", e);
506 Err(HttpResponse::new(HttpBody::from(format!("WS upgrade handshake failed: {e}"))))
507 }
508 }
509}