async_graphql_viz/
subscription.rs1use std::{borrow::Cow, future::Future};
2
3use async_graphql::{
4 http::{WebSocketProtocols, WsMessage},
5 Data, ObjectType, Result, Schema, SubscriptionType,
6};
7
8use viz_core::{
9 http::{
10 header,
11 headers::{self, Header, HeaderName, HeaderValue},
12 },
13 ws::{Message, WebSocket},
14};
15use viz_utils::{
16 futures::{future, SinkExt, StreamExt},
17 serde::json::Value,
18};
19
20#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
22pub struct SecWebsocketProtocol(pub WebSocketProtocols);
23
24impl Header for SecWebsocketProtocol {
25 fn name() -> &'static HeaderName {
26 &header::SEC_WEBSOCKET_PROTOCOL
27 }
28
29 fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
30 where
31 Self: Sized,
32 I: Iterator<Item = &'i HeaderValue>,
33 {
34 match values.next() {
35 Some(value) => Ok(SecWebsocketProtocol(
36 value
37 .to_str()
38 .map_err(|_| headers::Error::invalid())?
39 .parse()
40 .ok()
41 .unwrap_or(WebSocketProtocols::SubscriptionsTransportWS),
42 )),
43 None => Err(headers::Error::invalid()),
44 }
45 }
46
47 fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
48 values.extend(std::iter::once(HeaderValue::from_static(
49 self.0.sec_websocket_protocol(),
50 )))
51 }
52}
53
54pub async fn graphql_subscription<Query, Mutation, Subscription>(
56 websocket: WebSocket,
57 schema: Schema<Query, Mutation, Subscription>,
58 protocol: SecWebsocketProtocol,
59) where
60 Query: ObjectType + Sync + Send + 'static,
61 Mutation: ObjectType + Sync + Send + 'static,
62 Subscription: SubscriptionType + Send + Sync + 'static,
63{
64 graphql_subscription_with_data(websocket, schema, protocol, |_| async {
65 Ok(Default::default())
66 })
67 .await
68}
69
70pub async fn graphql_subscription_with_data<Query, Mutation, Subscription, F, R>(
74 websocket: WebSocket,
75 schema: Schema<Query, Mutation, Subscription>,
76 protocol: SecWebsocketProtocol,
77 initializer: F,
78) where
79 Query: ObjectType + 'static,
80 Mutation: ObjectType + 'static,
81 Subscription: SubscriptionType + 'static,
82 F: FnOnce(Value) -> R + Send + 'static,
83 R: Future<Output = Result<Data>> + Send + 'static,
84{
85 let (mut sink, stream) = websocket.split();
86 let input = stream
87 .take_while(|res| future::ready(res.is_ok()))
88 .map(Result::unwrap)
89 .filter(|msg| future::ready(msg.is_text() || msg.is_binary()))
90 .map(Message::into_bytes);
91
92 let mut stream =
93 async_graphql::http::WebSocket::with_data(schema, input, initializer, protocol.0).map(
94 |msg| match msg {
95 WsMessage::Text(text) => Message::text(text),
96 WsMessage::Close(code, status) => Message::close_with(code, Cow::from(status)),
97 },
98 );
99
100 while let Some(item) = stream.next().await {
101 let _ = sink.send(item).await;
102 }
103}