1#![doc(test(
174 no_crate_inject,
175 attr(allow(
176 dead_code,
177 unused_variables,
178 unreachable_pub,
179 clippy::undocumented_unsafe_blocks,
180 clippy::unused_trait_names,
181 ))
182))]
183#![forbid(unsafe_code)]
184
185#[allow(unused_extern_crates)]
187extern crate proc_macro;
188
189#[macro_use]
190mod error;
191
192mod ast;
193mod iter;
194mod to_tokens;
195
196use std::{collections::hash_map::DefaultHasher, hash::Hasher, iter::FromIterator, mem};
197
198use proc_macro::{Delimiter, Group, Ident, Punct, Spacing, Span, TokenStream, TokenTree};
199
200use self::{
201 ast::{
202 Attribute, AttributeKind, FnArg, GenericParam, Generics, ImplItem, ItemImpl, ItemTrait,
203 PredicateType, Signature, TraitItem, TraitItemConst, TraitItemMethod, TraitItemType,
204 TypeParam, Visibility, WherePredicate, parsing,
205 },
206 error::{Error, Result},
207 iter::TokenIter,
208 to_tokens::ToTokens,
209};
210
211#[proc_macro_attribute]
217pub fn ext(args: TokenStream, input: TokenStream) -> TokenStream {
218 expand(args, input).unwrap_or_else(Error::into_compile_error)
219}
220
221fn expand(args: TokenStream, input: TokenStream) -> Result<TokenStream> {
222 let trait_name = match parse_args(args)? {
223 None => Ident::new(&format!("__ExtTrait{}", hash(&input)), Span::call_site()),
224 Some(trait_name) => trait_name,
225 };
226
227 let mut item: ItemImpl = parsing::parse_impl(&mut TokenIter::new(input))?;
228
229 let mut tokens = trait_from_impl(&mut item, trait_name)?.to_token_stream();
230 item.to_tokens(&mut tokens);
231 Ok(tokens)
232}
233
234fn parse_args(input: TokenStream) -> Result<Option<Ident>> {
235 let input = &mut TokenIter::new(input);
236 let vis = ast::parsing::parse_visibility(input)?;
237 if !vis.is_inherited() {
238 bail!(vis, "use `{} impl` instead", vis);
239 }
240 let trait_name = input.parse_ident_opt();
241 if !input.is_empty() {
242 let tt = input.next().unwrap();
243 bail!(tt, "unexpected token: `{}`", tt);
244 }
245 Ok(trait_name)
246}
247
248fn determine_trait_generics<'a>(
249 generics: &mut Generics,
250 self_ty: &'a [TokenTree],
251) -> Option<&'a Ident> {
252 if self_ty.len() != 1 {
253 return None;
254 }
255 if let TokenTree::Ident(self_ty) = &self_ty[0] {
256 let i = generics.params.iter().position(|(param, _)| {
257 if let GenericParam::Type(param) = param {
258 param.ident.to_string() == self_ty.to_string()
259 } else {
260 false
261 }
262 });
263 if let Some(i) = i {
264 let mut params = mem::replace(&mut generics.params, vec![]);
265 let (param, _) = params.remove(i);
266 generics.params = params;
267
268 if let GenericParam::Type(TypeParam {
269 colon_token: Some(colon_token), bounds, ..
270 }) = param
271 {
272 let bounds = bounds.into_iter().filter(|(b, _)| !b.is_maybe).collect::<Vec<_>>();
273 if !bounds.is_empty() {
274 let where_clause = generics.make_where_clause();
275 if let Some((_, p)) = where_clause.predicates.last_mut() {
276 p.get_or_insert_with(|| Punct::new(',', Spacing::Alone));
277 }
278 where_clause.predicates.push((
279 WherePredicate::Type(PredicateType {
280 lifetimes: None,
281 bounded_ty: std::iter::once(TokenTree::Ident(Ident::new(
282 "Self",
283 self_ty.span(),
284 )))
285 .collect(),
286 colon_token,
287 bounds,
288 }),
289 None,
290 ));
291 }
292 }
293
294 return Some(self_ty);
295 }
296 }
297 None
298}
299
300fn trait_from_impl(item: &mut ItemImpl, trait_name: Ident) -> Result<ItemTrait> {
301 struct ReplaceParam {
303 self_ty: String,
304 remove_maybe: bool,
308 }
309
310 impl ReplaceParam {
311 fn visit_token_stream(&self, tokens: &mut TokenStream) -> bool {
312 let mut out: Vec<TokenTree> = vec![];
313 let mut modified = false;
314 let iter = tokens.clone().into_iter();
315 for tt in iter {
316 match tt {
317 TokenTree::Ident(ident) => {
318 if ident.to_string() == self.self_ty {
319 modified = true;
320 let self_ = Ident::new("Self", ident.span());
321 out.push(self_.into());
322 } else {
323 out.push(TokenTree::Ident(ident));
324 }
325 }
326 TokenTree::Group(group) => {
327 let mut content = group.stream();
328 modified |= self.visit_token_stream(&mut content);
329 let mut new = Group::new(group.delimiter(), content);
330 new.set_span(group.span());
331 out.push(TokenTree::Group(new));
332 }
333 other => out.push(other),
334 }
335 }
336 if modified {
337 *tokens = TokenStream::from_iter(out);
338 }
339 modified
340 }
341
342 fn visit_trait_item_mut(&self, node: &mut TraitItem) {
345 match node {
346 TraitItem::Const(node) => {
347 self.visit_token_stream(&mut node.ty);
348 }
349 TraitItem::Method(node) => {
350 self.visit_signature_mut(&mut node.sig);
351 }
352 TraitItem::Type(node) => {
353 self.visit_generics_mut(&mut node.generics);
354 }
355 }
356 }
357
358 fn visit_signature_mut(&self, node: &mut Signature) {
359 self.visit_generics_mut(&mut node.generics);
360 for arg in &mut node.inputs {
361 self.visit_fn_arg_mut(arg);
362 }
363 if let Some(ty) = &mut node.output {
364 self.visit_token_stream(ty);
365 }
366 }
367
368 fn visit_fn_arg_mut(&self, node: &mut FnArg) {
369 match node {
370 FnArg::Receiver(pat, _) => {
371 self.visit_token_stream(pat);
372 }
373 FnArg::Typed(pat, _, ty, _) => {
374 self.visit_token_stream(pat);
375 self.visit_token_stream(ty);
376 }
377 }
378 }
379
380 fn visit_generics_mut(&self, generics: &mut Generics) {
381 for (param, _) in &mut generics.params {
382 match param {
383 GenericParam::Type(param) => {
384 for (bound, _) in &mut param.bounds {
385 self.visit_token_stream(&mut bound.tokens);
386 }
387 }
388 GenericParam::Const(_) | GenericParam::Lifetime(_) => {}
389 }
390 }
391 if let Some(where_clause) = &mut generics.where_clause {
392 let predicates = Vec::with_capacity(where_clause.predicates.len());
393 for (mut predicate, p) in mem::replace(&mut where_clause.predicates, predicates) {
394 match &mut predicate {
395 WherePredicate::Type(pred) => {
396 if self.remove_maybe {
397 let mut iter = pred.bounded_ty.clone().into_iter();
398 if let Some(TokenTree::Ident(i)) = iter.next() {
399 if iter.next().is_none() && self.self_ty == i.to_string() {
400 let bounds = mem::replace(&mut pred.bounds, vec![])
401 .into_iter()
402 .filter(|(b, _)| !b.is_maybe)
403 .collect::<Vec<_>>();
404 if !bounds.is_empty() {
405 self.visit_token_stream(&mut pred.bounded_ty);
406 pred.bounds = bounds;
407 for (bound, _) in &mut pred.bounds {
408 self.visit_token_stream(&mut bound.tokens);
409 }
410 where_clause.predicates.push((predicate, p));
411 }
412 continue;
413 }
414 }
415 }
416
417 self.visit_token_stream(&mut pred.bounded_ty);
418 for (bound, _) in &mut pred.bounds {
419 self.visit_token_stream(&mut bound.tokens);
420 }
421 }
422 WherePredicate::Lifetime(_) => {}
423 }
424 where_clause.predicates.push((predicate, p));
425 }
426 }
427 }
428 }
429
430 let mut generics = item.generics.clone();
431 let mut visitor = determine_trait_generics(&mut generics, &item.self_ty)
432 .map(|self_ty| ReplaceParam { self_ty: self_ty.to_string(), remove_maybe: false });
433
434 if let Some(visitor) = &mut visitor {
435 visitor.remove_maybe = true;
436 visitor.visit_generics_mut(&mut generics);
437 visitor.remove_maybe = false;
438 }
439 let ty_generics = generics.ty_generics();
440 item.trait_ = Some((
441 trait_name.clone(),
442 ty_generics.to_token_stream(),
443 Ident::new("for", Span::call_site()),
444 ));
445
446 let impl_vis = if item.vis.is_inherited() { None } else { Some(item.vis.clone()) };
448 let mut assoc_vis = None;
450 let mut items = Vec::with_capacity(item.items.len());
451 item.items.iter_mut().try_for_each(|item| {
452 trait_item_from_impl_item(item, &mut assoc_vis, impl_vis.as_ref()).map(|mut item| {
453 if let Some(visitor) = &mut visitor {
454 visitor.visit_trait_item_mut(&mut item);
455 }
456 items.push(item);
457 })
458 })?;
459
460 let mut attrs = item.attrs.clone();
461 find_remove(&mut item.attrs, AttributeKind::Doc); attrs.push(Attribute::new(vec![
463 TokenTree::Ident(Ident::new("allow", Span::call_site())),
464 TokenTree::Group(Group::new(
465 Delimiter::Parenthesis,
466 std::iter::once(TokenTree::Ident(Ident::new(
467 "patterns_in_fns_without_body",
468 Span::call_site(),
469 )))
470 .collect(),
471 )),
472 ])); Ok(ItemTrait {
475 attrs,
476 vis: impl_vis.unwrap_or_else(|| assoc_vis.unwrap_or(Visibility::Inherited)),
478 unsafety: item.unsafety.clone(),
479 trait_token: Ident::new("trait", item.impl_token.span()),
480 ident: trait_name,
481 generics,
482 brace_token: item.brace_token,
483 items,
484 })
485}
486
487fn trait_item_from_impl_item(
488 impl_item: &mut ImplItem,
489 prev_vis: &mut Option<Visibility>,
490 impl_vis: Option<&Visibility>,
491) -> Result<TraitItem> {
492 fn check_visibility(
493 current: Visibility,
494 prev: &mut Option<Visibility>,
495 impl_vis: Option<&Visibility>,
496 span: &dyn ToTokens,
497 ) -> Result<()> {
498 if impl_vis.is_some() {
499 if current.is_inherited() {
500 return Ok(());
501 }
502 bail!(current, "all associated items must have inherited visibility");
503 }
504 match prev {
505 None => *prev = Some(current),
506 Some(prev) if *prev == current => {}
507 Some(prev) => {
508 if prev.is_inherited() {
509 bail!(current, "all associated items must have inherited visibility");
510 }
511 bail!(
512 if current.is_inherited() { span } else { ¤t },
513 "all associated items must have a visibility of `{}`",
514 prev,
515 );
516 }
517 }
518 Ok(())
519 }
520
521 match impl_item {
522 ImplItem::Const(impl_const) => {
523 let vis = mem::replace(&mut impl_const.vis, Visibility::Inherited);
524 check_visibility(vis, prev_vis, impl_vis, &impl_const.ident)?;
525
526 let attrs = impl_const.attrs.clone();
527 find_remove(&mut impl_const.attrs, AttributeKind::Doc); Ok(TraitItem::Const(TraitItemConst {
529 attrs,
530 const_token: impl_const.const_token.clone(),
531 ident: impl_const.ident.clone(),
532 colon_token: impl_const.colon_token.clone(),
533 ty: impl_const.ty.clone(),
534 semi_token: impl_const.semi_token.clone(),
535 }))
536 }
537 ImplItem::Type(impl_type) => {
538 let vis = mem::replace(&mut impl_type.vis, Visibility::Inherited);
539 check_visibility(vis, prev_vis, impl_vis, &impl_type.ident)?;
540
541 let attrs = impl_type.attrs.clone();
542 find_remove(&mut impl_type.attrs, AttributeKind::Doc); Ok(TraitItem::Type(TraitItemType {
544 attrs,
545 type_token: impl_type.type_token.clone(),
546 ident: impl_type.ident.clone(),
547 generics: impl_type.generics.clone(),
548 semi_token: impl_type.semi_token.clone(),
549 }))
550 }
551 ImplItem::Method(impl_method) => {
552 let vis = mem::replace(&mut impl_method.vis, Visibility::Inherited);
553 check_visibility(vis, prev_vis, impl_vis, &impl_method.sig.ident)?;
554
555 let mut attrs = impl_method.attrs.clone();
556 find_remove(&mut impl_method.attrs, AttributeKind::Doc); find_remove(&mut attrs, AttributeKind::Inline); Ok(TraitItem::Method(TraitItemMethod {
559 attrs,
560 sig: {
561 let mut sig = impl_method.sig.clone();
562 for arg in &mut sig.inputs {
563 if let FnArg::Typed(pat, ..) = arg {
564 if pat.to_string() != "self" {
565 *pat = std::iter::once(TokenTree::Ident(Ident::new(
566 "_",
567 pat.clone().into_iter().next().unwrap().span(),
568 )))
569 .collect();
570 }
571 }
572 }
573 sig
574 },
575 semi_token: {
576 let mut punct = Punct::new(';', Spacing::Alone);
577 punct.set_span(impl_method.body.span());
578 punct
579 },
580 }))
581 }
582 }
583}
584
585fn find_remove(attrs: &mut Vec<Attribute>, kind: AttributeKind) {
586 while let Some(i) = attrs.iter().position(|attr| attr.kind == kind) {
587 attrs.remove(i);
588 }
589}
590
591fn hash(input: &TokenStream) -> u64 {
593 let mut hasher = DefaultHasher::new();
594 hasher.write(input.to_string().as_bytes());
595 hasher.finish()
596}