element_ptr_macro/
lib.rs

1extern crate proc_macro;
2
3use proc_macro2::{Ident, Span, TokenStream};
4
5use proc_macro_crate::FoundCrate;
6use quote::{quote, ToTokens};
7use syn::{
8    bracketed, parenthesized,
9    parse::{Parse, ParseStream},
10    parse_macro_input, token, Expr, Index, LitInt, Token, Type,
11};
12
13mod quote_into_hack;
14use quote_into_hack::quote_into;
15
16#[proc_macro]
17pub fn element_ptr(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
18    let input = parse_macro_input!(input as MacroInput);
19
20    let base_crate = {
21        let found =
22            proc_macro_crate::crate_name("element-ptr").unwrap_or_else(|_| FoundCrate::Itself);
23
24        match found {
25            FoundCrate::Itself => String::from("element_ptr"),
26            FoundCrate::Name(name) => name,
27        }
28    };
29
30    let base_crate = Ident::new(&base_crate, Span::call_site());
31
32    let ctx = AccessListToTokensCtx {
33        list: &input.body,
34        base_crate: &base_crate,
35    };
36
37    let ptr = input.ptr;
38
39    (quote! {
40        {
41            let ptr = #ptr;
42            :: #base_crate ::helper::element_ptr_unsafe();
43            #[allow(unused_unsafe)]
44            unsafe {
45                let ptr = :: #base_crate ::helper::new_pointer(ptr);
46                #ctx
47            }
48        }
49    })
50    .into()
51}
52
53struct AccessList(Vec<ElementAccess>);
54
55struct AccessListToTokensCtx<'i> {
56    list: &'i AccessList,
57    base_crate: &'i Ident,
58}
59
60impl<'i> ToTokens for AccessListToTokensCtx<'i> {
61    fn to_tokens(&self, mut tokens: &mut TokenStream) {
62        let base_crate = self.base_crate;
63
64        let mut dirty = false;
65
66        for access in &self.list.0 {
67            use ElementAccess::*;
68
69            if dirty {
70                quote_into! { tokens =>
71                    let ptr = :: #base_crate ::helper::new_pointer(ptr);
72                };
73                dirty = false;
74            }
75
76            match access {
77                Field(FieldAccess { _dot, field }) => match &field {
78                    Some(FieldAccessType::Named(ident)) => quote_into! { tokens =>
79                        let ptr = ptr.copy_addr(
80                            ::core::ptr::addr_of!( ( *ptr.into_const() ) . #ident )
81                        );
82                    },
83                    Some(FieldAccessType::Tuple(index)) => quote_into! { tokens =>
84                        let ptr = ptr.copy_addr(
85                            ::core::ptr::addr_of!( ( *ptr.into_const() ) . #index )
86                        );
87                    },
88                    Some(FieldAccessType::Deref(..)) => {
89                        dirty = true;
90                        quote_into! { tokens =>
91                            let ptr = ptr.read();
92                        }
93                    }
94                    // output something for r-a autocomplete.
95                    None => {
96                        // honestly i'm not quite sure why this specifically
97                        // lets r-a autocomplete after the dot, but it does, and also
98                        // gives a correct (and sort of fake) compiler error of
99                        // "unexpected token `)`".
100                        // i wish there was a better way to interact with r-a about this,
101                        // but this hack will have to do.
102                        let error = syn::Error::new_spanned(
103                            _dot,
104                            "expected an identifier, integer literal, or `*` after this `.`",
105                        )
106                        .into_compile_error();
107                        quote_into! { tokens =>
108                            let ptr = ptr.copy_addr(
109                                ::core::ptr::addr_of!( ( *ptr.into_const() ) #_dot )
110                            );
111                            #error;
112                        }
113                        // just stop generating from here.
114                        return;
115                    }
116                },
117                Index(IndexAccess { index, .. }) => quote_into! { tokens =>
118                    let ptr = :: #base_crate ::helper::index(ptr, #index);
119                },
120                Offset(access) => {
121                    let name = match (&access.offset_type, access.byte.is_some()) {
122                        (OffsetType::Add(..), false) => Ident::new("add", Span::call_site()),
123                        (OffsetType::Sub(..), false) => Ident::new("sub", Span::call_site()),
124                        (OffsetType::Add(..), true) => Ident::new("byte_add", Span::call_site()),
125                        (OffsetType::Sub(..), true) => Ident::new("byte_sub", Span::call_site()),
126                    };
127                    let offset = &access.value;
128                    quote_into! { tokens =>
129                        let ptr = ptr . #name ( #offset );
130                    }
131                }
132                Cast(CastAccess { ty, .. }) => quote_into! { tokens =>
133                    let ptr = ptr.cast::<#ty>();
134                },
135                Group(access) => {
136                    let list = AccessListToTokensCtx {
137                        list: &access.inner,
138                        base_crate: self.base_crate,
139                    };
140                    quote_into! { tokens =>
141                        let ptr = {
142                            #list
143                        };
144                    };
145                    dirty = true;
146                }
147            };
148        }
149        if dirty {
150            quote_into! { tokens =>
151                ptr
152            };
153        } else {
154            quote_into! { tokens =>
155                ptr.into_inner()
156            };
157        }
158    }
159}
160
161impl Parse for AccessList {
162    fn parse(input: ParseStream) -> syn::Result<Self> {
163        let mut out = Vec::new();
164        while !input.is_empty() {
165            let access: ElementAccess = input.parse()?;
166            if access.is_final() && !input.is_empty() {
167                return Err(input.error(""));
168            }
169            out.push(access);
170        }
171        Ok(Self(out))
172    }
173}
174
175struct MacroInput {
176    ptr: Expr,
177    _arrow: Token![=>],
178    body: AccessList,
179}
180
181impl Parse for MacroInput {
182    fn parse(input: ParseStream) -> syn::Result<Self> {
183        Ok(Self {
184            ptr: input.parse()?,
185            _arrow: input.parse()?,
186            body: input.parse()?,
187        })
188    }
189}
190
191enum ElementAccess {
192    Field(FieldAccess),
193    Index(IndexAccess),
194    Offset(OffsetAccess),
195    Cast(CastAccess),
196    Group(GroupAccess),
197}
198
199impl ElementAccess {
200    fn is_final(&self) -> bool {
201        match self {
202            Self::Cast(acc) => acc.arrow.is_none(),
203            _ => false,
204        }
205    }
206}
207
208impl Parse for ElementAccess {
209    fn parse(input: ParseStream) -> syn::Result<Self> {
210        if input.peek(Token![.]) {
211            input.parse().map(Self::Field)
212        } else if input.peek(token::Bracket) {
213            input.parse().map(Self::Index)
214        } else if input.peek(kw::u8) || input.peek(Token![+]) || input.peek(Token![-]) {
215            input.parse().map(Self::Offset)
216        } else if input.peek(Token![as]) {
217            input.parse().map(Self::Cast)
218        } else if input.peek(token::Paren) {
219            input.parse().map(Self::Group)
220        } else {
221            Err(input.error("expected valid element access"))
222        }
223    }
224}
225
226// Also includes deref because it is similar.
227struct FieldAccess {
228    _dot: Token![.],
229    field: Option<FieldAccessType>,
230}
231
232impl Parse for FieldAccess {
233    fn parse(input: ParseStream) -> syn::Result<Self> {
234        Ok(Self {
235            _dot: input.parse()?,
236            field: {
237                if input.is_empty() {
238                    None
239                } else {
240                    Some(input.parse()?)
241                }
242            },
243        })
244    }
245}
246
247enum FieldAccessType {
248    Named(Ident),
249    Tuple(Index),
250    Deref(Token![*]),
251}
252
253impl Parse for FieldAccessType {
254    fn parse(input: ParseStream) -> syn::Result<Self> {
255        let l = input.lookahead1();
256        if l.peek(Token![*]) {
257            input.parse().map(Self::Deref)
258        } else if l.peek(syn::Ident) {
259            input.parse().map(Self::Named)
260        } else if l.peek(LitInt) {
261            // no amazing way to do this unfortunately.
262            input.parse().map(Self::Tuple)
263        } else {
264            Err(l.error())
265        }
266    }
267}
268
269struct IndexAccess {
270    _bracket: token::Bracket,
271    index: Expr,
272}
273
274impl Parse for IndexAccess {
275    fn parse(input: ParseStream) -> syn::Result<Self> {
276        let content;
277        Ok(Self {
278            _bracket: bracketed!(content in input),
279            index: content.parse()?,
280        })
281    }
282}
283
284// struct DerefAccess {
285//     dot: Token![.],
286//     star: Token![*],
287// }
288
289// impl Parse for DerefAccess {
290//     fn parse(input: ParseStream) -> syn::Result<Self> {
291//         Ok(Self {
292//             dot: input.parse()?,
293//             star: input.parse()?,
294//         })
295//     }
296// }
297
298struct OffsetAccess {
299    byte: Option<kw::u8>,
300    offset_type: OffsetType,
301    value: OffsetValue,
302}
303
304impl Parse for OffsetAccess {
305    fn parse(input: ParseStream) -> syn::Result<Self> {
306        Ok(Self {
307            byte: input.parse()?,
308            offset_type: input.parse()?,
309            value: input.parse()?,
310        })
311    }
312}
313
314enum OffsetType {
315    Add(Token![+]),
316    Sub(Token![-]),
317}
318
319impl Parse for OffsetType {
320    fn parse(input: ParseStream) -> syn::Result<Self> {
321        let l = input.lookahead1();
322        if l.peek(Token![+]) {
323            input.parse().map(Self::Add)
324        } else if l.peek(Token![-]) {
325            input.parse().map(Self::Sub)
326        } else {
327            Err(l.error())
328        }
329    }
330}
331
332enum OffsetValue {
333    Integer { int: LitInt },
334    Grouped { _paren: token::Paren, expr: Expr },
335}
336
337impl Parse for OffsetValue {
338    fn parse(input: ParseStream) -> syn::Result<Self> {
339        let l = input.lookahead1();
340        if l.peek(token::Paren) {
341            let content;
342            Ok(Self::Grouped {
343                _paren: parenthesized!(content in input),
344                expr: content.parse()?,
345            })
346        } else if l.peek(LitInt) {
347            Ok(Self::Integer {
348                int: input.parse()?,
349            })
350        } else {
351            Err(l.error())
352        }
353    }
354}
355
356impl ToTokens for OffsetValue {
357    fn to_tokens(&self, tokens: &mut TokenStream) {
358        match self {
359            Self::Integer { int } => int.to_tokens(tokens),
360            Self::Grouped { expr, .. } => expr.to_tokens(tokens),
361        }
362    }
363}
364
365struct CastAccess {
366    _as_token: Token![as],
367    ty: Type,
368    // TODO: is this best syntax for this?
369    arrow: Option<Token![=>]>,
370}
371
372impl Parse for CastAccess {
373    fn parse(input: ParseStream) -> syn::Result<Self> {
374        Ok(Self {
375            _as_token: input.parse()?,
376            ty: input.parse()?,
377            arrow: input.parse()?,
378        })
379    }
380}
381
382struct GroupAccess {
383    _paren: token::Paren,
384    inner: AccessList,
385}
386
387impl Parse for GroupAccess {
388    fn parse(input: ParseStream) -> syn::Result<Self> {
389        let content;
390        Ok(Self {
391            _paren: parenthesized!(content in input),
392            inner: content.parse()?,
393        })
394    }
395}
396
397mod kw {
398    syn::custom_keyword!(u8);
399}