1use std::{
2 future::Future,
3 str::FromStr,
4 time::{Duration, Instant},
5};
6
7use actix::{
8 Actor, ActorContext, ActorFutureExt, ActorStreamExt, AsyncContext, ContextFutureSpawner,
9 StreamHandler, WrapFuture, WrapStream,
10};
11use actix_http::{error::PayloadError, ws};
12use actix_web::{Error, HttpRequest, HttpResponse, web::Bytes};
13use actix_web_actors::ws::{CloseReason, Message, ProtocolError, WebsocketContext};
14use async_graphql::{
15 Data, Executor, Result,
16 http::{
17 ALL_WEBSOCKET_PROTOCOLS, DefaultOnConnInitType, DefaultOnPingType, WebSocket,
18 WebSocketProtocols, WsMessage, default_on_connection_init, default_on_ping,
19 },
20};
21use futures_util::stream::Stream;
22
23const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
24const CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
25
26#[derive(thiserror::Error, Debug)]
27#[error("failed to parse graphql protocol")]
28pub struct ParseGraphQLProtocolError;
29
30pub struct GraphQLSubscription<E, OnInit, OnPing> {
32 executor: E,
33 data: Data,
34 on_connection_init: OnInit,
35 on_ping: OnPing,
36 keepalive_timeout: Option<Duration>,
37}
38
39impl<E> GraphQLSubscription<E, DefaultOnConnInitType, DefaultOnPingType> {
40 pub fn new(executor: E) -> Self {
42 Self {
43 executor,
44 data: Default::default(),
45 on_connection_init: default_on_connection_init,
46 on_ping: default_on_ping,
47 keepalive_timeout: None,
48 }
49 }
50}
51
52impl<E, OnInit, OnInitFut, OnPing, OnPingFut> GraphQLSubscription<E, OnInit, OnPing>
53where
54 E: Executor,
55 OnInit: FnOnce(serde_json::Value) -> OnInitFut + Unpin + Send + 'static,
56 OnInitFut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
57 OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> OnPingFut
58 + Clone
59 + Unpin
60 + Send
61 + 'static,
62 OnPingFut: Future<Output = async_graphql::Result<Option<serde_json::Value>>> + Send + 'static,
63{
64 #[must_use]
67 pub fn with_data(self, data: Data) -> Self {
68 Self { data, ..self }
69 }
70
71 #[must_use]
78 pub fn on_connection_init<F, R>(self, callback: F) -> GraphQLSubscription<E, F, OnPing>
79 where
80 F: FnOnce(serde_json::Value) -> R + Unpin + Send + 'static,
81 R: Future<Output = async_graphql::Result<Data>> + Send + 'static,
82 {
83 GraphQLSubscription {
84 executor: self.executor,
85 data: self.data,
86 on_connection_init: callback,
87 on_ping: self.on_ping,
88 keepalive_timeout: self.keepalive_timeout,
89 }
90 }
91
92 #[must_use]
101 pub fn on_ping<F, R>(self, callback: F) -> GraphQLSubscription<E, OnInit, F>
102 where
103 F: FnOnce(Option<&Data>, Option<serde_json::Value>) -> R + Send + Clone + 'static,
104 R: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
105 {
106 GraphQLSubscription {
107 executor: self.executor,
108 data: self.data,
109 on_connection_init: self.on_connection_init,
110 on_ping: callback,
111 keepalive_timeout: self.keepalive_timeout,
112 }
113 }
114
115 #[must_use]
122 pub fn keepalive_timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
123 Self {
124 keepalive_timeout: timeout.into(),
125 ..self
126 }
127 }
128
129 pub fn start<S>(self, request: &HttpRequest, stream: S) -> Result<HttpResponse, Error>
131 where
132 S: Stream<Item = Result<Bytes, PayloadError>> + 'static,
133 {
134 let protocol = request
135 .headers()
136 .get("sec-websocket-protocol")
137 .and_then(|value| value.to_str().ok())
138 .and_then(|protocols| {
139 protocols
140 .split(',')
141 .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
142 })
143 .ok_or_else(|| actix_web::error::ErrorBadRequest(ParseGraphQLProtocolError))?;
144
145 let actor = GraphQLSubscriptionActor {
146 executor: self.executor,
147 data: Some(self.data),
148 protocol,
149 last_heartbeat: Instant::now(),
150 messages: None,
151 on_connection_init: Some(self.on_connection_init),
152 on_ping: self.on_ping,
153 keepalive_timeout: self.keepalive_timeout,
154 continuation: Vec::new(),
155 };
156
157 actix_web_actors::ws::WsResponseBuilder::new(actor, request, stream)
158 .protocols(&ALL_WEBSOCKET_PROTOCOLS)
159 .start()
160 }
161}
162
163struct GraphQLSubscriptionActor<E, OnInit, OnPing> {
164 executor: E,
165 data: Option<Data>,
166 protocol: WebSocketProtocols,
167 last_heartbeat: Instant,
168 messages: Option<async_channel::Sender<Vec<u8>>>,
169 on_connection_init: Option<OnInit>,
170 on_ping: OnPing,
171 keepalive_timeout: Option<Duration>,
172 continuation: Vec<u8>,
173}
174
175impl<E, OnInit, OnInitFut, OnPing, OnPingFut> GraphQLSubscriptionActor<E, OnInit, OnPing>
176where
177 E: Executor,
178 OnInit: FnOnce(serde_json::Value) -> OnInitFut + Unpin + Send + 'static,
179 OnInitFut: Future<Output = Result<Data>> + Send + 'static,
180 OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> OnPingFut
181 + Clone
182 + Unpin
183 + Send
184 + 'static,
185 OnPingFut: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
186{
187 fn send_heartbeats(&self, ctx: &mut WebsocketContext<Self>) {
188 ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| {
189 if Instant::now().duration_since(act.last_heartbeat) > CLIENT_TIMEOUT {
190 ctx.stop();
191 }
192 ctx.ping(b"");
193 });
194 }
195}
196
197impl<E, OnInit, OnInitFut, OnPing, OnPingFut> Actor for GraphQLSubscriptionActor<E, OnInit, OnPing>
198where
199 E: Executor,
200 OnInit: FnOnce(serde_json::Value) -> OnInitFut + Unpin + Send + 'static,
201 OnInitFut: Future<Output = Result<Data>> + Send + 'static,
202 OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> OnPingFut
203 + Clone
204 + Unpin
205 + Send
206 + 'static,
207 OnPingFut: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
208{
209 type Context = WebsocketContext<Self>;
210
211 fn started(&mut self, ctx: &mut Self::Context) {
212 self.send_heartbeats(ctx);
213
214 let (tx, rx) = async_channel::unbounded();
215
216 WebSocket::new(self.executor.clone(), rx, self.protocol)
217 .connection_data(self.data.take().unwrap())
218 .on_connection_init(self.on_connection_init.take().unwrap())
219 .on_ping(self.on_ping.clone())
220 .keepalive_timeout(self.keepalive_timeout)
221 .into_actor(self)
222 .map(|response, _act, ctx| match response {
223 WsMessage::Text(text) => ctx.text(text),
224 WsMessage::Close(code, msg) => ctx.close(Some(CloseReason {
225 code: code.into(),
226 description: Some(msg),
227 })),
228 })
229 .finish()
230 .spawn(ctx);
231
232 self.messages = Some(tx);
233 }
234}
235
236impl<E, OnInit, OnInitFut, OnPing, OnPingFut> StreamHandler<Result<Message, ProtocolError>>
237 for GraphQLSubscriptionActor<E, OnInit, OnPing>
238where
239 E: Executor,
240 OnInit: FnOnce(serde_json::Value) -> OnInitFut + Unpin + Send + 'static,
241 OnInitFut: Future<Output = Result<Data>> + Send + 'static,
242 OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> OnPingFut
243 + Clone
244 + Unpin
245 + Send
246 + 'static,
247 OnPingFut: Future<Output = async_graphql::Result<Option<serde_json::Value>>> + Send + 'static,
248{
249 fn handle(&mut self, msg: Result<Message, ProtocolError>, ctx: &mut Self::Context) {
250 let msg = match msg {
251 Err(_) => {
252 ctx.stop();
253 return;
254 }
255 Ok(msg) => msg,
256 };
257
258 let message = match msg {
259 Message::Ping(msg) => {
260 self.last_heartbeat = Instant::now();
261 ctx.pong(&msg);
262 None
263 }
264 Message::Pong(_) => {
265 self.last_heartbeat = Instant::now();
266 None
267 }
268 Message::Continuation(item) => match item {
269 ws::Item::FirstText(bytes) | ws::Item::FirstBinary(bytes) => {
270 self.continuation = bytes.to_vec();
271 None
272 }
273 ws::Item::Continue(bytes) => {
274 self.continuation.extend_from_slice(&bytes);
275 None
276 }
277 ws::Item::Last(bytes) => {
278 self.continuation.extend_from_slice(&bytes);
279 Some(std::mem::take(&mut self.continuation))
280 }
281 },
282 Message::Text(s) => Some(s.into_bytes().to_vec()),
283 Message::Binary(bytes) => Some(bytes.to_vec()),
284 Message::Close(_) => {
285 ctx.stop();
286 None
287 }
288 Message::Nop => None,
289 };
290
291 if let Some(message) = message {
292 let sender = self.messages.as_ref().unwrap().clone();
293
294 async move { sender.send(message).await }
295 .into_actor(self)
296 .map(|res, _actor, ctx| match res {
297 Ok(()) => {}
298 Err(_) => ctx.stop(),
299 })
300 .spawn(ctx)
301 }
302 }
303}