nestforge-graphql 1.2.0

GraphQL support surface for NestForge built on async-graphql
Documentation
use std::{any::type_name, sync::Arc};

use async_graphql::{
    http::GraphiQLSource, EmptyMutation, EmptySubscription, ObjectType, Schema, SubscriptionType,
};
use async_graphql_axum::{GraphQLRequest, GraphQLResponse};
use axum::{
    extract::{Extension, State},
    response::Html,
    routing::{get, post},
    Router,
};
use nestforge_core::{AuthIdentity, Container, RequestId};

pub use async_graphql;

pub type GraphQlSchema<Query, Mutation = EmptyMutation, Subscription = EmptySubscription> =
    Schema<Query, Mutation, Subscription>;

pub fn graphql_container<'ctx>(
    ctx: &'ctx async_graphql::Context<'ctx>,
) -> async_graphql::Result<&'ctx Container> {
    ctx.data::<Container>()
}

pub fn graphql_request_id<'ctx>(
    ctx: &'ctx async_graphql::Context<'ctx>,
) -> Option<&'ctx RequestId> {
    ctx.data_opt::<RequestId>()
}

pub fn graphql_auth_identity<'ctx>(
    ctx: &'ctx async_graphql::Context<'ctx>,
) -> Option<&'ctx AuthIdentity> {
    ctx.data_opt::<Arc<AuthIdentity>>().map(AsRef::as_ref)
}

pub fn resolve_graphql<T>(ctx: &async_graphql::Context<'_>) -> async_graphql::Result<Arc<T>>
where
    T: Send + Sync + 'static,
{
    let container = graphql_container(ctx)?;
    container.resolve::<T>().map_err(|_| {
        async_graphql::Error::new(format!(
            "Failed to resolve dependency `{}` from GraphQL context",
            type_name::<T>()
        ))
    })
}

#[derive(Debug, Clone)]
pub struct GraphQlConfig {
    pub endpoint: String,
    pub graphiql_endpoint: Option<String>,
}

impl Default for GraphQlConfig {
    fn default() -> Self {
        Self {
            endpoint: "/graphql".to_string(),
            graphiql_endpoint: Some("/graphiql".to_string()),
        }
    }
}

impl GraphQlConfig {
    pub fn new(endpoint: impl Into<String>) -> Self {
        Self {
            endpoint: normalize_path(endpoint.into()),
            ..Self::default()
        }
    }

    pub fn with_graphiql(mut self, path: impl Into<String>) -> Self {
        self.graphiql_endpoint = Some(normalize_path(path.into()));
        self
    }

    pub fn without_graphiql(mut self) -> Self {
        self.graphiql_endpoint = None;
        self
    }
}

pub fn graphql_router<Query, Mutation, Subscription>(
    schema: GraphQlSchema<Query, Mutation, Subscription>,
) -> Router<Container>
where
    Query: ObjectType + Send + Sync + 'static,
    Mutation: ObjectType + Send + Sync + 'static,
    Subscription: SubscriptionType + Send + Sync + 'static,
{
    graphql_router_with_config(schema, GraphQlConfig::default())
}

pub fn graphql_router_with_config<Query, Mutation, Subscription>(
    schema: GraphQlSchema<Query, Mutation, Subscription>,
    config: GraphQlConfig,
) -> Router<Container>
where
    Query: ObjectType + Send + Sync + 'static,
    Mutation: ObjectType + Send + Sync + 'static,
    Subscription: SubscriptionType + Send + Sync + 'static,
{
    let mut router = Router::new()
        .route(&config.endpoint, post(graphql_handler::<Query, Mutation, Subscription>))
        .layer(Extension(schema));

    if let Some(graphiql_endpoint) = &config.graphiql_endpoint {
        let endpoint = config.endpoint.clone();
        let graphiql_html = GraphiQLSource::build().endpoint(&endpoint).finish();
        router = router.route(
            graphiql_endpoint,
            get(move || {
                let html = graphiql_html.clone();
                async move { Html(html) }
            }),
        );
    }

    router
}

async fn graphql_handler<Query, Mutation, Subscription>(
    State(container): State<Container>,
    scoped_container: Option<Extension<Container>>,
    Extension(request_id): Extension<RequestId>,
    auth_identity: Option<Extension<Arc<AuthIdentity>>>,
    Extension(schema): Extension<GraphQlSchema<Query, Mutation, Subscription>>,
    request: GraphQLRequest,
) -> GraphQLResponse
where
    Query: ObjectType + Send + Sync + 'static,
    Mutation: ObjectType + Send + Sync + 'static,
    Subscription: SubscriptionType + Send + Sync + 'static,
{
    let container = scoped_container.map(|value| value.0).unwrap_or(container);
    let mut request = request.into_inner().data(container);
    request = request.data(request_id);
    if let Some(Extension(auth_identity)) = auth_identity {
        request = request.data(auth_identity);
    }

    schema.execute(request).await.into()
}

fn normalize_path(path: String) -> String {
    let trimmed = path.trim();
    if trimmed.is_empty() || trimmed == "/" {
        return "/graphql".to_string();
    }

    if trimmed.starts_with('/') {
        trimmed.to_string()
    } else {
        format!("/{trimmed}")
    }
}