branded_derive/
lib.rs

1use darling::{FromDeriveInput, FromField};
2use proc_macro::TokenStream;
3use quote::quote;
4
5#[derive(FromDeriveInput)]
6#[darling(attributes(branded), supports(struct_newtype))]
7pub(crate) struct BrandedTypeOptions {
8    ident: syn::Ident,
9    data: darling::ast::Data<(), BrandedFieldOptions>,
10
11    #[darling(default)]
12    serde: bool,
13    #[darling(default)]
14    uuid: bool,
15    #[darling(default)]
16    sqlx: bool,
17}
18
19#[derive(FromField)]
20pub(crate) struct BrandedFieldOptions {
21    ty: syn::Type,
22}
23
24#[proc_macro_derive(Branded, attributes(branded))]
25pub fn branded_derive(input: TokenStream) -> TokenStream {
26    let input = syn::parse_macro_input!(input);
27    let options = match BrandedTypeOptions::from_derive_input(&input) {
28        Ok(options) => options,
29        Err(err) => return err.write_errors().into(),
30    };
31    let expanded = match expand_branded_derive(options) {
32        Ok(expanded) => expanded,
33        Err(err) => return err.to_compile_error().into(),
34    };
35    expanded.into()
36}
37
38pub(crate) fn expand_branded_derive(
39    options: BrandedTypeOptions,
40) -> syn::Result<proc_macro2::TokenStream> {
41    let mut tokens = proc_macro2::TokenStream::new();
42    let struct_name = &options.ident;
43    let field = options
44        .data
45        .take_struct()
46        .map(|fields| {
47            fields.into_iter().next().ok_or(syn::Error::new(
48                struct_name.span(),
49                "struct must have exactly one field (newtype pattern)",
50            ))
51        })
52        .transpose()?
53        .ok_or(syn::Error::new(
54            struct_name.span(),
55            "derive(Branded) can only be used on structs",
56        ))?;
57    let ty = field.ty;
58    let constructor_doc_comment = format!("Construct a new `{struct_name}` value.");
59    tokens.extend(quote! {
60        impl Branded for #struct_name {
61            type Inner = #ty;
62            fn inner(&self) -> &#ty { &self.0 }
63            fn into_inner(self) -> #ty { self.0 }
64        }
65        impl #struct_name {
66            #[doc = #constructor_doc_comment]
67            pub fn new(inner: #ty) -> Self { Self(inner) }
68        }
69    });
70
71    tokens.extend(expand_clone_copy_impl(struct_name));
72    tokens.extend(expand_debug_display_impl(struct_name));
73    tokens.extend(expand_default_impl(struct_name));
74    tokens.extend(expand_ord_impl(struct_name));
75    tokens.extend(expand_hash_impl(struct_name));
76
77    if options.serde {
78        tokens.extend(expand_serde_impl(struct_name));
79    }
80
81    if options.sqlx {
82        tokens.extend(expand_sqlx_impl(struct_name));
83    }
84
85    if options.uuid {
86        tokens.extend(expand_uuid_impl(struct_name));
87    }
88
89    Ok(tokens)
90}
91
92/// Derive a Clone implementation for the branded type if the inner type is Clone.
93pub(crate) fn expand_clone_copy_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
94    let copy_trait: syn::Path = syn::parse_quote!(::std::marker::Copy);
95    let clone_trait: syn::Path = syn::parse_quote!(::std::clone::Clone);
96    quote! {
97        impl #clone_trait for #brand_struct_name
98        where
99            for<'__branded> <Self as Branded>::Inner: #clone_trait,
100        {
101            fn clone(&self) -> Self {
102                Self::new(self.inner().clone())
103            }
104        }
105        impl #copy_trait for #brand_struct_name
106        where
107            for<'__branded> <Self as Branded>::Inner: #copy_trait,
108        {
109        }
110    }
111}
112
113/// Derive a Display and Debug implementation for the branded type if the inner type conforms to
114/// either trait.
115///
116/// For the Debug implementation, this generates a Debug implementation that prints a tuple of the
117/// inner type contained in the branded type name.
118pub(crate) fn expand_debug_display_impl(
119    brand_struct_name: &syn::Ident,
120) -> proc_macro2::TokenStream {
121    let display_trait: syn::Path = syn::parse_quote!(::std::fmt::Display);
122    let debug_trait: syn::Path = syn::parse_quote!(::std::fmt::Debug);
123    quote! {
124        impl #display_trait for #brand_struct_name
125        where
126            for<'__branded> <Self as Branded>::Inner: #display_trait,
127        {
128            fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
129                ::std::fmt::Display::fmt(&self.inner(), f)
130            }
131        }
132        impl #debug_trait for #brand_struct_name
133        where
134            for<'__branded> <Self as Branded>::Inner: #debug_trait,
135        {
136            fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
137                f.debug_tuple(stringify!(#brand_struct_name)).field(self.inner()).finish()
138            }
139        }
140    }
141}
142
143/// Derive a Default implementation for the branded type if the inner type conforms to Default.
144pub(crate) fn expand_default_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
145    let path: syn::Path = syn::parse_quote!(::std::default::Default);
146    quote! {
147        impl #path for #brand_struct_name
148        where
149            for<'__branded> <Self as Branded>::Inner: #path,
150        {
151            fn default() -> Self {
152                Self::new(<Self as Branded>::Inner::default())
153            }
154        }
155    }
156}
157
158/// Derive a PartialEq, Eq, Ord, and PartialOrd implementation for the branded type if the inner
159/// type conforms to any of those traits.
160pub(crate) fn expand_ord_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
161    let eq_trait: syn::Path = syn::parse_quote!(::std::cmp::Eq);
162    let partial_eq_trait: syn::Path = syn::parse_quote!(::std::cmp::PartialEq);
163    let ord_trait: syn::Path = syn::parse_quote!(::std::cmp::Ord);
164    let partial_ord_trait: syn::Path = syn::parse_quote!(::std::cmp::PartialOrd);
165    quote! {
166        impl #partial_eq_trait for #brand_struct_name
167        where
168            for<'__branded> <Self as Branded>::Inner: #partial_eq_trait,
169        {
170            fn eq(&self, other: &Self) -> bool {
171                self.inner().eq(other.inner())
172            }
173        }
174        impl #eq_trait for #brand_struct_name
175        where
176            for<'__branded> <Self as Branded>::Inner: #eq_trait,
177        {
178        }
179        impl #ord_trait for #brand_struct_name
180        where
181            for<'__branded> <Self as Branded>::Inner: #ord_trait,
182        {
183            fn cmp(&self, other: &Self) -> ::std::cmp::Ordering {
184                self.0.cmp(&other.0)
185            }
186        }
187        impl #partial_ord_trait for #brand_struct_name
188        where
189            for<'__branded> <Self as Branded>::Inner: #partial_ord_trait,
190        {
191            fn partial_cmp(&self, other: &Self) -> ::std::option::Option<::std::cmp::Ordering> {
192                self.0.partial_cmp(&other.0)
193            }
194        }
195    }
196}
197
198/// Derive a Hash implementation for the branded type if the inner type conforms to Hash.
199pub(crate) fn expand_hash_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
200    let hash_trait: syn::Path = syn::parse_quote!(::std::hash::Hash);
201    quote! {
202        impl #hash_trait for #brand_struct_name
203        where
204            for<'__branded> <Self as Branded>::Inner: #hash_trait,
205        {
206            fn hash<H: ::std::hash::Hasher>(&self, state: &mut H) {
207                self.inner().hash(state);
208            }
209        }
210    }
211}
212
213/// Derive a Serde implementation for the branded type if asked for.
214pub(crate) fn expand_serde_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
215    let serialize_trait: syn::Path = syn::parse_quote!(::serde::Serialize);
216    let deserialize_trait: syn::Path = syn::parse_quote!(::serde::Deserialize);
217    quote! {
218        impl #serialize_trait for #brand_struct_name
219        where
220            for<'__branded> <Self as Branded>::Inner: #serialize_trait,
221        {
222            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
223            where
224                S: ::serde::Serializer,
225            {
226                self.inner().serialize(serializer)
227            }
228        }
229
230        impl<'de> #deserialize_trait<'de> for #brand_struct_name
231        where
232            for<'__branded> <Self as Branded>::Inner: #deserialize_trait<'de>,
233        {
234            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
235            where
236                D: ::serde::Deserializer<'de>,
237            {
238                <Self as Branded>::Inner::deserialize(deserializer)
239                    .map(Self::new)
240            }
241        }
242    }
243}
244
245/// Derive a sqlx Type, Encode, and Decode implementation for the branded type if asked for.
246pub(crate) fn expand_sqlx_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
247    let type_trait: syn::Path = syn::parse_quote!(::sqlx::Type);
248    let encode_trait: syn::Path = syn::parse_quote!(::sqlx::Encode);
249    let decode_trait: syn::Path = syn::parse_quote!(::sqlx::Decode);
250    quote! {
251        impl<DB> #type_trait<DB> for #brand_struct_name
252        where
253            for<'__branded> <Self as Branded>::Inner: #type_trait<DB>,
254            DB: ::sqlx::Database,
255        {
256            fn type_info() -> DB::TypeInfo {
257                <Self as Branded>::Inner::type_info()
258            }
259        }
260
261        impl<'de, DB> #decode_trait<'de, DB> for #brand_struct_name
262        where
263            for<'__branded> Self: Branded,
264            <Self as Branded>::Inner: for<'a> #decode_trait<'a, DB>,
265            DB: ::sqlx::Database,
266        {
267            fn decode(value: DB::ValueRef<'_>) -> ::std::result::Result<#brand_struct_name, ::sqlx::error::BoxDynError> {
268                <Self as Branded>::Inner::decode(value).map(Self::new)
269            }
270        }
271
272        impl<'en, DB> #encode_trait<'en, DB> for #brand_struct_name
273        where
274            for<'__branded> Self: Branded,
275            <Self as Branded>::Inner: for<'a> #encode_trait<'a, DB>,
276            DB: ::sqlx::Database,
277        {
278            fn encode_by_ref(&self, buf: &mut DB::ArgumentBuffer<'_>) -> ::std::result::Result<::sqlx::encode::IsNull, ::sqlx::error::BoxDynError> {
279                self.inner().encode_by_ref(buf)
280            }
281        }
282    }
283}
284
285pub(crate) fn expand_uuid_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
286    quote! {
287        impl #brand_struct_name
288        where
289            for<'__branded> Self: Branded<Inner = ::uuid::Uuid>
290        {
291            /// Get the nil UUID.
292            pub fn nil() -> Self { Self::new(::uuid::Uuid::nil()) }
293
294            /// Get a new random UUID v4.
295            pub fn new_v4() -> Self { Self::new(::uuid::Uuid::new_v4()) }
296        }
297    }
298}