pyo3_derive_backend/
method.rs

1// Copyright (c) 2017-present PyO3 Project and Contributors
2
3use crate::pyfunction::Argument;
4use crate::pyfunction::{parse_name_attribute, PyFunctionAttr};
5use crate::utils;
6use proc_macro2::TokenStream;
7use quote::ToTokens;
8use quote::{quote, quote_spanned};
9use syn::ext::IdentExt;
10use syn::spanned::Spanned;
11
12#[derive(Clone, PartialEq, Debug)]
13pub struct FnArg<'a> {
14    pub name: &'a syn::Ident,
15    pub by_ref: &'a Option<syn::token::Ref>,
16    pub mutability: &'a Option<syn::token::Mut>,
17    pub ty: &'a syn::Type,
18    pub optional: Option<&'a syn::Type>,
19    pub py: bool,
20}
21
22#[derive(Clone, PartialEq, Debug, Copy, Eq)]
23pub enum MethodTypeAttribute {
24    /// #[new]
25    New,
26    /// #[call]
27    Call,
28    /// #[classmethod]
29    ClassMethod,
30    /// #[classattr]
31    ClassAttribute,
32    /// #[staticmethod]
33    StaticMethod,
34    /// #[getter]
35    Getter,
36    /// #[setter]
37    Setter,
38}
39
40#[derive(Clone, Debug)]
41pub enum FnType {
42    Getter(SelfType),
43    Setter(SelfType),
44    Fn(SelfType),
45    FnCall(SelfType),
46    FnNew,
47    FnClass,
48    FnStatic,
49    ClassAttribute,
50}
51
52#[derive(Clone, Debug)]
53pub enum SelfType {
54    Receiver { mutable: bool },
55    TryFromPyCell(proc_macro2::Span),
56}
57
58impl SelfType {
59    pub fn receiver(&self, cls: &syn::Type) -> TokenStream {
60        match self {
61            SelfType::Receiver { mutable: false } => {
62                quote! {
63                    let _cell = _py.from_borrowed_ptr::<pyo3::PyCell<#cls>>(_slf);
64                    let _ref = _cell.try_borrow()?;
65                    let _slf = &_ref;
66                }
67            }
68            SelfType::Receiver { mutable: true } => {
69                quote! {
70                    let _cell = _py.from_borrowed_ptr::<pyo3::PyCell<#cls>>(_slf);
71                    let mut _ref = _cell.try_borrow_mut()?;
72                    let _slf = &mut _ref;
73                }
74            }
75            SelfType::TryFromPyCell(span) => {
76                quote_spanned! { *span =>
77                    let _cell = _py.from_borrowed_ptr::<pyo3::PyCell<#cls>>(_slf);
78                    #[allow(clippy::useless_conversion)]  // In case _slf is PyCell<Self>
79                    let _slf = std::convert::TryFrom::try_from(_cell)?;
80                }
81            }
82        }
83    }
84}
85
86#[derive(Clone, Debug)]
87pub struct FnSpec<'a> {
88    pub tp: FnType,
89    // Rust function name
90    pub name: &'a syn::Ident,
91    // Wrapped python name. This should not have any leading r#.
92    // r# can be removed by syn::ext::IdentExt::unraw()
93    pub python_name: syn::Ident,
94    pub attrs: Vec<Argument>,
95    pub args: Vec<FnArg<'a>>,
96    pub output: syn::Type,
97    pub doc: syn::LitStr,
98}
99
100pub fn get_return_info(output: &syn::ReturnType) -> syn::Type {
101    match output {
102        syn::ReturnType::Default => syn::Type::Infer(syn::parse_quote! {_}),
103        syn::ReturnType::Type(_, ref ty) => *ty.clone(),
104    }
105}
106
107pub fn parse_method_receiver(arg: &syn::FnArg) -> syn::Result<SelfType> {
108    match arg {
109        syn::FnArg::Receiver(recv) => Ok(SelfType::Receiver {
110            mutable: recv.mutability.is_some(),
111        }),
112        syn::FnArg::Typed(syn::PatType { ref ty, .. }) => Ok(SelfType::TryFromPyCell(ty.span())),
113    }
114}
115
116impl<'a> FnSpec<'a> {
117    /// Parser function signature and function attributes
118    pub fn parse(
119        sig: &'a syn::Signature,
120        meth_attrs: &mut Vec<syn::Attribute>,
121        allow_custom_name: bool,
122    ) -> syn::Result<FnSpec<'a>> {
123        let name = &sig.ident;
124        let MethodAttributes {
125            ty: fn_type_attr,
126            args: fn_attrs,
127            mut python_name,
128        } = parse_method_attributes(meth_attrs, allow_custom_name)?;
129
130        let mut arguments = Vec::new();
131        let mut inputs_iter = sig.inputs.iter();
132
133        let mut parse_receiver = |msg: &'static str| {
134            inputs_iter
135                .next()
136                .ok_or_else(|| syn::Error::new_spanned(sig, msg))
137                .and_then(parse_method_receiver)
138        };
139
140        // strip get_ or set_
141        let strip_fn_name = |prefix: &'static str| {
142            let ident = sig.ident.unraw().to_string();
143            if ident.starts_with(prefix) {
144                Some(syn::Ident::new(&ident[prefix.len()..], ident.span()))
145            } else {
146                None
147            }
148        };
149
150        // Parse receiver & function type for various method types
151        let fn_type = match fn_type_attr {
152            Some(MethodTypeAttribute::StaticMethod) => FnType::FnStatic,
153            Some(MethodTypeAttribute::ClassAttribute) => {
154                if !sig.inputs.is_empty() {
155                    return Err(syn::Error::new_spanned(
156                        name,
157                        "Class attribute methods cannot take arguments",
158                    ));
159                }
160                FnType::ClassAttribute
161            }
162            Some(MethodTypeAttribute::New) => FnType::FnNew,
163            Some(MethodTypeAttribute::ClassMethod) => {
164                // Skip first argument for classmethod - always &PyType
165                let _ = inputs_iter.next();
166                FnType::FnClass
167            }
168            Some(MethodTypeAttribute::Call) => {
169                FnType::FnCall(parse_receiver("expected receiver for #[call]")?)
170            }
171            Some(MethodTypeAttribute::Getter) => {
172                // Strip off "get_" prefix if needed
173                if python_name.is_none() {
174                    python_name = strip_fn_name("get_");
175                }
176
177                FnType::Getter(parse_receiver("expected receiver for #[getter]")?)
178            }
179            Some(MethodTypeAttribute::Setter) => {
180                // Strip off "set_" prefix if needed
181                if python_name.is_none() {
182                    python_name = strip_fn_name("set_");
183                }
184
185                FnType::Setter(parse_receiver("expected receiver for #[setter]")?)
186            }
187            None => FnType::Fn(parse_receiver(
188                "Static method needs #[staticmethod] attribute",
189            )?),
190        };
191
192        // parse rest of arguments
193        for input in inputs_iter {
194            match input {
195                syn::FnArg::Receiver(recv) => {
196                    return Err(syn::Error::new_spanned(
197                        recv,
198                        "Unexpected receiver for method",
199                    ));
200                }
201                syn::FnArg::Typed(syn::PatType {
202                    ref pat, ref ty, ..
203                }) => {
204                    let (ident, by_ref, mutability) = match **pat {
205                        syn::Pat::Ident(syn::PatIdent {
206                            ref ident,
207                            ref by_ref,
208                            ref mutability,
209                            ..
210                        }) => (ident, by_ref, mutability),
211                        _ => {
212                            return Err(syn::Error::new_spanned(pat, "unsupported argument"));
213                        }
214                    };
215
216                    arguments.push(FnArg {
217                        name: ident,
218                        by_ref,
219                        mutability,
220                        ty,
221                        optional: utils::option_type_argument(ty),
222                        py: utils::is_python(ty),
223                    });
224                }
225            }
226        }
227
228        let ty = get_return_info(&sig.output);
229        let python_name = python_name.unwrap_or_else(|| name.unraw());
230
231        let mut parse_erroneous_text_signature = |error_msg: &str| {
232            // try to parse anyway to give better error messages
233            if let Some(text_signature) =
234                utils::parse_text_signature_attrs(meth_attrs, &python_name)?
235            {
236                Err(syn::Error::new_spanned(text_signature, error_msg))
237            } else {
238                Ok(None)
239            }
240        };
241
242        let text_signature = match &fn_type {
243            FnType::Fn(_) | FnType::FnClass | FnType::FnStatic => {
244                utils::parse_text_signature_attrs(&mut *meth_attrs, name)?
245            }
246            FnType::FnNew => parse_erroneous_text_signature(
247                "text_signature not allowed on __new__; if you want to add a signature on \
248                 __new__, put it on the struct definition instead",
249            )?,
250            FnType::FnCall(_) | FnType::Getter(_) | FnType::Setter(_) | FnType::ClassAttribute => {
251                parse_erroneous_text_signature("text_signature not allowed with this attribute")?
252            }
253        };
254
255        let doc = utils::get_doc(&meth_attrs, text_signature, true)?;
256
257        Ok(FnSpec {
258            tp: fn_type,
259            name,
260            python_name,
261            attrs: fn_attrs,
262            args: arguments,
263            output: ty,
264            doc,
265        })
266    }
267
268    pub fn is_args(&self, name: &syn::Ident) -> bool {
269        for s in self.attrs.iter() {
270            if let Argument::VarArgs(ref path) = s {
271                return path.is_ident(name);
272            }
273        }
274        false
275    }
276
277    pub fn is_kwargs(&self, name: &syn::Ident) -> bool {
278        for s in self.attrs.iter() {
279            if let Argument::KeywordArgs(ref path) = s {
280                return path.is_ident(name);
281            }
282        }
283        false
284    }
285
286    pub fn default_value(&self, name: &syn::Ident) -> Option<TokenStream> {
287        for s in self.attrs.iter() {
288            match *s {
289                Argument::Arg(ref path, ref opt) | Argument::Kwarg(ref path, ref opt) => {
290                    if path.is_ident(name) {
291                        if let Some(ref val) = opt {
292                            let i: syn::Expr = syn::parse_str(&val).unwrap();
293                            return Some(i.into_token_stream());
294                        }
295                    }
296                }
297                _ => (),
298            }
299        }
300        None
301    }
302
303    pub fn is_kw_only(&self, name: &syn::Ident) -> bool {
304        for s in self.attrs.iter() {
305            if let Argument::Kwarg(ref path, _) = s {
306                if path.is_ident(name) {
307                    return true;
308                }
309            }
310        }
311        false
312    }
313}
314
315#[derive(Clone, PartialEq, Debug)]
316struct MethodAttributes {
317    ty: Option<MethodTypeAttribute>,
318    args: Vec<Argument>,
319    python_name: Option<syn::Ident>,
320}
321
322fn parse_method_attributes(
323    attrs: &mut Vec<syn::Attribute>,
324    allow_custom_name: bool,
325) -> syn::Result<MethodAttributes> {
326    let mut new_attrs = Vec::new();
327    let mut args = Vec::new();
328    let mut ty: Option<MethodTypeAttribute> = None;
329    let mut property_name = None;
330
331    macro_rules! set_ty {
332        ($new_ty:expr, $ident:expr) => {
333            if ty.replace($new_ty).is_some() {
334                return Err(syn::Error::new_spanned(
335                    $ident,
336                    "Cannot specify a second method type",
337                ));
338            }
339        };
340    }
341
342    for attr in attrs.iter() {
343        match attr.parse_meta()? {
344            syn::Meta::Path(ref name) => {
345                if name.is_ident("new") || name.is_ident("__new__") {
346                    set_ty!(MethodTypeAttribute::New, name);
347                } else if name.is_ident("init") || name.is_ident("__init__") {
348                    return Err(syn::Error::new_spanned(
349                        name,
350                        "#[init] is disabled since PyO3 0.9.0",
351                    ));
352                } else if name.is_ident("call") || name.is_ident("__call__") {
353                    set_ty!(MethodTypeAttribute::Call, name);
354                } else if name.is_ident("classmethod") {
355                    set_ty!(MethodTypeAttribute::ClassMethod, name);
356                } else if name.is_ident("staticmethod") {
357                    set_ty!(MethodTypeAttribute::StaticMethod, name);
358                } else if name.is_ident("classattr") {
359                    set_ty!(MethodTypeAttribute::ClassAttribute, name);
360                } else if name.is_ident("setter") || name.is_ident("getter") {
361                    if let syn::AttrStyle::Inner(_) = attr.style {
362                        return Err(syn::Error::new_spanned(
363                            attr,
364                            "Inner style attribute is not supported for setter and getter",
365                        ));
366                    }
367                    if name.is_ident("setter") {
368                        set_ty!(MethodTypeAttribute::Setter, name);
369                    } else {
370                        set_ty!(MethodTypeAttribute::Getter, name);
371                    }
372                } else {
373                    new_attrs.push(attr.clone())
374                }
375            }
376            syn::Meta::List(syn::MetaList {
377                ref path,
378                ref nested,
379                ..
380            }) => {
381                if path.is_ident("new") {
382                    set_ty!(MethodTypeAttribute::New, path);
383                } else if path.is_ident("init") {
384                    return Err(syn::Error::new_spanned(
385                        path,
386                        "#[init] is disabled since PyO3 0.9.0",
387                    ));
388                } else if path.is_ident("call") {
389                    set_ty!(MethodTypeAttribute::Call, path);
390                } else if path.is_ident("setter") || path.is_ident("getter") {
391                    if let syn::AttrStyle::Inner(_) = attr.style {
392                        return Err(syn::Error::new_spanned(
393                            attr,
394                            "Inner style attribute is not supported for setter and getter",
395                        ));
396                    }
397                    if nested.len() != 1 {
398                        return Err(syn::Error::new_spanned(
399                            attr,
400                            "setter/getter requires one value",
401                        ));
402                    }
403
404                    if path.is_ident("setter") {
405                        set_ty!(MethodTypeAttribute::Setter, path);
406                    } else {
407                        set_ty!(MethodTypeAttribute::Getter, path);
408                    };
409
410                    property_name = match nested.first().unwrap() {
411                        syn::NestedMeta::Meta(syn::Meta::Path(ref w)) if w.segments.len() == 1 => {
412                            Some(w.segments[0].ident.clone())
413                        }
414                        syn::NestedMeta::Lit(ref lit) => match *lit {
415                            syn::Lit::Str(ref s) => Some(s.parse()?),
416                            _ => {
417                                return Err(syn::Error::new_spanned(
418                                    lit,
419                                    "setter/getter attribute requires str value",
420                                ))
421                            }
422                        },
423                        _ => {
424                            return Err(syn::Error::new_spanned(
425                                nested.first().unwrap(),
426                                "expected ident or string literal for property name",
427                            ))
428                        }
429                    };
430                } else if path.is_ident("args") {
431                    let attrs = PyFunctionAttr::from_meta(nested)?;
432                    args.extend(attrs.arguments)
433                } else {
434                    new_attrs.push(attr.clone())
435                }
436            }
437            syn::Meta::NameValue(_) => new_attrs.push(attr.clone()),
438        }
439    }
440
441    attrs.clear();
442    attrs.extend(new_attrs);
443
444    let python_name = if allow_custom_name {
445        parse_method_name_attribute(ty.as_ref(), attrs, property_name)?
446    } else {
447        property_name
448    };
449
450    Ok(MethodAttributes {
451        ty,
452        args,
453        python_name,
454    })
455}
456
457fn parse_method_name_attribute(
458    ty: Option<&MethodTypeAttribute>,
459    attrs: &mut Vec<syn::Attribute>,
460    property_name: Option<syn::Ident>,
461) -> syn::Result<Option<syn::Ident>> {
462    use MethodTypeAttribute::*;
463    let name = parse_name_attribute(attrs)?;
464
465    // Reject some invalid combinations
466    if let (Some(name), Some(ty)) = (&name, ty) {
467        match ty {
468            New | Call | Getter | Setter => {
469                return Err(syn::Error::new_spanned(
470                    name,
471                    "name not allowed with this method type",
472                ))
473            }
474            _ => {}
475        }
476    }
477
478    // Thanks to check above we can be sure that this generates the right python name
479    Ok(match ty {
480        Some(New) => Some(syn::Ident::new("__new__", proc_macro2::Span::call_site())),
481        Some(Call) => Some(syn::Ident::new("__call__", proc_macro2::Span::call_site())),
482        Some(Getter) | Some(Setter) => property_name,
483        _ => name,
484    })
485}