1use std::{
4 collections::HashMap,
5 future::Future,
6 pin::Pin,
7 sync::Arc,
8 task::{Context, Poll},
9 time::{Duration, Instant},
10};
11
12use futures_util::{
13 FutureExt, StreamExt,
14 future::{BoxFuture, Ready},
15 stream::Stream,
16};
17use pin_project_lite::pin_project;
18use serde::{Deserialize, Serialize};
19
20use crate::{Data, Error, Executor, Request, Response, Result, util::Delay};
21
22pub const ALL_WEBSOCKET_PROTOCOLS: [&str; 2] = ["graphql-transport-ws", "graphql-ws"];
24
25#[derive(Clone, Debug, PartialEq, Eq)]
27pub enum WsMessage {
28 Text(String),
30
31 Close(u16, String),
33}
34
35impl WsMessage {
36 pub fn unwrap_text(self) -> String {
45 match self {
46 Self::Text(text) => text,
47 Self::Close(_, _) => panic!("Not a text message"),
48 }
49 }
50
51 pub fn unwrap_close(self) -> (u16, String) {
60 match self {
61 Self::Close(code, msg) => (code, msg),
62 Self::Text(_) => panic!("Not a close message"),
63 }
64 }
65}
66
67struct Timer {
68 interval: Duration,
69 delay: Delay,
70}
71
72impl Timer {
73 #[inline]
74 fn new(interval: Duration) -> Self {
75 Self {
76 interval,
77 delay: Delay::new(interval),
78 }
79 }
80
81 #[inline]
82 fn reset(&mut self) {
83 self.delay.reset(self.interval);
84 }
85}
86
87impl Stream for Timer {
88 type Item = ();
89
90 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
91 let this = &mut *self;
92 match this.delay.poll_unpin(cx) {
93 Poll::Ready(_) => {
94 this.delay.reset(this.interval);
95 Poll::Ready(Some(()))
96 }
97 Poll::Pending => Poll::Pending,
98 }
99 }
100}
101
102pin_project! {
103 pub struct WebSocket<S, E, OnInit, OnPing> {
110 on_connection_init: Option<OnInit>,
111 on_ping: OnPing,
112 init_fut: Option<BoxFuture<'static, Result<Data>>>,
113 ping_fut: Option<BoxFuture<'static, Result<Option<serde_json::Value>>>>,
114 connection_data: Option<Data>,
115 data: Option<Arc<Data>>,
116 executor: E,
117 streams: HashMap<String, Pin<Box<dyn Stream<Item = Response> + Send>>>,
118 #[pin]
119 stream: S,
120 protocol: Protocols,
121 last_msg_at: Instant,
122 keepalive_timer: Option<Timer>,
123 close: bool,
124 }
125}
126
127type MessageMapStream<S> =
128 futures_util::stream::Map<S, fn(<S as Stream>::Item) -> serde_json::Result<ClientMessage>>;
129
130pub type DefaultOnConnInitType = fn(serde_json::Value) -> Ready<Result<Data>>;
132
133pub type DefaultOnPingType =
135 fn(Option<&Data>, Option<serde_json::Value>) -> Ready<Result<Option<serde_json::Value>>>;
136
137pub fn default_on_connection_init(_: serde_json::Value) -> Ready<Result<Data>> {
139 futures_util::future::ready(Ok(Data::default()))
140}
141
142pub fn default_on_ping(
144 _: Option<&Data>,
145 _: Option<serde_json::Value>,
146) -> Ready<Result<Option<serde_json::Value>>> {
147 futures_util::future::ready(Ok(None))
148}
149
150impl<S, E> WebSocket<S, E, DefaultOnConnInitType, DefaultOnPingType>
151where
152 E: Executor,
153 S: Stream<Item = serde_json::Result<ClientMessage>>,
154{
155 pub fn from_message_stream(executor: E, stream: S, protocol: Protocols) -> Self {
157 WebSocket {
158 on_connection_init: Some(default_on_connection_init),
159 on_ping: default_on_ping,
160 init_fut: None,
161 ping_fut: None,
162 connection_data: None,
163 data: None,
164 executor,
165 streams: HashMap::new(),
166 stream,
167 protocol,
168 last_msg_at: Instant::now(),
169 keepalive_timer: None,
170 close: false,
171 }
172 }
173}
174
175impl<S, E> WebSocket<MessageMapStream<S>, E, DefaultOnConnInitType, DefaultOnPingType>
176where
177 E: Executor,
178 S: Stream,
179 S::Item: AsRef<[u8]>,
180{
181 pub fn new(executor: E, stream: S, protocol: Protocols) -> Self {
183 let stream = stream
184 .map(ClientMessage::from_bytes as fn(S::Item) -> serde_json::Result<ClientMessage>);
185 WebSocket::from_message_stream(executor, stream, protocol)
186 }
187}
188
189impl<S, E, OnInit, OnPing> WebSocket<S, E, OnInit, OnPing>
190where
191 E: Executor,
192 S: Stream<Item = serde_json::Result<ClientMessage>>,
193{
194 #[must_use]
201 pub fn connection_data(mut self, data: Data) -> Self {
202 self.connection_data = Some(data);
203 self
204 }
205
206 #[must_use]
212 pub fn on_connection_init<F, R>(self, callback: F) -> WebSocket<S, E, F, OnPing>
213 where
214 F: FnOnce(serde_json::Value) -> R + Send + 'static,
215 R: Future<Output = Result<Data>> + Send + 'static,
216 {
217 WebSocket {
218 on_connection_init: Some(callback),
219 on_ping: self.on_ping,
220 init_fut: self.init_fut,
221 ping_fut: self.ping_fut,
222 connection_data: self.connection_data,
223 data: self.data,
224 executor: self.executor,
225 streams: self.streams,
226 stream: self.stream,
227 protocol: self.protocol,
228 last_msg_at: self.last_msg_at,
229 keepalive_timer: self.keepalive_timer,
230 close: self.close,
231 }
232 }
233
234 #[must_use]
243 pub fn on_ping<F, R>(self, callback: F) -> WebSocket<S, E, OnInit, F>
244 where
245 F: FnOnce(Option<&Data>, Option<serde_json::Value>) -> R + Send + Clone + 'static,
246 R: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
247 {
248 WebSocket {
249 on_connection_init: self.on_connection_init,
250 on_ping: callback,
251 init_fut: self.init_fut,
252 ping_fut: self.ping_fut,
253 connection_data: self.connection_data,
254 data: self.data,
255 executor: self.executor,
256 streams: self.streams,
257 stream: self.stream,
258 protocol: self.protocol,
259 last_msg_at: self.last_msg_at,
260 keepalive_timer: self.keepalive_timer,
261 close: self.close,
262 }
263 }
264
265 #[must_use]
272 pub fn keepalive_timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
273 Self {
274 keepalive_timer: timeout.into().map(Timer::new),
275 ..self
276 }
277 }
278}
279
280impl<S, E, OnInit, InitFut, OnPing, PingFut> Stream for WebSocket<S, E, OnInit, OnPing>
281where
282 E: Executor,
283 S: Stream<Item = serde_json::Result<ClientMessage>>,
284 OnInit: FnOnce(serde_json::Value) -> InitFut + Send + 'static,
285 InitFut: Future<Output = Result<Data>> + Send + 'static,
286 OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> PingFut + Clone + Send + 'static,
287 PingFut: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
288{
289 type Item = WsMessage;
290
291 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
292 let mut this = self.project();
293
294 if *this.close {
295 return Poll::Ready(None);
296 }
297
298 if let Some(keepalive_timer) = this.keepalive_timer
299 && let Poll::Ready(Some(())) = keepalive_timer.poll_next_unpin(cx)
300 {
301 return match this.protocol {
302 Protocols::SubscriptionsTransportWS => {
303 *this.close = true;
304 Poll::Ready(Some(WsMessage::Text(
305 serde_json::to_string(&ServerMessage::ConnectionError {
306 payload: Error::new("timeout"),
307 })
308 .unwrap(),
309 )))
310 }
311 Protocols::GraphQLWS => {
312 *this.close = true;
313 Poll::Ready(Some(WsMessage::Close(3008, "timeout".to_string())))
314 }
315 };
316 }
317
318 if this.init_fut.is_none() && this.ping_fut.is_none() {
319 while let Poll::Ready(message) = Pin::new(&mut this.stream).poll_next(cx) {
320 let message = match message {
321 Some(message) => message,
322 None => return Poll::Ready(None),
323 };
324
325 let message: ClientMessage = match message {
326 Ok(message) => message,
327 Err(err) => {
328 *this.close = true;
329 return Poll::Ready(Some(WsMessage::Close(1002, err.to_string())));
330 }
331 };
332
333 *this.last_msg_at = Instant::now();
334 if let Some(keepalive_timer) = this.keepalive_timer {
335 keepalive_timer.reset();
336 }
337
338 match message {
339 ClientMessage::ConnectionInit { payload } => {
340 if let Some(on_connection_init) = this.on_connection_init.take() {
341 *this.init_fut = Some(Box::pin(async move {
342 on_connection_init(payload.unwrap_or_default()).await
343 }));
344 break;
345 } else {
346 *this.close = true;
347 match this.protocol {
348 Protocols::SubscriptionsTransportWS => {
349 return Poll::Ready(Some(WsMessage::Text(
350 serde_json::to_string(&ServerMessage::ConnectionError {
351 payload: Error::new(
352 "Too many initialisation requests.",
353 ),
354 })
355 .unwrap(),
356 )));
357 }
358 Protocols::GraphQLWS => {
359 return Poll::Ready(Some(WsMessage::Close(
360 4429,
361 "Too many initialisation requests.".to_string(),
362 )));
363 }
364 }
365 }
366 }
367 ClientMessage::Start {
368 id,
369 payload: request,
370 } => {
371 if let Some(data) = this.data.clone() {
372 this.streams.insert(
373 id,
374 Box::pin(this.executor.execute_stream(request, Some(data))),
375 );
376 } else {
377 *this.close = true;
378 return Poll::Ready(Some(WsMessage::Close(
379 1011,
380 "The handshake is not completed.".to_string(),
381 )));
382 }
383 }
384 ClientMessage::Stop { id } => {
385 if this.streams.remove(&id).is_some() {
386 return Poll::Ready(Some(WsMessage::Text(
387 serde_json::to_string(&ServerMessage::Complete { id: &id })
388 .unwrap(),
389 )));
390 }
391 }
392 ClientMessage::ConnectionTerminate => {
396 *this.close = true;
397 return Poll::Ready(None);
398 }
399 ClientMessage::Ping { payload } => {
401 let on_ping = this.on_ping.clone();
402 let data = this.data.clone();
403 *this.ping_fut =
404 Some(Box::pin(
405 async move { on_ping(data.as_deref(), payload).await },
406 ));
407 break;
408 }
409 ClientMessage::Pong { .. } => {
410 }
412 }
413 }
414 }
415
416 if let Some(init_fut) = this.init_fut {
417 return init_fut.poll_unpin(cx).map(|res| {
418 *this.init_fut = None;
419 match res {
420 Ok(data) => {
421 let mut ctx_data = this.connection_data.take().unwrap_or_default();
422 ctx_data.merge(data);
423 *this.data = Some(Arc::new(ctx_data));
424 Some(WsMessage::Text(
425 serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
426 ))
427 }
428 Err(err) => {
429 *this.close = true;
430 match this.protocol {
431 Protocols::SubscriptionsTransportWS => Some(WsMessage::Text(
432 serde_json::to_string(&ServerMessage::ConnectionError {
433 payload: Error::new(err.message),
434 })
435 .unwrap(),
436 )),
437 Protocols::GraphQLWS => Some(WsMessage::Close(1002, err.message)),
438 }
439 }
440 }
441 });
442 }
443
444 if let Some(ping_fut) = this.ping_fut {
445 return ping_fut.poll_unpin(cx).map(|res| {
446 *this.ping_fut = None;
447 match res {
448 Ok(payload) => Some(WsMessage::Text(
449 serde_json::to_string(&ServerMessage::Pong { payload }).unwrap(),
450 )),
451 Err(err) => {
452 *this.close = true;
453 match this.protocol {
454 Protocols::SubscriptionsTransportWS => Some(WsMessage::Text(
455 serde_json::to_string(&ServerMessage::ConnectionError {
456 payload: Error::new(err.message),
457 })
458 .unwrap(),
459 )),
460 Protocols::GraphQLWS => Some(WsMessage::Close(1002, err.message)),
461 }
462 }
463 }
464 });
465 }
466
467 for (id, stream) in &mut *this.streams {
468 match Pin::new(stream).poll_next(cx) {
469 Poll::Ready(Some(payload)) => {
470 return Poll::Ready(Some(WsMessage::Text(
471 serde_json::to_string(&this.protocol.next_message(id, payload)).unwrap(),
472 )));
473 }
474 Poll::Ready(None) => {
475 let id = id.clone();
476 this.streams.remove(&id);
477 return Poll::Ready(Some(WsMessage::Text(
478 serde_json::to_string(&ServerMessage::Complete { id: &id }).unwrap(),
479 )));
480 }
481 Poll::Pending => {}
482 }
483 }
484
485 Poll::Pending
486 }
487}
488
489#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
491pub enum Protocols {
492 SubscriptionsTransportWS,
494 GraphQLWS,
496}
497
498impl Protocols {
499 pub fn sec_websocket_protocol(&self) -> &'static str {
501 match self {
502 Protocols::SubscriptionsTransportWS => "graphql-ws",
503 Protocols::GraphQLWS => "graphql-transport-ws",
504 }
505 }
506
507 #[inline]
508 fn next_message<'s>(&self, id: &'s str, payload: Response) -> ServerMessage<'s> {
509 match self {
510 Protocols::SubscriptionsTransportWS => ServerMessage::Data { id, payload },
511 Protocols::GraphQLWS => ServerMessage::Next { id, payload },
512 }
513 }
514}
515
516impl std::str::FromStr for Protocols {
517 type Err = Error;
518
519 fn from_str(protocol: &str) -> Result<Self, Self::Err> {
520 if protocol.eq_ignore_ascii_case("graphql-ws") {
521 Ok(Protocols::SubscriptionsTransportWS)
522 } else if protocol.eq_ignore_ascii_case("graphql-transport-ws") {
523 Ok(Protocols::GraphQLWS)
524 } else {
525 Err(Error::new(format!(
526 "Unsupported Sec-WebSocket-Protocol: {}",
527 protocol
528 )))
529 }
530 }
531}
532
533#[derive(Deserialize)]
535#[serde(tag = "type", rename_all = "snake_case")]
536#[allow(clippy::large_enum_variant)] pub enum ClientMessage {
538 ConnectionInit {
540 payload: Option<serde_json::Value>,
542 },
543 #[serde(alias = "subscribe")]
545 Start {
546 id: String,
548 payload: Request,
551 },
552 #[serde(alias = "complete")]
554 Stop {
555 id: String,
557 },
558 ConnectionTerminate,
560 Ping {
565 payload: Option<serde_json::Value>,
567 },
568 Pong {
572 payload: Option<serde_json::Value>,
574 },
575}
576
577impl ClientMessage {
578 pub fn from_bytes<T>(message: T) -> serde_json::Result<Self>
580 where
581 T: AsRef<[u8]>,
582 {
583 serde_json::from_slice(message.as_ref())
584 }
585}
586
587#[derive(Serialize)]
588#[serde(tag = "type", rename_all = "snake_case")]
589enum ServerMessage<'a> {
590 ConnectionError {
591 payload: Error,
592 },
593 ConnectionAck,
594 Data {
596 id: &'a str,
597 payload: Response,
598 },
599 Next {
601 id: &'a str,
602 payload: Response,
603 },
604 Complete {
610 id: &'a str,
611 },
612 Pong {
616 #[serde(skip_serializing_if = "Option::is_none")]
617 payload: Option<serde_json::Value>,
618 },
619 }