1use async_graphql::http::{AltairSource, Credentials, GraphiQLSource};
2use axum::response::{Html, IntoResponse};
3
4pub async fn graphiql_playground_handler(path: String, title: &str) -> impl IntoResponse {
6 Html(
7 GraphiQLSource::build()
8 .endpoint(&path)
9 .subscription_endpoint(&format!("{path}/ws"))
10 .title(title)
11 .credentials(Credentials::SameOrigin)
12 .header("x-requested-with", "graphiql")
13 .finish(),
14 )
15}
16
17pub async fn altair_playground_handler(path: String, title: &str) -> impl IntoResponse {
19 Html(
20 AltairSource::build()
21 .title(title)
22 .options(serde_json::json!({
23 "endpointURL": path,
24 "subscriptionsEndpoint": format!("{path}/ws"),
25 "subscriptionsProtocol": "wss",
26 "disableAccount": true,
27 "initialHeaders": {
28 "x-requested-with": "altair"
29 },
30 "initialSettings": {
31 "addQueryDepthLimit": 1,
32 "request.withCredentials": true,
33 "plugin.list": ["altair-graphql-plugin-graphql-explorer"],
34 "schema.reloadOnStart": true,
35 }
36 }))
37 .finish(),
38 )
39}
40
41#[cfg(feature = "auth")]
42mod auth {
43 use async_graphql::{
44 http::ALL_WEBSOCKET_PROTOCOLS, BatchRequest, BatchResponse, Data, ObjectType, Response, Schema,
45 SubscriptionType,
46 };
47 use async_graphql_axum::{GraphQLProtocol, GraphQLResponse, GraphQLWebSocket};
48 use auto_impl::auto_impl;
49 use axum::{
50 extract::{FromRequestParts, State, WebSocketUpgrade},
51 response::IntoResponse,
52 };
53 use futures_util::{stream::FuturesOrdered, StreamExt};
54 use tracing::Instrument;
55
56 use crate::{
57 auth::{Auth, AuthErrorCode, AuthState, AuthenticationService, Subject},
58 axum::{
59 extract::{AcceptLanguage, Extension},
60 CorsService, CorsState,
61 },
62 error::{err, ApiError, GenericErrorCode, MapToErr},
63 graphql::GraphQLBatchRequest,
64 request_id::RequestId,
65 };
66
67 #[auto_impl(Box, Arc)]
69 pub trait RequestDataMiddleware<S: Subject>: Send + Sync + Sized + Clone + 'static {
70 fn customize_request_data(&self, subject: &Option<S>, accept_language: &AcceptLanguage, data: &mut Data);
72 }
73 impl<S: Subject> RequestDataMiddleware<S> for () {
74 fn customize_request_data(&self, _subject: &Option<S>, _accept_language: &AcceptLanguage, _data: &mut Data) {}
75 }
76
77 pub async fn graphql_batch_handler<S: Subject, M: RequestDataMiddleware<S>, Query, Mutation, Subscription>(
89 Extension(schema): Extension<Schema<Query, Mutation, Subscription>>,
90 Extension(request_id): Extension<RequestId>,
91 middleware: Option<Extension<M>>,
92 subject: Option<Auth<S>>,
93 accept_language: AcceptLanguage,
94 req: GraphQLBatchRequest,
95 ) -> GraphQLResponse
96 where
97 Query: ObjectType + 'static,
98 Mutation: ObjectType + 'static,
99 Subscription: SubscriptionType + 'static,
100 {
101 let mut req = req.into_inner();
102 let subject = subject.map(|s| s.0);
103 if tracing::event_enabled!(tracing::Level::TRACE) {
105 let op_names = req
106 .iter()
107 .flat_map(|r| r.operation_name.as_deref())
108 .collect::<Vec<_>>()
109 .join(", ");
110 tracing::trace!("request operations: {op_names}")
111 }
112 if let Some(Extension(middleware)) = middleware {
114 match &mut req {
115 BatchRequest::Single(r) => {
116 middleware.customize_request_data(&subject, &accept_language, &mut r.data);
117 }
118 BatchRequest::Batch(b) => {
119 for r in b {
120 middleware.customize_request_data(&subject, &accept_language, &mut r.data);
121 }
122 }
123 }
124 }
125 req = req.data(request_id).data(subject).data(accept_language);
127 let mut res = match req {
129 BatchRequest::Single(request) => {
130 let span = if let Some(op) = &request.operation_name {
131 tracing::info_span!("gql", %op)
132 } else {
133 tracing::info_span!("gql")
134 };
135 BatchResponse::Single(schema.execute(request).instrument(span).await)
136 }
137 BatchRequest::Batch(requests) => BatchResponse::Batch(
138 FuturesOrdered::from_iter(requests.into_iter().map(|request| {
139 let span = if let Some(op) = &request.operation_name {
140 tracing::info_span!("gql", %op)
141 } else {
142 tracing::info_span!("gql")
143 };
144 schema.execute(request).instrument(span)
145 }))
146 .collect()
147 .await,
148 ),
149 };
150 match &mut res {
152 BatchResponse::Single(res) => include_request_id(res, &request_id),
153 BatchResponse::Batch(responses) => {
154 for res in responses {
155 include_request_id(res, &request_id);
156 }
157 }
158 }
159 res.into()
160 }
161
162 pub async fn graphql_subscription_handler<
181 Query,
182 Mutation,
183 Subscription,
184 S: Subject,
185 M: RequestDataMiddleware<S>,
186 St: AuthState<S> + CorsState,
187 B,
188 >(
189 State(state): State<St>,
190 Extension(schema): Extension<Schema<Query, Mutation, Subscription>>,
191 Extension(request_id): Extension<RequestId>,
192 middleware: Option<Extension<M>>,
193 accept_language: AcceptLanguage,
194 req: http::Request<B>,
195 ) -> axum::response::Response
196 where
197 Query: ObjectType + 'static,
198 Mutation: ObjectType + 'static,
199 Subscription: SubscriptionType + 'static,
200 {
201 let (mut parts, _body) = req.into_parts();
202
203 let origin_header = match parts
205 .headers
206 .get(http::header::ORIGIN)
207 .map(|v| {
208 v.to_str()
209 .map_to_err_with(GenericErrorCode::BadRequest, "Couldn't parse request origin header")
210 })
211 .transpose()
212 {
213 Ok(o) => o,
214 Err(err) => return ApiError::from_err(err).into_response(),
215 };
216 if let Some(origin_header) = origin_header {
218 if !state.cors().allowed_origins().iter().any(|o| o == origin_header) {
219 return ApiError::from_err(err!(GenericErrorCode::Forbidden, "The origin is not allowed"))
220 .into_response();
221 }
222 }
223
224 let authn = state.authn().clone();
226 let auth_header_name = authn.header_name().to_lowercase();
227 let auth_cookie_name = authn.cookie_name().to_owned();
228
229 let cookies = match parts
231 .headers
232 .get(http::header::COOKIE)
233 .map(|v| {
234 v.to_str()
235 .map_to_err_with(AuthErrorCode::AuthMalformedCookies, "Couldn't parse request cookies")
236 })
237 .transpose()
238 {
239 Ok(c) => c,
240 Err(err) => return ApiError::from_err(err).into_response(),
241 };
242 let auth_cookie_value = cookies
243 .and_then(|cookies| {
244 cookies
245 .split("; ")
246 .find_map(|cookie| cookie.strip_prefix(&format!("{auth_cookie_name}=")))
247 })
248 .map(|s| s.to_owned());
249
250 let protocol = match GraphQLProtocol::from_request_parts(&mut parts, &()).await {
253 Ok(protocol) => protocol,
254 Err(err) => return err.into_response(),
255 };
256 let upgrade = match WebSocketUpgrade::from_request_parts(&mut parts, &()).await {
258 Ok(protocol) => protocol,
259 Err(err) => return err.into_response(),
260 };
261
262 upgrade
264 .protocols(ALL_WEBSOCKET_PROTOCOLS)
265 .on_upgrade(move |stream| {
266 GraphQLWebSocket::new(stream, schema.clone(), protocol)
268 .on_connection_init(move |payload| {
269 async move {
271 let mut data = Data::default();
272 let auth_token = payload.as_object().and_then(|payload| {
274 payload
275 .iter()
276 .find(|(k, _)| k.to_lowercase() == auth_header_name)
277 .and_then(|(_, v)| v.as_str())
278 });
279 let subject = authn.authenticate(auth_token, auth_cookie_value.as_deref()).await?;
281 tracing::trace!("Authenticated as {subject}");
282 let subject = Some(subject);
283
284 if let Some(Extension(middleware)) = middleware {
286 middleware.customize_request_data(&subject, &accept_language, &mut data);
287 }
288
289 data.insert(request_id);
291 data.insert(subject);
292 data.insert(accept_language);
293
294 Ok(data)
295 }
296 })
297 .serve()
298 })
299 .into_response()
300 }
301
302 fn include_request_id(res: &mut Response, id: &RequestId) {
304 for e in &mut res.errors {
305 e.extensions
306 .get_or_insert_with(Default::default)
307 .set("requestId", id.to_string())
308 }
309 }
310}
311#[cfg(feature = "auth")]
312pub use auth::*;