1extern crate proc_macro;
2
3use proc_macro2::{Ident, Span, TokenStream};
4
5use proc_macro_crate::FoundCrate;
6use quote::{quote, ToTokens};
7use syn::{
8 bracketed, parenthesized,
9 parse::{Parse, ParseStream},
10 parse_macro_input, token, Expr, Index, LitInt, Token, Type,
11};
12
13mod quote_into_hack;
14use quote_into_hack::quote_into;
15
16#[proc_macro]
17pub fn element_ptr(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
18 let input = parse_macro_input!(input as MacroInput);
19
20 let base_crate = {
21 let found =
22 proc_macro_crate::crate_name("element-ptr").unwrap_or_else(|_| FoundCrate::Itself);
23
24 match found {
25 FoundCrate::Itself => String::from("element_ptr"),
26 FoundCrate::Name(name) => name,
27 }
28 };
29
30 let base_crate = Ident::new(&base_crate, Span::call_site());
31
32 let ctx = AccessListToTokensCtx {
33 list: &input.body,
34 base_crate: &base_crate,
35 };
36
37 let ptr = input.ptr;
38
39 (quote! {
40 {
41 let ptr = #ptr;
42 :: #base_crate ::helper::element_ptr_unsafe();
43 #[allow(unused_unsafe)]
44 unsafe {
45 let ptr = :: #base_crate ::helper::new_pointer(ptr);
46 #ctx
47 }
48 }
49 })
50 .into()
51}
52
53struct AccessList(Vec<ElementAccess>);
54
55struct AccessListToTokensCtx<'i> {
56 list: &'i AccessList,
57 base_crate: &'i Ident,
58}
59
60impl<'i> ToTokens for AccessListToTokensCtx<'i> {
61 fn to_tokens(&self, mut tokens: &mut TokenStream) {
62 let base_crate = self.base_crate;
63
64 let mut dirty = false;
65
66 for access in &self.list.0 {
67 use ElementAccess::*;
68
69 if dirty {
70 quote_into! { tokens =>
71 let ptr = :: #base_crate ::helper::new_pointer(ptr);
72 };
73 dirty = false;
74 }
75
76 match access {
77 Field(FieldAccess { _dot, field }) => match &field {
78 Some(FieldAccessType::Named(ident)) => quote_into! { tokens =>
79 let ptr = ptr.copy_addr(
80 ::core::ptr::addr_of!( ( *ptr.into_const() ) . #ident )
81 );
82 },
83 Some(FieldAccessType::Tuple(index)) => quote_into! { tokens =>
84 let ptr = ptr.copy_addr(
85 ::core::ptr::addr_of!( ( *ptr.into_const() ) . #index )
86 );
87 },
88 Some(FieldAccessType::Deref(..)) => {
89 dirty = true;
90 quote_into! { tokens =>
91 let ptr = ptr.read();
92 }
93 }
94 None => {
96 let error = syn::Error::new_spanned(
103 _dot,
104 "expected an identifier, integer literal, or `*` after this `.`",
105 )
106 .into_compile_error();
107 quote_into! { tokens =>
108 let ptr = ptr.copy_addr(
109 ::core::ptr::addr_of!( ( *ptr.into_const() ) #_dot )
110 );
111 #error;
112 }
113 return;
115 }
116 },
117 Index(IndexAccess { index, .. }) => quote_into! { tokens =>
118 let ptr = :: #base_crate ::helper::index(ptr, #index);
119 },
120 Offset(access) => {
121 let name = match (&access.offset_type, access.byte.is_some()) {
122 (OffsetType::Add(..), false) => Ident::new("add", Span::call_site()),
123 (OffsetType::Sub(..), false) => Ident::new("sub", Span::call_site()),
124 (OffsetType::Add(..), true) => Ident::new("byte_add", Span::call_site()),
125 (OffsetType::Sub(..), true) => Ident::new("byte_sub", Span::call_site()),
126 };
127 let offset = &access.value;
128 quote_into! { tokens =>
129 let ptr = ptr . #name ( #offset );
130 }
131 }
132 Cast(CastAccess { ty, .. }) => quote_into! { tokens =>
133 let ptr = ptr.cast::<#ty>();
134 },
135 Group(access) => {
136 let list = AccessListToTokensCtx {
137 list: &access.inner,
138 base_crate: self.base_crate,
139 };
140 quote_into! { tokens =>
141 let ptr = {
142 #list
143 };
144 };
145 dirty = true;
146 }
147 };
148 }
149 if dirty {
150 quote_into! { tokens =>
151 ptr
152 };
153 } else {
154 quote_into! { tokens =>
155 ptr.into_inner()
156 };
157 }
158 }
159}
160
161impl Parse for AccessList {
162 fn parse(input: ParseStream) -> syn::Result<Self> {
163 let mut out = Vec::new();
164 while !input.is_empty() {
165 let access: ElementAccess = input.parse()?;
166 if access.is_final() && !input.is_empty() {
167 return Err(input.error(""));
168 }
169 out.push(access);
170 }
171 Ok(Self(out))
172 }
173}
174
175struct MacroInput {
176 ptr: Expr,
177 _arrow: Token![=>],
178 body: AccessList,
179}
180
181impl Parse for MacroInput {
182 fn parse(input: ParseStream) -> syn::Result<Self> {
183 Ok(Self {
184 ptr: input.parse()?,
185 _arrow: input.parse()?,
186 body: input.parse()?,
187 })
188 }
189}
190
191enum ElementAccess {
192 Field(FieldAccess),
193 Index(IndexAccess),
194 Offset(OffsetAccess),
195 Cast(CastAccess),
196 Group(GroupAccess),
197}
198
199impl ElementAccess {
200 fn is_final(&self) -> bool {
201 match self {
202 Self::Cast(acc) => acc.arrow.is_none(),
203 _ => false,
204 }
205 }
206}
207
208impl Parse for ElementAccess {
209 fn parse(input: ParseStream) -> syn::Result<Self> {
210 if input.peek(Token![.]) {
211 input.parse().map(Self::Field)
212 } else if input.peek(token::Bracket) {
213 input.parse().map(Self::Index)
214 } else if input.peek(kw::u8) || input.peek(Token![+]) || input.peek(Token![-]) {
215 input.parse().map(Self::Offset)
216 } else if input.peek(Token![as]) {
217 input.parse().map(Self::Cast)
218 } else if input.peek(token::Paren) {
219 input.parse().map(Self::Group)
220 } else {
221 Err(input.error("expected valid element access"))
222 }
223 }
224}
225
226struct FieldAccess {
228 _dot: Token![.],
229 field: Option<FieldAccessType>,
230}
231
232impl Parse for FieldAccess {
233 fn parse(input: ParseStream) -> syn::Result<Self> {
234 Ok(Self {
235 _dot: input.parse()?,
236 field: {
237 if input.is_empty() {
238 None
239 } else {
240 Some(input.parse()?)
241 }
242 },
243 })
244 }
245}
246
247enum FieldAccessType {
248 Named(Ident),
249 Tuple(Index),
250 Deref(Token![*]),
251}
252
253impl Parse for FieldAccessType {
254 fn parse(input: ParseStream) -> syn::Result<Self> {
255 let l = input.lookahead1();
256 if l.peek(Token![*]) {
257 input.parse().map(Self::Deref)
258 } else if l.peek(syn::Ident) {
259 input.parse().map(Self::Named)
260 } else if l.peek(LitInt) {
261 input.parse().map(Self::Tuple)
263 } else {
264 Err(l.error())
265 }
266 }
267}
268
269struct IndexAccess {
270 _bracket: token::Bracket,
271 index: Expr,
272}
273
274impl Parse for IndexAccess {
275 fn parse(input: ParseStream) -> syn::Result<Self> {
276 let content;
277 Ok(Self {
278 _bracket: bracketed!(content in input),
279 index: content.parse()?,
280 })
281 }
282}
283
284struct OffsetAccess {
299 byte: Option<kw::u8>,
300 offset_type: OffsetType,
301 value: OffsetValue,
302}
303
304impl Parse for OffsetAccess {
305 fn parse(input: ParseStream) -> syn::Result<Self> {
306 Ok(Self {
307 byte: input.parse()?,
308 offset_type: input.parse()?,
309 value: input.parse()?,
310 })
311 }
312}
313
314enum OffsetType {
315 Add(Token![+]),
316 Sub(Token![-]),
317}
318
319impl Parse for OffsetType {
320 fn parse(input: ParseStream) -> syn::Result<Self> {
321 let l = input.lookahead1();
322 if l.peek(Token![+]) {
323 input.parse().map(Self::Add)
324 } else if l.peek(Token![-]) {
325 input.parse().map(Self::Sub)
326 } else {
327 Err(l.error())
328 }
329 }
330}
331
332enum OffsetValue {
333 Integer { int: LitInt },
334 Grouped { _paren: token::Paren, expr: Expr },
335}
336
337impl Parse for OffsetValue {
338 fn parse(input: ParseStream) -> syn::Result<Self> {
339 let l = input.lookahead1();
340 if l.peek(token::Paren) {
341 let content;
342 Ok(Self::Grouped {
343 _paren: parenthesized!(content in input),
344 expr: content.parse()?,
345 })
346 } else if l.peek(LitInt) {
347 Ok(Self::Integer {
348 int: input.parse()?,
349 })
350 } else {
351 Err(l.error())
352 }
353 }
354}
355
356impl ToTokens for OffsetValue {
357 fn to_tokens(&self, tokens: &mut TokenStream) {
358 match self {
359 Self::Integer { int } => int.to_tokens(tokens),
360 Self::Grouped { expr, .. } => expr.to_tokens(tokens),
361 }
362 }
363}
364
365struct CastAccess {
366 _as_token: Token![as],
367 ty: Type,
368 arrow: Option<Token![=>]>,
370}
371
372impl Parse for CastAccess {
373 fn parse(input: ParseStream) -> syn::Result<Self> {
374 Ok(Self {
375 _as_token: input.parse()?,
376 ty: input.parse()?,
377 arrow: input.parse()?,
378 })
379 }
380}
381
382struct GroupAccess {
383 _paren: token::Paren,
384 inner: AccessList,
385}
386
387impl Parse for GroupAccess {
388 fn parse(input: ParseStream) -> syn::Result<Self> {
389 let content;
390 Ok(Self {
391 _paren: parenthesized!(content in input),
392 inner: content.parse()?,
393 })
394 }
395}
396
397mod kw {
398 syn::custom_keyword!(u8);
399}