1use darling::{FromDeriveInput, FromField};
2use proc_macro::TokenStream;
3use quote::quote;
4
5#[derive(FromDeriveInput)]
6#[darling(attributes(branded), supports(struct_newtype))]
7pub(crate) struct BrandedTypeOptions {
8 ident: syn::Ident,
9 data: darling::ast::Data<(), BrandedFieldOptions>,
10
11 #[darling(default)]
12 serde: bool,
13 #[darling(default)]
14 uuid: bool,
15 #[darling(default)]
16 sqlx: bool,
17}
18
19#[derive(FromField)]
20pub(crate) struct BrandedFieldOptions {
21 ty: syn::Type,
22}
23
24#[proc_macro_derive(Branded, attributes(branded))]
25pub fn branded_derive(input: TokenStream) -> TokenStream {
26 let input = syn::parse_macro_input!(input);
27 let options = match BrandedTypeOptions::from_derive_input(&input) {
28 Ok(options) => options,
29 Err(err) => return err.write_errors().into(),
30 };
31 let expanded = match expand_branded_derive(options) {
32 Ok(expanded) => expanded,
33 Err(err) => return err.to_compile_error().into(),
34 };
35 expanded.into()
36}
37
38pub(crate) fn expand_branded_derive(
39 options: BrandedTypeOptions,
40) -> syn::Result<proc_macro2::TokenStream> {
41 let mut tokens = proc_macro2::TokenStream::new();
42 let struct_name = &options.ident;
43 let field = options
44 .data
45 .take_struct()
46 .map(|fields| {
47 fields.into_iter().next().ok_or(syn::Error::new(
48 struct_name.span(),
49 "struct must have exactly one field (newtype pattern)",
50 ))
51 })
52 .transpose()?
53 .ok_or(syn::Error::new(
54 struct_name.span(),
55 "derive(Branded) can only be used on structs",
56 ))?;
57 let ty = field.ty;
58 let constructor_doc_comment = format!("Construct a new `{struct_name}` value.");
59 tokens.extend(quote! {
60 impl Branded for #struct_name {
61 type Inner = #ty;
62 fn inner(&self) -> &#ty { &self.0 }
63 fn into_inner(self) -> #ty { self.0 }
64 }
65 impl #struct_name {
66 #[doc = #constructor_doc_comment]
67 pub fn new(inner: #ty) -> Self { Self(inner) }
68 }
69 });
70
71 tokens.extend(expand_clone_copy_impl(struct_name));
72 tokens.extend(expand_debug_display_impl(struct_name));
73 tokens.extend(expand_default_impl(struct_name));
74 tokens.extend(expand_ord_impl(struct_name));
75 tokens.extend(expand_hash_impl(struct_name));
76
77 if options.serde {
78 tokens.extend(expand_serde_impl(struct_name));
79 }
80
81 if options.sqlx {
82 tokens.extend(expand_sqlx_impl(struct_name));
83 }
84
85 if options.uuid {
86 tokens.extend(expand_uuid_impl(struct_name));
87 }
88
89 Ok(tokens)
90}
91
92pub(crate) fn expand_clone_copy_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
94 let copy_trait: syn::Path = syn::parse_quote!(::std::marker::Copy);
95 let clone_trait: syn::Path = syn::parse_quote!(::std::clone::Clone);
96 quote! {
97 impl #clone_trait for #brand_struct_name
98 where
99 for<'__branded> <Self as Branded>::Inner: #clone_trait,
100 {
101 fn clone(&self) -> Self {
102 Self::new(self.inner().clone())
103 }
104 }
105 impl #copy_trait for #brand_struct_name
106 where
107 for<'__branded> <Self as Branded>::Inner: #copy_trait,
108 {
109 }
110 }
111}
112
113pub(crate) fn expand_debug_display_impl(
119 brand_struct_name: &syn::Ident,
120) -> proc_macro2::TokenStream {
121 let display_trait: syn::Path = syn::parse_quote!(::std::fmt::Display);
122 let debug_trait: syn::Path = syn::parse_quote!(::std::fmt::Debug);
123 quote! {
124 impl #display_trait for #brand_struct_name
125 where
126 for<'__branded> <Self as Branded>::Inner: #display_trait,
127 {
128 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
129 ::std::fmt::Display::fmt(&self.inner(), f)
130 }
131 }
132 impl #debug_trait for #brand_struct_name
133 where
134 for<'__branded> <Self as Branded>::Inner: #debug_trait,
135 {
136 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
137 f.debug_tuple(stringify!(#brand_struct_name)).field(self.inner()).finish()
138 }
139 }
140 }
141}
142
143pub(crate) fn expand_default_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
145 let path: syn::Path = syn::parse_quote!(::std::default::Default);
146 quote! {
147 impl #path for #brand_struct_name
148 where
149 for<'__branded> <Self as Branded>::Inner: #path,
150 {
151 fn default() -> Self {
152 Self::new(<Self as Branded>::Inner::default())
153 }
154 }
155 }
156}
157
158pub(crate) fn expand_ord_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
161 let eq_trait: syn::Path = syn::parse_quote!(::std::cmp::Eq);
162 let partial_eq_trait: syn::Path = syn::parse_quote!(::std::cmp::PartialEq);
163 let ord_trait: syn::Path = syn::parse_quote!(::std::cmp::Ord);
164 let partial_ord_trait: syn::Path = syn::parse_quote!(::std::cmp::PartialOrd);
165 quote! {
166 impl #partial_eq_trait for #brand_struct_name
167 where
168 for<'__branded> <Self as Branded>::Inner: #partial_eq_trait,
169 {
170 fn eq(&self, other: &Self) -> bool {
171 self.inner().eq(other.inner())
172 }
173 }
174 impl #eq_trait for #brand_struct_name
175 where
176 for<'__branded> <Self as Branded>::Inner: #eq_trait,
177 {
178 }
179 impl #ord_trait for #brand_struct_name
180 where
181 for<'__branded> <Self as Branded>::Inner: #ord_trait,
182 {
183 fn cmp(&self, other: &Self) -> ::std::cmp::Ordering {
184 self.0.cmp(&other.0)
185 }
186 }
187 impl #partial_ord_trait for #brand_struct_name
188 where
189 for<'__branded> <Self as Branded>::Inner: #partial_ord_trait,
190 {
191 fn partial_cmp(&self, other: &Self) -> ::std::option::Option<::std::cmp::Ordering> {
192 self.0.partial_cmp(&other.0)
193 }
194 }
195 }
196}
197
198pub(crate) fn expand_hash_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
200 let hash_trait: syn::Path = syn::parse_quote!(::std::hash::Hash);
201 quote! {
202 impl #hash_trait for #brand_struct_name
203 where
204 for<'__branded> <Self as Branded>::Inner: #hash_trait,
205 {
206 fn hash<H: ::std::hash::Hasher>(&self, state: &mut H) {
207 self.inner().hash(state);
208 }
209 }
210 }
211}
212
213pub(crate) fn expand_serde_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
215 let serialize_trait: syn::Path = syn::parse_quote!(::serde::Serialize);
216 let deserialize_trait: syn::Path = syn::parse_quote!(::serde::Deserialize);
217 quote! {
218 impl #serialize_trait for #brand_struct_name
219 where
220 for<'__branded> <Self as Branded>::Inner: #serialize_trait,
221 {
222 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
223 where
224 S: ::serde::Serializer,
225 {
226 self.inner().serialize(serializer)
227 }
228 }
229
230 impl<'de> #deserialize_trait<'de> for #brand_struct_name
231 where
232 for<'__branded> <Self as Branded>::Inner: #deserialize_trait<'de>,
233 {
234 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
235 where
236 D: ::serde::Deserializer<'de>,
237 {
238 <Self as Branded>::Inner::deserialize(deserializer)
239 .map(Self::new)
240 }
241 }
242 }
243}
244
245pub(crate) fn expand_sqlx_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
247 let type_trait: syn::Path = syn::parse_quote!(::sqlx::Type);
248 let encode_trait: syn::Path = syn::parse_quote!(::sqlx::Encode);
249 let decode_trait: syn::Path = syn::parse_quote!(::sqlx::Decode);
250 quote! {
251 impl<DB> #type_trait<DB> for #brand_struct_name
252 where
253 for<'__branded> <Self as Branded>::Inner: #type_trait<DB>,
254 DB: ::sqlx::Database,
255 {
256 fn type_info() -> DB::TypeInfo {
257 <Self as Branded>::Inner::type_info()
258 }
259 }
260
261 impl<'de, DB> #decode_trait<'de, DB> for #brand_struct_name
262 where
263 for<'__branded> Self: Branded,
264 <Self as Branded>::Inner: for<'a> #decode_trait<'a, DB>,
265 DB: ::sqlx::Database,
266 {
267 fn decode(value: DB::ValueRef<'_>) -> ::std::result::Result<#brand_struct_name, ::sqlx::error::BoxDynError> {
268 <Self as Branded>::Inner::decode(value).map(Self::new)
269 }
270 }
271
272 impl<'en, DB> #encode_trait<'en, DB> for #brand_struct_name
273 where
274 for<'__branded> Self: Branded,
275 <Self as Branded>::Inner: for<'a> #encode_trait<'a, DB>,
276 DB: ::sqlx::Database,
277 {
278 fn encode_by_ref(&self, buf: &mut DB::ArgumentBuffer<'_>) -> ::std::result::Result<::sqlx::encode::IsNull, ::sqlx::error::BoxDynError> {
279 self.inner().encode_by_ref(buf)
280 }
281 }
282 }
283}
284
285pub(crate) fn expand_uuid_impl(brand_struct_name: &syn::Ident) -> proc_macro2::TokenStream {
286 quote! {
287 impl #brand_struct_name
288 where
289 for<'__branded> Self: Branded<Inner = ::uuid::Uuid>
290 {
291 fn nil() -> Self { Self::new(::uuid::Uuid::nil()) }
293
294 fn new_v4() -> Self { Self::new(::uuid::Uuid::new_v4()) }
296 }
297 }
298}