1use std::{io::Error as IoError, str::FromStr, time::Duration};
2
3use async_graphql::{
4 Data, Executor,
5 http::{
6 ALL_WEBSOCKET_PROTOCOLS, DefaultOnConnInitType, DefaultOnPingType, WebSocketProtocols,
7 WsMessage, default_on_connection_init, default_on_ping,
8 },
9};
10use futures_util::{
11 Future, Sink, SinkExt, Stream, StreamExt,
12 future::{self},
13 stream::{SplitSink, SplitStream},
14};
15use poem::{
16 Endpoint, Error, FromRequest, IntoResponse, Request, RequestBody, Response, Result,
17 http::StatusCode,
18 web::websocket::{Message, WebSocket},
19};
20
21#[derive(Debug, Copy, Clone, PartialEq, Eq)]
25pub struct GraphQLProtocol(pub WebSocketProtocols);
26
27impl<'a> FromRequest<'a> for GraphQLProtocol {
28 async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result<Self> {
29 req.headers()
30 .get(http::header::SEC_WEBSOCKET_PROTOCOL)
31 .and_then(|value| value.to_str().ok())
32 .and_then(|protocols| {
33 protocols
34 .split(',')
35 .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
36 })
37 .map(Self)
38 .ok_or_else(|| Error::from_status(StatusCode::BAD_REQUEST))
39 }
40}
41
42pub struct GraphQLSubscription<E> {
76 executor: E,
77}
78
79impl<E> GraphQLSubscription<E> {
80 pub fn new(executor: E) -> Self {
82 Self { executor }
83 }
84}
85
86impl<E> Endpoint for GraphQLSubscription<E>
87where
88 E: Executor,
89{
90 type Output = Response;
91
92 async fn call(&self, req: Request) -> Result<Self::Output> {
93 let (req, mut body) = req.split();
94 let websocket = WebSocket::from_request(&req, &mut body).await?;
95 let protocol = GraphQLProtocol::from_request(&req, &mut body).await?;
96 let executor = self.executor.clone();
97
98 let resp = websocket
99 .protocols(ALL_WEBSOCKET_PROTOCOLS)
100 .on_upgrade(move |stream| GraphQLWebSocket::new(stream, executor, protocol).serve())
101 .into_response();
102 Ok(resp)
103 }
104}
105
106pub struct GraphQLWebSocket<Sink, Stream, E, OnConnInit, OnPing> {
108 sink: Sink,
109 stream: Stream,
110 executor: E,
111 data: Data,
112 on_connection_init: OnConnInit,
113 on_ping: OnPing,
114 protocol: GraphQLProtocol,
115 keepalive_timeout: Option<Duration>,
116}
117
118impl<S, E>
119 GraphQLWebSocket<
120 SplitSink<S, Message>,
121 SplitStream<S>,
122 E,
123 DefaultOnConnInitType,
124 DefaultOnPingType,
125 >
126where
127 S: Stream<Item = Result<Message, IoError>> + Sink<Message>,
128 E: Executor,
129{
130 pub fn new(stream: S, executor: E, protocol: GraphQLProtocol) -> Self {
132 let (sink, stream) = stream.split();
133 GraphQLWebSocket::new_with_pair(sink, stream, executor, protocol)
134 }
135}
136
137impl<Sink, Stream, E> GraphQLWebSocket<Sink, Stream, E, DefaultOnConnInitType, DefaultOnPingType>
138where
139 Sink: futures_util::sink::Sink<Message>,
140 Stream: futures_util::stream::Stream<Item = Result<Message, IoError>>,
141 E: Executor,
142{
143 pub fn new_with_pair(
145 sink: Sink,
146 stream: Stream,
147 executor: E,
148 protocol: GraphQLProtocol,
149 ) -> Self {
150 GraphQLWebSocket {
151 sink,
152 stream,
153 executor,
154 data: Data::default(),
155 on_connection_init: default_on_connection_init,
156 on_ping: default_on_ping,
157 protocol,
158 keepalive_timeout: None,
159 }
160 }
161}
162
163impl<Sink, Stream, E, OnConnInit, OnConnInitFut, OnPing, OnPingFut>
164 GraphQLWebSocket<Sink, Stream, E, OnConnInit, OnPing>
165where
166 Sink: futures_util::sink::Sink<Message>,
167 Stream: futures_util::stream::Stream<Item = Result<Message, IoError>>,
168 E: Executor,
169 OnConnInit: FnOnce(serde_json::Value) -> OnConnInitFut + Send + 'static,
170 OnConnInitFut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
171 OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> OnPingFut + Clone + Send + 'static,
172 OnPingFut: Future<Output = async_graphql::Result<Option<serde_json::Value>>> + Send + 'static,
173{
174 #[must_use]
177 pub fn with_data(self, data: Data) -> Self {
178 Self { data, ..self }
179 }
180
181 #[must_use]
188 pub fn on_connection_init<F, R>(
189 self,
190 callback: F,
191 ) -> GraphQLWebSocket<Sink, Stream, E, F, OnPing>
192 where
193 F: FnOnce(serde_json::Value) -> R + Send + 'static,
194 R: Future<Output = async_graphql::Result<Data>> + Send + 'static,
195 {
196 GraphQLWebSocket {
197 sink: self.sink,
198 stream: self.stream,
199 executor: self.executor,
200 data: self.data,
201 on_connection_init: callback,
202 on_ping: self.on_ping,
203 protocol: self.protocol,
204 keepalive_timeout: self.keepalive_timeout,
205 }
206 }
207
208 #[must_use]
217 pub fn on_ping<F, R>(self, callback: F) -> GraphQLWebSocket<Sink, Stream, E, OnConnInit, F>
218 where
219 F: FnOnce(Option<&Data>, Option<serde_json::Value>) -> R + Clone + Send + 'static,
220 R: Future<Output = async_graphql::Result<Option<serde_json::Value>>> + Send + 'static,
221 {
222 GraphQLWebSocket {
223 sink: self.sink,
224 stream: self.stream,
225 executor: self.executor,
226 data: self.data,
227 on_connection_init: self.on_connection_init,
228 on_ping: callback,
229 protocol: self.protocol,
230 keepalive_timeout: self.keepalive_timeout,
231 }
232 }
233
234 #[must_use]
241 pub fn keepalive_timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
242 Self {
243 keepalive_timeout: timeout.into(),
244 ..self
245 }
246 }
247
248 pub async fn serve(self) {
250 let stream = self
251 .stream
252 .take_while(|res| future::ready(res.is_ok()))
253 .map(Result::unwrap)
254 .filter_map(|msg| {
255 if msg.is_text() || msg.is_binary() {
256 future::ready(Some(msg))
257 } else {
258 future::ready(None)
259 }
260 })
261 .map(Message::into_bytes);
262
263 let stream =
264 async_graphql::http::WebSocket::new(self.executor.clone(), stream, self.protocol.0)
265 .connection_data(self.data)
266 .on_connection_init(self.on_connection_init)
267 .on_ping(self.on_ping.clone())
268 .keepalive_timeout(self.keepalive_timeout)
269 .map(|msg| match msg {
270 WsMessage::Text(text) => Message::text(text),
271 WsMessage::Close(code, status) => Message::close_with(code, status),
272 });
273
274 let sink = self.sink;
275 futures_util::pin_mut!(stream, sink);
276
277 while let Some(item) = stream.next().await {
278 if sink.send(item).await.is_err() {
279 break;
280 }
281 }
282 }
283}