1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
use crate::AccountsStruct;
use quote::quote;
use std::iter;
use syn::punctuated::Punctuated;
use syn::{ConstParam, LifetimeDef, Token, TypeParam};
use syn::{GenericParam, PredicateLifetime, WhereClause, WherePredicate};

mod __client_accounts;
mod __cpi_client_accounts;
mod constraints;
mod exit;
mod to_account_infos;
mod to_account_metas;
mod try_accounts;

pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream {
    let impl_try_accounts = try_accounts::generate(accs);
    let impl_to_account_infos = to_account_infos::generate(accs);
    let impl_to_account_metas = to_account_metas::generate(accs);
    let impl_exit = exit::generate(accs);

    let __client_accounts_mod = __client_accounts::generate(accs);
    let __cpi_client_accounts_mod = __cpi_client_accounts::generate(accs);

    quote! {
        #impl_try_accounts
        #impl_to_account_infos
        #impl_to_account_metas
        #impl_exit

        #__client_accounts_mod
        #__cpi_client_accounts_mod
    }
}

fn generics(accs: &AccountsStruct) -> ParsedGenerics {
    let trait_lifetime = accs
        .generics
        .lifetimes()
        .next()
        .cloned()
        .unwrap_or_else(|| syn::parse_str("'info").expect("Could not parse lifetime"));

    let mut where_clause = accs.generics.where_clause.clone().unwrap_or(WhereClause {
        where_token: Default::default(),
        predicates: Default::default(),
    });
    for lifetime in accs.generics.lifetimes().map(|def| &def.lifetime) {
        where_clause
            .predicates
            .push(WherePredicate::Lifetime(PredicateLifetime {
                lifetime: lifetime.clone(),
                colon_token: Default::default(),
                bounds: iter::once(trait_lifetime.lifetime.clone()).collect(),
            }))
    }
    let trait_lifetime = GenericParam::Lifetime(trait_lifetime);

    ParsedGenerics {
        combined_generics: if accs.generics.lifetimes().next().is_some() {
            accs.generics.params.clone()
        } else {
            iter::once(trait_lifetime.clone())
                .chain(accs.generics.params.clone())
                .collect()
        },
        trait_generics: iter::once(trait_lifetime).collect(),
        struct_generics: accs
            .generics
            .params
            .clone()
            .into_iter()
            .map(|param: GenericParam| match param {
                GenericParam::Const(ConstParam { ident, .. })
                | GenericParam::Type(TypeParam { ident, .. }) => GenericParam::Type(TypeParam {
                    attrs: vec![],
                    ident,
                    colon_token: None,
                    bounds: Default::default(),
                    eq_token: None,
                    default: None,
                }),
                GenericParam::Lifetime(LifetimeDef { lifetime, .. }) => {
                    GenericParam::Lifetime(LifetimeDef {
                        attrs: vec![],
                        lifetime,
                        colon_token: None,
                        bounds: Default::default(),
                    })
                }
            })
            .collect(),
        where_clause,
    }
}

struct ParsedGenerics {
    pub combined_generics: Punctuated<GenericParam, Token![,]>,
    pub trait_generics: Punctuated<GenericParam, Token![,]>,
    pub struct_generics: Punctuated<GenericParam, Token![,]>,
    pub where_clause: WhereClause,
}