pub mod attr;
pub mod derive;
use std::collections::HashMap;
use proc_macro2::TokenStream;
use quote::{ToTokens, format_ident, quote};
use syn::{
ext::IdentExt as _,
parse::{Parse, ParseStream},
parse_quote,
spanned::Spanned as _,
token,
};
use crate::common::{
AttrNames, Description, SpanContainer, filter_attrs, generate,
parse::{
ParseBufferExt as _,
attr::{OptionExt as _, err},
},
scalar,
};
type AttrResolvers = HashMap<syn::Type, SpanContainer<syn::ExprPath>>;
#[derive(Debug, Default)]
struct Attr {
name: Option<SpanContainer<String>>,
description: Option<SpanContainer<Description>>,
context: Option<SpanContainer<syn::Type>>,
scalar: Option<SpanContainer<scalar::AttrValue>>,
external_resolvers: AttrResolvers,
is_internal: bool,
}
impl Parse for Attr {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let mut out = Self::default();
while !input.is_empty() {
let ident = input.parse::<syn::Ident>()?;
match ident.to_string().as_str() {
"name" => {
input.parse::<token::Eq>()?;
let name = input.parse::<syn::LitStr>()?;
out.name
.replace(SpanContainer::new(
ident.span(),
Some(name.span()),
name.value(),
))
.none_or_else(|_| err::dup_arg(&ident))?
}
"desc" | "description" => {
input.parse::<token::Eq>()?;
let desc = input.parse::<Description>()?;
out.description
.replace(SpanContainer::new(ident.span(), Some(desc.span()), desc))
.none_or_else(|_| err::dup_arg(&ident))?
}
"ctx" | "context" | "Context" => {
input.parse::<token::Eq>()?;
let ctx = input.parse::<syn::Type>()?;
out.context
.replace(SpanContainer::new(ident.span(), Some(ctx.span()), ctx))
.none_or_else(|_| err::dup_arg(&ident))?
}
"scalar" | "Scalar" | "ScalarValue" => {
input.parse::<token::Eq>()?;
let scl = input.parse::<scalar::AttrValue>()?;
out.scalar
.replace(SpanContainer::new(ident.span(), Some(scl.span()), scl))
.none_or_else(|_| err::dup_arg(&ident))?
}
"on" => {
let ty = input.parse::<syn::Type>()?;
input.parse::<token::Eq>()?;
let rslvr = input.parse::<syn::ExprPath>()?;
let rslvr_spanned = SpanContainer::new(ident.span(), Some(ty.span()), rslvr);
let rslvr_span = rslvr_spanned.span_joined();
out.external_resolvers
.insert(ty, rslvr_spanned)
.none_or_else(|_| err::dup_arg(rslvr_span))?
}
"internal" => {
out.is_internal = true;
}
name => {
return Err(err::unknown_arg(&ident, name));
}
}
input.try_parse::<token::Comma>()?;
}
Ok(out)
}
}
impl Attr {
fn try_merge(self, mut another: Self) -> syn::Result<Self> {
Ok(Self {
name: try_merge_opt!(name: self, another),
description: try_merge_opt!(description: self, another),
context: try_merge_opt!(context: self, another),
scalar: try_merge_opt!(scalar: self, another),
external_resolvers: try_merge_hashmap!(
external_resolvers: self, another => span_joined
),
is_internal: self.is_internal || another.is_internal,
})
}
fn from_attrs(names: impl AttrNames, attrs: &[syn::Attribute]) -> syn::Result<Self> {
let mut meta = filter_attrs(names, attrs)
.map(|attr| attr.parse_args())
.try_fold(Self::default(), |prev, curr| prev.try_merge(curr?))?;
if meta.description.is_none() {
meta.description = Description::parse_from_doc_attrs(attrs)?;
}
Ok(meta)
}
}
#[derive(Debug, Default)]
struct VariantAttr {
ignore: Option<SpanContainer<syn::Ident>>,
external_resolver: Option<SpanContainer<syn::ExprPath>>,
}
impl Parse for VariantAttr {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let mut out = Self::default();
while !input.is_empty() {
let ident = input.parse::<syn::Ident>()?;
match ident.to_string().as_str() {
"ignore" | "skip" => out
.ignore
.replace(SpanContainer::new(ident.span(), None, ident.clone()))
.none_or_else(|_| err::dup_arg(&ident))?,
"with" => {
input.parse::<token::Eq>()?;
let rslvr = input.parse::<syn::ExprPath>()?;
out.external_resolver
.replace(SpanContainer::new(ident.span(), Some(rslvr.span()), rslvr))
.none_or_else(|_| err::dup_arg(&ident))?
}
name => {
return Err(err::unknown_arg(&ident, name));
}
}
input.try_parse::<token::Comma>()?;
}
Ok(out)
}
}
impl VariantAttr {
fn try_merge(self, mut another: Self) -> syn::Result<Self> {
Ok(Self {
ignore: try_merge_opt!(ignore: self, another),
external_resolver: try_merge_opt!(external_resolver: self, another),
})
}
fn from_attrs(name: &str, attrs: &[syn::Attribute]) -> syn::Result<Self> {
filter_attrs(name, attrs)
.map(|attr| attr.parse_args())
.try_fold(Self::default(), |prev, curr| prev.try_merge(curr?))
}
}
struct Definition {
name: String,
ty: syn::Type,
generics: syn::Generics,
is_trait_object: bool,
description: Option<Description>,
context: syn::Type,
scalar: scalar::Type,
variants: Vec<VariantDefinition>,
}
impl ToTokens for Definition {
fn to_tokens(&self, into: &mut TokenStream) {
self.impl_graphql_union_tokens().to_tokens(into);
self.impl_output_type_tokens().to_tokens(into);
self.impl_graphql_type_tokens().to_tokens(into);
self.impl_graphql_value_tokens().to_tokens(into);
self.impl_graphql_value_async_tokens().to_tokens(into);
self.impl_reflection_traits_tokens().to_tokens(into);
}
}
impl Definition {
#[must_use]
fn impl_generics(
&self,
for_async: bool,
) -> (TokenStream, TokenStream, Option<syn::WhereClause>) {
let (_, ty_generics, _) = self.generics.split_for_impl();
let ty = &self.ty;
let mut ty_full = quote! { #ty #ty_generics };
if self.is_trait_object {
ty_full =
quote! { dyn #ty_full + '__obj + ::core::marker::Send + ::core::marker::Sync };
}
let mut generics = self.generics.clone();
if self.is_trait_object {
generics.params.push(parse_quote! { '__obj });
}
let scalar = &self.scalar;
if scalar.is_implicit_generic() {
generics.params.push(parse_quote! { #scalar });
}
if scalar.is_generic() {
generics
.make_where_clause()
.predicates
.push(parse_quote! { #scalar: ::juniper::ScalarValue });
}
if let Some(bound) = scalar.bounds() {
generics.make_where_clause().predicates.push(bound);
}
if for_async {
let self_ty = if !self.is_trait_object && self.generics.lifetimes().next().is_some() {
let mut generics = self.generics.clone();
for lt in generics.lifetimes_mut() {
let ident = lt.lifetime.ident.unraw();
lt.lifetime.ident = format_ident!("__fa__{ident}");
}
let lifetimes = generics.lifetimes().map(|lt| <.lifetime);
let ty = &self.ty;
let (_, ty_generics, _) = generics.split_for_impl();
quote! { for<#( #lifetimes ),*> #ty #ty_generics }
} else {
quote! { Self }
};
generics
.make_where_clause()
.predicates
.push(parse_quote! { #self_ty: ::core::marker::Sync });
if scalar.is_generic() {
generics
.make_where_clause()
.predicates
.push(parse_quote! { #scalar: ::core::marker::Send + ::core::marker::Sync });
}
}
let (impl_generics, _, where_clause) = generics.split_for_impl();
(
quote! { #impl_generics },
quote! { #ty_full },
where_clause.cloned(),
)
}
#[must_use]
fn impl_graphql_union_tokens(&self) -> TokenStream {
let scalar = &self.scalar;
let (impl_generics, ty_full, where_clause) = self.impl_generics(false);
let variant_tys: Vec<_> = self.variants.iter().map(|var| &var.ty).collect();
let all_variants_unique = (variant_tys.len() > 1).then(|| {
quote! { ::juniper::sa::assert_type_ne_all!(#( #variant_tys ),*); }
});
quote! {
#[automatically_derived]
impl #impl_generics ::juniper::marker::GraphQLUnion<#scalar> for #ty_full #where_clause
{
fn mark() {
#all_variants_unique
#( <#variant_tys as ::juniper::marker::GraphQLObject<#scalar>>::mark(); )*
}
}
}
}
#[must_use]
fn impl_output_type_tokens(&self) -> TokenStream {
let scalar = &self.scalar;
let (impl_generics, ty_full, where_clause) = self.impl_generics(false);
let variant_tys = self.variants.iter().map(|var| &var.ty);
quote! {
#[automatically_derived]
impl #impl_generics ::juniper::marker::IsOutputType<#scalar> for #ty_full #where_clause
{
fn mark() {
#( <#variant_tys as ::juniper::marker::IsOutputType<#scalar>>::mark(); )*
}
}
}
}
#[must_use]
fn impl_graphql_type_tokens(&self) -> TokenStream {
let scalar = &self.scalar;
let (impl_generics, ty_full, where_clause) = self.impl_generics(false);
let name = &self.name;
let description = &self.description;
let variant_tys = self.variants.iter().map(|var| &var.ty);
quote! {
#[automatically_derived]
impl #impl_generics ::juniper::GraphQLType<#scalar> for #ty_full #where_clause
{
fn name(_ : &Self::TypeInfo) -> ::core::option::Option<::juniper::ArcStr> {
::core::option::Option::Some(::juniper::arcstr::literal!(#name))
}
fn meta(
info: &Self::TypeInfo,
registry: &mut ::juniper::Registry<#scalar>,
) -> ::juniper::meta::MetaType<#scalar> {
let types = [
#( registry.get_type::<#variant_tys>(info), )*
];
registry.build_union_type::<#ty_full>(info, &types)
#description
.into_meta()
}
}
}
}
#[must_use]
fn impl_graphql_value_tokens(&self) -> TokenStream {
let scalar = &self.scalar;
let context = &self.context;
let (impl_generics, ty_full, where_clause) = self.impl_generics(false);
let name = &self.name;
let match_variant_names = self
.variants
.iter()
.map(|v| v.method_concrete_type_name_tokens(scalar));
let variant_resolvers = self
.variants
.iter()
.map(|v| v.method_resolve_into_type_tokens(scalar));
quote! {
#[automatically_derived]
impl #impl_generics ::juniper::GraphQLValue<#scalar> for #ty_full #where_clause
{
type Context = #context;
type TypeInfo = ();
fn type_name(
&self,
info: &Self::TypeInfo,
) -> ::core::option::Option<::juniper::ArcStr> {
<Self as ::juniper::GraphQLType<#scalar>>::name(info)
}
fn concrete_type_name(
&self,
context: &Self::Context,
info: &Self::TypeInfo,
) -> ::std::string::String {
#( #match_variant_names )*
::core::panic!(
"GraphQL union `{}` cannot be resolved into any of its \
variants in its current state",
#name,
);
}
fn resolve_into_type(
&self,
info: &Self::TypeInfo,
type_name: &::core::primitive::str,
_: ::core::option::Option<&[::juniper::Selection<'_, #scalar>]>,
executor: &::juniper::Executor<'_, '_, Self::Context, #scalar>,
) -> ::juniper::ExecutionResult<#scalar> {
let context = executor.context();
#( #variant_resolvers )*
return ::core::result::Result::Err(::juniper::FieldError::from(::std::format!(
"Concrete type `{}` is not handled by instance \
resolvers on GraphQL union `{}`",
type_name, #name,
)));
}
}
}
}
#[must_use]
fn impl_graphql_value_async_tokens(&self) -> TokenStream {
let scalar = &self.scalar;
let (impl_generics, ty_full, where_clause) = self.impl_generics(true);
let name = &self.name;
let variant_async_resolvers = self
.variants
.iter()
.map(|v| v.method_resolve_into_type_async_tokens(scalar));
quote! {
#[allow(non_snake_case)]
#[automatically_derived]
impl #impl_generics ::juniper::GraphQLValueAsync<#scalar> for #ty_full #where_clause
{
fn resolve_into_type_async<'b>(
&'b self,
info: &'b Self::TypeInfo,
type_name: &::core::primitive::str,
_: ::core::option::Option<&'b [::juniper::Selection<'b, #scalar>]>,
executor: &'b ::juniper::Executor<'b, 'b, Self::Context, #scalar>
) -> ::juniper::BoxFuture<'b, ::juniper::ExecutionResult<#scalar>> {
let context = executor.context();
#( #variant_async_resolvers )*
return ::juniper::macros::helper::err_fut(::std::format!(
"Concrete type `{}` is not handled by instance \
resolvers on GraphQL union `{}`",
type_name, #name,
));
}
}
}
}
#[must_use]
pub(crate) fn impl_reflection_traits_tokens(&self) -> TokenStream {
let scalar = &self.scalar;
let name = &self.name;
let variants = self.variants.iter().map(|var| &var.ty);
let (impl_generics, ty, where_clause) = self.impl_generics(false);
quote! {
#[automatically_derived]
impl #impl_generics ::juniper::macros::reflect::BaseType<#scalar>
for #ty
#where_clause
{
const NAME: ::juniper::macros::reflect::Type = #name;
}
#[automatically_derived]
impl #impl_generics ::juniper::macros::reflect::BaseSubTypes<#scalar>
for #ty
#where_clause
{
const NAMES: ::juniper::macros::reflect::Types = &[
<Self as ::juniper::macros::reflect::BaseType<#scalar>>::NAME,
#(<#variants as ::juniper::macros::reflect::BaseType<#scalar>>::NAME),*
];
}
#[automatically_derived]
impl #impl_generics ::juniper::macros::reflect::WrappedType<#scalar>
for #ty
#where_clause
{
const VALUE: ::juniper::macros::reflect::WrappedValue = 1;
}
}
}
}
struct VariantDefinition {
ty: syn::Type,
resolver_code: syn::Expr,
resolver_check: syn::Expr,
context: Option<syn::Type>,
}
impl VariantDefinition {
#[must_use]
fn method_concrete_type_name_tokens(&self, scalar: &scalar::Type) -> TokenStream {
let ty = &self.ty;
let check = &self.resolver_check;
quote! {
if #check {
return <#ty as ::juniper::GraphQLType<#scalar>>::name(info)
.unwrap()
.to_string();
}
}
}
#[must_use]
fn method_resolve_into_type_tokens(&self, scalar: &scalar::Type) -> TokenStream {
let ty = &self.ty;
let ty_name = ty.to_token_stream().to_string();
let expr = &self.resolver_code;
let resolving_code = generate::sync_resolving_code();
quote! {
if type_name == <#ty as ::juniper::GraphQLType<#scalar>>::name(info)
.ok_or_else(|| ::juniper::macros::helper::err_unnamed_type(#ty_name))?
{
let res = { #expr };
return #resolving_code;
}
}
}
#[must_use]
fn method_resolve_into_type_async_tokens(&self, scalar: &scalar::Type) -> TokenStream {
let ty = &self.ty;
let ty_name = ty.to_token_stream().to_string();
let expr = &self.resolver_code;
let resolving_code = generate::async_resolving_code(None);
quote! {
match <#ty as ::juniper::GraphQLType<#scalar>>::name(info) {
::core::option::Option::Some(name) => {
if type_name == name {
let fut = ::juniper::futures::future::ready({ #expr });
return #resolving_code;
}
}
::core::option::Option::None => {
return ::juniper::macros::helper::err_unnamed_type_fut(#ty_name);
}
}
}
}
}
fn emerge_union_variants_from_attr(
variants: &mut Vec<VariantDefinition>,
external_resolvers: AttrResolvers,
) {
if external_resolvers.is_empty() {
return;
}
for (ty, rslvr) in external_resolvers {
let resolver_fn = rslvr.into_inner();
let resolver_code = parse_quote! {
#resolver_fn(self, ::juniper::FromContext::from(context))
};
let resolver_check = parse_quote! {
({ #resolver_code } as ::core::option::Option<&#ty>).is_some()
};
if let Some(var) = variants.iter_mut().find(|v| v.ty == ty) {
var.resolver_code = resolver_code;
var.resolver_check = resolver_check;
} else {
variants.push(VariantDefinition {
ty,
resolver_code,
resolver_check,
context: None,
})
}
}
}
fn all_variants_different(variants: &[VariantDefinition]) -> bool {
let mut types: Vec<_> = variants.iter().map(|var| &var.ty).collect();
types.dedup();
types.len() == variants.len()
}