bitbash_macros/bitfield/
mod.rs

1use std::collections::HashMap;
2use std::ops::Range;
3
4use proc_macro::TokenStream;
5use proc_macro2::Span;
6use quote::quote;
7use syn::parse::Result;
8use syn::{parse_macro_input, parse_quote};
9use syn::{token, Error, Expr, Fields, Ident, ItemStruct, Meta, Path, RangeLimits, Token, Type};
10
11mod output;
12mod parse;
13
14use parse::{Composition, Index, Relationship};
15
16fn is_uint(p: &Path) -> bool {
17    p.is_ident("usize")
18        || p.is_ident("u8")
19        || p.is_ident("u16")
20        || p.is_ident("u32")
21        || p.is_ident("u64")
22        || p.is_ident("u128")
23}
24
25fn value_repr_ty(value_ty: &Type) -> Type {
26    match value_ty {
27        Type::Path(t) if is_uint(&t.path) => value_ty.clone(),
28        Type::Path(t) if t.path.is_ident("bool") => parse_quote! { u8 },
29        _ => parse_quote! {
30            <#value_ty as bitbash::ConvertRepr>::Repr
31        },
32    }
33}
34
35fn self_repr_ty(strukt: &ItemStruct, src: &Option<Expr>) -> Type {
36    match &strukt.fields {
37        Fields::Unnamed(fields) => match &fields.unnamed[0].ty {
38            Type::Array(t) => (*t.elem).clone(),
39            t => t.clone(),
40        },
41        Fields::Named(fields) => match &src {
42            Some(Expr::Path(p)) => match p.path.get_ident() {
43                Some(p) => match fields.named.iter().find(|f| f.ident.as_ref().unwrap() == p) {
44                    Some(f) => f.ty.clone(),
45                    None => unreachable!(),
46                },
47                None => unreachable!(),
48            },
49            _ => unreachable!(),
50        },
51        Fields::Unit => unreachable!(),
52    }
53}
54
55fn validate_bits(strukt: &ItemStruct, bits: &parse::Bits) -> Result<()> {
56    match &strukt.fields {
57        Fields::Unnamed(fields) => match fields.unnamed[0].ty {
58            Type::Array(_) if bits.src.is_none() => {
59                return Err(Error::new(
60                    bits.bracket_token.span,
61                    "missing an array index",
62                ))
63            }
64            Type::Array(_) => Ok(()),
65            _ if bits.src.is_some() => {
66                return Err(Error::new_spanned(&bits.src, "unexpected array index"))
67            }
68            _ => Ok(()),
69        },
70        Fields::Named(fields) => match &bits.src {
71            None => return Err(Error::new(bits.bracket_token.span, "missing a field")),
72            Some(Expr::Path(p))
73                if fields
74                    .named
75                    .iter()
76                    .filter_map(|field| field.ident.as_ref())
77                    .find(|field_name| p.path.is_ident(&*field_name.to_string()))
78                    .is_some() =>
79            {
80                Ok(())
81            }
82            _ => return Err(Error::new_spanned(&bits.src, "invalid field")),
83        },
84        Fields::Unit => unreachable!(),
85    }
86}
87
88fn into_output_spec(bitfield: parse::Bitfield, use_const: bool) -> Result<output::Bitfield> {
89    let strukt = bitfield.strukt;
90    match strukt.fields {
91        Fields::Unnamed(fields) if fields.unnamed.len() != 1 => {
92            return Err(Error::new_spanned(
93                fields,
94                "a tuple struct may have only one field",
95            ))
96        }
97        Fields::Unit => return Err(Error::new_spanned(strukt, "unit structs are not supported")),
98        _ => (),
99    }
100    let mut fields = Vec::new();
101    for field in bitfield.fields {
102        fields.push(into_output_field(&strukt, field)?);
103    }
104
105    let new = match bitfield.new {
106        None => None,
107        Some(new) => {
108            let mut attrs = Vec::new();
109            for attr in new.attrs {
110                attrs.push(match attr.parse_meta()? {
111                    Meta::Path(p) if p.is_ident("disable_check") => {
112                        output::NewAttribute::DisableCheck
113                    }
114                    _ => return Err(Error::new_spanned(attr, "invalid attribute")),
115                });
116            }
117
118            let field_tys: HashMap<Ident, Type> = fields
119                .iter()
120                .map(|f| (f.name.clone(), f.value_ty.clone()))
121                .collect();
122            let mut init_field_tys = Vec::new();
123            for name in &new.init_fields {
124                match field_tys.get(name) {
125                    Some(ty) => init_field_tys.push(ty.clone()),
126                    None => return Err(Error::new_spanned(name, "field does not exist")),
127                }
128            }
129            let init_field_names = new.init_fields.into_iter().collect();
130            Some(output::New {
131                attrs,
132                vis: new.vis,
133                init_field_names,
134                init_field_tys,
135            })
136        }
137    };
138
139    let derive_debug = match bitfield.derive_debug {
140        None => false,
141        Some(dd) => {
142            for attr in dd.attrs {
143                return Err(Error::new_spanned(attr, "invalid attribute"));
144            }
145            true
146        }
147    };
148
149    Ok(output::Bitfield {
150        use_const,
151        strukt,
152        new,
153        derive_debug,
154        fields,
155    })
156}
157
158fn into_output_field(bitfield: &ItemStruct, field: parse::Field) -> Result<output::Field> {
159    let parse::Field {
160        attrs: in_attrs,
161        vis,
162        name,
163        value_ty,
164        mut composition,
165        ..
166    } = field;
167
168    let mut out_attrs = Vec::new();
169    for attr in in_attrs {
170        out_attrs.push(match attr.parse_meta()? {
171            Meta::Path(p) if p.is_ident("ro") => output::FieldAttribute::ReadOnly,
172            Meta::Path(p) if p.is_ident("private_write") => output::FieldAttribute::PrivateWrite,
173            _ => return Err(Error::new_spanned(attr, "invalid attribute")),
174        });
175    }
176
177    fn fill_in_range(index: &mut Index, ty: &Type) {
178        if let Index::Range {
179            start: ref mut x @ None,
180            ..
181        } = index
182        {
183            *x = Some(parse_quote! { 0 });
184        }
185        if let Index::Range {
186            end: ref mut x @ None,
187            ..
188        } = index
189        {
190            *x = Some(parse_quote! { ((core::mem::size_of::<#ty>() * 8) as u32) });
191        }
192    }
193
194    match &mut composition {
195        Composition::Mapping { relationships, .. } => {
196            for relationship in relationships {
197                if let Some(src) = &relationship.from.src {
198                    return Err(Error::new_spanned(
199                        src,
200                        "the value index must refer only to bits",
201                    ));
202                }
203                let value_repr_ty = value_repr_ty(&value_ty);
204                fill_in_range(&mut relationship.from.index, &value_repr_ty);
205                validate_bits(bitfield, &relationship.to)?;
206                let self_repr_ty = self_repr_ty(bitfield, &relationship.to.src);
207                fill_in_range(&mut relationship.to.index, &self_repr_ty);
208            }
209        }
210        Composition::Concatenation { bits, .. } => {
211            for bits in bits {
212                validate_bits(bitfield, bits)?;
213                let self_repr_ty = self_repr_ty(bitfield, &bits.src);
214                fill_in_range(&mut bits.index, &self_repr_ty);
215            }
216        }
217    }
218
219    let relationships = match composition {
220        Composition::Mapping { relationships, .. } => relationships.into_iter().collect(),
221        Composition::Concatenation { bits, .. } => {
222            let mut relationships = Vec::new();
223            let mut prev_end = parse_quote! { 0 };
224            for bits in bits {
225                let next_end: Expr = match &bits.index {
226                    Index::One(_) => parse_quote! { (#prev_end + 1) },
227                    Index::Range {
228                        start: Some(start),
229                        limits: RangeLimits::HalfOpen(_),
230                        end: Some(end),
231                    } => parse_quote! { (#prev_end + (#end - #start)) },
232                    Index::Range {
233                        start: Some(start),
234                        limits: RangeLimits::Closed(_),
235                        end: Some(end),
236                    } => parse_quote! { (#prev_end + (1 + #end - #start)) },
237                    _ => unreachable!(),
238                };
239                let from_index = Index::Range {
240                    start: Some(prev_end),
241                    limits: RangeLimits::HalfOpen(Token![..](Span::call_site())),
242                    end: Some(next_end.clone()),
243                };
244                relationships.push(Relationship {
245                    from: parse::Bits {
246                        src: None,
247                        bracket_token: token::Bracket {
248                            span: Span::call_site(),
249                        },
250                        index: from_index,
251                    },
252                    arrow_token: Token![=>](Span::call_site()),
253                    to: bits,
254                });
255                prev_end = next_end;
256            }
257            relationships
258        }
259    };
260
261    let mut mapping = Vec::new();
262    for mut relationship in relationships {
263        fn fix_range(index: &mut Index) {
264            if let Index::Range {
265                limits: ref mut l @ RangeLimits::Closed(_),
266                end: Some(ref mut end),
267                ..
268            } = index
269            {
270                *l = RangeLimits::HalfOpen(Token![..](Span::call_site()));
271                *end = parse_quote! { (#end + 1) };
272            }
273        }
274        fix_range(&mut relationship.from.index);
275        fix_range(&mut relationship.to.index);
276
277        fn get_range(index: Index) -> Range<Expr> {
278            match index {
279                Index::One(start) => {
280                    let end = parse_quote! { (#start + 1) };
281                    start..end
282                }
283                Index::Range {
284                    start: Some(start),
285                    limits: RangeLimits::HalfOpen(_),
286                    end: Some(end),
287                } => start..end,
288                _ => unreachable!(),
289            }
290        }
291
292        let from = get_range(relationship.from.index);
293        let to_src = relationship.to.src;
294        let to = get_range(relationship.to.index);
295        mapping.push(output::Relationship { from, to_src, to });
296    }
297
298    Ok(output::Field {
299        attrs: out_attrs,
300        vis,
301        name,
302        value_ty,
303        rels: mapping,
304    })
305}
306
307pub fn bitfield(input: TokenStream, use_const: bool) -> TokenStream {
308    let bitfield = parse_macro_input!(input as parse::Bitfield);
309    let bitfield = match into_output_spec(bitfield, use_const) {
310        Ok(data) => data,
311        Err(err) => return TokenStream::from(err.to_compile_error()),
312    };
313    let expanded = quote! {
314        #bitfield
315    };
316    TokenStream::from(expanded)
317}