Skip to main content

bit_by_bit/
lib.rs

1extern crate proc_macro;
2
3use std::str::FromStr;
4
5use proc_macro::TokenStream;
6use quote::quote;
7use syn::punctuated::Punctuated;
8use syn::{parse, Token};
9
10#[derive(PartialEq, Eq, Clone, Copy)]
11struct TypeInfo {
12    bittness: usize,
13    is_signed: bool,
14}
15
16impl TypeInfo {
17    fn from_type_path(path: &syn::TypePath) -> Self {
18        assert!(path.qself.is_none());
19        let ty_str = path
20            .path
21            .get_ident()
22            .unwrap_or_else(|| {
23                let path = &path.path;
24                panic!("Unsupported type `{}`", quote!(#path))
25            })
26            .to_string();
27        match ty_str.as_str() {
28            "u8" => TypeInfo {
29                bittness: 8,
30                is_signed: false,
31            },
32            "i8" => TypeInfo {
33                bittness: 8,
34                is_signed: true,
35            },
36            "u16" => TypeInfo {
37                bittness: 16,
38                is_signed: false,
39            },
40            "i16" => TypeInfo {
41                bittness: 16,
42                is_signed: true,
43            },
44            "u32" => TypeInfo {
45                bittness: 32,
46                is_signed: false,
47            },
48            "i32" => TypeInfo {
49                bittness: 32,
50                is_signed: true,
51            },
52            "u64" => TypeInfo {
53                bittness: 64,
54                is_signed: false,
55            },
56            "i64" => TypeInfo {
57                bittness: 64,
58                is_signed: true,
59            },
60            "u128" => TypeInfo {
61                bittness: 128,
62                is_signed: false,
63            },
64            "i128" => TypeInfo {
65                bittness: 128,
66                is_signed: true,
67            },
68            s => panic!("Unsupported type `{}`", s),
69        }
70    }
71
72    fn from_type(ty: &syn::Type) -> Self {
73        match ty {
74            syn::Type::Path(ref path) => Self::from_type_path(path),
75            syn::Type::Group(group) => Self::from_type(&group.elem),
76            ty => panic!(
77                "Only primitive types are supported, but provided `{}`",
78                quote!(#ty)
79            ),
80        }
81    }
82}
83
84impl quote::ToTokens for TypeInfo {
85    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
86        let s = format!(
87            "{}{}",
88            if self.is_signed { 'i' } else { 'u' },
89            self.bittness
90        );
91        let stream = proc_macro::TokenStream::from_str(&s).unwrap();
92        let ty: syn::Type = parse(stream).unwrap();
93        tokens.extend(quote! { #ty })
94    }
95}
96
97struct PrevField {
98    ty_info: TypeInfo,
99    bits_left: usize,
100    common_field_name: String,
101}
102
103struct FieldInfo {
104    field_name: String,
105    common_field_name: String,
106    ty_info: TypeInfo,
107    start_bit: usize,
108    end_bit: usize,
109}
110
111#[proc_macro_attribute]
112pub fn bit_by_bit(_attr: TokenStream, input: TokenStream) -> TokenStream {
113    let syn::ItemStruct {
114        attrs,
115        vis,
116        ident,
117        generics,
118        fields,
119        ..
120    } = parse(input).expect("Expected a struct item");
121
122    let fields = match fields {
123        syn::Fields::Named(fields) => fields.named,
124        _ => panic!("Only structs with named fields are supported"),
125    };
126
127    let mut replaced_fileds = Vec::<FieldInfo>::new();
128    let mut new_fields = 0;
129    let mut prev_field = None::<PrevField>;
130
131    let fields = fields
132        .into_iter()
133        .filter_map(|mut f| {
134            if let Some(idx) = f
135                .attrs
136                .iter()
137                .position(|attr| attr.path.get_ident().unwrap().to_string().as_str() == "bit")
138            {
139                let ty_info = TypeInfo::from_type(&f.ty);
140                let bits = f
141                    .attrs
142                    .remove(idx)
143                    .parse_args::<syn::LitInt>()
144                    .unwrap()
145                    .base10_parse::<usize>()
146                    .unwrap();
147                assert!(
148                    bits <= ty_info.bittness,
149                    "bitness overflow, the type support up to {} bit, but {} provided",
150                    ty_info.bittness,
151                    bits
152                );
153
154                let field_name = f
155                    .ident
156                    .as_ref()
157                    .expect("fields are always named")
158                    .to_string();
159
160                if let Some(mut prev) = prev_field
161                    .take()
162                    .filter(|prev| prev.bits_left >= bits && prev.ty_info == ty_info)
163                {
164                    let start_bit = ty_info.bittness - prev.bits_left;
165                    prev.bits_left -= bits;
166                    let end_bit = ty_info.bittness - prev.bits_left;
167                    replaced_fileds.push(FieldInfo {
168                        field_name,
169                        common_field_name: prev.common_field_name.clone(),
170                        ty_info,
171                        start_bit,
172                        end_bit,
173                    });
174                    prev_field = Some(prev);
175                    return None;
176                } else {
177                    let common_field_name = format!("__base_field_{}", new_fields);
178                    replaced_fileds.push(FieldInfo {
179                        field_name,
180                        common_field_name: common_field_name.clone(),
181                        ty_info,
182                        start_bit: 0,
183                        end_bit: bits,
184                    });
185
186                    let ident = f.ident.as_mut().expect("checked earlier");
187                    let span = ident.span();
188                    *ident = syn::Ident::new(&common_field_name, span);
189
190                    prev_field = Some(PrevField {
191                        ty_info,
192                        bits_left: ty_info.bittness - bits,
193                        common_field_name,
194                    });
195                    new_fields += 1;
196                }
197            }
198
199            Some(f)
200        })
201        .collect::<Punctuated<_, Token![,]>>();
202
203    let fns = replaced_fileds
204        .iter()
205        .map(|info| {
206            let FieldInfo {
207                field_name,
208                common_field_name,
209                ty_info,
210                start_bit,
211                end_bit,
212            } = info;
213            let field_ident = quote::format_ident!("{}", field_name);
214            let set_field_ident = quote::format_ident!("set_{}", field_name);
215            let cfn = quote::format_ident!("{}", common_field_name);
216            let mut mask = 0u128;
217            for _ in 0..*end_bit {
218                mask <<= 1;
219                mask |= 1;
220            }
221            let mask = syn::LitInt::new(&mask.to_string(), proc_macro2::Span::call_site());
222
223            quote! {
224                fn #field_ident (&self) -> #ty_info {
225                    (self.#cfn >> #start_bit) & #mask
226                }
227
228                fn #set_field_ident (&mut self, val: #ty_info) {
229                    self.#cfn ^= (self.#cfn >> #start_bit) & #mask;
230                    self.#cfn |= (val & #mask) << #start_bit;
231                }
232            }
233        })
234        .collect::<Vec<_>>();
235
236    let generic_params = &generics.params;
237    let generic_param_names = generics
238        .params
239        .iter()
240        .map(|param| {
241            use syn::GenericParam;
242            match &param {
243                GenericParam::Type(ty) => {
244                    let ident = &ty.ident;
245                    quote! { #ident }
246                }
247                GenericParam::Lifetime(lf) => {
248                    let lifetime = &lf.lifetime;
249                    quote! { #lifetime }
250                }
251                GenericParam::Const(c) => {
252                    let ident = &c.ident;
253                    quote! { #ident }
254                }
255            }
256        })
257        .collect::<Punctuated<_, Token![,]>>();
258    let generic_where = &generics.where_clause;
259
260    let item = quote! {
261        #(#attrs)*
262        #vis struct #ident #generics
263        #generic_where
264        {
265            #fields
266        }
267
268        impl < #generic_params > #ident < #generic_param_names >
269        #generic_where
270        {
271            #(#fns)*
272        }
273    };
274
275    item.into()
276}