1use std::{any::type_name, sync::Arc};
2
3use async_graphql::{
4 http::GraphiQLSource, EmptyMutation, EmptySubscription, ObjectType, Schema, SubscriptionType,
5};
6use async_graphql_axum::{GraphQLRequest, GraphQLResponse};
7use axum::{
8 extract::{DefaultBodyLimit, Extension, FromRequest, Request, State},
9 http::{header, StatusCode},
10 response::{Html, IntoResponse, Response},
11 routing::{get, post},
12 RequestExt, Router,
13};
14use nestforge_core::{AuthIdentity, Container, RequestId};
15
16pub use async_graphql;
17
18pub type GraphQlSchema<Query, Mutation = EmptyMutation, Subscription = EmptySubscription> =
19 Schema<Query, Mutation, Subscription>;
20
21pub fn graphql_container<'ctx>(
22 ctx: &'ctx async_graphql::Context<'ctx>,
23) -> async_graphql::Result<&'ctx Container> {
24 ctx.data::<Container>()
25}
26
27pub fn graphql_request_id<'ctx>(
28 ctx: &'ctx async_graphql::Context<'ctx>,
29) -> Option<&'ctx RequestId> {
30 ctx.data_opt::<RequestId>()
31}
32
33pub fn graphql_auth_identity<'ctx>(
34 ctx: &'ctx async_graphql::Context<'ctx>,
35) -> Option<&'ctx AuthIdentity> {
36 ctx.data_opt::<Arc<AuthIdentity>>().map(AsRef::as_ref)
37}
38
39pub fn resolve_graphql<T>(ctx: &async_graphql::Context<'_>) -> async_graphql::Result<Arc<T>>
40where
41 T: Send + Sync + 'static,
42{
43 let container = graphql_container(ctx)?;
44 container.resolve::<T>().map_err(|_| {
45 async_graphql::Error::new(format!(
46 "Failed to resolve dependency `{}` from GraphQL context",
47 type_name::<T>()
48 ))
49 })
50}
51
52#[derive(Debug, Clone)]
53pub struct GraphQlConfig {
54 pub endpoint: String,
55 pub graphiql_endpoint: Option<String>,
56 pub max_request_bytes: usize,
57}
58
59impl Default for GraphQlConfig {
60 fn default() -> Self {
61 Self {
62 endpoint: "/graphql".to_string(),
63 graphiql_endpoint: Some("/graphiql".to_string()),
64 max_request_bytes: 1024 * 1024,
65 }
66 }
67}
68
69impl GraphQlConfig {
70 pub fn new(endpoint: impl Into<String>) -> Self {
71 Self {
72 endpoint: normalize_path(endpoint.into()),
73 ..Self::default()
74 }
75 }
76
77 pub fn with_graphiql(mut self, path: impl Into<String>) -> Self {
78 self.graphiql_endpoint = Some(normalize_path(path.into()));
79 self
80 }
81
82 pub fn without_graphiql(mut self) -> Self {
83 self.graphiql_endpoint = None;
84 self
85 }
86
87 pub fn with_max_request_bytes(mut self, bytes: usize) -> Self {
88 self.max_request_bytes = bytes;
89 self
90 }
91}
92
93pub fn graphql_router<Query, Mutation, Subscription>(
94 schema: GraphQlSchema<Query, Mutation, Subscription>,
95) -> Router<Container>
96where
97 Query: ObjectType + Send + Sync + 'static,
98 Mutation: ObjectType + Send + Sync + 'static,
99 Subscription: SubscriptionType + Send + Sync + 'static,
100{
101 graphql_router_with_config(schema, GraphQlConfig::default())
102}
103
104pub fn graphql_router_with_config<Query, Mutation, Subscription>(
105 schema: GraphQlSchema<Query, Mutation, Subscription>,
106 config: GraphQlConfig,
107) -> Router<Container>
108where
109 Query: ObjectType + Send + Sync + 'static,
110 Mutation: ObjectType + Send + Sync + 'static,
111 Subscription: SubscriptionType + Send + Sync + 'static,
112{
113 let max_request_bytes = config.max_request_bytes;
114 let mut router = Router::new()
115 .route(
116 &config.endpoint,
117 post(
118 move |container, scoped_container, request_id, auth_identity, schema, request| {
119 graphql_handler::<Query, Mutation, Subscription>(
120 max_request_bytes,
121 container,
122 scoped_container,
123 request_id,
124 auth_identity,
125 schema,
126 request,
127 )
128 },
129 ),
130 )
131 .layer(DefaultBodyLimit::max(config.max_request_bytes))
132 .layer(Extension(schema));
133
134 if let Some(graphiql_endpoint) = &config.graphiql_endpoint {
135 let endpoint = config.endpoint.clone();
136 let graphiql_html = GraphiQLSource::build().endpoint(&endpoint).finish();
137 router = router.route(
138 graphiql_endpoint,
139 get(move || {
140 let html = graphiql_html.clone();
141 async move { Html(html) }
142 }),
143 );
144 }
145
146 router
147}
148
149async fn graphql_handler<Query, Mutation, Subscription>(
150 max_request_bytes: usize,
151 State(container): State<Container>,
152 scoped_container: Option<Extension<Container>>,
153 request_id: Option<Extension<RequestId>>,
154 auth_identity: Option<Extension<Arc<AuthIdentity>>>,
155 Extension(schema): Extension<GraphQlSchema<Query, Mutation, Subscription>>,
156 request: Request,
157) -> Response
158where
159 Query: ObjectType + Send + Sync + 'static,
160 Mutation: ObjectType + Send + Sync + 'static,
161 Subscription: SubscriptionType + Send + Sync + 'static,
162{
163 if request
164 .headers()
165 .get(header::CONTENT_LENGTH)
166 .and_then(|value| value.to_str().ok())
167 .and_then(|value| value.parse::<usize>().ok())
168 .is_some_and(|length| length > max_request_bytes)
169 {
170 return StatusCode::PAYLOAD_TOO_LARGE.into_response();
171 }
172
173 let request =
174 match GraphQLRequest::<async_graphql_axum::rejection::GraphQLRejection>::from_request(
175 request.with_limited_body(),
176 &(),
177 )
178 .await
179 {
180 Ok(request) => request,
181 Err(rejection) => return rejection.into_response(),
182 };
183
184 let container = scoped_container.map(|value| value.0).unwrap_or(container);
185 let mut request = request.into_inner().data(container);
186 if let Some(Extension(request_id)) = request_id {
187 request = request.data(request_id);
188 }
189 if let Some(Extension(auth_identity)) = auth_identity {
190 request = request.data(auth_identity);
191 }
192
193 GraphQLResponse::from(schema.execute(request).await).into_response()
194}
195
196fn normalize_path(path: String) -> String {
197 let trimmed = path.trim();
198 if trimmed.is_empty() || trimmed == "/" {
199 return "/graphql".to_string();
200 }
201
202 if trimmed.starts_with('/') {
203 trimmed.to_string()
204 } else {
205 format!("/{trimmed}")
206 }
207}