Skip to main content

decycle_impl/
finalize.rs

1use proc_macro2::{Span, TokenStream};
2use proc_macro_error::*;
3use std::collections::HashMap;
4use syn::parse::{Parse, ParseStream};
5use syn::punctuated::Punctuated;
6use syn::visit_mut::VisitMut;
7use syn::*;
8use template_quote::quote;
9
10macro_rules! parse_quote {
11    ($($tt:tt)*) => {
12        syn::parse2(::template_quote::quote!($($tt)*)).unwrap()
13    };
14}
15
16fn parse_comma_separated<T: Parse>(input: ParseStream) -> Result<Vec<T>> {
17    let mut items = Vec::new();
18    while !input.is_empty() {
19        items.push(input.parse()?);
20        if !input.is_empty() {
21            input.parse::<Token![,]>()?;
22        }
23    }
24    Ok(items)
25}
26
27/// Inserts a `Type` as a `GenericArgument::Type` at the given position
28/// in the last segment's arguments of `path`.
29fn path_insert_type_arg(path: &mut Path, index: usize, ty: Type) {
30    let last_seg = path.segments.last_mut().unwrap();
31    let arg = GenericArgument::Type(ty);
32    match &mut last_seg.arguments {
33        PathArguments::None => {
34            let mut args = Punctuated::new();
35            args.insert(index, arg);
36            last_seg.arguments = PathArguments::AngleBracketed(AngleBracketedGenericArguments {
37                colon2_token: None,
38                lt_token: Default::default(),
39                args,
40                gt_token: Default::default(),
41            });
42        }
43        PathArguments::AngleBracketed(ref mut angle_args) => {
44            angle_args.args.insert(index, arg);
45        }
46        PathArguments::Parenthesized(_) => {}
47    }
48}
49
50/// Returns `false` for trait bounds whose path is a single segment present
51/// in `replacing_table`, used to filter out bounds that will be replaced.
52fn should_keep_bound(
53    bound: &TypeParamBound,
54    replacing_table: &HashMap<Ident, (usize, Path)>,
55) -> bool {
56    if let TypeParamBound::Trait(trait_bound) = bound {
57        if trait_bound.path.segments.len() == 1 {
58            return !replacing_table.contains_key(&trait_bound.path.segments[0].ident);
59        }
60    }
61    true
62}
63
64/// Strips bounds matching `replacing_table` from a `Generics`, removing
65/// type param bounds and where-clause predicates whose paths appear as keys.
66fn strip_replaced_bounds(generics: &mut Generics, replacing_table: &HashMap<Ident, (usize, Path)>) {
67    for param in &mut generics.params {
68        if let GenericParam::Type(ref mut type_param) = param {
69            type_param.bounds = type_param
70                .bounds
71                .iter()
72                .filter(|bound| should_keep_bound(bound, replacing_table))
73                .cloned()
74                .collect();
75            if type_param.bounds.is_empty() {
76                type_param.colon_token = None;
77            }
78        }
79    }
80    if let Some(ref mut where_clause) = generics.where_clause {
81        where_clause.predicates = where_clause
82            .predicates
83            .iter()
84            .filter_map(|pred| {
85                if let WherePredicate::Type(type_pred) = pred {
86                    let new_bounds: Punctuated<TypeParamBound, Token![+]> = type_pred
87                        .bounds
88                        .iter()
89                        .filter(|bound| should_keep_bound(bound, replacing_table))
90                        .cloned()
91                        .collect();
92                    if new_bounds.is_empty() {
93                        None
94                    } else {
95                        let mut new_pred = type_pred.clone();
96                        new_pred.bounds = new_bounds;
97                        Some(WherePredicate::Type(new_pred))
98                    }
99                } else {
100                    Some(pred.clone())
101                }
102            })
103            .collect();
104        if where_clause.predicates.is_empty() {
105            generics.where_clause = None;
106        }
107    }
108}
109
110/// Replaces trait paths that have a single segment matching a key in the
111/// HashMap with the corresponding replacement Path, copying the original
112/// PathArguments and inserting the given Type at the stored position.
113struct TraitReplacer(HashMap<Ident, (usize, Path)>, Type);
114
115impl VisitMut for TraitReplacer {
116    fn visit_expr_path_mut(&mut self, expr_path: &mut ExprPath) {
117        self.replace_qself_path(expr_path.qself.as_mut(), &mut expr_path.path);
118        syn::visit_mut::visit_expr_path_mut(self, expr_path);
119    }
120
121    fn visit_type_path_mut(&mut self, type_path: &mut TypePath) {
122        self.replace_qself_path(type_path.qself.as_mut(), &mut type_path.path);
123        syn::visit_mut::visit_type_path_mut(self, type_path);
124    }
125
126    fn visit_path_mut(&mut self, path: &mut Path) {
127        self.replace_qself_path(None, path);
128        syn::visit_mut::visit_path_mut(self, path);
129    }
130}
131
132impl TraitReplacer {
133    fn replace_qself_path(&mut self, qself: Option<&mut QSelf>, path: &mut Path) -> bool {
134        // allow `Trait` or `<_ as Trait>::path`
135        if !(matches!(qself, Some(QSelf { position: 1, .. })) || qself.is_none())
136            || path.leading_colon.is_some()
137        {
138            return false;
139        }
140
141        if let Some((index, replacement)) = self.0.get(&path.segments[0].ident) {
142            let orig_args = std::mem::replace(&mut path.segments[0].arguments, PathArguments::None);
143            let mut new_path = replacement.clone();
144            new_path.segments.last_mut().unwrap().arguments = orig_args;
145            path_insert_type_arg(&mut new_path, *index, self.1.clone());
146            let mut new_segments: Punctuated<PathSegment, Token![::]> = Punctuated::new();
147            for seg in new_path.segments {
148                new_segments.push(seg);
149            }
150            if let Some(qself) = qself {
151                qself.position = new_segments.len();
152                for seg in path.segments.iter().skip(qself.position) {
153                    new_segments.push(seg.clone());
154                }
155            }
156            path.segments = new_segments;
157        }
158        true
159    }
160}
161
162pub struct FinalizeArgs {
163    pub working_list: Vec<Path>,
164    pub traits: Vec<ItemTrait>,
165    pub contents: Vec<ItemImpl>,
166    pub recurse_level: usize,
167    pub support_infinite_cycle: bool,
168}
169
170impl Parse for FinalizeArgs {
171    fn parse(input: ParseStream) -> Result<Self> {
172        let _crate_identity: LitStr = input.parse()?;
173        let crate_version: LitStr = input.parse()?;
174        let expected_version = env!("CARGO_PKG_VERSION");
175        if crate_version.value() != expected_version {
176            abort!(
177                Span::call_site(),
178                "version mismatch: expected '{}', got '{}'",
179                expected_version,
180                crate_version.value()
181            )
182        }
183
184        let working_list_content;
185        bracketed!(working_list_content in input);
186        let working_list = parse_comma_separated(&working_list_content)?;
187
188        let traits_content;
189        braced!(traits_content in input);
190        let traits = parse_comma_separated(&traits_content)?;
191
192        let contents_content;
193        braced!(contents_content in input);
194        let contents = parse_comma_separated(&contents_content)?;
195
196        let lit: LitInt = input.parse()?;
197        let recurse_level = lit.base10_parse()?;
198
199        let lit: LitBool = input.parse()?;
200        let support_infinite_cycle = lit.value;
201
202        Ok(FinalizeArgs {
203            working_list,
204            traits,
205            contents,
206            recurse_level,
207            support_infinite_cycle,
208        })
209    }
210}
211
212impl template_quote::ToTokens for FinalizeArgs {
213    fn to_tokens(&self, tokens: &mut TokenStream) {
214        let crate_identity = LitStr::new(&crate::get_crate_identity(), Span::call_site());
215        let crate_version = env!("CARGO_PKG_VERSION");
216        let working_list = &self.working_list;
217        let traits = &self.traits;
218        let contents = &self.contents;
219
220        let recurse_level = &self.recurse_level;
221        let support_infinite_cycle = &self.support_infinite_cycle;
222
223        tokens.extend(quote! {
224            #crate_identity
225            #crate_version
226            [ #(#working_list),* ]
227            { #(#traits),* }
228            { #(#contents),* }
229            #recurse_level
230            #support_infinite_cycle
231        });
232    }
233}
234
235fn get_initial_rank(count: usize) -> Type {
236    if count == 0 {
237        parse_quote!(())
238    } else {
239        let inner = get_initial_rank(count - 1);
240        parse_quote!((#inner,))
241    }
242}
243
244trait GenericsScheme {
245    fn insert(&self, index: usize, param: TypeParam) -> Self;
246    fn impl_generics(&self) -> TokenStream;
247    fn ty_generics(&self) -> TokenStream;
248}
249
250impl GenericsScheme for Generics {
251    fn insert(&self, index: usize, param: TypeParam) -> Self {
252        let mut generics = self.clone();
253        generics.params.insert(index, GenericParam::Type(param));
254        generics
255    }
256
257    fn impl_generics(&self) -> TokenStream {
258        let (impl_generics, _, _) = self.split_for_impl();
259        quote!(#impl_generics)
260    }
261
262    fn ty_generics(&self) -> TokenStream {
263        let (_, ty_generics, _) = self.split_for_impl();
264        quote!(#ty_generics)
265    }
266}
267
268impl GenericsScheme for Path {
269    fn insert(&self, index: usize, param: TypeParam) -> Self {
270        let mut path = self.clone();
271        let ty = Type::Path(TypePath {
272            qself: None,
273            path: parse_quote!(#param),
274        });
275        path_insert_type_arg(&mut path, index, ty);
276        path
277    }
278
279    fn impl_generics(&self) -> TokenStream {
280        quote!()
281    }
282
283    fn ty_generics(&self) -> TokenStream {
284        if let Some(last_segment) = self.segments.last() {
285            let args = &last_segment.arguments;
286            quote!(#args)
287        } else {
288            quote!()
289        }
290    }
291}
292
293pub fn finalize(args: FinalizeArgs) -> TokenStream {
294    let random_suffix = crate::get_random();
295    let name =
296        |s: &str| -> Ident { Ident::new(&format!("{}{}", s, &random_suffix), Span::call_site()) };
297
298    // Mapping which maps trait path (with no args) to corresponding impl item
299    let mut traits_impls: HashMap<Path, Vec<_>> = HashMap::new();
300
301    for item_impl in args.contents {
302        let mut trait_path = item_impl.trait_.clone().unwrap().1;
303        if let Some(last_seg) = trait_path.segments.last_mut() {
304            last_seg.arguments = PathArguments::None;
305        }
306        traits_impls.entry(trait_path).or_default().push(item_impl);
307    }
308
309    let replacing_table: HashMap<Ident, (usize, Path)> = args
310        .traits
311        .iter()
312        .map(|trait_| {
313            let ident = &trait_.ident;
314            let g = &trait_.generics;
315            let loc = g
316                .params
317                .iter()
318                .position(|param| !matches!(param, GenericParam::Lifetime(_)))
319                .unwrap_or(g.params.len());
320            let ranked_ident_str = format!("{}Ranked", ident);
321            let ranked_ident = name(ranked_ident_str.as_str());
322            let ranked_path: Path = parse_quote!(#ranked_ident);
323            (ident.clone(), (loc, ranked_path))
324        })
325        .collect();
326
327    let mut output = TokenStream::new();
328    for trait_ in &args.traits {
329        let ident = &trait_.ident;
330        let Some(impls) = traits_impls.get(&parse_quote!(#ident)) else {
331            emit_warning!(ident, "trait '{}' has no implementations", ident);
332            continue;
333        };
334
335        let g = &trait_.generics;
336        let &(loc, ref ranked_path) = replacing_table.get(ident).unwrap();
337        let initial_rank = get_initial_rank(args.recurse_level);
338
339        let make_delegated_items =
340            |ranked_bound: &Path, renamer: &mut crate::GenericRenamer| -> Vec<TokenStream> {
341                trait_
342                    .items
343                    .iter()
344                    .map(|item| match item {
345                        TraitItem::Fn(method) => {
346                            let mut sig = method.sig.clone();
347                            let mut sig_renamer = renamer.clone();
348                            for param in &sig.generics.params {
349                                match param {
350                                    GenericParam::Lifetime(lt) => {
351                                        let lifetime = &lt.lifetime;
352                                        sig_renamer
353                                            .lifetime_renames
354                                            .retain(|(old, _)| old != lifetime);
355                                    }
356                                    GenericParam::Type(tp) => {
357                                        let ident = &tp.ident;
358                                        sig_renamer.ident_renames.retain(|(old, _)| old != ident);
359                                    }
360                                    GenericParam::Const(cp) => {
361                                        let ident = &cp.ident;
362                                        sig_renamer.ident_renames.retain(|(old, _)| old != ident);
363                                    }
364                                }
365                            }
366                            sig_renamer.visit_signature_mut(&mut sig);
367                            let method_ident = &sig.ident;
368                            let call_args: Vec<TokenStream> = sig
369                                .inputs
370                                .iter_mut()
371                                .enumerate()
372                                .map(|(num, arg)| match arg {
373                                    FnArg::Receiver(receiver) => {
374                                        let self_token = &receiver.self_token;
375                                        quote!(#self_token)
376                                    }
377                                    FnArg::Typed(pat_type) => {
378                                        if !matches!(*pat_type.pat, Pat::Ident(_)) {
379                                            *pat_type.pat = Pat::Ident(PatIdent {
380                                                attrs: vec![],
381                                                by_ref: None,
382                                                mutability: None,
383                                                ident: name(format!("param_{}_", num).as_str()),
384                                                subpat: None,
385                                            });
386                                        }
387                                        let pat = &pat_type.pat;
388                                        quote!(#pat)
389                                    }
390                                })
391                                .collect();
392                            quote! {
393                                #sig {
394                                    <Self as #ranked_bound>::#method_ident(#(#call_args),*)
395                                }
396                            }
397                        }
398                        TraitItem::Type(assoc_type) => {
399                            let type_ident = &assoc_type.ident;
400                            let generics = &assoc_type.generics;
401                            quote! {
402                                type #type_ident #generics = <Self as #ranked_bound>::#type_ident;
403                            }
404                        }
405                        TraitItem::Const(assoc_const) => {
406                            let const_ident = &assoc_const.ident;
407                            let ty = &assoc_const.ty;
408                            quote! {
409                                const #const_ident: #ty = <Self as #ranked_bound>::#const_ident;
410                            }
411                        }
412                        _ => quote!(),
413                    })
414                    .collect()
415            };
416
417        output.extend(quote! {
418            #[allow(unused)]
419            #{&trait_.trait_token} #ranked_path #{g.insert(loc, parse_quote!(#{name("Rank")})).ty_generics()}
420            #{trait_.colon_token} #{&trait_.supertraits} {
421                #(for item in &trait_.items) { #item }
422            }
423        });
424
425        for impl_ in impls {
426            let impl_trait_path = impl_.trait_.as_ref().unwrap().1.clone();
427            let make_ranked_path_from_impl = |rank_ty: Type| -> Path {
428                let mut path: Path = parse_quote!(#ranked_path);
429                if let (Some(from), Some(to)) =
430                    (impl_trait_path.segments.last(), path.segments.last_mut())
431                {
432                    to.arguments = from.arguments.clone();
433                }
434                path_insert_type_arg(&mut path, loc, rank_ty);
435                path
436            };
437            let ranked_bound = make_ranked_path_from_impl(initial_rank.clone());
438            let ranked_bound_end = make_ranked_path_from_impl(parse_quote!(()));
439
440            let mut modified_impl = impl_.clone();
441            TraitReplacer(replacing_table.clone(), parse_quote!((#{name("Rank")},)))
442                .visit_path_mut(&mut modified_impl.trait_.as_mut().unwrap().1);
443            TraitReplacer(replacing_table.clone(), parse_quote!(#{name("Rank")}))
444                .visit_item_impl_mut(&mut modified_impl);
445            modified_impl
446                .generics
447                .params
448                .push(parse_quote!(#{name("Rank")}));
449            let ranked_bound_with_rank = make_ranked_path_from_impl(parse_quote!(#{name("Rank")}));
450            modified_impl
451                .generics
452                .where_clause
453                .get_or_insert(WhereClause {
454                    where_token: Default::default(),
455                    predicates: Default::default(),
456                })
457                .predicates
458                .push(parse_quote!(Self: #ranked_bound_with_rank));
459
460            if args.support_infinite_cycle {
461                for (num, item) in modified_impl.items.iter_mut().enumerate() {
462                    if let ImplItem::Fn(ImplItemFn { sig, block, .. }) = item {
463                        let old_block = block.clone();
464                        *block = parse_quote! {
465                            {
466                                let _ = Self::#{name("get_cell")}(#num).set( <Self as #ranked_bound>::#{&sig.ident} as _);
467                                #old_block
468                            }
469                        };
470                    }
471                }
472            }
473
474            let cycle_items: Vec<TokenStream> = impl_
475                .items
476                .iter()
477                .enumerate()
478                .map(|(id, item)| match item {
479                    ImplItem::Fn(method) => {
480                        let mut sig = method.sig.clone();
481                        // ensure that all params are ident
482                        for (num, p) in sig.inputs.iter_mut().enumerate() {
483                            if let FnArg::Typed(PatType { pat, .. }) = p {
484                                if !matches!(pat.as_ref(), Pat::Ident(_)) {
485                                    **pat = Pat::Ident(PatIdent {
486                                        attrs: vec![],
487                                        by_ref: None,
488                                        mutability: None,
489                                        ident: name(format!("param_{}_", num).as_str()),
490                                        subpat: None,
491                                    });
492                                }
493                            }
494                        }
495                        quote! {
496                            #sig {
497                                #(if args.support_infinite_cycle) {
498                                    /// SAFETY:
499                                    #[allow(unused_unsafe)]
500                                    unsafe {
501                                        ::core::mem::transmute::<
502                                            _,
503                                            #{&sig.unsafety} #{&sig.abi}
504                                            fn(
505                                                #(for p in &sig.inputs), {
506                                                    #(if let FnArg::Receiver ( Receiver { ty, .. }) = p) {
507                                                        #ty
508                                                    }
509                                                    #(if let FnArg::Typed ( PatType { ty, .. }) = p) {
510                                                        #ty
511                                                    }
512                                                }
513                                            ) #{&sig.output}
514                                        >(Self::#{name("get_cell")}(#id).get().unwrap())
515                                        (
516                                            #(for p in &sig.inputs), {
517                                                #(if let FnArg::Receiver ( Receiver { self_token, .. }) = p) {
518                                                    #self_token
519                                                }
520                                                #(if let FnArg::Typed ( PatType { pat, .. }) = p) {
521                                                    #pat
522                                                }
523                                            }
524                                        )
525                                    }
526                                }
527                                #(else) {
528                                    ::core::unimplemented!("decycle: cycle limit reached")
529                                }
530                            }
531                        }
532                    }
533                    other => quote!(#other),
534                })
535                .collect();
536
537            let mut impl_generics = impl_.generics.clone();
538            strip_replaced_bounds(&mut impl_generics, &replacing_table);
539            output.extend(quote! {
540                #modified_impl
541
542                #[allow(unused_variables)]
543                impl #{impl_generics.impl_generics()} #ranked_bound_end for #{&impl_.self_ty} #{&impl_generics.where_clause} {
544                    #(#cycle_items)*
545                }
546            });
547            let mut delegated_generics = impl_.generics.clone();
548            strip_replaced_bounds(&mut delegated_generics, &replacing_table);
549            let mut delegated_self_ty = (*impl_.self_ty).clone();
550            let mut delegated_trait_path: Path = parse_quote!(super::#ident);
551            if let Some((_, trait_path, _)) = &impl_.trait_ {
552                if let (Some(from), Some(to)) = (
553                    trait_path.segments.last(),
554                    delegated_trait_path.segments.last_mut(),
555                ) {
556                    to.arguments = from.arguments.clone();
557                }
558            }
559            let mut renamer =
560                crate::randomize_impl_generics(&mut delegated_generics, random_suffix);
561            renamer.visit_type_mut(&mut delegated_self_ty);
562            renamer.visit_path_mut(&mut delegated_trait_path);
563            let mut delegated_ranked_bound = ranked_bound.clone();
564            renamer.visit_path_mut(&mut delegated_ranked_bound);
565            let delegated_items = make_delegated_items(&delegated_ranked_bound, &mut renamer);
566            let ranked_bound_pred: WherePredicate =
567                parse_quote!(#delegated_self_ty: #delegated_ranked_bound);
568            match delegated_generics.where_clause {
569                Some(ref mut where_clause) => where_clause.predicates.push(ranked_bound_pred),
570                None => {
571                    delegated_generics.where_clause = Some(parse_quote!(where #ranked_bound_pred));
572                }
573            }
574            output.extend(quote! {
575                #(for attr in &trait_.attrs) { #attr }
576                impl #{delegated_generics.impl_generics()}
577                #delegated_trait_path for #delegated_self_ty #{&delegated_generics.where_clause}
578                {
579                    #(#delegated_items)*
580                }
581            });
582        }
583    }
584
585    quote! {
586        // this module is to prevent confliction of trait method call between ranked and non-ranked
587        // traits
588        #[doc(hidden)]
589        mod #{name("shadowing_module")} {
590            #[allow(unused)]
591            use super::*;
592
593            // Remove the non-ranked traits from namespace to prevent conflicting
594            #(for ident in replacing_table.keys()) { trait #ident {} }
595
596            #(if args.support_infinite_cycle) {
597                #[allow(unused)]
598                trait #{name("GetVTableKey")} {
599                    extern "C" fn #{name("get_vtable_key")}(&self) {}
600
601                    fn #{name("get_cell")}(id: ::core::primitive::usize) -> &'static ::std::sync::OnceLock<::core::primitive::usize> {
602                        use ::std::sync::{Mutex, OnceLock};
603                        use ::std::collections::HashMap;
604                        use ::std::primitive::*;
605                        static VTABLE_MAP_PARSE: OnceLock<Mutex<HashMap<(usize, usize), OnceLock<usize>>>> = OnceLock::new();
606                        let map = VTABLE_MAP_PARSE.get_or_init(|| Mutex::new(HashMap::new()));
607                        let mut map = map.lock().unwrap();
608                        let r = map.entry((Self::#{name("get_vtable_key")} as usize, id)).or_insert(OnceLock::new());
609                        // SAFETY:
610                        unsafe {
611                            ::core::mem::transmute(r)
612                        }
613                    }
614                }
615
616                impl<T: ?::core::marker::Sized> #{name("GetVTableKey")} for T {}
617            }
618
619            #output
620        }
621    }
622}