pyo3_macros_backend/
utils.rs

1use crate::attributes::{CrateAttribute, RenamingRule};
2use proc_macro2::{Span, TokenStream};
3use quote::{quote, quote_spanned, ToTokens};
4use std::ffi::CString;
5use syn::spanned::Spanned;
6use syn::{punctuated::Punctuated, Token};
7
8/// Macro inspired by `anyhow::anyhow!` to create a compiler error with the given span.
9macro_rules! err_spanned {
10    ($span:expr => $msg:expr) => {
11        syn::Error::new($span, $msg)
12    };
13}
14
15/// Macro inspired by `anyhow::bail!` to return a compiler error with the given span.
16macro_rules! bail_spanned {
17    ($span:expr => $msg:expr) => {
18        return Err(err_spanned!($span => $msg))
19    };
20}
21
22/// Macro inspired by `anyhow::ensure!` to return a compiler error with the given span if the
23/// specified condition is not met.
24macro_rules! ensure_spanned {
25    ($condition:expr, $span:expr => $msg:expr) => {
26        if !($condition) {
27            bail_spanned!($span => $msg);
28        }
29    };
30    ($($condition:expr, $span:expr => $msg:expr;)*) => {
31        if let Some(e) = [$(
32            (!($condition)).then(|| err_spanned!($span => $msg)),
33        )*]
34            .into_iter()
35            .flatten()
36            .reduce(|mut acc, e| {
37                acc.combine(e);
38                acc
39            }) {
40                return Err(e);
41            }
42    };
43}
44
45/// Check if the given type `ty` is `pyo3::Python`.
46pub fn is_python(ty: &syn::Type) -> bool {
47    match unwrap_ty_group(ty) {
48        syn::Type::Path(typath) => typath
49            .path
50            .segments
51            .last()
52            .map(|seg| seg.ident == "Python")
53            .unwrap_or(false),
54        _ => false,
55    }
56}
57
58/// If `ty` is `Option<T>`, return `Some(T)`, else `None`.
59pub fn option_type_argument(ty: &syn::Type) -> Option<&syn::Type> {
60    if let syn::Type::Path(syn::TypePath { path, .. }) = ty {
61        let seg = path.segments.last().filter(|s| s.ident == "Option")?;
62        if let syn::PathArguments::AngleBracketed(params) = &seg.arguments {
63            if let syn::GenericArgument::Type(ty) = params.args.first()? {
64                return Some(ty);
65            }
66        }
67    }
68    None
69}
70
71// TODO: Replace usage of this by [`syn::LitCStr`] when on MSRV 1.77
72#[derive(Clone)]
73pub struct LitCStr {
74    lit: CString,
75    span: Span,
76    pyo3_path: PyO3CratePath,
77}
78
79impl LitCStr {
80    pub fn new(lit: CString, span: Span, ctx: &Ctx) -> Self {
81        Self {
82            lit,
83            span,
84            pyo3_path: ctx.pyo3_path.clone(),
85        }
86    }
87
88    pub fn empty(ctx: &Ctx) -> Self {
89        Self {
90            lit: CString::new("").unwrap(),
91            span: Span::call_site(),
92            pyo3_path: ctx.pyo3_path.clone(),
93        }
94    }
95}
96
97impl quote::ToTokens for LitCStr {
98    fn to_tokens(&self, tokens: &mut TokenStream) {
99        if cfg!(c_str_lit) {
100            syn::LitCStr::new(&self.lit, self.span).to_tokens(tokens);
101        } else {
102            let pyo3_path = &self.pyo3_path;
103            let lit = self.lit.to_str().unwrap();
104            tokens.extend(quote::quote_spanned!(self.span => #pyo3_path::ffi::c_str!(#lit)));
105        }
106    }
107}
108
109/// A syntax tree which evaluates to a nul-terminated docstring for Python.
110///
111/// Typically the tokens will just be that string, but if the original docs included macro
112/// expressions then the tokens will be a concat!("...", "\n", "\0") expression of the strings and
113/// macro parts. contents such as parse the string contents.
114#[derive(Clone)]
115pub struct PythonDoc(PythonDocKind);
116
117#[derive(Clone)]
118enum PythonDocKind {
119    LitCStr(LitCStr),
120    // There is currently no way to `concat!` c-string literals, we fallback to the `c_str!` macro in
121    // this case.
122    Tokens(TokenStream),
123}
124
125/// Collects all #[doc = "..."] attributes into a TokenStream evaluating to a null-terminated string.
126///
127/// If this doc is for a callable, the provided `text_signature` can be passed to prepend
128/// this to the documentation suitable for Python to extract this into the `__text_signature__`
129/// attribute.
130pub fn get_doc(
131    attrs: &[syn::Attribute],
132    mut text_signature: Option<String>,
133    ctx: &Ctx,
134) -> syn::Result<PythonDoc> {
135    let Ctx { pyo3_path, .. } = ctx;
136    // insert special divider between `__text_signature__` and doc
137    // (assume text_signature is itself well-formed)
138    if let Some(text_signature) = &mut text_signature {
139        text_signature.push_str("\n--\n\n");
140    }
141
142    let mut parts = Punctuated::<TokenStream, Token![,]>::new();
143    let mut first = true;
144    let mut current_part = text_signature.unwrap_or_default();
145    let mut current_part_span = None;
146
147    for attr in attrs {
148        if attr.path().is_ident("doc") {
149            if let Ok(nv) = attr.meta.require_name_value() {
150                current_part_span = match current_part_span {
151                    None => Some(nv.value.span()),
152                    Some(span) => span.join(nv.value.span()),
153                };
154                if !first {
155                    current_part.push('\n');
156                } else {
157                    first = false;
158                }
159                if let syn::Expr::Lit(syn::ExprLit {
160                    lit: syn::Lit::Str(lit_str),
161                    ..
162                }) = &nv.value
163                {
164                    // Strip single left space from literal strings, if needed.
165                    // e.g. `/// Hello world` expands to #[doc = " Hello world"]
166                    let doc_line = lit_str.value();
167                    current_part.push_str(doc_line.strip_prefix(' ').unwrap_or(&doc_line));
168                } else {
169                    // This is probably a macro doc from Rust 1.54, e.g. #[doc = include_str!(...)]
170                    // Reset the string buffer, write that part, and then push this macro part too.
171                    parts.push(quote_spanned!(current_part_span.unwrap_or(Span::call_site()) => #current_part));
172                    current_part.clear();
173                    parts.push(nv.value.to_token_stream());
174                }
175            }
176        }
177    }
178
179    if !parts.is_empty() {
180        // Doc contained macro pieces - return as `concat!` expression
181        if !current_part.is_empty() {
182            parts.push(
183                quote_spanned!(current_part_span.unwrap_or(Span::call_site()) => #current_part),
184            );
185        }
186
187        let mut tokens = TokenStream::new();
188
189        syn::Ident::new("concat", Span::call_site()).to_tokens(&mut tokens);
190        syn::token::Not(Span::call_site()).to_tokens(&mut tokens);
191        syn::token::Bracket(Span::call_site()).surround(&mut tokens, |tokens| {
192            parts.to_tokens(tokens);
193            syn::token::Comma(Span::call_site()).to_tokens(tokens);
194        });
195
196        Ok(PythonDoc(PythonDocKind::Tokens(
197            quote!(#pyo3_path::ffi::c_str!(#tokens)),
198        )))
199    } else {
200        // Just a string doc - return directly with nul terminator
201        let docs = CString::new(current_part).map_err(|e| {
202            syn::Error::new(
203                current_part_span.unwrap_or(Span::call_site()),
204                format!(
205                    "Python doc may not contain nul byte, found nul at position {}",
206                    e.nul_position()
207                ),
208            )
209        })?;
210        Ok(PythonDoc(PythonDocKind::LitCStr(LitCStr::new(
211            docs,
212            current_part_span.unwrap_or(Span::call_site()),
213            ctx,
214        ))))
215    }
216}
217
218impl quote::ToTokens for PythonDoc {
219    fn to_tokens(&self, tokens: &mut TokenStream) {
220        match &self.0 {
221            PythonDocKind::LitCStr(lit) => lit.to_tokens(tokens),
222            PythonDocKind::Tokens(toks) => toks.to_tokens(tokens),
223        }
224    }
225}
226
227pub fn unwrap_ty_group(mut ty: &syn::Type) -> &syn::Type {
228    while let syn::Type::Group(g) = ty {
229        ty = &*g.elem;
230    }
231    ty
232}
233
234pub struct Ctx {
235    /// Where we can find the pyo3 crate
236    pub pyo3_path: PyO3CratePath,
237
238    /// If we are in a pymethod or pyfunction,
239    /// this will be the span of the return type
240    pub output_span: Span,
241}
242
243impl Ctx {
244    pub(crate) fn new(attr: &Option<CrateAttribute>, signature: Option<&syn::Signature>) -> Self {
245        let pyo3_path = match attr {
246            Some(attr) => PyO3CratePath::Given(attr.value.0.clone()),
247            None => PyO3CratePath::Default,
248        };
249
250        let output_span = if let Some(syn::Signature {
251            output: syn::ReturnType::Type(_, output_type),
252            ..
253        }) = &signature
254        {
255            output_type.span()
256        } else {
257            Span::call_site()
258        };
259
260        Self {
261            pyo3_path,
262            output_span,
263        }
264    }
265}
266
267#[derive(Clone)]
268pub enum PyO3CratePath {
269    Given(syn::Path),
270    Default,
271}
272
273impl PyO3CratePath {
274    pub fn to_tokens_spanned(&self, span: Span) -> TokenStream {
275        match self {
276            Self::Given(path) => quote::quote_spanned! { span => #path },
277            Self::Default => quote::quote_spanned! {  span => ::pyo3 },
278        }
279    }
280}
281
282impl quote::ToTokens for PyO3CratePath {
283    fn to_tokens(&self, tokens: &mut TokenStream) {
284        match self {
285            Self::Given(path) => path.to_tokens(tokens),
286            Self::Default => quote::quote! { ::pyo3 }.to_tokens(tokens),
287        }
288    }
289}
290
291pub fn apply_renaming_rule(rule: RenamingRule, name: &str) -> String {
292    use heck::*;
293
294    match rule {
295        RenamingRule::CamelCase => name.to_lower_camel_case(),
296        RenamingRule::KebabCase => name.to_kebab_case(),
297        RenamingRule::Lowercase => name.to_lowercase(),
298        RenamingRule::PascalCase => name.to_upper_camel_case(),
299        RenamingRule::ScreamingKebabCase => name.to_shouty_kebab_case(),
300        RenamingRule::ScreamingSnakeCase => name.to_shouty_snake_case(),
301        RenamingRule::SnakeCase => name.to_snake_case(),
302        RenamingRule::Uppercase => name.to_uppercase(),
303    }
304}
305
306pub(crate) enum IdentOrStr<'a> {
307    Str(&'a str),
308    Ident(syn::Ident),
309}
310
311pub(crate) fn has_attribute(attrs: &[syn::Attribute], ident: &str) -> bool {
312    has_attribute_with_namespace(attrs, None, &[ident])
313}
314
315pub(crate) fn has_attribute_with_namespace(
316    attrs: &[syn::Attribute],
317    crate_path: Option<&PyO3CratePath>,
318    idents: &[&str],
319) -> bool {
320    let mut segments = vec![];
321    if let Some(c) = crate_path {
322        match c {
323            PyO3CratePath::Given(paths) => {
324                for p in &paths.segments {
325                    segments.push(IdentOrStr::Ident(p.ident.clone()));
326                }
327            }
328            PyO3CratePath::Default => segments.push(IdentOrStr::Str("pyo3")),
329        }
330    };
331    for i in idents {
332        segments.push(IdentOrStr::Str(i));
333    }
334
335    attrs.iter().any(|attr| {
336        segments
337            .iter()
338            .eq(attr.path().segments.iter().map(|v| &v.ident))
339    })
340}
341
342pub fn expr_to_python(expr: &syn::Expr) -> String {
343    match expr {
344        // literal values
345        syn::Expr::Lit(syn::ExprLit { lit, .. }) => match lit {
346            syn::Lit::Str(s) => s.token().to_string(),
347            syn::Lit::Char(c) => c.token().to_string(),
348            syn::Lit::Int(i) => i.base10_digits().to_string(),
349            syn::Lit::Float(f) => f.base10_digits().to_string(),
350            syn::Lit::Bool(b) => {
351                if b.value() {
352                    "True".to_string()
353                } else {
354                    "False".to_string()
355                }
356            }
357            _ => "...".to_string(),
358        },
359        // None
360        syn::Expr::Path(syn::ExprPath { qself, path, .. })
361            if qself.is_none() && path.is_ident("None") =>
362        {
363            "None".to_string()
364        }
365        // others, unsupported yet so defaults to `...`
366        _ => "...".to_string(),
367    }
368}