bitregions_impl/
lib.rs

1//
2// Copyright (c) Zach Marcantel. All rights reserved.
3// Licensed under the GPLv3. See LICENSE file in the project root
4// for full license information.
5//
6
7extern crate proc_macro;
8
9#[macro_use]
10extern crate quote;
11
12/// Wrapper type for the visibility of the generated struct
13/// and the parsed syntax defining the regions.
14struct BitRegions {
15    vis: Option<syn::token::Pub>,
16    struct_def: Struct,
17    user_fns: UserFns,
18}
19
20impl syn::parse::Parse for BitRegions {
21    fn parse(input: syn::parse::ParseStream) -> syn::parse::Result<Self> {
22        let vis = if input.peek(syn::token::Pub) {
23            Some(input.parse()?)
24        } else {
25            None
26        };
27
28        let struct_def: Struct = input.parse()?;
29        let user_fns: UserFns = input.parse()?;
30
31        Ok(BitRegions {
32            vis: vis,
33            struct_def,
34            user_fns,
35        })
36    }
37}
38
39/// The memory location to default initialize the region to.
40enum MemoryLocation {
41    Lit(syn::LitInt),
42    Ident(syn::Ident),
43    Expr(syn::ExprBinary),
44}
45
46impl syn::parse::Parse for MemoryLocation {
47    fn parse(input: syn::parse::ParseStream) -> syn::parse::Result<Self> {
48        if !input.peek2(syn::token::Brace) {
49            return Ok(MemoryLocation::Expr(input.parse()?));
50        }
51
52        if input.peek(syn::Ident) {
53            return Ok(MemoryLocation::Ident(input.parse()?));
54        }
55        if input.peek(syn::LitInt) {
56            return Ok(MemoryLocation::Lit(input.parse()?));
57        }
58
59        Err(syn::Error::new(
60            input.span(),
61            "expected ident, literal, or const expression",
62        ))
63    }
64}
65
66/// Holds the identitiy, numeric type representation, and regions.
67struct Struct {
68    ident: syn::Ident,
69    repr: syn::Type,
70    default_loc: Option<MemoryLocation>,
71    fields: syn::punctuated::Punctuated<Field, syn::Token![,]>,
72}
73
74impl syn::parse::Parse for Struct {
75    fn parse(input: syn::parse::ParseStream) -> syn::parse::Result<Self> {
76        // grab the identity and "C" representation
77        let ident: syn::Ident = input.parse()?;
78        let repr: syn::Type = input.parse()?;
79
80        let mut default_loc = None;
81        if input.peek(syn::token::At) {
82            let _: syn::token::At = input.parse()?;
83            default_loc = Some(input.parse()?);
84        }
85
86        // extract everything wrapped in the braces
87        // also collect a region per parsed field
88        let content;
89        let _: syn::token::Brace = syn::braced!(content in input);
90        let fields = content.parse_terminated(Field::parse)?;
91        let regions = fields
92            .iter()
93            .map(|f| f.region.clone())
94            .collect::<Vec<Region>>();
95
96        // check for intersections
97        for (i, r) in regions.iter().enumerate() {
98            // check all entries we haven't been checked against
99            for k in (i + 1)..regions.len() {
100                let other = &regions[k];
101
102                if r.intersects(other) {
103                    let oth_err = syn::Error::new(other.lit.span(), "other region");
104                    let mut err = syn::Error::new(
105                        r.lit.span(),
106                        format!(
107                            "0b{:b} intersected by other region 0b{:b}",
108                            r.value, other.value
109                        ),
110                    );
111                    err.combine(oth_err);
112                    return Err(err);
113                }
114            }
115        }
116
117        Ok(Struct {
118            ident,
119            repr,
120            default_loc,
121            fields,
122        })
123    }
124}
125
126/// Syntax representation of a region with a name and bit-region.
127struct Field {
128    name: syn::Ident,
129    lower_name: syn::Ident,
130    region: Region,
131}
132impl Field {
133    /// Generate the operations on this field.
134    pub fn gen_ops(&self, struct_name: &syn::Ident, repr: &syn::Type) -> proc_macro2::TokenStream {
135        if self.region.len() == 1 {
136            self.gen_single_bit_ops(struct_name)
137        } else {
138            self.gen_region_ops(struct_name, repr)
139        }
140    }
141
142    /// Generate the "with_{name}" constructor
143    pub fn gen_with_ctor(
144        &self,
145        struct_name: &syn::Ident,
146        repr: &syn::Type,
147    ) -> proc_macro2::TokenStream {
148        let name = &self.name;
149        let lower = &self.lower_name;
150        let with_fn = format_ident!("with_{}", lower);
151        let set_call = format_ident!("set_{}", lower);
152        proc_macro2::TokenStream::from(if self.region.len() == 1 {
153            quote! {
154                pub fn #with_fn() -> #struct_name {
155                    #struct_name::new(#struct_name::#name)
156                }
157            }
158        } else {
159            quote! {
160                pub fn #with_fn<T: Into<#repr>>(val: T) -> #struct_name {
161                    let mut r = #struct_name::new(0 as #repr);
162                    r.#set_call(val);
163                    r
164                }
165            }
166        })
167    }
168
169    /// Generate the Display printer for this field.
170    pub fn gen_display(&self) -> proc_macro2::TokenStream {
171        let name = &self.name;
172        let lower = &self.lower_name;
173        proc_macro2::TokenStream::from(if self.region.len() == 1 {
174            quote! {
175                if self.#lower() {
176                    if is_first { is_first = false; } else { write!(f, " | ")?; }
177                    write!(f, stringify!(#name))?;
178                }
179            }
180        } else {
181            quote! {
182                if is_first { is_first = false; } else { write!(f, " | ")?; }
183                write!(f, "{}={:#X}", stringify!(#name), self.#lower())?;
184            }
185        })
186    }
187
188    /// Generates methods to operate on single-bit regions.
189    /// Setter methods do not take values and includes a toggle method.
190    fn gen_single_bit_ops(&self, struct_name: &syn::Ident) -> proc_macro2::TokenStream {
191        let mask = format_ident!("{}", self.name);
192        let lower = &self.lower_name;
193
194        let set = format_ident!("set_{}", lower);
195        let unset = format_ident!("unset_{}", lower);
196        let toggle = format_ident!("toggle_{}", lower);
197        let extract = format_ident!("extract_{}", lower);
198
199        let getters = quote! {
200            pub fn #lower(&self) -> bool {
201                (self.0 & #struct_name::#mask) != 0
202            }
203            pub fn #extract(&self) -> #struct_name {
204                #struct_name(self.0 & #struct_name::#mask)
205            }
206        };
207
208        let setters = quote! {
209            pub fn #set(&mut self) {
210                self.0 |= #struct_name::#mask
211            }
212            pub fn #unset(&mut self) {
213                self.0 &= !#struct_name::#mask
214            }
215            pub fn #toggle(&mut self) {
216                self.0 ^= #struct_name::#mask
217            }
218        };
219
220        proc_macro2::TokenStream::from(quote! {
221            #getters
222            #setters
223        })
224    }
225
226    /// Generates methods to operate on multi-bit regions.
227    /// Setter methods take values and include debug_assert! calls for both
228    /// bit-region as well as the optional value-range.
229    fn gen_region_ops(
230        &self,
231        struct_name: &syn::Ident,
232        repr: &syn::Type,
233    ) -> proc_macro2::TokenStream {
234        let mask = format_ident!("{}", self.name);
235        let lower = &self.lower_name;
236
237        let set = format_ident!("set_{}", lower);
238        let extract = format_ident!("extract_{}", lower);
239
240        let lower_tuple = format_ident!("{}_tuple", lower);
241        let lower_bools = format_ident!("{}_bools", lower);
242        let region_len = self.region.len();
243
244        let bools_repr = (0..region_len)
245            .map(|_| quote! {bool}.into())
246            .collect::<Vec<proc_macro2::TokenStream>>();
247        let bools_result = (0..region_len)
248            .enumerate()
249            .rev()
250            .map(|(i, _)| quote! {(val >> #i) & 1 == 1}.into())
251            .collect::<Vec<proc_macro2::TokenStream>>();
252
253        let tuple_repr = (0..region_len)
254            .map(|_| quote! {u8}.into())
255            .collect::<Vec<proc_macro2::TokenStream>>();
256        let tuple_result = (0..region_len)
257            .enumerate()
258            .rev()
259            .map(|(i, _)| quote! {((val >> #i) & 1) as u8}.into())
260            .collect::<Vec<proc_macro2::TokenStream>>();
261
262        let shift_offset = self.region.shift_offset();
263        let value_assert = format!(
264            "attempted to set {}::{} with value outside of region: {{:#X}}",
265            struct_name, self.name
266        );
267
268        let range_assert = format!(
269            "attempted to set {}::{} with value outside of range ({{:?}}): {{:#X}}",
270            struct_name, self.name
271        );
272        let range_check = self.region.range.as_ref().map(|ref e| {
273            quote! {
274                debug_assert!((#e).contains(&typed), #range_assert, (#e), typed);
275            }
276        });
277
278        let value_repr = match self.region.len() {
279            0..=8 => {
280                quote! { u8 }
281            }
282            9..=16 => {
283                quote! { u16 }
284            }
285            17..=32 => {
286                quote! { u32 }
287            }
288            33..=64 => {
289                quote! { u64 }
290            }
291            _ => {
292                quote! { usize }
293            }
294        };
295
296        let (upshift, downshift) = if self.region.shift_offset() > 0 {
297            (
298                Some(quote! { << #shift_offset }),
299                Some(quote! { >> #shift_offset }),
300            )
301        } else {
302            (None, None)
303        };
304
305        let getters = quote! {
306            pub fn #lower(&self) -> #value_repr {
307                ((self.0 & #struct_name::#mask) #downshift) as #value_repr
308            }
309            pub fn #lower_tuple(&self) -> (#(#tuple_repr),*) {
310                let val = self.#lower();
311                (#(#tuple_result),*)
312            }
313            pub fn #lower_bools(&self) -> (#(#bools_repr),*) {
314                let val = self.#lower();
315                (#(#bools_result),*)
316            }
317            pub fn #extract(&self) -> #struct_name {
318                #struct_name(self.0 & #struct_name::#mask)
319            }
320        };
321
322        let setters = quote! {
323            pub fn #set<T: Into<#repr>>(&mut self, raw: T) {
324                let typed: #repr = raw.into();
325                let val: #repr = (typed #upshift) as #repr;
326                #range_check
327                debug_assert!(val & #struct_name::#mask == val, #value_assert, val);
328
329                let mut tmp: #repr = self.0 & (!#struct_name::#mask);
330                // TODO: may not be able to write entire word (allow slicing?)
331                self.0 = tmp | val;
332            }
333        };
334
335        (quote! {
336            #getters
337            #setters
338        })
339        .into()
340    }
341}
342
343impl syn::parse::Parse for Field {
344    fn parse(input: syn::parse::ParseStream) -> syn::parse::Result<Self> {
345        let name: syn::Ident = input.parse()?;
346        let lower_str = format!("{}", name).trim().to_lowercase();
347        let lower_name = syn::Ident::new(&lower_str, name.span());
348        let _: syn::Token![:] = input.parse()?;
349        let region: Region = input.parse()?;
350
351        // check for gaps
352        if region.has_gaps() {
353            return Err(syn::Error::new(
354                region.lit.span(),
355                "region cannot contain gap(s)",
356            ));
357        }
358
359        Ok(Field {
360            name,
361            lower_name,
362            region,
363        })
364    }
365}
366
367/// Region contains metadata about a bit-region including the literal
368/// expression, the mask, and a range if defined.
369#[derive(Clone)]
370struct Region {
371    lit: syn::LitInt,
372    value: usize,
373    range: Option<syn::ExprRange>,
374}
375impl Region {
376    /// Minimum number of bits needed to represent the mask literal
377    pub fn min_value_bits(&self) -> usize {
378        (core::mem::size_of::<usize>() * 8) - (self.value.leading_zeros() as usize) - 1
379    }
380
381    /// Number of bits in the region
382    pub fn len(&self) -> usize {
383        self.value.count_ones() as usize
384    }
385
386    /// Offset required to shift "1" to the least significant bit in the region
387    pub fn shift_offset(&self) -> usize {
388        self.value.trailing_zeros() as usize
389    }
390
391    /// Check if the defined mask contains gaps
392    pub fn has_gaps(&self) -> bool {
393        (self.len() + self.shift_offset() - 1) != self.min_value_bits()
394    }
395
396    /// Check if this region intersects with another
397    pub fn intersects(&self, other: &Self) -> bool {
398        let self_min = self.shift_offset();
399        let self_max = self_min + self.len() - 1;
400
401        let oth_min = other.shift_offset();
402        let oth_max = oth_min + other.len() - 1;
403
404        if self_max <= oth_max && self_max >= oth_min {
405            true
406        } else if self_min >= oth_min && self_min <= oth_max {
407            true
408        } else {
409            false
410        }
411    }
412}
413
414impl syn::parse::Parse for Region {
415    fn parse(input: syn::parse::ParseStream) -> syn::parse::Result<Self> {
416        let lit: syn::LitInt = input.parse()?;
417        let value = lit
418            .base10_digits()
419            .parse()
420            .expect("failed to parse literal");
421
422        let mut range = None;
423        if input.peek(syn::Token![|]) {
424            let _: syn::Token![|] = input.parse()?;
425            range = Some(input.parse()?);
426        }
427
428        Ok(Region { lit, value, range })
429    }
430}
431
432struct UserFns {
433    fns: Vec<syn::ItemFn>,
434}
435
436impl syn::parse::Parse for UserFns {
437    fn parse(input: syn::parse::ParseStream) -> syn::parse::Result<Self> {
438        let mut fns = vec![];
439
440        while input.peek(syn::token::Pub) || input.peek(syn::token::Fn) {
441            fns.push(input.parse()?);
442        }
443
444        Ok(UserFns { fns })
445    }
446}
447
448#[proc_macro]
449pub fn bitregions(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
450    let input = syn::parse_macro_input!(item as BitRegions);
451    let vis = &input.vis;
452    let name = &input.struct_def.ident;
453    let repr = &input.struct_def.repr;
454    let user_fns = &input.user_fns.fns;
455
456    // create token streams for the const-defs of the masks
457    let mask_defs = input
458        .struct_def
459        .fields
460        .iter()
461        .map(|f| {
462            let val = &f.region.lit;
463            let mask = &f.name;
464            (quote! { pub const #mask: #repr = #val; }).into()
465        })
466        .collect::<Vec<proc_macro2::TokenStream>>();
467
468    // generate token stream for (optional) default
469    let default = input.struct_def.default_loc.map(|m| {
470        let expr = match m {
471            MemoryLocation::Ident(i) => {
472                quote! { #i }
473            }
474            MemoryLocation::Lit(l) => {
475                quote! { #l }
476            }
477            MemoryLocation::Expr(e) => {
478                quote! { #e }
479            }
480        };
481        quote! {
482            pub unsafe fn default_ptr() -> &'static mut Self {
483                Self::at_addr_mut(#expr)
484            }
485        }
486    });
487
488    // generate token streams for the "with_{field}" constructors
489    let with_ctors = input
490        .struct_def
491        .fields
492        .iter()
493        .map(|f| f.gen_with_ctor(name, repr))
494        .collect::<Vec<proc_macro2::TokenStream>>();
495
496    // generate token streams for the methods
497    let mask_ops = input
498        .struct_def
499        .fields
500        .iter()
501        .map(|f| f.gen_ops(name, repr))
502        .collect::<Vec<proc_macro2::TokenStream>>();
503
504    // generate token streams for the field Display impls
505    let display_ops = input
506        .struct_def
507        .fields
508        .iter()
509        .map(|f| f.gen_display())
510        .collect::<Vec<proc_macro2::TokenStream>>();
511
512    // make display and debug impls
513    let display_debug = quote! {
514        impl core::fmt::Display for #name {
515            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
516                let mut is_first = true;
517                #( #display_ops )*
518                Ok(())
519            }
520        }
521
522        impl core::fmt::Debug for #name {
523            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
524                write!(f, "{:#X}", self.0)
525            }
526        }
527    };
528
529    let result = quote! {
530        #[repr(C)]
531        #[derive(Copy, Clone)]
532        #vis struct #name(#repr);
533
534        #display_debug
535
536        impl #name {
537            #(#mask_defs)*
538
539            pub fn raw(&self) -> #repr {
540                self.0
541            }
542
543            pub fn new<T: Into<#repr>>(bits: T) -> #name {
544                #name(bits.into())
545            }
546
547            pub unsafe fn at_addr<'a>(addr: usize) -> &'a #name {
548                &*(addr as *const u8 as *const #name)
549            }
550
551            pub unsafe fn at_addr_mut<'a>(addr: usize) -> &'a mut #name {
552                &mut *(addr as *mut u8 as *mut #name)
553            }
554
555            pub unsafe fn at_ref<'a, T>(r: &T) -> &'a #name {
556                &*(r as *const T as *const #name)
557            }
558
559            pub unsafe fn at_ref_mut<'a, T>(r: &mut T) -> &'a mut #name {
560                &mut *(r as *mut T as *mut #name)
561            }
562
563            #default
564            #(#with_ctors)*
565
566
567            #(#user_fns)*
568
569            #(#mask_ops)*
570        }
571        impl From<#repr> for #name {
572            fn from(orig: #repr) -> #name {
573                Self(orig)
574            }
575        }
576        impl From<#name> for #repr {
577            fn from(orig: #name) -> #repr {
578                orig.0
579            }
580        }
581
582        impl PartialEq for #name {
583            fn eq(&self, other: &Self) -> bool {
584                self.0 == other.0
585            }
586        }
587        impl Default for #name {
588            fn default() -> Self {
589                #name(#repr::default())
590            }
591        }
592
593        //
594        // add
595        //
596        impl<T: Into<#repr>> core::ops::Add<T> for #name {
597            type Output = Self;
598            fn add(self, other: T) -> Self::Output {
599                #name(self.0 + other.into())
600            }
601        }
602        impl<T: Into<#repr>> core::ops::AddAssign<T> for #name {
603            fn add_assign(&mut self, other: T) {
604                self.0 += other.into();
605            }
606        }
607
608        //
609        // sub
610        //
611        impl<T: Into<#repr>> core::ops::Sub<T> for #name {
612            type Output = Self;
613            fn sub(self, other: T) -> Self::Output {
614                #name(self.0 - other.into())
615            }
616        }
617        impl<T: Into<#repr>> core::ops::SubAssign<T> for #name {
618            fn sub_assign(&mut self, other: T) {
619                self.0 -= other.into();
620            }
621        }
622
623        //
624        // mul
625        //
626        impl<T: Into<#repr>> core::ops::Mul<T> for #name {
627            type Output = Self;
628            fn mul(self, other: T) -> Self::Output {
629                #name(self.0 * other.into())
630            }
631        }
632        impl<T: Into<#repr>> core::ops::MulAssign<T> for #name {
633            fn mul_assign(&mut self, other: T) {
634                self.0 *= other.into();
635            }
636        }
637
638        //
639        // div
640        //
641        impl<T: Into<#repr>> core::ops::Div<T> for #name {
642            type Output = Self;
643            fn div(self, other: T) -> Self::Output {
644                #name(self.0 / other.into())
645            }
646        }
647        impl<T: Into<#repr>> core::ops::DivAssign<T> for #name {
648            fn div_assign(&mut self, other: T) {
649                self.0 /= other.into();
650            }
651        }
652
653        //
654        // bitor
655        //
656        impl<T: Into<#repr>> core::ops::BitOr<T> for #name {
657            type Output = Self;
658            fn bitor(self, other: T) -> Self::Output {
659                #name(self.0 | other.into())
660            }
661        }
662        impl<T: Into<#repr>> core::ops::BitOrAssign<T> for #name {
663            fn bitor_assign(&mut self, other: T) {
664                self.0 |= other.into();
665            }
666        }
667
668        //
669        // bitand
670        //
671        impl<T: Into<#repr>> core::ops::BitAnd<T> for #name {
672            type Output = Self;
673            fn bitand(self, other: T) -> Self::Output {
674                #name(self.0 & other.into())
675            }
676        }
677        impl<T: Into<#repr>> core::ops::BitAndAssign<T> for #name {
678            fn bitand_assign(&mut self, other: T) {
679                self.0 &= other.into();
680            }
681        }
682
683
684        //
685        // bitxor
686        //
687        impl<T: Into<#repr>> core::ops::BitXor<T> for #name {
688            type Output = Self;
689            fn bitxor(self, other: T) -> Self::Output {
690                #name(self.0 ^ other.into())
691            }
692        }
693        impl<T: Into<#repr>> core::ops::BitXorAssign<T> for #name {
694            fn bitxor_assign(&mut self, other: T) {
695                self.0 ^= other.into();
696            }
697        }
698
699        //
700        // shr
701        //
702        impl<T: Into<#repr>> core::ops::Shr<T> for #name {
703            type Output = Self;
704            fn shr(self, other: T) -> Self::Output {
705                #name(self.0 >> other.into())
706            }
707        }
708        impl<T: Into<#repr>> core::ops::ShrAssign<T> for #name {
709            fn shr_assign(&mut self, other: T) {
710                self.0 >>= other.into();
711            }
712        }
713
714        //
715        // shl
716        //
717        impl<T: Into<#repr>> core::ops::Shl<T> for #name {
718            type Output = Self;
719            fn shl(self, other: T) -> Self::Output {
720                #name(self.0 << other.into())
721            }
722        }
723        impl<T: Into<#repr>> core::ops::ShlAssign<T> for #name {
724            fn shl_assign(&mut self, other: T) {
725                self.0 <<= other.into();
726            }
727        }
728    };
729
730    result.into()
731}