1use std::ops::Deref;
4use std::{cmp::Ordering, collections::HashSet, iter::FromIterator};
5
6use darling::util::PathList;
7use darling::FromMeta;
8use proc_macro2::TokenStream;
9use quote::quote;
10use syn::{
11 parse::{Parse, ParseStream},
12 punctuated::Punctuated,
13 AttributeArgs, GenericArgument, Token, TypePath,
14};
15
16#[doc(hidden)]
17#[derive(Debug, Default, FromMeta)]
18pub struct Attrs {
19 #[darling(default)]
20 new: bool,
21 #[darling(default)]
22 copy: bool,
23 #[darling(default)]
24 opaque: bool,
25 #[darling(default)]
26 serde: bool,
27 #[darling(default)]
28 sqlx: bool,
29 #[darling(default)]
30 async_graphql: bool,
31 #[darling(default)]
32 borrow: Option<syn::Path>,
33 #[darling(default)]
34 try_from: Option<syn::LitStr>,
35 #[darling(default)]
36 display: bool,
37
38 #[darling(default)]
39 derive: Option<PathList>,
40}
41
42fn pointy_bits(ty: &syn::Type) -> Punctuated<GenericArgument, Token![,]> {
43 let set = match ty {
44 syn::Type::Path(path) => path
45 .path
46 .segments
47 .iter()
48 .map(|x| match &x.arguments {
49 syn::PathArguments::AngleBracketed(a) => {
50 a.args.iter().map(|x| x).cloned().collect()
51 }
52 syn::PathArguments::Parenthesized(_) => vec![],
53 syn::PathArguments::None => vec![],
54 })
55 .flatten()
56 .collect::<HashSet<_>>(),
57 _ => Default::default(),
58 };
59
60 let mut vec = set.into_iter().collect::<Vec<_>>();
61 vec.sort_by(|a, b| {
62 if a == b {
63 return Ordering::Equal;
64 }
65
66 match (a, b) {
67 (GenericArgument::Lifetime(_), _) => Ordering::Greater,
68 (GenericArgument::Type(_), GenericArgument::Lifetime(_)) => Ordering::Less,
69 (GenericArgument::Type(_), GenericArgument::Const(_)) => Ordering::Greater,
70 (GenericArgument::Const(_), _) => Ordering::Less,
71 _ => Ordering::Less,
72 }
73 });
74
75 Punctuated::from_iter(vec.into_iter())
76}
77
78#[doc(hidden)]
79#[derive(Debug, Default, FromMeta)]
80pub struct SerdeAttrs {
81 #[allow(dead_code)]
82 #[darling(default, rename = "crate")]
83 crate_: Option<syn::Path>,
84}
85
86fn do_newtype(mut attrs: Attrs, item: Item) -> Result<TokenStream, syn::Error> {
87 let Item {
88 visibility,
89 new_ty,
90 wrapped_ty,
91 } = item;
92
93 let borrow_ty = attrs
94 .borrow
95 .take()
96 .map(|path| syn::Type::Path(TypePath { qself: None, path }))
97 .unwrap_or_else(|| wrapped_ty.clone());
98
99 let copy = if attrs.copy {
100 Some(quote! {
101 #[derive(Copy)]
102 })
103 } else {
104 None
105 };
106
107 let serde = if attrs.serde {
108 let serde_path: syn::Path = syn::parse_quote! { serde };
109 Some(match attrs.try_from.as_ref() {
110 Some(path) => {
111 quote! {
112 #[derive(#serde_path::Deserialize, #serde_path::Serialize)]
113 #[serde(try_from = #path)]
114 }
115 }
116 None => quote! {
117 #[derive(#serde_path::Deserialize, #serde_path::Serialize)]
118 #[serde(transparent)]
119 },
120 })
121 } else {
122 None
123 };
124
125 let sqlx = if attrs.sqlx {
126 let segments = match &wrapped_ty {
127 syn::Type::Path(p) => &p.path.segments,
128 _ => panic!("Ahhhh"),
129 };
130
131 let sql_type_literal = match &*segments.last().unwrap().ident.to_string() {
132 "u128" | "i128" | "Uuid" => "UUID",
133 "u64" | "i64" => "INT8",
134 "u32" | "i32" => "INT4",
135 "u16" | "i16" | "u8" | "i8" => "INT2",
136 "bool" => "BOOL",
137 _ => "",
138 };
139
140 let sql_type_literal = if sql_type_literal != "" {
141 quote! { #[sqlx(transparent, type_name = #sql_type_literal)] }
142 } else {
143 quote! { #[sqlx(transparent)] }
144 };
145
146 quote! {
147 #[derive(sqlx::Type)]
148 #sql_type_literal
149 }
150 } else {
151 quote! {
153 #[repr(transparent)]
154 }
155 };
156
157 let async_graphql = if attrs.async_graphql {
158 Some(quote! {
159 async_graphql::scalar!(#new_ty);
160 })
161 } else {
162 None
163 };
164
165 let pointy_bits = pointy_bits(&new_ty);
166 let pointy = quote!( < #pointy_bits > );
167
168 let deref = if attrs.opaque {
169 None
170 } else {
171 Some(quote! {
172 impl #pointy core::ops::Deref for #new_ty {
173 type Target = #borrow_ty;
174
175 fn deref(&self) -> &Self::Target {
176 &self.0
177 }
178 }
179
180 impl #pointy #new_ty {
181 #[allow(dead_code)]
182 pub fn into_inner(self) -> #wrapped_ty {
183 self.0
184 }
185 }
186 })
187 };
188
189 let new = if attrs.new {
190 let consty = if attrs.copy {
191 Some(quote! { const })
192 } else {
193 None
194 };
195 Some(quote! {
196 impl #pointy #new_ty {
197 pub #consty fn new(input: #wrapped_ty) -> Self {
198 Self(input)
199 }
200 }
201
202 impl #pointy From<#wrapped_ty> for #new_ty {
203 fn from(x: #wrapped_ty) -> Self {
204 Self(x)
205 }
206 }
207 })
208 } else {
209 None
210 };
211
212 let trait_impl = quote! {
213 impl #pointy ::nova::NewType for #new_ty {
214 type Inner = #wrapped_ty;
215 }
216 };
217
218 let display = if attrs.display {
219 Some(quote! {
220 impl #pointy core::fmt::Display for #new_ty {
221 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
222 core::fmt::Display::fmt(&self.0, f)
223 }
224 }
225 })
226 } else {
227 None
228 };
229
230 let derives = if let Some(custom_derives) = attrs.derive {
231 let paths = custom_derives.deref().clone();
232 quote! { #[derive( #(#paths),*)]}
233 } else {
234 quote! { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, core::hash::Hash)]}
235 };
236 let out = quote! {
237 #derives
238 #copy
239 #serde
240 #sqlx
241 #visibility struct #new_ty(#wrapped_ty);
242 #async_graphql
243 #deref
244 #new
245 #trait_impl
246 #display
247 };
248
249 Ok(out)
250}
251
252#[doc(hidden)]
253pub fn newtype(attrs: AttributeArgs, item: TokenStream) -> Result<TokenStream, syn::Error> {
254 let attrs = match Attrs::from_list(&attrs) {
255 Ok(v) => v,
256 Err(e) => {
257 return Ok(TokenStream::from(e.write_errors()));
258 }
259 };
260
261 let item: Item = syn::parse2(item.clone())?;
262
263 do_newtype(attrs, item)
264}
265
266#[derive(Debug)]
267struct Item {
268 visibility: syn::Visibility,
269 new_ty: syn::Type,
270 wrapped_ty: syn::Type,
271}
272
273impl Parse for Item {
274 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
275 let lookahead = input.lookahead1();
276
277 let visibility = if lookahead.peek(Token![pub]) {
278 let visibility: syn::Visibility = input.call(syn::Visibility::parse)?;
279 visibility
280 } else {
281 syn::Visibility::Inherited
282 };
283
284 let _: Token![type] = input.parse()?;
285
286 let new_ty: syn::Type = input.parse()?;
287 let _: Token![=] = input.parse()?;
288 let wrapped_ty: syn::Type = input.parse()?;
289 let _: Token![;] = input.parse()?;
290
291 Ok(Item {
294 visibility,
295 new_ty,
296 wrapped_ty,
297 })
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn example() {
307 println!(
308 "{:?}",
309 newtype(
310 vec![syn::parse_quote!(copy)],
311 quote! { pub(crate) type Hello = u8; },
312 )
313 .unwrap()
314 );
315
316 println!(
317 "{:?}",
318 newtype(
319 vec![syn::parse_quote!(copy)],
320 quote! { pub(in super) type SpecialUuid = uuid::Uuid; },
321 )
322 .unwrap()
323 );
324
325 println!(
326 "{:?}",
327 newtype(
328 vec![syn::parse_quote!(new), syn::parse_quote!(borrow = "str")],
329 quote! { pub(in super) type S<'a> = std::borrow::Cow<'a, str>; },
330 )
331 .unwrap()
332 );
333 }
334}