pod_derive/
lib.rs

1//! This crate provides a procedural macro for deriving the `Pod` trait defined in `pod-rs`.
2
3use proc_macro2::{Ident, TokenStream};
4use quote::quote;
5use syn::{
6    parse_macro_input, Attribute, Data, DataStruct, DataUnion, DeriveInput, Fields,
7    Generics,
8};
9
10/// Deriving [`Pod`] trait for a struct or union. 
11///
12/// When deriving the `Pod` trait,
13/// this macro performs a safety check because the `Pod` trait is marked as unsafe.
14/// For structs and unions, 
15/// the macro checks that the struct has a valid repr attribute (e.g., `repr(C)`, `repr(u8)`),
16/// and each field is of `Pod` type.
17/// Enums cannot implement the `Pod` trait.
18/// 
19/// If you want to implement `Pod` 
20/// for a struct or union with fields that are not of Pod type,
21/// you can implement it unsafely and perform the necessary checks manually.
22/// 
23/// [`Pod`]: https://docs.rs/pod-rs/latest/pod_rs/trait.Pod.html
24#[proc_macro_derive(Pod)]
25pub fn derive_pod(input_token: proc_macro::TokenStream) -> proc_macro::TokenStream {
26    let input = parse_macro_input!(input_token as DeriveInput);
27    expand_derive_pod(input).into()
28}
29
30const ALLOWED_REPRS: [&'static str; 11] = [
31    "C", "u8", "i8", "u16", "i16", "u32", "i32", "u64", "i64", "usize", "isize",
32];
33
34fn expand_derive_pod(input: DeriveInput) -> TokenStream {
35    let attrs = input.attrs;
36    let ident = input.ident;
37    let generics = input.generics;
38    match input.data {
39        Data::Struct(data_struct) => impl_pod_for_struct(data_struct, generics, ident, attrs),
40        Data::Union(data_union) => impl_pod_for_union(data_union, generics, ident, attrs),
41        Data::Enum(_) => panic!("Trying to derive `Pod` trait for enum may be unsound. Use `TryFromInt` instead."),
42    }
43}
44
45fn impl_pod_for_struct(
46    data_struct: DataStruct,
47    generics: Generics,
48    ident: Ident,
49    attrs: Vec<Attribute>,
50) -> TokenStream {
51    if !has_valid_repr(attrs) {
52        panic!("{} has invalid repr to implement Pod", ident.to_string());
53    }
54    let DataStruct { fields, .. } = data_struct;
55    let fields = match fields {
56        Fields::Named(fields_named) => fields_named.named,
57        Fields::Unnamed(fields_unnamed) => fields_unnamed.unnamed,
58        Fields::Unit => panic!("derive pod does not work for struct with unit field"),
59    };
60
61    // deal with generics
62    let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
63
64    let pod_where_predicates = fields
65        .into_iter()
66        .map(|field| {
67            let field_ty = field.ty;
68            quote! {
69                #field_ty: ::pod_rs::Pod
70            }
71        })
72        .collect::<Vec<_>>();
73
74    // if where_clause is none, we should add a `where` word manually.
75    if where_clause.is_none() {
76        quote! {
77            #[automatically_derived]
78            unsafe impl #impl_generics ::pod_rs::Pod for #ident #type_generics where #(#pod_where_predicates),* {}
79        }
80    } else {
81        quote! {
82            #[automatically_derived]
83            unsafe impl #impl_generics ::pod_rs::Pod for #ident #type_generics #where_clause, #(#pod_where_predicates),* {}
84        }
85    }
86}
87
88fn impl_pod_for_union(
89    data_union: DataUnion,
90    generics: Generics,
91    ident: Ident,
92    attrs: Vec<Attribute>,
93) -> TokenStream {
94    if !has_valid_repr(attrs) {
95        panic!("{} has invalid repr to implement Pod", ident.to_string());
96    }
97    let fields = data_union.fields.named;
98    // deal with generics
99    let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
100
101    let pod_where_predicates = fields
102        .into_iter()
103        .map(|field| {
104            let field_ty = field.ty;
105            quote! {
106                #field_ty: ::pod_rs::Pod
107            }
108        })
109        .collect::<Vec<_>>();
110
111    // if where_clause is none, we should add a `where` word manually.
112    if where_clause.is_none() {
113        quote! {
114            #[automatically_derived]
115            unsafe impl #impl_generics ::pod_rs::Pod for #ident #type_generics where #(#pod_where_predicates),* {}
116        }
117    } else {
118        quote! {
119            #[automatically_derived]
120            unsafe impl #impl_generics ::pod_rs::Pod for #ident #type_generics #where_clause, #(#pod_where_predicates),* {}
121        }
122    }
123}
124
125fn has_valid_repr(attrs: Vec<Attribute>) -> bool {
126    for attr in attrs {
127        if let Some(ident) = attr.path.get_ident() {
128            if "repr" == ident.to_string().as_str() {
129                let repr = attr.tokens.to_string();
130                let repr = repr.replace("(", "").replace(")", "");
131                let reprs = repr
132                    .split(",")
133                    .map(|one_repr| one_repr.trim())
134                    .collect::<Vec<_>>();
135                if let Some(_) = ALLOWED_REPRS.iter().position(|allowed_repr| {
136                    reprs
137                        .iter()
138                        .position(|one_repr| one_repr == allowed_repr)
139                        .is_some()
140                }) {
141                    return true;
142                }
143            }
144        }
145    }
146    false
147}