graphql_starter/graphql/
handler.rs

1use async_graphql::http::{AltairSource, Credentials, GraphiQLSource};
2use axum::response::{Html, IntoResponse};
3
4/// Handler that renders a GraphiQL playground on the given path to explore the API.
5pub async fn graphiql_playground_handler(path: String, title: &str) -> impl IntoResponse {
6    Html(
7        GraphiQLSource::build()
8            .endpoint(&path)
9            .subscription_endpoint(&format!("{path}/ws"))
10            .title(title)
11            .credentials(Credentials::SameOrigin)
12            .header("x-requested-with", "graphiql")
13            .finish(),
14    )
15}
16
17/// Handler that renders an Altair GraphQL playground on the given path to explore the API.
18pub async fn altair_playground_handler(path: String, title: &str) -> impl IntoResponse {
19    Html(
20        AltairSource::build()
21            .title(title)
22            .options(serde_json::json!({
23                "endpointURL": path,
24                "subscriptionsEndpoint": format!("{path}/ws"),
25                "subscriptionsProtocol": "wss",
26                "disableAccount": true,
27                "initialHeaders": {
28                    "x-requested-with": "altair"
29                },
30                "initialSettings": {
31                    "addQueryDepthLimit": 1,
32                    "request.withCredentials": true,
33                    "plugin.list": ["altair-graphql-plugin-graphql-explorer"],
34                    "schema.reloadOnStart": true,
35                }
36            }))
37            .finish(),
38    )
39}
40
41#[cfg(feature = "auth")]
42mod auth {
43    use async_graphql::{
44        http::ALL_WEBSOCKET_PROTOCOLS, BatchRequest, BatchResponse, Data, ObjectType, Response, Schema,
45        SubscriptionType,
46    };
47    use async_graphql_axum::{GraphQLProtocol, GraphQLResponse, GraphQLWebSocket};
48    use auto_impl::auto_impl;
49    use axum::{
50        extract::{FromRequestParts, State, WebSocketUpgrade},
51        response::IntoResponse,
52    };
53    use futures_util::{stream::FuturesOrdered, StreamExt};
54    use tracing::Instrument;
55
56    use crate::{
57        auth::{Auth, AuthErrorCode, AuthState, AuthenticationService, Subject},
58        axum::{
59            extract::{AcceptLanguage, Extension},
60            CorsService, CorsState,
61        },
62        error::{err, ApiError, GenericErrorCode, MapToErr},
63        graphql::GraphQLBatchRequest,
64        request_id::RequestId,
65    };
66
67    /// Middleware to customize the data attached to each GraphQL request.
68    #[auto_impl(Box, Arc)]
69    pub trait RequestDataMiddleware<S: Subject>: Send + Sync + Sized + Clone + 'static {
70        /// Customize the given request data, inserting or modifying the content.
71        fn customize_request_data(&self, subject: &Option<S>, accept_language: &AcceptLanguage, data: &mut Data);
72    }
73    impl<S: Subject> RequestDataMiddleware<S> for () {
74        fn customize_request_data(&self, _subject: &Option<S>, _accept_language: &AcceptLanguage, _data: &mut Data) {}
75    }
76
77    /// Handler for [batch requests](https://www.apollographql.com/blog/apollo-client/performance/batching-client-graphql-queries/).
78    ///
79    /// [RequestId], [`Option<Subject>`](Subject) and [AcceptLanguage] will be added to the GraphQL context before
80    /// executing the request on the schema.
81    ///
82    /// This handler expects two extensions:
83    /// - `Schema<Query, Mutation, Subscription>` with the GraphQL [Schema]
84    /// - `RequestId` with the request id (see [RequestIdLayer](crate::request_id::RequestIdLayer))
85    ///
86    /// And optionally:
87    /// - `RequestDataMiddleware<Subject>` with the [RequestDataMiddleware]
88    pub async fn graphql_batch_handler<S: Subject, M: RequestDataMiddleware<S>, Query, Mutation, Subscription>(
89        Extension(schema): Extension<Schema<Query, Mutation, Subscription>>,
90        Extension(request_id): Extension<RequestId>,
91        middleware: Option<Extension<M>>,
92        subject: Option<Auth<S>>,
93        accept_language: AcceptLanguage,
94        req: GraphQLBatchRequest,
95    ) -> GraphQLResponse
96    where
97        Query: ObjectType + 'static,
98        Mutation: ObjectType + 'static,
99        Subscription: SubscriptionType + 'static,
100    {
101        let mut req = req.into_inner();
102        let subject = subject.map(|s| s.0);
103        // Log request operations
104        if tracing::event_enabled!(tracing::Level::TRACE) {
105            let op_names = req
106                .iter()
107                .flat_map(|r| r.operation_name.as_deref())
108                .collect::<Vec<_>>()
109                .join(", ");
110            tracing::trace!("request operations: {op_names}")
111        }
112        // Call the request data middleware to include additional data
113        if let Some(Extension(middleware)) = middleware {
114            match &mut req {
115                BatchRequest::Single(r) => {
116                    middleware.customize_request_data(&subject, &accept_language, &mut r.data);
117                }
118                BatchRequest::Batch(b) => {
119                    for r in b {
120                        middleware.customize_request_data(&subject, &accept_language, &mut r.data);
121                    }
122                }
123            }
124        }
125        // Include the request_id, subject and accept language into the GraphQL context
126        req = req.data(request_id).data(subject).data(accept_language);
127        // Execute the requests, instrumenting them with the operation name (if present)
128        let mut res = match req {
129            BatchRequest::Single(request) => {
130                let span = if let Some(op) = &request.operation_name {
131                    tracing::info_span!("gql", %op)
132                } else {
133                    tracing::info_span!("gql")
134                };
135                BatchResponse::Single(schema.execute(request).instrument(span).await)
136            }
137            BatchRequest::Batch(requests) => BatchResponse::Batch(
138                FuturesOrdered::from_iter(requests.into_iter().map(|request| {
139                    let span = if let Some(op) = &request.operation_name {
140                        tracing::info_span!("gql", %op)
141                    } else {
142                        tracing::info_span!("gql")
143                    };
144                    schema.execute(request).instrument(span)
145                }))
146                .collect()
147                .await,
148            ),
149        };
150        // Include the request id if any error is found
151        match &mut res {
152            BatchResponse::Single(res) => include_request_id(res, &request_id),
153            BatchResponse::Batch(responses) => {
154                for res in responses {
155                    include_request_id(res, &request_id);
156                }
157            }
158        }
159        res.into()
160    }
161
162    /// Handler for GraphQL [subscriptions](https://www.apollographql.com/docs/react/data/subscriptions/).
163    ///
164    /// **Note**: For HTTP/1.1 requests, this handler requires the request method to be `GET`; in later versions,
165    /// `CONNECT` is used instead. To support both, it should be used with [`any`](axum::routing::any).
166    ///
167    /// [RequestId], [`Option<Subject>`](Subject) and [AcceptLanguage] will be added to the GraphQL context before
168    /// executing the request on the schema.
169    ///
170    /// This handler expects two extensions:
171    /// - `Schema<Query, Mutation, Subscription>` with the GraphQL [Schema]
172    /// - `RequestId` with the request id (see [RequestIdLayer](crate::request_id::RequestIdLayer))
173    ///
174    /// And optionally:
175    /// - `RequestDataMiddleware<Subject>` with the [RequestDataMiddleware]
176    ///
177    /// Authentication will be performed using the same criteria than [Auth](crate::auth::Auth) extractor,
178    /// retrieving the Cookie from the `GET` request and the token from the
179    /// [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init).
180    pub async fn graphql_subscription_handler<
181        Query,
182        Mutation,
183        Subscription,
184        S: Subject,
185        M: RequestDataMiddleware<S>,
186        St: AuthState<S> + CorsState,
187        B,
188    >(
189        State(state): State<St>,
190        Extension(schema): Extension<Schema<Query, Mutation, Subscription>>,
191        Extension(request_id): Extension<RequestId>,
192        middleware: Option<Extension<M>>,
193        accept_language: AcceptLanguage,
194        req: http::Request<B>,
195    ) -> axum::response::Response
196    where
197        Query: ObjectType + 'static,
198        Mutation: ObjectType + 'static,
199        Subscription: SubscriptionType + 'static,
200    {
201        let (mut parts, _body) = req.into_parts();
202
203        // Retrieve `Origin` header set by browsers
204        let origin_header = match parts
205            .headers
206            .get(http::header::ORIGIN)
207            .map(|v| {
208                v.to_str()
209                    .map_to_err_with(GenericErrorCode::BadRequest, "Couldn't parse request origin header")
210            })
211            .transpose()
212        {
213            Ok(o) => o,
214            Err(err) => return ApiError::from_err(err).into_response(),
215        };
216        // If it's present, check it's allowed
217        if let Some(origin_header) = origin_header {
218            if !state.cors().allowed_origins().iter().any(|o| o == origin_header) {
219                return ApiError::from_err(err!(GenericErrorCode::Forbidden, "The origin is not allowed"))
220                    .into_response();
221            }
222        }
223
224        // Retrieve token & cookie names
225        let authn = state.authn().clone();
226        let auth_header_name = authn.header_name().to_lowercase();
227        let auth_cookie_name = authn.cookie_name().to_owned();
228
229        // Retrieve the auth cookie value
230        let cookies = match parts
231            .headers
232            .get(http::header::COOKIE)
233            .map(|v| {
234                v.to_str()
235                    .map_to_err_with(AuthErrorCode::AuthMalformedCookies, "Couldn't parse request cookies")
236            })
237            .transpose()
238        {
239            Ok(c) => c,
240            Err(err) => return ApiError::from_err(err).into_response(),
241        };
242        let auth_cookie_value = cookies
243            .and_then(|cookies| {
244                cookies
245                    .split("; ")
246                    .find_map(|cookie| cookie.strip_prefix(&format!("{auth_cookie_name}=")))
247            })
248            .map(|s| s.to_owned());
249
250        // Based on https://github.com/async-graphql/async-graphql/blob/master/integrations/axum/src/subscription.rs
251        // Extract GraphQL WebSocket protocol
252        let protocol = match GraphQLProtocol::from_request_parts(&mut parts, &()).await {
253            Ok(protocol) => protocol,
254            Err(err) => return err.into_response(),
255        };
256        // Prepare upgrade connection from HTTPS to WSS
257        let upgrade = match WebSocketUpgrade::from_request_parts(&mut parts, &()).await {
258            Ok(protocol) => protocol,
259            Err(err) => return err.into_response(),
260        };
261
262        // Finalize upgrading connection
263        upgrade
264            .protocols(ALL_WEBSOCKET_PROTOCOLS)
265            .on_upgrade(move |stream| {
266                // Forward the stream to the GraphQL websocket
267                GraphQLWebSocket::new(stream, schema.clone(), protocol)
268                    .on_connection_init(move |payload| {
269                        // Authenticate the subject on connection init
270                        async move {
271                            let mut data = Data::default();
272                            // Retrieve auth token from the payload
273                            let auth_token = payload.as_object().and_then(|payload| {
274                                payload
275                                    .iter()
276                                    .find(|(k, _)| k.to_lowercase() == auth_header_name)
277                                    .and_then(|(_, v)| v.as_str())
278                            });
279                            // Authenticate the subject
280                            let subject = authn.authenticate(auth_token, auth_cookie_value.as_deref()).await?;
281                            tracing::trace!("Authenticated as {subject}");
282                            let subject = Some(subject);
283
284                            // Call the request data middleware to include additional data
285                            if let Some(Extension(middleware)) = middleware {
286                                middleware.customize_request_data(&subject, &accept_language, &mut data);
287                            }
288
289                            // Include the request_id, subject and accept language into the GraphQL context
290                            data.insert(request_id);
291                            data.insert(subject);
292                            data.insert(accept_language);
293
294                            Ok(data)
295                        }
296                    })
297                    .serve()
298            })
299            .into_response()
300    }
301
302    /// Includes the request id extension on the response errors (if any)
303    fn include_request_id(res: &mut Response, id: &RequestId) {
304        for e in &mut res.errors {
305            e.extensions
306                .get_or_insert_with(Default::default)
307                .set("requestId", id.to_string())
308        }
309    }
310}
311#[cfg(feature = "auth")]
312pub use auth::*;