async_graphql_poem/
subscription.rs

1use std::{io::Error as IoError, str::FromStr, time::Duration};
2
3use async_graphql::{
4    Data, Executor,
5    http::{
6        ALL_WEBSOCKET_PROTOCOLS, DefaultOnConnInitType, DefaultOnPingType, WebSocketProtocols,
7        WsMessage, default_on_connection_init, default_on_ping,
8    },
9};
10use futures_util::{
11    Future, Sink, SinkExt, Stream, StreamExt,
12    future::{self},
13    stream::{SplitSink, SplitStream},
14};
15use poem::{
16    Endpoint, Error, FromRequest, IntoResponse, Request, RequestBody, Response, Result,
17    http::StatusCode,
18    web::websocket::{Message, WebSocket},
19};
20
21/// A GraphQL protocol extractor.
22///
23/// It extract GraphQL protocol from `SEC_WEBSOCKET_PROTOCOL` header.
24#[derive(Debug, Copy, Clone, PartialEq, Eq)]
25pub struct GraphQLProtocol(pub WebSocketProtocols);
26
27impl<'a> FromRequest<'a> for GraphQLProtocol {
28    async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result<Self> {
29        req.headers()
30            .get(http::header::SEC_WEBSOCKET_PROTOCOL)
31            .and_then(|value| value.to_str().ok())
32            .and_then(|protocols| {
33                protocols
34                    .split(',')
35                    .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
36            })
37            .map(Self)
38            .ok_or_else(|| Error::from_status(StatusCode::BAD_REQUEST))
39    }
40}
41
42/// A GraphQL subscription endpoint.
43///
44/// # Example
45///
46/// ```
47/// use async_graphql::{EmptyMutation, Object, Schema, Subscription};
48/// use async_graphql_poem::GraphQLSubscription;
49/// use futures_util::{Stream, stream};
50/// use poem::{Route, get};
51///
52/// struct Query;
53///
54/// #[Object]
55/// impl Query {
56///     async fn value(&self) -> i32 {
57///         100
58///     }
59/// }
60///
61/// struct Subscription;
62///
63/// #[Subscription]
64/// impl Subscription {
65///     async fn values(&self) -> impl Stream<Item = i32> {
66///         stream::iter(vec![1, 2, 3, 4, 5])
67///     }
68/// }
69///
70/// type MySchema = Schema<Query, EmptyMutation, Subscription>;
71///
72/// let schema = Schema::new(Query, EmptyMutation, Subscription);
73/// let app = Route::new().at("/ws", get(GraphQLSubscription::new(schema)));
74/// ```
75pub struct GraphQLSubscription<E> {
76    executor: E,
77}
78
79impl<E> GraphQLSubscription<E> {
80    /// Create a GraphQL subscription endpoint.
81    pub fn new(executor: E) -> Self {
82        Self { executor }
83    }
84}
85
86impl<E> Endpoint for GraphQLSubscription<E>
87where
88    E: Executor,
89{
90    type Output = Response;
91
92    async fn call(&self, req: Request) -> Result<Self::Output> {
93        let (req, mut body) = req.split();
94        let websocket = WebSocket::from_request(&req, &mut body).await?;
95        let protocol = GraphQLProtocol::from_request(&req, &mut body).await?;
96        let executor = self.executor.clone();
97
98        let resp = websocket
99            .protocols(ALL_WEBSOCKET_PROTOCOLS)
100            .on_upgrade(move |stream| GraphQLWebSocket::new(stream, executor, protocol).serve())
101            .into_response();
102        Ok(resp)
103    }
104}
105
106/// A Websocket connection for GraphQL subscription.
107pub struct GraphQLWebSocket<Sink, Stream, E, OnConnInit, OnPing> {
108    sink: Sink,
109    stream: Stream,
110    executor: E,
111    data: Data,
112    on_connection_init: OnConnInit,
113    on_ping: OnPing,
114    protocol: GraphQLProtocol,
115    keepalive_timeout: Option<Duration>,
116}
117
118impl<S, E>
119    GraphQLWebSocket<
120        SplitSink<S, Message>,
121        SplitStream<S>,
122        E,
123        DefaultOnConnInitType,
124        DefaultOnPingType,
125    >
126where
127    S: Stream<Item = Result<Message, IoError>> + Sink<Message>,
128    E: Executor,
129{
130    /// Create a [`GraphQLWebSocket`] object.
131    pub fn new(stream: S, executor: E, protocol: GraphQLProtocol) -> Self {
132        let (sink, stream) = stream.split();
133        GraphQLWebSocket::new_with_pair(sink, stream, executor, protocol)
134    }
135}
136
137impl<Sink, Stream, E> GraphQLWebSocket<Sink, Stream, E, DefaultOnConnInitType, DefaultOnPingType>
138where
139    Sink: futures_util::sink::Sink<Message>,
140    Stream: futures_util::stream::Stream<Item = Result<Message, IoError>>,
141    E: Executor,
142{
143    /// Create a [`GraphQLWebSocket`] object with sink and stream objects.
144    pub fn new_with_pair(
145        sink: Sink,
146        stream: Stream,
147        executor: E,
148        protocol: GraphQLProtocol,
149    ) -> Self {
150        GraphQLWebSocket {
151            sink,
152            stream,
153            executor,
154            data: Data::default(),
155            on_connection_init: default_on_connection_init,
156            on_ping: default_on_ping,
157            protocol,
158            keepalive_timeout: None,
159        }
160    }
161}
162
163impl<Sink, Stream, E, OnConnInit, OnConnInitFut, OnPing, OnPingFut>
164    GraphQLWebSocket<Sink, Stream, E, OnConnInit, OnPing>
165where
166    Sink: futures_util::sink::Sink<Message>,
167    Stream: futures_util::stream::Stream<Item = Result<Message, IoError>>,
168    E: Executor,
169    OnConnInit: FnOnce(serde_json::Value) -> OnConnInitFut + Send + 'static,
170    OnConnInitFut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
171    OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> OnPingFut + Clone + Send + 'static,
172    OnPingFut: Future<Output = async_graphql::Result<Option<serde_json::Value>>> + Send + 'static,
173{
174    /// Specify the initial subscription context data, usually you can get
175    /// something from the incoming request to create it.
176    #[must_use]
177    pub fn with_data(self, data: Data) -> Self {
178        Self { data, ..self }
179    }
180
181    /// Specify a callback function to be called when the connection is
182    /// initialized.
183    ///
184    /// You can get something from the payload of [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init) to create [`Data`].
185    /// The data returned by this callback function will be merged with the data
186    /// specified by [`with_data`].
187    #[must_use]
188    pub fn on_connection_init<F, R>(
189        self,
190        callback: F,
191    ) -> GraphQLWebSocket<Sink, Stream, E, F, OnPing>
192    where
193        F: FnOnce(serde_json::Value) -> R + Send + 'static,
194        R: Future<Output = async_graphql::Result<Data>> + Send + 'static,
195    {
196        GraphQLWebSocket {
197            sink: self.sink,
198            stream: self.stream,
199            executor: self.executor,
200            data: self.data,
201            on_connection_init: callback,
202            on_ping: self.on_ping,
203            protocol: self.protocol,
204            keepalive_timeout: self.keepalive_timeout,
205        }
206    }
207
208    /// Specify a ping callback function.
209    ///
210    /// This function if present, will be called with the data sent by the
211    /// client in the [`Ping` message](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#ping).
212    ///
213    /// The function should return the data to be sent in the [`Pong` message](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#pong).
214    ///
215    /// NOTE: Only used for the `graphql-ws` protocol.
216    #[must_use]
217    pub fn on_ping<F, R>(self, callback: F) -> GraphQLWebSocket<Sink, Stream, E, OnConnInit, F>
218    where
219        F: FnOnce(Option<&Data>, Option<serde_json::Value>) -> R + Clone + Send + 'static,
220        R: Future<Output = async_graphql::Result<Option<serde_json::Value>>> + Send + 'static,
221    {
222        GraphQLWebSocket {
223            sink: self.sink,
224            stream: self.stream,
225            executor: self.executor,
226            data: self.data,
227            on_connection_init: self.on_connection_init,
228            on_ping: callback,
229            protocol: self.protocol,
230            keepalive_timeout: self.keepalive_timeout,
231        }
232    }
233
234    /// Sets a timeout for receiving an acknowledgement of the keep-alive ping.
235    ///
236    /// If the ping is not acknowledged within the timeout, the connection will
237    /// be closed.
238    ///
239    /// NOTE: Only used for the `graphql-ws` protocol.
240    #[must_use]
241    pub fn keepalive_timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
242        Self {
243            keepalive_timeout: timeout.into(),
244            ..self
245        }
246    }
247
248    /// Processing subscription requests.
249    pub async fn serve(self) {
250        let stream = self
251            .stream
252            .take_while(|res| future::ready(res.is_ok()))
253            .map(Result::unwrap)
254            .filter_map(|msg| {
255                if msg.is_text() || msg.is_binary() {
256                    future::ready(Some(msg))
257                } else {
258                    future::ready(None)
259                }
260            })
261            .map(Message::into_bytes);
262
263        let stream =
264            async_graphql::http::WebSocket::new(self.executor.clone(), stream, self.protocol.0)
265                .connection_data(self.data)
266                .on_connection_init(self.on_connection_init)
267                .on_ping(self.on_ping.clone())
268                .keepalive_timeout(self.keepalive_timeout)
269                .map(|msg| match msg {
270                    WsMessage::Text(text) => Message::text(text),
271                    WsMessage::Close(code, status) => Message::close_with(code, status),
272                });
273
274        let sink = self.sink;
275        futures_util::pin_mut!(stream, sink);
276
277        while let Some(item) = stream.next().await {
278            if sink.send(item).await.is_err() {
279                break;
280            }
281        }
282    }
283}