Skip to main content

nestforge_graphql/
lib.rs

1use std::{any::type_name, sync::Arc};
2
3use async_graphql::{
4    http::GraphiQLSource, EmptyMutation, EmptySubscription, ObjectType, Schema, SubscriptionType,
5};
6use async_graphql_axum::{GraphQLRequest, GraphQLResponse};
7use axum::{
8    extract::{DefaultBodyLimit, Extension, FromRequest, Request, State},
9    http::{header, StatusCode},
10    response::{Html, IntoResponse, Response},
11    routing::{get, post},
12    RequestExt, Router,
13};
14use nestforge_core::{AuthIdentity, Container, RequestId};
15
16pub use async_graphql;
17
18pub type GraphQlSchema<Query, Mutation = EmptyMutation, Subscription = EmptySubscription> =
19    Schema<Query, Mutation, Subscription>;
20
21pub fn graphql_container<'ctx>(
22    ctx: &'ctx async_graphql::Context<'ctx>,
23) -> async_graphql::Result<&'ctx Container> {
24    ctx.data::<Container>()
25}
26
27pub fn graphql_request_id<'ctx>(
28    ctx: &'ctx async_graphql::Context<'ctx>,
29) -> Option<&'ctx RequestId> {
30    ctx.data_opt::<RequestId>()
31}
32
33pub fn graphql_auth_identity<'ctx>(
34    ctx: &'ctx async_graphql::Context<'ctx>,
35) -> Option<&'ctx AuthIdentity> {
36    ctx.data_opt::<Arc<AuthIdentity>>().map(AsRef::as_ref)
37}
38
39pub fn resolve_graphql<T>(ctx: &async_graphql::Context<'_>) -> async_graphql::Result<Arc<T>>
40where
41    T: Send + Sync + 'static,
42{
43    let container = graphql_container(ctx)?;
44    container.resolve::<T>().map_err(|_| {
45        async_graphql::Error::new(format!(
46            "Failed to resolve dependency `{}` from GraphQL context",
47            type_name::<T>()
48        ))
49    })
50}
51
52#[derive(Debug, Clone)]
53pub struct GraphQlConfig {
54    pub endpoint: String,
55    pub graphiql_endpoint: Option<String>,
56    pub max_request_bytes: usize,
57}
58
59impl Default for GraphQlConfig {
60    fn default() -> Self {
61        Self {
62            endpoint: "/graphql".to_string(),
63            graphiql_endpoint: Some("/graphiql".to_string()),
64            max_request_bytes: 1024 * 1024,
65        }
66    }
67}
68
69impl GraphQlConfig {
70    pub fn new(endpoint: impl Into<String>) -> Self {
71        Self {
72            endpoint: normalize_path(endpoint.into()),
73            ..Self::default()
74        }
75    }
76
77    pub fn with_graphiql(mut self, path: impl Into<String>) -> Self {
78        self.graphiql_endpoint = Some(normalize_path(path.into()));
79        self
80    }
81
82    pub fn without_graphiql(mut self) -> Self {
83        self.graphiql_endpoint = None;
84        self
85    }
86
87    pub fn with_max_request_bytes(mut self, bytes: usize) -> Self {
88        self.max_request_bytes = bytes;
89        self
90    }
91}
92
93pub fn graphql_router<Query, Mutation, Subscription>(
94    schema: GraphQlSchema<Query, Mutation, Subscription>,
95) -> Router<Container>
96where
97    Query: ObjectType + Send + Sync + 'static,
98    Mutation: ObjectType + Send + Sync + 'static,
99    Subscription: SubscriptionType + Send + Sync + 'static,
100{
101    graphql_router_with_config(schema, GraphQlConfig::default())
102}
103
104pub fn graphql_router_with_config<Query, Mutation, Subscription>(
105    schema: GraphQlSchema<Query, Mutation, Subscription>,
106    config: GraphQlConfig,
107) -> Router<Container>
108where
109    Query: ObjectType + Send + Sync + 'static,
110    Mutation: ObjectType + Send + Sync + 'static,
111    Subscription: SubscriptionType + Send + Sync + 'static,
112{
113    let max_request_bytes = config.max_request_bytes;
114    let mut router = Router::new()
115        .route(
116            &config.endpoint,
117            post(
118                move |container, scoped_container, request_id, auth_identity, schema, request| {
119                    graphql_handler::<Query, Mutation, Subscription>(
120                        max_request_bytes,
121                        container,
122                        scoped_container,
123                        request_id,
124                        auth_identity,
125                        schema,
126                        request,
127                    )
128                },
129            ),
130        )
131        .layer(DefaultBodyLimit::max(config.max_request_bytes))
132        .layer(Extension(schema));
133
134    if let Some(graphiql_endpoint) = &config.graphiql_endpoint {
135        let endpoint = config.endpoint.clone();
136        let graphiql_html = GraphiQLSource::build().endpoint(&endpoint).finish();
137        router = router.route(
138            graphiql_endpoint,
139            get(move || {
140                let html = graphiql_html.clone();
141                async move { Html(html) }
142            }),
143        );
144    }
145
146    router
147}
148
149async fn graphql_handler<Query, Mutation, Subscription>(
150    max_request_bytes: usize,
151    State(container): State<Container>,
152    scoped_container: Option<Extension<Container>>,
153    request_id: Option<Extension<RequestId>>,
154    auth_identity: Option<Extension<Arc<AuthIdentity>>>,
155    Extension(schema): Extension<GraphQlSchema<Query, Mutation, Subscription>>,
156    request: Request,
157) -> Response
158where
159    Query: ObjectType + Send + Sync + 'static,
160    Mutation: ObjectType + Send + Sync + 'static,
161    Subscription: SubscriptionType + Send + Sync + 'static,
162{
163    if request
164        .headers()
165        .get(header::CONTENT_LENGTH)
166        .and_then(|value| value.to_str().ok())
167        .and_then(|value| value.parse::<usize>().ok())
168        .is_some_and(|length| length > max_request_bytes)
169    {
170        return StatusCode::PAYLOAD_TOO_LARGE.into_response();
171    }
172
173    let request =
174        match GraphQLRequest::<async_graphql_axum::rejection::GraphQLRejection>::from_request(
175            request.with_limited_body(),
176            &(),
177        )
178        .await
179        {
180            Ok(request) => request,
181            Err(rejection) => return rejection.into_response(),
182        };
183
184    let container = scoped_container.map(|value| value.0).unwrap_or(container);
185    let mut request = request.into_inner().data(container);
186    if let Some(Extension(request_id)) = request_id {
187        request = request.data(request_id);
188    }
189    if let Some(Extension(auth_identity)) = auth_identity {
190        request = request.data(auth_identity);
191    }
192
193    GraphQLResponse::from(schema.execute(request).await).into_response()
194}
195
196fn normalize_path(path: String) -> String {
197    let trimmed = path.trim();
198    if trimmed.is_empty() || trimmed == "/" {
199        return "/graphql".to_string();
200    }
201
202    if trimmed.starts_with('/') {
203        trimmed.to_string()
204    } else {
205        format!("/{trimmed}")
206    }
207}