1#![doc = include_str!("../README.md")]
2
3mod normalize;
4mod substitute;
5
6use normalize::WherePredicateBinding;
7use proc_macro::TokenStream as TokenStream1;
8use proc_macro2::{Span, TokenStream};
9use proc_macro_error::{abort, proc_macro_error};
10use std::collections::{HashMap, HashSet};
11use substitute::{Substitute, SubstituteEnvironment};
12use syn::punctuated::Punctuated;
13use syn::spanned::Spanned;
14use syn::visit_mut::VisitMut;
15use syn::*;
16use template_quote::quote;
17
18fn replace_type_of_trait_item_fn(mut ty: TraitItemFn, from: &Type, to: &Type) -> TraitItemFn {
19 use syn::visit_mut::VisitMut;
20 struct Visitor<'a>(&'a Type, &'a Type);
21 impl<'a> VisitMut for Visitor<'a> {
22 fn visit_type_mut(&mut self, ty: &mut Type) {
23 if &ty == &self.0 {
24 *ty = self.1.clone();
25 }
26 syn::visit_mut::visit_type_mut(self, ty)
27 }
28 }
29 let mut visitor = Visitor(from, to);
30 visitor.visit_trait_item_fn_mut(&mut ty);
31 ty
32}
33
34fn check_defaultness(item_impl: &ItemImpl) -> Option<bool> {
35 let mut ret = false;
36 if item_impl.defaultness.is_some() {
38 return None;
39 }
40 for item in item_impl.items.iter() {
41 match item {
42 ImplItem::Const(item_const) if item_const.defaultness.is_some() => {
43 return None;
44 }
45 ImplItem::Fn(item_method) if item_method.defaultness.is_some() => {
46 ret = true;
47 }
48 ImplItem::Type(item_type) if item_type.defaultness.is_some() => {
49 return None;
50 }
51 _ => (),
52 }
53 }
54 Some(ret)
55}
56
57fn normalize_params_and_predicates(
58 impl_: &ItemImpl,
59) -> (HashSet<GenericParam>, HashSet<WherePredicateBinding>) {
60 let (mut gps, mut wps) = (HashSet::new(), HashSet::new());
61 for gp in impl_.generics.params.iter() {
62 let (gp, nwps) = normalize::normalize_generic_param(gp.clone());
63 gps.insert(gp);
64 wps.extend(nwps);
65 }
66 if let Some(wc) = &impl_.generics.where_clause {
67 for p in wc.predicates.iter() {
68 let nwps = normalize::normalize_where_predicate(p.clone());
69 wps.extend(nwps);
70 }
71 }
72 (gps, wps)
73}
74
75fn get_param_ident(p: GenericParam) -> Option<Ident> {
76 match p {
77 GenericParam::Type(tp) => Some(tp.ident),
78 _ => None,
79 }
80}
81
82fn get_type_ident(ty: Type) -> Option<Ident> {
83 match ty {
84 Type::Path(tp) if tp.qself.is_none() => tp.path.get_ident().cloned(),
85 _ => None,
86 }
87}
88
89fn find_type_ident(ty: &Type, ident: &Ident) -> bool {
90 use syn::visit::Visit;
91 struct Visitor<'a>(&'a Ident, bool);
92 impl<'ast, 'a> Visit<'ast> for Visitor<'a> {
93 fn visit_type(&mut self, i: &'ast Type) {
94 match i {
95 Type::Path(tp) if tp.qself.is_none() && tp.path.get_ident() == Some(&self.0) => {
96 self.1 = true;
97 }
98 _ => {
99 syn::visit::visit_type(self, i);
100 }
101 }
102 }
103 }
104 let mut vis = Visitor(ident, false);
105 vis.visit_type(ty);
106 vis.1
107}
108
109fn get_trivial_substitutions(
110 special_params: &HashSet<Ident>,
111 substitution: &HashMap<Ident, Type>,
112) -> Vec<(Ident, Ident)> {
113 substitution
114 .iter()
115 .filter_map(|(d, s)| {
116 get_type_ident(s.clone())
117 .and_then(|i| special_params.iter().find(|ii| &&i == ii).cloned())
118 .map(|s| (d.clone(), s))
119 })
120 .collect()
121}
122
123fn substitute_impl(
124 default_impl: &ItemImpl,
125 special_impl: &ItemImpl,
126) -> Vec<(HashMap<Ident, Type>, usize)> {
127 let (d_ps, d_ws) = normalize_params_and_predicates(default_impl);
128 let (s_ps, s_ws) = normalize_params_and_predicates(special_impl);
129 let self_ident = Ident::new("Self", Span::call_site());
131 let d_ws = d_ws
132 .into_iter()
133 .map(|w| {
134 w.replace_type_params(
135 core::iter::once((self_ident.clone(), default_impl.self_ty.as_ref().clone()))
136 .collect(),
137 )
138 })
139 .collect::<HashSet<_>>();
140 let s_ws = s_ws
141 .into_iter()
142 .map(|w| {
143 w.replace_type_params(
144 core::iter::once((self_ident.clone(), special_impl.self_ty.as_ref().clone()))
145 .collect(),
146 )
147 })
148 .collect::<HashSet<_>>();
149 let s_ps: HashSet<_> = s_ps.into_iter().filter_map(get_param_ident).collect();
150 let env = SubstituteEnvironment {
151 general_params: d_ps.into_iter().filter_map(get_param_ident).collect(),
152 };
153 let s = env.substitute(&d_ws, &s_ws)
154 * env.substitute(
155 &default_impl.trait_.as_ref().unwrap().1,
156 &special_impl.trait_.as_ref().unwrap().1,
157 )
158 * env.substitute(&*default_impl.self_ty, &*special_impl.self_ty);
159 s.0.into_iter()
161 .filter(|m| {
162 m.iter().all(|(_, ty)| {
163 s_ps.iter().all(|i| {
164 &get_type_ident(ty.clone()).as_ref() == &Some(i) || !find_type_ident(ty, &i)
165 })
166 })
167 })
168 .map(|r| {
169 (
170 r.clone(),
171 r.len() - get_trivial_substitutions(&s_ps, &r).len(),
172 )
173 })
174 .collect()
175}
176
177trait ReplaceTypeParams {
178 fn replace_type_params(self, map: HashMap<Ident, Type>) -> Self;
179}
180
181const _: () = {
182 fn filter_map_with_generics(
183 map: &HashMap<Ident, Type>,
184 generics: &Generics,
185 ) -> HashMap<Ident, Type> {
186 map.clone()
187 .into_iter()
188 .filter(|(k, _)| {
189 generics
190 .params
191 .iter()
192 .filter_map(|o| {
193 if let GenericParam::Type(pt) = o {
194 Some(&pt.ident)
195 } else {
196 None
197 }
198 })
199 .all(|id| k != id)
200 })
201 .collect()
202 }
203 #[derive(Clone)]
204 struct Visitor(HashMap<Ident, Type>);
205 impl VisitMut for Visitor {
206 fn visit_type_mut(&mut self, i: &mut Type) {
207 if let Type::Path(tp) = i {
208 if let Some(id) = tp.path.get_ident() {
209 if let Some(replaced) = self.0.get(id) {
210 *i = replaced.clone();
211 return;
212 }
213 }
214 }
215 syn::visit_mut::visit_type_mut(self, i)
216 }
217 fn visit_item_fn_mut(&mut self, i: &mut ItemFn) {
218 let mut this = Visitor(filter_map_with_generics(&self.0, &i.sig.generics));
219 syn::visit_mut::visit_item_fn_mut(&mut this, i);
220 }
221 fn visit_item_impl_mut(&mut self, i: &mut ItemImpl) {
222 let mut this = Visitor(filter_map_with_generics(&self.0, &i.generics));
223 syn::visit_mut::visit_item_impl_mut(&mut this, i);
224 }
225 fn visit_item_trait_mut(&mut self, i: &mut ItemTrait) {
226 let mut this = Visitor(filter_map_with_generics(&self.0, &i.generics));
227 syn::visit_mut::visit_item_trait_mut(&mut this, i);
228 }
229 fn visit_item_struct_mut(&mut self, i: &mut ItemStruct) {
230 let mut this = Visitor(filter_map_with_generics(&self.0, &i.generics));
231 syn::visit_mut::visit_item_struct_mut(&mut this, i);
232 }
233 fn visit_item_enum_mut(&mut self, i: &mut ItemEnum) {
234 let mut this = Visitor(filter_map_with_generics(&self.0, &i.generics));
235 syn::visit_mut::visit_item_enum_mut(&mut this, i);
236 }
237 fn visit_item_type_mut(&mut self, i: &mut ItemType) {
238 let mut this = Visitor(filter_map_with_generics(&self.0, &i.generics));
239 syn::visit_mut::visit_item_type_mut(&mut this, i);
240 }
241 fn visit_item_union_mut(&mut self, i: &mut ItemUnion) {
242 let mut this = Visitor(filter_map_with_generics(&self.0, &i.generics));
243 syn::visit_mut::visit_item_union_mut(&mut this, i);
244 }
245 }
246
247 impl ReplaceTypeParams for WherePredicateBinding {
248 fn replace_type_params(self, map: HashMap<Ident, Type>) -> Self {
249 match self {
250 WherePredicateBinding::Lifetime(lt) => {
251 WherePredicateBinding::Lifetime(lt.replace_type_params(map))
252 }
253 WherePredicateBinding::Type(pt) => {
254 WherePredicateBinding::Type(pt.replace_type_params(map))
255 }
256 WherePredicateBinding::Eq {
257 lhs_ty,
258 eq_token,
259 rhs_ty,
260 } => WherePredicateBinding::Eq {
261 lhs_ty: lhs_ty.replace_type_params(map.clone()),
262 eq_token,
263 rhs_ty: rhs_ty.replace_type_params(map),
264 },
265 }
266 }
267 }
268 impl ReplaceTypeParams for PredicateType {
269 fn replace_type_params(mut self, map: HashMap<Ident, Type>) -> Self {
270 let mut visitor = Visitor(map);
271 visitor.visit_predicate_type_mut(&mut self);
272 self
273 }
274 }
275 impl ReplaceTypeParams for PredicateLifetime {
276 fn replace_type_params(mut self, map: HashMap<Ident, Type>) -> Self {
277 let mut visitor = Visitor(map);
278 visitor.visit_predicate_lifetime_mut(&mut self);
279 self
280 }
281 }
282 impl ReplaceTypeParams for ImplItemFn {
283 fn replace_type_params(mut self, map: HashMap<Ident, Type>) -> Self {
284 let mut visitor = Visitor(map);
285 visitor.visit_impl_item_fn_mut(&mut self);
286 self
287 }
288 }
289 impl ReplaceTypeParams for Type {
290 fn replace_type_params(mut self, map: HashMap<Ident, Type>) -> Self {
291 let mut visitor = Visitor(map);
292 visitor.visit_type_mut(&mut self);
293 self
294 }
295 }
296};
297
298fn contains_generics_param(param: &GenericParam, ty: &Type) -> bool {
299 use syn::visit::Visit;
300 struct Visitor<'a>(&'a GenericParam, bool);
301 impl<'ast, 'a> Visit<'ast> for Visitor<'a> {
302 fn visit_lifetime(&mut self, i: &Lifetime) {
303 if matches!(&self.0, GenericParam::Lifetime(l) if &l.lifetime == i) {
304 self.1 = true;
305 }
306 }
307 fn visit_type_path(&mut self, i: &TypePath) {
308 if matches!(
309 (&self.0, &i.qself, i.path.get_ident()),
310 (GenericParam::Type(TypeParam {ident, ..}), &None, Some(id)) |
311 (GenericParam::Const(ConstParam {ident, ..}), &None, Some(id))
312 if ident == id
313 ) {
314 self.1 = true;
315 } else {
316 syn::visit::visit_type_path(self, i)
317 }
318 }
319 }
320 let mut visitor = Visitor(param, false);
321 visitor.visit_type(ty);
322 visitor.1
323}
324
325fn specialize_item_fn_trait(
326 impl_: &ItemImpl,
327 ident: &Ident,
328 fn_ident: &Ident,
329 impl_item_fn: &ImplItemFn,
330 needs_sized_bound: bool,
331 self_ty: &Type,
332) -> (TokenStream, Punctuated<GenericParam, Token![,]>) {
333 let trait_path = &impl_.trait_.as_ref().unwrap().1;
334 let impl_generics: Punctuated<_, Token![,]> = impl_
335 .generics
336 .params
337 .iter()
338 .filter(|p| {
339 contains_generics_param(
340 p,
341 &Type::Path(TypePath {
342 qself: None,
343 path: trait_path.clone(),
344 }),
345 ) || contains_generics_param(p, self_ty)
346 })
347 .cloned()
348 .collect();
349 let ty_generics: Punctuated<_, Token![,]> = impl_
350 .generics
351 .params
352 .iter()
353 .filter(|p| {
354 contains_generics_param(
355 p,
356 &Type::Path(TypePath {
357 qself: None,
358 path: trait_path.clone(),
359 }),
360 )
361 })
362 .map(|p| {
363 let mut p = p.clone();
364 match &mut p {
365 GenericParam::Lifetime(p) => {
366 p.attrs = Vec::new();
367 p.colon_token = None;
368 p.bounds = Punctuated::new();
369 }
370 GenericParam::Type(t) => {
371 t.attrs = Vec::new();
372 t.colon_token = None;
373 t.bounds = Punctuated::new();
374 t.eq_token = None;
375 t.default = None;
376 }
377 GenericParam::Const(c) => {
378 c.attrs = Vec::new();
379 c.eq_token = None;
380 c.default = None;
381 }
382 }
383 p
384 })
385 .collect();
386 let mut item_fn = replace_type_of_trait_item_fn(
387 TraitItemFn {
388 attrs: vec![],
389 sig: impl_item_fn.sig.clone(),
390 default: None,
391 semi_token: Some(Default::default()),
392 },
393 &impl_.self_ty,
394 &parse_quote!(Self),
395 );
396 item_fn.sig.ident = fn_ident.clone();
397 let mut impl_item_fn = impl_item_fn.clone();
398 impl_item_fn.defaultness = None;
399 impl_item_fn.sig.ident = fn_ident.clone();
400 let out = quote! {
401 trait #ident<#ty_generics>: #trait_path
402 #(if needs_sized_bound) { + ::core::marker::Sized }
403 {
404 #item_fn
405 }
406 impl<#impl_generics> #ident<#ty_generics> for #self_ty
407 #{&impl_.generics.where_clause}
408 {
409 #impl_item_fn
410 }
411 };
412 (out, ty_generics)
413}
414
415fn set_argument_named(sig: &mut Signature) {
416 for (n, arg) in sig.inputs.iter_mut().enumerate() {
417 if let FnArg::Typed(PatType { pat, .. }) = arg {
418 if let Pat::Wild(_) = &**pat {
419 *pat = Box::new(Pat::Ident(PatIdent {
420 attrs: Vec::new(),
421 by_ref: None,
422 mutability: None,
423 ident: Ident::new(&format!("_min_specialization_v{}", n), pat.span()),
424 subpat: None,
425 }));
426 }
427 }
428 }
429}
430
431fn specialize_item_fn(
432 default_impl: &ItemImpl,
433 mut ifn: ImplItemFn,
434 specials: Vec<(HashMap<Ident, Type>, ItemImpl, ImplItemFn)>,
435 needs_sized_bound: bool,
436) -> ImplItemFn {
437 let itrait_name = Ident::new("__MinSpecialization_InnerTrait", Span::call_site());
438 let ifn_name = Ident::new("__min_specialization__inner_fn", Span::call_site());
439 set_argument_named(&mut ifn.sig);
440 let specials_out = specials
441 .into_iter()
442 .enumerate()
443 .map(|(n, (m, simpl, mut sfn))| {
444 let strait_name = Ident::new(
445 &format!("__MinSpecialization_InnerTrait_{}", n),
446 Span::call_site(),
447 );
448 let sfn_name = Ident::new(
449 &format!("__min_specialization__inner_fn_{}", n),
450 Span::call_site(),
451 );
452 sfn.sig.ident = sfn_name.clone();
453 let mut condition = quote! {true};
454 let mut replacement = HashMap::new();
455 for (lhs, rhs) in m.iter() {
456 if let Some(rhs) = get_type_ident(rhs.clone()) {
457 if simpl
458 .generics
459 .params
460 .iter()
461 .filter_map(|p| {
462 if let GenericParam::Type(p) = p {
463 Some(&p.ident)
464 } else {
465 None
466 }
467 })
468 .any(|p| p == &rhs)
469 {
470 let lhs = Type::Path(TypePath {
471 qself: None,
472 path: Path {
473 leading_colon: None,
474 segments: Some(PathSegment {
475 ident: lhs.clone(),
476 arguments: PathArguments::None,
477 })
478 .into_iter()
479 .collect(),
480 },
481 });
482 replacement.insert(rhs, lhs);
483 continue;
484 }
485 }
486 condition.extend(quote! {
487 && __min_specialization_id::<#lhs> as *const ()
488 == __min_specialization_id::<#rhs> as *const ()
489 });
490 }
491 let sfn = sfn.replace_type_params(replacement.clone());
492 let replaced_self_ty = default_impl.self_ty.clone().replace_type_params(m.clone());
493 let (special_trait_impl, special_trait_params) = specialize_item_fn_trait(
494 default_impl,
495 &strait_name,
496 &sfn_name,
497 &sfn,
498 needs_sized_bound,
499 &replaced_self_ty,
500 );
501 quote! {
502 if #condition {
503 #special_trait_impl
504 __min_specialization_transmute(
505 <#replaced_self_ty as #strait_name<
506 #(for par in &special_trait_params), {
507 #(if let GenericParam::Type(TypeParam{ident, ..}) = par) {
508 #(if let Some(ident) = replacement.get(ident)) {
509 #ident
510 }
511 #(else) {
512 #ident
513 }
514 } #(else) {
515 #par
516 }
517 }
518 >>::#sfn_name(
519 #(for arg in &ifn.sig.inputs), {
520 #(if let FnArg::Receiver(_) = arg) {
521 __min_specialization_transmute(self)
522 }
523 #(if let FnArg::Typed(pt) = arg) {
524 __min_specialization_transmute(#{&pt.pat})
525 }
526 }
527 )
528 )
529 } else
530 }
531 })
532 .collect::<Vec<_>>();
533 let (default_trait_impl, default_trait_params) = specialize_item_fn_trait(
534 default_impl,
535 &itrait_name,
536 &ifn_name,
537 &ifn,
538 needs_sized_bound,
539 &default_impl.self_ty,
540 );
541 let inner = quote! {
542 #(for attr in &ifn.attrs) {#attr}
543 #{&ifn.vis}
544 #{&ifn.sig}
545 {
546 fn __min_specialization_id<T>(input: &T) -> ! {
547 unsafe {
548 let _ = ::core::mem::MaybeUninit::new(
549 ::core::ptr::read_volatile(input as *const _)
550 );
551 }
552 ::core::panic!()
553 }
554 fn __min_specialization_transmute<T, U>(input: T) -> U {
555 ::core::assert_eq!(
556 ::core::mem::size_of::<T>(),
557 ::core::mem::size_of::<U>()
558 );
559 ::core::assert_eq!(
560 ::core::mem::align_of::<T>(),
561 ::core::mem::align_of::<U>()
562 );
563 let mut rhs = ::core::mem::MaybeUninit::new(input);
564 let mut lhs = ::core::mem::MaybeUninit::<U>::uninit();
565 unsafe {
566 let rhs = ::core::mem::transmute::<
567 _, &mut ::core::mem::MaybeUninit<U>
568 >(&mut rhs);
569 ::core::ptr::swap(lhs.as_mut_ptr(), rhs.as_mut_ptr());
570 lhs.assume_init()
571 }
572 }
573 #( #specials_out)*
574 {
575 #default_trait_impl
576 <#{&default_impl.self_ty} as #itrait_name<#default_trait_params>>::#ifn_name(
577 #(for arg in &ifn.sig.inputs),{
578 #(if let FnArg::Receiver(Receiver{self_token, ..}) = arg) {
579 #self_token
580 }
581 #(if let FnArg::Typed(PatType{pat, ..}) = arg) {
582 #pat
583 }
584 }
585 )
586 }
587 }
588 };
589 parse2(inner).unwrap()
590}
591
592fn check_needs_sized_bound(impl_: &ItemImpl) -> bool {
593 impl_
594 .items
595 .iter()
596 .filter_map(|item| {
597 if let ImplItem::Fn(item) = item {
598 Some(item)
599 } else {
600 None
601 }
602 })
603 .any(|item| {
604 item.sig
605 .inputs
606 .iter()
607 .filter_map(|item| {
608 if let FnArg::Typed(PatType { ty, .. }) = item {
609 Some(&*ty)
610 } else {
611 None
612 }
613 })
614 .chain(if let ReturnType::Type(_, ty) = &item.sig.output {
615 Some(&*ty)
616 } else {
617 None
618 })
619 .any(|ty| ty == &impl_.self_ty || ty == &parse_quote!(Self))
620 })
621}
622
623fn specialize_impl(
624 mut default_impl: ItemImpl,
625 special_impls: Vec<(ItemImpl, HashMap<Ident, Type>)>,
626) -> ItemImpl {
627 if special_impls.len() == 0 {
628 return default_impl;
629 }
630 let needs_sized_bound = check_needs_sized_bound(&default_impl);
631 let mut fn_map = HashMap::new();
632 for (simpl, ssub) in special_impls.into_iter() {
633 for item in simpl.items.iter() {
634 match item {
635 ImplItem::Fn(ifn) => {
636 fn_map
637 .entry(ifn.sig.ident.clone())
638 .or_insert(Vec::new())
639 .push((ssub.clone(), simpl.clone(), ifn.clone()));
640 }
641 o => abort!(o.span(), "This item cannot be specialized"),
642 }
643 }
644 }
645 let mut out = Vec::new();
646 for item in &default_impl.items {
647 match item {
648 ImplItem::Fn(ifn) => {
649 let specials = fn_map.get(&ifn.sig.ident).cloned().unwrap_or(Vec::new());
650 out.push(ImplItem::Fn(specialize_item_fn(
651 &default_impl,
652 ifn.clone(),
653 specials,
654 needs_sized_bound,
655 )));
656 }
657 o => out.push(o.clone()),
658 }
659 }
660 default_impl.items = out;
661 default_impl
662}
663
664fn specialize_trait(
665 default_impls: HashSet<ItemImpl>,
666 special_impls: HashSet<ItemImpl>,
667) -> (Vec<ItemImpl>, Vec<ItemImpl>) {
668 let mut default_map: HashMap<_, _> = default_impls
669 .iter()
670 .cloned()
671 .map(|d| (d, Vec::new()))
672 .collect();
673 let mut orphan_impls = Vec::new();
674 for s in special_impls.into_iter() {
675 if let Some((d, a, _)) = default_impls
676 .iter()
677 .map(|d| {
678 substitute_impl(d, &s)
679 .into_iter()
680 .map(move |(sub, n)| (d, sub, n))
681 })
682 .flatten()
683 .min_by_key(|(_, _, n)| *n)
684 {
685 default_map
686 .entry(d.clone())
687 .or_insert_with(|| unreachable!())
688 .push((s, a));
689 } else {
690 orphan_impls.push(s);
691 }
692 }
693 (
694 default_map
695 .into_iter()
696 .map(|(d, s)| specialize_impl(d, s))
697 .collect(),
698 orphan_impls,
699 )
700}
701
702fn specialization_mod(module: ItemMod) -> TokenStream {
703 let (_, content) = if let Some(inner) = module.content {
704 inner
705 } else {
706 abort!(module.span(), "Require mod content")
707 };
708 let (mut defaults, mut specials): (HashSet<_>, HashSet<_>) = Default::default();
709 let mut generated_content = Vec::new();
710 for item in content.into_iter() {
711 if let Item::Impl(item_impl) = &item {
712 if item_impl.trait_.is_some() {
713 if let Some(defaultness) = check_defaultness(&item_impl) {
714 if defaultness {
715 defaults.insert(item_impl.clone());
716 } else {
717 specials.insert(item_impl.clone());
718 }
719 continue;
720 }
721 }
722 }
723 generated_content.push(item);
724 }
725 let (impls, orphans) = specialize_trait(defaults, specials);
726 generated_content.extend(impls.into_iter().map(Item::Impl));
727 generated_content.extend(orphans.into_iter().map(Item::Impl));
728
729 quote! {
730 #(for attr in &module.attrs) { #attr }
731 #{&module.vis}
732 #{&module.mod_token}
733 #{&module.ident}
734 {
735 #(#generated_content)*
736 }
737 }
738}
739
740#[proc_macro_error]
741#[proc_macro_attribute]
742pub fn specialization(_attr: TokenStream1, input: TokenStream1) -> TokenStream1 {
743 let module = parse_macro_input!(input);
744 specialization_mod(module).into()
745}