Skip to main content

c2rust_bitfields_derive/
lib.rs

1#![recursion_limit = "512"]
2
3use proc_macro::TokenStream;
4use proc_macro2::Span;
5use quote::quote;
6use syn::parse::Error;
7use syn::punctuated::Punctuated;
8use syn::spanned::Spanned;
9use syn::{
10    parse_macro_input, Attribute, Field, Fields, Ident, ItemStruct, LitStr, Path, PathArguments,
11    PathSegment, Token,
12};
13
14#[cfg(target_endian = "big")]
15compile_error!("Big endian architectures are not currently supported");
16
17/// This struct keeps track of a single bitfield attr's params
18/// as well as the bitfield's field name.
19#[derive(Debug)]
20struct BFFieldAttr {
21    field_name: Ident,
22    name: String,
23    ty: String,
24    bits: (String, Span),
25}
26
27fn parse_bitfield_attr(
28    attr: &Attribute,
29    field_ident: &Ident,
30) -> Result<Option<BFFieldAttr>, Error> {
31    let mut name = None;
32    let mut ty = None;
33    let mut bits = None;
34    let mut bits_span = None;
35    let mut is_padding = false;
36
37    attr.parse_nested_meta(|meta| {
38        if meta.path.is_ident("padding") {
39            // If the attribute is just `#[bitfield(padding)]`, we can skip parsing further.
40            is_padding = true;
41        } else {
42            let value = match meta.value()?.parse::<LitStr>() {
43                Ok(lit_str) => lit_str.value(),
44                Err(_) => {
45                    let err_str = "Found bitfield attribute with non str literal assignment";
46                    return Err(meta.error(err_str));
47                }
48            };
49
50            if meta.path.is_ident("name") {
51                name = Some(value);
52            } else if meta.path.is_ident("ty") {
53                ty = Some(value);
54            } else if meta.path.is_ident("bits") {
55                bits = Some(value);
56                bits_span = Some(meta.path.span());
57            }
58        }
59
60        Ok(())
61    })?;
62
63    if is_padding {
64        return Ok(None);
65    }
66
67    if name.is_none() || ty.is_none() || bits.is_none() {
68        let mut missing_fields = Vec::new();
69
70        if name.is_none() {
71            missing_fields.push("name");
72        }
73
74        if ty.is_none() {
75            missing_fields.push("ty");
76        }
77
78        if bits.is_none() {
79            missing_fields.push("bits");
80        }
81
82        let err_str = format!("Missing bitfield params: {:?}", missing_fields);
83        let span = attr.span();
84
85        return Err(Error::new(span, err_str));
86    }
87
88    Ok(Some(BFFieldAttr {
89        field_name: field_ident.clone(),
90        name: name.unwrap(),
91        ty: ty.unwrap(),
92        bits: (bits.unwrap(), bits_span.unwrap()),
93    }))
94}
95
96fn filter_and_parse_fields(field: &Field) -> Vec<Result<BFFieldAttr, Error>> {
97    let attrs: Vec<_> = field
98        .attrs
99        .iter()
100        .filter(|attr| attr.path().segments.last().unwrap().ident == "bitfield")
101        .collect();
102
103    if attrs.is_empty() {
104        return Vec::new();
105    }
106
107    attrs
108        .into_iter()
109        .map(|attr| parse_bitfield_attr(attr, field.ident.as_ref().unwrap()))
110        .flat_map(Result::transpose) // Remove the Ok(None) values
111        .collect()
112}
113
114fn parse_bitfield_ty_path(field: &BFFieldAttr) -> Path {
115    let mut segments = Punctuated::new();
116    let mut segment_strings = field.ty.split("::").peekable();
117    let colon = Token![::]([Span::call_site(), Span::call_site()]);
118    let leading_colon = segment_strings.next_if_eq(&"").map(|_| colon);
119
120    while let Some(segment_string) = segment_strings.next() {
121        segments.push_value(PathSegment {
122            ident: Ident::new(segment_string, Span::call_site()),
123            arguments: PathArguments::None,
124        });
125
126        if segment_strings.peek().is_some() {
127            segments.push_punct(colon);
128        }
129    }
130
131    Path {
132        leading_colon,
133        segments,
134    }
135}
136
137#[cfg(test)]
138#[test]
139fn test_parse_bitfield_ty_path_non_empty_idents() {
140    let tys = ["::core::ffi::c_int", "core::ffi::c_int"];
141    for ty in tys {
142        let field = BFFieldAttr {
143            field_name: Ident::new("field", Span::call_site()),
144            name: Default::default(),
145            ty: ty.into(),
146            bits: (Default::default(), Span::call_site()),
147        };
148        let _path = parse_bitfield_ty_path(&field);
149    }
150}
151
152#[proc_macro_derive(BitfieldStruct, attributes(bitfield))]
153pub fn bitfield_struct(input: TokenStream) -> TokenStream {
154    let struct_item = parse_macro_input!(input as ItemStruct);
155
156    match bitfield_struct_impl(struct_item) {
157        Ok(ts) => ts,
158        Err(error) => error.to_compile_error().into(),
159    }
160}
161
162fn bitfield_struct_impl(struct_item: ItemStruct) -> Result<TokenStream, Error> {
163    // REVIEW: Should we throw a compile error if bit ranges on a single field overlap?
164    let struct_ident = struct_item.ident;
165    let fields = match struct_item.fields {
166        Fields::Named(named_fields) => named_fields.named,
167        Fields::Unnamed(_) => {
168            let err_str =
169                "Unnamed struct fields are not currently supported but may be in the future.";
170            let span = struct_ident.span();
171
172            return Err(Error::new(span, err_str));
173        }
174        Fields::Unit => {
175            let err_str = "Cannot create bitfield struct out of struct with no fields";
176            let span = struct_ident.span();
177
178            return Err(Error::new(span, err_str));
179        }
180    };
181    let bitfields: Result<Vec<BFFieldAttr>, Error> =
182        fields.iter().flat_map(filter_and_parse_fields).collect();
183    let bitfields = bitfields?;
184    let field_types: Vec<_> = bitfields.iter().map(parse_bitfield_ty_path).collect();
185    let field_types_return = &field_types;
186    let field_types_typedef = &field_types;
187    let field_types_setter_arg = &field_types;
188    let method_names: Vec<_> = bitfields
189        .iter()
190        .map(|field| Ident::new(&field.name, Span::call_site()))
191        .collect();
192    let field_names: Vec<_> = bitfields.iter().map(|field| &field.field_name).collect();
193    let field_names_setters = &field_names;
194    let field_names_getters = &field_names;
195    let method_name_setters: Vec<_> = method_names
196        .iter()
197        .map(|field_ident| {
198            let span = Span::call_site();
199            let setter_name = &format!("set_{}", field_ident);
200
201            Ident::new(setter_name, span)
202        })
203        .collect();
204    let field_bit_info: Result<Vec<_>, Error> = bitfields
205        .iter()
206        .map(|field| {
207            let bit_string = &field.bits.0;
208            let nums: Vec<_> = bit_string.split("..=").collect();
209            let err_str = "bits param must be in the format \"1..=4\"";
210
211            if nums.len() != 2 {
212                return Err(Error::new(field.bits.1, err_str));
213            }
214
215            let lhs = nums[0].parse::<usize>();
216            let rhs = nums[1].parse::<usize>();
217
218            let (lhs, rhs) = match (lhs, rhs) {
219                (Err(_), _) | (_, Err(_)) => return Err(Error::new(field.bits.1, err_str)),
220                (Ok(lhs), Ok(rhs)) => (lhs, rhs),
221            };
222
223            Ok(quote! { (#lhs, #rhs) })
224        })
225        .collect();
226    let field_bit_info = field_bit_info?;
227    let field_bit_info_setters = &field_bit_info;
228    let field_bit_info_getters = &field_bit_info;
229
230    // TODO: Method visibility determined by struct field visibility?
231    let q = quote! {
232        #[automatically_derived]
233        impl #struct_ident {
234            #(
235                /// This method allows you to write to a bitfield with a value
236                pub fn #method_name_setters(&mut self, int: #field_types_setter_arg) {
237                    use c2rust_bitfields::FieldType;
238
239                    let field = &mut self.#field_names_setters;
240                    let (lhs_bit, rhs_bit) = #field_bit_info_setters;
241                    int.set_field(field, (lhs_bit, rhs_bit));
242                }
243
244                /// This method allows you to read from a bitfield to a value
245                pub fn #method_names(&self) -> #field_types_return {
246                    use c2rust_bitfields::FieldType;
247
248                    type IntType = #field_types_typedef;
249
250                    let field = &self.#field_names_getters;
251                    let (lhs_bit, rhs_bit) = #field_bit_info_getters;
252                    <IntType as FieldType>::get_field(field, (lhs_bit, rhs_bit))
253                }
254            )*
255        }
256    };
257
258    Ok(q.into())
259}