async_graphql_warp/
subscription.rs

1use std::{future::Future, str::FromStr, time::Duration};
2
3use async_graphql::{
4    Data, Executor, Result,
5    http::{
6        DefaultOnConnInitType, DefaultOnPingType, WebSocketProtocols, WsMessage,
7        default_on_connection_init, default_on_ping,
8    },
9};
10use futures_util::{
11    Sink, Stream, StreamExt, future,
12    stream::{SplitSink, SplitStream},
13};
14use warp::{Error, Filter, Rejection, Reply, filters::ws, ws::Message};
15
16/// GraphQL subscription filter
17///
18/// # Examples
19///
20/// ```no_run
21/// use std::time::Duration;
22///
23/// use async_graphql::*;
24/// use async_graphql_warp::*;
25/// use futures_util::stream::{Stream, StreamExt};
26/// use warp::Filter;
27///
28/// struct QueryRoot;
29///
30/// #[Object]
31/// impl QueryRoot {
32///     async fn value(&self) -> i32 {
33///         // A GraphQL Object type must define one or more fields.
34///         100
35///     }
36/// }
37///
38/// struct SubscriptionRoot;
39///
40/// #[Subscription]
41/// impl SubscriptionRoot {
42///     async fn tick(&self) -> impl Stream<Item = String> {
43///         async_stream::stream! {
44///             let mut interval = tokio::time::interval(Duration::from_secs(1));
45///             loop {
46///                 let n = interval.tick().await;
47///                 yield format!("{}", n.elapsed().as_secs_f32());
48///             }
49///         }
50///     }
51/// }
52///
53/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
54/// let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
55/// let filter =
56///     async_graphql_warp::graphql_subscription(schema).or(warp::any().map(|| "Hello, World!"));
57/// warp::serve(filter).run(([0, 0, 0, 0], 8000)).await;
58/// # });
59/// ```
60pub fn graphql_subscription<E>(
61    executor: E,
62) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone
63where
64    E: Executor,
65{
66    warp::ws()
67        .and(graphql_protocol())
68        .map(move |ws: ws::Ws, protocol| {
69            let executor = executor.clone();
70
71            let reply = ws.on_upgrade(move |socket| {
72                GraphQLWebSocket::new(socket, executor, protocol).serve()
73            });
74
75            warp::reply::with_header(
76                reply,
77                "Sec-WebSocket-Protocol",
78                protocol.sec_websocket_protocol(),
79            )
80        })
81}
82
83/// Create a `Filter` that parse [WebSocketProtocols] from
84/// `sec-websocket-protocol` header.
85pub fn graphql_protocol() -> impl Filter<Extract = (WebSocketProtocols,), Error = Rejection> + Clone
86{
87    warp::header::optional::<String>("sec-websocket-protocol").map(|protocols: Option<String>| {
88        protocols
89            .and_then(|protocols| {
90                protocols
91                    .split(',')
92                    .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
93            })
94            .unwrap_or(WebSocketProtocols::SubscriptionsTransportWS)
95    })
96}
97
98/// A Websocket connection for GraphQL subscription.
99///
100/// # Examples
101///
102/// ```no_run
103/// use std::time::Duration;
104///
105/// use async_graphql::*;
106/// use async_graphql_warp::*;
107/// use futures_util::stream::{Stream, StreamExt};
108/// use warp::{Filter, ws};
109///
110/// struct QueryRoot;
111///
112/// #[Object]
113/// impl QueryRoot {
114///     async fn value(&self) -> i32 {
115///         // A GraphQL Object type must define one or more fields.
116///         100
117///     }
118/// }
119///
120/// struct SubscriptionRoot;
121///
122/// #[Subscription]
123/// impl SubscriptionRoot {
124///     async fn tick(&self) -> impl Stream<Item = String> {
125///         async_stream::stream! {
126///             let mut interval = tokio::time::interval(Duration::from_secs(1));
127///             loop {
128///                 let n = interval.tick().await;
129///                 yield format!("{}", n.elapsed().as_secs_f32());
130///             }
131///         }
132///     }
133/// }
134///
135/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
136/// let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
137///
138/// let filter = warp::ws()
139///     .and(graphql_protocol())
140///     .map(move |ws: ws::Ws, protocol| {
141///         let schema = schema.clone();
142///
143///         let reply = ws
144///             .on_upgrade(move |socket| GraphQLWebSocket::new(socket, schema, protocol).serve());
145///
146///         warp::reply::with_header(
147///             reply,
148///             "Sec-WebSocket-Protocol",
149///             protocol.sec_websocket_protocol(),
150///         )
151///     });
152///
153/// warp::serve(filter).run(([0, 0, 0, 0], 8000)).await;
154/// # });
155/// ```
156pub struct GraphQLWebSocket<Sink, Stream, E, OnInit, OnPing> {
157    sink: Sink,
158    stream: Stream,
159    protocol: WebSocketProtocols,
160    executor: E,
161    data: Data,
162    on_init: OnInit,
163    on_ping: OnPing,
164    keepalive_timeout: Option<Duration>,
165}
166
167impl<S, E>
168    GraphQLWebSocket<
169        SplitSink<S, Message>,
170        SplitStream<S>,
171        E,
172        DefaultOnConnInitType,
173        DefaultOnPingType,
174    >
175where
176    S: Stream<Item = Result<Message, Error>> + Sink<Message>,
177    E: Executor,
178{
179    /// Create a [`GraphQLWebSocket`] object.
180    pub fn new(socket: S, executor: E, protocol: WebSocketProtocols) -> Self {
181        let (sink, stream) = socket.split();
182        GraphQLWebSocket::new_with_pair(sink, stream, executor, protocol)
183    }
184}
185
186impl<Sink, Stream, E> GraphQLWebSocket<Sink, Stream, E, DefaultOnConnInitType, DefaultOnPingType>
187where
188    Sink: futures_util::sink::Sink<Message>,
189    Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
190    E: Executor,
191{
192    /// Create a [`GraphQLWebSocket`] object with sink and stream objects.
193    pub fn new_with_pair(
194        sink: Sink,
195        stream: Stream,
196        executor: E,
197        protocol: WebSocketProtocols,
198    ) -> Self {
199        GraphQLWebSocket {
200            sink,
201            stream,
202            protocol,
203            executor,
204            data: Data::default(),
205            on_init: default_on_connection_init,
206            on_ping: default_on_ping,
207            keepalive_timeout: None,
208        }
209    }
210}
211
212impl<Sink, Stream, E, OnConnInit, OnConnInitFut, OnPing, OnPingFut>
213    GraphQLWebSocket<Sink, Stream, E, OnConnInit, OnPing>
214where
215    Sink: futures_util::sink::Sink<Message>,
216    Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
217    E: Executor,
218    OnConnInit: FnOnce(serde_json::Value) -> OnConnInitFut + Send + 'static,
219    OnConnInitFut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
220    OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> OnPingFut + Clone + Send + 'static,
221    OnPingFut: Future<Output = async_graphql::Result<Option<serde_json::Value>>> + Send + 'static,
222{
223    /// Specify the initial subscription context data, usually you can get
224    /// something from the incoming request to create it.
225    #[must_use]
226    pub fn with_data(self, data: Data) -> Self {
227        Self { data, ..self }
228    }
229
230    /// Specify a callback function to be called when the connection is
231    /// initialized.
232    ///
233    /// You can get something from the payload of [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init) to create [`Data`].
234    /// The data returned by this callback function will be merged with the data
235    /// specified by [`with_data`].
236    #[must_use]
237    pub fn on_connection_init<F, R>(
238        self,
239        callback: F,
240    ) -> GraphQLWebSocket<Sink, Stream, E, F, OnPing>
241    where
242        F: FnOnce(serde_json::Value) -> R + Send + 'static,
243        R: Future<Output = async_graphql::Result<Data>> + Send + 'static,
244    {
245        GraphQLWebSocket {
246            sink: self.sink,
247            stream: self.stream,
248            executor: self.executor,
249            data: self.data,
250            on_init: callback,
251            on_ping: self.on_ping,
252            protocol: self.protocol,
253            keepalive_timeout: self.keepalive_timeout,
254        }
255    }
256
257    /// Specify a ping callback function.
258    ///
259    /// This function if present, will be called with the data sent by the
260    /// client in the [`Ping` message](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#ping).
261    ///
262    /// The function should return the data to be sent in the [`Pong` message](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#pong).
263    ///
264    /// NOTE: Only used for the `graphql-ws` protocol.
265    #[must_use]
266    pub fn on_ping<F, R>(self, callback: F) -> GraphQLWebSocket<Sink, Stream, E, OnConnInit, F>
267    where
268        F: FnOnce(Option<&Data>, Option<serde_json::Value>) -> R + Send + Clone + 'static,
269        R: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
270    {
271        GraphQLWebSocket {
272            sink: self.sink,
273            stream: self.stream,
274            executor: self.executor,
275            data: self.data,
276            on_init: self.on_init,
277            on_ping: callback,
278            protocol: self.protocol,
279            keepalive_timeout: self.keepalive_timeout,
280        }
281    }
282
283    /// Sets a timeout for receiving an acknowledgement of the keep-alive ping.
284    ///
285    /// If the ping is not acknowledged within the timeout, the connection will
286    /// be closed.
287    ///
288    /// NOTE: Only used for the `graphql-ws` protocol.
289    #[must_use]
290    pub fn keepalive_timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
291        Self {
292            keepalive_timeout: timeout.into(),
293            ..self
294        }
295    }
296
297    /// Processing subscription requests.
298    pub async fn serve(self) {
299        let stream = self
300            .stream
301            .take_while(|msg| future::ready(msg.is_ok()))
302            .map(Result::unwrap)
303            .filter(|msg| future::ready(msg.is_text() || msg.is_binary()))
304            .map(ws::Message::into_bytes);
305
306        let _ = async_graphql::http::WebSocket::new(self.executor.clone(), stream, self.protocol)
307            .connection_data(self.data)
308            .on_connection_init(self.on_init)
309            .on_ping(self.on_ping)
310            .keepalive_timeout(self.keepalive_timeout)
311            .map(|msg| match msg {
312                WsMessage::Text(text) => ws::Message::text(text),
313                WsMessage::Close(code, status) => ws::Message::close_with(code, status),
314            })
315            .map(Ok)
316            .forward(self.sink)
317            .await;
318    }
319}