eff_attr/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use proc_macro_hack::proc_macro_hack;
5use quote::{quote, ToTokens};
6use syn::parse::Parse;
7
8type TokenStream2 = proc_macro2::TokenStream;
9
10enum CommaOrColon {
11    Comma(syn::token::Comma),
12    Colon(syn::token::Colon),
13}
14
15impl Parse for CommaOrColon {
16    fn parse(input: syn::parse::ParseStream) -> syn::parse::Result<Self> {
17        if input.peek(syn::token::Comma) {
18            input.parse().map(CommaOrColon::Comma)
19        } else {
20            input.parse().map(CommaOrColon::Colon)
21        }
22    }
23}
24
25impl ToTokens for CommaOrColon {
26    fn to_tokens(&self, tokens: &mut TokenStream2) {
27        match self {
28            CommaOrColon::Comma(comma) => comma.to_tokens(tokens),
29            CommaOrColon::Colon(colon) => colon.to_tokens(tokens),
30        }
31    }
32}
33
34fn wrap_pattern(tokens: impl ToTokens, n: usize) -> TokenStream2 {
35    if n == 0 {
36        tokens.into_token_stream()
37    } else {
38        wrap_pattern(quote! { eff::coproduct::Either::B(#tokens) }, n - 1)
39    }
40}
41
42/// Declare the function to be an effectful computation whose effect type is the coproduct of the arguments
43///
44/// This macro transforms the function like what `async fn` does
45#[proc_macro_attribute]
46pub fn eff(attr: TokenStream, item: TokenStream) -> TokenStream {
47    use syn::parse::Parser;
48    use syn::punctuated::Punctuated;
49
50    if let Ok(syn::Item::Fn(mut func)) = syn::parse(item.clone()) {
51        let effects_parser = Punctuated::<syn::Type, CommaOrColon>::parse_terminated;
52        let types = effects_parser
53            .parse(attr)
54            .expect("failed to parse attribute");
55
56        let mut ret = TokenStream2::new();
57
58        let effects_type_name = quote! {
59            eff::Coproduct![#types]
60        };
61
62        func.sig.output = syn::parse2(match func.sig.output {
63            syn::ReturnType::Default => quote! {
64                -> impl eff::Effectful<Output = (), Effect = #effects_type_name>
65            },
66            syn::ReturnType::Type(arrow, ty) => quote! {
67                #arrow impl eff::Effectful<Output = #ty, Effect = #effects_type_name>
68            },
69        })
70        .expect("return type is invalid");
71
72        let original_block = func.block;
73        func.block = syn::parse2(quote! {
74            {
75                eff::from_generator(static move || {
76                    if false {
77                        yield unreachable!();
78                    }
79
80                    #original_block
81                })
82            }
83        })
84        .expect("function block is invalid");
85
86        // supress warning
87        ret.extend(quote! {
88            #[allow(unreachable_code)]
89            #func
90        });
91
92        ret.into()
93    } else if let Ok(syn::Expr::Match(mut m)) = syn::parse(item) {
94        // Provide a nice pattern-match syntax for polling result of an effectful computation.
95        // Concretely, convert a match clause of the form
96        // ```
97        // match <poll_expr> {
98        //     <value_pattern> => <value_body>,
99        //     (<eff1>, <k1>) => <eff_body1>,
100        //     ...
101        //     (<effN>, <kN>) => <eff_bodyN>,
102        // }
103        // ```
104        // into
105        // ```
106        // match <poll_expr> {
107        //     Complete(<value_pattern>) => <value_body>,
108        //     Effect((A(<eff1>, <k1>) => <eff_body1>,
109        //     ...
110        //     Effect(B(B(...B(A(<effN>, <kN>))...))) => <eff_bodyN>,
111        //     Effect(B(B(...B(B(__rest))...))) => reperform_rest!(__rest),
112        // }
113        // ```
114        // TODO: Currently, this macro doesn't support other forms such as if guards.
115        assert!(
116            m.arms.len() >= 1,
117            "An effect match clause must have an arm for the value pattern"
118        );
119        {
120            let pat = m.arms[0].pat.clone();
121            m.arms[0].pat = syn::parse2(quote! { eff::Event::Complete(#pat) })
122                .expect("value pattern is invalid");
123        }
124        for (idx, ref mut arm) in m.arms[1..].iter_mut().enumerate() {
125            let pat = arm.pat.clone();
126            let wrapped = wrap_pattern(quote! { eff::coproduct::Either::A #pat }, idx);
127            arm.pat = syn::parse2(quote! { eff::Event::Effect(#wrapped) })
128                .expect(&format!("{}'th pattern is invalid", idx));
129        }
130        {
131            let ident = quote! { __rest };
132            let wrapped = wrap_pattern(&ident, m.arms.len() - 1);
133            // allow unreachable as there can be no remaining effects
134            m.arms.push(
135                syn::parse2(quote! {
136                    #[allow(unreachable_code)] eff::Event::Effect(#wrapped) => eff::reperform_rest!(#ident),
137                })
138                .expect("reperform arm is invalid"),
139            );
140        }
141        m.into_token_stream().into()
142    } else {
143        panic!("eff couldn't parse the content");
144    }
145}
146
147#[proc_macro_hack]
148pub fn poll(input: TokenStream) -> TokenStream {
149    use syn::parse::Parser;
150    use syn::punctuated::Punctuated;
151
152    let parser = Punctuated::<syn::Expr, syn::token::Comma>::parse_terminated;
153    let exprs = parser.parse(input).expect("failed to parse input");
154
155    let mut ret = TokenStream2::new();
156    let names = exprs
157        .iter()
158        .enumerate()
159        .map(|(idx, expr)| {
160            let name = quote::format_ident!("__comp{}", idx);
161            ret.extend(quote! {
162                let mut #name = eff::Effectful::next_event(#expr);
163            });
164            name
165        })
166        .collect::<Vec<_>>();
167
168    ret.extend(quote! {
169        loop {
170            let mut __all_occured = true;
171
172            #(
173                if let eff::Poll::Pending = eff::poll_with_task_context(unsafe { eff::pin_reexport::Pin::new_unchecked(&mut #names) }) {
174                    __all_occured = false;
175                }
176            )*
177
178            if __all_occured {
179                break (#( unsafe { eff::pin_reexport::Pin::new_unchecked(&mut #names) }.take_event().unwrap() ),*);
180            }
181        }
182    });
183
184    quote!({ #ret }).into()
185}