async_graphql_viz/
subscription.rs

1use std::{borrow::Cow, future::Future};
2
3use async_graphql::{
4    http::{WebSocketProtocols, WsMessage},
5    Data, ObjectType, Result, Schema, SubscriptionType,
6};
7
8use viz_core::{
9    http::{
10        header,
11        headers::{self, Header, HeaderName, HeaderValue},
12    },
13    ws::{Message, WebSocket},
14};
15use viz_utils::{
16    futures::{future, SinkExt, StreamExt},
17    serde::json::Value,
18};
19
20/// The Sec-Websocket-Protocol header.
21#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
22pub struct SecWebsocketProtocol(pub WebSocketProtocols);
23
24impl Header for SecWebsocketProtocol {
25    fn name() -> &'static HeaderName {
26        &header::SEC_WEBSOCKET_PROTOCOL
27    }
28
29    fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
30    where
31        Self: Sized,
32        I: Iterator<Item = &'i HeaderValue>,
33    {
34        match values.next() {
35            Some(value) => Ok(SecWebsocketProtocol(
36                value
37                    .to_str()
38                    .map_err(|_| headers::Error::invalid())?
39                    .parse()
40                    .ok()
41                    .unwrap_or(WebSocketProtocols::SubscriptionsTransportWS),
42            )),
43            None => Err(headers::Error::invalid()),
44        }
45    }
46
47    fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
48        values.extend(std::iter::once(HeaderValue::from_static(
49            self.0.sec_websocket_protocol(),
50        )))
51    }
52}
53
54/// GraphQL subscription handler
55pub async fn graphql_subscription<Query, Mutation, Subscription>(
56    websocket: WebSocket,
57    schema: Schema<Query, Mutation, Subscription>,
58    protocol: SecWebsocketProtocol,
59) where
60    Query: ObjectType + Sync + Send + 'static,
61    Mutation: ObjectType + Sync + Send + 'static,
62    Subscription: SubscriptionType + Send + Sync + 'static,
63{
64    graphql_subscription_with_data(websocket, schema, protocol, |_| async {
65        Ok(Default::default())
66    })
67    .await
68}
69
70/// GraphQL subscription handler
71///
72/// Specifies that a function converts the init payload to data.
73pub async fn graphql_subscription_with_data<Query, Mutation, Subscription, F, R>(
74    websocket: WebSocket,
75    schema: Schema<Query, Mutation, Subscription>,
76    protocol: SecWebsocketProtocol,
77    initializer: F,
78) where
79    Query: ObjectType + 'static,
80    Mutation: ObjectType + 'static,
81    Subscription: SubscriptionType + 'static,
82    F: FnOnce(Value) -> R + Send + 'static,
83    R: Future<Output = Result<Data>> + Send + 'static,
84{
85    let (mut sink, stream) = websocket.split();
86    let input = stream
87        .take_while(|res| future::ready(res.is_ok()))
88        .map(Result::unwrap)
89        .filter(|msg| future::ready(msg.is_text() || msg.is_binary()))
90        .map(Message::into_bytes);
91
92    let mut stream =
93        async_graphql::http::WebSocket::with_data(schema, input, initializer, protocol.0).map(
94            |msg| match msg {
95                WsMessage::Text(text) => Message::text(text),
96                WsMessage::Close(code, status) => Message::close_with(code, Cow::from(status)),
97            },
98        );
99
100    while let Some(item) = stream.next().await {
101        let _ = sink.send(item).await;
102    }
103}