async_graphql_actix_web/
subscription.rs

1use std::{
2    future::Future,
3    str::FromStr,
4    time::{Duration, Instant},
5};
6
7use actix::{
8    Actor, ActorContext, ActorFutureExt, ActorStreamExt, AsyncContext, ContextFutureSpawner,
9    StreamHandler, WrapFuture, WrapStream,
10};
11use actix_http::{error::PayloadError, ws};
12use actix_web::{Error, HttpRequest, HttpResponse, web::Bytes};
13use actix_web_actors::ws::{CloseReason, Message, ProtocolError, WebsocketContext};
14use async_graphql::{
15    Data, Executor, Result,
16    http::{
17        ALL_WEBSOCKET_PROTOCOLS, DefaultOnConnInitType, DefaultOnPingType, WebSocket,
18        WebSocketProtocols, WsMessage, default_on_connection_init, default_on_ping,
19    },
20};
21use futures_util::stream::Stream;
22
23const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
24const CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
25
26#[derive(thiserror::Error, Debug)]
27#[error("failed to parse graphql protocol")]
28pub struct ParseGraphQLProtocolError;
29
30/// A builder for websocket subscription actor.
31pub struct GraphQLSubscription<E, OnInit, OnPing> {
32    executor: E,
33    data: Data,
34    on_connection_init: OnInit,
35    on_ping: OnPing,
36    keepalive_timeout: Option<Duration>,
37}
38
39impl<E> GraphQLSubscription<E, DefaultOnConnInitType, DefaultOnPingType> {
40    /// Create a GraphQL subscription builder.
41    pub fn new(executor: E) -> Self {
42        Self {
43            executor,
44            data: Default::default(),
45            on_connection_init: default_on_connection_init,
46            on_ping: default_on_ping,
47            keepalive_timeout: None,
48        }
49    }
50}
51
52impl<E, OnInit, OnInitFut, OnPing, OnPingFut> GraphQLSubscription<E, OnInit, OnPing>
53where
54    E: Executor,
55    OnInit: FnOnce(serde_json::Value) -> OnInitFut + Unpin + Send + 'static,
56    OnInitFut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
57    OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> OnPingFut
58        + Clone
59        + Unpin
60        + Send
61        + 'static,
62    OnPingFut: Future<Output = async_graphql::Result<Option<serde_json::Value>>> + Send + 'static,
63{
64    /// Specify the initial subscription context data, usually you can get
65    /// something from the incoming request to create it.
66    #[must_use]
67    pub fn with_data(self, data: Data) -> Self {
68        Self { data, ..self }
69    }
70
71    /// Specify a callback function to be called when the connection is
72    /// initialized.
73    ///
74    /// You can get something from the payload of [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init) to create [`Data`].
75    /// The data returned by this callback function will be merged with the data
76    /// specified by [`with_data`].
77    #[must_use]
78    pub fn on_connection_init<F, R>(self, callback: F) -> GraphQLSubscription<E, F, OnPing>
79    where
80        F: FnOnce(serde_json::Value) -> R + Unpin + Send + 'static,
81        R: Future<Output = async_graphql::Result<Data>> + Send + 'static,
82    {
83        GraphQLSubscription {
84            executor: self.executor,
85            data: self.data,
86            on_connection_init: callback,
87            on_ping: self.on_ping,
88            keepalive_timeout: self.keepalive_timeout,
89        }
90    }
91
92    /// Specify a ping callback function.
93    ///
94    /// This function if present, will be called with the data sent by the
95    /// client in the [`Ping` message](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#ping).
96    ///
97    /// The function should return the data to be sent in the [`Pong` message](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#pong).
98    ///
99    /// NOTE: Only used for the `graphql-ws` protocol.
100    #[must_use]
101    pub fn on_ping<F, R>(self, callback: F) -> GraphQLSubscription<E, OnInit, F>
102    where
103        F: FnOnce(Option<&Data>, Option<serde_json::Value>) -> R + Send + Clone + 'static,
104        R: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
105    {
106        GraphQLSubscription {
107            executor: self.executor,
108            data: self.data,
109            on_connection_init: self.on_connection_init,
110            on_ping: callback,
111            keepalive_timeout: self.keepalive_timeout,
112        }
113    }
114
115    /// Sets a timeout for receiving an acknowledgement of the keep-alive ping.
116    ///
117    /// If the ping is not acknowledged within the timeout, the connection will
118    /// be closed.
119    ///
120    /// NOTE: Only used for the `graphql-ws` protocol.
121    #[must_use]
122    pub fn keepalive_timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
123        Self {
124            keepalive_timeout: timeout.into(),
125            ..self
126        }
127    }
128
129    /// Start the subscription actor.
130    pub fn start<S>(self, request: &HttpRequest, stream: S) -> Result<HttpResponse, Error>
131    where
132        S: Stream<Item = Result<Bytes, PayloadError>> + 'static,
133    {
134        let protocol = request
135            .headers()
136            .get("sec-websocket-protocol")
137            .and_then(|value| value.to_str().ok())
138            .and_then(|protocols| {
139                protocols
140                    .split(',')
141                    .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
142            })
143            .ok_or_else(|| actix_web::error::ErrorBadRequest(ParseGraphQLProtocolError))?;
144
145        let actor = GraphQLSubscriptionActor {
146            executor: self.executor,
147            data: Some(self.data),
148            protocol,
149            last_heartbeat: Instant::now(),
150            messages: None,
151            on_connection_init: Some(self.on_connection_init),
152            on_ping: self.on_ping,
153            keepalive_timeout: self.keepalive_timeout,
154            continuation: Vec::new(),
155        };
156
157        actix_web_actors::ws::WsResponseBuilder::new(actor, request, stream)
158            .protocols(&ALL_WEBSOCKET_PROTOCOLS)
159            .start()
160    }
161}
162
163struct GraphQLSubscriptionActor<E, OnInit, OnPing> {
164    executor: E,
165    data: Option<Data>,
166    protocol: WebSocketProtocols,
167    last_heartbeat: Instant,
168    messages: Option<async_channel::Sender<Vec<u8>>>,
169    on_connection_init: Option<OnInit>,
170    on_ping: OnPing,
171    keepalive_timeout: Option<Duration>,
172    continuation: Vec<u8>,
173}
174
175impl<E, OnInit, OnInitFut, OnPing, OnPingFut> GraphQLSubscriptionActor<E, OnInit, OnPing>
176where
177    E: Executor,
178    OnInit: FnOnce(serde_json::Value) -> OnInitFut + Unpin + Send + 'static,
179    OnInitFut: Future<Output = Result<Data>> + Send + 'static,
180    OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> OnPingFut
181        + Clone
182        + Unpin
183        + Send
184        + 'static,
185    OnPingFut: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
186{
187    fn send_heartbeats(&self, ctx: &mut WebsocketContext<Self>) {
188        ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| {
189            if Instant::now().duration_since(act.last_heartbeat) > CLIENT_TIMEOUT {
190                ctx.stop();
191            }
192            ctx.ping(b"");
193        });
194    }
195}
196
197impl<E, OnInit, OnInitFut, OnPing, OnPingFut> Actor for GraphQLSubscriptionActor<E, OnInit, OnPing>
198where
199    E: Executor,
200    OnInit: FnOnce(serde_json::Value) -> OnInitFut + Unpin + Send + 'static,
201    OnInitFut: Future<Output = Result<Data>> + Send + 'static,
202    OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> OnPingFut
203        + Clone
204        + Unpin
205        + Send
206        + 'static,
207    OnPingFut: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
208{
209    type Context = WebsocketContext<Self>;
210
211    fn started(&mut self, ctx: &mut Self::Context) {
212        self.send_heartbeats(ctx);
213
214        let (tx, rx) = async_channel::unbounded();
215
216        WebSocket::new(self.executor.clone(), rx, self.protocol)
217            .connection_data(self.data.take().unwrap())
218            .on_connection_init(self.on_connection_init.take().unwrap())
219            .on_ping(self.on_ping.clone())
220            .keepalive_timeout(self.keepalive_timeout)
221            .into_actor(self)
222            .map(|response, _act, ctx| match response {
223                WsMessage::Text(text) => ctx.text(text),
224                WsMessage::Close(code, msg) => ctx.close(Some(CloseReason {
225                    code: code.into(),
226                    description: Some(msg),
227                })),
228            })
229            .finish()
230            .spawn(ctx);
231
232        self.messages = Some(tx);
233    }
234}
235
236impl<E, OnInit, OnInitFut, OnPing, OnPingFut> StreamHandler<Result<Message, ProtocolError>>
237    for GraphQLSubscriptionActor<E, OnInit, OnPing>
238where
239    E: Executor,
240    OnInit: FnOnce(serde_json::Value) -> OnInitFut + Unpin + Send + 'static,
241    OnInitFut: Future<Output = Result<Data>> + Send + 'static,
242    OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> OnPingFut
243        + Clone
244        + Unpin
245        + Send
246        + 'static,
247    OnPingFut: Future<Output = async_graphql::Result<Option<serde_json::Value>>> + Send + 'static,
248{
249    fn handle(&mut self, msg: Result<Message, ProtocolError>, ctx: &mut Self::Context) {
250        let msg = match msg {
251            Err(_) => {
252                ctx.stop();
253                return;
254            }
255            Ok(msg) => msg,
256        };
257
258        let message = match msg {
259            Message::Ping(msg) => {
260                self.last_heartbeat = Instant::now();
261                ctx.pong(&msg);
262                None
263            }
264            Message::Pong(_) => {
265                self.last_heartbeat = Instant::now();
266                None
267            }
268            Message::Continuation(item) => match item {
269                ws::Item::FirstText(bytes) | ws::Item::FirstBinary(bytes) => {
270                    self.continuation = bytes.to_vec();
271                    None
272                }
273                ws::Item::Continue(bytes) => {
274                    self.continuation.extend_from_slice(&bytes);
275                    None
276                }
277                ws::Item::Last(bytes) => {
278                    self.continuation.extend_from_slice(&bytes);
279                    Some(std::mem::take(&mut self.continuation))
280                }
281            },
282            Message::Text(s) => Some(s.into_bytes().to_vec()),
283            Message::Binary(bytes) => Some(bytes.to_vec()),
284            Message::Close(_) => {
285                ctx.stop();
286                None
287            }
288            Message::Nop => None,
289        };
290
291        if let Some(message) = message {
292            let sender = self.messages.as_ref().unwrap().clone();
293
294            async move { sender.send(message).await }
295                .into_actor(self)
296                .map(|res, _actor, ctx| match res {
297                    Ok(()) => {}
298                    Err(_) => ctx.stop(),
299                })
300                .spawn(ctx)
301        }
302    }
303}