1use std::{convert::Infallible, future::Future, str::FromStr, time::Duration};
2
3use async_graphql::{
4 Data, Executor, Result,
5 futures_util::task::{Context, Poll},
6 http::{
7 ALL_WEBSOCKET_PROTOCOLS, DefaultOnConnInitType, DefaultOnPingType, WebSocketProtocols,
8 WsMessage, default_on_connection_init, default_on_ping,
9 },
10};
11use axum::{
12 Error,
13 body::{Body, HttpBody},
14 extract::{
15 FromRequestParts, WebSocketUpgrade,
16 ws::{CloseFrame, Message},
17 },
18 http::{self, Request, Response, StatusCode, request::Parts},
19 response::IntoResponse,
20};
21use futures_util::{
22 Sink, SinkExt, Stream, StreamExt, future,
23 future::BoxFuture,
24 stream::{SplitSink, SplitStream},
25};
26use tower_service::Service;
27
28#[derive(Debug, Copy, Clone, PartialEq, Eq)]
32pub struct GraphQLProtocol(WebSocketProtocols);
33
34impl<S> FromRequestParts<S> for GraphQLProtocol
35where
36 S: Send + Sync,
37{
38 type Rejection = StatusCode;
39
40 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
41 parts
42 .headers
43 .get(http::header::SEC_WEBSOCKET_PROTOCOL)
44 .and_then(|value| value.to_str().ok())
45 .and_then(|protocols| {
46 protocols
47 .split(',')
48 .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
49 })
50 .map(Self)
51 .ok_or(StatusCode::BAD_REQUEST)
52 }
53}
54
55pub struct GraphQLSubscription<E> {
57 executor: E,
58}
59
60impl<E> Clone for GraphQLSubscription<E>
61where
62 E: Executor,
63{
64 fn clone(&self) -> Self {
65 Self {
66 executor: self.executor.clone(),
67 }
68 }
69}
70
71impl<E> GraphQLSubscription<E>
72where
73 E: Executor,
74{
75 pub fn new(executor: E) -> Self {
77 Self { executor }
78 }
79}
80
81impl<B, E> Service<Request<B>> for GraphQLSubscription<E>
82where
83 B: HttpBody + Send + 'static,
84 E: Executor,
85{
86 type Response = Response<Body>;
87 type Error = Infallible;
88 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
89
90 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
91 Poll::Ready(Ok(()))
92 }
93
94 fn call(&mut self, req: Request<B>) -> Self::Future {
95 let executor = self.executor.clone();
96
97 Box::pin(async move {
98 let (mut parts, _body) = req.into_parts();
99
100 let protocol = match GraphQLProtocol::from_request_parts(&mut parts, &()).await {
101 Ok(protocol) => protocol,
102 Err(err) => return Ok(err.into_response()),
103 };
104 let upgrade = match WebSocketUpgrade::from_request_parts(&mut parts, &()).await {
105 Ok(protocol) => protocol,
106 Err(err) => return Ok(err.into_response()),
107 };
108
109 let executor = executor.clone();
110
111 let resp = upgrade
112 .protocols(ALL_WEBSOCKET_PROTOCOLS)
113 .on_upgrade(move |stream| {
114 GraphQLWebSocket::new(stream, executor, protocol).serve()
115 });
116 Ok(resp.into_response())
117 })
118 }
119}
120
121pub struct GraphQLWebSocket<Sink, Stream, E, OnConnInit, OnPing> {
123 sink: Sink,
124 stream: Stream,
125 executor: E,
126 data: Data,
127 on_connection_init: OnConnInit,
128 on_ping: OnPing,
129 protocol: GraphQLProtocol,
130 keepalive_timeout: Option<Duration>,
131}
132
133impl<S, E>
134 GraphQLWebSocket<
135 SplitSink<S, Message>,
136 SplitStream<S>,
137 E,
138 DefaultOnConnInitType,
139 DefaultOnPingType,
140 >
141where
142 S: Stream<Item = Result<Message, Error>> + Sink<Message>,
143 E: Executor,
144{
145 pub fn new(stream: S, executor: E, protocol: GraphQLProtocol) -> Self {
147 let (sink, stream) = stream.split();
148 GraphQLWebSocket::new_with_pair(sink, stream, executor, protocol)
149 }
150}
151
152impl<Sink, Stream, E> GraphQLWebSocket<Sink, Stream, E, DefaultOnConnInitType, DefaultOnPingType>
153where
154 Sink: futures_util::sink::Sink<Message>,
155 Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
156 E: Executor,
157{
158 pub fn new_with_pair(
160 sink: Sink,
161 stream: Stream,
162 executor: E,
163 protocol: GraphQLProtocol,
164 ) -> Self {
165 GraphQLWebSocket {
166 sink,
167 stream,
168 executor,
169 data: Data::default(),
170 on_connection_init: default_on_connection_init,
171 on_ping: default_on_ping,
172 protocol,
173 keepalive_timeout: None,
174 }
175 }
176}
177
178impl<Sink, Stream, E, OnConnInit, OnConnInitFut, OnPing, OnPingFut>
179 GraphQLWebSocket<Sink, Stream, E, OnConnInit, OnPing>
180where
181 Sink: futures_util::sink::Sink<Message>,
182 Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
183 E: Executor,
184 OnConnInit: FnOnce(serde_json::Value) -> OnConnInitFut + Send + 'static,
185 OnConnInitFut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
186 OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> OnPingFut + Clone + Send + 'static,
187 OnPingFut: Future<Output = async_graphql::Result<Option<serde_json::Value>>> + Send + 'static,
188{
189 #[must_use]
192 pub fn with_data(self, data: Data) -> Self {
193 Self { data, ..self }
194 }
195
196 #[must_use]
203 pub fn on_connection_init<F, R>(
204 self,
205 callback: F,
206 ) -> GraphQLWebSocket<Sink, Stream, E, F, OnPing>
207 where
208 F: FnOnce(serde_json::Value) -> R + Send + 'static,
209 R: Future<Output = async_graphql::Result<Data>> + Send + 'static,
210 {
211 GraphQLWebSocket {
212 sink: self.sink,
213 stream: self.stream,
214 executor: self.executor,
215 data: self.data,
216 on_connection_init: callback,
217 on_ping: self.on_ping,
218 protocol: self.protocol,
219 keepalive_timeout: self.keepalive_timeout,
220 }
221 }
222
223 #[must_use]
232 pub fn on_ping<F, R>(self, callback: F) -> GraphQLWebSocket<Sink, Stream, E, OnConnInit, F>
233 where
234 F: FnOnce(Option<&Data>, Option<serde_json::Value>) -> R + Clone + Send + 'static,
235 R: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
236 {
237 GraphQLWebSocket {
238 sink: self.sink,
239 stream: self.stream,
240 executor: self.executor,
241 data: self.data,
242 on_connection_init: self.on_connection_init,
243 on_ping: callback,
244 protocol: self.protocol,
245 keepalive_timeout: self.keepalive_timeout,
246 }
247 }
248
249 #[must_use]
256 pub fn keepalive_timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
257 Self {
258 keepalive_timeout: timeout.into(),
259 ..self
260 }
261 }
262
263 pub async fn serve(self) {
265 let input = self
266 .stream
267 .take_while(|res| future::ready(res.is_ok()))
268 .map(Result::unwrap)
269 .filter_map(|msg| {
270 if let Message::Text(_) | Message::Binary(_) = msg {
271 future::ready(Some(msg))
272 } else {
273 future::ready(None)
274 }
275 })
276 .map(Message::into_data);
277
278 let stream =
279 async_graphql::http::WebSocket::new(self.executor.clone(), input, self.protocol.0)
280 .connection_data(self.data)
281 .on_connection_init(self.on_connection_init)
282 .on_ping(self.on_ping.clone())
283 .keepalive_timeout(self.keepalive_timeout)
284 .map(|msg| match msg {
285 WsMessage::Text(text) => Message::Text(text.into()),
286 WsMessage::Close(code, status) => Message::Close(Some(CloseFrame {
287 code,
288 reason: status.into(),
289 })),
290 });
291
292 let sink = self.sink;
293 futures_util::pin_mut!(stream, sink);
294
295 while let Some(item) = stream.next().await {
296 if sink.send(item).await.is_err() {
297 break;
298 }
299 }
300 }
301}