use crate::app::extract_app_meta;
use crate::context::{partition_context_params, should_inject_context};
use heck::ToLowerCamelCase;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use server_less_parse::{MethodInfo, extract_methods, get_impl_name, partition_methods};
use syn::{ItemImpl, Token, parse::Parse};
use crate::server_attrs::{has_server_hidden, has_server_skip, validate_server_attrs};
#[derive(Default)]
pub(crate) struct GraphqlArgs {
pub name: Option<String>,
pub enums: Vec<syn::Ident>,
pub inputs: Vec<syn::Ident>,
}
impl Parse for GraphqlArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut args = GraphqlArgs::default();
while !input.is_empty() {
let ident: syn::Ident = input.parse()?;
match ident.to_string().as_str() {
"name" => {
input.parse::<Token![=]>()?;
let lit: syn::LitStr = input.parse()?;
args.name = Some(lit.value());
}
"enums" => {
let content;
syn::parenthesized!(content in input);
let enum_types = content.parse_terminated(syn::Ident::parse, Token![,])?;
args.enums = enum_types.into_iter().collect();
}
"inputs" => {
let content;
syn::parenthesized!(content in input);
let input_types = content.parse_terminated(syn::Ident::parse, Token![,])?;
args.inputs = input_types.into_iter().collect();
}
other => {
const VALID: &[&str] = &["name", "enums", "inputs"];
let suggestion = crate::did_you_mean(other, VALID)
.map(|s| format!(" — did you mean `{s}`?"))
.unwrap_or_default();
return Err(syn::Error::new(
ident.span(),
format!(
"unknown argument `{other}`{suggestion}\n\
\n\
Valid arguments: name, enums, inputs\n\
\n\
Examples:\n\
- #[graphql(name = \"UserAPI\")]\n\
- #[graphql(enums(Status, Priority))]\n\
- #[graphql(inputs(CreateUserInput))]\n\
- #[graphql(name = \"MyAPI\", enums(Status), inputs(CreateUserInput))]"
),
));
}
}
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
}
}
Ok(args)
}
}
pub(crate) fn expand_graphql(args: GraphqlArgs, mut impl_block: ItemImpl) -> syn::Result<TokenStream2> {
crate::reject_generic_impl(&impl_block)?;
let app_meta = extract_app_meta(&mut impl_block.attrs);
let effective_name = args.name.or(app_meta.name);
let _ = effective_name; let struct_name = get_impl_name(&impl_block)?;
let (impl_generics, _ty_generics, where_clause) = impl_block.generics.split_for_impl();
let self_ty = &impl_block.self_ty;
let methods = extract_methods(&impl_block)?;
for m in &methods {
validate_server_attrs(m)?;
}
let partitioned = partition_methods(&methods, has_server_skip);
let visible_leaf: Vec<_> = partitioned
.leaf
.iter()
.copied()
.filter(|m| !has_server_hidden(m))
.collect();
let leaf_methods = &visible_leaf;
let (query_methods, mutation_methods): (Vec<_>, Vec<_>) = leaf_methods
.iter()
.copied()
.partition(|m| is_query_method(&m.name_str()));
let query_fields = generate_field_registrations(&query_methods);
let mutation_fields = generate_field_registrations(&mutation_methods);
let query_resolvers = generate_resolver_dispatch(&struct_name, &query_methods);
let mutation_resolvers = generate_resolver_dispatch(&struct_name, &mutation_methods);
let query_type_name = format!("{}Query", struct_name);
let mutation_type_name = format!("{}Mutation", struct_name);
let mount_query_merges: Vec<_> = partitioned
.static_mounts
.iter()
.map(|mount| {
let method_ident = &mount.name;
let inner_ty = mount.return_info.reference_inner.as_ref().unwrap();
quote! {
{
let child_arc = ::std::sync::Arc::new(service.#method_ident().clone());
obj = #inner_ty::__graphql_merge_query_fields(obj, child_arc);
}
}
})
.collect();
let mount_mutation_merges: Vec<_> = partitioned
.static_mounts
.iter()
.map(|mount| {
let method_ident = &mount.name;
let inner_ty = mount.return_info.reference_inner.as_ref().unwrap();
quote! {
{
let child_arc = ::std::sync::Arc::new(service.#method_ident().clone());
let (updated_obj, added) = #inner_ty::__graphql_merge_mutation_fields(obj, child_arc);
obj = updated_obj;
mutation_field_count += added;
}
}
})
.collect();
let has_own_mutations = !mutation_methods.is_empty();
let has_mounts = !partitioned.static_mounts.is_empty();
let custom_scalars = collect_custom_scalars(leaf_methods);
let scalar_registrations: Vec<_> = custom_scalars
.iter()
.map(|name| {
quote! {
.register(Scalar::new(#name))
}
})
.collect();
let enum_registrations: Vec<_> = args
.enums
.iter()
.map(|enum_type| {
quote! {
.register(#enum_type::__graphql_enum_type())
}
})
.collect();
let input_registrations: Vec<_> = args
.inputs
.iter()
.map(|input_type| {
quote! {
.register(#input_type::__graphql_input_type())
}
})
.collect();
let schema_build = if has_own_mutations && has_mounts {
quote! {
let mut mutation_field_count: usize = 0;
let mutation = {
let service = service.clone();
let mut obj = Object::new(#mutation_type_name);
#(
{
let service = service.clone();
mutation_field_count += 1;
#mutation_fields
}
)*
#(#mount_mutation_merges)*
obj
};
Schema::build(#query_type_name, Some(#mutation_type_name), None)
.register(query)
.register(mutation)
#(#scalar_registrations)*
#(#enum_registrations)*
#(#input_registrations)*
.finish()
.expect("Failed to build GraphQL schema")
}
} else if has_own_mutations {
quote! {
let mutation = {
let service = service.clone();
let mut obj = Object::new(#mutation_type_name);
#(
{
let service = service.clone();
#mutation_fields
}
)*
obj
};
Schema::build(#query_type_name, Some(#mutation_type_name), None)
.register(query)
.register(mutation)
#(#scalar_registrations)*
#(#enum_registrations)*
#(#input_registrations)*
.finish()
.expect("Failed to build GraphQL schema")
}
} else if has_mounts {
quote! {
let mut mutation_field_count: usize = 0;
let mutation = {
let mut obj = Object::new(#mutation_type_name);
#(#mount_mutation_merges)*
obj
};
if mutation_field_count > 0 {
Schema::build(#query_type_name, Some(#mutation_type_name), None)
.register(query)
.register(mutation)
#(#scalar_registrations)*
#(#enum_registrations)*
#(#input_registrations)*
.finish()
.expect("Failed to build GraphQL schema")
} else {
Schema::build(#query_type_name, None::<&str>, None)
.register(query)
#(#scalar_registrations)*
#(#enum_registrations)*
#(#input_registrations)*
.finish()
.expect("Failed to build GraphQL schema")
}
}
} else {
quote! {
Schema::build(#query_type_name, None::<&str>, None)
.register(query)
#(#scalar_registrations)*
#(#enum_registrations)*
#(#input_registrations)*
.finish()
.expect("Failed to build GraphQL schema")
}
};
let merge_query_helper = generate_merge_query_helper(&struct_name, &query_methods);
let merge_mutation_helper = generate_merge_mutation_helper(&struct_name, &mutation_methods);
let maybe_impl = if crate::is_protocol_impl_emitter(&impl_block, "graphql") {
quote! { #impl_block }
} else {
quote! {}
};
Ok(quote! {
#maybe_impl
impl #impl_generics #self_ty #where_clause {
#[doc(hidden)]
#[allow(dead_code)]
fn __graphql_json_to_value(json_val: ::serde_json::Value) -> ::async_graphql::Value {
match json_val {
::serde_json::Value::Null => ::async_graphql::Value::Null,
::serde_json::Value::Bool(b) => ::async_graphql::Value::Boolean(b),
::serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
::async_graphql::Value::Number((i as i32).into())
} else if let Some(f) = n.as_f64() {
match ::serde_json::to_value(f) {
Ok(::serde_json::Value::Number(num)) => {
::async_graphql::Value::Number(num.into())
}
_ => ::async_graphql::Value::String(f.to_string()),
}
} else {
::async_graphql::Value::Number(n.into())
}
}
::serde_json::Value::String(s) => ::async_graphql::Value::String(s),
::serde_json::Value::Array(arr) => {
let values: Vec<_> = arr
.into_iter()
.map(Self::__graphql_json_to_value)
.collect();
::async_graphql::Value::List(values)
}
::serde_json::Value::Object(obj) => {
let mut fields = ::async_graphql::indexmap::IndexMap::new();
for (key, value) in obj {
fields.insert(
::async_graphql::Name::new(key),
Self::__graphql_json_to_value(value),
);
}
::async_graphql::Value::Object(fields)
}
}
}
#[doc(hidden)]
#[allow(dead_code)]
fn __graphql_to_value<T>(v: T) -> ::async_graphql::Value
where
T: ::serde::Serialize + ::std::fmt::Debug,
{
if let Ok(json_val) = ::serde_json::to_value(&v) {
Self::__graphql_json_to_value(json_val)
} else {
::async_graphql::Value::String(format!("{:?}", v))
}
}
pub fn graphql_schema(self) -> ::async_graphql::dynamic::Schema
where
Self: Clone + Send + Sync + 'static,
{
use ::async_graphql::dynamic::*;
let service = ::std::sync::Arc::new(self);
let query = {
let service = service.clone();
let mut obj = Object::new(#query_type_name);
#(
{
let service = service.clone();
#query_fields
}
)*
#(#mount_query_merges)*
obj
};
#schema_build
}
#merge_query_helper
#merge_mutation_helper
pub fn graphql_router(self) -> ::server_less::axum::Router
where
Self: Clone + Send + Sync + 'static,
{
use ::server_less::axum::routing::{get, post};
use ::server_less::axum::response::IntoResponse;
let schema = self.graphql_schema();
async fn graphql_handler(
schema: ::server_less::axum::extract::State<::async_graphql::dynamic::Schema>,
req: ::async_graphql_axum::GraphQLRequest,
) -> ::async_graphql_axum::GraphQLResponse {
schema.execute(req.into_inner()).await.into()
}
async fn playground() -> impl IntoResponse {
::server_less::axum::response::Html(
::async_graphql::http::playground_source(
::async_graphql::http::GraphQLPlaygroundConfig::new("/graphql")
)
)
}
::server_less::axum::Router::new()
.route("/graphql", get(playground).post(graphql_handler))
.with_state(schema)
}
pub fn graphql_sdl(self) -> String
where
Self: Clone + Send + Sync + 'static,
{
self.graphql_schema().sdl()
}
pub fn graphql_openapi_paths() -> ::std::vec::Vec<::server_less::OpenApiPath> {
vec![
::server_less::OpenApiPath {
path: "/graphql".to_string(),
method: "post".to_string(),
operation: ::server_less::OpenApiOperation {
summary: Some("GraphQL query endpoint".to_string()),
description: None,
operation_id: Some("graphql_query".to_string()),
tags: vec!["graphql".to_string()],
deprecated: false,
parameters: vec![],
request_body: Some(::server_less::serde_json::json!({
"required": true,
"content": {
"application/json": {
"schema": {
"type": "object",
"required": ["query"],
"properties": {
"query": {
"type": "string",
"description": "GraphQL query string"
},
"operationName": {
"type": "string",
"description": "Optional operation name"
},
"variables": {
"type": "object",
"description": "Optional query variables"
}
}
}
}
}
})),
responses: {
let mut r = ::server_less::serde_json::Map::new();
r.insert("200".to_string(), ::server_less::serde_json::json!({
"description": "GraphQL response",
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"data": {},
"errors": {
"type": "array",
"items": {
"type": "object",
"properties": {
"message": {"type": "string"},
"locations": {"type": "array"},
"path": {"type": "array"}
}
}
}
}
}
}
}
}));
r
},
extra: ::server_less::serde_json::Map::new(),
},
},
::server_less::OpenApiPath {
path: "/graphql".to_string(),
method: "get".to_string(),
operation: ::server_less::OpenApiOperation {
summary: Some("GraphQL Playground".to_string()),
description: None,
operation_id: Some("graphql_playground".to_string()),
tags: vec!["graphql".to_string()],
deprecated: false,
parameters: vec![],
request_body: None,
responses: {
let mut r = ::server_less::serde_json::Map::new();
r.insert("200".to_string(), ::server_less::serde_json::json!({
"description": "GraphQL Playground HTML page",
"content": {
"text/html": {
"schema": {"type": "string"}
}
}
}));
r
},
extra: ::server_less::serde_json::Map::new(),
},
}
]
}
fn __graphql_resolve_query(
service: &::std::sync::Arc<Self>,
method: &str,
args: &::async_graphql::dynamic::ResolverContext,
) -> ::async_graphql::Result<::async_graphql::Value>
where
Self: Send + Sync,
{
match method {
#(#query_resolvers)*
_ => Err(::async_graphql::Error::new(format!("Unknown query: {}", method))),
}
}
fn __graphql_resolve_mutation(
service: &::std::sync::Arc<Self>,
method: &str,
args: &::async_graphql::dynamic::ResolverContext,
) -> ::async_graphql::Result<::async_graphql::Value>
where
Self: Send + Sync,
{
match method {
#(#mutation_resolvers)*
_ => Err(::async_graphql::Error::new(format!("Unknown mutation: {}", method))),
}
}
}
})
}
fn is_query_method(name: &str) -> bool {
name.starts_with("get_")
|| name.starts_with("fetch_")
|| name.starts_with("read_")
|| name.starts_with("list_")
|| name.starts_with("find_")
|| name.starts_with("search_")
|| name.starts_with("count_")
|| name.starts_with("exists_")
|| name.starts_with("is_")
|| name.starts_with("has_")
}
fn generate_field_registrations(methods: &[&MethodInfo]) -> Vec<TokenStream2> {
methods
.iter()
.map(|m| {
let field_code = generate_field_registration(m);
let cfg_attrs = &m.cfg_attrs;
quote! {
#(#cfg_attrs)*
{ #field_code }
}
})
.collect()
}
fn generate_field_registration(method: &MethodInfo) -> TokenStream2 {
let method_name = method.name_str();
let method_ident = &method.name;
let field_name = method_name.to_lower_camel_case();
let description = method.docs.clone().unwrap_or_default();
let ret = &method.return_info;
let (type_ref, is_list) = infer_graphql_type_ref(ret);
let (_ctx_param, user_params) =
partition_context_params(&method.params).unwrap_or((None, method.params.iter().collect()));
let arg_registrations: Vec<_> = user_params
.iter()
.map(|p| {
let arg_name = p.name_str();
let gql_type = rust_type_to_graphql(&p.ty);
let is_required = !p.is_optional;
if is_required {
quote! {
.argument(InputValue::new(#arg_name, TypeRef::named_nn(#gql_type)))
}
} else {
quote! {
.argument(InputValue::new(#arg_name, TypeRef::named(#gql_type)))
}
}
})
.collect();
let arg_extractions: Vec<_> = user_params.iter().map(|p| {
let arg_name = p.name_str();
let param_name = &p.name;
let ty = &p.ty;
if p.is_optional {
quote! {
let #param_name: #ty = ctx.args.try_get(#arg_name).ok()
.and_then(|v| v.deserialize().ok());
}
} else {
quote! {
let #param_name: #ty = ctx.args.try_get(#arg_name)
.map_err(|_| ::async_graphql::Error::new(format!("Missing argument: {}", #arg_name)))?
.deserialize()
.map_err(|_| ::async_graphql::Error::new(format!("Invalid argument: {}", #arg_name)))?;
}
}
}).collect();
let param_names: Vec<_> = method.params.iter().map(|p| {
if should_inject_context(&p.ty, &method.params) {
quote! { ::server_less::Context::new() }
} else {
let name = &p.name;
quote! { #name }
}
}).collect();
let method_call = if method.is_async {
quote! { service.#method_ident(#(#param_names),*).await }
} else {
quote! { service.#method_ident(#(#param_names),*) }
};
let result_conversion = if ret.is_unit {
quote! {
#method_call;
Ok(Some(::async_graphql::Value::Boolean(true)))
}
} else if ret.is_result {
if is_list {
quote! {
match #method_call {
Ok(items) => {
let values: Vec<_> = items.into_iter()
.map(|item| Self::__graphql_to_value(item))
.collect();
Ok(Some(::async_graphql::Value::List(values)))
}
Err(e) => Err(::async_graphql::Error::new(format!("{}", e))),
}
}
} else {
quote! {
match #method_call {
Ok(value) => Ok(Some(Self::__graphql_to_value(value))),
Err(e) => Err(::async_graphql::Error::new(format!("{}", e))),
}
}
}
} else if ret.is_option {
quote! {
match #method_call {
Some(value) => Ok(Some(Self::__graphql_to_value(value))),
None => Ok(None),
}
}
} else if is_list {
quote! {
let items = #method_call;
let values: Vec<_> = items.into_iter()
.map(|item| Self::__graphql_to_value(item))
.collect();
Ok(Some(::async_graphql::Value::List(values)))
}
} else {
quote! {
let result = #method_call;
Ok(Some(Self::__graphql_to_value(result)))
}
};
quote! {
let field = Field::new(#field_name, #type_ref, move |ctx| {
let service = service.clone();
FieldFuture::new(async move {
#(#arg_extractions)*
#result_conversion
})
})
.description(#description)
#(#arg_registrations)*;
obj = obj.field(field);
}
}
fn infer_graphql_type_ref(ret: &server_less_parse::ReturnInfo) -> (TokenStream2, bool) {
if ret.is_unit {
(quote! { TypeRef::named_nn(TypeRef::BOOLEAN) }, false)
} else if let Some(ref ty) = ret.ty {
let type_str = quote!(#ty).to_string();
let is_list = type_str.contains("Vec");
let base_type = if type_str.contains("DateTime") {
quote! { "DateTime" }
} else if type_str.contains("Uuid") {
quote! { "UUID" }
} else if type_str.contains("Url") {
quote! { "Url" }
} else if type_str.contains("serde_json :: Value") || type_str == "Value" {
quote! { "JSON" }
} else if type_str.contains("String") || type_str.contains("str") {
quote! { TypeRef::STRING }
} else if type_str.contains("i32")
|| type_str.contains("i64")
|| type_str.contains("u32")
|| type_str.contains("u64")
|| type_str.contains("usize")
{
quote! { TypeRef::INT }
} else if type_str.contains("f32") || type_str.contains("f64") {
quote! { TypeRef::FLOAT }
} else if type_str.contains("bool") {
quote! { TypeRef::BOOLEAN }
} else {
quote! { "JSON" }
};
if ret.is_option {
if is_list {
(
quote! { TypeRef::named(TypeRef::named_list(#base_type)) },
true,
)
} else {
(quote! { TypeRef::named(#base_type) }, false)
}
} else if ret.is_result {
if is_list {
(quote! { TypeRef::named_nn_list(#base_type) }, true)
} else {
(quote! { TypeRef::named_nn(#base_type) }, false)
}
} else if is_list {
(quote! { TypeRef::named_nn_list(#base_type) }, true)
} else {
(quote! { TypeRef::named_nn(#base_type) }, false)
}
} else {
(quote! { TypeRef::named_nn(TypeRef::BOOLEAN) }, false)
}
}
fn generate_resolver_dispatch(
struct_name: &syn::Ident,
methods: &[&MethodInfo],
) -> Vec<TokenStream2> {
methods
.iter()
.map(|m| generate_resolver_arm(struct_name, m))
.collect()
}
fn generate_resolver_arm(_struct_name: &syn::Ident, method: &MethodInfo) -> TokenStream2 {
let method_name_str = method.name_str();
quote! {
#method_name_str => {
unreachable!("BUG: resolver arm should not be called — dispatch is handled by FieldFuture")
}
}
}
fn rust_type_to_graphql(ty: &syn::Type) -> &'static str {
let type_str = quote!(#ty).to_string();
if type_str.contains("Vec") {
return extract_vec_inner_type(&type_str);
}
if type_str.contains("DateTime") {
return "DateTime";
}
if type_str.contains("Uuid") {
return "UUID";
}
if type_str.contains("Url") {
return "Url";
}
if type_str.contains("serde_json :: Value") || type_str == "Value" {
return "JSON";
}
let json_type = server_less_rpc::infer_json_type(ty);
match json_type {
"integer" => "Int",
"number" => "Float",
"boolean" => "Boolean",
"string" => "String",
_ => "String", }
}
fn extract_vec_inner_type(type_str: &str) -> &'static str {
if let Some(start) = type_str.find("Vec<") {
let inner = &type_str[start + 4..];
if let Some(end) = inner.find('>') {
let inner_type = inner[..end].trim();
return map_inner_type_to_graphql(inner_type);
}
}
"String"
}
fn map_inner_type_to_graphql(inner: &str) -> &'static str {
if inner.contains("DateTime") {
return "DateTime";
}
if inner.contains("Uuid") {
return "UUID";
}
if inner.contains("Url") {
return "Url";
}
if inner.contains("serde_json :: Value") || inner == "Value" {
return "JSON";
}
if inner.contains("String") || inner.contains("str") {
"String"
} else if inner.contains("i32")
|| inner.contains("i64")
|| inner.contains("u32")
|| inner.contains("u64")
|| inner.contains("isize")
|| inner.contains("usize")
{
"Int"
} else if inner.contains("f32") || inner.contains("f64") {
"Float"
} else if inner.contains("bool") {
"Boolean"
} else {
"JSON"
}
}
fn collect_custom_scalars(methods: &[&MethodInfo]) -> Vec<String> {
let mut scalars = std::collections::BTreeSet::new();
scalars.insert("JSON".to_string());
for method in methods {
for param in &method.params {
let ty = ¶m.ty;
check_type_for_scalars("e!(#ty).to_string(), &mut scalars);
}
if let Some(ref ty) = method.return_info.ty {
check_type_for_scalars("e!(#ty).to_string(), &mut scalars);
}
}
scalars.into_iter().collect()
}
fn generate_merge_query_helper(
_struct_name: &syn::Ident,
query_methods: &[&MethodInfo],
) -> TokenStream2 {
let field_registrations = generate_field_registrations(query_methods);
quote! {
#[doc(hidden)]
pub fn __graphql_merge_query_fields(
mut obj: ::async_graphql::dynamic::Object,
service: ::std::sync::Arc<Self>,
) -> ::async_graphql::dynamic::Object
where
Self: Clone + Send + Sync + 'static,
{
use ::async_graphql::dynamic::*;
#(
{
let service = service.clone();
#field_registrations
}
)*
obj
}
}
}
fn generate_merge_mutation_helper(
_struct_name: &syn::Ident,
mutation_methods: &[&MethodInfo],
) -> TokenStream2 {
let field_count = mutation_methods.len();
let field_registrations = generate_field_registrations(mutation_methods);
quote! {
#[doc(hidden)]
pub fn __graphql_merge_mutation_fields(
mut obj: ::async_graphql::dynamic::Object,
service: ::std::sync::Arc<Self>,
) -> (::async_graphql::dynamic::Object, usize)
where
Self: Clone + Send + Sync + 'static,
{
use ::async_graphql::dynamic::*;
#(
{
let service = service.clone();
#field_registrations
}
)*
(obj, #field_count)
}
}
}
fn check_type_for_scalars(type_str: &str, scalars: &mut std::collections::BTreeSet<String>) {
if type_str.contains("DateTime") {
scalars.insert("DateTime".to_string());
}
if type_str.contains("Uuid") {
scalars.insert("UUID".to_string());
}
if type_str.contains("Url") && !type_str.contains("UrlError") {
scalars.insert("Url".to_string());
}
if type_str.contains("serde_json :: Value") || type_str == "Value" {
scalars.insert("JSON".to_string());
}
}