1use darling::{FromDeriveInput, FromField};
2use proc_macro::TokenStream;
3use quote::quote;
4
5#[derive(FromDeriveInput)]
6#[darling(attributes(branded), supports(struct_newtype))]
7pub(crate) struct BrandedTypeOptions {
8 ident: syn::Ident,
9 data: darling::ast::Data<(), BrandedFieldOptions>,
10
11 #[darling(default)]
12 serde: bool,
13 #[darling(default)]
14 uuidv4: bool,
15 #[darling(default)]
16 uuidv7: bool,
17 #[darling(default)]
18 sqlx: bool,
19}
20
21#[derive(FromField)]
22pub(crate) struct BrandedFieldOptions {
23 ty: syn::Type,
24}
25
26#[proc_macro_derive(Branded, attributes(branded))]
27pub fn branded_derive(input: TokenStream) -> TokenStream {
28 let input = syn::parse_macro_input!(input);
29 let options = match BrandedTypeOptions::from_derive_input(&input) {
30 Ok(options) => options,
31 Err(err) => return err.write_errors().into(),
32 };
33 let expanded = match expand_branded_derive(options) {
34 Ok(expanded) => expanded,
35 Err(err) => return err.to_compile_error().into(),
36 };
37 expanded.into()
38}
39
40pub(crate) fn expand_branded_derive(
41 options: BrandedTypeOptions,
42) -> syn::Result<proc_macro2::TokenStream> {
43 let mut tokens = proc_macro2::TokenStream::new();
44 let struct_name = &options.ident;
45 let field = options
46 .data
47 .take_struct()
48 .map(|fields| {
49 fields.into_iter().next().ok_or(syn::Error::new(
50 struct_name.span(),
51 "struct must have exactly one field (newtype pattern)",
52 ))
53 })
54 .transpose()?
55 .ok_or(syn::Error::new(
56 struct_name.span(),
57 "derive(Branded) can only be used on structs",
58 ))?;
59 let ty = field.ty;
60 let constructor_doc_comment = format!("Construct a new `{struct_name}` value.");
61 tokens.extend(quote! {
62 impl Branded for #struct_name {
63 type Inner = #ty;
64 fn inner(&self) -> &#ty { &self.0 }
65 fn into_inner(self) -> #ty { self.0 }
66 }
67 impl #struct_name {
68 #[doc = #constructor_doc_comment]
69 pub fn new(inner: #ty) -> Self { Self(inner) }
70 }
71 });
72
73 tokens.extend(expand_clone_copy_impl(struct_name));
74 tokens.extend(expand_debug_display_impl(struct_name));
75 tokens.extend(expand_default_impl(struct_name));
76 tokens.extend(expand_ord_impl(struct_name));
77 tokens.extend(expand_hash_impl(struct_name));
78
79 if options.serde {
80 tokens.extend(expand_serde_impl(struct_name));
81 }
82
83 if options.sqlx {
84 tokens.extend(expand_sqlx_impl(struct_name));
85 }
86
87 if options.uuidv4 || options.uuidv7 {
88 tokens.extend(expand_uuid_nil_impl(struct_name));
89 }
90
91 if options.uuidv4 {
92 tokens.extend(expand_uuidv4_impl(struct_name));
93 }
94
95 if options.uuidv7 {
96 tokens.extend(expand_uuidv7_impl(struct_name));
97 }
98
99 Ok(tokens)
100}
101
102pub(crate) fn expand_clone_copy_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
104 let copy_trait: syn::Path = syn::parse_quote!(::std::marker::Copy);
105 let clone_trait: syn::Path = syn::parse_quote!(::std::clone::Clone);
106 quote! {
107 impl #clone_trait for #brand_struct_name
108 where
109 for<'__branded> <Self as Branded>::Inner: #clone_trait,
110 {
111 fn clone(&self) -> Self {
112 Self::new(self.inner().clone())
113 }
114 }
115 impl #copy_trait for #brand_struct_name
116 where
117 for<'__branded> <Self as Branded>::Inner: #copy_trait,
118 {
119 }
120 }
121}
122
123pub(crate) fn expand_debug_display_impl(
129 brand_struct_name: &syn::Ident,
130) -> proc_macro2::TokenStream {
131 let display_trait: syn::Path = syn::parse_quote!(::std::fmt::Display);
132 let debug_trait: syn::Path = syn::parse_quote!(::std::fmt::Debug);
133 quote! {
134 impl #display_trait for #brand_struct_name
135 where
136 for<'__branded> <Self as Branded>::Inner: #display_trait,
137 {
138 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
139 ::std::fmt::Display::fmt(&self.inner(), f)
140 }
141 }
142 impl #debug_trait for #brand_struct_name
143 where
144 for<'__branded> <Self as Branded>::Inner: #debug_trait,
145 {
146 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
147 f.debug_tuple(stringify!(#brand_struct_name)).field(self.inner()).finish()
148 }
149 }
150 }
151}
152
153pub(crate) fn expand_default_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
155 let path: syn::Path = syn::parse_quote!(::std::default::Default);
156 quote! {
157 impl #path for #brand_struct_name
158 where
159 for<'__branded> <Self as Branded>::Inner: #path,
160 {
161 fn default() -> Self {
162 Self::new(<Self as Branded>::Inner::default())
163 }
164 }
165 }
166}
167
168pub(crate) fn expand_ord_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
171 let eq_trait: syn::Path = syn::parse_quote!(::std::cmp::Eq);
172 let partial_eq_trait: syn::Path = syn::parse_quote!(::std::cmp::PartialEq);
173 let ord_trait: syn::Path = syn::parse_quote!(::std::cmp::Ord);
174 let partial_ord_trait: syn::Path = syn::parse_quote!(::std::cmp::PartialOrd);
175 quote! {
176 impl #partial_eq_trait for #brand_struct_name
177 where
178 for<'__branded> <Self as Branded>::Inner: #partial_eq_trait,
179 {
180 fn eq(&self, other: &Self) -> bool {
181 self.inner().eq(other.inner())
182 }
183 }
184 impl #eq_trait for #brand_struct_name
185 where
186 for<'__branded> <Self as Branded>::Inner: #eq_trait,
187 {
188 }
189 impl #ord_trait for #brand_struct_name
190 where
191 for<'__branded> <Self as Branded>::Inner: #ord_trait,
192 {
193 fn cmp(&self, other: &Self) -> ::std::cmp::Ordering {
194 self.0.cmp(&other.0)
195 }
196 }
197 impl #partial_ord_trait for #brand_struct_name
198 where
199 for<'__branded> <Self as Branded>::Inner: #partial_ord_trait,
200 {
201 fn partial_cmp(&self, other: &Self) -> ::std::option::Option<::std::cmp::Ordering> {
202 self.0.partial_cmp(&other.0)
203 }
204 }
205 }
206}
207
208pub(crate) fn expand_hash_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
210 let hash_trait: syn::Path = syn::parse_quote!(::std::hash::Hash);
211 quote! {
212 impl #hash_trait for #brand_struct_name
213 where
214 for<'__branded> <Self as Branded>::Inner: #hash_trait,
215 {
216 fn hash<H: ::std::hash::Hasher>(&self, state: &mut H) {
217 self.inner().hash(state);
218 }
219 }
220 }
221}
222
223pub(crate) fn expand_serde_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
225 let serialize_trait: syn::Path = syn::parse_quote!(::serde::Serialize);
226 let deserialize_trait: syn::Path = syn::parse_quote!(::serde::Deserialize);
227 quote! {
228 impl #serialize_trait for #brand_struct_name
229 where
230 for<'__branded> <Self as Branded>::Inner: #serialize_trait,
231 {
232 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
233 where
234 S: ::serde::Serializer,
235 {
236 self.inner().serialize(serializer)
237 }
238 }
239
240 impl<'de> #deserialize_trait<'de> for #brand_struct_name
241 where
242 for<'__branded> <Self as Branded>::Inner: #deserialize_trait<'de>,
243 {
244 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
245 where
246 D: ::serde::Deserializer<'de>,
247 {
248 <Self as Branded>::Inner::deserialize(deserializer)
249 .map(Self::new)
250 }
251 }
252 }
253}
254
255pub(crate) fn expand_sqlx_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
257 let type_trait: syn::Path = syn::parse_quote!(::sqlx::Type);
258 let encode_trait: syn::Path = syn::parse_quote!(::sqlx::Encode);
259 let decode_trait: syn::Path = syn::parse_quote!(::sqlx::Decode);
260 quote! {
261 impl<DB> #type_trait<DB> for #brand_struct_name
262 where
263 for<'__branded> <Self as Branded>::Inner: #type_trait<DB>,
264 DB: ::sqlx::Database,
265 {
266 fn type_info() -> DB::TypeInfo {
267 <Self as Branded>::Inner::type_info()
268 }
269 }
270
271 impl<'de, DB> #decode_trait<'de, DB> for #brand_struct_name
272 where
273 for<'__branded> Self: Branded,
274 <Self as Branded>::Inner: for<'a> #decode_trait<'a, DB>,
275 DB: ::sqlx::Database,
276 {
277 fn decode(value: DB::ValueRef<'_>) -> ::std::result::Result<#brand_struct_name, ::sqlx::error::BoxDynError> {
278 <Self as Branded>::Inner::decode(value).map(Self::new)
279 }
280 }
281
282 impl<'en, DB> #encode_trait<'en, DB> for #brand_struct_name
283 where
284 for<'__branded> Self: Branded,
285 <Self as Branded>::Inner: for<'a> #encode_trait<'a, DB>,
286 DB: ::sqlx::Database,
287 {
288 fn encode_by_ref(&self, buf: &mut DB::ArgumentBuffer<'_>) -> ::std::result::Result<::sqlx::encode::IsNull, ::sqlx::error::BoxDynError> {
289 self.inner().encode_by_ref(buf)
290 }
291 }
292 }
293}
294
295pub(crate) fn expand_uuidv4_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
296 quote! {
297 impl #brand_struct_name
298 where
299 for<'__branded> Self: Branded<Inner = ::uuid::Uuid>
300 {
301 pub fn new_v4() -> Self { Self::new(::uuid::Uuid::new_v4()) }
303 }
304 }
305}
306
307pub(crate) fn expand_uuid_nil_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
308 quote! {
309 impl #brand_struct_name
310 where
311 for<'__branded> Self: Branded<Inner = ::uuid::Uuid>
312 {
313 pub fn nil() -> Self { Self::new(::uuid::Uuid::nil()) }
315 }
316 }
317}
318
319pub(crate) fn expand_uuidv7_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
320 quote! {
321 impl #brand_struct_name
322 where
323 for<'__branded> Self: Branded<Inner = ::uuid::Uuid>
324 {
325 pub fn new_v7(ts: ::uuid::timestamp::Timestamp) -> Self { Self::new(::uuid::Uuid::new_v7(ts)) }
327
328 pub fn now_v7() -> Self { Self::new(::uuid::Uuid::now_v7() )}
330 }
331 }
332}