bitstruct_derive/
lib.rs

1use core::{cmp::Ordering, convert::TryInto, fmt, ops::Range};
2
3use proc_macro2::TokenStream;
4use quote::quote;
5use syn::{
6    parse::{Parse, ParseStream},
7    parse_macro_input,
8    punctuated::Punctuated,
9    Token,
10};
11
12#[proc_macro]
13pub fn bitstruct(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
14    let input = parse_macro_input!(tokens as BitStructInput);
15    expand_bitstruct(input)
16        .unwrap_or_else(|err| err.to_compile_error())
17        .into()
18}
19
20fn expand_bitstruct(input: BitStructInput) -> syn::Result<TokenStream> {
21    let attrs = &input.attrs;
22    let vis = &input.vis;
23    let name = &input.name;
24    let raw_vis = &input.raw_vis;
25    let raw = &input.raw.as_type();
26    let fields = input
27        .fields
28        .iter()
29        .map(|field| expand_field_methods(&input, field))
30        .collect::<syn::Result<Vec<TokenStream>>>()?;
31    Ok(quote! {
32        #(#attrs)*
33        #vis struct #name(#raw_vis #raw);
34        impl #name {
35            #(#fields)*
36        }
37    })
38}
39
40fn expand_field_methods(input: &BitStructInput, field: &FieldDef) -> syn::Result<TokenStream> {
41    // Extract any bitstruct specific field attributes.
42    let bitstruct_field_attrs = field
43        .attrs
44        .iter()
45        .find_map(|attr| {
46            let bitstruct: syn::Path = syn::parse_quote! {bitstruct};
47            match attr.parse_meta().ok()? {
48                syn::Meta::List(meta_list) if meta_list.path == bitstruct => Some(meta_list.nested),
49                _ => None,
50            }
51        })
52        .unwrap_or_default();
53
54    let getter_method = expand_field_getter(input, field);
55    let setter_methods = {
56        let omit_setter = bitstruct_field_attrs.iter().any(|nested_meta| {
57            let omit_setter: syn::NestedMeta = syn::parse_quote! {omit_setter};
58            nested_meta == &omit_setter
59        });
60
61        if omit_setter {
62            quote! {}
63        } else {
64            expand_field_setter(input, field)
65        }
66    };
67
68    Ok(quote! {
69        #getter_method
70        #setter_methods
71    })
72}
73
74fn expand_field_getter(input: &BitStructInput, field: &FieldDef) -> TokenStream {
75    // Only pass through the non-bitstruct field attributes.
76    let pass_thru_attrs = field.attrs.iter().filter(|&attr| {
77        let bitstruct: syn::Path = syn::parse_quote! {bitstruct};
78        attr.path != bitstruct
79    });
80
81    let target_ty = field.target.as_type();
82    let mask = hexlit(input.raw, field.bits.get_mask());
83    let start_bit = hexlit(input.raw, field.bits.0.start.into());
84    let mask_and_shift: syn::Expr = syn::parse_quote! {
85        ((self.0 & #mask) >> #start_bit)
86    };
87    let cast = from_raw(mask_and_shift, input.raw, &field.target, &field.bits);
88
89    let field_vis = &field.vis;
90    let field_name = &field.name;
91    let maybe_const_fn = if let Target::Convert(_) = field.target {
92        quote! {fn}
93    } else {
94        quote! {const fn}
95    };
96    quote! {
97        #(#pass_thru_attrs)*
98        #field_vis #maybe_const_fn #field_name(&self) -> #target_ty {
99            #cast
100        }
101    }
102}
103
104fn from_raw(raw_expr: syn::Expr, raw: RawDef, target: &Target, bitrange: &BitRange) -> syn::Expr {
105    match target {
106        Target::Int(raw_def) => {
107            let target_ty = raw_def.as_type();
108            syn::parse_quote! {
109                #raw_expr as #target_ty
110            }
111        }
112        Target::Bool => {
113            syn::parse_quote! {
114                #raw_expr != 0
115            }
116        }
117        Target::Convert(ty) => {
118            let bitlen = bitrange.0.end - bitrange.0.start;
119            let smallest_target = Target::smallest_target(bitlen);
120            let smallest_target_expr = from_raw(raw_expr, raw, &smallest_target, bitrange);
121            let smallest_target_ty = smallest_target.as_type();
122            syn::parse_quote! {
123                <Self as ::bitstruct::FromRaw<#smallest_target_ty, #ty>>::from_raw(#smallest_target_expr)
124            }
125        }
126    }
127}
128
129fn expand_field_setter(input: &BitStructInput, field: &FieldDef) -> TokenStream {
130    // Only pass through the non-bitstruct field attributes.
131    let pass_thru_attrs = field
132        .attrs
133        .iter()
134        .filter(|&attr| {
135            let bitstruct: syn::Path = syn::parse_quote! {bitstruct};
136            attr.path != bitstruct
137        })
138        .collect::<Vec<_>>();
139
140    let target_ty = field.target.as_type();
141    let mask = field.bits.get_mask();
142    let neg_mask = hexlit(input.raw, !mask);
143    let mask = hexlit(input.raw, mask);
144    let start_bit = hexlit(input.raw, field.bits.0.start.into());
145
146    let field_vis = &field.vis;
147    let field_name = &field.name;
148    let with_method = quote::format_ident!("with_{}", field_name);
149    let set_method = quote::format_ident!("set_{}", field_name);
150    let cast = into_raw(
151        syn::parse_quote! {value},
152        &field.target,
153        input.raw,
154        &field.bits,
155    );
156    let maybe_const_fn = if let Target::Convert(_) = field.target {
157        quote! {fn}
158    } else {
159        quote! {const fn}
160    };
161    quote! {
162        #[must_use]
163        #(#pass_thru_attrs)*
164        #field_vis #maybe_const_fn #with_method(mut self, value: #target_ty) -> Self {
165            self.0 = (self.0 & #neg_mask) | ((#cast << #start_bit) & #mask);
166            self
167        }
168
169        #(#pass_thru_attrs)*
170        #field_vis fn #set_method(&mut self, value: #target_ty) {
171            self.0 = (self.0 & #neg_mask) | ((#cast << #start_bit) & #mask);
172        }
173    }
174}
175
176fn into_raw(
177    target_expr: syn::Expr,
178    target: &Target,
179    raw: RawDef,
180    bitrange: &BitRange,
181) -> syn::Expr {
182    match target {
183        Target::Int(_) | Target::Bool => {
184            let raw = raw.as_type();
185            syn::parse_quote! {
186                (#target_expr as #raw)
187            }
188        }
189        Target::Convert(ty) => {
190            let bitlen = bitrange.0.end - bitrange.0.start;
191            let smallest_target = Target::smallest_target(bitlen);
192            let smallest_target_ty = smallest_target.as_type();
193            let smallest_target_expr = syn::parse_quote! {
194                <Self as ::bitstruct::IntoRaw<#smallest_target_ty, #ty>>::into_raw(#target_expr)
195            };
196            into_raw(smallest_target_expr, &smallest_target, raw, bitrange)
197        }
198    }
199}
200
201/// Helper methods on ParseStream that attempt to parse an item but only advance the cursor on success.
202trait TryParse {
203    fn try_parse<T: Parse>(&self) -> syn::Result<T>;
204    fn try_call<T>(&self, function: fn(_: ParseStream<'_>) -> syn::Result<T>) -> syn::Result<T>;
205}
206
207impl TryParse for ParseStream<'_> {
208    fn try_parse<T: Parse>(&self) -> syn::Result<T> {
209        use syn::parse::discouraged::Speculative;
210        let fork = self.fork();
211        match fork.parse::<T>() {
212            Ok(value) => {
213                self.advance_to(&fork);
214                Ok(value)
215            }
216            err => err,
217        }
218    }
219
220    fn try_call<T>(&self, function: fn(_: ParseStream<'_>) -> syn::Result<T>) -> syn::Result<T> {
221        use syn::parse::discouraged::Speculative;
222        let fork = self.fork();
223        match fork.call(function) {
224            Ok(value) => {
225                self.advance_to(&fork);
226                Ok(value)
227            }
228            err => err,
229        }
230    }
231}
232
233#[derive(Debug)]
234struct BitStructInput {
235    attrs: Vec<syn::Attribute>,
236    vis: syn::Visibility,
237    name: syn::Ident,
238    raw_vis: syn::Visibility,
239    raw: RawDef,
240    fields: Punctuated<FieldDef, Token![;]>,
241}
242
243impl Parse for BitStructInput {
244    fn parse(input: ParseStream) -> syn::Result<Self> {
245        let attrs = input.call(syn::Attribute::parse_outer)?;
246        let vis = input.parse()?;
247        input.parse::<Token![struct]>()?;
248        let name = input.parse()?;
249        let within_parens;
250        syn::parenthesized!(within_parens in input);
251        let raw_vis = within_parens.parse()?;
252        let raw: RawDef = within_parens.parse()?;
253        let within_braces;
254        syn::braced!(within_braces in input);
255        let fields: Punctuated<FieldDef, _> = Punctuated::parse_terminated(&within_braces)?;
256        for field in fields.iter() {
257            if field.bits.0.end > raw.bit_len() {
258                return Err(syn::Error::new(
259                    field.name.span(),
260                    format!(
261                        "field `{}` specifies a bitrange beyond `{}` range",
262                        field.name,
263                        raw.as_str()
264                    ),
265                ));
266            }
267        }
268        Ok(BitStructInput {
269            attrs,
270            vis,
271            name,
272            raw_vis,
273            raw,
274            fields,
275        })
276    }
277}
278
279#[derive(Debug, Copy, Clone, Eq, PartialEq)]
280enum RawDef {
281    U8,
282    U16,
283    U32,
284    U64,
285    U128,
286}
287
288impl RawDef {
289    fn as_str(self) -> &'static str {
290        match self {
291            RawDef::U8 => "u8",
292            RawDef::U16 => "u16",
293            RawDef::U32 => "u32",
294            RawDef::U64 => "u64",
295            RawDef::U128 => "u128",
296        }
297    }
298
299    fn as_type(self) -> syn::Type {
300        syn::parse_str(self.as_str()).unwrap()
301    }
302
303    fn bit_len(self) -> u8 {
304        match self {
305            RawDef::U8 => 8,
306            RawDef::U16 => 16,
307            RawDef::U32 => 32,
308            RawDef::U64 => 64,
309            RawDef::U128 => 128,
310        }
311    }
312}
313
314impl Parse for RawDef {
315    fn parse(input: ParseStream) -> syn::Result<Self> {
316        let ident: syn::Ident = input.parse()?;
317        if ident == "u8" {
318            Ok(RawDef::U8)
319        } else if ident == "u16" {
320            Ok(RawDef::U16)
321        } else if ident == "u32" {
322            Ok(RawDef::U32)
323        } else if ident == "u64" {
324            Ok(RawDef::U64)
325        } else if ident == "u128" {
326            Ok(RawDef::U128)
327        } else {
328            Err(input.error(format!(
329                "`{}` is not supported; needs to be one of u8,u16,u32,u64,u128",
330                ident
331            )))
332        }
333    }
334}
335
336#[derive(Debug)]
337struct FieldDef {
338    attrs: Vec<syn::Attribute>,
339    vis: syn::Visibility,
340    name: syn::Ident,
341    target: Target,
342    bits: BitRange,
343}
344
345impl Parse for FieldDef {
346    fn parse(input: ParseStream) -> syn::Result<Self> {
347        let attrs = input.call(syn::Attribute::parse_outer)?;
348        let vis = input.parse()?;
349        let name = input.parse()?;
350        input.parse::<Token![:]>()?;
351        let target: Target = input.parse()?;
352        input.parse::<Token![=]>()?;
353        let bits: BitRange = input.parse()?;
354        if target.bit_len() < bits.bit_len() {
355            return Err(input.error(format!(
356                "target `{}` can only represent {} bits; {} specified",
357                target,
358                target.bit_len(),
359                bits.bit_len(),
360            )));
361        }
362        Ok(FieldDef {
363            attrs,
364            vis,
365            name,
366            target,
367            bits,
368        })
369    }
370}
371
372#[derive(Debug, Eq, PartialEq)]
373enum Target {
374    /// Basic integer type: u8,u16,u32,u64,u128
375    Int(RawDef),
376    /// bool
377    Bool,
378    /// A type that will be converted to/from using bitstruct::{FromRaw, IntoRaw}
379    Convert(syn::Type),
380}
381
382impl Target {
383    fn smallest_target(bitlen: u8) -> Target {
384        match bitlen {
385            x if x == 1 => Target::Bool,
386            x if x <= 8 => Target::Int(RawDef::U8),
387            x if x <= 16 => Target::Int(RawDef::U16),
388            x if x <= 32 => Target::Int(RawDef::U32),
389            x if x <= 64 => Target::Int(RawDef::U64),
390            x if x <= 128 => Target::Int(RawDef::U128),
391            _ => unreachable!("invalid bitlen"),
392        }
393    }
394
395    fn bit_len(&self) -> u8 {
396        match self {
397            Target::Int(raw) => raw.bit_len(),
398            Target::Bool => 1,
399            Target::Convert(_) => u8::MAX,
400        }
401    }
402
403    fn as_type(&self) -> syn::Type {
404        match self {
405            Target::Int(raw) => raw.as_type(),
406            Target::Bool => syn::parse_quote! {bool},
407            Target::Convert(ty) => ty.clone().into(),
408        }
409    }
410}
411
412impl fmt::Display for Target {
413    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
414        match self {
415            Target::Int(rawdef) => write!(f, "{}", rawdef.as_str()),
416            Target::Bool => write!(f, "bool"),
417            Target::Convert(ty) => write!(f, "{:?}", ty),
418        }
419    }
420}
421
422mod kw {
423    syn::custom_keyword!(bool);
424}
425
426impl Parse for Target {
427    fn parse(input: ParseStream) -> syn::Result<Self> {
428        input
429            .try_parse::<RawDef>()
430            .map(|raw_def| Target::Int(raw_def))
431            .or_else(|_| input.try_parse::<kw::bool>().map(|_| Target::Bool))
432            .or_else(|_| input.try_parse::<syn::Type>().map(|ty| Target::Convert(ty)))
433    }
434}
435
436#[derive(Debug, Eq, PartialEq)]
437struct BitRange(Range<u8>);
438
439impl BitRange {
440    fn bit_len(&self) -> u8 {
441        self.0.len().try_into().unwrap()
442    }
443
444    fn get_mask(&self) -> u128 {
445        let mut mask = !0u128;
446        mask <<= 128 - self.0.end;
447        mask >>= 128 - self.0.end;
448        mask >>= self.0.start;
449        mask <<= self.0.start;
450        mask
451    }
452}
453
454impl Parse for BitRange {
455    fn parse(input: ParseStream) -> syn::Result<Self> {
456        fn parse_end_range(input: ParseStream) -> syn::Result<u8> {
457            let range_limits: syn::RangeLimits = input.parse()?;
458            let end_bit: u8 = input.parse::<syn::LitInt>()?.base10_parse()?;
459            Ok(match range_limits {
460                syn::RangeLimits::HalfOpen(_) => end_bit,
461                syn::RangeLimits::Closed(_) => end_bit + 1,
462            })
463        }
464
465        let start_bit: u8 = input.parse::<syn::LitInt>()?.base10_parse()?;
466        let range = match input.try_call(parse_end_range) {
467            Ok(end_bit) => start_bit..end_bit,
468            Err(_) => start_bit..start_bit + 1,
469        };
470        match range.start.cmp(&range.end) {
471            Ordering::Less => {}
472            Ordering::Equal => return Err(input.error("empty bit range specified")),
473            Ordering::Greater => {
474                return Err(input
475                    .error("least significant bit must be specified before most significant bit"))
476            }
477        };
478        Ok(BitRange(range))
479    }
480}
481
482fn hexlit(typ: RawDef, value: u128) -> syn::LitInt {
483    let num_hex_chars = typ.bit_len() as usize / 4;
484    syn::LitInt::new(
485        &format!(
486            "0x{value:0width$x}{suffix:}",
487            value = value,
488            suffix = typ.as_str(),
489            width = num_hex_chars
490        ),
491        proc_macro2::Span::call_site(),
492    )
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498    #[test]
499    fn parse_bitstruct_input() {
500        let bitstruct: BitStructInput = syn::parse2(quote! {
501            #[derive(Clone,Copy)]
502            pub(crate) struct Foo(pub u16) {
503                #[inline]
504                pub f1: u8 = 0 .. 8;
505                pub f2: u8 = 8 .. 12;
506            }
507        })
508        .unwrap();
509        assert_eq!(bitstruct.name, quote::format_ident!("Foo"));
510        assert_eq!(bitstruct.fields.len(), 2);
511        assert_eq!(bitstruct.fields[0].attrs.len(), 1);
512        assert_eq!(bitstruct.fields[1].attrs.len(), 0);
513    }
514
515    #[test]
516    fn parse_field_def() {
517        let field_def: FieldDef = syn::parse2(quote! {
518            pub field1: u8 = 3 .. 5
519        })
520        .unwrap();
521        assert_eq!(field_def.name, quote::format_ident!("field1"));
522        assert_eq!(field_def.target, Target::Int(RawDef::U8));
523        assert_eq!(field_def.bits, BitRange(3..5));
524
525        let field_def: FieldDef = syn::parse2(quote! {
526            pub field1: bool = 3
527        })
528        .unwrap();
529        assert_eq!(field_def.name, quote::format_ident!("field1"));
530        assert_eq!(field_def.target, Target::Bool);
531        assert_eq!(field_def.bits, BitRange(3..4));
532    }
533
534    #[test]
535    fn parse_target() {
536        assert_eq!(
537            Target::Int(RawDef::U8),
538            syn::parse2::<Target>(quote! {u8}).unwrap(),
539        );
540        assert_eq!(
541            Target::Int(RawDef::U16),
542            syn::parse2::<Target>(quote! {u16}).unwrap(),
543        );
544        assert_eq!(
545            Target::Int(RawDef::U128),
546            syn::parse2::<Target>(quote! {u128}).unwrap(),
547        );
548        assert_eq!(Target::Bool, syn::parse2::<Target>(quote! {bool}).unwrap(),);
549        assert_eq!(
550            Target::Convert(syn::parse_quote! {MyEnum}),
551            syn::parse2::<Target>(quote! {MyEnum}).unwrap(),
552        );
553        assert_eq!(
554            Target::Convert(syn::parse_quote! {Vec<u32>}),
555            syn::parse2::<Target>(quote! {Vec<u32>}).unwrap(),
556        );
557    }
558
559    #[test]
560    fn parse_bitrange() {
561        assert_eq!(
562            BitRange(0..10),
563            syn::parse2::<BitRange>(quote! {0..10}).unwrap()
564        );
565        assert_eq!(
566            BitRange(0..12),
567            syn::parse2::<BitRange>(quote! {0..=11}).unwrap()
568        );
569        assert_eq!(
570            BitRange(14..15),
571            syn::parse2::<BitRange>(quote! {14}).unwrap()
572        );
573    }
574}