async_graphql_warp/subscription.rs
1use std::{future::Future, str::FromStr, time::Duration};
2
3use async_graphql::{
4 Data, Executor, Result,
5 http::{
6 DefaultOnConnInitType, DefaultOnPingType, WebSocketProtocols, WsMessage,
7 default_on_connection_init, default_on_ping,
8 },
9};
10use futures_util::{
11 Sink, Stream, StreamExt, future,
12 stream::{SplitSink, SplitStream},
13};
14use warp::{Error, Filter, Rejection, Reply, filters::ws, ws::Message};
15
16/// GraphQL subscription filter
17///
18/// # Examples
19///
20/// ```no_run
21/// use std::time::Duration;
22///
23/// use async_graphql::*;
24/// use async_graphql_warp::*;
25/// use futures_util::stream::{Stream, StreamExt};
26/// use warp::Filter;
27///
28/// struct QueryRoot;
29///
30/// #[Object]
31/// impl QueryRoot {
32/// async fn value(&self) -> i32 {
33/// // A GraphQL Object type must define one or more fields.
34/// 100
35/// }
36/// }
37///
38/// struct SubscriptionRoot;
39///
40/// #[Subscription]
41/// impl SubscriptionRoot {
42/// async fn tick(&self) -> impl Stream<Item = String> {
43/// asynk_strim::stream_fn(|mut yielder| async move {
44/// let mut interval = tokio::time::interval(Duration::from_secs(1));
45/// loop {
46/// let n = interval.tick().await;
47/// yielder
48/// .yield_item(format!("{}", n.elapsed().as_secs_f32()))
49/// .await;
50/// }
51/// })
52/// }
53/// }
54///
55/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
56/// let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
57/// let filter =
58/// async_graphql_warp::graphql_subscription(schema).or(warp::any().map(|| "Hello, World!"));
59/// warp::serve(filter).run(([0, 0, 0, 0], 8000)).await;
60/// # });
61/// ```
62pub fn graphql_subscription<E>(
63 executor: E,
64) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone
65where
66 E: Executor,
67{
68 warp::ws()
69 .and(graphql_protocol())
70 .map(move |ws: ws::Ws, protocol| {
71 let executor = executor.clone();
72
73 let reply = ws.on_upgrade(move |socket| {
74 GraphQLWebSocket::new(socket, executor, protocol).serve()
75 });
76
77 warp::reply::with_header(
78 reply,
79 "Sec-WebSocket-Protocol",
80 protocol.sec_websocket_protocol(),
81 )
82 })
83}
84
85/// Create a `Filter` that parse [WebSocketProtocols] from
86/// `sec-websocket-protocol` header.
87pub fn graphql_protocol() -> impl Filter<Extract = (WebSocketProtocols,), Error = Rejection> + Clone
88{
89 warp::header::optional::<String>("sec-websocket-protocol").map(|protocols: Option<String>| {
90 protocols
91 .and_then(|protocols| {
92 protocols
93 .split(',')
94 .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
95 })
96 .unwrap_or(WebSocketProtocols::SubscriptionsTransportWS)
97 })
98}
99
100/// A Websocket connection for GraphQL subscription.
101///
102/// # Examples
103///
104/// ```no_run
105/// use std::time::Duration;
106///
107/// use async_graphql::*;
108/// use async_graphql_warp::*;
109/// use futures_util::stream::{Stream, StreamExt};
110/// use warp::{Filter, ws};
111///
112/// struct QueryRoot;
113///
114/// #[Object]
115/// impl QueryRoot {
116/// async fn value(&self) -> i32 {
117/// // A GraphQL Object type must define one or more fields.
118/// 100
119/// }
120/// }
121///
122/// struct SubscriptionRoot;
123///
124/// #[Subscription]
125/// impl SubscriptionRoot {
126/// async fn tick(&self) -> impl Stream<Item = String> {
127/// asynk_strim::stream_fn(|mut yielder| async move {
128/// let mut interval = tokio::time::interval(Duration::from_secs(1));
129/// loop {
130/// let n = interval.tick().await;
131/// yielder
132/// .yield_item(format!("{}", n.elapsed().as_secs_f32()))
133/// .await;
134/// }
135/// })
136/// }
137/// }
138///
139/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
140/// let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
141///
142/// let filter = warp::ws()
143/// .and(graphql_protocol())
144/// .map(move |ws: ws::Ws, protocol| {
145/// let schema = schema.clone();
146///
147/// let reply = ws
148/// .on_upgrade(move |socket| GraphQLWebSocket::new(socket, schema, protocol).serve());
149///
150/// warp::reply::with_header(
151/// reply,
152/// "Sec-WebSocket-Protocol",
153/// protocol.sec_websocket_protocol(),
154/// )
155/// });
156///
157/// warp::serve(filter).run(([0, 0, 0, 0], 8000)).await;
158/// # });
159/// ```
160pub struct GraphQLWebSocket<Sink, Stream, E, OnInit, OnPing> {
161 sink: Sink,
162 stream: Stream,
163 protocol: WebSocketProtocols,
164 executor: E,
165 data: Data,
166 on_init: OnInit,
167 on_ping: OnPing,
168 keepalive_timeout: Option<Duration>,
169}
170
171impl<S, E>
172 GraphQLWebSocket<
173 SplitSink<S, Message>,
174 SplitStream<S>,
175 E,
176 DefaultOnConnInitType,
177 DefaultOnPingType,
178 >
179where
180 S: Stream<Item = Result<Message, Error>> + Sink<Message>,
181 E: Executor,
182{
183 /// Create a [`GraphQLWebSocket`] object.
184 pub fn new(socket: S, executor: E, protocol: WebSocketProtocols) -> Self {
185 let (sink, stream) = socket.split();
186 GraphQLWebSocket::new_with_pair(sink, stream, executor, protocol)
187 }
188}
189
190impl<Sink, Stream, E> GraphQLWebSocket<Sink, Stream, E, DefaultOnConnInitType, DefaultOnPingType>
191where
192 Sink: futures_util::sink::Sink<Message>,
193 Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
194 E: Executor,
195{
196 /// Create a [`GraphQLWebSocket`] object with sink and stream objects.
197 pub fn new_with_pair(
198 sink: Sink,
199 stream: Stream,
200 executor: E,
201 protocol: WebSocketProtocols,
202 ) -> Self {
203 GraphQLWebSocket {
204 sink,
205 stream,
206 protocol,
207 executor,
208 data: Data::default(),
209 on_init: default_on_connection_init,
210 on_ping: default_on_ping,
211 keepalive_timeout: None,
212 }
213 }
214}
215
216impl<Sink, Stream, E, OnConnInit, OnConnInitFut, OnPing, OnPingFut>
217 GraphQLWebSocket<Sink, Stream, E, OnConnInit, OnPing>
218where
219 Sink: futures_util::sink::Sink<Message>,
220 Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
221 E: Executor,
222 OnConnInit: FnOnce(serde_json::Value) -> OnConnInitFut + Send + 'static,
223 OnConnInitFut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
224 OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> OnPingFut + Clone + Send + 'static,
225 OnPingFut: Future<Output = async_graphql::Result<Option<serde_json::Value>>> + Send + 'static,
226{
227 /// Specify the initial subscription context data, usually you can get
228 /// something from the incoming request to create it.
229 #[must_use]
230 pub fn with_data(self, data: Data) -> Self {
231 Self { data, ..self }
232 }
233
234 /// Specify a callback function to be called when the connection is
235 /// initialized.
236 ///
237 /// You can get something from the payload of [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init) to create [`Data`].
238 /// The data returned by this callback function will be merged with the data
239 /// specified by [`with_data`].
240 #[must_use]
241 pub fn on_connection_init<F, R>(
242 self,
243 callback: F,
244 ) -> GraphQLWebSocket<Sink, Stream, E, F, OnPing>
245 where
246 F: FnOnce(serde_json::Value) -> R + Send + 'static,
247 R: Future<Output = async_graphql::Result<Data>> + Send + 'static,
248 {
249 GraphQLWebSocket {
250 sink: self.sink,
251 stream: self.stream,
252 executor: self.executor,
253 data: self.data,
254 on_init: callback,
255 on_ping: self.on_ping,
256 protocol: self.protocol,
257 keepalive_timeout: self.keepalive_timeout,
258 }
259 }
260
261 /// Specify a ping callback function.
262 ///
263 /// This function if present, will be called with the data sent by the
264 /// client in the [`Ping` message](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#ping).
265 ///
266 /// The function should return the data to be sent in the [`Pong` message](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#pong).
267 ///
268 /// NOTE: Only used for the `graphql-ws` protocol.
269 #[must_use]
270 pub fn on_ping<F, R>(self, callback: F) -> GraphQLWebSocket<Sink, Stream, E, OnConnInit, F>
271 where
272 F: FnOnce(Option<&Data>, Option<serde_json::Value>) -> R + Send + Clone + 'static,
273 R: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
274 {
275 GraphQLWebSocket {
276 sink: self.sink,
277 stream: self.stream,
278 executor: self.executor,
279 data: self.data,
280 on_init: self.on_init,
281 on_ping: callback,
282 protocol: self.protocol,
283 keepalive_timeout: self.keepalive_timeout,
284 }
285 }
286
287 /// Sets a timeout for receiving an acknowledgement of the keep-alive ping.
288 ///
289 /// If the ping is not acknowledged within the timeout, the connection will
290 /// be closed.
291 ///
292 /// NOTE: Only used for the `graphql-ws` protocol.
293 #[must_use]
294 pub fn keepalive_timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
295 Self {
296 keepalive_timeout: timeout.into(),
297 ..self
298 }
299 }
300
301 /// Processing subscription requests.
302 pub async fn serve(self) {
303 let stream = self
304 .stream
305 .take_while(|msg| future::ready(msg.is_ok()))
306 .map(Result::unwrap)
307 .filter(|msg| future::ready(msg.is_text() || msg.is_binary()))
308 .map(ws::Message::into_bytes);
309
310 let _ = async_graphql::http::WebSocket::new(self.executor.clone(), stream, self.protocol)
311 .connection_data(self.data)
312 .on_connection_init(self.on_init)
313 .on_ping(self.on_ping)
314 .keepalive_timeout(self.keepalive_timeout)
315 .map(|msg| match msg {
316 WsMessage::Text(text) => ws::Message::text(text),
317 WsMessage::Close(code, status) => ws::Message::close_with(code, status),
318 })
319 .map(Ok)
320 .forward(self.sink)
321 .await;
322 }
323}