1use crate::attributes::KeywordAttribute;
2use crate::combine_errors::CombineErrors;
3#[cfg(feature = "experimental-inspect")]
4use crate::introspection::{function_introspection_code, introspection_id_const};
5#[cfg(feature = "experimental-inspect")]
6use crate::utils::get_doc;
7use crate::utils::Ctx;
8use crate::{
9 attributes::{
10 self, get_pyo3_options, take_attributes, take_pyo3_options, CrateAttribute,
11 FromPyWithAttribute, NameAttribute, TextSignatureAttribute,
12 },
13 method::{self, CallingConvention, FnArg},
14 pymethod::check_generic,
15};
16use proc_macro2::{Span, TokenStream};
17use quote::{format_ident, quote, ToTokens};
18use std::cmp::PartialEq;
19use std::ffi::CString;
20#[cfg(feature = "experimental-inspect")]
21use std::iter::empty;
22use syn::parse::{Parse, ParseStream};
23use syn::punctuated::Punctuated;
24use syn::LitCStr;
25use syn::{ext::IdentExt, spanned::Spanned, LitStr, Path, Result, Token};
26
27mod signature;
28
29pub use self::signature::{ConstructorAttribute, FunctionSignature, SignatureAttribute};
30
31#[derive(Clone, Debug)]
32pub struct PyFunctionArgPyForgeAttributes {
33 pub from_py_with: Option<FromPyWithAttribute>,
34 pub cancel_handle: Option<attributes::kw::cancel_handle>,
35}
36
37enum PyFunctionArgPyForgeAttribute {
38 FromPyWith(FromPyWithAttribute),
39 CancelHandle(attributes::kw::cancel_handle),
40}
41
42impl Parse for PyFunctionArgPyForgeAttribute {
43 fn parse(input: ParseStream<'_>) -> Result<Self> {
44 let lookahead = input.lookahead1();
45 if lookahead.peek(attributes::kw::cancel_handle) {
46 input.parse().map(PyFunctionArgPyForgeAttribute::CancelHandle)
47 } else if lookahead.peek(attributes::kw::from_py_with) {
48 input.parse().map(PyFunctionArgPyForgeAttribute::FromPyWith)
49 } else {
50 Err(lookahead.error())
51 }
52 }
53}
54
55impl PyFunctionArgPyForgeAttributes {
56 pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
58 let mut attributes = PyFunctionArgPyForgeAttributes {
59 from_py_with: None,
60 cancel_handle: None,
61 };
62 take_attributes(attrs, |attr| {
63 if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
64 for attr in pyo3_attrs {
65 match attr {
66 PyFunctionArgPyForgeAttribute::FromPyWith(from_py_with) => {
67 ensure_spanned!(
68 attributes.from_py_with.is_none(),
69 from_py_with.span() => "`from_py_with` may only be specified once per argument"
70 );
71 attributes.from_py_with = Some(from_py_with);
72 }
73 PyFunctionArgPyForgeAttribute::CancelHandle(cancel_handle) => {
74 ensure_spanned!(
75 attributes.cancel_handle.is_none(),
76 cancel_handle.span() => "`cancel_handle` may only be specified once per argument"
77 );
78 attributes.cancel_handle = Some(cancel_handle);
79 }
80 }
81 ensure_spanned!(
82 attributes.from_py_with.is_none() || attributes.cancel_handle.is_none(),
83 attributes.cancel_handle.unwrap().span() => "`from_py_with` and `cancel_handle` cannot be specified together"
84 );
85 }
86 Ok(true)
87 } else {
88 Ok(false)
89 }
90 })?;
91 Ok(attributes)
92 }
93}
94
95type PyFunctionWarningMessageAttribute = KeywordAttribute<attributes::kw::message, LitStr>;
96type PyFunctionWarningCategoryAttribute = KeywordAttribute<attributes::kw::category, Path>;
97
98pub struct PyFunctionWarningAttribute {
99 pub message: PyFunctionWarningMessageAttribute,
100 pub category: Option<PyFunctionWarningCategoryAttribute>,
101 pub span: Span,
102}
103
104#[derive(PartialEq, Clone)]
105pub enum PyFunctionWarningCategory {
106 Path(Path),
107 UserWarning,
108 DeprecationWarning, }
110
111#[derive(Clone)]
112pub struct PyFunctionWarning {
113 pub message: LitStr,
114 pub category: PyFunctionWarningCategory,
115 pub span: Span,
116}
117
118impl From<PyFunctionWarningAttribute> for PyFunctionWarning {
119 fn from(value: PyFunctionWarningAttribute) -> Self {
120 Self {
121 message: value.message.value,
122 category: value
123 .category
124 .map_or(PyFunctionWarningCategory::UserWarning, |cat| {
125 PyFunctionWarningCategory::Path(cat.value)
126 }),
127 span: value.span,
128 }
129 }
130}
131
132pub trait WarningFactory {
133 fn build_py_warning(&self, ctx: &Ctx) -> TokenStream;
134 fn span(&self) -> Span;
135}
136
137impl WarningFactory for PyFunctionWarning {
138 fn build_py_warning(&self, ctx: &Ctx) -> TokenStream {
139 let message = &self.message.value();
140 let c_message = LitCStr::new(
141 &CString::new(message.clone()).unwrap(),
142 Spanned::span(&message),
143 );
144 let pyo3_path = &ctx.pyo3_path;
145 let category = match &self.category {
146 PyFunctionWarningCategory::Path(path) => quote! {#path},
147 PyFunctionWarningCategory::UserWarning => {
148 quote! {#pyo3_path::exceptions::PyUserWarning}
149 }
150 PyFunctionWarningCategory::DeprecationWarning => {
151 quote! {#pyo3_path::exceptions::PyDeprecationWarning}
152 }
153 };
154 quote! {
155 #pyo3_path::PyErr::warn(py, &<#category as #pyo3_path::PyTypeInfo>::type_object(py), #c_message, 1)?;
156 }
157 }
158
159 fn span(&self) -> Span {
160 self.span
161 }
162}
163
164impl<T: WarningFactory> WarningFactory for Vec<T> {
165 fn build_py_warning(&self, ctx: &Ctx) -> TokenStream {
166 let warnings = self.iter().map(|warning| warning.build_py_warning(ctx));
167
168 quote! {
169 #(#warnings)*
170 }
171 }
172
173 fn span(&self) -> Span {
174 self.iter()
175 .map(|val| val.span())
176 .reduce(|acc, span| acc.join(span).unwrap_or(acc))
177 .unwrap()
178 }
179}
180
181impl Parse for PyFunctionWarningAttribute {
182 fn parse(input: ParseStream<'_>) -> Result<Self> {
183 let mut message: Option<PyFunctionWarningMessageAttribute> = None;
184 let mut category: Option<PyFunctionWarningCategoryAttribute> = None;
185
186 let span = input.parse::<attributes::kw::warn>()?.span();
187
188 let content;
189 syn::parenthesized!(content in input);
190
191 while !content.is_empty() {
192 let lookahead = content.lookahead1();
193
194 if lookahead.peek(attributes::kw::message) {
195 message = content
196 .parse::<PyFunctionWarningMessageAttribute>()
197 .map(Some)?;
198 } else if lookahead.peek(attributes::kw::category) {
199 category = content
200 .parse::<PyFunctionWarningCategoryAttribute>()
201 .map(Some)?;
202 } else {
203 return Err(lookahead.error());
204 }
205
206 if content.peek(Token![,]) {
207 content.parse::<Token![,]>()?;
208 }
209 }
210
211 Ok(PyFunctionWarningAttribute {
212 message: message.ok_or(syn::Error::new(
213 content.span(),
214 "missing `message` in `warn` attribute",
215 ))?,
216 category,
217 span,
218 })
219 }
220}
221
222impl ToTokens for PyFunctionWarningAttribute {
223 fn to_tokens(&self, tokens: &mut TokenStream) {
224 let message_tokens = self.message.to_token_stream();
225 let category_tokens = self
226 .category
227 .as_ref()
228 .map_or(quote! {}, |cat| cat.to_token_stream());
229
230 let token_stream = quote! {
231 warn(#message_tokens, #category_tokens)
232 };
233
234 tokens.extend(token_stream);
235 }
236}
237
238#[derive(Default)]
239pub struct PyFunctionOptions {
240 pub pass_module: Option<attributes::kw::pass_module>,
241 pub name: Option<NameAttribute>,
242 pub signature: Option<SignatureAttribute>,
243 pub text_signature: Option<TextSignatureAttribute>,
244 pub krate: Option<CrateAttribute>,
245 pub warnings: Vec<PyFunctionWarning>,
246}
247
248impl Parse for PyFunctionOptions {
249 fn parse(input: ParseStream<'_>) -> Result<Self> {
250 let mut options = PyFunctionOptions::default();
251
252 let attrs = Punctuated::<PyFunctionOption, syn::Token![,]>::parse_terminated(input)?;
253 options.add_attributes(attrs)?;
254
255 Ok(options)
256 }
257}
258
259pub enum PyFunctionOption {
260 Name(NameAttribute),
261 PassModule(attributes::kw::pass_module),
262 Signature(SignatureAttribute),
263 TextSignature(TextSignatureAttribute),
264 Crate(CrateAttribute),
265 Warning(PyFunctionWarningAttribute),
266}
267
268impl Parse for PyFunctionOption {
269 fn parse(input: ParseStream<'_>) -> Result<Self> {
270 let lookahead = input.lookahead1();
271 if lookahead.peek(attributes::kw::name) {
272 input.parse().map(PyFunctionOption::Name)
273 } else if lookahead.peek(attributes::kw::pass_module) {
274 input.parse().map(PyFunctionOption::PassModule)
275 } else if lookahead.peek(attributes::kw::signature) {
276 input.parse().map(PyFunctionOption::Signature)
277 } else if lookahead.peek(attributes::kw::text_signature) {
278 input.parse().map(PyFunctionOption::TextSignature)
279 } else if lookahead.peek(syn::Token![crate]) {
280 input.parse().map(PyFunctionOption::Crate)
281 } else if lookahead.peek(attributes::kw::warn) {
282 input.parse().map(PyFunctionOption::Warning)
283 } else {
284 Err(lookahead.error())
285 }
286 }
287}
288
289impl PyFunctionOptions {
290 pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
291 let mut options = PyFunctionOptions::default();
292 options.add_attributes(take_pyo3_options(attrs)?)?;
293 Ok(options)
294 }
295
296 pub fn add_attributes(
297 &mut self,
298 attrs: impl IntoIterator<Item = PyFunctionOption>,
299 ) -> Result<()> {
300 macro_rules! set_option {
301 ($key:ident) => {
302 {
303 ensure_spanned!(
304 self.$key.is_none(),
305 $key.span() => concat!("`", stringify!($key), "` may only be specified once")
306 );
307 self.$key = Some($key);
308 }
309 };
310 }
311 for attr in attrs {
312 match attr {
313 PyFunctionOption::Name(name) => set_option!(name),
314 PyFunctionOption::PassModule(pass_module) => set_option!(pass_module),
315 PyFunctionOption::Signature(signature) => set_option!(signature),
316 PyFunctionOption::TextSignature(text_signature) => set_option!(text_signature),
317 PyFunctionOption::Crate(krate) => set_option!(krate),
318 PyFunctionOption::Warning(warning) => {
319 self.warnings.push(warning.into());
320 }
321 }
322 }
323 Ok(())
324 }
325}
326
327pub fn build_py_function(
328 ast: &mut syn::ItemFn,
329 mut options: PyFunctionOptions,
330) -> syn::Result<TokenStream> {
331 options.add_attributes(take_pyo3_options(&mut ast.attrs)?)?;
332 impl_wrap_pyfunction(ast, options)
333}
334
335pub fn impl_wrap_pyfunction(
338 func: &mut syn::ItemFn,
339 options: PyFunctionOptions,
340) -> syn::Result<TokenStream> {
341 check_generic(&func.sig)?;
342 let PyFunctionOptions {
343 pass_module,
344 name,
345 signature,
346 text_signature,
347 krate,
348 warnings,
349 } = options;
350
351 let ctx = &Ctx::new(&krate, Some(&func.sig));
352 let Ctx { pyo3_path, .. } = &ctx;
353
354 let python_name = name
355 .as_ref()
356 .map_or_else(|| &func.sig.ident, |name| &name.value.0)
357 .unraw();
358
359 let tp = if pass_module.is_some() {
360 let span = match func.sig.inputs.first() {
361 Some(syn::FnArg::Typed(first_arg)) => first_arg.ty.span(),
362 Some(syn::FnArg::Receiver(_)) | None => bail_spanned!(
363 func.sig.paren_token.span.join() => "expected `&PyModule` or `Py<PyModule>` as first argument with `pass_module`"
364 ),
365 };
366 method::FnType::FnModule(span)
367 } else {
368 method::FnType::FnStatic
369 };
370
371 let arguments = func
372 .sig
373 .inputs
374 .iter_mut()
375 .skip(if tp.skip_first_rust_argument_in_python_signature() {
376 1
377 } else {
378 0
379 })
380 .map(FnArg::parse)
381 .try_combine_syn_errors()?;
382
383 let signature = if let Some(signature) = signature {
384 FunctionSignature::from_arguments_and_attribute(arguments, signature)?
385 } else {
386 FunctionSignature::from_arguments(arguments)
387 };
388
389 let spec = method::FnSpec {
390 tp,
391 name: &func.sig.ident,
392 python_name,
393 signature,
394 text_signature,
395 asyncness: func.sig.asyncness,
396 unsafety: func.sig.unsafety,
397 warnings,
398 output: func.sig.output.clone(),
399 };
400
401 let vis = &func.vis;
402 let name = &func.sig.ident;
403
404 #[cfg(feature = "experimental-inspect")]
405 let introspection = function_introspection_code(
406 pyo3_path,
407 Some(name),
408 &name.to_string(),
409 &spec.signature,
410 None,
411 func.sig.output.clone(),
412 empty(),
413 func.sig.asyncness.is_some(),
414 false,
415 get_doc(&func.attrs, None).as_ref(),
416 None,
417 );
418 #[cfg(not(feature = "experimental-inspect"))]
419 let introspection = quote! {};
420 #[cfg(feature = "experimental-inspect")]
421 let introspection_id = introspection_id_const();
422 #[cfg(not(feature = "experimental-inspect"))]
423 let introspection_id = quote! {};
424
425 let wrapper_ident = format_ident!("__pyfunction_{}", spec.name);
426 let calling_convention = CallingConvention::from_signature(&spec.signature);
428 let wrapper = spec.get_wrapper_function(&wrapper_ident, None, calling_convention, ctx)?;
429 let methoddef = spec.get_methoddef(
430 wrapper_ident,
431 spec.get_doc(&func.attrs).as_ref(),
432 calling_convention,
433 ctx,
434 )?;
435
436 let wrapped_pyfunction = quote! {
437 #[doc(hidden)]
440 #vis mod #name {
441 pub(crate) struct MakeDef;
442 pub static _PYO3_DEF: #pyo3_path::impl_::pyfunction::PyFunctionDef = MakeDef::_PYO3_DEF;
443 #introspection_id
444 }
445
446 #[allow(unknown_lints, non_local_definitions)]
451 impl #name::MakeDef {
452 #[allow(clippy::declare_interior_mutable_const)]
454 const _PYO3_DEF: #pyo3_path::impl_::pyfunction::PyFunctionDef =
455 #pyo3_path::impl_::pyfunction::PyFunctionDef::from_method_def(#methoddef);
456 }
457
458 #[allow(non_snake_case)]
459 #wrapper
460
461 #introspection
462 };
463 Ok(wrapped_pyfunction)
464}