graphql_extract/
lib.rs

1//! Macro to extract data from deeply nested types representing GraphQL results
2//!
3//! # Suggested workflow
4//!
5//! 1. Generate query types using [cynic] and its [generator]
6//! 1. Use [insta] to define an inline snapshot test so that the query string is visible in the
7//!    module that defines the query types
8//! 1. Define an `extract` function that takes the root query type and returns the data of interest
9//! 1. Inside `extract`, use [`extract!`](crate::extract!) as `extract!(data => { ... })`
10//! 1. Inside the curly braces, past the query string from the snapshot test above
11//! 1. Change all node names from `camelCase` to `snake_case`
12//! 1. Add `?` after the nodes that are nullable
13//! 1. Add `[]` after the nodes that are iterable
14//!
15//! [cynic]: https://cynic-rs.dev/
16//! [generator]: https://generator.cynic-rs.dev/
17//! [insta]: https://insta.rs/
18
19use proc_macro2::{Span, TokenStream};
20use quote::{ToTokens as _, quote};
21use syn::parse::{Parse, ParseStream};
22use syn::spanned::Spanned as _;
23use syn::token::{self, Brace};
24use syn::{Error, Ident, Token, braced, bracketed, parse_macro_input, parse_quote};
25
26/// See the top-level [`crate`] doc for a description.
27#[proc_macro]
28pub fn extract(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
29    let root = parse_macro_input!(input as Root);
30    let stmt = root.generate_extract();
31    stmt.into()
32}
33
34struct Root {
35    expr: syn::Expr,
36    nested: Nested,
37}
38
39struct Node {
40    ident: Ident,
41    alias: Option<Ident>,
42    optional: bool,
43    iterable: bool,
44    nested: Option<Nested>,
45}
46
47enum Nested {
48    Nodes(Vec<Node>),
49    Variant(Variant),
50}
51
52struct Variant {
53    path: syn::Path,
54    nodes: Vec<Node>,
55}
56
57//=================================================================================================
58// Parsing
59//=================================================================================================
60
61impl Parse for Root {
62    fn parse(input: ParseStream) -> syn::Result<Self> {
63        let expr = input.parse()?;
64        let _: Token![=>] = input.parse()?;
65        let nested = input.parse()?;
66        Ok(Self { expr, nested })
67    }
68}
69
70impl Parse for Node {
71    fn parse(input: ParseStream) -> syn::Result<Self> {
72        let mut self_ = Self {
73            ident: input.parse()?,
74            alias: None,
75            optional: false,
76            iterable: false,
77            nested: None,
78        };
79
80        // Caller is allowed to set an alias like `alias: node`
81        let lookahead = input.lookahead1();
82        if lookahead.peek(Token![:]) {
83            let _: Token![:] = input.parse()?;
84            self_.alias = Some(self_.ident);
85            self_.ident = input.parse()?;
86        }
87
88        while !input.is_empty() {
89            let lookahead = input.lookahead1();
90            if lookahead.peek(Ident) {
91                break; // There's another field to be parsed
92            } else if lookahead.peek(Token![?]) {
93                let question: Token![?] = input.parse()?;
94                if self_.optional {
95                    return Err(Error::new_spanned(
96                        question,
97                        "Can't have two `?` for the same node",
98                    ));
99                }
100                self_.optional = true;
101            } else if lookahead.peek(token::Bracket) {
102                let content;
103                let bracket = bracketed!(content in input);
104                if self_.iterable {
105                    return Err(Error::new(
106                        bracket.span.span(),
107                        "Can't have two `[]` for the same node",
108                    ));
109                }
110                if !content.is_empty() {
111                    return Err(Error::new(
112                        bracket.span.span(),
113                        "Only empty brackets allowed",
114                    ));
115                }
116                self_.iterable = true;
117            } else if lookahead.peek(token::Brace) {
118                let nested = input.parse()?;
119                self_.nested = Some(nested);
120                break; // Everything after the closing brace is ignored
121            } else {
122                return Err(lookahead.error());
123            }
124        }
125
126        Ok(self_)
127    }
128}
129
130impl Node {
131    fn within_braces(brace: Brace, content: ParseStream) -> syn::Result<Vec<Self>> {
132        let mut nodes = vec![];
133        while !content.is_empty() {
134            let lookahead = content.lookahead1();
135            if lookahead.peek(Token![...]) {
136                return Err(Error::new(
137                    brace.span.span(),
138                    "Nodes can't be mixed with '... on Variant' matches",
139                ));
140            }
141            nodes.push(content.parse()?);
142        }
143        if nodes.is_empty() {
144            return Err(Error::new(
145                brace.span.span(),
146                "Empty braces; must have at least one node",
147            ));
148        }
149        Ok(nodes)
150    }
151}
152
153impl Parse for Nested {
154    fn parse(input: ParseStream) -> syn::Result<Self> {
155        let content;
156        let brace = braced!(content in input);
157
158        let lookahead = content.lookahead1();
159        Ok(if lookahead.peek(Token![...]) {
160            let var = Self::Variant(content.parse()?);
161            if !content.is_empty() {
162                return Err(Error::new(
163                    brace.span.span(),
164                    "Only a single '... on Variant' match is supported within the same braces",
165                ));
166            }
167            var
168        } else {
169            Self::Nodes(Node::within_braces(brace, &content)?)
170        })
171    }
172}
173
174impl Parse for Variant {
175    fn parse(input: ParseStream) -> syn::Result<Self> {
176        input.parse::<Token![...]>()?;
177        let on: Ident = input.parse()?;
178        if on != "on" {
179            return Err(Error::new(on.span(), "Expected 'on'"));
180        }
181        let path = input.parse()?;
182        let content;
183        let brace = braced!(content in input);
184        Ok(Self {
185            path,
186            nodes: Node::within_braces(brace, &content)?,
187        })
188    }
189}
190
191//=================================================================================================
192// Generation
193//=================================================================================================
194
195impl Root {
196    fn generate_extract(self) -> TokenStream {
197        let Self { expr, nested, .. } = self;
198        let data = Ident::new("data", Span::mixed_site());
199        let err = data.to_string() + " is null";
200        let (pats, tokens): (Vec<_>, Vec<_>) =
201            nested.generate_extract(data.clone(), data.to_string());
202        quote! {
203            let #data = ( #expr ).ok_or::<&'static str>(#err)?;
204            let ( #(#pats),* ) = {
205                #(#tokens)*
206                ( #(#pats),* )
207            };
208        }
209    }
210}
211
212impl Node {
213    fn generate_extract(self, data: Ident, path: String) -> (syn::Pat, TokenStream) {
214        let Self {
215            ident,
216            alias,
217            optional,
218            iterable,
219            nested,
220        } = self;
221        let field = &ident;
222        let ident = alias.as_ref().unwrap_or(&ident);
223
224        let path = path + " -> " + ident.to_string().as_str();
225
226        let assign = if optional {
227            let err = path.clone() + " is null";
228            quote!(let #ident = #data.#field.ok_or::<&'static str>(#err)?;)
229        } else {
230            quote!(let #ident = #data.#field;)
231        };
232
233        let Some(inner) = nested else {
234            return (parse_quote!(#ident), assign);
235        };
236
237        let (pats, tokens) = inner.generate_extract(ident.clone(), path);
238        let (pat, tokens_);
239        // TODO: consider
240        // - verifying that no nested `[]` exist
241        // - detecting any `?` in the subtree and setting the return type accordingly
242        if iterable {
243            pat = parse_quote!(#ident);
244            tokens_ = quote! {
245                #assign
246                let #ident = #ident.into_iter().map(|#ident| -> Result<_, &'static str> {
247                    #(#tokens)*
248                    Ok(( #(#pats),* ))
249                });
250            };
251        } else {
252            pat = parse_quote!( (#(#pats),*) );
253            tokens_ = quote! {
254                #assign
255                let ( #(#pats),* ) = {
256                    #(#tokens)*
257                    ( #(#pats),* )
258                };
259            };
260        }
261        (pat, tokens_)
262    }
263}
264
265impl Nested {
266    fn generate_extract(self, data: Ident, path: String) -> (Vec<syn::Pat>, Vec<TokenStream>) {
267        match self {
268            Self::Nodes(nodes) => nodes
269                .into_iter()
270                .map(|n| n.generate_extract(data.clone(), path.clone()))
271                .unzip(),
272            Self::Variant(Variant { path: var, nodes }) => {
273                let path = path + " ... on " + var.to_token_stream().to_string().as_str();
274                let err = path.clone() + " is null";
275                let val = Ident::new("val", Span::mixed_site());
276                let assign = quote! {
277                    let #var(#val) = #data else {
278                        return Err(#err);
279                    };
280                };
281
282                let mut tokens_ = vec![assign];
283                let (pats, tokens): (Vec<_>, Vec<_>) = nodes
284                    .into_iter()
285                    .map(|n| n.generate_extract(val.clone(), path.clone()))
286                    .unzip();
287                tokens_.extend(tokens);
288                (pats, tokens_)
289            }
290        }
291    }
292}