Skip to main content

decycle_impl/
finalize.rs

1use proc_macro2::{Span, TokenStream};
2use proc_macro_error::*;
3use std::collections::HashMap;
4use syn::parse::{Parse, ParseStream};
5use syn::punctuated::Punctuated;
6use syn::visit_mut::VisitMut;
7use syn::*;
8use template_quote::quote;
9
10macro_rules! parse_quote {
11    ($($tt:tt)*) => {
12        syn::parse2(::template_quote::quote!($($tt)*)).unwrap()
13    };
14}
15
16fn parse_comma_separated<T: Parse>(input: ParseStream) -> Result<Vec<T>> {
17    let mut items = Vec::new();
18    while !input.is_empty() {
19        items.push(input.parse()?);
20        if !input.is_empty() {
21            input.parse::<Token![,]>()?;
22        }
23    }
24    Ok(items)
25}
26
27/// Inserts a `Type` as a `GenericArgument::Type` at the given position
28/// in the last segment's arguments of `path`.
29fn path_insert_type_arg(path: &mut Path, index: usize, ty: Type) {
30    let last_seg = path.segments.last_mut().unwrap();
31    let arg = GenericArgument::Type(ty);
32    match &mut last_seg.arguments {
33        PathArguments::None => {
34            let mut args = Punctuated::new();
35            args.insert(index, arg);
36            last_seg.arguments = PathArguments::AngleBracketed(AngleBracketedGenericArguments {
37                colon2_token: None,
38                lt_token: Default::default(),
39                args,
40                gt_token: Default::default(),
41            });
42        }
43        PathArguments::AngleBracketed(ref mut angle_args) => {
44            angle_args.args.insert(index, arg);
45        }
46        PathArguments::Parenthesized(_) => {}
47    }
48}
49
50/// Returns `false` for trait bounds whose path is a single segment present
51/// in `replacing_table`, used to filter out bounds that will be replaced.
52fn should_keep_bound(
53    bound: &TypeParamBound,
54    replacing_table: &HashMap<Ident, (usize, Path)>,
55) -> bool {
56    if let TypeParamBound::Trait(trait_bound) = bound {
57        if trait_bound.path.segments.len() == 1 {
58            return !replacing_table.contains_key(&trait_bound.path.segments[0].ident);
59        }
60    }
61    true
62}
63
64/// Strips bounds matching `replacing_table` from a `Generics`, removing
65/// type param bounds and where-clause predicates whose paths appear as keys.
66fn strip_replaced_bounds(generics: &mut Generics, replacing_table: &HashMap<Ident, (usize, Path)>) {
67    for param in &mut generics.params {
68        if let GenericParam::Type(ref mut type_param) = param {
69            type_param.bounds = type_param
70                .bounds
71                .iter()
72                .filter(|bound| should_keep_bound(bound, replacing_table))
73                .cloned()
74                .collect();
75            if type_param.bounds.is_empty() {
76                type_param.colon_token = None;
77            }
78        }
79    }
80    if let Some(ref mut where_clause) = generics.where_clause {
81        where_clause.predicates = where_clause
82            .predicates
83            .iter()
84            .filter_map(|pred| {
85                if let WherePredicate::Type(type_pred) = pred {
86                    let new_bounds: Punctuated<TypeParamBound, Token![+]> = type_pred
87                        .bounds
88                        .iter()
89                        .filter(|bound| should_keep_bound(bound, replacing_table))
90                        .cloned()
91                        .collect();
92                    if new_bounds.is_empty() {
93                        None
94                    } else {
95                        let mut new_pred = type_pred.clone();
96                        new_pred.bounds = new_bounds;
97                        Some(WherePredicate::Type(new_pred))
98                    }
99                } else {
100                    Some(pred.clone())
101                }
102            })
103            .collect();
104        if where_clause.predicates.is_empty() {
105            generics.where_clause = None;
106        }
107    }
108}
109
110/// Replaces trait paths that have a single segment matching a key in the
111/// HashMap with the corresponding replacement Path, copying the original
112/// PathArguments and inserting the given Type at the stored position.
113struct TraitReplacer(HashMap<Ident, (usize, Path)>, Type);
114
115impl VisitMut for TraitReplacer {
116    fn visit_path_mut(&mut self, path: &mut Path) {
117        if path.segments.len() == 1 {
118            if let Some((index, replacement)) = self.0.get(&path.segments[0].ident) {
119                let orig_args =
120                    std::mem::replace(&mut path.segments[0].arguments, PathArguments::None);
121                let mut new_path = replacement.clone();
122                new_path.segments.last_mut().unwrap().arguments = orig_args;
123                path_insert_type_arg(&mut new_path, *index, self.1.clone());
124                *path = new_path;
125                return;
126            }
127        }
128        syn::visit_mut::visit_path_mut(self, path);
129    }
130}
131
132pub struct FinalizeArgs {
133    pub working_list: Vec<Path>,
134    pub traits: Vec<ItemTrait>,
135    pub contents: Vec<ItemImpl>,
136    pub recurse_level: usize,
137    pub support_infinite_cycle: bool,
138}
139
140impl Parse for FinalizeArgs {
141    fn parse(input: ParseStream) -> Result<Self> {
142        let _crate_identity: LitStr = input.parse()?;
143        let crate_version: LitStr = input.parse()?;
144        let expected_version = env!("CARGO_PKG_VERSION");
145        if crate_version.value() != expected_version {
146            abort!(
147                Span::call_site(),
148                "version mismatch: expected '{}', got '{}'",
149                expected_version,
150                crate_version.value()
151            )
152        }
153
154        let working_list_content;
155        bracketed!(working_list_content in input);
156        let working_list = parse_comma_separated(&working_list_content)?;
157
158        let traits_content;
159        braced!(traits_content in input);
160        let traits = parse_comma_separated(&traits_content)?;
161
162        let contents_content;
163        braced!(contents_content in input);
164        let contents = parse_comma_separated(&contents_content)?;
165
166        let lit: LitInt = input.parse()?;
167        let recurse_level = lit.base10_parse()?;
168
169        let lit: LitBool = input.parse()?;
170        let support_infinite_cycle = lit.value;
171
172        Ok(FinalizeArgs {
173            working_list,
174            traits,
175            contents,
176            recurse_level,
177            support_infinite_cycle,
178        })
179    }
180}
181
182impl template_quote::ToTokens for FinalizeArgs {
183    fn to_tokens(&self, tokens: &mut TokenStream) {
184        let crate_identity = LitStr::new(&crate::get_crate_identity(), Span::call_site());
185        let crate_version = env!("CARGO_PKG_VERSION");
186        let working_list = &self.working_list;
187        let traits = &self.traits;
188        let contents = &self.contents;
189
190        let recurse_level = &self.recurse_level;
191        let support_infinite_cycle = &self.support_infinite_cycle;
192
193        tokens.extend(quote! {
194            #crate_identity
195            #crate_version
196            [ #(#working_list),* ]
197            { #(#traits),* }
198            { #(#contents),* }
199            #recurse_level
200            #support_infinite_cycle
201        });
202    }
203}
204
205fn get_initial_rank(count: usize) -> Type {
206    if count == 0 {
207        parse_quote!(())
208    } else {
209        let inner = get_initial_rank(count - 1);
210        parse_quote!((#inner,))
211    }
212}
213
214trait GenericsScheme {
215    fn insert(&self, index: usize, param: TypeParam) -> Self;
216    fn impl_generics(&self) -> TokenStream;
217    fn ty_generics(&self) -> TokenStream;
218}
219
220impl GenericsScheme for Generics {
221    fn insert(&self, index: usize, param: TypeParam) -> Self {
222        let mut generics = self.clone();
223        generics.params.insert(index, GenericParam::Type(param));
224        generics
225    }
226
227    fn impl_generics(&self) -> TokenStream {
228        let (impl_generics, _, _) = self.split_for_impl();
229        quote!(#impl_generics)
230    }
231
232    fn ty_generics(&self) -> TokenStream {
233        let (_, ty_generics, _) = self.split_for_impl();
234        quote!(#ty_generics)
235    }
236}
237
238impl GenericsScheme for Path {
239    fn insert(&self, index: usize, param: TypeParam) -> Self {
240        let mut path = self.clone();
241        let ty = Type::Path(TypePath {
242            qself: None,
243            path: parse_quote!(#param),
244        });
245        path_insert_type_arg(&mut path, index, ty);
246        path
247    }
248
249    fn impl_generics(&self) -> TokenStream {
250        quote!()
251    }
252
253    fn ty_generics(&self) -> TokenStream {
254        if let Some(last_segment) = self.segments.last() {
255            let args = &last_segment.arguments;
256            quote!(#args)
257        } else {
258            quote!()
259        }
260    }
261}
262
263pub fn finalize(args: FinalizeArgs) -> TokenStream {
264    let random_suffix = crate::get_random();
265    let name =
266        |s: &str| -> Ident { Ident::new(&format!("{}{}", s, &random_suffix), Span::call_site()) };
267
268    // Mapping which maps trait path (with no args) to corresponding impl item
269    let mut traits_impls: HashMap<Path, Vec<_>> = HashMap::new();
270
271    for item_impl in args.contents {
272        let mut trait_path = item_impl.trait_.clone().unwrap().1;
273        if let Some(last_seg) = trait_path.segments.last_mut() {
274            last_seg.arguments = PathArguments::None;
275        }
276        traits_impls.entry(trait_path).or_default().push(item_impl);
277    }
278
279    let replacing_table: HashMap<Ident, (usize, Path)> = args
280        .traits
281        .iter()
282        .map(|trait_| {
283            let ident = &trait_.ident;
284            let g = &trait_.generics;
285            let loc = g
286                .params
287                .iter()
288                .position(|param| !matches!(param, GenericParam::Lifetime(_)))
289                .unwrap_or(g.params.len());
290            let ranked_ident_str = format!("{}Ranked", ident);
291            let ranked_ident = name(ranked_ident_str.as_str());
292            let ranked_path: Path = parse_quote!(#ranked_ident);
293            (ident.clone(), (loc, ranked_path))
294        })
295        .collect();
296
297    let mut output = TokenStream::new();
298    for trait_ in &args.traits {
299        let ident = &trait_.ident;
300        let Some(impls) = traits_impls.get(&parse_quote!(#ident)) else {
301            emit_warning!(ident, "trait '{}' has no implementations", ident);
302            continue;
303        };
304
305        let g = &trait_.generics;
306        let &(loc, ref ranked_path) = replacing_table.get(ident).unwrap();
307        let initial_rank = get_initial_rank(args.recurse_level);
308
309        let make_ranked_path = |rank_ty: Type| -> Path {
310            let mut path: Path = parse_quote!(#ranked_path #{g.ty_generics()});
311            path_insert_type_arg(&mut path, loc, rank_ty);
312            path
313        };
314        let ranked_bound = make_ranked_path(initial_rank.clone());
315        let ranked_bound_end = make_ranked_path(parse_quote!(()));
316
317        let delegated_items: Vec<TokenStream> = trait_
318            .items
319            .iter()
320            .map(|item| match item {
321                TraitItem::Fn(method) => {
322                    let sig = &method.sig;
323                    let method_ident = &sig.ident;
324                    let call_args: Vec<TokenStream> = sig
325                        .inputs
326                        .iter()
327                        .map(|arg| match arg {
328                            FnArg::Receiver(receiver) => {
329                                let self_token = &receiver.self_token;
330                                quote!(#self_token)
331                            }
332                            FnArg::Typed(pat_type) => {
333                                let pat = &pat_type.pat;
334                                quote!(#pat)
335                            }
336                        })
337                        .collect();
338                    quote! {
339                        #sig {
340                            <Self as #ranked_bound>::#method_ident(#(#call_args),*)
341                        }
342                    }
343                }
344                TraitItem::Type(assoc_type) => {
345                    let type_ident = &assoc_type.ident;
346                    let generics = &assoc_type.generics;
347                    quote! {
348                        type #type_ident #generics = <Self as #ranked_bound>::#type_ident;
349                    }
350                }
351                TraitItem::Const(assoc_const) => {
352                    let const_ident = &assoc_const.ident;
353                    let ty = &assoc_const.ty;
354                    quote! {
355                        const #const_ident: #ty = <Self as #ranked_bound>::#const_ident;
356                    }
357                }
358                _ => quote!(),
359            })
360            .collect();
361
362        output.extend(quote! {
363            #{&trait_.trait_token} #ranked_path #{g.insert(loc, parse_quote!(#{name("Rank")})).ty_generics()}
364            #{trait_.colon_token} #{&trait_.supertraits} {
365                #(for item in &trait_.items) { #item }
366            }
367        });
368
369        for impl_ in impls {
370            let mut modified_impl = impl_.clone();
371            TraitReplacer(replacing_table.clone(), parse_quote!((#{name("Rank")},)))
372                .visit_path_mut(&mut modified_impl.trait_.as_mut().unwrap().1);
373            TraitReplacer(replacing_table.clone(), parse_quote!(#{name("Rank")}))
374                .visit_item_impl_mut(&mut modified_impl);
375            modified_impl
376                .generics
377                .params
378                .push(parse_quote!(#{name("Rank")}));
379
380            if args.support_infinite_cycle {
381                for (num, item) in modified_impl.items.iter_mut().enumerate() {
382                    if let ImplItem::Fn(ImplItemFn { sig, block, .. }) = item {
383                        let old_block = block.clone();
384                        *block = parse_quote! {
385                            {
386                                let _ = Self::#{name("get_cell")}(#num).set( <Self as #ranked_bound>::#{&sig.ident} as _);
387                                #old_block
388                            }
389                        };
390                    }
391                }
392            }
393
394            let cycle_items: Vec<TokenStream> = impl_
395                .items
396                .iter()
397                .enumerate()
398                .map(|(id, item)| match item {
399                    ImplItem::Fn(method) => {
400                        let mut sig = method.sig.clone();
401                        // ensure that all params are ident
402                        for (num, p) in sig.inputs.iter_mut().enumerate() {
403                            if let FnArg::Typed(PatType { pat, .. }) = p {
404                                if !matches!(pat.as_ref(), Pat::Ident(_)) {
405                                    **pat = Pat::Ident(PatIdent {
406                                        attrs: vec![],
407                                        by_ref: None,
408                                        mutability: None,
409                                        ident: name(format!("param_{}_", num).as_str()),
410                                        subpat: None,
411                                    });
412                                }
413                            }
414                        }
415                        quote! {
416                            #sig {
417                                #(if args.support_infinite_cycle) {
418                                    /// SAFETY:
419                                    #[allow(unused_unsafe)]
420                                    unsafe {
421                                        ::core::mem::transmute::<
422                                            _,
423                                            #{&sig.unsafety} #{&sig.abi}
424                                            fn(
425                                                #(for p in &sig.inputs), {
426                                                    #(if let FnArg::Receiver ( Receiver { ty, .. }) = p) {
427                                                        #ty
428                                                    }
429                                                    #(if let FnArg::Typed ( PatType { ty, .. }) = p) {
430                                                        #ty
431                                                    }
432                                                }
433                                            ) #{&sig.output}
434                                        >(Self::#{name("get_cell")}(#id).get().unwrap())
435                                        (
436                                            #(for p in &sig.inputs), {
437                                                #(if let FnArg::Receiver ( Receiver { self_token, .. }) = p) {
438                                                    #self_token
439                                                }
440                                                #(if let FnArg::Typed ( PatType { pat, .. }) = p) {
441                                                    #pat
442                                                }
443                                            }
444                                        )
445                                    }
446                                }
447                                #(else) {
448                                    ::core::unimplemented!("decycle: cycle limit reached")
449                                }
450                            }
451                        }
452                    }
453                    other => quote!(#other),
454                })
455                .collect();
456
457            let mut impl_generics = impl_.generics.clone();
458            strip_replaced_bounds(&mut impl_generics, &replacing_table);
459            output.extend(quote! {
460                #modified_impl
461
462                #[allow(unused_variables)]
463                impl #{impl_generics.impl_generics()} #ranked_bound_end for #{&impl_.self_ty} #{&impl_generics.where_clause} {
464                    #(#cycle_items)*
465                }
466            });
467            let mut delegated_generics = impl_.generics.clone();
468            strip_replaced_bounds(&mut delegated_generics, &replacing_table);
469            let ranked_bound_pred: WherePredicate =
470                parse_quote!(#{&impl_.self_ty}: #ranked_bound);
471            match delegated_generics.where_clause {
472                Some(ref mut where_clause) => where_clause.predicates.push(ranked_bound_pred),
473                None => {
474                    delegated_generics.where_clause =
475                        Some(parse_quote!(where #ranked_bound_pred));
476                }
477            }
478            output.extend(quote! {
479                #(for attr in &trait_.attrs) { #attr }
480                impl #{delegated_generics.impl_generics()}
481                super::#ident #{g.ty_generics()} for #{&impl_.self_ty} #{&delegated_generics.where_clause}
482                {
483                    #(#delegated_items)*
484                }
485            });
486        }
487    }
488
489    quote! {
490        // this module is to prevent confliction of trait method call between ranked and non-ranked
491        // traits
492        #[doc(hidden)]
493        mod #{name("shadowing_module")} {
494            use super::*;
495
496            // Remove the non-ranked traits from namespace to prevent conflicting
497            #(for ident in replacing_table.keys()) { trait #ident {} }
498
499            #(if args.support_infinite_cycle) {
500                trait #{name("GetVTableKey")} {
501                    extern "C" fn #{name("get_vtable_key")}(&self) {}
502
503                    fn #{name("get_cell")}(id: ::core::primitive::usize) -> &'static ::std::sync::OnceLock<::core::primitive::usize> {
504                        use ::std::sync::{Mutex, OnceLock};
505                        use ::std::collections::HashMap;
506                        use ::std::primitive::*;
507                        static VTABLE_MAP_PARSE: OnceLock<Mutex<HashMap<(usize, usize), OnceLock<usize>>>> = OnceLock::new();
508                        let map = VTABLE_MAP_PARSE.get_or_init(|| Mutex::new(HashMap::new()));
509                        let mut map = map.lock().unwrap();
510                        let r = map.entry((Self::#{name("get_vtable_key")} as usize, id)).or_insert(OnceLock::new());
511                        // SAFETY:
512                        unsafe {
513                            ::core::mem::transmute(r)
514                        }
515                    }
516                }
517
518                impl<T: ?::core::marker::Sized> #{name("GetVTableKey")} for T {}
519            }
520
521            #output
522        }
523    }
524}