async_graphql_axum/
subscription.rs

1use std::{convert::Infallible, future::Future, str::FromStr, time::Duration};
2
3use async_graphql::{
4    futures_util::task::{Context, Poll},
5    http::{
6        default_on_connection_init, default_on_ping, DefaultOnConnInitType, DefaultOnPingType,
7        WebSocketProtocols, WsMessage, ALL_WEBSOCKET_PROTOCOLS,
8    },
9    Data, Executor, Result,
10};
11use axum::{
12    body::{Body, HttpBody},
13    extract::{
14        ws::{CloseFrame, Message},
15        FromRequestParts, WebSocketUpgrade,
16    },
17    http::{self, request::Parts, Request, Response, StatusCode},
18    response::IntoResponse,
19    Error,
20};
21use futures_util::{
22    future,
23    future::BoxFuture,
24    stream::{SplitSink, SplitStream},
25    Sink, SinkExt, Stream, StreamExt,
26};
27use tower_service::Service;
28
29/// A GraphQL protocol extractor.
30///
31/// It extract GraphQL protocol from `SEC_WEBSOCKET_PROTOCOL` header.
32#[derive(Debug, Copy, Clone, PartialEq, Eq)]
33pub struct GraphQLProtocol(WebSocketProtocols);
34
35impl<S> FromRequestParts<S> for GraphQLProtocol
36where
37    S: Send + Sync,
38{
39    type Rejection = StatusCode;
40
41    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
42        parts
43            .headers
44            .get(http::header::SEC_WEBSOCKET_PROTOCOL)
45            .and_then(|value| value.to_str().ok())
46            .and_then(|protocols| {
47                protocols
48                    .split(',')
49                    .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
50            })
51            .map(Self)
52            .ok_or(StatusCode::BAD_REQUEST)
53    }
54}
55
56/// A GraphQL subscription service.
57pub struct GraphQLSubscription<E> {
58    executor: E,
59}
60
61impl<E> Clone for GraphQLSubscription<E>
62where
63    E: Executor,
64{
65    fn clone(&self) -> Self {
66        Self {
67            executor: self.executor.clone(),
68        }
69    }
70}
71
72impl<E> GraphQLSubscription<E>
73where
74    E: Executor,
75{
76    /// Create a GraphQL subscription service.
77    pub fn new(executor: E) -> Self {
78        Self { executor }
79    }
80}
81
82impl<B, E> Service<Request<B>> for GraphQLSubscription<E>
83where
84    B: HttpBody + Send + 'static,
85    E: Executor,
86{
87    type Response = Response<Body>;
88    type Error = Infallible;
89    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
90
91    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
92        Poll::Ready(Ok(()))
93    }
94
95    fn call(&mut self, req: Request<B>) -> Self::Future {
96        let executor = self.executor.clone();
97
98        Box::pin(async move {
99            let (mut parts, _body) = req.into_parts();
100
101            let protocol = match GraphQLProtocol::from_request_parts(&mut parts, &()).await {
102                Ok(protocol) => protocol,
103                Err(err) => return Ok(err.into_response()),
104            };
105            let upgrade = match WebSocketUpgrade::from_request_parts(&mut parts, &()).await {
106                Ok(protocol) => protocol,
107                Err(err) => return Ok(err.into_response()),
108            };
109
110            let executor = executor.clone();
111
112            let resp = upgrade
113                .protocols(ALL_WEBSOCKET_PROTOCOLS)
114                .on_upgrade(move |stream| {
115                    GraphQLWebSocket::new(stream, executor, protocol).serve()
116                });
117            Ok(resp.into_response())
118        })
119    }
120}
121
122/// A Websocket connection for GraphQL subscription.
123pub struct GraphQLWebSocket<Sink, Stream, E, OnConnInit, OnPing> {
124    sink: Sink,
125    stream: Stream,
126    executor: E,
127    data: Data,
128    on_connection_init: OnConnInit,
129    on_ping: OnPing,
130    protocol: GraphQLProtocol,
131    keepalive_timeout: Option<Duration>,
132}
133
134impl<S, E>
135    GraphQLWebSocket<
136        SplitSink<S, Message>,
137        SplitStream<S>,
138        E,
139        DefaultOnConnInitType,
140        DefaultOnPingType,
141    >
142where
143    S: Stream<Item = Result<Message, Error>> + Sink<Message>,
144    E: Executor,
145{
146    /// Create a [`GraphQLWebSocket`] object.
147    pub fn new(stream: S, executor: E, protocol: GraphQLProtocol) -> Self {
148        let (sink, stream) = stream.split();
149        GraphQLWebSocket::new_with_pair(sink, stream, executor, protocol)
150    }
151}
152
153impl<Sink, Stream, E> GraphQLWebSocket<Sink, Stream, E, DefaultOnConnInitType, DefaultOnPingType>
154where
155    Sink: futures_util::sink::Sink<Message>,
156    Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
157    E: Executor,
158{
159    /// Create a [`GraphQLWebSocket`] object with sink and stream objects.
160    pub fn new_with_pair(
161        sink: Sink,
162        stream: Stream,
163        executor: E,
164        protocol: GraphQLProtocol,
165    ) -> Self {
166        GraphQLWebSocket {
167            sink,
168            stream,
169            executor,
170            data: Data::default(),
171            on_connection_init: default_on_connection_init,
172            on_ping: default_on_ping,
173            protocol,
174            keepalive_timeout: None,
175        }
176    }
177}
178
179impl<Sink, Stream, E, OnConnInit, OnConnInitFut, OnPing, OnPingFut>
180    GraphQLWebSocket<Sink, Stream, E, OnConnInit, OnPing>
181where
182    Sink: futures_util::sink::Sink<Message>,
183    Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
184    E: Executor,
185    OnConnInit: FnOnce(serde_json::Value) -> OnConnInitFut + Send + 'static,
186    OnConnInitFut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
187    OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> OnPingFut + Clone + Send + 'static,
188    OnPingFut: Future<Output = async_graphql::Result<Option<serde_json::Value>>> + Send + 'static,
189{
190    /// Specify the initial subscription context data, usually you can get
191    /// something from the incoming request to create it.
192    #[must_use]
193    pub fn with_data(self, data: Data) -> Self {
194        Self { data, ..self }
195    }
196
197    /// Specify a callback function to be called when the connection is
198    /// initialized.
199    ///
200    /// You can get something from the payload of [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init) to create [`Data`].
201    /// The data returned by this callback function will be merged with the data
202    /// specified by [`with_data`].
203    #[must_use]
204    pub fn on_connection_init<F, R>(
205        self,
206        callback: F,
207    ) -> GraphQLWebSocket<Sink, Stream, E, F, OnPing>
208    where
209        F: FnOnce(serde_json::Value) -> R + Send + 'static,
210        R: Future<Output = async_graphql::Result<Data>> + Send + 'static,
211    {
212        GraphQLWebSocket {
213            sink: self.sink,
214            stream: self.stream,
215            executor: self.executor,
216            data: self.data,
217            on_connection_init: callback,
218            on_ping: self.on_ping,
219            protocol: self.protocol,
220            keepalive_timeout: self.keepalive_timeout,
221        }
222    }
223
224    /// Specify a ping callback function.
225    ///
226    /// This function if present, will be called with the data sent by the
227    /// client in the [`Ping` message](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#ping).
228    ///
229    /// The function should return the data to be sent in the [`Pong` message](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#pong).
230    ///
231    /// NOTE: Only used for the `graphql-ws` protocol.
232    #[must_use]
233    pub fn on_ping<F, R>(self, callback: F) -> GraphQLWebSocket<Sink, Stream, E, OnConnInit, F>
234    where
235        F: FnOnce(Option<&Data>, Option<serde_json::Value>) -> R + Clone + Send + 'static,
236        R: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
237    {
238        GraphQLWebSocket {
239            sink: self.sink,
240            stream: self.stream,
241            executor: self.executor,
242            data: self.data,
243            on_connection_init: self.on_connection_init,
244            on_ping: callback,
245            protocol: self.protocol,
246            keepalive_timeout: self.keepalive_timeout,
247        }
248    }
249
250    /// Sets a timeout for receiving an acknowledgement of the keep-alive ping.
251    ///
252    /// If the ping is not acknowledged within the timeout, the connection will
253    /// be closed.
254    ///
255    /// NOTE: Only used for the `graphql-ws` protocol.
256    #[must_use]
257    pub fn keepalive_timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
258        Self {
259            keepalive_timeout: timeout.into(),
260            ..self
261        }
262    }
263
264    /// Processing subscription requests.
265    pub async fn serve(self) {
266        let input = self
267            .stream
268            .take_while(|res| future::ready(res.is_ok()))
269            .map(Result::unwrap)
270            .filter_map(|msg| {
271                if let Message::Text(_) | Message::Binary(_) = msg {
272                    future::ready(Some(msg))
273                } else {
274                    future::ready(None)
275                }
276            })
277            .map(Message::into_data);
278
279        let stream =
280            async_graphql::http::WebSocket::new(self.executor.clone(), input, self.protocol.0)
281                .connection_data(self.data)
282                .on_connection_init(self.on_connection_init)
283                .on_ping(self.on_ping.clone())
284                .keepalive_timeout(self.keepalive_timeout)
285                .map(|msg| match msg {
286                    WsMessage::Text(text) => Message::Text(text.into()),
287                    WsMessage::Close(code, status) => Message::Close(Some(CloseFrame {
288                        code,
289                        reason: status.into(),
290                    })),
291                });
292
293        let sink = self.sink;
294        futures_util::pin_mut!(stream, sink);
295
296        while let Some(item) = stream.next().await {
297            if sink.send(item).await.is_err() {
298                break;
299            }
300        }
301    }
302}