pydeco/
lib.rs

1//! Python-like decorator for Rust
2//!
3//! Example
4//! --------
5//!
6//! ```
7//! use pydeco::deco;
8//!
9//! fn logging<F>(func: F) -> impl Fn(i32) -> i32
10//! where
11//!     F: Fn(i32) -> i32,
12//! {
13//!     move |i| {
14//!         println!("Input = {}", i);
15//!         let out = func(i);
16//!         println!("Output = {}", out);
17//!         out
18//!     }
19//! }
20//!
21//! #[deco(logging)]
22//! fn add2(i: i32) -> i32 {
23//!     i + 2
24//! }
25//!
26//! add2(2);
27//! ```
28//!
29//! - Decorator with parameter
30//!
31//! ```
32//! use pydeco::deco;
33//! use std::{fs, io::Write};
34//!
35//! fn logging<InputFunc: 'static>(
36//!     log_filename: &'static str,
37//! ) -> impl Fn(InputFunc) -> Box<dyn Fn(i32) -> i32>
38//! where
39//!     InputFunc: Fn(i32) -> i32,
40//! {
41//!     move |func: InputFunc| {
42//!         Box::new(move |i: i32| {
43//!             let mut f = fs::File::create(log_filename).unwrap();
44//!             writeln!(f, "Input = {}", i).unwrap();
45//!             let out = func(i);
46//!             writeln!(f, "Output = {}", out).unwrap();
47//!             out
48//!         })
49//!     }
50//! }
51//!
52//! #[deco(logging("test.log"))]
53//! fn add2(i: i32) -> i32 {
54//!     i + 2
55//! }
56//!
57//! add2(2);
58//! ```
59//!
60
61use anyhow::{bail, Result};
62use proc_macro::TokenStream;
63use proc_macro2::TokenTree;
64use syn::*;
65
66#[proc_macro_attribute]
67pub fn deco(attr: TokenStream, func: TokenStream) -> TokenStream {
68    let func = func.into();
69    let item_fn: ItemFn = syn::parse(func).expect("Input is not a function");
70    let vis = &item_fn.vis;
71    let ident = &item_fn.sig.ident;
72    let block = &item_fn.block;
73
74    let inputs = item_fn.sig.inputs;
75    let output = item_fn.sig.output;
76
77    let input_values: Vec<_> = inputs
78        .iter()
79        .map(|arg| match arg {
80            &FnArg::Typed(ref val) => &val.pat,
81            _ => unimplemented!("#[deco] cannot be used with associated function"),
82        })
83        .collect();
84
85    let attr = DecoratorAttr::parse(attr.into()).expect("Failed to parse attribute");
86    let caller = match attr {
87        DecoratorAttr::Fixed { name } => {
88            quote::quote! {
89                #vis fn #ident(#inputs) #output {
90                    let f = #name(deco_internal);
91                    return f(#(#input_values,) *);
92
93                    fn deco_internal(#inputs) #output #block
94                }
95            }
96        }
97        DecoratorAttr::Parametric { name, args } => {
98            quote::quote! {
99                #vis fn #ident(#inputs) #output {
100                    let deco = #name(#(#args,) *);
101                    let f = deco(deco_internal);
102                    return f(#(#input_values,) *);
103
104                    fn deco_internal(#inputs) #output #block
105                }
106            }
107        }
108    };
109    caller.into()
110}
111
112#[derive(Debug, PartialEq)]
113enum DecoratorAttr {
114    Fixed { name: Ident },
115    Parametric { name: Ident, args: Vec<Expr> },
116}
117
118impl DecoratorAttr {
119    fn parse(attr: proc_macro2::TokenStream) -> Result<Self> {
120        let mut ident = None;
121        let mut args = Vec::new();
122        for at in attr {
123            match at {
124                TokenTree::Ident(id) => {
125                    ident = Some(id);
126                }
127                TokenTree::Group(grp) => {
128                    if ident.is_none() {
129                        bail!("Invalid token stream");
130                    }
131                    for t in grp.stream() {
132                        if let Ok(expr) = syn::parse2(t.into()) {
133                            args.push(expr);
134                        }
135                    }
136                }
137                _ => bail!("Invalid token stream"),
138            }
139        }
140        if let Some(name) = ident {
141            if args.is_empty() {
142                Ok(DecoratorAttr::Fixed { name })
143            } else {
144                Ok(DecoratorAttr::Parametric { name, args })
145            }
146        } else {
147            bail!("Decorator name not found");
148        }
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use std::str::FromStr;
156
157    #[test]
158    fn parse_attr() -> Result<()> {
159        let ts = proc_macro2::TokenStream::from_str("logging").unwrap();
160        assert!(matches!(DecoratorAttr::parse(ts)?, DecoratorAttr::Fixed {..}));
161        Ok(())
162    }
163
164    #[test]
165    fn parse_attr_parametric_literal() -> Result<()> {
166        let ts = proc_macro2::TokenStream::from_str(r#"logging("test.log", 2)"#).unwrap();
167        match DecoratorAttr::parse(ts)? {
168            DecoratorAttr::Fixed { .. } => bail!("Failed to parse args"),
169            DecoratorAttr::Parametric { args, .. } => {
170                assert_eq!(args.len(), 2);
171            }
172        }
173        Ok(())
174    }
175
176    #[test]
177    fn parse_attr_parametric_variable() -> Result<()> {
178        let ts =
179            proc_macro2::TokenStream::from_str(r#"logging("test.log", some_variable)"#).unwrap();
180        match DecoratorAttr::parse(ts)? {
181            DecoratorAttr::Fixed { .. } => bail!("Failed to parse args"),
182            DecoratorAttr::Parametric { args, .. } => {
183                assert_eq!(args.len(), 2);
184            }
185        }
186        Ok(())
187    }
188
189    #[test]
190    fn parse_attr_parametric_expr() -> Result<()> {
191        let ts = proc_macro2::TokenStream::from_str(r#"logging("test.log", (1 + 2))"#).unwrap();
192        match DecoratorAttr::parse(ts)? {
193            DecoratorAttr::Fixed { .. } => bail!("Failed to parse args"),
194            DecoratorAttr::Parametric { args, .. } => {
195                assert_eq!(args.len(), 2);
196            }
197        }
198        Ok(())
199    }
200
201    #[test]
202    fn parse_attr_empty() -> Result<()> {
203        let ts = proc_macro2::TokenStream::from_str("").unwrap();
204        assert!(DecoratorAttr::parse(ts).is_err());
205        Ok(())
206    }
207
208    #[test]
209    fn parse_attr_invalid() -> Result<()> {
210        // inverse order
211        let ts = proc_macro2::TokenStream::from_str(r#"("test.log", 2)logging"#).unwrap();
212        assert!(DecoratorAttr::parse(ts).is_err());
213        Ok(())
214    }
215}