1use proc_macro2::{Ident, TokenStream};
4use quote::quote;
5use syn::{
6 parse_macro_input, Attribute, Data, DataStruct, DataUnion, DeriveInput, Fields,
7 Generics,
8};
9
10#[proc_macro_derive(Pod)]
25pub fn derive_pod(input_token: proc_macro::TokenStream) -> proc_macro::TokenStream {
26 let input = parse_macro_input!(input_token as DeriveInput);
27 expand_derive_pod(input).into()
28}
29
30const ALLOWED_REPRS: [&'static str; 11] = [
31 "C", "u8", "i8", "u16", "i16", "u32", "i32", "u64", "i64", "usize", "isize",
32];
33
34fn expand_derive_pod(input: DeriveInput) -> TokenStream {
35 let attrs = input.attrs;
36 let ident = input.ident;
37 let generics = input.generics;
38 match input.data {
39 Data::Struct(data_struct) => impl_pod_for_struct(data_struct, generics, ident, attrs),
40 Data::Union(data_union) => impl_pod_for_union(data_union, generics, ident, attrs),
41 Data::Enum(_) => panic!("Trying to derive `Pod` trait for enum may be unsound. Use `TryFromInt` instead."),
42 }
43}
44
45fn impl_pod_for_struct(
46 data_struct: DataStruct,
47 generics: Generics,
48 ident: Ident,
49 attrs: Vec<Attribute>,
50) -> TokenStream {
51 if !has_valid_repr(attrs) {
52 panic!("{} has invalid repr to implement Pod", ident.to_string());
53 }
54 let DataStruct { fields, .. } = data_struct;
55 let fields = match fields {
56 Fields::Named(fields_named) => fields_named.named,
57 Fields::Unnamed(fields_unnamed) => fields_unnamed.unnamed,
58 Fields::Unit => panic!("derive pod does not work for struct with unit field"),
59 };
60
61 let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
63
64 let trait_tokens = pod_trait_tokens();
65
66 let pod_where_predicates = fields
67 .into_iter()
68 .map(|field| {
69 let field_ty = field.ty;
70 quote! {
71 #field_ty: #trait_tokens
72 }
73 })
74 .collect::<Vec<_>>();
75
76 if where_clause.is_none() {
78 quote! {
79 #[automatically_derived]
80 unsafe impl #impl_generics #trait_tokens for #ident #type_generics where #(#pod_where_predicates),* {}
81 }
82 } else {
83 quote! {
84 #[automatically_derived]
85 unsafe impl #impl_generics #trait_tokens for #ident #type_generics #where_clause, #(#pod_where_predicates),* {}
86 }
87 }
88}
89
90fn impl_pod_for_union(
91 data_union: DataUnion,
92 generics: Generics,
93 ident: Ident,
94 attrs: Vec<Attribute>,
95) -> TokenStream {
96 if !has_valid_repr(attrs) {
97 panic!("{} has invalid repr to implement Pod", ident.to_string());
98 }
99 let fields = data_union.fields.named;
100 let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
102
103 let trait_tokens = pod_trait_tokens();
104 let pod_where_predicates = fields
105 .into_iter()
106 .map(|field| {
107 let field_ty = field.ty;
108 quote! {
109 #field_ty: #trait_tokens
110 }
111 })
112 .collect::<Vec<_>>();
113
114 if where_clause.is_none() {
116 quote! {
117 #[automatically_derived]
118 unsafe impl #impl_generics #trait_tokens for #ident #type_generics where #(#pod_where_predicates),* {}
119 }
120 } else {
121 quote! {
122 #[automatically_derived]
123 unsafe impl #impl_generics #trait_tokens for #ident #type_generics #where_clause, #(#pod_where_predicates),* {}
124 }
125 }
126}
127
128fn has_valid_repr(attrs: Vec<Attribute>) -> bool {
129 for attr in attrs {
130 if let Some(ident) = attr.path.get_ident() {
131 if "repr" == ident.to_string().as_str() {
132 let repr = attr.tokens.to_string();
133 let repr = repr.replace("(", "").replace(")", "");
134 let reprs = repr
135 .split(",")
136 .map(|one_repr| one_repr.trim())
137 .collect::<Vec<_>>();
138 if let Some(_) = ALLOWED_REPRS.iter().position(|allowed_repr| {
139 reprs
140 .iter()
141 .position(|one_repr| one_repr == allowed_repr)
142 .is_some()
143 }) {
144 return true;
145 }
146 }
147 }
148 }
149 false
150}
151
152fn pod_trait_tokens() -> TokenStream {
153 let package_name = std::env::var("CARGO_PKG_NAME").unwrap();
154
155 if package_name.as_str() == "ostd" || package_name.as_str() == "ostd-pod" {
158 quote! {
159 ::ostd_pod::Pod
160 }
161 } else {
162 quote! {
163 ::ostd::Pod
164 }
165 }
166}