nova_impl/
lib.rs

1//! This crate implements the macro for `nova` and should not be used directly.
2
3use std::ops::Deref;
4use std::{cmp::Ordering, collections::HashSet, iter::FromIterator};
5
6use darling::util::PathList;
7use darling::FromMeta;
8use proc_macro2::TokenStream;
9use quote::quote;
10use syn::{
11    parse::{Parse, ParseStream},
12    punctuated::Punctuated,
13    AttributeArgs, GenericArgument, Token, TypePath,
14};
15
16#[doc(hidden)]
17#[derive(Debug, Default, FromMeta)]
18pub struct Attrs {
19    #[darling(default)]
20    new: bool,
21    #[darling(default)]
22    copy: bool,
23    #[darling(default)]
24    opaque: bool,
25    #[darling(default)]
26    serde: bool,
27    #[darling(default)]
28    sqlx: bool,
29    #[darling(default)]
30    async_graphql: bool,
31    #[darling(default)]
32    borrow: Option<syn::Path>,
33    #[darling(default)]
34    try_from: Option<syn::LitStr>,
35    #[darling(default)]
36    display: bool,
37
38    #[darling(default)]
39    derive: Option<PathList>,
40}
41
42fn pointy_bits(ty: &syn::Type) -> Punctuated<GenericArgument, Token![,]> {
43    let set = match ty {
44        syn::Type::Path(path) => path
45            .path
46            .segments
47            .iter()
48            .map(|x| match &x.arguments {
49                syn::PathArguments::AngleBracketed(a) => {
50                    a.args.iter().map(|x| x).cloned().collect()
51                }
52                syn::PathArguments::Parenthesized(_) => vec![],
53                syn::PathArguments::None => vec![],
54            })
55            .flatten()
56            .collect::<HashSet<_>>(),
57        _ => Default::default(),
58    };
59
60    let mut vec = set.into_iter().collect::<Vec<_>>();
61    vec.sort_by(|a, b| {
62        if a == b {
63            return Ordering::Equal;
64        }
65
66        match (a, b) {
67            (GenericArgument::Lifetime(_), _) => Ordering::Greater,
68            (GenericArgument::Type(_), GenericArgument::Lifetime(_)) => Ordering::Less,
69            (GenericArgument::Type(_), GenericArgument::Const(_)) => Ordering::Greater,
70            (GenericArgument::Const(_), _) => Ordering::Less,
71            _ => Ordering::Less,
72        }
73    });
74
75    Punctuated::from_iter(vec.into_iter())
76}
77
78#[doc(hidden)]
79#[derive(Debug, Default, FromMeta)]
80pub struct SerdeAttrs {
81    #[allow(dead_code)]
82    #[darling(default, rename = "crate")]
83    crate_: Option<syn::Path>,
84}
85
86fn do_newtype(mut attrs: Attrs, item: Item) -> Result<TokenStream, syn::Error> {
87    let Item {
88        visibility,
89        new_ty,
90        wrapped_ty,
91    } = item;
92
93    let borrow_ty = attrs
94        .borrow
95        .take()
96        .map(|path| syn::Type::Path(TypePath { qself: None, path }))
97        .unwrap_or_else(|| wrapped_ty.clone());
98
99    let copy = if attrs.copy {
100        Some(quote! {
101            #[derive(Copy)]
102        })
103    } else {
104        None
105    };
106
107    let serde = if attrs.serde {
108        let serde_path: syn::Path = syn::parse_quote! { serde };
109        Some(match attrs.try_from.as_ref() {
110            Some(path) => {
111                quote! {
112                    #[derive(#serde_path::Deserialize, #serde_path::Serialize)]
113                    #[serde(try_from = #path)]
114                }
115            }
116            None => quote! {
117                #[derive(#serde_path::Deserialize, #serde_path::Serialize)]
118                #[serde(transparent)]
119            },
120        })
121    } else {
122        None
123    };
124
125    let sqlx = if attrs.sqlx {
126        let segments = match &wrapped_ty {
127            syn::Type::Path(p) => &p.path.segments,
128            _ => panic!("Ahhhh"),
129        };
130
131        let sql_type_literal = match &*segments.last().unwrap().ident.to_string() {
132            "u128" | "i128" | "Uuid" => "UUID",
133            "u64" | "i64" => "INT8",
134            "u32" | "i32" => "INT4",
135            "u16" | "i16" | "u8" | "i8" => "INT2",
136            "bool" => "BOOL",
137            _ => "",
138        };
139
140        let sql_type_literal = if sql_type_literal != "" {
141            quote! { #[sqlx(transparent, type_name = #sql_type_literal)] }
142        } else {
143            quote! { #[sqlx(transparent)] }
144        };
145
146        quote! {
147            #[derive(sqlx::Type)]
148            #sql_type_literal
149        }
150    } else {
151        // sqlx's derive interferes with a repr declaration, so we do it here.
152        quote! {
153            #[repr(transparent)]
154        }
155    };
156
157    let async_graphql = if attrs.async_graphql {
158        Some(quote! {
159            async_graphql::scalar!(#new_ty);
160        })
161    } else {
162        None
163    };
164
165    let pointy_bits = pointy_bits(&new_ty);
166    let pointy = quote!( < #pointy_bits > );
167
168    let deref = if attrs.opaque {
169        None
170    } else {
171        Some(quote! {
172            impl #pointy core::ops::Deref for #new_ty {
173                type Target = #borrow_ty;
174
175                fn deref(&self) -> &Self::Target {
176                    &self.0
177                }
178            }
179
180            impl #pointy #new_ty {
181                #[allow(dead_code)]
182                pub fn into_inner(self) -> #wrapped_ty {
183                    self.0
184                }
185            }
186        })
187    };
188
189    let new = if attrs.new {
190        let consty = if attrs.copy {
191            Some(quote! { const })
192        } else {
193            None
194        };
195        Some(quote! {
196            impl #pointy #new_ty {
197                pub #consty fn new(input: #wrapped_ty) -> Self {
198                    Self(input)
199                }
200            }
201
202            impl #pointy From<#wrapped_ty> for #new_ty {
203                fn from(x: #wrapped_ty) -> Self {
204                    Self(x)
205                }
206            }
207        })
208    } else {
209        None
210    };
211
212    let trait_impl = quote! {
213        impl #pointy ::nova::NewType for #new_ty {
214            type Inner = #wrapped_ty;
215        }
216    };
217
218    let display = if attrs.display {
219        Some(quote! {
220            impl #pointy core::fmt::Display for #new_ty {
221                fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
222                    core::fmt::Display::fmt(&self.0, f)
223                }
224            }
225        })
226    } else {
227        None
228    };
229
230    let derives = if let Some(custom_derives) = attrs.derive {
231        let paths = custom_derives.deref().clone();
232        quote! { #[derive( #(#paths),*)]}
233    } else {
234        quote! { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, core::hash::Hash)]}
235    };
236    let out = quote! {
237        #derives
238        #copy
239        #serde
240        #sqlx
241        #visibility struct #new_ty(#wrapped_ty);
242        #async_graphql
243        #deref
244        #new
245        #trait_impl
246        #display
247    };
248
249    Ok(out)
250}
251
252#[doc(hidden)]
253pub fn newtype(attrs: AttributeArgs, item: TokenStream) -> Result<TokenStream, syn::Error> {
254    let attrs = match Attrs::from_list(&attrs) {
255        Ok(v) => v,
256        Err(e) => {
257            return Ok(TokenStream::from(e.write_errors()));
258        }
259    };
260
261    let item: Item = syn::parse2(item.clone())?;
262
263    do_newtype(attrs, item)
264}
265
266#[derive(Debug)]
267struct Item {
268    visibility: syn::Visibility,
269    new_ty: syn::Type,
270    wrapped_ty: syn::Type,
271}
272
273impl Parse for Item {
274    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
275        let lookahead = input.lookahead1();
276
277        let visibility = if lookahead.peek(Token![pub]) {
278            let visibility: syn::Visibility = input.call(syn::Visibility::parse)?;
279            visibility
280        } else {
281            syn::Visibility::Inherited
282        };
283
284        let _: Token![type] = input.parse()?;
285
286        let new_ty: syn::Type = input.parse()?;
287        let _: Token![=] = input.parse()?;
288        let wrapped_ty: syn::Type = input.parse()?;
289        let _: Token![;] = input.parse()?;
290
291        // println!("{:?}", input.cursor().token_stream().to_string());
292
293        Ok(Item {
294            visibility,
295            new_ty,
296            wrapped_ty,
297        })
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    #[test]
306    fn example() {
307        println!(
308            "{:?}",
309            newtype(
310                vec![syn::parse_quote!(copy)],
311                quote! { pub(crate) type Hello = u8; },
312            )
313            .unwrap()
314        );
315
316        println!(
317            "{:?}",
318            newtype(
319                vec![syn::parse_quote!(copy)],
320                quote! { pub(in super) type SpecialUuid = uuid::Uuid; },
321            )
322            .unwrap()
323        );
324
325        println!(
326            "{:?}",
327            newtype(
328                vec![syn::parse_quote!(new), syn::parse_quote!(borrow = "str")],
329                quote! { pub(in super) type S<'a> = std::borrow::Cow<'a, str>; },
330            )
331            .unwrap()
332        );
333    }
334}