base64id_derive/
lib.rs

1//! This crate contains the derive macro for [base64id-rs](https://github.com/shauncksm/base64id-rs).
2//! You shouldn't use this crate directly. See [here](https://docs.rs/base64id/latest/base64id/) instead.
3
4#![forbid(unsafe_code)]
5#![warn(missing_docs)]
6
7use proc_macro::TokenStream;
8use proc_macro2::{Ident, Span, TokenTree};
9use quote::quote;
10use syn::{Attribute, DeriveInput, Meta};
11
12const ERROR_INVALID_INNER_TYPE: &str =
13    "invalid type within tuple struct, expected i64, u64, i32, u32, i16 or u16";
14
15/// Create your own base64id tuple struct
16///
17/// # Usage
18///
19/// In it's most simple form, `#[derive(Base64Id)]` may be used as follows:
20/// ```ignore
21/// #[derive(Base64Id)]
22/// struct MyCustomId(T);
23/// ```
24///
25/// Where `T` is any of the following concrete types:
26/// [`i64`](https://doc.rust-lang.org/core/primitive.i64.html),
27/// [`i32`](https://doc.rust-lang.org/core/primitive.i32.html),
28/// [`i16`](https://doc.rust-lang.org/core/primitive.i16.html),
29/// [`u64`](https://doc.rust-lang.org/core/primitive.u64.html),
30/// [`u32`](https://doc.rust-lang.org/core/primitive.u32.html),
31/// [`u16`](https://doc.rust-lang.org/core/primitive.u16.html)
32///
33/// For example:
34/// ```ignore
35/// #[derive(Base64Id)]
36/// struct MyCustomId(i64);
37/// ```
38///
39/// ## Derive Macro Trait Implementations
40/// Once `#[derive(Base64Id)]` is applied to a tuple struct as described above, the following trait implementations are added:
41///
42/// #### [`Display`](https://doc.rust-lang.org/core/fmt/trait.Display.html)
43///
44/// Display is added to encode the inner integer as a base64url string of appropriate length.
45///
46/// #### [`FromStr`](https://doc.rust-lang.org/core/str/trait.FromStr.html)
47///
48/// FromStr is added to decode any string from base64url to the inner integer type, returning an [`Error`](https://docs.rs/base64id/latest/base64id/enum.Error.html) on failure.
49///
50/// #### [`TryFrom`](https://doc.rust-lang.org/core/convert/trait.TryFrom.html)
51///
52/// `TryFrom<[char; n]>` is added where `n` is the length of a given base64url string.
53/// This allows for converting a [`char`](https://doc.rust-lang.org/core/primitive.char.html) array of length `n` into your tuple struct.
54/// The value of `n` is:
55/// - 11 for 64 bit integers
56/// - 6 for 32 bit integers
57/// - 3 for 16 bit integers
58///
59/// #### [`PartialEq`](https://doc.rust-lang.org/core/cmp/trait.PartialEq.html), [`Eq`](https://doc.rust-lang.org/core/cmp/trait.Eq.html)
60///
61/// These are standard impl's and have no special behaviour.
62///
63/// #### [`PartialOrd`](https://doc.rust-lang.org/core/cmp/trait.PartialOrd.html), [`Ord`](https://doc.rust-lang.org/core/cmp/trait.Ord.html)
64///
65/// These traits are implemented with special behaviour where the inner type is a signed integer.
66///
67/// For unsigned integers the ordering behaviour is standard.
68/// For signed integers, the value is converted to big endian bytes, these bytes are then converted into an unsigned integer and order comparsion is done on this.
69///
70/// In other words, order comparions are based on the unsigned integer / binary representation of the integer.
71///
72/// #### [`From`](https://doc.rust-lang.org/core/convert/trait.From.html)
73///
74/// Four `From` trait impl's are added.
75/// Given the following example struct:
76/// ```ignore
77/// #[derive(Base64Id)]
78/// struct MyCustomId(i64);
79/// ```
80///
81/// The following `From` traits would be added:
82///
83/// ```ignore
84/// impl From<MyCustomId> for i64;
85/// impl From<i64> for MyCustomId;
86///
87/// impl From<MyCustomId> for u64;
88/// impl From<u64> for MyCustomId;
89/// ```
90///
91/// The first two trait impl's allow converting to and from the structs internal integer type, useful for simply extracting the internal integer out of the struct.
92///
93/// The last two allow converting integers **of the opposite sign** to and from the struct.
94/// This conversion preserves the binary representation of the integer.
95/// In practice this means signed and unsigned positive integers will have the same decimal value when converting between them.
96/// However, signed and unsigned negative integers will have different decimal values however.
97///
98/// ## Serde Trait Implementations
99///
100/// #### [`Serialize`](https://docs.rs/serde/latest/serde/trait.Serialize.html), [`Deserialize`](https://docs.rs/serde/latest/serde/trait.Deserialize.html)
101///
102/// You can also add optional Serde Serialize and Deserialize trait implementations to the struct.
103/// To do this you must include Serde as a dependency in your Cargo.toml file.
104/// Serde is not a dependency of this crate.
105///
106/// Serde traits can be applied using the following derive macro helper attribute:
107/// ```ignore
108/// #[derive(Base64Id)]
109/// #[base64id(Serialize, Deserialize)]
110/// struct MyCustomId(i64);
111/// ```
112///
113/// You can add neither, either or both traits as needed.
114///
115/// ## `MIN` / `MAX` Constants
116///
117/// In addition to the above trait implementations, `MIN` and `MAX` constants are added.
118/// These values are based on the inner integers unsigned / binary representation.
119#[proc_macro_derive(Base64Id, attributes(base64id))]
120pub fn tuple_struct_into_base64id(input: TokenStream) -> TokenStream {
121    let ast: DeriveInput = syn::parse(input).expect("failed to parse token stream");
122
123    let ident = ast.ident;
124    let struct_inner_type = get_validated_struct_data(ast.data);
125    let struct_inner_type_string = struct_inner_type.to_string();
126
127    let char_len = match struct_inner_type_string.as_str() {
128        "i64" | "u64" => 11,
129        "i32" | "u32" => 6,
130        "i16" | "u16" => 3,
131        _ => panic!("{ERROR_INVALID_INNER_TYPE}"),
132    };
133
134    let is_signed = struct_inner_type_string.starts_with("i");
135
136    let (
137        encode_fn,
138        decode_fn,
139        char_array_type,
140        struct_inner_type_u,
141        struct_inner_type_alt,
142        int_min,
143        int_max,
144    ) = match struct_inner_type_string.as_str() {
145        "i64" => (
146            quote! {::base64id::base64::encode_i64},
147            quote! {::base64id::base64::decode_i64},
148            quote! {[char; #char_len]},
149            quote! {u64},
150            quote! {u64},
151            quote! {0},
152            quote! {-1},
153        ),
154        "u64" => (
155            quote! {::base64id::base64::encode_u64},
156            quote! {::base64id::base64::decode_u64},
157            quote! {[char; #char_len]},
158            quote! {u64},
159            quote! {i64},
160            quote! {0},
161            quote! {#struct_inner_type::MAX},
162        ),
163        "i32" => (
164            quote! {::base64id::base64::encode_i32},
165            quote! {::base64id::base64::decode_i32},
166            quote! {[char; #char_len]},
167            quote! {u32},
168            quote! {u32},
169            quote! {0},
170            quote! {-1},
171        ),
172        "u32" => (
173            quote! {::base64id::base64::encode_u32},
174            quote! {::base64id::base64::decode_u32},
175            quote! {[char; #char_len]},
176            quote! {u32},
177            quote! {i32},
178            quote! {0},
179            quote! {#struct_inner_type::MAX},
180        ),
181        "i16" => (
182            quote! {::base64id::base64::encode_i16},
183            quote! {::base64id::base64::decode_i16},
184            quote! {[char; #char_len]},
185            quote! {u16},
186            quote! {u16},
187            quote! {0},
188            quote! {-1},
189        ),
190        "u16" => (
191            quote! {::base64id::base64::encode_u16},
192            quote! {::base64id::base64::decode_u16},
193            quote! {[char; #char_len]},
194            quote! {u16},
195            quote! {i16},
196            quote! {0},
197            quote! {#struct_inner_type::MAX},
198        ),
199        _ => panic!("{ERROR_INVALID_INNER_TYPE}"),
200    };
201
202    let mut implementation = quote! {
203        impl #ident {
204            const MIN: #ident = #ident(#int_min);
205            const MAX: #ident = #ident(#int_max);
206        }
207
208        impl ::core::fmt::Display for #ident {
209            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
210                use ::core::fmt::Write;
211
212                for c in #encode_fn(self.0) {
213                    f.write_char(c)?;
214                }
215
216                Ok(())
217            }
218        }
219
220        impl ::core::convert::From<#ident> for #struct_inner_type {
221            fn from(id: #ident) -> Self {
222                id.0
223            }
224        }
225
226        impl ::core::convert::From<#struct_inner_type> for #ident {
227            fn from(id: #struct_inner_type) -> Self {
228                Self(id)
229            }
230        }
231
232        impl ::core::convert::From<#ident> for #struct_inner_type_alt {
233            fn from(id: #ident) -> Self {
234                #struct_inner_type_alt::from_be_bytes(id.0.to_be_bytes())
235            }
236        }
237
238        impl ::core::convert::From<#struct_inner_type_alt> for #ident {
239            fn from(id: #struct_inner_type_alt) -> Self {
240                Self(#struct_inner_type::from_be_bytes(id.to_be_bytes()))
241            }
242        }
243
244        impl ::core::convert::TryFrom<#char_array_type> for #ident {
245            type Error = ::base64id::Error;
246
247            fn try_from(input: #char_array_type) -> ::core::result::Result<Self, Self::Error> {
248                Ok(Self(#decode_fn(input)?))
249            }
250        }
251
252        impl ::core::str::FromStr for #ident {
253            type Err = ::base64id::Error;
254
255            fn from_str(id: &str) -> ::core::result::Result<Self, Self::Err> {
256                let mut array: #char_array_type = ::core::default::Default::default();
257                let mut id_iter = id.chars();
258
259                for c in array.iter_mut() {
260                    *c = match id_iter.next() {
261                        Some(d) => d,
262                        None => return Err(::base64id::Error::InvalidLength),
263                    };
264                }
265
266                if id_iter.next().is_some() {
267                    return Err(::base64id::Error::InvalidLength);
268                }
269
270                #ident::try_from(array)
271            }
272        }
273
274        impl ::core::cmp::PartialEq for #ident {
275            fn eq(&self, other: &Self) -> bool {
276                self.0 == other.0
277            }
278        }
279        impl ::core::cmp::Eq for #ident {}
280    };
281
282    apply_ord_trait(&ident, struct_inner_type_u, is_signed, &mut implementation);
283
284    evaluate_attributes(&ident, ast.attrs, char_len, &mut implementation);
285
286    implementation.into()
287}
288
289/// Add PartialOrd and Ord trait to struct.
290/// This applies a signed to unsigned integer byte conversion if the inner integer type is signed
291fn apply_ord_trait(
292    ident: &proc_macro2::Ident,
293    struct_inner_type_u: proc_macro2::TokenStream,
294    is_signed: bool,
295    implementation: &mut proc_macro2::TokenStream,
296) {
297    implementation.extend(quote! {
298        impl ::core::cmp::PartialOrd for #ident {
299            fn partial_cmp(&self, other: &Self) -> ::core::option::Option<::core::cmp::Ordering> {
300                Some(self.cmp(other))
301            }
302        }
303    });
304
305    if is_signed {
306        implementation.extend(quote! {
307            impl ::core::cmp::Ord for #ident {
308                fn cmp(&self, other: &Self) -> ::core::cmp::Ordering {
309                    let this = #struct_inner_type_u::from_be_bytes(self.0.to_be_bytes());
310                    let other = #struct_inner_type_u::from_be_bytes(other.0.to_be_bytes());
311
312                    this.cmp(&other)
313                }
314            }
315        });
316    } else {
317        implementation.extend(quote! {
318            impl ::core::cmp::Ord for #ident {
319                fn cmp(&self, other: &Self) -> ::core::cmp::Ordering {
320                    let this = self.0;
321                    let other = other.0;
322
323                    this.cmp(&other)
324                }
325            }
326        });
327    }
328}
329
330/// Determines if the base64id attribute is present
331/// and if it contains expected keywords
332fn evaluate_attributes(
333    ident: &proc_macro2::Ident,
334    attrs: Vec<Attribute>,
335    char_len: usize,
336    implementation: &mut proc_macro2::TokenStream,
337) {
338    for attr in attrs {
339        let attr_ident = match attr.path().get_ident() {
340            Some(i) => i,
341            None => continue,
342        };
343
344        if attr_ident != "base64id" {
345            continue;
346        }
347
348        let meta_list = match attr.meta {
349            Meta::List(l) => l,
350            _ => continue,
351        };
352
353        for token in meta_list.tokens {
354            let token_ident = match token {
355                TokenTree::Ident(i) => i,
356                _ => continue,
357            };
358
359            if token_ident == "Serialize" {
360                apply_serialize_trait(&ident, implementation);
361            }
362
363            if token_ident == "Deserialize" {
364                apply_deserialize_trait(&ident, char_len, implementation);
365            }
366        }
367
368        return;
369    }
370}
371
372/// Enable the following syntax:
373/// ```ignore
374/// #[derive(base64id::Base64Id)]
375/// #[base64id(Serialize)]
376/// struct MyType(i64);
377/// ```
378fn apply_serialize_trait(
379    ident: &proc_macro2::Ident,
380    implementation: &mut proc_macro2::TokenStream,
381) {
382    implementation.extend(quote!(
383        impl ::serde::Serialize for #ident {
384            fn serialize<S>(&self, serializer: S) -> ::core::result::Result<S::Ok, S::Error>
385                where
386                    S: ::serde::Serializer
387            {
388                serializer.collect_str(self)
389            }
390        }
391    ));
392}
393
394/// Enable the following syntax:
395/// ```ignore
396/// #[derive(base64id::Base64Id)]
397/// #[base64id(Deserialize)]
398/// struct MyType(i64);
399/// ```
400fn apply_deserialize_trait(
401    ident: &proc_macro2::Ident,
402    char_len: usize,
403    implementation: &mut proc_macro2::TokenStream,
404) {
405    let visitor = Ident::new(
406        format!("{ident}__Base64Id_Serde_Visitor").as_str(),
407        Span::call_site(),
408    );
409
410    let last_char_range = match char_len {
411        11 | 3 => "AEIMQUYcgkosw048",
412        6 => "AQgw",
413        _ => panic!("unexpected character length {char_len}. cannot get last_char_range"),
414    };
415
416    implementation.extend(quote!(
417        impl<'de> ::serde::de::Deserialize<'de> for #ident {
418            fn deserialize<D>(deserializer: D) -> ::core::result::Result<Self, D::Error>
419            where
420                D: ::serde::Deserializer<'de>,
421            {
422                deserializer.deserialize_str(#visitor)
423            }
424        }
425
426        #[allow(non_camel_case_types)]
427        struct #visitor;
428
429        impl<'de> ::serde::de::Visitor<'de> for #visitor {
430            type Value = #ident;
431
432            fn expecting(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
433                f.write_str("a base64url encoded string")
434            }
435
436            fn visit_str<E>(self, v: &str) -> ::core::result::Result<Self::Value, E>
437            where
438                E: ::serde::de::Error,
439            {
440                use ::core::str::FromStr;
441
442                const EXP1: &str = concat!("exactly ", #char_len, " base64url characters");
443                const EXP2: &str = concat!(
444                    "the last character must be one of the following: ",
445                    #last_char_range
446                );
447                const ERR: &str = concat!("unknown error! expected exactly ", #char_len, "base64url characters");
448
449                #ident::from_str(v).map_err(|e| match e {
450                    ::base64id::Error::InvalidLength => E::invalid_length(v.len(), &EXP1),
451                    ::base64id::Error::InvalidCharacter => E::invalid_value(
452                        ::serde::de::Unexpected::Other("1 or more non-base64url characters"),
453                        &EXP1,
454                    ),
455                    ::base64id::Error::OutOfBoundsCharacter => E::invalid_value(
456                        ::serde::de::Unexpected::Other("the last character was out of bounds"),
457                        &EXP2,
458                    ),
459                    _ => E::custom(ERR)
460                })
461            }
462        }
463    ));
464}
465
466/// Ensure data type is a tuple struct and contains one of the expected integer types inside
467fn get_validated_struct_data(data: syn::Data) -> syn::Ident {
468    let data = match data {
469        syn::Data::Struct(s) => s,
470        _ => panic!("unsupported data type. expected a tuple struct"),
471    };
472
473    let fields = match data.fields {
474        syn::Fields::Unnamed(f) => f.unnamed,
475        _ => panic!("unsupported data type. expected a tuple struct"),
476    };
477
478    let item = match fields.len() {
479        1 => fields.first().unwrap(),
480        _ => panic!("expected a tuple struct with exactly 1 field"),
481    };
482
483    let item_path = match item.ty.clone() {
484        syn::Type::Path(p) => p.path,
485        _ => panic!("{ERROR_INVALID_INNER_TYPE}"),
486    };
487
488    let item_type = match item_path.get_ident() {
489        Some(t) => t,
490        None => panic!("{ERROR_INVALID_INNER_TYPE}"),
491    };
492
493    match item_type.to_string().as_str() {
494        "i64" | "i32" | "i16" | "u64" | "u32" | "u16" => item_type.clone(),
495        _ => panic!("{ERROR_INVALID_INNER_TYPE}"),
496    }
497}