c2rust_bitfields_derive/
lib.rs

1#![recursion_limit = "512"]
2
3use proc_macro::{Span, TokenStream};
4use quote::quote;
5use syn::parse::Error;
6use syn::punctuated::Punctuated;
7use syn::spanned::Spanned;
8use syn::{
9    parse_macro_input, Attribute, Field, Fields, Ident, ItemStruct, Lit, Meta, NestedMeta, Path,
10    PathArguments, PathSegment, Token,
11};
12
13#[cfg(target_endian = "big")]
14compile_error!("Big endian architectures are not currently supported");
15
16/// This struct keeps track of a single bitfield attr's params
17/// as well as the bitfield's field name.
18#[derive(Debug)]
19struct BFFieldAttr {
20    field_name: Ident,
21    name: String,
22    ty: String,
23    bits: (String, proc_macro2::Span),
24}
25
26fn parse_bitfield_attr(
27    attr: &Attribute,
28    field_ident: &Ident,
29) -> Result<Option<BFFieldAttr>, Error> {
30    let mut name = None;
31    let mut ty = None;
32    let mut bits = None;
33    let mut bits_span = None;
34
35    if let Meta::List(meta_list) = attr.parse_meta()? {
36        for nested_meta in meta_list.nested {
37            if let NestedMeta::Meta(Meta::NameValue(meta_name_value)) = nested_meta {
38                let rhs_string = match meta_name_value.lit {
39                    Lit::Str(lit_str) => lit_str.value(),
40                    _ => {
41                        let err_str = "Found bitfield attribute with non str literal assignment";
42                        let span = meta_name_value.path.span();
43
44                        return Err(Error::new(span, err_str));
45                    }
46                };
47
48                if let Some(lhs_ident) = meta_name_value.path.get_ident() {
49                    match lhs_ident.to_string().as_str() {
50                        "name" => name = Some(rhs_string),
51                        "ty" => ty = Some(rhs_string),
52                        "bits" => {
53                            bits = Some(rhs_string);
54                            bits_span = Some(meta_name_value.path.span());
55                        }
56                        // This one shouldn't ever occur here,
57                        // but we're handling it just to be safe
58                        "padding" => {
59                            return Ok(None);
60                        }
61                        _ => {}
62                    }
63                }
64            } else if let NestedMeta::Meta(Meta::Path(ref path)) = nested_meta {
65                if let Some(ident) = path.get_ident() {
66                    if ident == "padding" {
67                        return Ok(None);
68                    }
69                }
70            }
71        }
72    }
73
74    if name.is_none() || ty.is_none() || bits.is_none() {
75        let mut missing_fields = Vec::new();
76
77        if name.is_none() {
78            missing_fields.push("name");
79        }
80
81        if ty.is_none() {
82            missing_fields.push("ty");
83        }
84
85        if bits.is_none() {
86            missing_fields.push("bits");
87        }
88
89        let err_str = format!("Missing bitfield params: {:?}", missing_fields);
90        let span = attr.path.segments.span();
91
92        return Err(Error::new(span, err_str));
93    }
94
95    Ok(Some(BFFieldAttr {
96        field_name: field_ident.clone(),
97        name: name.unwrap(),
98        ty: ty.unwrap(),
99        bits: (bits.unwrap(), bits_span.unwrap()),
100    }))
101}
102
103fn filter_and_parse_fields(field: &Field) -> Vec<Result<BFFieldAttr, Error>> {
104    let attrs: Vec<_> = field
105        .attrs
106        .iter()
107        .filter(|attr| attr.path.segments.last().unwrap().ident == "bitfield")
108        .collect();
109
110    if attrs.is_empty() {
111        return Vec::new();
112    }
113
114    attrs
115        .into_iter()
116        .map(|attr| parse_bitfield_attr(attr, field.ident.as_ref().unwrap()))
117        .flat_map(Result::transpose) // Remove the Ok(None) values
118        .collect()
119}
120
121fn parse_bitfield_ty_path(field: &BFFieldAttr) -> Path {
122    let leading_colon = if field.ty.starts_with("::") {
123        Some(Token![::]([
124            Span::call_site().into(),
125            Span::call_site().into(),
126        ]))
127    } else {
128        None
129    };
130
131    let mut segments = Punctuated::new();
132    let mut segment_strings = field.ty.split("::").peekable();
133
134    while let Some(segment_string) = segment_strings.next() {
135        segments.push_value(PathSegment {
136            ident: Ident::new(segment_string, Span::call_site().into()),
137            arguments: PathArguments::None,
138        });
139
140        if segment_strings.peek().is_some() {
141            segments.push_punct(Token![::]([
142                Span::call_site().into(),
143                Span::call_site().into(),
144            ]));
145        }
146    }
147
148    Path {
149        leading_colon,
150        segments,
151    }
152}
153
154#[proc_macro_derive(BitfieldStruct, attributes(bitfield))]
155pub fn bitfield_struct(input: TokenStream) -> TokenStream {
156    let struct_item = parse_macro_input!(input as ItemStruct);
157
158    match bitfield_struct_impl(struct_item) {
159        Ok(ts) => ts,
160        Err(error) => error.to_compile_error().into(),
161    }
162}
163
164fn bitfield_struct_impl(struct_item: ItemStruct) -> Result<TokenStream, Error> {
165    // REVIEW: Should we throw a compile error if bit ranges on a single field overlap?
166    let struct_ident = struct_item.ident;
167    let fields = match struct_item.fields {
168        Fields::Named(named_fields) => named_fields.named,
169        Fields::Unnamed(_) => {
170            let err_str =
171                "Unnamed struct fields are not currently supported but may be in the future.";
172            let span = struct_ident.span();
173
174            return Err(Error::new(span, err_str));
175        }
176        Fields::Unit => {
177            let err_str = "Cannot create bitfield struct out of struct with no fields";
178            let span = struct_ident.span();
179
180            return Err(Error::new(span, err_str));
181        }
182    };
183    let bitfields: Result<Vec<BFFieldAttr>, Error> =
184        fields.iter().flat_map(filter_and_parse_fields).collect();
185    let bitfields = bitfields?;
186    let field_types: Vec<_> = bitfields.iter().map(parse_bitfield_ty_path).collect();
187    let field_types_return = &field_types;
188    let field_types_typedef = &field_types;
189    let field_types_setter_arg = &field_types;
190    let method_names: Vec<_> = bitfields
191        .iter()
192        .map(|field| Ident::new(&field.name, Span::call_site().into()))
193        .collect();
194    let field_names: Vec<_> = bitfields.iter().map(|field| &field.field_name).collect();
195    let field_names_setters = &field_names;
196    let field_names_getters = &field_names;
197    let method_name_setters: Vec<_> = method_names
198        .iter()
199        .map(|field_ident| {
200            let span = Span::call_site().into();
201            let setter_name = &format!("set_{}", field_ident);
202
203            Ident::new(setter_name, span)
204        })
205        .collect();
206    let field_bit_info: Result<Vec<_>, Error> = bitfields
207        .iter()
208        .map(|field| {
209            let bit_string = &field.bits.0;
210            let nums: Vec<_> = bit_string.split("..=").collect();
211            let err_str = "bits param must be in the format \"1..=4\"";
212
213            if nums.len() != 2 {
214                return Err(Error::new(field.bits.1, err_str));
215            }
216
217            let lhs = nums[0].parse::<usize>();
218            let rhs = nums[1].parse::<usize>();
219
220            let (lhs, rhs) = match (lhs, rhs) {
221                (Err(_), _) | (_, Err(_)) => return Err(Error::new(field.bits.1, err_str)),
222                (Ok(lhs), Ok(rhs)) => (lhs, rhs),
223            };
224
225            Ok(quote! { (#lhs, #rhs) })
226        })
227        .collect();
228    let field_bit_info = field_bit_info?;
229    let field_bit_info_setters = &field_bit_info;
230    let field_bit_info_getters = &field_bit_info;
231
232    // TODO: Method visibility determined by struct field visibility?
233    let q = quote! {
234        #[automatically_derived]
235        impl #struct_ident {
236            #(
237                /// This method allows you to write to a bitfield with a value
238                pub fn #method_name_setters(&mut self, int: #field_types_setter_arg) {
239                    use c2rust_bitfields::FieldType;
240
241                    let field = &mut self.#field_names_setters;
242                    let (lhs_bit, rhs_bit) = #field_bit_info_setters;
243                    int.set_field(field, (lhs_bit, rhs_bit));
244                }
245
246                /// This method allows you to read from a bitfield to a value
247                pub fn #method_names(&self) -> #field_types_return {
248                    use c2rust_bitfields::FieldType;
249
250                    type IntType = #field_types_typedef;
251
252                    let field = &self.#field_names_getters;
253                    let (lhs_bit, rhs_bit) = #field_bit_info_getters;
254                    <IntType as FieldType>::get_field(field, (lhs_bit, rhs_bit))
255                }
256            )*
257        }
258    };
259
260    Ok(q.into())
261}