1use std::{any::type_name, sync::Arc};
2
3use async_graphql::{
4 http::GraphiQLSource, EmptyMutation, EmptySubscription, ObjectType, Schema, SubscriptionType,
5};
6use async_graphql_axum::{GraphQLRequest, GraphQLResponse};
7use axum::{
8 extract::{DefaultBodyLimit, Extension, FromRequest, Request, State},
9 http::{header, StatusCode},
10 response::{Html, IntoResponse, Response},
11 routing::{get, post},
12 RequestExt, Router,
13};
14use nestforge_core::{AuthIdentity, Container, RequestId};
15
16pub use async_graphql;
20
21pub type GraphQlSchema<Query, Mutation = EmptyMutation, Subscription = EmptySubscription> =
38 Schema<Query, Mutation, Subscription>;
39
40pub fn graphql_container<'ctx>(
55 ctx: &'ctx async_graphql::Context<'ctx>,
56) -> async_graphql::Result<&'ctx Container> {
57 ctx.data::<Container>()
58}
59
60pub fn graphql_request_id<'ctx>(
67 ctx: &'ctx async_graphql::Context<'ctx>,
68) -> Option<&'ctx RequestId> {
69 ctx.data_opt::<RequestId>()
70}
71
72pub fn graphql_auth_identity<'ctx>(
79 ctx: &'ctx async_graphql::Context<'ctx>,
80) -> Option<&'ctx AuthIdentity> {
81 ctx.data_opt::<Arc<AuthIdentity>>().map(AsRef::as_ref)
82}
83
84pub fn resolve_graphql<T>(ctx: &async_graphql::Context<'_>) -> async_graphql::Result<Arc<T>>
102where
103 T: Send + Sync + 'static,
104{
105 let container = graphql_container(ctx)?;
106 container.resolve::<T>().map_err(|_| {
107 async_graphql::Error::new(format!(
108 "Failed to resolve dependency `{}` from GraphQL context",
109 type_name::<T>()
110 ))
111 })
112}
113
114#[derive(Debug, Clone)]
125pub struct GraphQlConfig {
126 pub endpoint: String,
127 pub graphiql_endpoint: Option<String>,
128 pub max_request_bytes: usize,
129}
130
131impl Default for GraphQlConfig {
132 fn default() -> Self {
133 Self {
134 endpoint: "/graphql".to_string(),
135 graphiql_endpoint: Some("/graphiql".to_string()),
136 max_request_bytes: 1024 * 1024,
137 }
138 }
139}
140
141impl GraphQlConfig {
142 pub fn new(endpoint: impl Into<String>) -> Self {
146 Self {
147 endpoint: normalize_path(endpoint.into()),
148 ..Self::default()
149 }
150 }
151
152 pub fn with_graphiql(mut self, path: impl Into<String>) -> Self {
156 self.graphiql_endpoint = Some(normalize_path(path.into()));
157 self
158 }
159
160 pub fn without_graphiql(mut self) -> Self {
164 self.graphiql_endpoint = None;
165 self
166 }
167
168 pub fn with_max_request_bytes(mut self, bytes: usize) -> Self {
172 self.max_request_bytes = bytes;
173 self
174 }
175}
176
177pub fn graphql_router<Query, Mutation, Subscription>(
178 schema: GraphQlSchema<Query, Mutation, Subscription>,
179) -> Router<Container>
180where
181 Query: ObjectType + Send + Sync + 'static,
182 Mutation: ObjectType + Send + Sync + 'static,
183 Subscription: SubscriptionType + Send + Sync + 'static,
184{
185 graphql_router_with_config(schema, GraphQlConfig::default())
186}
187
188pub fn graphql_router_with_config<Query, Mutation, Subscription>(
189 schema: GraphQlSchema<Query, Mutation, Subscription>,
190 config: GraphQlConfig,
191) -> Router<Container>
192where
193 Query: ObjectType + Send + Sync + 'static,
194 Mutation: ObjectType + Send + Sync + 'static,
195 Subscription: SubscriptionType + Send + Sync + 'static,
196{
197 let max_request_bytes = config.max_request_bytes;
198 let mut router = Router::new()
199 .route(
200 &config.endpoint,
201 post(
202 move |container, scoped_container, request_id, auth_identity, schema, request| {
203 graphql_handler::<Query, Mutation, Subscription>(
204 max_request_bytes,
205 container,
206 scoped_container,
207 request_id,
208 auth_identity,
209 schema,
210 request,
211 )
212 },
213 ),
214 )
215 .layer(DefaultBodyLimit::max(config.max_request_bytes))
216 .layer(Extension(schema));
217
218 if let Some(graphiql_endpoint) = &config.graphiql_endpoint {
219 let endpoint = config.endpoint.clone();
220 let graphiql_html = GraphiQLSource::build().endpoint(&endpoint).finish();
221 router = router.route(
222 graphiql_endpoint,
223 get(move || {
224 let html = graphiql_html.clone();
225 async move { Html(html) }
226 }),
227 );
228 }
229
230 router
231}
232
233async fn graphql_handler<Query, Mutation, Subscription>(
234 max_request_bytes: usize,
235 State(container): State<Container>,
236 scoped_container: Option<Extension<Container>>,
237 request_id: Option<Extension<RequestId>>,
238 auth_identity: Option<Extension<Arc<AuthIdentity>>>,
239 Extension(schema): Extension<GraphQlSchema<Query, Mutation, Subscription>>,
240 request: Request,
241) -> Response
242where
243 Query: ObjectType + Send + Sync + 'static,
244 Mutation: ObjectType + Send + Sync + 'static,
245 Subscription: SubscriptionType + Send + Sync + 'static,
246{
247 if request
248 .headers()
249 .get(header::CONTENT_LENGTH)
250 .and_then(|value| value.to_str().ok())
251 .and_then(|value| value.parse::<usize>().ok())
252 .is_some_and(|length| length > max_request_bytes)
253 {
254 return StatusCode::PAYLOAD_TOO_LARGE.into_response();
255 }
256
257 let request =
258 match GraphQLRequest::<async_graphql_axum::rejection::GraphQLRejection>::from_request(
259 request.with_limited_body(),
260 &(),
261 )
262 .await
263 {
264 Ok(request) => request,
265 Err(rejection) => return rejection.into_response(),
266 };
267
268 let container = scoped_container.map(|value| value.0).unwrap_or(container);
269 let mut request = request.into_inner().data(container);
270 if let Some(Extension(request_id)) = request_id {
271 request = request.data(request_id);
272 }
273 if let Some(Extension(auth_identity)) = auth_identity {
274 request = request.data(auth_identity);
275 }
276
277 GraphQLResponse::from(schema.execute(request).await).into_response()
278}
279
280fn normalize_path(path: String) -> String {
281 let trimmed = path.trim();
282 if trimmed.is_empty() || trimmed == "/" {
283 return "/graphql".to_string();
284 }
285
286 if trimmed.starts_with('/') {
287 trimmed.to_string()
288 } else {
289 format!("/{trimmed}")
290 }
291}