Skip to main content

decycle_impl/
finalize.rs

1use proc_macro2::{Span, TokenStream};
2use proc_macro_error::*;
3use std::collections::HashMap;
4use syn::parse::{Parse, ParseStream};
5use syn::punctuated::Punctuated;
6use syn::visit_mut::VisitMut;
7use syn::*;
8use template_quote::quote;
9
10macro_rules! parse_quote {
11    ($($tt:tt)*) => {
12        syn::parse2(::template_quote::quote!($($tt)*)).unwrap()
13    };
14}
15
16
17fn parse_comma_separated<T: Parse>(input: ParseStream) -> Result<Vec<T>> {
18    let mut items = Vec::new();
19    while !input.is_empty() {
20        items.push(input.parse()?);
21        if !input.is_empty() {
22            input.parse::<Token![,]>()?;
23        }
24    }
25    Ok(items)
26}
27
28/// Inserts a `Type` as a `GenericArgument::Type` at the given position
29/// in the last segment's arguments of `path`.
30fn path_insert_type_arg(path: &mut Path, index: usize, ty: Type) {
31    let last_seg = path.segments.last_mut().unwrap();
32    let arg = GenericArgument::Type(ty);
33    match &mut last_seg.arguments {
34        PathArguments::None => {
35            let mut args = Punctuated::new();
36            args.insert(index, arg);
37            last_seg.arguments = PathArguments::AngleBracketed(AngleBracketedGenericArguments {
38                colon2_token: None,
39                lt_token: Default::default(),
40                args,
41                gt_token: Default::default(),
42            });
43        }
44        PathArguments::AngleBracketed(ref mut angle_args) => {
45            angle_args.args.insert(index, arg);
46        }
47        PathArguments::Parenthesized(_) => {}
48    }
49}
50
51/// Returns `false` for trait bounds whose path is a single segment present
52/// in `replacing_table`, used to filter out bounds that will be replaced.
53fn should_keep_bound(bound: &TypeParamBound, replacing_table: &HashMap<Ident, (usize, Path)>) -> bool {
54    if let TypeParamBound::Trait(trait_bound) = bound {
55        if trait_bound.path.segments.len() == 1 {
56            return !replacing_table.contains_key(&trait_bound.path.segments[0].ident);
57        }
58    }
59    true
60}
61
62/// Strips bounds matching `replacing_table` from a `Generics`, removing
63/// type param bounds and where-clause predicates whose paths appear as keys.
64fn strip_replaced_bounds(
65    generics: &mut Generics,
66    replacing_table: &HashMap<Ident, (usize, Path)>,
67) {
68    for param in &mut generics.params {
69        if let GenericParam::Type(ref mut type_param) = param {
70            type_param.bounds = type_param
71                .bounds
72                .iter()
73                .filter(|bound| should_keep_bound(bound, replacing_table))
74                .cloned()
75                .collect();
76            if type_param.bounds.is_empty() {
77                type_param.colon_token = None;
78            }
79        }
80    }
81    if let Some(ref mut where_clause) = generics.where_clause {
82        where_clause.predicates = where_clause
83            .predicates
84            .iter()
85            .filter_map(|pred| {
86                if let WherePredicate::Type(type_pred) = pred {
87                    let new_bounds: Punctuated<TypeParamBound, Token![+]> = type_pred
88                        .bounds
89                        .iter()
90                        .filter(|bound| should_keep_bound(bound, replacing_table))
91                        .cloned()
92                        .collect();
93                    if new_bounds.is_empty() {
94                        None
95                    } else {
96                        let mut new_pred = type_pred.clone();
97                        new_pred.bounds = new_bounds;
98                        Some(WherePredicate::Type(new_pred))
99                    }
100                } else {
101                    Some(pred.clone())
102                }
103            })
104            .collect();
105        if where_clause.predicates.is_empty() {
106            generics.where_clause = None;
107        }
108    }
109}
110
111/// Replaces trait paths that have a single segment matching a key in the
112/// HashMap with the corresponding replacement Path, copying the original
113/// PathArguments and inserting the given Type at the stored position.
114struct TraitReplacer(HashMap<Ident, (usize, Path)>, Type);
115
116impl VisitMut for TraitReplacer {
117    fn visit_path_mut(&mut self, path: &mut Path) {
118        if path.segments.len() == 1 {
119            if let Some((index, replacement)) = self.0.get(&path.segments[0].ident) {
120                let orig_args =
121                    std::mem::replace(&mut path.segments[0].arguments, PathArguments::None);
122                let mut new_path = replacement.clone();
123                new_path.segments.last_mut().unwrap().arguments = orig_args;
124                path_insert_type_arg(&mut new_path, *index, self.1.clone());
125                *path = new_path;
126                return;
127            }
128        }
129        syn::visit_mut::visit_path_mut(self, path);
130    }
131}
132
133pub struct FinalizeArgs {
134    pub working_list: Vec<Path>,
135    pub traits: Vec<ItemTrait>,
136    pub contents: Vec<ItemImpl>,
137    pub recurse_level: usize,
138    pub support_infinite_cycle: bool,
139}
140
141impl Parse for FinalizeArgs {
142    fn parse(input: ParseStream) -> Result<Self> {
143        let _crate_identity: LitStr = input.parse()?;
144        let crate_version: LitStr = input.parse()?;
145        let expected_version = env!("CARGO_PKG_VERSION");
146        if crate_version.value() != expected_version {
147            abort!(
148                Span::call_site(),
149                "version mismatch: expected '{}', got '{}'",
150                expected_version,
151                crate_version.value()
152            )
153        }
154
155        let working_list_content;
156        bracketed!(working_list_content in input);
157        let working_list = parse_comma_separated(&working_list_content)?;
158
159        let traits_content;
160        braced!(traits_content in input);
161        let traits = parse_comma_separated(&traits_content)?;
162
163        let contents_content;
164        braced!(contents_content in input);
165        let contents = parse_comma_separated(&contents_content)?;
166
167        let lit: LitInt = input.parse()?;
168        let recurse_level = lit.base10_parse()?;
169
170        let lit: LitBool = input.parse()?;
171        let support_infinite_cycle = lit.value;
172
173        Ok(FinalizeArgs {
174            working_list,
175            traits,
176            contents,
177            recurse_level,
178            support_infinite_cycle,
179        })
180    }
181}
182
183impl template_quote::ToTokens for FinalizeArgs {
184    fn to_tokens(&self, tokens: &mut TokenStream) {
185        let crate_identity = LitStr::new(&crate::get_crate_identity(), Span::call_site());
186        let crate_version = env!("CARGO_PKG_VERSION");
187        let working_list = &self.working_list;
188        let traits = &self.traits;
189        let contents = &self.contents;
190
191        let recurse_level = &self.recurse_level;
192        let support_infinite_cycle = &self.support_infinite_cycle;
193
194        tokens.extend(quote! {
195            #crate_identity
196            #crate_version
197            [ #(#working_list),* ]
198            { #(#traits),* }
199            { #(#contents),* }
200            #recurse_level
201            #support_infinite_cycle
202        });
203    }
204}
205
206fn get_initial_rank(count: usize) -> Type {
207    if count == 0 {
208        parse_quote!(())
209    } else {
210        let inner = get_initial_rank(count - 1);
211        parse_quote!((#inner,))
212    }
213}
214
215trait GenericsScheme {
216    fn insert(&self, index: usize, param: TypeParam) -> Self;
217    fn impl_generics(&self) -> TokenStream;
218    fn ty_generics(&self) -> TokenStream;
219}
220
221impl GenericsScheme for Generics {
222    fn insert(&self, index: usize, param: TypeParam) -> Self {
223        let mut generics = self.clone();
224        generics.params.insert(index, GenericParam::Type(param));
225        generics
226    }
227
228    fn impl_generics(&self) -> TokenStream {
229        let (impl_generics, _, _) = self.split_for_impl();
230        quote!(#impl_generics)
231    }
232
233    fn ty_generics(&self) -> TokenStream {
234        let (_, ty_generics, _) = self.split_for_impl();
235        quote!(#ty_generics)
236    }
237}
238
239impl GenericsScheme for Path {
240    fn insert(&self, index: usize, param: TypeParam) -> Self {
241        let mut path = self.clone();
242        let ty = Type::Path(TypePath {
243            qself: None,
244            path: parse_quote!(#param),
245        });
246        path_insert_type_arg(&mut path, index, ty);
247        path
248    }
249
250    fn impl_generics(&self) -> TokenStream {
251        quote!()
252    }
253
254    fn ty_generics(&self) -> TokenStream {
255        if let Some(last_segment) = self.segments.last() {
256            let args = &last_segment.arguments;
257            quote!(#args)
258        } else {
259            quote!()
260        }
261    }
262}
263
264pub fn finalize(args: FinalizeArgs) -> TokenStream {
265    let random_suffix = crate::get_random();
266    let name =
267        |s: &str| -> Ident { Ident::new(&format!("{}{}", s, &random_suffix), Span::call_site()) };
268
269    // Mapping which maps trait path (with no args) to corresponding impl item
270    let mut traits_impls: HashMap<Path, Vec<_>> = HashMap::new();
271
272    for item_impl in args.contents {
273        let mut trait_path = item_impl.trait_.clone().unwrap().1;
274        if let Some(last_seg) = trait_path.segments.last_mut() {
275            last_seg.arguments = PathArguments::None;
276        }
277        traits_impls.entry(trait_path).or_default().push(item_impl);
278    }
279
280    let replacing_table: HashMap<Ident, (usize, Path)> = args
281        .traits
282        .iter()
283        .map(|trait_| {
284            let ident = &trait_.ident;
285            let g = &trait_.generics;
286            let loc = g
287                .params
288                .iter()
289                .position(|param| !matches!(param, GenericParam::Lifetime(_)))
290                .unwrap_or(g.params.len());
291            let ranked_ident_str = format!("{}Ranked", ident);
292            let ranked_ident = name(ranked_ident_str.as_str());
293            let ranked_path: Path = parse_quote!(#ranked_ident);
294            (ident.clone(), (loc, ranked_path))
295        })
296        .collect();
297
298    let mut output = TokenStream::new();
299    for trait_ in &args.traits {
300        let ident = &trait_.ident;
301        let Some(impls) = traits_impls.get(&parse_quote!(#ident)) else {
302            emit_warning!(ident, "trait '{}' has no implementations", ident);
303            continue;
304        };
305
306        let g = &trait_.generics;
307        let &(loc, ref ranked_path) = replacing_table.get(ident).unwrap();
308        let initial_rank = get_initial_rank(args.recurse_level);
309
310        let make_ranked_path = |rank_ty: Type| -> Path {
311            let mut path: Path = parse_quote!(#ranked_path #{g.ty_generics()});
312            path_insert_type_arg(&mut path, loc, rank_ty);
313            path
314        };
315        let ranked_bound = make_ranked_path(initial_rank.clone());
316        let ranked_bound_end = make_ranked_path(parse_quote!(()));
317
318        let delegated_items: Vec<TokenStream> = trait_
319            .items
320            .iter()
321            .map(|item| match item {
322                TraitItem::Fn(method) => {
323                    let sig = &method.sig;
324                    let method_ident = &sig.ident;
325                    let call_args: Vec<TokenStream> = sig
326                        .inputs
327                        .iter()
328                        .map(|arg| match arg {
329                            FnArg::Receiver(receiver) => {
330                                let self_token = &receiver.self_token;
331                                quote!(#self_token)
332                            }
333                            FnArg::Typed(pat_type) => {
334                                let pat = &pat_type.pat;
335                                quote!(#pat)
336                            }
337                        })
338                        .collect();
339                    quote! {
340                        #sig {
341                            <Self as #ranked_bound>::#method_ident(#(#call_args),*)
342                        }
343                    }
344                }
345                TraitItem::Type(assoc_type) => {
346                    let type_ident = &assoc_type.ident;
347                    let generics = &assoc_type.generics;
348                    quote! {
349                        type #type_ident #generics = <Self as #ranked_bound>::#type_ident;
350                    }
351                }
352                TraitItem::Const(assoc_const) => {
353                    let const_ident = &assoc_const.ident;
354                    let ty = &assoc_const.ty;
355                    quote! {
356                        const #const_ident: #ty = <Self as #ranked_bound>::#const_ident;
357                    }
358                }
359                _ => quote!(),
360            })
361            .collect();
362
363        output.extend(quote! {
364            #{&trait_.trait_token} #ranked_path #{g.insert(loc, parse_quote!(#{name("Rank")})).ty_generics()}
365            #{trait_.colon_token} #{&trait_.supertraits} {
366                #(for item in &trait_.items) { #item }
367            }
368        });
369        output.extend(quote! {
370            #(for attr in &trait_.attrs) { #attr }
371            impl #{g.insert(loc, parse_quote!(
372                #{name("Self")}: #ranked_bound
373            )).impl_generics()}
374            super::#ident #{g.ty_generics()} for #{name("Self")} #{&g.where_clause} {
375                #(#delegated_items)*
376            }
377        });
378
379        for impl_ in impls {
380            let mut modified_impl = impl_.clone();
381            TraitReplacer(replacing_table.clone(), parse_quote!((#{name("Rank")},)))
382                .visit_path_mut(&mut modified_impl.trait_.as_mut().unwrap().1);
383            TraitReplacer(replacing_table.clone(), parse_quote!(#{name("Rank")}))
384                .visit_item_impl_mut(&mut modified_impl);
385            modified_impl
386                .generics
387                .params
388                .push(parse_quote!(#{name("Rank")}));
389
390            if args.support_infinite_cycle {
391                for (num, item) in modified_impl.items.iter_mut().enumerate() {
392                    if let ImplItem::Fn(ImplItemFn { sig, block, .. }) = item {
393                        let old_block = block.clone();
394                        *block = parse_quote! {
395                            {
396                                let _ = Self::#{name("get_cell")}(#num).set( <Self as #ranked_bound>::#{&sig.ident} as _);
397                                #old_block
398                            }
399                        };
400                    }
401                }
402            }
403
404            let cycle_items: Vec<TokenStream> = impl_
405                .items
406                .iter()
407                .enumerate()
408                .map(|(id, item)| match item {
409                    ImplItem::Fn(method) => {
410                        let mut sig = method.sig.clone();
411                        // ensure that all params are ident
412                        for (num, p) in sig.inputs.iter_mut().enumerate() {
413                            if let FnArg::Typed(PatType { pat, .. }) = p {
414                                if !matches!(pat.as_ref(), Pat::Ident(_)) {
415                                    **pat = Pat::Ident(PatIdent {
416                                        attrs: vec![],
417                                        by_ref: None,
418                                        mutability: None,
419                                        ident: name(format!("param_{}_", num).as_str()),
420                                        subpat: None,
421                                    });
422                                }
423                            }
424                        }
425                        quote! {
426                            #sig {
427                                #(if args.support_infinite_cycle) {
428                                    /// SAFETY:
429                                    #[allow(unused_unsafe)]
430                                    unsafe {
431                                        ::core::mem::transmute::<
432                                            _,
433                                            #{&sig.unsafety} #{&sig.abi}
434                                            fn(
435                                                #(for p in &sig.inputs), {
436                                                    #(if let FnArg::Receiver ( Receiver { ty, .. }) = p) {
437                                                        #ty
438                                                    }
439                                                    #(if let FnArg::Typed ( PatType { ty, .. }) = p) {
440                                                        #ty
441                                                    }
442                                                }
443                                            ) #{&sig.output}
444                                        >(Self::#{name("get_cell")}(#id).get().unwrap())
445                                        (
446                                            #(for p in &sig.inputs), {
447                                                #(if let FnArg::Receiver ( Receiver { self_token, .. }) = p) {
448                                                    #self_token
449                                                }
450                                                #(if let FnArg::Typed ( PatType { pat, .. }) = p) {
451                                                    #pat
452                                                }
453                                            }
454                                        )
455                                    }
456                                }
457                                #(else) {
458                                    ::core::unimplemented!("decycle: cycle limit reached")
459                                }
460                            }
461                        }
462                    }
463                    other => quote!(#other),
464                })
465                .collect();
466
467            let mut modified_g = g.clone();
468            strip_replaced_bounds(&mut modified_g, &replacing_table);
469
470            output.extend(quote! {
471                #modified_impl
472
473                #[allow(unused_variables)]
474                impl #{modified_g.impl_generics()} #ranked_bound_end for #{&impl_.self_ty} #{&modified_g.where_clause} {
475                    #(#cycle_items)*
476                }
477            });
478        }
479    }
480
481    quote! {
482        // this module is to prevent confliction of trait method call between ranked and non-ranked
483        // traits
484        #[doc(hidden)]
485        mod #{name("shadowing_module")} {
486            use super::*;
487
488            // Remove the non-ranked traits from namespace to prevent conflicting
489            #(for ident in replacing_table.keys()) { trait #ident {} }
490
491            #(if args.support_infinite_cycle) {
492                trait #{name("GetVTableKey")} {
493                    extern "C" fn #{name("get_vtable_key")}(&self) {}
494
495                    fn #{name("get_cell")}(id: ::core::primitive::usize) -> &'static ::std::sync::OnceLock<::core::primitive::usize> {
496                        use ::std::sync::{Mutex, OnceLock};
497                        use ::std::collections::HashMap;
498                        use ::std::primitive::*;
499                        static VTABLE_MAP_PARSE: OnceLock<Mutex<HashMap<(usize, usize), OnceLock<usize>>>> = OnceLock::new();
500                        let map = VTABLE_MAP_PARSE.get_or_init(|| Mutex::new(HashMap::new()));
501                        let mut map = map.lock().unwrap();
502                        let r = map.entry((Self::#{name("get_vtable_key")} as usize, id)).or_insert(OnceLock::new());
503                        // SAFETY:
504                        unsafe {
505                            ::core::mem::transmute(r)
506                        }
507                    }
508                }
509
510                impl<T: ?::core::marker::Sized> #{name("GetVTableKey")} for T {}
511            }
512
513            #output
514        }
515    }
516}