1use std::{
4 collections::HashMap,
5 future::Future,
6 pin::Pin,
7 sync::Arc,
8 task::{Context, Poll},
9 time::{Duration, Instant},
10};
11
12use futures_util::{
13 FutureExt, StreamExt,
14 future::{BoxFuture, Ready},
15 stream::Stream,
16};
17use pin_project_lite::pin_project;
18use serde::{Deserialize, Serialize};
19
20use crate::{Data, Error, Executor, Request, Response, Result, runtime::Timer as RtTimer};
21
22pub const ALL_WEBSOCKET_PROTOCOLS: [&str; 2] = ["graphql-transport-ws", "graphql-ws"];
24
25#[derive(Clone, Debug, PartialEq, Eq)]
27pub enum WsMessage {
28 Text(String),
30
31 Close(u16, String),
33}
34
35impl WsMessage {
36 pub fn unwrap_text(self) -> String {
45 match self {
46 Self::Text(text) => text,
47 Self::Close(_, _) => panic!("Not a text message"),
48 }
49 }
50
51 pub fn unwrap_close(self) -> (u16, String) {
60 match self {
61 Self::Close(code, msg) => (code, msg),
62 Self::Text(_) => panic!("Not a close message"),
63 }
64 }
65}
66
67struct Timer {
68 interval: Duration,
69 rt_timer: Box<dyn RtTimer>,
70 future: BoxFuture<'static, ()>,
71}
72
73impl Timer {
74 #[inline]
75 fn new<T>(rt_timer: T, interval: Duration) -> Self
76 where
77 T: RtTimer,
78 {
79 Self {
80 interval,
81 future: rt_timer.delay(interval),
82 rt_timer: Box::new(rt_timer),
83 }
84 }
85
86 #[inline]
87 fn reset(&mut self) {
88 self.future = self.rt_timer.delay(self.interval);
89 }
90}
91
92impl Stream for Timer {
93 type Item = ();
94
95 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
96 let this = &mut *self;
97 match this.future.poll_unpin(cx) {
98 Poll::Ready(_) => {
99 this.reset();
100 Poll::Ready(Some(()))
101 }
102 Poll::Pending => Poll::Pending,
103 }
104 }
105}
106
107pin_project! {
108 pub struct WebSocket<S, E, OnInit, OnPing> {
115 on_connection_init: Option<OnInit>,
116 on_ping: OnPing,
117 init_fut: Option<BoxFuture<'static, Result<Data>>>,
118 ping_fut: Option<BoxFuture<'static, Result<Option<serde_json::Value>>>>,
119 connection_data: Option<Data>,
120 data: Option<Arc<Data>>,
121 executor: E,
122 streams: HashMap<String, Pin<Box<dyn Stream<Item = Response> + Send>>>,
123 #[pin]
124 stream: S,
125 protocol: Protocols,
126 last_msg_at: Instant,
127 keepalive_timer: Option<Timer>,
128 close: bool,
129 }
130}
131
132type MessageMapStream<S> =
133 futures_util::stream::Map<S, fn(<S as Stream>::Item) -> serde_json::Result<ClientMessage>>;
134
135pub type DefaultOnConnInitType = fn(serde_json::Value) -> Ready<Result<Data>>;
137
138pub type DefaultOnPingType =
140 fn(Option<&Data>, Option<serde_json::Value>) -> Ready<Result<Option<serde_json::Value>>>;
141
142pub fn default_on_connection_init(_: serde_json::Value) -> Ready<Result<Data>> {
144 futures_util::future::ready(Ok(Data::default()))
145}
146
147pub fn default_on_ping(
149 _: Option<&Data>,
150 _: Option<serde_json::Value>,
151) -> Ready<Result<Option<serde_json::Value>>> {
152 futures_util::future::ready(Ok(None))
153}
154
155impl<S, E> WebSocket<S, E, DefaultOnConnInitType, DefaultOnPingType>
156where
157 E: Executor,
158 S: Stream<Item = serde_json::Result<ClientMessage>>,
159{
160 pub fn from_message_stream(executor: E, stream: S, protocol: Protocols) -> Self {
162 WebSocket {
163 on_connection_init: Some(default_on_connection_init),
164 on_ping: default_on_ping,
165 init_fut: None,
166 ping_fut: None,
167 connection_data: None,
168 data: None,
169 executor,
170 streams: HashMap::new(),
171 stream,
172 protocol,
173 last_msg_at: Instant::now(),
174 keepalive_timer: None,
175 close: false,
176 }
177 }
178}
179
180impl<S, E> WebSocket<MessageMapStream<S>, E, DefaultOnConnInitType, DefaultOnPingType>
181where
182 E: Executor,
183 S: Stream,
184 S::Item: AsRef<[u8]>,
185{
186 pub fn new(executor: E, stream: S, protocol: Protocols) -> Self {
188 let stream = stream
189 .map(ClientMessage::from_bytes as fn(S::Item) -> serde_json::Result<ClientMessage>);
190 WebSocket::from_message_stream(executor, stream, protocol)
191 }
192}
193
194impl<S, E, OnInit, OnPing> WebSocket<S, E, OnInit, OnPing>
195where
196 E: Executor,
197 S: Stream<Item = serde_json::Result<ClientMessage>>,
198{
199 #[must_use]
206 pub fn connection_data(mut self, data: Data) -> Self {
207 self.connection_data = Some(data);
208 self
209 }
210
211 #[must_use]
217 pub fn on_connection_init<F, R>(self, callback: F) -> WebSocket<S, E, F, OnPing>
218 where
219 F: FnOnce(serde_json::Value) -> R + Send + 'static,
220 R: Future<Output = Result<Data>> + Send + 'static,
221 {
222 WebSocket {
223 on_connection_init: Some(callback),
224 on_ping: self.on_ping,
225 init_fut: self.init_fut,
226 ping_fut: self.ping_fut,
227 connection_data: self.connection_data,
228 data: self.data,
229 executor: self.executor,
230 streams: self.streams,
231 stream: self.stream,
232 protocol: self.protocol,
233 last_msg_at: self.last_msg_at,
234 keepalive_timer: self.keepalive_timer,
235 close: self.close,
236 }
237 }
238
239 #[must_use]
248 pub fn on_ping<F, R>(self, callback: F) -> WebSocket<S, E, OnInit, F>
249 where
250 F: FnOnce(Option<&Data>, Option<serde_json::Value>) -> R + Send + Clone + 'static,
251 R: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
252 {
253 WebSocket {
254 on_connection_init: self.on_connection_init,
255 on_ping: callback,
256 init_fut: self.init_fut,
257 ping_fut: self.ping_fut,
258 connection_data: self.connection_data,
259 data: self.data,
260 executor: self.executor,
261 streams: self.streams,
262 stream: self.stream,
263 protocol: self.protocol,
264 last_msg_at: self.last_msg_at,
265 keepalive_timer: self.keepalive_timer,
266 close: self.close,
267 }
268 }
269
270 #[must_use]
277 pub fn keepalive_timeout<T>(self, timer: T, timeout: impl Into<Option<Duration>>) -> Self
278 where
279 T: RtTimer,
280 {
281 Self {
282 keepalive_timer: timeout.into().map(|timeout| Timer::new(timer, timeout)),
283 ..self
284 }
285 }
286}
287
288impl<S, E, OnInit, InitFut, OnPing, PingFut> Stream for WebSocket<S, E, OnInit, OnPing>
289where
290 E: Executor,
291 S: Stream<Item = serde_json::Result<ClientMessage>>,
292 OnInit: FnOnce(serde_json::Value) -> InitFut + Send + 'static,
293 InitFut: Future<Output = Result<Data>> + Send + 'static,
294 OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> PingFut + Clone + Send + 'static,
295 PingFut: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
296{
297 type Item = WsMessage;
298
299 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
300 let mut this = self.project();
301
302 if *this.close {
303 return Poll::Ready(None);
304 }
305
306 if let Some(keepalive_timer) = this.keepalive_timer
307 && let Poll::Ready(Some(())) = keepalive_timer.poll_next_unpin(cx)
308 {
309 return match this.protocol {
310 Protocols::SubscriptionsTransportWS => {
311 *this.close = true;
312 Poll::Ready(Some(WsMessage::Text(
313 serde_json::to_string(&ServerMessage::ConnectionError {
314 payload: Error::new("timeout"),
315 })
316 .unwrap(),
317 )))
318 }
319 Protocols::GraphQLWS => {
320 *this.close = true;
321 Poll::Ready(Some(WsMessage::Close(3008, "timeout".to_string())))
322 }
323 };
324 }
325
326 if this.init_fut.is_none() && this.ping_fut.is_none() {
327 while let Poll::Ready(message) = Pin::new(&mut this.stream).poll_next(cx) {
328 let message = match message {
329 Some(message) => message,
330 None => return Poll::Ready(None),
331 };
332
333 let message: ClientMessage = match message {
334 Ok(message) => message,
335 Err(err) => {
336 *this.close = true;
337 return Poll::Ready(Some(WsMessage::Close(1002, err.to_string())));
338 }
339 };
340
341 *this.last_msg_at = Instant::now();
342 if let Some(keepalive_timer) = this.keepalive_timer {
343 keepalive_timer.reset();
344 }
345
346 match message {
347 ClientMessage::ConnectionInit { payload } => {
348 if let Some(on_connection_init) = this.on_connection_init.take() {
349 *this.init_fut = Some(Box::pin(async move {
350 on_connection_init(payload.unwrap_or_default()).await
351 }));
352 break;
353 } else {
354 *this.close = true;
355 match this.protocol {
356 Protocols::SubscriptionsTransportWS => {
357 return Poll::Ready(Some(WsMessage::Text(
358 serde_json::to_string(&ServerMessage::ConnectionError {
359 payload: Error::new(
360 "Too many initialisation requests.",
361 ),
362 })
363 .unwrap(),
364 )));
365 }
366 Protocols::GraphQLWS => {
367 return Poll::Ready(Some(WsMessage::Close(
368 4429,
369 "Too many initialisation requests.".to_string(),
370 )));
371 }
372 }
373 }
374 }
375 ClientMessage::Start {
376 id,
377 payload: request,
378 } => {
379 if let Some(data) = this.data.clone() {
380 this.streams.insert(
381 id,
382 Box::pin(this.executor.execute_stream(request, Some(data))),
383 );
384 } else {
385 *this.close = true;
386 return Poll::Ready(Some(WsMessage::Close(
387 1011,
388 "The handshake is not completed.".to_string(),
389 )));
390 }
391 }
392 ClientMessage::Stop { id } => {
393 if this.streams.remove(&id).is_some() {
394 return Poll::Ready(Some(WsMessage::Text(
395 serde_json::to_string(&ServerMessage::Complete { id: &id })
396 .unwrap(),
397 )));
398 }
399 }
400 ClientMessage::ConnectionTerminate => {
404 *this.close = true;
405 return Poll::Ready(None);
406 }
407 ClientMessage::Ping { payload } => {
409 let on_ping = this.on_ping.clone();
410 let data = this.data.clone();
411 *this.ping_fut =
412 Some(Box::pin(
413 async move { on_ping(data.as_deref(), payload).await },
414 ));
415 break;
416 }
417 ClientMessage::Pong { .. } => {
418 }
420 }
421 }
422 }
423
424 if let Some(init_fut) = this.init_fut {
425 return init_fut.poll_unpin(cx).map(|res| {
426 *this.init_fut = None;
427 match res {
428 Ok(data) => {
429 let mut ctx_data = this.connection_data.take().unwrap_or_default();
430 ctx_data.merge(data);
431 *this.data = Some(Arc::new(ctx_data));
432 Some(WsMessage::Text(
433 serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
434 ))
435 }
436 Err(err) => {
437 *this.close = true;
438 match this.protocol {
439 Protocols::SubscriptionsTransportWS => Some(WsMessage::Text(
440 serde_json::to_string(&ServerMessage::ConnectionError {
441 payload: Error::new(err.message),
442 })
443 .unwrap(),
444 )),
445 Protocols::GraphQLWS => Some(WsMessage::Close(1002, err.message)),
446 }
447 }
448 }
449 });
450 }
451
452 if let Some(ping_fut) = this.ping_fut {
453 return ping_fut.poll_unpin(cx).map(|res| {
454 *this.ping_fut = None;
455 match res {
456 Ok(payload) => Some(WsMessage::Text(
457 serde_json::to_string(&ServerMessage::Pong { payload }).unwrap(),
458 )),
459 Err(err) => {
460 *this.close = true;
461 match this.protocol {
462 Protocols::SubscriptionsTransportWS => Some(WsMessage::Text(
463 serde_json::to_string(&ServerMessage::ConnectionError {
464 payload: Error::new(err.message),
465 })
466 .unwrap(),
467 )),
468 Protocols::GraphQLWS => Some(WsMessage::Close(1002, err.message)),
469 }
470 }
471 }
472 });
473 }
474
475 for (id, stream) in &mut *this.streams {
476 match Pin::new(stream).poll_next(cx) {
477 Poll::Ready(Some(payload)) => {
478 return Poll::Ready(Some(WsMessage::Text(
479 serde_json::to_string(&this.protocol.next_message(id, payload)).unwrap(),
480 )));
481 }
482 Poll::Ready(None) => {
483 let id = id.clone();
484 this.streams.remove(&id);
485 return Poll::Ready(Some(WsMessage::Text(
486 serde_json::to_string(&ServerMessage::Complete { id: &id }).unwrap(),
487 )));
488 }
489 Poll::Pending => {}
490 }
491 }
492
493 Poll::Pending
494 }
495}
496
497#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
499pub enum Protocols {
500 SubscriptionsTransportWS,
502 GraphQLWS,
504}
505
506impl Protocols {
507 pub fn sec_websocket_protocol(&self) -> &'static str {
509 match self {
510 Protocols::SubscriptionsTransportWS => "graphql-ws",
511 Protocols::GraphQLWS => "graphql-transport-ws",
512 }
513 }
514
515 #[inline]
516 fn next_message<'s>(&self, id: &'s str, payload: Response) -> ServerMessage<'s> {
517 match self {
518 Protocols::SubscriptionsTransportWS => ServerMessage::Data { id, payload },
519 Protocols::GraphQLWS => ServerMessage::Next { id, payload },
520 }
521 }
522}
523
524impl std::str::FromStr for Protocols {
525 type Err = Error;
526
527 fn from_str(protocol: &str) -> Result<Self, Self::Err> {
528 if protocol.eq_ignore_ascii_case("graphql-ws") {
529 Ok(Protocols::SubscriptionsTransportWS)
530 } else if protocol.eq_ignore_ascii_case("graphql-transport-ws") {
531 Ok(Protocols::GraphQLWS)
532 } else {
533 Err(Error::new(format!(
534 "Unsupported Sec-WebSocket-Protocol: {}",
535 protocol
536 )))
537 }
538 }
539}
540
541#[derive(Deserialize)]
543#[serde(tag = "type", rename_all = "snake_case")]
544#[allow(clippy::large_enum_variant)] pub enum ClientMessage {
546 ConnectionInit {
548 payload: Option<serde_json::Value>,
550 },
551 #[serde(alias = "subscribe")]
553 Start {
554 id: String,
556 payload: Request,
559 },
560 #[serde(alias = "complete")]
562 Stop {
563 id: String,
565 },
566 ConnectionTerminate,
568 Ping {
573 payload: Option<serde_json::Value>,
575 },
576 Pong {
580 payload: Option<serde_json::Value>,
582 },
583}
584
585impl ClientMessage {
586 pub fn from_bytes<T>(message: T) -> serde_json::Result<Self>
588 where
589 T: AsRef<[u8]>,
590 {
591 serde_json::from_slice(message.as_ref())
592 }
593}
594
595#[derive(Serialize)]
596#[serde(tag = "type", rename_all = "snake_case")]
597enum ServerMessage<'a> {
598 ConnectionError {
599 payload: Error,
600 },
601 ConnectionAck,
602 Data {
604 id: &'a str,
605 payload: Response,
606 },
607 Next {
609 id: &'a str,
610 payload: Response,
611 },
612 Complete {
618 id: &'a str,
619 },
620 Pong {
624 #[serde(skip_serializing_if = "Option::is_none")]
625 payload: Option<serde_json::Value>,
626 },
627 }