pluginop_macro/
lib.rs

1//! A set of attribute macros, to be used in the source code of the host implementation,
2//! to ease the process of making it pluginizable, e.g., by transforming a regular Rust
3//! function into a plugin operation.
4
5use darling::FromMeta;
6use proc_macro::TokenStream;
7use quote::{format_ident, quote};
8use syn::{
9    parse_macro_input, punctuated::Punctuated, AttributeArgs, Expr, FnArg, GenericArgument, Ident,
10    ItemFn, Pat, PatType, Path, ReturnType, Type,
11};
12
13extern crate proc_macro;
14
15/// Extracts the `PatType` of a `FnArg`.
16fn extract_arg_pat(a: FnArg) -> Option<Pat> {
17    match a {
18        FnArg::Typed(p) => Some(*p.pat),
19        _ => None,
20    }
21}
22
23/// Retrieves the argument identifiers of a function.
24fn extract_arg_idents(fn_args: Punctuated<FnArg, syn::token::Comma>) -> Vec<Pat> {
25    fn_args
26        .into_iter()
27        .filter_map(extract_arg_pat)
28        .collect::<Vec<_>>()
29}
30
31fn extract_arg_idents_vec(fn_args: Vec<FnArg>) -> Vec<Pat> {
32    fn_args
33        .into_iter()
34        .filter_map(extract_arg_pat)
35        .collect::<Vec<_>>()
36}
37
38// First boolean returns whether the type is `Octets` or `OctetsMut`. Second returns whether the
39// type is exactly `OctetsMut`, The third indicates whether the reference is mutable or not.
40fn has_octets(pt: &PatType) -> (bool, bool, bool) {
41    match &*pt.ty {
42        Type::Reference(tref) => match &*tref.elem {
43            Type::Path(p) => {
44                if p.path
45                    .segments
46                    .iter()
47                    .any(|ps| &ps.ident.to_string() == "Octets")
48                {
49                    (true, false, tref.mutability.is_some())
50                } else if p
51                    .path
52                    .segments
53                    .iter()
54                    .any(|ps| &ps.ident.to_string() == "OctetsMut")
55                {
56                    (true, true, tref.mutability.is_some())
57                } else {
58                    (false, false, false)
59                }
60            }
61            _ => (false, false, false),
62        },
63        _ => (false, false, false),
64    }
65}
66
67fn is_result_unit(ty: &syn::Type) -> bool {
68    match ty {
69        Type::Path(tp) => {
70            if let Some(ps) = tp
71                .path
72                .segments
73                .iter()
74                .find(|s| &s.ident.to_string() == "Result")
75            {
76                if let syn::PathArguments::AngleBracketed(ab) = &ps.arguments {
77                    if ab.args.is_empty() {
78                        return false;
79                    }
80                    if let Some(GenericArgument::Type(syn::Type::Tuple(tu))) = ab.args.first() {
81                        return tu.elems.is_empty();
82                    }
83                }
84            }
85            false
86        }
87        _ => false,
88    }
89}
90
91fn get_param_block(
92    args: &Punctuated<FnArg, syn::token::Comma>,
93    ignore: Option<Ident>,
94    with_octets: bool,
95) -> proc_macro2::TokenStream {
96    let args_code: Vec<proc_macro2::TokenStream> = args
97        .iter()
98        .filter_map(|a| match a {
99            FnArg::Typed(pt) => {
100                let pat = &pt.pat;
101                match has_octets(pt) {
102                    (true, _, true) if !with_octets => None,
103                    (true, false, true) => Some(quote!( OctetsPtr::from(#pat).into_with_ph(ph) )),
104                    (true, true, true) => Some(quote!( OctetsMutPtr::from(#pat).into_with_ph(ph) )),
105                    (true, _, false) => panic!("Octets argument must be mutable"),
106                    _ => {
107                        if let Some(ign) = &ignore {
108                            if let Pat::Ident(pi) = &*pt.pat {
109                                if pi.ident == *ign {
110                                    return None;
111                                }
112                            }
113                        }
114
115                        Some(quote!( #pat.clone().into_with_ph(ph) ))
116                    }
117                }
118            }
119            _ => None,
120        })
121        .collect();
122    quote!(
123        [
124            #(#args_code ,)*
125        ]
126    )
127}
128
129fn get_ret_block(fn_output_type: &ReturnType) -> proc_macro2::TokenStream {
130    match fn_output_type {
131        syn::ReturnType::Default => quote!({
132            if let Err(err) = res {
133                panic!("plugin execution error: {:?}", err);
134            }
135        }),
136        syn::ReturnType::Type(_, t) => {
137            if let Type::Tuple(tu) = *t.clone() {
138                let elems = tu.elems.into_iter();
139                quote! {
140                    let mut it = match res {
141                        Ok(r) => r.into_iter(),
142                        Err(pluginop::Error::OperationError(e)) => todo!("operation error {:?}; should you use pluginop_result?", e),
143                        Err(err) => panic!("plugin execution error: {:?}", err),
144                    };
145                    (
146                        #(
147                            #elems :: try_from(it.next().unwrap()).unwrap(),
148                        )*
149                    )
150                }
151            } else {
152                quote!(
153                    let mut it = match res {
154                        Ok(r) => r.into_iter(),
155                        Err(pluginop::Error::OperationError(e)) => todo!("operation error {:?}; should you use pluginop_result?", e),
156                        Err(err) => panic!("plugin execution error: {:?}", err),
157                    };
158                    { it.next().unwrap().try_into().unwrap() }
159                )
160            }
161        }
162    }
163}
164
165fn get_ret_result_block(fn_output_type: &ReturnType) -> proc_macro2::TokenStream {
166    match fn_output_type {
167        syn::ReturnType::Default => quote!({
168            if let Err(err) = res {
169                panic!("plugin execution error: {:?}", err);
170            }
171        }),
172        syn::ReturnType::Type(_, t) => {
173            if let Type::Tuple(tu) = *t.clone() {
174                let elems = tu.elems.into_iter();
175                quote! {
176                    let mut it = match res {
177                        Ok(r) => r.into_iter(),
178                        Err(pluginop::Error::OperationError(e)) => return Err(e.into()),
179                        Err(err) => panic!("plugin execution error: {:?}", err),
180                    };
181                    Ok((
182                        #(
183                            #elems :: try_from_with_ph(it.next().unwrap(), ph).unwrap(),
184                        )*
185                    ))
186                }
187            } else {
188                // We need to check if this is the unit type.
189                if is_result_unit(t) {
190                    quote!(match res {
191                        Ok(r) => Ok(()),
192                        Err(pluginop::Error::OperationError(e)) => Err(e.into()),
193                        Err(err) => panic!("plugin execution error: {:?}", err),
194                    })
195                } else {
196                    quote!(
197                        let mut it = match res {
198                            Ok(r) => r.into_iter(),
199                            Err(pluginop::Error::OperationError(e)) => return Err(e.into()),
200                            Err(err) => panic!("plugin execution error: {:?}", err),
201                        };
202                        match it.next() {
203                            Some(r) => Ok(r.try_into_with_ph(ph).unwrap()),
204                            None => panic!("Missing output from the plugin"),
205                        }
206                    )
207                }
208            }
209        }
210    }
211}
212
213fn get_out_block(
214    base_fn: &ItemFn,
215    po: &Path,
216    value: Option<Expr>,
217    ret_block: &proc_macro2::TokenStream,
218) -> proc_macro2::TokenStream {
219    let fn_args = extract_arg_idents(base_fn.sig.inputs.clone());
220    let fn_inputs = &base_fn.sig.inputs;
221    let mut fn_inputs_no_self = fn_inputs.clone();
222    fn_inputs_no_self.pop();
223    let fn_vis = &base_fn.vis;
224    let fn_name = &base_fn.sig.ident;
225    let fn_block = &base_fn.block;
226    let fn_output = &base_fn.sig.output;
227    let fn_name_internal = format_ident!("__{}__", fn_name);
228    let param_code = get_param_block(fn_inputs, None, true);
229    let param_code_prepost = get_param_block(fn_inputs, None, false);
230
231    let po_code = if let Some(v) = value {
232        quote! { #po ( #v ) }
233    } else {
234        quote! { #po }
235    };
236
237    quote! {
238        #[allow(unused_variables)]
239        fn #fn_name_internal(#fn_inputs) #fn_output {
240            #fn_block
241        }
242
243        #fn_vis fn #fn_name(#fn_inputs) #fn_output {
244            use pluginop::api::ToPluginizableConnection;
245            use pluginop::Error;
246            use pluginop::IntoWithPH;
247            use pluginop::TryIntoWithPH;
248            use pluginop::octets::OctetsMutPtr;
249            use pluginop::octets::OctetsPtr;
250            let ph = self.get_pluginizable_connection().map(|pc| pc.get_ph_mut());
251            if let Some(ph) = ph {
252                if ph.provides(& #po_code, pluginop::common::Anchor::Define) {
253                    let params = & #param_code;
254                    let res = ph.call(
255                        & #po_code,
256                        params,
257                    );
258                    ph.clear_bytes_content();
259
260                    #ret_block
261                } else {
262                    let has_before = ph.provides(& #po_code, pluginop::common::Anchor::Before);
263                    let has_after = ph.provides(& #po_code, pluginop::common::Anchor::After);
264                    let params = if has_before || has_after { Some(#param_code_prepost) } else { None };
265                    if has_before {
266                        ph.call_direct(
267                            & #po_code,
268                            pluginop::common::Anchor::Before,
269                            params.as_ref().unwrap(),
270                        ).ok();
271                    }
272                    let ret = self.#fn_name_internal(#(#fn_args,)*);
273                    if has_after {
274                        if let Some(ph) = self.get_pluginizable_connection().map(|pc| pc.get_ph_mut()) {
275                            ph.call_direct(
276                                & #po_code,
277                                pluginop::common::Anchor::After,
278                                params.as_ref().unwrap(),
279                            ).ok();
280                        }
281                    }
282                    ret
283                }
284            } else {
285                self.#fn_name_internal(#(#fn_args,)*)
286            }
287        }
288    }
289}
290
291fn get_out_param_block(
292    param: Ident,
293    base_fn: &ItemFn,
294    po: &Path,
295    ret_block: &proc_macro2::TokenStream,
296) -> proc_macro2::TokenStream {
297    let fn_output = &base_fn.sig.output;
298    let fn_inputs = &base_fn.sig.inputs;
299    let fn_inputs_iter: Vec<FnArg> = fn_inputs.clone().into_iter().collect();
300    let fn_inputs_no_self = fn_inputs_iter.clone();
301    let fn_args = extract_arg_idents_vec(fn_inputs_no_self);
302    let fn_vis = &base_fn.vis;
303    let fn_name = &base_fn.sig.ident;
304    let fn_block = &base_fn.block;
305    let fn_name_internal = format_ident!("__{}__", fn_name);
306    let param_code = get_param_block(fn_inputs, Some(param.clone()), true);
307    let param_code_prepost = get_param_block(fn_inputs, Some(param.clone()), false);
308
309    quote! {
310        #[allow(unused_variables)]
311        fn #fn_name_internal(#(#fn_inputs_iter,)*) #fn_output {
312            #fn_block
313        }
314
315        #fn_vis fn #fn_name(#fn_inputs) #fn_output {
316            use pluginop::api::ToPluginizableConnection;
317            use pluginop::IntoWithPH;
318            use pluginop::TryIntoWithPH;
319            use pluginop::octets::OctetsMutPtr;
320            use pluginop::octets::OctetsPtr;
321            let ph = self.get_pluginizable_connection().map(|pc| pc.get_ph_mut());
322            if let Some(ph) = ph {
323                if ph.provides(& #po(#param), pluginop::common::Anchor::Define) {
324                    let params = & #param_code;
325                    let res = ph.call(
326                        & #po(#param),
327                        params,
328                    );
329                    ph.clear_bytes_content();
330
331                    #ret_block
332                } else {
333                    let has_before = ph.provides(& #po(#param), pluginop::common::Anchor::Before);
334                    let has_after = ph.provides(& #po(#param), pluginop::common::Anchor::After);
335                    let params = if has_before || has_after { Some(#param_code_prepost) } else { None };
336                    if has_before {
337                        ph.call_direct(
338                            & #po(#param),
339                            pluginop::common::Anchor::Before,
340                            params.as_ref().unwrap(),
341                        ).ok();
342                    }
343                    let ret = self.#fn_name_internal(#(#fn_args,)*);
344                    if has_after {
345                        if let Some(ph) = self.get_pluginizable_connection().map(|pc| pc.get_ph_mut()) {
346                            ph.call_direct(
347                                & #po(#param),
348                                pluginop::common::Anchor::After,
349                                params.as_ref().unwrap(),
350                            ).ok();
351                        }
352                    }
353                    ret
354                }
355            } else {
356                self.#fn_name_internal(#(#fn_args,)*)
357            }
358        }
359    }
360}
361
362/// Arguments that can be passed through the `protoop` macro. See the
363/// documentation of the macro `protoop` for more details.
364#[derive(Debug, FromMeta)]
365struct MacroSimpleArgs {
366    po: Path,
367    value: Option<Expr>,
368}
369
370/// An attribute macro to transform a non-faillible function into a
371/// non-parametrized plugin operation.
372#[proc_macro_attribute]
373pub fn pluginop(attr: TokenStream, item: TokenStream) -> TokenStream {
374    let attrs = parse_macro_input!(attr as AttributeArgs);
375    let attrs_args = match MacroSimpleArgs::from_list(&attrs) {
376        Ok(v) => v,
377        Err(e) => return TokenStream::from(e.write_errors()),
378    };
379
380    let po = attrs_args.po;
381    let value = attrs_args.value;
382    let base_fn = parse_macro_input!(item as ItemFn);
383
384    let ret_block = get_ret_block(&base_fn.sig.output);
385    let out = get_out_block(&base_fn, &po, value, &ret_block);
386
387    // println!("output is\n{}", out);
388
389    out.into()
390}
391
392/// An attribute macro to transform a function returning a [`Result`] into a
393/// non-parametrized plugin operation.
394#[proc_macro_attribute]
395pub fn pluginop_result(attr: TokenStream, item: TokenStream) -> TokenStream {
396    let attrs = parse_macro_input!(attr as AttributeArgs);
397    let attrs_args = match MacroSimpleArgs::from_list(&attrs) {
398        Ok(v) => v,
399        Err(e) => return TokenStream::from(e.write_errors()),
400    };
401
402    let po = attrs_args.po;
403    let value = attrs_args.value;
404    let base_fn = parse_macro_input!(item as ItemFn);
405
406    let ret_block = get_ret_result_block(&base_fn.sig.output);
407    let out = get_out_block(&base_fn, &po, value, &ret_block);
408
409    // println!("output is\n{}", out);
410
411    out.into()
412}
413
414/// Arguments that can be passed through the `protoop` macro. See the
415/// documentation of the macro `protoop` for more details.
416#[derive(Debug, FromMeta)]
417struct MacroArgs {
418    po: Path,
419    param: Ident,
420}
421
422/// An attribute macro to transform a non-faillible function into a
423/// parametrized plugin operation. One of the arguments of the function
424/// must act as the parameter of the plugin operation.
425#[proc_macro_attribute]
426pub fn pluginop_param(attr: TokenStream, item: TokenStream) -> TokenStream {
427    let attrs = parse_macro_input!(attr as AttributeArgs);
428    let attrs_args = match MacroArgs::from_list(&attrs) {
429        Ok(v) => v,
430        Err(e) => return TokenStream::from(e.write_errors()),
431    };
432
433    let po = attrs_args.po;
434    let param = attrs_args.param;
435
436    let base_fn = parse_macro_input!(item as ItemFn);
437
438    let ret_block = get_ret_block(&base_fn.sig.output);
439    get_out_param_block(param, &base_fn, &po, &ret_block).into()
440}
441
442/// An attribute macro to transform a function returning a [`Result`] into a
443/// parametrized plugin operation. One of the arguments of the function
444/// must act as the parameter of the plugin operation.
445#[proc_macro_attribute]
446pub fn pluginop_result_param(attr: TokenStream, item: TokenStream) -> TokenStream {
447    let attrs = parse_macro_input!(attr as AttributeArgs);
448    let attrs_args = match MacroArgs::from_list(&attrs) {
449        Ok(v) => v,
450        Err(e) => return TokenStream::from(e.write_errors()),
451    };
452
453    let po = attrs_args.po;
454    let param = attrs_args.param;
455
456    let base_fn = parse_macro_input!(item as ItemFn);
457
458    let ret_block = get_ret_result_block(&base_fn.sig.output);
459    let out = get_out_param_block(param, &base_fn, &po, &ret_block);
460
461    // println!("output is\n{}", out);
462
463    out.into()
464}