1use proc_macro2::{Span, TokenStream};
2use proc_macro_error::*;
3use std::collections::HashMap;
4use syn::parse::{Parse, ParseStream};
5use syn::punctuated::Punctuated;
6use syn::visit_mut::VisitMut;
7use syn::*;
8use template_quote::quote;
9
10macro_rules! parse_quote {
11 ($($tt:tt)*) => {
12 syn::parse2(::template_quote::quote!($($tt)*)).unwrap()
13 };
14}
15
16fn parse_comma_separated<T: Parse>(input: ParseStream) -> Result<Vec<T>> {
17 let mut items = Vec::new();
18 while !input.is_empty() {
19 items.push(input.parse()?);
20 if !input.is_empty() {
21 input.parse::<Token![,]>()?;
22 }
23 }
24 Ok(items)
25}
26
27fn path_insert_type_arg(path: &mut Path, index: usize, ty: Type) {
30 let last_seg = path.segments.last_mut().unwrap();
31 let arg = GenericArgument::Type(ty);
32 match &mut last_seg.arguments {
33 PathArguments::None => {
34 let mut args = Punctuated::new();
35 args.insert(index, arg);
36 last_seg.arguments = PathArguments::AngleBracketed(AngleBracketedGenericArguments {
37 colon2_token: None,
38 lt_token: Default::default(),
39 args,
40 gt_token: Default::default(),
41 });
42 }
43 PathArguments::AngleBracketed(ref mut angle_args) => {
44 angle_args.args.insert(index, arg);
45 }
46 PathArguments::Parenthesized(_) => {}
47 }
48}
49
50fn should_keep_bound(
53 bound: &TypeParamBound,
54 replacing_table: &HashMap<Ident, (usize, Path)>,
55) -> bool {
56 if let TypeParamBound::Trait(trait_bound) = bound {
57 if trait_bound.path.segments.len() == 1 {
58 return !replacing_table.contains_key(&trait_bound.path.segments[0].ident);
59 }
60 }
61 true
62}
63
64fn strip_replaced_bounds(generics: &mut Generics, replacing_table: &HashMap<Ident, (usize, Path)>) {
67 for param in &mut generics.params {
68 if let GenericParam::Type(ref mut type_param) = param {
69 type_param.bounds = type_param
70 .bounds
71 .iter()
72 .filter(|bound| should_keep_bound(bound, replacing_table))
73 .cloned()
74 .collect();
75 if type_param.bounds.is_empty() {
76 type_param.colon_token = None;
77 }
78 }
79 }
80 if let Some(ref mut where_clause) = generics.where_clause {
81 where_clause.predicates = where_clause
82 .predicates
83 .iter()
84 .filter_map(|pred| {
85 if let WherePredicate::Type(type_pred) = pred {
86 let new_bounds: Punctuated<TypeParamBound, Token![+]> = type_pred
87 .bounds
88 .iter()
89 .filter(|bound| should_keep_bound(bound, replacing_table))
90 .cloned()
91 .collect();
92 if new_bounds.is_empty() {
93 None
94 } else {
95 let mut new_pred = type_pred.clone();
96 new_pred.bounds = new_bounds;
97 Some(WherePredicate::Type(new_pred))
98 }
99 } else {
100 Some(pred.clone())
101 }
102 })
103 .collect();
104 if where_clause.predicates.is_empty() {
105 generics.where_clause = None;
106 }
107 }
108}
109
110struct TraitReplacer(HashMap<Ident, (usize, Path)>, Type);
114
115impl VisitMut for TraitReplacer {
116 fn visit_path_mut(&mut self, path: &mut Path) {
117 if path.segments.len() == 1 {
118 if let Some((index, replacement)) = self.0.get(&path.segments[0].ident) {
119 let orig_args =
120 std::mem::replace(&mut path.segments[0].arguments, PathArguments::None);
121 let mut new_path = replacement.clone();
122 new_path.segments.last_mut().unwrap().arguments = orig_args;
123 path_insert_type_arg(&mut new_path, *index, self.1.clone());
124 *path = new_path;
125 return;
126 }
127 }
128 syn::visit_mut::visit_path_mut(self, path);
129 }
130}
131
132pub struct FinalizeArgs {
133 pub working_list: Vec<Path>,
134 pub traits: Vec<ItemTrait>,
135 pub contents: Vec<ItemImpl>,
136 pub recurse_level: usize,
137 pub support_infinite_cycle: bool,
138}
139
140impl Parse for FinalizeArgs {
141 fn parse(input: ParseStream) -> Result<Self> {
142 let _crate_identity: LitStr = input.parse()?;
143 let crate_version: LitStr = input.parse()?;
144 let expected_version = env!("CARGO_PKG_VERSION");
145 if crate_version.value() != expected_version {
146 abort!(
147 Span::call_site(),
148 "version mismatch: expected '{}', got '{}'",
149 expected_version,
150 crate_version.value()
151 )
152 }
153
154 let working_list_content;
155 bracketed!(working_list_content in input);
156 let working_list = parse_comma_separated(&working_list_content)?;
157
158 let traits_content;
159 braced!(traits_content in input);
160 let traits = parse_comma_separated(&traits_content)?;
161
162 let contents_content;
163 braced!(contents_content in input);
164 let contents = parse_comma_separated(&contents_content)?;
165
166 let lit: LitInt = input.parse()?;
167 let recurse_level = lit.base10_parse()?;
168
169 let lit: LitBool = input.parse()?;
170 let support_infinite_cycle = lit.value;
171
172 Ok(FinalizeArgs {
173 working_list,
174 traits,
175 contents,
176 recurse_level,
177 support_infinite_cycle,
178 })
179 }
180}
181
182impl template_quote::ToTokens for FinalizeArgs {
183 fn to_tokens(&self, tokens: &mut TokenStream) {
184 let crate_identity = LitStr::new(&crate::get_crate_identity(), Span::call_site());
185 let crate_version = env!("CARGO_PKG_VERSION");
186 let working_list = &self.working_list;
187 let traits = &self.traits;
188 let contents = &self.contents;
189
190 let recurse_level = &self.recurse_level;
191 let support_infinite_cycle = &self.support_infinite_cycle;
192
193 tokens.extend(quote! {
194 #crate_identity
195 #crate_version
196 [ #(#working_list),* ]
197 { #(#traits),* }
198 { #(#contents),* }
199 #recurse_level
200 #support_infinite_cycle
201 });
202 }
203}
204
205fn get_initial_rank(count: usize) -> Type {
206 if count == 0 {
207 parse_quote!(())
208 } else {
209 let inner = get_initial_rank(count - 1);
210 parse_quote!((#inner,))
211 }
212}
213
214trait GenericsScheme {
215 fn insert(&self, index: usize, param: TypeParam) -> Self;
216 fn impl_generics(&self) -> TokenStream;
217 fn ty_generics(&self) -> TokenStream;
218}
219
220impl GenericsScheme for Generics {
221 fn insert(&self, index: usize, param: TypeParam) -> Self {
222 let mut generics = self.clone();
223 generics.params.insert(index, GenericParam::Type(param));
224 generics
225 }
226
227 fn impl_generics(&self) -> TokenStream {
228 let (impl_generics, _, _) = self.split_for_impl();
229 quote!(#impl_generics)
230 }
231
232 fn ty_generics(&self) -> TokenStream {
233 let (_, ty_generics, _) = self.split_for_impl();
234 quote!(#ty_generics)
235 }
236}
237
238impl GenericsScheme for Path {
239 fn insert(&self, index: usize, param: TypeParam) -> Self {
240 let mut path = self.clone();
241 let ty = Type::Path(TypePath {
242 qself: None,
243 path: parse_quote!(#param),
244 });
245 path_insert_type_arg(&mut path, index, ty);
246 path
247 }
248
249 fn impl_generics(&self) -> TokenStream {
250 quote!()
251 }
252
253 fn ty_generics(&self) -> TokenStream {
254 if let Some(last_segment) = self.segments.last() {
255 let args = &last_segment.arguments;
256 quote!(#args)
257 } else {
258 quote!()
259 }
260 }
261}
262
263pub fn finalize(args: FinalizeArgs) -> TokenStream {
264 let random_suffix = crate::get_random();
265 let name =
266 |s: &str| -> Ident { Ident::new(&format!("{}{}", s, &random_suffix), Span::call_site()) };
267
268 let mut traits_impls: HashMap<Path, Vec<_>> = HashMap::new();
270
271 for item_impl in args.contents {
272 let mut trait_path = item_impl.trait_.clone().unwrap().1;
273 if let Some(last_seg) = trait_path.segments.last_mut() {
274 last_seg.arguments = PathArguments::None;
275 }
276 traits_impls.entry(trait_path).or_default().push(item_impl);
277 }
278
279 let replacing_table: HashMap<Ident, (usize, Path)> = args
280 .traits
281 .iter()
282 .map(|trait_| {
283 let ident = &trait_.ident;
284 let g = &trait_.generics;
285 let loc = g
286 .params
287 .iter()
288 .position(|param| !matches!(param, GenericParam::Lifetime(_)))
289 .unwrap_or(g.params.len());
290 let ranked_ident_str = format!("{}Ranked", ident);
291 let ranked_ident = name(ranked_ident_str.as_str());
292 let ranked_path: Path = parse_quote!(#ranked_ident);
293 (ident.clone(), (loc, ranked_path))
294 })
295 .collect();
296
297 let mut output = TokenStream::new();
298 for trait_ in &args.traits {
299 let ident = &trait_.ident;
300 let Some(impls) = traits_impls.get(&parse_quote!(#ident)) else {
301 emit_warning!(ident, "trait '{}' has no implementations", ident);
302 continue;
303 };
304
305 let g = &trait_.generics;
306 let &(loc, ref ranked_path) = replacing_table.get(ident).unwrap();
307 let initial_rank = get_initial_rank(args.recurse_level);
308
309 let make_ranked_path = |rank_ty: Type| -> Path {
310 let mut path: Path = parse_quote!(#ranked_path #{g.ty_generics()});
311 path_insert_type_arg(&mut path, loc, rank_ty);
312 path
313 };
314 let ranked_bound = make_ranked_path(initial_rank.clone());
315 let ranked_bound_end = make_ranked_path(parse_quote!(()));
316
317 let delegated_items: Vec<TokenStream> = trait_
318 .items
319 .iter()
320 .map(|item| match item {
321 TraitItem::Fn(method) => {
322 let sig = &method.sig;
323 let method_ident = &sig.ident;
324 let call_args: Vec<TokenStream> = sig
325 .inputs
326 .iter()
327 .map(|arg| match arg {
328 FnArg::Receiver(receiver) => {
329 let self_token = &receiver.self_token;
330 quote!(#self_token)
331 }
332 FnArg::Typed(pat_type) => {
333 let pat = &pat_type.pat;
334 quote!(#pat)
335 }
336 })
337 .collect();
338 quote! {
339 #sig {
340 <Self as #ranked_bound>::#method_ident(#(#call_args),*)
341 }
342 }
343 }
344 TraitItem::Type(assoc_type) => {
345 let type_ident = &assoc_type.ident;
346 let generics = &assoc_type.generics;
347 quote! {
348 type #type_ident #generics = <Self as #ranked_bound>::#type_ident;
349 }
350 }
351 TraitItem::Const(assoc_const) => {
352 let const_ident = &assoc_const.ident;
353 let ty = &assoc_const.ty;
354 quote! {
355 const #const_ident: #ty = <Self as #ranked_bound>::#const_ident;
356 }
357 }
358 _ => quote!(),
359 })
360 .collect();
361
362 output.extend(quote! {
363 #{&trait_.trait_token} #ranked_path #{g.insert(loc, parse_quote!(#{name("Rank")})).ty_generics()}
364 #{trait_.colon_token} #{&trait_.supertraits} {
365 #(for item in &trait_.items) { #item }
366 }
367 });
368
369 for impl_ in impls {
370 let mut modified_impl = impl_.clone();
371 TraitReplacer(replacing_table.clone(), parse_quote!((#{name("Rank")},)))
372 .visit_path_mut(&mut modified_impl.trait_.as_mut().unwrap().1);
373 TraitReplacer(replacing_table.clone(), parse_quote!(#{name("Rank")}))
374 .visit_item_impl_mut(&mut modified_impl);
375 modified_impl
376 .generics
377 .params
378 .push(parse_quote!(#{name("Rank")}));
379
380 if args.support_infinite_cycle {
381 for (num, item) in modified_impl.items.iter_mut().enumerate() {
382 if let ImplItem::Fn(ImplItemFn { sig, block, .. }) = item {
383 let old_block = block.clone();
384 *block = parse_quote! {
385 {
386 let _ = Self::#{name("get_cell")}(#num).set( <Self as #ranked_bound>::#{&sig.ident} as _);
387 #old_block
388 }
389 };
390 }
391 }
392 }
393
394 let cycle_items: Vec<TokenStream> = impl_
395 .items
396 .iter()
397 .enumerate()
398 .map(|(id, item)| match item {
399 ImplItem::Fn(method) => {
400 let mut sig = method.sig.clone();
401 for (num, p) in sig.inputs.iter_mut().enumerate() {
403 if let FnArg::Typed(PatType { pat, .. }) = p {
404 if !matches!(pat.as_ref(), Pat::Ident(_)) {
405 **pat = Pat::Ident(PatIdent {
406 attrs: vec![],
407 by_ref: None,
408 mutability: None,
409 ident: name(format!("param_{}_", num).as_str()),
410 subpat: None,
411 });
412 }
413 }
414 }
415 quote! {
416 #sig {
417 #(if args.support_infinite_cycle) {
418 #[allow(unused_unsafe)]
420 unsafe {
421 ::core::mem::transmute::<
422 _,
423 #{&sig.unsafety} #{&sig.abi}
424 fn(
425 #(for p in &sig.inputs), {
426 #(if let FnArg::Receiver ( Receiver { ty, .. }) = p) {
427 #ty
428 }
429 #(if let FnArg::Typed ( PatType { ty, .. }) = p) {
430 #ty
431 }
432 }
433 ) #{&sig.output}
434 >(Self::#{name("get_cell")}(#id).get().unwrap())
435 (
436 #(for p in &sig.inputs), {
437 #(if let FnArg::Receiver ( Receiver { self_token, .. }) = p) {
438 #self_token
439 }
440 #(if let FnArg::Typed ( PatType { pat, .. }) = p) {
441 #pat
442 }
443 }
444 )
445 }
446 }
447 #(else) {
448 ::core::unimplemented!("decycle: cycle limit reached")
449 }
450 }
451 }
452 }
453 other => quote!(#other),
454 })
455 .collect();
456
457 let mut impl_generics = impl_.generics.clone();
458 strip_replaced_bounds(&mut impl_generics, &replacing_table);
459 output.extend(quote! {
460 #modified_impl
461
462 #[allow(unused_variables)]
463 impl #{impl_generics.impl_generics()} #ranked_bound_end for #{&impl_.self_ty} #{&impl_generics.where_clause} {
464 #(#cycle_items)*
465 }
466 });
467 let mut delegated_generics = impl_.generics.clone();
468 strip_replaced_bounds(&mut delegated_generics, &replacing_table);
469 let ranked_bound_pred: WherePredicate =
470 parse_quote!(#{&impl_.self_ty}: #ranked_bound);
471 match delegated_generics.where_clause {
472 Some(ref mut where_clause) => where_clause.predicates.push(ranked_bound_pred),
473 None => {
474 delegated_generics.where_clause =
475 Some(parse_quote!(where #ranked_bound_pred));
476 }
477 }
478 output.extend(quote! {
479 #(for attr in &trait_.attrs) { #attr }
480 impl #{delegated_generics.impl_generics()}
481 super::#ident #{g.ty_generics()} for #{&impl_.self_ty} #{&delegated_generics.where_clause}
482 {
483 #(#delegated_items)*
484 }
485 });
486 }
487 }
488
489 quote! {
490 #[doc(hidden)]
493 mod #{name("shadowing_module")} {
494 use super::*;
495
496 #(for ident in replacing_table.keys()) { trait #ident {} }
498
499 #(if args.support_infinite_cycle) {
500 trait #{name("GetVTableKey")} {
501 extern "C" fn #{name("get_vtable_key")}(&self) {}
502
503 fn #{name("get_cell")}(id: ::core::primitive::usize) -> &'static ::std::sync::OnceLock<::core::primitive::usize> {
504 use ::std::sync::{Mutex, OnceLock};
505 use ::std::collections::HashMap;
506 use ::std::primitive::*;
507 static VTABLE_MAP_PARSE: OnceLock<Mutex<HashMap<(usize, usize), OnceLock<usize>>>> = OnceLock::new();
508 let map = VTABLE_MAP_PARSE.get_or_init(|| Mutex::new(HashMap::new()));
509 let mut map = map.lock().unwrap();
510 let r = map.entry((Self::#{name("get_vtable_key")} as usize, id)).or_insert(OnceLock::new());
511 unsafe {
513 ::core::mem::transmute(r)
514 }
515 }
516 }
517
518 impl<T: ?::core::marker::Sized> #{name("GetVTableKey")} for T {}
519 }
520
521 #output
522 }
523 }
524}