1use std::{convert::Infallible, future::Future, str::FromStr, time::Duration};
2
3use async_graphql::{
4 futures_util::task::{Context, Poll},
5 http::{
6 default_on_connection_init, default_on_ping, DefaultOnConnInitType, DefaultOnPingType,
7 WebSocketProtocols, WsMessage, ALL_WEBSOCKET_PROTOCOLS,
8 },
9 Data, Executor, Result,
10};
11use axum::{
12 body::{Body, HttpBody},
13 extract::{
14 ws::{CloseFrame, Message},
15 FromRequestParts, WebSocketUpgrade,
16 },
17 http::{self, request::Parts, Request, Response, StatusCode},
18 response::IntoResponse,
19 Error,
20};
21use futures_util::{
22 future,
23 future::BoxFuture,
24 stream::{SplitSink, SplitStream},
25 Sink, SinkExt, Stream, StreamExt,
26};
27use tower_service::Service;
28
29#[derive(Debug, Copy, Clone, PartialEq, Eq)]
33pub struct GraphQLProtocol(WebSocketProtocols);
34
35impl<S> FromRequestParts<S> for GraphQLProtocol
36where
37 S: Send + Sync,
38{
39 type Rejection = StatusCode;
40
41 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
42 parts
43 .headers
44 .get(http::header::SEC_WEBSOCKET_PROTOCOL)
45 .and_then(|value| value.to_str().ok())
46 .and_then(|protocols| {
47 protocols
48 .split(',')
49 .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
50 })
51 .map(Self)
52 .ok_or(StatusCode::BAD_REQUEST)
53 }
54}
55
56pub struct GraphQLSubscription<E> {
58 executor: E,
59}
60
61impl<E> Clone for GraphQLSubscription<E>
62where
63 E: Executor,
64{
65 fn clone(&self) -> Self {
66 Self {
67 executor: self.executor.clone(),
68 }
69 }
70}
71
72impl<E> GraphQLSubscription<E>
73where
74 E: Executor,
75{
76 pub fn new(executor: E) -> Self {
78 Self { executor }
79 }
80}
81
82impl<B, E> Service<Request<B>> for GraphQLSubscription<E>
83where
84 B: HttpBody + Send + 'static,
85 E: Executor,
86{
87 type Response = Response<Body>;
88 type Error = Infallible;
89 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
90
91 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
92 Poll::Ready(Ok(()))
93 }
94
95 fn call(&mut self, req: Request<B>) -> Self::Future {
96 let executor = self.executor.clone();
97
98 Box::pin(async move {
99 let (mut parts, _body) = req.into_parts();
100
101 let protocol = match GraphQLProtocol::from_request_parts(&mut parts, &()).await {
102 Ok(protocol) => protocol,
103 Err(err) => return Ok(err.into_response()),
104 };
105 let upgrade = match WebSocketUpgrade::from_request_parts(&mut parts, &()).await {
106 Ok(protocol) => protocol,
107 Err(err) => return Ok(err.into_response()),
108 };
109
110 let executor = executor.clone();
111
112 let resp = upgrade
113 .protocols(ALL_WEBSOCKET_PROTOCOLS)
114 .on_upgrade(move |stream| {
115 GraphQLWebSocket::new(stream, executor, protocol).serve()
116 });
117 Ok(resp.into_response())
118 })
119 }
120}
121
122pub struct GraphQLWebSocket<Sink, Stream, E, OnConnInit, OnPing> {
124 sink: Sink,
125 stream: Stream,
126 executor: E,
127 data: Data,
128 on_connection_init: OnConnInit,
129 on_ping: OnPing,
130 protocol: GraphQLProtocol,
131 keepalive_timeout: Option<Duration>,
132}
133
134impl<S, E>
135 GraphQLWebSocket<
136 SplitSink<S, Message>,
137 SplitStream<S>,
138 E,
139 DefaultOnConnInitType,
140 DefaultOnPingType,
141 >
142where
143 S: Stream<Item = Result<Message, Error>> + Sink<Message>,
144 E: Executor,
145{
146 pub fn new(stream: S, executor: E, protocol: GraphQLProtocol) -> Self {
148 let (sink, stream) = stream.split();
149 GraphQLWebSocket::new_with_pair(sink, stream, executor, protocol)
150 }
151}
152
153impl<Sink, Stream, E> GraphQLWebSocket<Sink, Stream, E, DefaultOnConnInitType, DefaultOnPingType>
154where
155 Sink: futures_util::sink::Sink<Message>,
156 Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
157 E: Executor,
158{
159 pub fn new_with_pair(
161 sink: Sink,
162 stream: Stream,
163 executor: E,
164 protocol: GraphQLProtocol,
165 ) -> Self {
166 GraphQLWebSocket {
167 sink,
168 stream,
169 executor,
170 data: Data::default(),
171 on_connection_init: default_on_connection_init,
172 on_ping: default_on_ping,
173 protocol,
174 keepalive_timeout: None,
175 }
176 }
177}
178
179impl<Sink, Stream, E, OnConnInit, OnConnInitFut, OnPing, OnPingFut>
180 GraphQLWebSocket<Sink, Stream, E, OnConnInit, OnPing>
181where
182 Sink: futures_util::sink::Sink<Message>,
183 Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
184 E: Executor,
185 OnConnInit: FnOnce(serde_json::Value) -> OnConnInitFut + Send + 'static,
186 OnConnInitFut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
187 OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> OnPingFut + Clone + Send + 'static,
188 OnPingFut: Future<Output = async_graphql::Result<Option<serde_json::Value>>> + Send + 'static,
189{
190 #[must_use]
193 pub fn with_data(self, data: Data) -> Self {
194 Self { data, ..self }
195 }
196
197 #[must_use]
204 pub fn on_connection_init<F, R>(
205 self,
206 callback: F,
207 ) -> GraphQLWebSocket<Sink, Stream, E, F, OnPing>
208 where
209 F: FnOnce(serde_json::Value) -> R + Send + 'static,
210 R: Future<Output = async_graphql::Result<Data>> + Send + 'static,
211 {
212 GraphQLWebSocket {
213 sink: self.sink,
214 stream: self.stream,
215 executor: self.executor,
216 data: self.data,
217 on_connection_init: callback,
218 on_ping: self.on_ping,
219 protocol: self.protocol,
220 keepalive_timeout: self.keepalive_timeout,
221 }
222 }
223
224 #[must_use]
233 pub fn on_ping<F, R>(self, callback: F) -> GraphQLWebSocket<Sink, Stream, E, OnConnInit, F>
234 where
235 F: FnOnce(Option<&Data>, Option<serde_json::Value>) -> R + Clone + Send + 'static,
236 R: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
237 {
238 GraphQLWebSocket {
239 sink: self.sink,
240 stream: self.stream,
241 executor: self.executor,
242 data: self.data,
243 on_connection_init: self.on_connection_init,
244 on_ping: callback,
245 protocol: self.protocol,
246 keepalive_timeout: self.keepalive_timeout,
247 }
248 }
249
250 #[must_use]
257 pub fn keepalive_timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
258 Self {
259 keepalive_timeout: timeout.into(),
260 ..self
261 }
262 }
263
264 pub async fn serve(self) {
266 let input = self
267 .stream
268 .take_while(|res| future::ready(res.is_ok()))
269 .map(Result::unwrap)
270 .filter_map(|msg| {
271 if let Message::Text(_) | Message::Binary(_) = msg {
272 future::ready(Some(msg))
273 } else {
274 future::ready(None)
275 }
276 })
277 .map(Message::into_data);
278
279 let stream =
280 async_graphql::http::WebSocket::new(self.executor.clone(), input, self.protocol.0)
281 .connection_data(self.data)
282 .on_connection_init(self.on_connection_init)
283 .on_ping(self.on_ping.clone())
284 .keepalive_timeout(self.keepalive_timeout)
285 .map(|msg| match msg {
286 WsMessage::Text(text) => Message::Text(text.into()),
287 WsMessage::Close(code, status) => Message::Close(Some(CloseFrame {
288 code,
289 reason: status.into(),
290 })),
291 });
292
293 let sink = self.sink;
294 futures_util::pin_mut!(stream, sink);
295
296 while let Some(item) = stream.next().await {
297 if sink.send(item).await.is_err() {
298 break;
299 }
300 }
301 }
302}