munge_macro/
lib.rs

1//! The proc macro at the core of munge.
2
3#![deny(
4    missing_docs,
5    unsafe_op_in_unsafe_fn,
6    clippy::missing_safety_doc,
7    clippy::undocumented_unsafe_blocks,
8    rustdoc::broken_intra_doc_links,
9    rustdoc::missing_crate_level_docs
10)]
11
12use proc_macro2::TokenStream;
13use quote::{quote, quote_spanned};
14use syn::{
15    parse, parse_macro_input,
16    punctuated::Punctuated,
17    spanned::Spanned,
18    token::{Eq, FatArrow, Let, Semi},
19    Error, Expr, FieldPat, Index, Pat, PatIdent, PatRest, PatSlice, PatStruct,
20    PatTuple, PatTupleStruct, Path,
21};
22
23/// Destructures a value by projecting pointers.
24#[proc_macro]
25pub fn munge_with_path(
26    input: proc_macro::TokenStream,
27) -> proc_macro::TokenStream {
28    let input = parse_macro_input!(input as Input);
29    destructure(input)
30        .unwrap_or_else(|e| e.to_compile_error())
31        .into()
32}
33
34struct Input {
35    crate_path: Path,
36    _arrow: FatArrow,
37    destructures: Punctuated<Destructure, Semi>,
38}
39
40impl parse::Parse for Input {
41    fn parse(input: parse::ParseStream) -> parse::Result<Self> {
42        Ok(Input {
43            crate_path: input.parse::<Path>()?,
44            _arrow: input.parse::<FatArrow>()?,
45            destructures: input.parse_terminated(Destructure::parse, Semi)?,
46        })
47    }
48}
49
50struct Destructure {
51    _let_token: Let,
52    pat: Pat,
53    _eq_token: Eq,
54    expr: Expr,
55}
56
57impl parse::Parse for Destructure {
58    fn parse(input: parse::ParseStream) -> parse::Result<Self> {
59        Ok(Destructure {
60            _let_token: input.parse::<Let>()?,
61            pat: Pat::parse_single(input)?,
62            _eq_token: input.parse::<Eq>()?,
63            expr: input.parse::<Expr>()?,
64        })
65    }
66}
67
68fn make_rest_check(crate_path: &Path, rest: &PatRest) -> TokenStream {
69    let span = rest.dot2_token.span();
70    let destructurer = quote! { destructurer };
71
72    quote_spanned! { span => {
73        if false {
74            let ptr = #crate_path::__macro::get_destructuring_ptr(
75                &#destructurer
76            );
77            // SAFETY: This code can never be called.
78            let _ = unsafe { &*ptr } as &dyn #crate_path::__macro::MustBeBorrow;
79        }
80    } }
81}
82
83fn parse_pat(
84    crate_path: &Path,
85    pat: &Pat,
86) -> Result<(TokenStream, TokenStream), Error> {
87    let test_ident = quote_spanned!(pat.span() => test);
88    let test_ident_ref = quote_spanned!(pat.span() => &test);
89    let test = quote! {
90        let #test_ident =
91            #crate_path::__macro::IsReference::for_ptr(ptr).test();
92        let _: &dyn #crate_path::__macro::MustBeAValue = #test_ident_ref;
93    };
94
95    Ok(match pat {
96        Pat::Ident(pat_ident) => {
97            let mutability = &pat_ident.mutability;
98            let ident = &pat_ident.ident;
99
100            if let Some(r#ref) = &pat_ident.by_ref {
101                return Err(Error::new_spanned(
102                    r#ref,
103                    "`ref` is not allowed in munge destructures",
104                ));
105            }
106            if let Some((at, _)) = &pat_ident.subpat {
107                return Err(Error::new_spanned(
108                    at,
109                    "subpatterns are not allowed in munge destructures",
110                ));
111            }
112
113            (
114                quote! { #mutability #ident },
115                quote! {
116                    #test
117
118                    // SAFETY: `ptr` is a properly-aligned pointer to a subfield
119                    // of the pointer underlying `destructurer`.
120                    unsafe {
121                        #crate_path::__macro::restructure_destructurer(
122                            &destructurer,
123                            ptr,
124                        )
125                    }
126                },
127            )
128        }
129        Pat::Tuple(PatTuple { elems, .. })
130        | Pat::TupleStruct(PatTupleStruct { elems, .. }) => {
131            let rest_check = elems.iter().find_map(|e| {
132                if let Pat::Rest(rest) = e {
133                    Some(make_rest_check(crate_path, rest))
134                } else {
135                    None
136                }
137            });
138            let parsed = elems
139                .iter()
140                .filter(|e| !matches!(e, Pat::Rest(_)))
141                .map(|e| parse_pat(crate_path, e))
142                .collect::<Result<Vec<_>, Error>>()?;
143            let (bindings, (exprs, indices)) = parsed
144                .iter()
145                .enumerate()
146                .map(|(i, x)| (&x.0, (&x.1, Index::from(i))))
147                .unzip::<_, _, Vec<_>, (Vec<_>, Vec<_>)>();
148            (
149                quote! { (#(#bindings,)*) },
150                quote! { {
151                    #rest_check
152                    #test
153
154                    ( #({
155                        // SAFETY: `ptr` is guaranteed to always be non-null,
156                        // properly-aligned, and valid for reads.
157                        let ptr = unsafe {
158                            ::core::ptr::addr_of_mut!((*ptr).#indices)
159                        };
160
161                        #exprs
162                    },)* )
163                } },
164            )
165        }
166        Pat::Slice(pat_slice) => {
167            let rest_check = pat_slice.elems.iter().find_map(|e| {
168                if let Pat::Rest(rest) = e {
169                    Some(make_rest_check(crate_path, rest))
170                } else {
171                    None
172                }
173            });
174            let parsed = pat_slice
175                .elems
176                .iter()
177                .filter(|e| !matches!(e, Pat::Rest(_)))
178                .map(|e| parse_pat(crate_path, e))
179                .collect::<Result<Vec<_>, Error>>()?;
180            let (bindings, (exprs, indices)) = parsed
181                .iter()
182                .enumerate()
183                .map(|(i, x)| (&x.0, (&x.1, Index::from(i))))
184                .unzip::<_, _, Vec<_>, (Vec<_>, Vec<_>)>();
185            (
186                quote! { (#(#bindings,)*) },
187                quote! { {
188                    #rest_check
189                    #test
190
191                    ( #({
192                        // SAFETY: `ptr` is guaranteed to always be non-null,
193                        // properly-aligned, and valid for reads.
194                        let ptr = unsafe {
195                            ::core::ptr::addr_of_mut!((*ptr)[#indices])
196                        };
197
198                        #exprs
199                    },)* )
200                } },
201            )
202        }
203        Pat::Struct(pat_struct) => {
204            let parsed = pat_struct
205                .fields
206                .iter()
207                .map(|fp| {
208                    parse_pat(crate_path, &fp.pat).map(|ie| (&fp.member, ie))
209                })
210                .collect::<Result<Vec<_>, Error>>()?;
211            let (members, (bindings, exprs)) =
212                parsed.into_iter().unzip::<_, _, Vec<_>, (Vec<_>, Vec<_>)>();
213
214            let rest_check = pat_struct
215                .rest
216                .as_ref()
217                .map(|rest| make_rest_check(crate_path, rest));
218
219            (
220                quote! { (
221                    #(#bindings,)*
222                ) },
223                quote! { {
224                    #rest_check
225                    #test
226
227                    ( #({
228                        // SAFETY: `ptr` is guaranteed to always be non-null,
229                        // properly-aligned, and valid for reads.
230                        let ptr = unsafe {
231                            ::core::ptr::addr_of_mut!((*ptr).#members)
232                        };
233
234                        #exprs
235                    },)* )
236                } },
237            )
238        }
239        Pat::Rest(_) => unreachable!(
240            "rest patterns only occur in tuples, tuple structs, and slices"
241        ),
242        Pat::Wild(pat_wild) => {
243            let token = &pat_wild.underscore_token;
244            (
245                quote! { #token },
246                quote! {
247                    #test
248
249                    // SAFETY: `ptr` is a properly-aligned pointer to a subfield
250                    // of the pointer underlying `destructurer`.
251                    unsafe {
252                        #crate_path::__macro::restructure_destructurer(
253                            &destructurer,
254                            ptr,
255                        )
256                    }
257                },
258            )
259        }
260        _ => {
261            return Err(Error::new_spanned(
262                pat,
263                "expected a destructuring pattern",
264            ));
265        }
266    })
267}
268
269fn strip_mut(pat: &Pat) -> Result<Pat, Error> {
270    Ok(match pat {
271        Pat::Ident(pat_ident) => Pat::Ident(PatIdent {
272            attrs: pat_ident.attrs.clone(),
273            by_ref: None,
274            mutability: None,
275            ident: pat_ident.ident.clone(),
276            subpat: if let Some((at, pat)) = pat_ident.subpat.as_ref() {
277                Some((*at, Box::new(strip_mut(pat)?)))
278            } else {
279                None
280            },
281        }),
282        Pat::Tuple(pat_tuple) => {
283            let mut elems = Punctuated::new();
284            for elem in pat_tuple.elems.iter() {
285                elems.push_value(strip_mut(elem)?);
286                elems.push_punct(Default::default());
287            }
288            Pat::Tuple(PatTuple {
289                attrs: pat_tuple.attrs.clone(),
290                paren_token: pat_tuple.paren_token,
291                elems,
292            })
293        }
294        Pat::TupleStruct(pat_tuple_struct) => {
295            let mut elems = Punctuated::new();
296            for elem in pat_tuple_struct.elems.iter() {
297                elems.push(strip_mut(elem)?);
298            }
299            Pat::TupleStruct(PatTupleStruct {
300                attrs: pat_tuple_struct.attrs.clone(),
301                qself: pat_tuple_struct.qself.clone(),
302                path: pat_tuple_struct.path.clone(),
303                paren_token: pat_tuple_struct.paren_token,
304                elems,
305            })
306        }
307        Pat::Slice(pat_slice) => {
308            let mut elems = Punctuated::new();
309            for elem in pat_slice.elems.iter() {
310                elems.push(strip_mut(elem)?);
311            }
312            Pat::Slice(PatSlice {
313                attrs: pat_slice.attrs.clone(),
314                bracket_token: pat_slice.bracket_token,
315                elems,
316            })
317        }
318        Pat::Struct(pat_struct) => {
319            let mut fields = Punctuated::new();
320            for field in pat_struct.fields.iter() {
321                fields.push(FieldPat {
322                    attrs: field.attrs.clone(),
323                    member: field.member.clone(),
324                    colon_token: field.colon_token,
325                    pat: Box::new(strip_mut(&field.pat)?),
326                });
327            }
328            Pat::Struct(PatStruct {
329                attrs: pat_struct.attrs.clone(),
330                qself: pat_struct.qself.clone(),
331                path: pat_struct.path.clone(),
332                brace_token: pat_struct.brace_token,
333                fields,
334                rest: pat_struct.rest.clone(),
335            })
336        }
337        Pat::Rest(pat_rest) => Pat::Rest(pat_rest.clone()),
338        Pat::Wild(pat_wild) => Pat::Wild(pat_wild.clone()),
339        _ => todo!(),
340    })
341}
342
343fn destructure(input: Input) -> Result<TokenStream, Error> {
344    let crate_path = &input.crate_path;
345
346    let mut result = TokenStream::new();
347    for destructure in input.destructures.iter() {
348        let pat = &destructure.pat;
349        let expr = &destructure.expr;
350
351        let test_pat = strip_mut(pat)?;
352
353        let (bindings, exprs) = parse_pat(crate_path, pat)?;
354
355        result.extend(quote! {
356            let mut destructurer = #crate_path::__macro::make_destructurer(
357                #expr
358            );
359            let #bindings = {
360                #[allow(
361                    unused_mut,
362                    unused_unsafe,
363                    clippy::undocumented_unsafe_blocks,
364                )]
365                {
366                    use #crate_path::__macro::MaybeReference as _;
367
368                    let ptr = #crate_path::__macro::destructurer_ptr(
369                        &mut destructurer
370                    );
371
372                    #[allow(unreachable_code, unused_variables)]
373                    if false {
374                        // SAFETY: This can never be called.
375                        unsafe { ::core::hint::unreachable_unchecked() };
376                        // SAFETY: This can never be called.
377                        let #test_pat = unsafe {
378                            #crate_path::__macro::test_destructurer(
379                                &mut destructurer,
380                            )
381                        };
382                    }
383
384                    #exprs
385                }
386            };
387        });
388    }
389    Ok(result)
390}