pyo3_derive_backend/
pyfunction.rs

1// Copyright (c) 2017-present PyO3 Project and Contributors
2
3use crate::module::add_fn_to_module;
4use proc_macro2::TokenStream;
5use syn::ext::IdentExt;
6use syn::parse::ParseBuffer;
7use syn::punctuated::Punctuated;
8use syn::spanned::Spanned;
9use syn::{NestedMeta, Path};
10
11#[derive(Debug, Clone, PartialEq)]
12pub enum Argument {
13    VarArgsSeparator,
14    VarArgs(syn::Path),
15    KeywordArgs(syn::Path),
16    Arg(syn::Path, Option<String>),
17    Kwarg(syn::Path, Option<String>),
18}
19
20/// The attributes of the pyfunction macro
21#[derive(Default)]
22pub struct PyFunctionAttr {
23    pub arguments: Vec<Argument>,
24    has_kw: bool,
25    has_varargs: bool,
26    has_kwargs: bool,
27    pub pass_module: bool,
28}
29
30impl syn::parse::Parse for PyFunctionAttr {
31    fn parse(input: &ParseBuffer) -> syn::Result<Self> {
32        let attr = Punctuated::<NestedMeta, syn::Token![,]>::parse_terminated(input)?;
33        Self::from_meta(&attr)
34    }
35}
36
37impl PyFunctionAttr {
38    pub fn from_meta<'a>(iter: impl IntoIterator<Item = &'a NestedMeta>) -> syn::Result<Self> {
39        let mut slf = PyFunctionAttr::default();
40
41        for item in iter {
42            slf.add_item(item)?
43        }
44        Ok(slf)
45    }
46
47    pub fn add_item(&mut self, item: &NestedMeta) -> syn::Result<()> {
48        match item {
49            NestedMeta::Meta(syn::Meta::Path(ref ident)) if ident.is_ident("pass_module") => {
50                self.pass_module = true;
51            }
52            NestedMeta::Meta(syn::Meta::Path(ref ident)) => self.add_work(item, ident)?,
53            NestedMeta::Meta(syn::Meta::NameValue(ref nv)) => {
54                self.add_name_value(item, nv)?;
55            }
56            NestedMeta::Lit(ref lit) => {
57                self.add_literal(item, lit)?;
58            }
59            NestedMeta::Meta(syn::Meta::List(ref list)) => {
60                return Err(syn::Error::new_spanned(
61                    list,
62                    "List is not supported as argument",
63                ));
64            }
65        }
66        Ok(())
67    }
68
69    fn add_literal(&mut self, item: &NestedMeta, lit: &syn::Lit) -> syn::Result<()> {
70        match lit {
71            syn::Lit::Str(ref lits) if lits.value() == "*" => {
72                // "*"
73                self.vararg_is_ok(item)?;
74                self.has_varargs = true;
75                self.arguments.push(Argument::VarArgsSeparator);
76                Ok(())
77            }
78            _ => Err(syn::Error::new_spanned(
79                item,
80                format!("Only \"*\" is supported here, got: {:?}", lit),
81            )),
82        }
83    }
84
85    fn add_work(&mut self, item: &NestedMeta, path: &Path) -> syn::Result<()> {
86        if self.has_kw || self.has_kwargs {
87            return Err(syn::Error::new_spanned(
88                item,
89                "Positional argument or varargs(*) is not allowed after keyword arguments",
90            ));
91        }
92        if self.has_varargs {
93            self.arguments.push(Argument::Kwarg(path.clone(), None));
94        } else {
95            self.arguments.push(Argument::Arg(path.clone(), None));
96        }
97        Ok(())
98    }
99
100    fn vararg_is_ok(&self, item: &NestedMeta) -> syn::Result<()> {
101        if self.has_kwargs || self.has_varargs {
102            return Err(syn::Error::new_spanned(
103                item,
104                "* is not allowed after varargs(*) or kwargs(**)",
105            ));
106        }
107        Ok(())
108    }
109
110    fn kw_arg_is_ok(&self, item: &NestedMeta) -> syn::Result<()> {
111        if self.has_kwargs {
112            return Err(syn::Error::new_spanned(
113                item,
114                "Keyword argument or kwargs(**) is not allowed after kwargs(**)",
115            ));
116        }
117        Ok(())
118    }
119
120    fn add_nv_common(
121        &mut self,
122        item: &NestedMeta,
123        name: &syn::Path,
124        value: String,
125    ) -> syn::Result<()> {
126        self.kw_arg_is_ok(item)?;
127        if self.has_varargs {
128            // kw only
129            self.arguments
130                .push(Argument::Kwarg(name.clone(), Some(value)));
131        } else {
132            self.has_kw = true;
133            self.arguments
134                .push(Argument::Arg(name.clone(), Some(value)));
135        }
136        Ok(())
137    }
138
139    fn add_name_value(&mut self, item: &NestedMeta, nv: &syn::MetaNameValue) -> syn::Result<()> {
140        match nv.lit {
141            syn::Lit::Str(ref litstr) => {
142                if litstr.value() == "*" {
143                    // args="*"
144                    self.vararg_is_ok(item)?;
145                    self.has_varargs = true;
146                    self.arguments.push(Argument::VarArgs(nv.path.clone()));
147                } else if litstr.value() == "**" {
148                    // kwargs="**"
149                    self.kw_arg_is_ok(item)?;
150                    self.has_kwargs = true;
151                    self.arguments.push(Argument::KeywordArgs(nv.path.clone()));
152                } else {
153                    self.add_nv_common(item, &nv.path, litstr.value())?;
154                }
155            }
156            syn::Lit::Int(ref litint) => {
157                self.add_nv_common(item, &nv.path, format!("{}", litint))?;
158            }
159            syn::Lit::Bool(ref litb) => {
160                self.add_nv_common(item, &nv.path, format!("{}", litb.value))?;
161            }
162            _ => {
163                return Err(syn::Error::new_spanned(
164                    nv.lit.clone(),
165                    "Only string literal is supported",
166                ));
167            }
168        };
169        Ok(())
170    }
171}
172
173pub fn parse_name_attribute(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Option<syn::Ident>> {
174    let mut name_attrs = Vec::new();
175
176    // Using retain will extract all name attributes from the attribute list
177    attrs.retain(|attr| match attr.parse_meta() {
178        Ok(syn::Meta::NameValue(ref nv)) if nv.path.is_ident("name") => {
179            name_attrs.push((nv.lit.clone(), attr.span()));
180            false
181        }
182        _ => true,
183    });
184
185    if 1 < name_attrs.len() {
186        return Err(syn::Error::new(
187            name_attrs[0].1,
188            "#[name] can not be specified multiple times",
189        ));
190    }
191
192    match name_attrs.get(0) {
193        Some((syn::Lit::Str(s), span)) => {
194            let mut ident: syn::Ident = s.parse()?;
195            // This span is the whole attribute span, which is nicer for reporting errors.
196            ident.set_span(*span);
197            Ok(Some(ident))
198        }
199        Some((_, span)) => Err(syn::Error::new(
200            *span,
201            "Expected string literal for #[name] argument",
202        )),
203        None => Ok(None),
204    }
205}
206
207pub fn build_py_function(ast: &mut syn::ItemFn, args: PyFunctionAttr) -> syn::Result<TokenStream> {
208    let python_name =
209        parse_name_attribute(&mut ast.attrs)?.unwrap_or_else(|| ast.sig.ident.unraw());
210    add_fn_to_module(ast, python_name, args)
211}
212
213#[cfg(test)]
214mod test {
215    use super::{Argument, PyFunctionAttr};
216    use proc_macro2::TokenStream;
217    use quote::quote;
218    use syn::parse_quote;
219
220    fn items(input: TokenStream) -> syn::Result<Vec<Argument>> {
221        let py_fn_attr: PyFunctionAttr = syn::parse2(input)?;
222        Ok(py_fn_attr.arguments)
223    }
224
225    #[test]
226    fn test_errs() {
227        assert!(items(quote! {test="1", test2}).is_err());
228        assert!(items(quote! {test, "*", args="*"}).is_err());
229        assert!(items(quote! {test, kwargs="**", args="*"}).is_err());
230        assert!(items(quote! {test, kwargs="**", args}).is_err());
231    }
232
233    #[test]
234    fn test_simple_args() {
235        let args = items(quote! {test1, test2, test3="None"}).unwrap();
236        assert!(
237            args == vec![
238                Argument::Arg(parse_quote! {test1}, None),
239                Argument::Arg(parse_quote! {test2}, None),
240                Argument::Arg(parse_quote! {test3}, Some("None".to_owned())),
241            ]
242        );
243    }
244
245    #[test]
246    fn test_varargs() {
247        let args = items(quote! {test1, test2="None", "*", test3="None"}).unwrap();
248        assert!(
249            args == vec![
250                Argument::Arg(parse_quote! {test1}, None),
251                Argument::Arg(parse_quote! {test2}, Some("None".to_owned())),
252                Argument::VarArgsSeparator,
253                Argument::Kwarg(parse_quote! {test3}, Some("None".to_owned())),
254            ]
255        );
256
257        let args = items(quote! {"*", test1, test2}).unwrap();
258        assert!(
259            args == vec![
260                Argument::VarArgsSeparator,
261                Argument::Kwarg(parse_quote! {test1}, None),
262                Argument::Kwarg(parse_quote! {test2}, None),
263            ]
264        );
265
266        let args = items(quote! {"*", test1, test2="None"}).unwrap();
267        assert!(
268            args == vec![
269                Argument::VarArgsSeparator,
270                Argument::Kwarg(parse_quote! {test1}, None),
271                Argument::Kwarg(parse_quote! {test2}, Some("None".to_owned())),
272            ]
273        );
274
275        let args = items(quote! {"*", test1="None", test2}).unwrap();
276        assert!(
277            args == vec![
278                Argument::VarArgsSeparator,
279                Argument::Kwarg(parse_quote! {test1}, Some("None".to_owned())),
280                Argument::Kwarg(parse_quote! {test2}, None),
281            ]
282        );
283    }
284
285    #[test]
286    fn test_all() {
287        let args =
288            items(quote! {test1, test2="None", args="*", test3="None", kwargs="**"}).unwrap();
289        assert!(
290            args == vec![
291                Argument::Arg(parse_quote! {test1}, None),
292                Argument::Arg(parse_quote! {test2}, Some("None".to_owned())),
293                Argument::VarArgs(parse_quote! {args}),
294                Argument::Kwarg(parse_quote! {test3}, Some("None".to_owned())),
295                Argument::KeywordArgs(parse_quote! {kwargs}),
296            ]
297        );
298    }
299}