async_graphql_warp/
subscription.rs

1use std::{future::Future, str::FromStr, time::Duration};
2
3use async_graphql::{
4    Data, Executor, Result,
5    http::{
6        DefaultOnConnInitType, DefaultOnPingType, WebSocketProtocols, WsMessage,
7        default_on_connection_init, default_on_ping,
8    },
9};
10use futures_util::{
11    Sink, Stream, StreamExt, future,
12    stream::{SplitSink, SplitStream},
13};
14use warp::{Error, Filter, Rejection, Reply, filters::ws, ws::Message};
15
16/// GraphQL subscription filter
17///
18/// # Examples
19///
20/// ```no_run
21/// use std::time::Duration;
22///
23/// use async_graphql::*;
24/// use async_graphql_warp::*;
25/// use futures_util::stream::{Stream, StreamExt};
26/// use warp::Filter;
27///
28/// struct QueryRoot;
29///
30/// #[Object]
31/// impl QueryRoot {
32///     async fn value(&self) -> i32 {
33///         // A GraphQL Object type must define one or more fields.
34///         100
35///     }
36/// }
37///
38/// struct SubscriptionRoot;
39///
40/// #[Subscription]
41/// impl SubscriptionRoot {
42///     async fn tick(&self) -> impl Stream<Item = String> {
43///         asynk_strim::stream_fn(|mut yielder| async move {
44///             let mut interval = tokio::time::interval(Duration::from_secs(1));
45///             loop {
46///                 let n = interval.tick().await;
47///                 yielder
48///                     .yield_item(format!("{}", n.elapsed().as_secs_f32()))
49///                     .await;
50///             }
51///         })
52///     }
53/// }
54///
55/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
56/// let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
57/// let filter =
58///     async_graphql_warp::graphql_subscription(schema).or(warp::any().map(|| "Hello, World!"));
59/// warp::serve(filter).run(([0, 0, 0, 0], 8000)).await;
60/// # });
61/// ```
62pub fn graphql_subscription<E>(
63    executor: E,
64) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone
65where
66    E: Executor,
67{
68    warp::ws()
69        .and(graphql_protocol())
70        .map(move |ws: ws::Ws, protocol| {
71            let executor = executor.clone();
72
73            let reply = ws.on_upgrade(move |socket| {
74                GraphQLWebSocket::new(socket, executor, protocol).serve()
75            });
76
77            warp::reply::with_header(
78                reply,
79                "Sec-WebSocket-Protocol",
80                protocol.sec_websocket_protocol(),
81            )
82        })
83}
84
85/// Create a `Filter` that parse [WebSocketProtocols] from
86/// `sec-websocket-protocol` header.
87pub fn graphql_protocol() -> impl Filter<Extract = (WebSocketProtocols,), Error = Rejection> + Clone
88{
89    warp::header::optional::<String>("sec-websocket-protocol").map(|protocols: Option<String>| {
90        protocols
91            .and_then(|protocols| {
92                protocols
93                    .split(',')
94                    .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
95            })
96            .unwrap_or(WebSocketProtocols::SubscriptionsTransportWS)
97    })
98}
99
100/// A Websocket connection for GraphQL subscription.
101///
102/// # Examples
103///
104/// ```no_run
105/// use std::time::Duration;
106///
107/// use async_graphql::*;
108/// use async_graphql_warp::*;
109/// use futures_util::stream::{Stream, StreamExt};
110/// use warp::{Filter, ws};
111///
112/// struct QueryRoot;
113///
114/// #[Object]
115/// impl QueryRoot {
116///     async fn value(&self) -> i32 {
117///         // A GraphQL Object type must define one or more fields.
118///         100
119///     }
120/// }
121///
122/// struct SubscriptionRoot;
123///
124/// #[Subscription]
125/// impl SubscriptionRoot {
126///     async fn tick(&self) -> impl Stream<Item = String> {
127///         asynk_strim::stream_fn(|mut yielder| async move {
128///             let mut interval = tokio::time::interval(Duration::from_secs(1));
129///             loop {
130///                 let n = interval.tick().await;
131///                 yielder
132///                     .yield_item(format!("{}", n.elapsed().as_secs_f32()))
133///                     .await;
134///             }
135///         })
136///     }
137/// }
138///
139/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
140/// let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
141///
142/// let filter = warp::ws()
143///     .and(graphql_protocol())
144///     .map(move |ws: ws::Ws, protocol| {
145///         let schema = schema.clone();
146///
147///         let reply = ws
148///             .on_upgrade(move |socket| GraphQLWebSocket::new(socket, schema, protocol).serve());
149///
150///         warp::reply::with_header(
151///             reply,
152///             "Sec-WebSocket-Protocol",
153///             protocol.sec_websocket_protocol(),
154///         )
155///     });
156///
157/// warp::serve(filter).run(([0, 0, 0, 0], 8000)).await;
158/// # });
159/// ```
160pub struct GraphQLWebSocket<Sink, Stream, E, OnInit, OnPing> {
161    sink: Sink,
162    stream: Stream,
163    protocol: WebSocketProtocols,
164    executor: E,
165    data: Data,
166    on_init: OnInit,
167    on_ping: OnPing,
168    keepalive_timeout: Option<Duration>,
169}
170
171impl<S, E>
172    GraphQLWebSocket<
173        SplitSink<S, Message>,
174        SplitStream<S>,
175        E,
176        DefaultOnConnInitType,
177        DefaultOnPingType,
178    >
179where
180    S: Stream<Item = Result<Message, Error>> + Sink<Message>,
181    E: Executor,
182{
183    /// Create a [`GraphQLWebSocket`] object.
184    pub fn new(socket: S, executor: E, protocol: WebSocketProtocols) -> Self {
185        let (sink, stream) = socket.split();
186        GraphQLWebSocket::new_with_pair(sink, stream, executor, protocol)
187    }
188}
189
190impl<Sink, Stream, E> GraphQLWebSocket<Sink, Stream, E, DefaultOnConnInitType, DefaultOnPingType>
191where
192    Sink: futures_util::sink::Sink<Message>,
193    Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
194    E: Executor,
195{
196    /// Create a [`GraphQLWebSocket`] object with sink and stream objects.
197    pub fn new_with_pair(
198        sink: Sink,
199        stream: Stream,
200        executor: E,
201        protocol: WebSocketProtocols,
202    ) -> Self {
203        GraphQLWebSocket {
204            sink,
205            stream,
206            protocol,
207            executor,
208            data: Data::default(),
209            on_init: default_on_connection_init,
210            on_ping: default_on_ping,
211            keepalive_timeout: None,
212        }
213    }
214}
215
216impl<Sink, Stream, E, OnConnInit, OnConnInitFut, OnPing, OnPingFut>
217    GraphQLWebSocket<Sink, Stream, E, OnConnInit, OnPing>
218where
219    Sink: futures_util::sink::Sink<Message>,
220    Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
221    E: Executor,
222    OnConnInit: FnOnce(serde_json::Value) -> OnConnInitFut + Send + 'static,
223    OnConnInitFut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
224    OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> OnPingFut + Clone + Send + 'static,
225    OnPingFut: Future<Output = async_graphql::Result<Option<serde_json::Value>>> + Send + 'static,
226{
227    /// Specify the initial subscription context data, usually you can get
228    /// something from the incoming request to create it.
229    #[must_use]
230    pub fn with_data(self, data: Data) -> Self {
231        Self { data, ..self }
232    }
233
234    /// Specify a callback function to be called when the connection is
235    /// initialized.
236    ///
237    /// You can get something from the payload of [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init) to create [`Data`].
238    /// The data returned by this callback function will be merged with the data
239    /// specified by [`with_data`].
240    #[must_use]
241    pub fn on_connection_init<F, R>(
242        self,
243        callback: F,
244    ) -> GraphQLWebSocket<Sink, Stream, E, F, OnPing>
245    where
246        F: FnOnce(serde_json::Value) -> R + Send + 'static,
247        R: Future<Output = async_graphql::Result<Data>> + Send + 'static,
248    {
249        GraphQLWebSocket {
250            sink: self.sink,
251            stream: self.stream,
252            executor: self.executor,
253            data: self.data,
254            on_init: callback,
255            on_ping: self.on_ping,
256            protocol: self.protocol,
257            keepalive_timeout: self.keepalive_timeout,
258        }
259    }
260
261    /// Specify a ping callback function.
262    ///
263    /// This function if present, will be called with the data sent by the
264    /// client in the [`Ping` message](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#ping).
265    ///
266    /// The function should return the data to be sent in the [`Pong` message](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#pong).
267    ///
268    /// NOTE: Only used for the `graphql-ws` protocol.
269    #[must_use]
270    pub fn on_ping<F, R>(self, callback: F) -> GraphQLWebSocket<Sink, Stream, E, OnConnInit, F>
271    where
272        F: FnOnce(Option<&Data>, Option<serde_json::Value>) -> R + Send + Clone + 'static,
273        R: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
274    {
275        GraphQLWebSocket {
276            sink: self.sink,
277            stream: self.stream,
278            executor: self.executor,
279            data: self.data,
280            on_init: self.on_init,
281            on_ping: callback,
282            protocol: self.protocol,
283            keepalive_timeout: self.keepalive_timeout,
284        }
285    }
286
287    /// Sets a timeout for receiving an acknowledgement of the keep-alive ping.
288    ///
289    /// If the ping is not acknowledged within the timeout, the connection will
290    /// be closed.
291    ///
292    /// NOTE: Only used for the `graphql-ws` protocol.
293    #[must_use]
294    pub fn keepalive_timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
295        Self {
296            keepalive_timeout: timeout.into(),
297            ..self
298        }
299    }
300
301    /// Processing subscription requests.
302    pub async fn serve(self) {
303        let stream = self
304            .stream
305            .take_while(|msg| future::ready(msg.is_ok()))
306            .map(Result::unwrap)
307            .filter(|msg| future::ready(msg.is_text() || msg.is_binary()))
308            .map(ws::Message::into_bytes);
309
310        let _ = async_graphql::http::WebSocket::new(self.executor.clone(), stream, self.protocol)
311            .connection_data(self.data)
312            .on_connection_init(self.on_init)
313            .on_ping(self.on_ping)
314            .keepalive_timeout(self.keepalive_timeout)
315            .map(|msg| match msg {
316                WsMessage::Text(text) => ws::Message::text(text),
317                WsMessage::Close(code, status) => ws::Message::close_with(code, status),
318            })
319            .map(Ok)
320            .forward(self.sink)
321            .await;
322    }
323}