cgp_macro_lib/entrypoints/
cgp_impl.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{ToTokens, quote};
3use syn::parse::discouraged::Speculative;
4use syn::parse::{Parse, ParseStream};
5use syn::spanned::Spanned;
6use syn::token::{Colon, For};
7use syn::{Error, FnArg, Ident, ImplItem, ItemImpl, Type, parse2};
8
9use crate::derive_provider::{
10    derive_component_name_from_provider_impl, derive_is_provider_for, derive_provider_struct,
11};
12use crate::parse::SimpleType;
13use crate::replace_self::{
14    replace_self_receiver, replace_self_type, replace_self_var, to_snake_case_ident,
15};
16
17pub fn cgp_impl(attr: TokenStream, body: TokenStream) -> syn::Result<TokenStream> {
18    let spec: ImplProviderSpec = parse2(attr)?;
19    let item_impl: ItemImpl = parse2(body)?;
20
21    let consumer_trait_path = &item_impl
22        .trait_
23        .as_ref()
24        .ok_or_else(|| Error::new(item_impl.span(), "expect impl trait to contain path"))?
25        .1;
26
27    let consumer_trait_path: SimpleType = parse2(consumer_trait_path.to_token_stream())?;
28
29    let provider_impl =
30        transform_impl_trait(&item_impl, &consumer_trait_path, &spec.provider_type)?;
31
32    let component_type = match &spec.component_type {
33        Some(component_type) => component_type.clone(),
34        None => derive_component_name_from_provider_impl(&provider_impl)?,
35    };
36
37    let is_provider_for_impl: ItemImpl = derive_is_provider_for(&component_type, &provider_impl)?;
38
39    let provider_struct = if spec.new_struct {
40        Some(derive_provider_struct(&provider_impl)?)
41    } else {
42        None
43    };
44
45    Ok(quote! {
46        #provider_struct
47
48        #provider_impl
49
50        #is_provider_for_impl
51    })
52}
53
54pub struct ImplProviderSpec {
55    pub new_struct: bool,
56    pub provider_type: Type,
57    pub component_type: Option<Type>,
58}
59
60impl Parse for ImplProviderSpec {
61    fn parse(input: ParseStream) -> syn::Result<Self> {
62        let new_struct = {
63            let fork = input.fork();
64            let new_ident: Option<Ident> = fork.parse().ok();
65            match new_ident {
66                Some(new_ident) if new_ident == "new" => {
67                    input.advance_to(&fork);
68                    true
69                }
70                _ => false,
71            }
72        };
73
74        let provider_type = input.parse()?;
75
76        let component_type = if let Some(_colon) = input.parse::<Option<Colon>>()? {
77            let component_type: Type = input.parse()?;
78            Some(component_type)
79        } else {
80            None
81        };
82
83        Ok(ImplProviderSpec {
84            new_struct,
85            provider_type,
86            component_type,
87        })
88    }
89}
90
91pub fn transform_impl_trait(
92    item_impl: &ItemImpl,
93    consumer_trait_path: &SimpleType,
94    provider_type: &Type,
95) -> syn::Result<ItemImpl> {
96    let context_type = item_impl.self_ty.as_ref();
97
98    let context_var = if let Ok(ident) = parse2::<Ident>(context_type.to_token_stream()) {
99        to_snake_case_ident(&ident)
100    } else {
101        Ident::new("__context__", Span::call_site())
102    };
103
104    let local_assoc_types: Vec<Ident> = item_impl
105        .items
106        .iter()
107        .filter_map(|item| {
108            if let ImplItem::Type(assoc_type) = item {
109                Some(assoc_type.ident.clone())
110            } else {
111                None
112            }
113        })
114        .collect();
115
116    let raw_out_impl = replace_self_type(
117        item_impl.to_token_stream(),
118        context_type.to_token_stream(),
119        &local_assoc_types,
120    );
121
122    let mut out_impl: ItemImpl = parse2(raw_out_impl)?;
123    out_impl.self_ty = Box::new(provider_type.clone());
124
125    let mut provider_trait_path: SimpleType = consumer_trait_path.clone();
126
127    match &mut provider_trait_path.generics {
128        Some(generics) => {
129            generics
130                .args
131                .insert(0, parse2(context_type.to_token_stream())?);
132        }
133        None => {
134            provider_trait_path.generics = Some(parse2(quote! { < #context_type > })?);
135        }
136    }
137
138    out_impl.trait_ = Some((
139        None,
140        parse2(provider_trait_path.to_token_stream())?,
141        For(Span::call_site()),
142    ));
143
144    for item in out_impl.items.iter_mut() {
145        if let ImplItem::Fn(item_fn) = item
146            && let Some(arg) = item_fn.sig.inputs.first_mut()
147            && let FnArg::Receiver(receiver) = arg
148        {
149            *arg = replace_self_receiver(receiver, &context_var, context_type.to_token_stream());
150
151            let replaced_block = replace_self_var(item_fn.block.to_token_stream(), &context_var);
152            item_fn.block = parse2(replaced_block)?;
153        }
154    }
155
156    Ok(out_impl)
157}