pyo3_macros_backend/
utils.rs

1use crate::attributes::{CrateAttribute, RenamingRule};
2use proc_macro2::{Span, TokenStream};
3use quote::{quote, quote_spanned, ToTokens};
4use std::ffi::CString;
5use syn::spanned::Spanned;
6use syn::visit_mut::VisitMut;
7use syn::{punctuated::Punctuated, Token};
8
9/// Macro inspired by `anyhow::anyhow!` to create a compiler error with the given span.
10macro_rules! err_spanned {
11    ($span:expr => $msg:expr) => {
12        syn::Error::new($span, $msg)
13    };
14}
15
16/// Macro inspired by `anyhow::bail!` to return a compiler error with the given span.
17macro_rules! bail_spanned {
18    ($span:expr => $msg:expr) => {
19        return Err(err_spanned!($span => $msg))
20    };
21}
22
23/// Macro inspired by `anyhow::ensure!` to return a compiler error with the given span if the
24/// specified condition is not met.
25macro_rules! ensure_spanned {
26    ($condition:expr, $span:expr => $msg:expr) => {
27        if !($condition) {
28            bail_spanned!($span => $msg);
29        }
30    };
31    ($($condition:expr, $span:expr => $msg:expr;)*) => {
32        if let Some(e) = [$(
33            (!($condition)).then(|| err_spanned!($span => $msg)),
34        )*]
35            .into_iter()
36            .flatten()
37            .reduce(|mut acc, e| {
38                acc.combine(e);
39                acc
40            }) {
41                return Err(e);
42            }
43    };
44}
45
46/// Check if the given type `ty` is `pyo3::Python`.
47pub fn is_python(ty: &syn::Type) -> bool {
48    match unwrap_ty_group(ty) {
49        syn::Type::Path(typath) => typath
50            .path
51            .segments
52            .last()
53            .map(|seg| seg.ident == "Python")
54            .unwrap_or(false),
55        _ => false,
56    }
57}
58
59/// If `ty` is `Option<T>`, return `Some(T)`, else `None`.
60pub fn option_type_argument(ty: &syn::Type) -> Option<&syn::Type> {
61    if let syn::Type::Path(syn::TypePath { path, .. }) = ty {
62        let seg = path.segments.last().filter(|s| s.ident == "Option")?;
63        if let syn::PathArguments::AngleBracketed(params) = &seg.arguments {
64            if let syn::GenericArgument::Type(ty) = params.args.first()? {
65                return Some(ty);
66            }
67        }
68    }
69    None
70}
71
72// TODO: Replace usage of this by [`syn::LitCStr`] when on MSRV 1.77
73#[derive(Clone)]
74pub struct LitCStr {
75    lit: CString,
76    span: Span,
77    pyo3_path: PyO3CratePath,
78}
79
80impl LitCStr {
81    pub fn new(lit: CString, span: Span, ctx: &Ctx) -> Self {
82        Self {
83            lit,
84            span,
85            pyo3_path: ctx.pyo3_path.clone(),
86        }
87    }
88
89    pub fn empty(ctx: &Ctx) -> Self {
90        Self {
91            lit: CString::new("").unwrap(),
92            span: Span::call_site(),
93            pyo3_path: ctx.pyo3_path.clone(),
94        }
95    }
96}
97
98impl quote::ToTokens for LitCStr {
99    fn to_tokens(&self, tokens: &mut TokenStream) {
100        if cfg!(c_str_lit) {
101            syn::LitCStr::new(&self.lit, self.span).to_tokens(tokens);
102        } else {
103            let pyo3_path = &self.pyo3_path;
104            let lit = self.lit.to_str().unwrap();
105            tokens.extend(quote::quote_spanned!(self.span => #pyo3_path::ffi::c_str!(#lit)));
106        }
107    }
108}
109
110/// A syntax tree which evaluates to a nul-terminated docstring for Python.
111///
112/// Typically the tokens will just be that string, but if the original docs included macro
113/// expressions then the tokens will be a concat!("...", "\n", "\0") expression of the strings and
114/// macro parts. contents such as parse the string contents.
115#[derive(Clone)]
116pub struct PythonDoc(PythonDocKind);
117
118#[derive(Clone)]
119enum PythonDocKind {
120    LitCStr(LitCStr),
121    // There is currently no way to `concat!` c-string literals, we fallback to the `c_str!` macro in
122    // this case.
123    Tokens(TokenStream),
124}
125
126/// Collects all #[doc = "..."] attributes into a TokenStream evaluating to a null-terminated string.
127///
128/// If this doc is for a callable, the provided `text_signature` can be passed to prepend
129/// this to the documentation suitable for Python to extract this into the `__text_signature__`
130/// attribute.
131pub fn get_doc(
132    attrs: &[syn::Attribute],
133    mut text_signature: Option<String>,
134    ctx: &Ctx,
135) -> syn::Result<PythonDoc> {
136    let Ctx { pyo3_path, .. } = ctx;
137    // insert special divider between `__text_signature__` and doc
138    // (assume text_signature is itself well-formed)
139    if let Some(text_signature) = &mut text_signature {
140        text_signature.push_str("\n--\n\n");
141    }
142
143    let mut parts = Punctuated::<TokenStream, Token![,]>::new();
144    let mut first = true;
145    let mut current_part = text_signature.unwrap_or_default();
146    let mut current_part_span = None;
147
148    for attr in attrs {
149        if attr.path().is_ident("doc") {
150            if let Ok(nv) = attr.meta.require_name_value() {
151                current_part_span = match current_part_span {
152                    None => Some(nv.value.span()),
153                    Some(span) => span.join(nv.value.span()),
154                };
155                if !first {
156                    current_part.push('\n');
157                } else {
158                    first = false;
159                }
160                if let syn::Expr::Lit(syn::ExprLit {
161                    lit: syn::Lit::Str(lit_str),
162                    ..
163                }) = &nv.value
164                {
165                    // Strip single left space from literal strings, if needed.
166                    // e.g. `/// Hello world` expands to #[doc = " Hello world"]
167                    let doc_line = lit_str.value();
168                    current_part.push_str(doc_line.strip_prefix(' ').unwrap_or(&doc_line));
169                } else {
170                    // This is probably a macro doc from Rust 1.54, e.g. #[doc = include_str!(...)]
171                    // Reset the string buffer, write that part, and then push this macro part too.
172                    parts.push(quote_spanned!(current_part_span.unwrap_or(Span::call_site()) => #current_part));
173                    current_part.clear();
174                    parts.push(nv.value.to_token_stream());
175                }
176            }
177        }
178    }
179
180    if !parts.is_empty() {
181        // Doc contained macro pieces - return as `concat!` expression
182        if !current_part.is_empty() {
183            parts.push(
184                quote_spanned!(current_part_span.unwrap_or(Span::call_site()) => #current_part),
185            );
186        }
187
188        let mut tokens = TokenStream::new();
189
190        syn::Ident::new("concat", Span::call_site()).to_tokens(&mut tokens);
191        syn::token::Not(Span::call_site()).to_tokens(&mut tokens);
192        syn::token::Bracket(Span::call_site()).surround(&mut tokens, |tokens| {
193            parts.to_tokens(tokens);
194            syn::token::Comma(Span::call_site()).to_tokens(tokens);
195        });
196
197        Ok(PythonDoc(PythonDocKind::Tokens(
198            quote!(#pyo3_path::ffi::c_str!(#tokens)),
199        )))
200    } else {
201        // Just a string doc - return directly with nul terminator
202        let docs = CString::new(current_part).map_err(|e| {
203            syn::Error::new(
204                current_part_span.unwrap_or(Span::call_site()),
205                format!(
206                    "Python doc may not contain nul byte, found nul at position {}",
207                    e.nul_position()
208                ),
209            )
210        })?;
211        Ok(PythonDoc(PythonDocKind::LitCStr(LitCStr::new(
212            docs,
213            current_part_span.unwrap_or(Span::call_site()),
214            ctx,
215        ))))
216    }
217}
218
219impl quote::ToTokens for PythonDoc {
220    fn to_tokens(&self, tokens: &mut TokenStream) {
221        match &self.0 {
222            PythonDocKind::LitCStr(lit) => lit.to_tokens(tokens),
223            PythonDocKind::Tokens(toks) => toks.to_tokens(tokens),
224        }
225    }
226}
227
228pub fn unwrap_ty_group(mut ty: &syn::Type) -> &syn::Type {
229    while let syn::Type::Group(g) = ty {
230        ty = &*g.elem;
231    }
232    ty
233}
234
235pub struct Ctx {
236    /// Where we can find the pyo3 crate
237    pub pyo3_path: PyO3CratePath,
238
239    /// If we are in a pymethod or pyfunction,
240    /// this will be the span of the return type
241    pub output_span: Span,
242}
243
244impl Ctx {
245    pub(crate) fn new(attr: &Option<CrateAttribute>, signature: Option<&syn::Signature>) -> Self {
246        let pyo3_path = match attr {
247            Some(attr) => PyO3CratePath::Given(attr.value.0.clone()),
248            None => PyO3CratePath::Default,
249        };
250
251        let output_span = if let Some(syn::Signature {
252            output: syn::ReturnType::Type(_, output_type),
253            ..
254        }) = &signature
255        {
256            output_type.span()
257        } else {
258            Span::call_site()
259        };
260
261        Self {
262            pyo3_path,
263            output_span,
264        }
265    }
266}
267
268#[derive(Clone)]
269pub enum PyO3CratePath {
270    Given(syn::Path),
271    Default,
272}
273
274impl PyO3CratePath {
275    pub fn to_tokens_spanned(&self, span: Span) -> TokenStream {
276        match self {
277            Self::Given(path) => quote::quote_spanned! { span => #path },
278            Self::Default => quote::quote_spanned! {  span => ::pyo3 },
279        }
280    }
281}
282
283impl quote::ToTokens for PyO3CratePath {
284    fn to_tokens(&self, tokens: &mut TokenStream) {
285        match self {
286            Self::Given(path) => path.to_tokens(tokens),
287            Self::Default => quote::quote! { ::pyo3 }.to_tokens(tokens),
288        }
289    }
290}
291
292pub fn apply_renaming_rule(rule: RenamingRule, name: &str) -> String {
293    use heck::*;
294
295    match rule {
296        RenamingRule::CamelCase => name.to_lower_camel_case(),
297        RenamingRule::KebabCase => name.to_kebab_case(),
298        RenamingRule::Lowercase => name.to_lowercase(),
299        RenamingRule::PascalCase => name.to_upper_camel_case(),
300        RenamingRule::ScreamingKebabCase => name.to_shouty_kebab_case(),
301        RenamingRule::ScreamingSnakeCase => name.to_shouty_snake_case(),
302        RenamingRule::SnakeCase => name.to_snake_case(),
303        RenamingRule::Uppercase => name.to_uppercase(),
304    }
305}
306
307pub(crate) enum IdentOrStr<'a> {
308    Str(&'a str),
309    Ident(syn::Ident),
310}
311
312pub(crate) fn has_attribute(attrs: &[syn::Attribute], ident: &str) -> bool {
313    has_attribute_with_namespace(attrs, None, &[ident])
314}
315
316pub(crate) fn has_attribute_with_namespace(
317    attrs: &[syn::Attribute],
318    crate_path: Option<&PyO3CratePath>,
319    idents: &[&str],
320) -> bool {
321    let mut segments = vec![];
322    if let Some(c) = crate_path {
323        match c {
324            PyO3CratePath::Given(paths) => {
325                for p in &paths.segments {
326                    segments.push(IdentOrStr::Ident(p.ident.clone()));
327                }
328            }
329            PyO3CratePath::Default => segments.push(IdentOrStr::Str("pyo3")),
330        }
331    };
332    for i in idents {
333        segments.push(IdentOrStr::Str(i));
334    }
335
336    attrs.iter().any(|attr| {
337        segments
338            .iter()
339            .eq(attr.path().segments.iter().map(|v| &v.ident))
340    })
341}
342
343pub(crate) trait TypeExt {
344    /// Replaces all explicit lifetimes in `self` with elided (`'_`) lifetimes
345    ///
346    /// This is useful if `Self` is used in `const` context, where explicit
347    /// lifetimes are not allowed (yet).
348    fn elide_lifetimes(self) -> Self;
349}
350
351impl TypeExt for syn::Type {
352    fn elide_lifetimes(mut self) -> Self {
353        struct ElideLifetimesVisitor;
354
355        impl VisitMut for ElideLifetimesVisitor {
356            fn visit_lifetime_mut(&mut self, l: &mut syn::Lifetime) {
357                *l = syn::Lifetime::new("'_", l.span());
358            }
359        }
360
361        ElideLifetimesVisitor.visit_type_mut(&mut self);
362        self
363    }
364}
365
366pub fn expr_to_python(expr: &syn::Expr) -> String {
367    match expr {
368        // literal values
369        syn::Expr::Lit(syn::ExprLit { lit, .. }) => match lit {
370            syn::Lit::Str(s) => s.token().to_string(),
371            syn::Lit::Char(c) => c.token().to_string(),
372            syn::Lit::Int(i) => i.base10_digits().to_string(),
373            syn::Lit::Float(f) => f.base10_digits().to_string(),
374            syn::Lit::Bool(b) => {
375                if b.value() {
376                    "True".to_string()
377                } else {
378                    "False".to_string()
379                }
380            }
381            _ => "...".to_string(),
382        },
383        // None
384        syn::Expr::Path(syn::ExprPath { qself, path, .. })
385            if qself.is_none() && path.is_ident("None") =>
386        {
387            "None".to_string()
388        }
389        // others, unsupported yet so defaults to `...`
390        _ => "...".to_string(),
391    }
392}