futurize_derive/
lib.rs

1#![recursion_limit="512"]
2
3extern crate proc_macro;
4extern crate proc_macro2;
5extern crate syn;
6extern crate heck;
7
8#[macro_use]
9extern crate quote;
10
11use proc_macro::TokenStream;
12use proc_macro2::{Ident};
13use syn::DeriveInput;
14use heck::SnakeCase;
15use quote::ToTokens;
16
17
18#[proc_macro_derive(Worker, attributes(returns))]
19pub fn derive_worker(input: TokenStream) -> TokenStream {
20
21    let ast: DeriveInput = syn::parse(input).unwrap();
22    let dnum = match ast.data {
23        syn::Data::Enum(v) => v,
24        _ => panic!("must be enum"),
25    };
26
27    let name            = &ast.ident;
28
29    let mut rets        = Vec::new();
30    let mut call_fns    = Vec::new();
31    let mut trait_fns   = Vec::new();
32    let mut matches     = Vec::new();
33
34    for variant in dnum.variants {
35        let varname = variant.ident;
36        let fname = Ident::new(&format!("{}", varname).to_snake_case(), varname.span());
37        let mut args = Vec::new();
38        let mut argnames  = Vec::new();
39        match variant.fields {
40            syn::Fields::Named(fields) => {
41                for field in fields.named {
42                    let name = field.ident.unwrap();
43                    let typ  = field.ty.into_token_stream();
44
45                    args.push(quote!{
46                        #name : #typ
47                    });
48                    argnames.push(name);
49                }
50            },
51            syn::Fields::Unnamed(_) => {
52                panic!("cannot use unnamed args");
53            },
54            syn::Fields::Unit => (),
55        };
56
57        let mut returns = quote!(());
58        for attr in variant.attrs {
59            if attr.path.segments.len() == 1 {
60                if format!("{}", attr.path.segments[0].ident) == "returns" {
61                    let meta = attr.interpret_meta().expect("cannot parse as meta");
62                    let meta = match meta {
63                        syn::Meta::NameValue(m) => m,
64                        _ =>  panic!("needs name value pair like '#[returns = \"u8\"]"),
65                    };
66
67                    let meta = match meta.lit {
68                        syn::Lit::Str(s) => s.value(),
69                        _ =>  panic!("needs name value pair like '#[returns = \"u8\"]"),
70                    };
71
72                    let meta : syn::Type = syn::parse_str(&meta).unwrap();
73                    returns = meta.into_token_stream();
74                }
75            }
76        }
77
78
79        let varname_    = varname.clone();
80        rets.push(quote!{
81            #varname_(#returns)
82        });
83
84
85        let args_ = args.clone();
86        trait_fns.push(quote! {
87            fn #fname(self, #(#args_),*) -> R<Self,#returns>;
88        });
89
90        let name_       = name.clone();
91        let varname_    = varname.clone();
92        let argnames_   = argnames.clone();
93        let callarg = if argnames.len() > 0 {quote! {
94            #name_::#varname_{#(#argnames_),*}
95        }} else { quote! {
96            #name_::#varname_
97        }};
98        call_fns.push(quote!{
99            pub fn #fname(&mut self, #(#args),*) -> impl futures::Future<Item=#returns, Error=Error> {
100                let (tx,rx) = oneshot::channel();
101                self.tx.clone().send((tx, #callarg))
102                    .map_err(Error::from)
103                    .and_then(|_|rx.map_err(Error::from))
104                    .and_then(|r|r)
105                    .map(|v|{
106                        match v {
107                            Return::#varname_(r) => r,
108                            _ => unreachable!()
109                        }
110                    })
111                    .map_err(Error::from)
112            }
113        });
114
115        let argnames_ = argnames.clone();
116        let then = quote!{
117            then(|v|{
118                match v {
119                    Err((s,e)) => {
120                        ret.send(Err(e)).ok();
121                        s
122                    },
123                    Ok((s,v)) => {
124                        ret.send(Ok(Return::#varname_(v))).ok();
125                        s
126                    }
127                }.ok_or(())
128            })
129        };
130        let mcall = if argnames.len() > 0 { quote! {
131            #name::#varname { #(#argnames),* } => {
132                let ft = t.#fname(#(#argnames_),*)
133                    .#then;
134                Box::new(ft) as Box<Future<Item=T,Error=()> + Send + Sync>
135            }
136        }} else { quote! {
137            #name::#varname => {
138                let ft = t.#fname()
139                    .#then;
140                Box::new(ft) as Box<Future<Item=T,Error=()> + Send + Sync>
141            }
142        }};
143
144        matches.push(mcall);
145    }
146
147    let matches_ = matches.clone();
148    let expanded = quote! {
149        use futures;
150        use futures::Stream;
151        use futures::Sink;
152        use futures::Future;
153        use futures::sync::oneshot;
154        use failure::Error;
155        use std::time::Instant;
156        use futures::future::Either;
157
158        #[derive(Clone)]
159        pub struct Handle {
160            tx: futures::sync::mpsc::Sender<(oneshot::Sender<Result<Return,Error>>, #name)>,
161        }
162        impl Handle {
163            #(#call_fns)*
164        }
165
166        enum Return {
167            #(#rets),*
168        }
169
170        pub type R<S: Worker + Sized, T> = Box<Future<Item=(Option<S>, T), Error=(Option<S>, Error)> + Sync + Send>;
171
172        pub trait Worker
173            where Self: Sized,
174        {
175
176            #(#trait_fns)*
177
178            fn canceled(self) {}
179            fn interval(self, Instant) -> Box<Future<Item=Option<Self>, Error=()> + Sync + Send> {
180                panic!("must implement Worker::interval if using spawn_with_interval");
181            }
182        }
183
184        pub fn spawn<T: Worker> (buffer: usize, t: T)
185            -> (impl Future<Item=(), Error=()>, Handle)
186            where T: 'static + Send + Sync
187        {
188            let (tx,rx) = futures::sync::mpsc::channel(buffer);
189
190            let ft = rx.fold(t, |t, (ret, m) : (oneshot::Sender<Result<Return, Error>>, #name)|{
191                match m {
192                    #(#matches),*
193                }
194            }).and_then(|t|{
195                t.canceled();
196                Ok(())
197            });
198
199            (
200                ft,
201                Handle {
202                    tx,
203                },
204            )
205        }
206
207        pub fn spawn_with_interval<T: Worker, I: Stream<Item=Instant, Error=()>> (buffer: usize, t: T, i : I)
208            -> (impl Future<Item=(), Error=()>, Handle)
209            where T: 'static + Send + Sync
210        {
211            let (tx,rx) = futures::sync::mpsc::channel(buffer);
212
213            let i  = i.map(|i|Either::A(i));
214            let rx = rx.map(|i|Either::B(i));
215            let rx = rx.select(i);
216
217            let ft = rx.fold(t, |t, either : Either<Instant, (oneshot::Sender<Result<Return, Error>>, #name)>|{
218                match either {
219                    Either::A(i) => {
220                        let ft = t.interval(i)
221                            .and_then(|v|{
222                                v.ok_or(())
223                            });
224                        Box::new(ft) as Box<Future<Item=T,Error=()> + Send + Sync>
225                    },
226                    Either::B((ret,m)) => match m {
227                        #(#matches_),*
228                    }
229                }
230            }).and_then(|t|{
231                t.canceled();
232                Ok(())
233            });
234
235            (
236                ft,
237                Handle {
238                    tx,
239                },
240            )
241        }
242    };
243
244    // Hand the output tokens back to the compiler.
245    expanded.into()
246}