1use proc_macro2::{Span, TokenStream};
2use proc_macro_error::*;
3use std::collections::HashMap;
4use syn::parse::{Parse, ParseStream};
5use syn::punctuated::Punctuated;
6use syn::visit_mut::VisitMut;
7use syn::*;
8use template_quote::quote;
9
10macro_rules! parse_quote {
11 ($($tt:tt)*) => {
12 syn::parse2(::template_quote::quote!($($tt)*)).unwrap()
13 };
14}
15
16
17fn parse_comma_separated<T: Parse>(input: ParseStream) -> Result<Vec<T>> {
18 let mut items = Vec::new();
19 while !input.is_empty() {
20 items.push(input.parse()?);
21 if !input.is_empty() {
22 input.parse::<Token![,]>()?;
23 }
24 }
25 Ok(items)
26}
27
28fn path_insert_type_arg(path: &mut Path, index: usize, ty: Type) {
31 let last_seg = path.segments.last_mut().unwrap();
32 let arg = GenericArgument::Type(ty);
33 match &mut last_seg.arguments {
34 PathArguments::None => {
35 let mut args = Punctuated::new();
36 args.insert(index, arg);
37 last_seg.arguments = PathArguments::AngleBracketed(AngleBracketedGenericArguments {
38 colon2_token: None,
39 lt_token: Default::default(),
40 args,
41 gt_token: Default::default(),
42 });
43 }
44 PathArguments::AngleBracketed(ref mut angle_args) => {
45 angle_args.args.insert(index, arg);
46 }
47 PathArguments::Parenthesized(_) => {}
48 }
49}
50
51fn should_keep_bound(bound: &TypeParamBound, replacing_table: &HashMap<Ident, (usize, Path)>) -> bool {
54 if let TypeParamBound::Trait(trait_bound) = bound {
55 if trait_bound.path.segments.len() == 1 {
56 return !replacing_table.contains_key(&trait_bound.path.segments[0].ident);
57 }
58 }
59 true
60}
61
62fn strip_replaced_bounds(
65 generics: &mut Generics,
66 replacing_table: &HashMap<Ident, (usize, Path)>,
67) {
68 for param in &mut generics.params {
69 if let GenericParam::Type(ref mut type_param) = param {
70 type_param.bounds = type_param
71 .bounds
72 .iter()
73 .filter(|bound| should_keep_bound(bound, replacing_table))
74 .cloned()
75 .collect();
76 if type_param.bounds.is_empty() {
77 type_param.colon_token = None;
78 }
79 }
80 }
81 if let Some(ref mut where_clause) = generics.where_clause {
82 where_clause.predicates = where_clause
83 .predicates
84 .iter()
85 .filter_map(|pred| {
86 if let WherePredicate::Type(type_pred) = pred {
87 let new_bounds: Punctuated<TypeParamBound, Token![+]> = type_pred
88 .bounds
89 .iter()
90 .filter(|bound| should_keep_bound(bound, replacing_table))
91 .cloned()
92 .collect();
93 if new_bounds.is_empty() {
94 None
95 } else {
96 let mut new_pred = type_pred.clone();
97 new_pred.bounds = new_bounds;
98 Some(WherePredicate::Type(new_pred))
99 }
100 } else {
101 Some(pred.clone())
102 }
103 })
104 .collect();
105 if where_clause.predicates.is_empty() {
106 generics.where_clause = None;
107 }
108 }
109}
110
111struct TraitReplacer(HashMap<Ident, (usize, Path)>, Type);
115
116impl VisitMut for TraitReplacer {
117 fn visit_path_mut(&mut self, path: &mut Path) {
118 if path.segments.len() == 1 {
119 if let Some((index, replacement)) = self.0.get(&path.segments[0].ident) {
120 let orig_args =
121 std::mem::replace(&mut path.segments[0].arguments, PathArguments::None);
122 let mut new_path = replacement.clone();
123 new_path.segments.last_mut().unwrap().arguments = orig_args;
124 path_insert_type_arg(&mut new_path, *index, self.1.clone());
125 *path = new_path;
126 return;
127 }
128 }
129 syn::visit_mut::visit_path_mut(self, path);
130 }
131}
132
133pub struct FinalizeArgs {
134 pub working_list: Vec<Path>,
135 pub traits: Vec<ItemTrait>,
136 pub contents: Vec<ItemImpl>,
137 pub recurse_level: usize,
138 pub support_infinite_cycle: bool,
139}
140
141impl Parse for FinalizeArgs {
142 fn parse(input: ParseStream) -> Result<Self> {
143 let _crate_identity: LitStr = input.parse()?;
144 let crate_version: LitStr = input.parse()?;
145 let expected_version = env!("CARGO_PKG_VERSION");
146 if crate_version.value() != expected_version {
147 abort!(
148 Span::call_site(),
149 "version mismatch: expected '{}', got '{}'",
150 expected_version,
151 crate_version.value()
152 )
153 }
154
155 let working_list_content;
156 bracketed!(working_list_content in input);
157 let working_list = parse_comma_separated(&working_list_content)?;
158
159 let traits_content;
160 braced!(traits_content in input);
161 let traits = parse_comma_separated(&traits_content)?;
162
163 let contents_content;
164 braced!(contents_content in input);
165 let contents = parse_comma_separated(&contents_content)?;
166
167 let lit: LitInt = input.parse()?;
168 let recurse_level = lit.base10_parse()?;
169
170 let lit: LitBool = input.parse()?;
171 let support_infinite_cycle = lit.value;
172
173 Ok(FinalizeArgs {
174 working_list,
175 traits,
176 contents,
177 recurse_level,
178 support_infinite_cycle,
179 })
180 }
181}
182
183impl template_quote::ToTokens for FinalizeArgs {
184 fn to_tokens(&self, tokens: &mut TokenStream) {
185 let crate_identity = LitStr::new(&crate::get_crate_identity(), Span::call_site());
186 let crate_version = env!("CARGO_PKG_VERSION");
187 let working_list = &self.working_list;
188 let traits = &self.traits;
189 let contents = &self.contents;
190
191 let recurse_level = &self.recurse_level;
192 let support_infinite_cycle = &self.support_infinite_cycle;
193
194 tokens.extend(quote! {
195 #crate_identity
196 #crate_version
197 [ #(#working_list),* ]
198 { #(#traits),* }
199 { #(#contents),* }
200 #recurse_level
201 #support_infinite_cycle
202 });
203 }
204}
205
206fn get_initial_rank(count: usize) -> Type {
207 if count == 0 {
208 parse_quote!(())
209 } else {
210 let inner = get_initial_rank(count - 1);
211 parse_quote!((#inner,))
212 }
213}
214
215trait GenericsScheme {
216 fn insert(&self, index: usize, param: TypeParam) -> Self;
217 fn impl_generics(&self) -> TokenStream;
218 fn ty_generics(&self) -> TokenStream;
219}
220
221impl GenericsScheme for Generics {
222 fn insert(&self, index: usize, param: TypeParam) -> Self {
223 let mut generics = self.clone();
224 generics.params.insert(index, GenericParam::Type(param));
225 generics
226 }
227
228 fn impl_generics(&self) -> TokenStream {
229 let (impl_generics, _, _) = self.split_for_impl();
230 quote!(#impl_generics)
231 }
232
233 fn ty_generics(&self) -> TokenStream {
234 let (_, ty_generics, _) = self.split_for_impl();
235 quote!(#ty_generics)
236 }
237}
238
239impl GenericsScheme for Path {
240 fn insert(&self, index: usize, param: TypeParam) -> Self {
241 let mut path = self.clone();
242 let ty = Type::Path(TypePath {
243 qself: None,
244 path: parse_quote!(#param),
245 });
246 path_insert_type_arg(&mut path, index, ty);
247 path
248 }
249
250 fn impl_generics(&self) -> TokenStream {
251 quote!()
252 }
253
254 fn ty_generics(&self) -> TokenStream {
255 if let Some(last_segment) = self.segments.last() {
256 let args = &last_segment.arguments;
257 quote!(#args)
258 } else {
259 quote!()
260 }
261 }
262}
263
264pub fn finalize(args: FinalizeArgs) -> TokenStream {
265 let random_suffix = crate::get_random();
266 let name =
267 |s: &str| -> Ident { Ident::new(&format!("{}{}", s, &random_suffix), Span::call_site()) };
268
269 let mut traits_impls: HashMap<Path, Vec<_>> = HashMap::new();
271
272 for item_impl in args.contents {
273 let mut trait_path = item_impl.trait_.clone().unwrap().1;
274 if let Some(last_seg) = trait_path.segments.last_mut() {
275 last_seg.arguments = PathArguments::None;
276 }
277 traits_impls.entry(trait_path).or_default().push(item_impl);
278 }
279
280 let replacing_table: HashMap<Ident, (usize, Path)> = args
281 .traits
282 .iter()
283 .map(|trait_| {
284 let ident = &trait_.ident;
285 let g = &trait_.generics;
286 let loc = g
287 .params
288 .iter()
289 .position(|param| !matches!(param, GenericParam::Lifetime(_)))
290 .unwrap_or(g.params.len());
291 let ranked_ident_str = format!("{}Ranked", ident);
292 let ranked_ident = name(ranked_ident_str.as_str());
293 let ranked_path: Path = parse_quote!(#ranked_ident);
294 (ident.clone(), (loc, ranked_path))
295 })
296 .collect();
297
298 let mut output = TokenStream::new();
299 for trait_ in &args.traits {
300 let ident = &trait_.ident;
301 let Some(impls) = traits_impls.get(&parse_quote!(#ident)) else {
302 emit_warning!(ident, "trait '{}' has no implementations", ident);
303 continue;
304 };
305
306 let g = &trait_.generics;
307 let &(loc, ref ranked_path) = replacing_table.get(ident).unwrap();
308 let initial_rank = get_initial_rank(args.recurse_level);
309
310 let make_ranked_path = |rank_ty: Type| -> Path {
311 let mut path: Path = parse_quote!(#ranked_path #{g.ty_generics()});
312 path_insert_type_arg(&mut path, loc, rank_ty);
313 path
314 };
315 let ranked_bound = make_ranked_path(initial_rank.clone());
316 let ranked_bound_end = make_ranked_path(parse_quote!(()));
317
318 let delegated_items: Vec<TokenStream> = trait_
319 .items
320 .iter()
321 .map(|item| match item {
322 TraitItem::Fn(method) => {
323 let sig = &method.sig;
324 let method_ident = &sig.ident;
325 let call_args: Vec<TokenStream> = sig
326 .inputs
327 .iter()
328 .map(|arg| match arg {
329 FnArg::Receiver(receiver) => {
330 let self_token = &receiver.self_token;
331 quote!(#self_token)
332 }
333 FnArg::Typed(pat_type) => {
334 let pat = &pat_type.pat;
335 quote!(#pat)
336 }
337 })
338 .collect();
339 quote! {
340 #sig {
341 <Self as #ranked_bound>::#method_ident(#(#call_args),*)
342 }
343 }
344 }
345 TraitItem::Type(assoc_type) => {
346 let type_ident = &assoc_type.ident;
347 let generics = &assoc_type.generics;
348 quote! {
349 type #type_ident #generics = <Self as #ranked_bound>::#type_ident;
350 }
351 }
352 TraitItem::Const(assoc_const) => {
353 let const_ident = &assoc_const.ident;
354 let ty = &assoc_const.ty;
355 quote! {
356 const #const_ident: #ty = <Self as #ranked_bound>::#const_ident;
357 }
358 }
359 _ => quote!(),
360 })
361 .collect();
362
363 output.extend(quote! {
364 #{&trait_.trait_token} #ranked_path #{g.insert(loc, parse_quote!(#{name("Rank")})).ty_generics()}
365 #{trait_.colon_token} #{&trait_.supertraits} {
366 #(for item in &trait_.items) { #item }
367 }
368 });
369 output.extend(quote! {
370 #(for attr in &trait_.attrs) { #attr }
371 impl #{g.insert(loc, parse_quote!(
372 #{name("Self")}: #ranked_bound
373 )).impl_generics()}
374 super::#ident #{g.ty_generics()} for #{name("Self")} #{&g.where_clause} {
375 #(#delegated_items)*
376 }
377 });
378
379 for impl_ in impls {
380 let mut modified_impl = impl_.clone();
381 TraitReplacer(replacing_table.clone(), parse_quote!((#{name("Rank")},)))
382 .visit_path_mut(&mut modified_impl.trait_.as_mut().unwrap().1);
383 TraitReplacer(replacing_table.clone(), parse_quote!(#{name("Rank")}))
384 .visit_item_impl_mut(&mut modified_impl);
385 modified_impl
386 .generics
387 .params
388 .push(parse_quote!(#{name("Rank")}));
389
390 if args.support_infinite_cycle {
391 for (num, item) in modified_impl.items.iter_mut().enumerate() {
392 if let ImplItem::Fn(ImplItemFn { sig, block, .. }) = item {
393 let old_block = block.clone();
394 *block = parse_quote! {
395 {
396 let _ = Self::#{name("get_cell")}(#num).set( <Self as #ranked_bound>::#{&sig.ident} as _);
397 #old_block
398 }
399 };
400 }
401 }
402 }
403
404 let cycle_items: Vec<TokenStream> = impl_
405 .items
406 .iter()
407 .enumerate()
408 .map(|(id, item)| match item {
409 ImplItem::Fn(method) => {
410 let mut sig = method.sig.clone();
411 for (num, p) in sig.inputs.iter_mut().enumerate() {
413 if let FnArg::Typed(PatType { pat, .. }) = p {
414 if !matches!(pat.as_ref(), Pat::Ident(_)) {
415 **pat = Pat::Ident(PatIdent {
416 attrs: vec![],
417 by_ref: None,
418 mutability: None,
419 ident: name(format!("param_{}_", num).as_str()),
420 subpat: None,
421 });
422 }
423 }
424 }
425 quote! {
426 #sig {
427 #(if args.support_infinite_cycle) {
428 #[allow(unused_unsafe)]
430 unsafe {
431 ::core::mem::transmute::<
432 _,
433 #{&sig.unsafety} #{&sig.abi}
434 fn(
435 #(for p in &sig.inputs), {
436 #(if let FnArg::Receiver ( Receiver { ty, .. }) = p) {
437 #ty
438 }
439 #(if let FnArg::Typed ( PatType { ty, .. }) = p) {
440 #ty
441 }
442 }
443 ) #{&sig.output}
444 >(Self::#{name("get_cell")}(#id).get().unwrap())
445 (
446 #(for p in &sig.inputs), {
447 #(if let FnArg::Receiver ( Receiver { self_token, .. }) = p) {
448 #self_token
449 }
450 #(if let FnArg::Typed ( PatType { pat, .. }) = p) {
451 #pat
452 }
453 }
454 )
455 }
456 }
457 #(else) {
458 ::core::unimplemented!("decycle: cycle limit reached")
459 }
460 }
461 }
462 }
463 other => quote!(#other),
464 })
465 .collect();
466
467 let mut modified_g = g.clone();
468 strip_replaced_bounds(&mut modified_g, &replacing_table);
469
470 output.extend(quote! {
471 #modified_impl
472
473 #[allow(unused_variables)]
474 impl #{modified_g.impl_generics()} #ranked_bound_end for #{&impl_.self_ty} #{&modified_g.where_clause} {
475 #(#cycle_items)*
476 }
477 });
478 }
479 }
480
481 quote! {
482 #[doc(hidden)]
485 mod #{name("shadowing_module")} {
486 use super::*;
487
488 #(for ident in replacing_table.keys()) { trait #ident {} }
490
491 #(if args.support_infinite_cycle) {
492 trait #{name("GetVTableKey")} {
493 extern "C" fn #{name("get_vtable_key")}(&self) {}
494
495 fn #{name("get_cell")}(id: ::core::primitive::usize) -> &'static ::std::sync::OnceLock<::core::primitive::usize> {
496 use ::std::sync::{Mutex, OnceLock};
497 use ::std::collections::HashMap;
498 use ::std::primitive::*;
499 static VTABLE_MAP_PARSE: OnceLock<Mutex<HashMap<(usize, usize), OnceLock<usize>>>> = OnceLock::new();
500 let map = VTABLE_MAP_PARSE.get_or_init(|| Mutex::new(HashMap::new()));
501 let mut map = map.lock().unwrap();
502 let r = map.entry((Self::#{name("get_vtable_key")} as usize, id)).or_insert(OnceLock::new());
503 unsafe {
505 ::core::mem::transmute(r)
506 }
507 }
508 }
509
510 impl<T: ?::core::marker::Sized> #{name("GetVTableKey")} for T {}
511 }
512
513 #output
514 }
515 }
516}