async_graphql_warp/subscription.rs
1use std::{future::Future, str::FromStr, time::Duration};
2
3use async_graphql::{
4 Data, Executor, Result,
5 http::{
6 DefaultOnConnInitType, DefaultOnPingType, WebSocketProtocols, WsMessage,
7 default_on_connection_init, default_on_ping,
8 },
9};
10use futures_util::{
11 Sink, Stream, StreamExt, future,
12 stream::{SplitSink, SplitStream},
13};
14use warp::{Error, Filter, Rejection, Reply, filters::ws, ws::Message};
15
16/// GraphQL subscription filter
17///
18/// # Examples
19///
20/// ```no_run
21/// use std::time::Duration;
22///
23/// use async_graphql::*;
24/// use async_graphql_warp::*;
25/// use futures_util::stream::{Stream, StreamExt};
26/// use warp::Filter;
27///
28/// struct QueryRoot;
29///
30/// #[Object]
31/// impl QueryRoot {
32/// async fn value(&self) -> i32 {
33/// // A GraphQL Object type must define one or more fields.
34/// 100
35/// }
36/// }
37///
38/// struct SubscriptionRoot;
39///
40/// #[Subscription]
41/// impl SubscriptionRoot {
42/// async fn tick(&self) -> impl Stream<Item = String> {
43/// async_stream::stream! {
44/// let mut interval = tokio::time::interval(Duration::from_secs(1));
45/// loop {
46/// let n = interval.tick().await;
47/// yield format!("{}", n.elapsed().as_secs_f32());
48/// }
49/// }
50/// }
51/// }
52///
53/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
54/// let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
55/// let filter =
56/// async_graphql_warp::graphql_subscription(schema).or(warp::any().map(|| "Hello, World!"));
57/// warp::serve(filter).run(([0, 0, 0, 0], 8000)).await;
58/// # });
59/// ```
60pub fn graphql_subscription<E>(
61 executor: E,
62) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone
63where
64 E: Executor,
65{
66 warp::ws()
67 .and(graphql_protocol())
68 .map(move |ws: ws::Ws, protocol| {
69 let executor = executor.clone();
70
71 let reply = ws.on_upgrade(move |socket| {
72 GraphQLWebSocket::new(socket, executor, protocol).serve()
73 });
74
75 warp::reply::with_header(
76 reply,
77 "Sec-WebSocket-Protocol",
78 protocol.sec_websocket_protocol(),
79 )
80 })
81}
82
83/// Create a `Filter` that parse [WebSocketProtocols] from
84/// `sec-websocket-protocol` header.
85pub fn graphql_protocol() -> impl Filter<Extract = (WebSocketProtocols,), Error = Rejection> + Clone
86{
87 warp::header::optional::<String>("sec-websocket-protocol").map(|protocols: Option<String>| {
88 protocols
89 .and_then(|protocols| {
90 protocols
91 .split(',')
92 .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
93 })
94 .unwrap_or(WebSocketProtocols::SubscriptionsTransportWS)
95 })
96}
97
98/// A Websocket connection for GraphQL subscription.
99///
100/// # Examples
101///
102/// ```no_run
103/// use std::time::Duration;
104///
105/// use async_graphql::*;
106/// use async_graphql_warp::*;
107/// use futures_util::stream::{Stream, StreamExt};
108/// use warp::{Filter, ws};
109///
110/// struct QueryRoot;
111///
112/// #[Object]
113/// impl QueryRoot {
114/// async fn value(&self) -> i32 {
115/// // A GraphQL Object type must define one or more fields.
116/// 100
117/// }
118/// }
119///
120/// struct SubscriptionRoot;
121///
122/// #[Subscription]
123/// impl SubscriptionRoot {
124/// async fn tick(&self) -> impl Stream<Item = String> {
125/// async_stream::stream! {
126/// let mut interval = tokio::time::interval(Duration::from_secs(1));
127/// loop {
128/// let n = interval.tick().await;
129/// yield format!("{}", n.elapsed().as_secs_f32());
130/// }
131/// }
132/// }
133/// }
134///
135/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
136/// let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
137///
138/// let filter = warp::ws()
139/// .and(graphql_protocol())
140/// .map(move |ws: ws::Ws, protocol| {
141/// let schema = schema.clone();
142///
143/// let reply = ws
144/// .on_upgrade(move |socket| GraphQLWebSocket::new(socket, schema, protocol).serve());
145///
146/// warp::reply::with_header(
147/// reply,
148/// "Sec-WebSocket-Protocol",
149/// protocol.sec_websocket_protocol(),
150/// )
151/// });
152///
153/// warp::serve(filter).run(([0, 0, 0, 0], 8000)).await;
154/// # });
155/// ```
156pub struct GraphQLWebSocket<Sink, Stream, E, OnInit, OnPing> {
157 sink: Sink,
158 stream: Stream,
159 protocol: WebSocketProtocols,
160 executor: E,
161 data: Data,
162 on_init: OnInit,
163 on_ping: OnPing,
164 keepalive_timeout: Option<Duration>,
165}
166
167impl<S, E>
168 GraphQLWebSocket<
169 SplitSink<S, Message>,
170 SplitStream<S>,
171 E,
172 DefaultOnConnInitType,
173 DefaultOnPingType,
174 >
175where
176 S: Stream<Item = Result<Message, Error>> + Sink<Message>,
177 E: Executor,
178{
179 /// Create a [`GraphQLWebSocket`] object.
180 pub fn new(socket: S, executor: E, protocol: WebSocketProtocols) -> Self {
181 let (sink, stream) = socket.split();
182 GraphQLWebSocket::new_with_pair(sink, stream, executor, protocol)
183 }
184}
185
186impl<Sink, Stream, E> GraphQLWebSocket<Sink, Stream, E, DefaultOnConnInitType, DefaultOnPingType>
187where
188 Sink: futures_util::sink::Sink<Message>,
189 Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
190 E: Executor,
191{
192 /// Create a [`GraphQLWebSocket`] object with sink and stream objects.
193 pub fn new_with_pair(
194 sink: Sink,
195 stream: Stream,
196 executor: E,
197 protocol: WebSocketProtocols,
198 ) -> Self {
199 GraphQLWebSocket {
200 sink,
201 stream,
202 protocol,
203 executor,
204 data: Data::default(),
205 on_init: default_on_connection_init,
206 on_ping: default_on_ping,
207 keepalive_timeout: None,
208 }
209 }
210}
211
212impl<Sink, Stream, E, OnConnInit, OnConnInitFut, OnPing, OnPingFut>
213 GraphQLWebSocket<Sink, Stream, E, OnConnInit, OnPing>
214where
215 Sink: futures_util::sink::Sink<Message>,
216 Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
217 E: Executor,
218 OnConnInit: FnOnce(serde_json::Value) -> OnConnInitFut + Send + 'static,
219 OnConnInitFut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
220 OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> OnPingFut + Clone + Send + 'static,
221 OnPingFut: Future<Output = async_graphql::Result<Option<serde_json::Value>>> + Send + 'static,
222{
223 /// Specify the initial subscription context data, usually you can get
224 /// something from the incoming request to create it.
225 #[must_use]
226 pub fn with_data(self, data: Data) -> Self {
227 Self { data, ..self }
228 }
229
230 /// Specify a callback function to be called when the connection is
231 /// initialized.
232 ///
233 /// You can get something from the payload of [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init) to create [`Data`].
234 /// The data returned by this callback function will be merged with the data
235 /// specified by [`with_data`].
236 #[must_use]
237 pub fn on_connection_init<F, R>(
238 self,
239 callback: F,
240 ) -> GraphQLWebSocket<Sink, Stream, E, F, OnPing>
241 where
242 F: FnOnce(serde_json::Value) -> R + Send + 'static,
243 R: Future<Output = async_graphql::Result<Data>> + Send + 'static,
244 {
245 GraphQLWebSocket {
246 sink: self.sink,
247 stream: self.stream,
248 executor: self.executor,
249 data: self.data,
250 on_init: callback,
251 on_ping: self.on_ping,
252 protocol: self.protocol,
253 keepalive_timeout: self.keepalive_timeout,
254 }
255 }
256
257 /// Specify a ping callback function.
258 ///
259 /// This function if present, will be called with the data sent by the
260 /// client in the [`Ping` message](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#ping).
261 ///
262 /// The function should return the data to be sent in the [`Pong` message](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#pong).
263 ///
264 /// NOTE: Only used for the `graphql-ws` protocol.
265 #[must_use]
266 pub fn on_ping<F, R>(self, callback: F) -> GraphQLWebSocket<Sink, Stream, E, OnConnInit, F>
267 where
268 F: FnOnce(Option<&Data>, Option<serde_json::Value>) -> R + Send + Clone + 'static,
269 R: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
270 {
271 GraphQLWebSocket {
272 sink: self.sink,
273 stream: self.stream,
274 executor: self.executor,
275 data: self.data,
276 on_init: self.on_init,
277 on_ping: callback,
278 protocol: self.protocol,
279 keepalive_timeout: self.keepalive_timeout,
280 }
281 }
282
283 /// Sets a timeout for receiving an acknowledgement of the keep-alive ping.
284 ///
285 /// If the ping is not acknowledged within the timeout, the connection will
286 /// be closed.
287 ///
288 /// NOTE: Only used for the `graphql-ws` protocol.
289 #[must_use]
290 pub fn keepalive_timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
291 Self {
292 keepalive_timeout: timeout.into(),
293 ..self
294 }
295 }
296
297 /// Processing subscription requests.
298 pub async fn serve(self) {
299 let stream = self
300 .stream
301 .take_while(|msg| future::ready(msg.is_ok()))
302 .map(Result::unwrap)
303 .filter(|msg| future::ready(msg.is_text() || msg.is_binary()))
304 .map(ws::Message::into_bytes);
305
306 let _ = async_graphql::http::WebSocket::new(self.executor.clone(), stream, self.protocol)
307 .connection_data(self.data)
308 .on_connection_init(self.on_init)
309 .on_ping(self.on_ping)
310 .keepalive_timeout(self.keepalive_timeout)
311 .map(|msg| match msg {
312 WsMessage::Text(text) => ws::Message::text(text),
313 WsMessage::Close(code, status) => ws::Message::close_with(code, status),
314 })
315 .map(Ok)
316 .forward(self.sink)
317 .await;
318 }
319}