1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
extern crate proc_macro;
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span};

use quote::quote;
use syn::punctuated::Punctuated;
use syn::token::Where;
use syn::visit_mut::VisitMut;
use syn::{parse_macro_input, ImplItem, ItemImpl, Result, WhereClause};

mod operation_syntax;
use operation_syntax::*;

mod closure_definition_syntax;
use closure_definition_syntax::*;

mod search_replace;
use search_replace::*;

#[proc_macro_attribute]
pub fn expand_operations(attr: TokenStream, item: TokenStream) -> TokenStream {
    // Parse input and impl block.
    let operation_sequence = parse_macro_input!(attr as OperationSequence);
    let item = parse_macro_input!(item as ItemImpl);

    let mut impl_blocks = Vec::new();

    // Apply changes corresponding to each operation in the sequence.
    for operation in operation_sequence.iter() {
        let name = match &operation.alias {
            Some(alias) => alias,
            None => &operation.ident,
        };
        let mut impl_block = item.clone();

        // Turn impl to trait impl if needed
        if let Some(trait_impl) = &operation.trait_impl {
            if let Some(trait_) = &mut impl_block.trait_ {
                if let Some(trait_path) = trait_.1.segments.last_mut() {
                    trait_path.ident = trait_impl.clone();
                }
            }
        }

        // Visitor that replaces calls to `placeholer` by the operation's name.
        let mut placeholder_method_visitor = FindReplaceExprMethodCall {
            find: Ident::new("placeholder", Span::call_site()),
            replace: operation.ident.clone(),
        };

        // Visitor that replaces calls to `unchecked` by the operation's unchecked method.
        let mut unchecked_method_visitor = FindReplaceExprMethodCall {
            find: Ident::new("unchecked", Span::call_site()),
            replace: Ident::new(&format!("{}_unchecked", name), Span::call_site()),
        };
        let mut unchecked_underscore_method_visitor = FindReplaceExprMethodCall {
            find: Ident::new("unchecked_", Span::call_site()),
            replace: Ident::new(&format!("{}_unchecked_", name), Span::call_site()),
        };

        // Add trait bound if specified by the operation.
        // Marginalize generic if an equality bound is specified instead.
        if let Some(bound) = &operation.bound {
            match bound {
                OperationBound::Type(trait_bound) => {
                    impl_block.generics.where_clause =
                        if let Some(mut where_clause) = impl_block.generics.where_clause {
                            where_clause.predicates.push(trait_bound.clone());
                            Some(where_clause)
                        } else {
                            let mut predicates = Punctuated::new();
                            predicates.push(trait_bound.clone());

                            Some(WhereClause {
                                where_token: Where::default(),
                                predicates,
                            })
                        }
                }
                OperationBound::Eq(margin) => {
                    let mut generics_visitor = RemoveGenerics {
                        find: margin.lhs_ty.clone(),
                    };
                    let mut type_visitor = FindReplaceType {
                        find: margin.lhs_ty.clone(),
                        replace: margin.rhs_ty.clone(),
                    };
                    generics_visitor.visit_item_impl_mut(&mut impl_block);
                    type_visitor.visit_item_impl_mut(&mut impl_block);

                    for attribute in impl_block.attrs.iter_mut() {
                        let parsed_attribute: Result<ClosureDefinition> = attribute.parse_args();
                        match parsed_attribute {
                            Ok(def) => {
                                let fn_name = def.fn_name;
                                let mut closure = def.closure;
                                let rhs_ty = &margin.rhs_ty;
                                let mut visitor = FindReplaceIdent {
                                    find: margin.lhs_ty.clone(),
                                    replace: Ident::new(
                                        &format!("{}", quote! { #rhs_ty }),
                                        Span::call_site(),
                                    ),
                                };
                                visitor.visit_expr_closure_mut(&mut closure);

                                attribute.tokens = quote! { (#fn_name:#closure) };
                            }
                            Err(_) => {}
                        }
                    }
                }
            }
        }

        // Replace placeholder types with given types.
        if let Some(types) = &operation.types {
            for (i, t) in types.iter().enumerate() {
                let mut pat_type_visitor = FindReplacePatType {
                    find: Ident::new(&format!("type{}", i), Span::call_site()),
                    replace: t.clone(),
                };

                pat_type_visitor.visit_item_impl_mut(&mut impl_block);
            }
        }
        // Adapt all methods to the operation.
        for impl_item in impl_block.items.iter_mut() {
            if let ImplItem::Method(method) = impl_item {
                // Prepend the operation name to the method's name.
                // If the method's name constains operation, replace it with the operation name.
                let mut method_name = format!("{}", method.sig.ident);
                method.sig.ident = if let Some(position) = method_name.find("operation") {
                    method_name.replace_range(position..position + 9, &format!("{}", name));
                    Ident::new(&method_name, Span::call_site())
                } else {
                    Ident::new(&format!("{}_{}", name, method_name), Span::call_site())
                };

                placeholder_method_visitor.visit_block_mut(&mut method.block);
                unchecked_method_visitor.visit_block_mut(&mut method.block);
                unchecked_underscore_method_visitor.visit_block_mut(&mut method.block);
            }
        }

        impl_blocks.push(impl_block);
    }

    let result = quote! {
        #(#impl_blocks)*
    };
    result.into()
}

#[proc_macro_attribute]
pub fn define_closure(attr: TokenStream, item: TokenStream) -> TokenStream {
    let attr = parse_macro_input!(attr as ClosureDefinition);
    let mut item = parse_macro_input!(item as ItemImpl);

    for impl_item in item.items.iter_mut() {
        if let ImplItem::Method(method) = impl_item {
            if attr.fn_name == method.sig.ident {
                let mut visitor = ReplaceExprClosure {
                    replace: attr.closure,
                };

                visitor.visit_block_mut(&mut method.block);
                break;
            }
        }
    }

    let result = quote! {
        #item
    };
    result.into()
}