ostd_pod_derive/
lib.rs

1//! This crate provides a procedural macro for deriving the `Pod` trait defined in `pod-rs`.
2
3use proc_macro2::{Ident, TokenStream};
4use quote::quote;
5use syn::{
6    parse_macro_input, Attribute, Data, DataStruct, DataUnion, DeriveInput, Fields,
7    Generics,
8};
9
10/// Deriving [`Pod`] trait for a struct or union. 
11///
12/// When deriving the `Pod` trait,
13/// this macro performs a safety check because the `Pod` trait is marked as unsafe.
14/// For structs and unions, 
15/// the macro checks that the struct has a valid repr attribute (e.g., `repr(C)`, `repr(u8)`),
16/// and each field is of `Pod` type.
17/// Enums cannot implement the `Pod` trait.
18/// 
19/// If you want to implement `Pod` 
20/// for a struct or union with fields that are not of Pod type,
21/// you can implement it unsafely and perform the necessary checks manually.
22/// 
23/// [`Pod`]: https://docs.rs/pod-rs/latest/ostd_pod/trait.Pod.html
24#[proc_macro_derive(Pod)]
25pub fn derive_pod(input_token: proc_macro::TokenStream) -> proc_macro::TokenStream {
26    let input = parse_macro_input!(input_token as DeriveInput);
27    expand_derive_pod(input).into()
28}
29
30const ALLOWED_REPRS: [&'static str; 11] = [
31    "C", "u8", "i8", "u16", "i16", "u32", "i32", "u64", "i64", "usize", "isize",
32];
33
34fn expand_derive_pod(input: DeriveInput) -> TokenStream {
35    let attrs = input.attrs;
36    let ident = input.ident;
37    let generics = input.generics;
38    match input.data {
39        Data::Struct(data_struct) => impl_pod_for_struct(data_struct, generics, ident, attrs),
40        Data::Union(data_union) => impl_pod_for_union(data_union, generics, ident, attrs),
41        Data::Enum(_) => panic!("Trying to derive `Pod` trait for enum may be unsound. Use `TryFromInt` instead."),
42    }
43}
44
45fn impl_pod_for_struct(
46    data_struct: DataStruct,
47    generics: Generics,
48    ident: Ident,
49    attrs: Vec<Attribute>,
50) -> TokenStream {
51    if !has_valid_repr(attrs) {
52        panic!("{} has invalid repr to implement Pod", ident.to_string());
53    }
54    let DataStruct { fields, .. } = data_struct;
55    let fields = match fields {
56        Fields::Named(fields_named) => fields_named.named,
57        Fields::Unnamed(fields_unnamed) => fields_unnamed.unnamed,
58        Fields::Unit => panic!("derive pod does not work for struct with unit field"),
59    };
60
61    // deal with generics
62    let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
63
64    let trait_tokens = pod_trait_tokens();
65
66    let pod_where_predicates = fields
67        .into_iter()
68        .map(|field| {
69            let field_ty = field.ty;
70            quote! {
71                #field_ty: #trait_tokens
72            }
73        })
74        .collect::<Vec<_>>();
75
76    // if where_clause is none, we should add a `where` word manually.
77    if where_clause.is_none() {
78        quote! {
79            #[automatically_derived]
80            unsafe impl #impl_generics #trait_tokens for #ident #type_generics where #(#pod_where_predicates),* {}
81        }
82    } else {
83        quote! {
84            #[automatically_derived]
85            unsafe impl #impl_generics #trait_tokens for #ident #type_generics #where_clause, #(#pod_where_predicates),* {}
86        }
87    }
88}
89
90fn impl_pod_for_union(
91    data_union: DataUnion,
92    generics: Generics,
93    ident: Ident,
94    attrs: Vec<Attribute>,
95) -> TokenStream {
96    if !has_valid_repr(attrs) {
97        panic!("{} has invalid repr to implement Pod", ident.to_string());
98    }
99    let fields = data_union.fields.named;
100    // deal with generics
101    let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
102
103    let trait_tokens = pod_trait_tokens();
104    let pod_where_predicates = fields
105        .into_iter()
106        .map(|field| {
107            let field_ty = field.ty;
108            quote! {
109                #field_ty: #trait_tokens
110            }
111        })
112        .collect::<Vec<_>>();
113
114    // if where_clause is none, we should add a `where` word manually.
115    if where_clause.is_none() {
116        quote! {
117            #[automatically_derived]
118            unsafe impl #impl_generics #trait_tokens for #ident #type_generics where #(#pod_where_predicates),* {}
119        }
120    } else {
121        quote! {
122            #[automatically_derived]
123            unsafe impl #impl_generics #trait_tokens for #ident #type_generics #where_clause, #(#pod_where_predicates),* {}
124        }
125    }
126}
127
128fn has_valid_repr(attrs: Vec<Attribute>) -> bool {
129    for attr in attrs {
130        if let Some(ident) = attr.path.get_ident() {
131            if "repr" == ident.to_string().as_str() {
132                let repr = attr.tokens.to_string();
133                let repr = repr.replace("(", "").replace(")", "");
134                let reprs = repr
135                    .split(",")
136                    .map(|one_repr| one_repr.trim())
137                    .collect::<Vec<_>>();
138                if let Some(_) = ALLOWED_REPRS.iter().position(|allowed_repr| {
139                    reprs
140                        .iter()
141                        .position(|one_repr| one_repr == allowed_repr)
142                        .is_some()
143                }) {
144                    return true;
145                }
146            }
147        }
148    }
149    false
150}
151
152fn pod_trait_tokens() -> TokenStream {
153    let package_name = std::env::var("CARGO_PKG_NAME").unwrap();
154
155    // Only `ostd` and the unit test in `ostd-pod` depend on `Pod` fro `ostd-pod`,
156    // other crates depend on `Pod` re-exported from ostd.
157    if package_name.as_str() == "ostd" || package_name.as_str() == "ostd-pod" {
158        quote! {
159            ::ostd_pod::Pod
160        }
161    } else {
162        quote! {
163            ::ostd::Pod
164        }
165    }
166}