branded_derive/
lib.rs

1use darling::{FromDeriveInput, FromField};
2use proc_macro::TokenStream;
3use quote::quote;
4
5#[derive(FromDeriveInput)]
6#[darling(attributes(branded), supports(struct_newtype))]
7pub(crate) struct BrandedTypeOptions {
8    ident: syn::Ident,
9    data: darling::ast::Data<(), BrandedFieldOptions>,
10
11    #[darling(default)]
12    serde: bool,
13    #[darling(default)]
14    uuidv4: bool,
15    #[darling(default)]
16    uuidv7: bool,
17    #[darling(default)]
18    sqlx: bool,
19}
20
21#[derive(FromField)]
22pub(crate) struct BrandedFieldOptions {
23    ty: syn::Type,
24}
25
26#[proc_macro_derive(Branded, attributes(branded))]
27pub fn branded_derive(input: TokenStream) -> TokenStream {
28    let input = syn::parse_macro_input!(input);
29    let options = match BrandedTypeOptions::from_derive_input(&input) {
30        Ok(options) => options,
31        Err(err) => return err.write_errors().into(),
32    };
33    let expanded = match expand_branded_derive(options) {
34        Ok(expanded) => expanded,
35        Err(err) => return err.to_compile_error().into(),
36    };
37    expanded.into()
38}
39
40pub(crate) fn expand_branded_derive(
41    options: BrandedTypeOptions,
42) -> syn::Result<proc_macro2::TokenStream> {
43    let mut tokens = proc_macro2::TokenStream::new();
44    let struct_name = &options.ident;
45    let field = options
46        .data
47        .take_struct()
48        .map(|fields| {
49            fields.into_iter().next().ok_or(syn::Error::new(
50                struct_name.span(),
51                "struct must have exactly one field (newtype pattern)",
52            ))
53        })
54        .transpose()?
55        .ok_or(syn::Error::new(
56            struct_name.span(),
57            "derive(Branded) can only be used on structs",
58        ))?;
59    let ty = field.ty;
60    let constructor_doc_comment = format!("Construct a new `{struct_name}` value.");
61    tokens.extend(quote! {
62        impl Branded for #struct_name {
63            type Inner = #ty;
64            fn inner(&self) -> &#ty { &self.0 }
65            fn into_inner(self) -> #ty { self.0 }
66        }
67        impl #struct_name {
68            #[doc = #constructor_doc_comment]
69            pub fn new(inner: #ty) -> Self { Self(inner) }
70        }
71    });
72
73    tokens.extend(expand_clone_copy_impl(struct_name));
74    tokens.extend(expand_debug_display_impl(struct_name));
75    tokens.extend(expand_default_impl(struct_name));
76    tokens.extend(expand_ord_impl(struct_name));
77    tokens.extend(expand_hash_impl(struct_name));
78
79    if options.serde {
80        tokens.extend(expand_serde_impl(struct_name));
81    }
82
83    if options.sqlx {
84        tokens.extend(expand_sqlx_impl(struct_name));
85    }
86
87    if options.uuidv4 || options.uuidv7 {
88        tokens.extend(expand_uuid_nil_impl(struct_name));
89    }
90
91    if options.uuidv4 {
92        tokens.extend(expand_uuidv4_impl(struct_name));
93    }
94
95    if options.uuidv7 {
96        tokens.extend(expand_uuidv7_impl(struct_name));
97    }
98
99    Ok(tokens)
100}
101
102/// Derive a Clone implementation for the branded type if the inner type is Clone.
103pub(crate) fn expand_clone_copy_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
104    let copy_trait: syn::Path = syn::parse_quote!(::std::marker::Copy);
105    let clone_trait: syn::Path = syn::parse_quote!(::std::clone::Clone);
106    quote! {
107        impl #clone_trait for #brand_struct_name
108        where
109            for<'__branded> <Self as Branded>::Inner: #clone_trait,
110        {
111            fn clone(&self) -> Self {
112                Self::new(self.inner().clone())
113            }
114        }
115        impl #copy_trait for #brand_struct_name
116        where
117            for<'__branded> <Self as Branded>::Inner: #copy_trait,
118        {
119        }
120    }
121}
122
123/// Derive a Display and Debug implementation for the branded type if the inner type conforms to
124/// either trait.
125///
126/// For the Debug implementation, this generates a Debug implementation that prints a tuple of the
127/// inner type contained in the branded type name.
128pub(crate) fn expand_debug_display_impl(
129    brand_struct_name: &syn::Ident,
130) -> proc_macro2::TokenStream {
131    let display_trait: syn::Path = syn::parse_quote!(::std::fmt::Display);
132    let debug_trait: syn::Path = syn::parse_quote!(::std::fmt::Debug);
133    quote! {
134        impl #display_trait for #brand_struct_name
135        where
136            for<'__branded> <Self as Branded>::Inner: #display_trait,
137        {
138            fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
139                ::std::fmt::Display::fmt(&self.inner(), f)
140            }
141        }
142        impl #debug_trait for #brand_struct_name
143        where
144            for<'__branded> <Self as Branded>::Inner: #debug_trait,
145        {
146            fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
147                f.debug_tuple(stringify!(#brand_struct_name)).field(self.inner()).finish()
148            }
149        }
150    }
151}
152
153/// Derive a Default implementation for the branded type if the inner type conforms to Default.
154pub(crate) fn expand_default_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
155    let path: syn::Path = syn::parse_quote!(::std::default::Default);
156    quote! {
157        impl #path for #brand_struct_name
158        where
159            for<'__branded> <Self as Branded>::Inner: #path,
160        {
161            fn default() -> Self {
162                Self::new(<Self as Branded>::Inner::default())
163            }
164        }
165    }
166}
167
168/// Derive a PartialEq, Eq, Ord, and PartialOrd implementation for the branded type if the inner
169/// type conforms to any of those traits.
170pub(crate) fn expand_ord_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
171    let eq_trait: syn::Path = syn::parse_quote!(::std::cmp::Eq);
172    let partial_eq_trait: syn::Path = syn::parse_quote!(::std::cmp::PartialEq);
173    let ord_trait: syn::Path = syn::parse_quote!(::std::cmp::Ord);
174    let partial_ord_trait: syn::Path = syn::parse_quote!(::std::cmp::PartialOrd);
175    quote! {
176        impl #partial_eq_trait for #brand_struct_name
177        where
178            for<'__branded> <Self as Branded>::Inner: #partial_eq_trait,
179        {
180            fn eq(&self, other: &Self) -> bool {
181                self.inner().eq(other.inner())
182            }
183        }
184        impl #eq_trait for #brand_struct_name
185        where
186            for<'__branded> <Self as Branded>::Inner: #eq_trait,
187        {
188        }
189        impl #ord_trait for #brand_struct_name
190        where
191            for<'__branded> <Self as Branded>::Inner: #ord_trait,
192        {
193            fn cmp(&self, other: &Self) -> ::std::cmp::Ordering {
194                self.0.cmp(&other.0)
195            }
196        }
197        impl #partial_ord_trait for #brand_struct_name
198        where
199            for<'__branded> <Self as Branded>::Inner: #partial_ord_trait,
200        {
201            fn partial_cmp(&self, other: &Self) -> ::std::option::Option<::std::cmp::Ordering> {
202                self.0.partial_cmp(&other.0)
203            }
204        }
205    }
206}
207
208/// Derive a Hash implementation for the branded type if the inner type conforms to Hash.
209pub(crate) fn expand_hash_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
210    let hash_trait: syn::Path = syn::parse_quote!(::std::hash::Hash);
211    quote! {
212        impl #hash_trait for #brand_struct_name
213        where
214            for<'__branded> <Self as Branded>::Inner: #hash_trait,
215        {
216            fn hash<H: ::std::hash::Hasher>(&self, state: &mut H) {
217                self.inner().hash(state);
218            }
219        }
220    }
221}
222
223/// Derive a Serde implementation for the branded type if asked for.
224pub(crate) fn expand_serde_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
225    let serialize_trait: syn::Path = syn::parse_quote!(::serde::Serialize);
226    let deserialize_trait: syn::Path = syn::parse_quote!(::serde::Deserialize);
227    quote! {
228        impl #serialize_trait for #brand_struct_name
229        where
230            for<'__branded> <Self as Branded>::Inner: #serialize_trait,
231        {
232            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
233            where
234                S: ::serde::Serializer,
235            {
236                self.inner().serialize(serializer)
237            }
238        }
239
240        impl<'de> #deserialize_trait<'de> for #brand_struct_name
241        where
242            for<'__branded> <Self as Branded>::Inner: #deserialize_trait<'de>,
243        {
244            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
245            where
246                D: ::serde::Deserializer<'de>,
247            {
248                <Self as Branded>::Inner::deserialize(deserializer)
249                    .map(Self::new)
250            }
251        }
252    }
253}
254
255/// Derive a sqlx Type, Encode, and Decode implementation for the branded type if asked for.
256pub(crate) fn expand_sqlx_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
257    let type_trait: syn::Path = syn::parse_quote!(::sqlx::Type);
258    let encode_trait: syn::Path = syn::parse_quote!(::sqlx::Encode);
259    let decode_trait: syn::Path = syn::parse_quote!(::sqlx::Decode);
260    quote! {
261        impl<DB> #type_trait<DB> for #brand_struct_name
262        where
263            for<'__branded> <Self as Branded>::Inner: #type_trait<DB>,
264            DB: ::sqlx::Database,
265        {
266            fn type_info() -> DB::TypeInfo {
267                <Self as Branded>::Inner::type_info()
268            }
269        }
270
271        impl<'de, DB> #decode_trait<'de, DB> for #brand_struct_name
272        where
273            for<'__branded> Self: Branded,
274            <Self as Branded>::Inner: for<'a> #decode_trait<'a, DB>,
275            DB: ::sqlx::Database,
276        {
277            fn decode(value: DB::ValueRef<'_>) -> ::std::result::Result<#brand_struct_name, ::sqlx::error::BoxDynError> {
278                <Self as Branded>::Inner::decode(value).map(Self::new)
279            }
280        }
281
282        impl<'en, DB> #encode_trait<'en, DB> for #brand_struct_name
283        where
284            for<'__branded> Self: Branded,
285            <Self as Branded>::Inner: for<'a> #encode_trait<'a, DB>,
286            DB: ::sqlx::Database,
287        {
288            fn encode_by_ref(&self, buf: &mut DB::ArgumentBuffer<'_>) -> ::std::result::Result<::sqlx::encode::IsNull, ::sqlx::error::BoxDynError> {
289                self.inner().encode_by_ref(buf)
290            }
291        }
292    }
293}
294
295pub(crate) fn expand_uuidv4_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
296    quote! {
297        impl #brand_struct_name
298        where
299            for<'__branded> Self: Branded<Inner = ::uuid::Uuid>
300        {
301            /// Get a new random UUID v4.
302            pub fn new_v4() -> Self { Self::new(::uuid::Uuid::new_v4()) }
303        }
304    }
305}
306
307pub(crate) fn expand_uuid_nil_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
308    quote! {
309        impl #brand_struct_name
310        where
311            for<'__branded> Self: Branded<Inner = ::uuid::Uuid>
312        {
313            /// Get a new random UUID v4.
314            pub fn nil() -> Self { Self::new(::uuid::Uuid::nil()) }
315        }
316    }
317}
318
319pub(crate) fn expand_uuidv7_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
320    quote! {
321        impl #brand_struct_name
322        where
323            for<'__branded> Self: Branded<Inner = ::uuid::Uuid>
324        {
325            /// Get a new random UUID v7.
326            pub fn new_v7(ts: ::uuid::timestamp::Timestamp) -> Self { Self::new(::uuid::Uuid::new_v7(ts)) }
327
328            /// Get a new random UUID v7 with the timestamp bits set to now.
329            pub fn now_v7() -> Self { Self::new(::uuid::Uuid::now_v7() )}
330        }
331    }
332}