1use crate::attributes::KeywordAttribute;
2use crate::combine_errors::CombineErrors;
3#[cfg(feature = "experimental-inspect")]
4use crate::introspection::{function_introspection_code, introspection_id_const};
5use crate::utils::{Ctx, LitCStr};
6use crate::{
7 attributes::{
8 self, get_pyo3_options, take_attributes, take_pyo3_options, CrateAttribute,
9 FromPyWithAttribute, NameAttribute, TextSignatureAttribute,
10 },
11 method::{self, CallingConvention, FnArg},
12 pymethod::check_generic,
13};
14use proc_macro2::{Span, TokenStream};
15use quote::{format_ident, quote, ToTokens};
16use std::cmp::PartialEq;
17use std::ffi::CString;
18use syn::parse::{Parse, ParseStream};
19use syn::punctuated::Punctuated;
20use syn::{ext::IdentExt, spanned::Spanned, LitStr, Path, Result, Token};
21
22mod signature;
23
24pub use self::signature::{ConstructorAttribute, FunctionSignature, SignatureAttribute};
25
26#[derive(Clone, Debug)]
27pub struct PyFunctionArgPyO3Attributes {
28 pub from_py_with: Option<FromPyWithAttribute>,
29 pub cancel_handle: Option<attributes::kw::cancel_handle>,
30}
31
32enum PyFunctionArgPyO3Attribute {
33 FromPyWith(FromPyWithAttribute),
34 CancelHandle(attributes::kw::cancel_handle),
35}
36
37impl Parse for PyFunctionArgPyO3Attribute {
38 fn parse(input: ParseStream<'_>) -> Result<Self> {
39 let lookahead = input.lookahead1();
40 if lookahead.peek(attributes::kw::cancel_handle) {
41 input.parse().map(PyFunctionArgPyO3Attribute::CancelHandle)
42 } else if lookahead.peek(attributes::kw::from_py_with) {
43 input.parse().map(PyFunctionArgPyO3Attribute::FromPyWith)
44 } else {
45 Err(lookahead.error())
46 }
47 }
48}
49
50impl PyFunctionArgPyO3Attributes {
51 pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
53 let mut attributes = PyFunctionArgPyO3Attributes {
54 from_py_with: None,
55 cancel_handle: None,
56 };
57 take_attributes(attrs, |attr| {
58 if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
59 for attr in pyo3_attrs {
60 match attr {
61 PyFunctionArgPyO3Attribute::FromPyWith(from_py_with) => {
62 ensure_spanned!(
63 attributes.from_py_with.is_none(),
64 from_py_with.span() => "`from_py_with` may only be specified once per argument"
65 );
66 attributes.from_py_with = Some(from_py_with);
67 }
68 PyFunctionArgPyO3Attribute::CancelHandle(cancel_handle) => {
69 ensure_spanned!(
70 attributes.cancel_handle.is_none(),
71 cancel_handle.span() => "`cancel_handle` may only be specified once per argument"
72 );
73 attributes.cancel_handle = Some(cancel_handle);
74 }
75 }
76 ensure_spanned!(
77 attributes.from_py_with.is_none() || attributes.cancel_handle.is_none(),
78 attributes.cancel_handle.unwrap().span() => "`from_py_with` and `cancel_handle` cannot be specified together"
79 );
80 }
81 Ok(true)
82 } else {
83 Ok(false)
84 }
85 })?;
86 Ok(attributes)
87 }
88}
89
90type PyFunctionWarningMessageAttribute = KeywordAttribute<attributes::kw::message, LitStr>;
91type PyFunctionWarningCategoryAttribute = KeywordAttribute<attributes::kw::category, Path>;
92
93pub struct PyFunctionWarningAttribute {
94 pub message: PyFunctionWarningMessageAttribute,
95 pub category: Option<PyFunctionWarningCategoryAttribute>,
96 pub span: Span,
97}
98
99#[derive(PartialEq, Clone)]
100pub enum PyFunctionWarningCategory {
101 Path(Path),
102 UserWarning,
103 DeprecationWarning, }
105
106#[derive(Clone)]
107pub struct PyFunctionWarning {
108 pub message: LitStr,
109 pub category: PyFunctionWarningCategory,
110 pub span: Span,
111}
112
113impl From<PyFunctionWarningAttribute> for PyFunctionWarning {
114 fn from(value: PyFunctionWarningAttribute) -> Self {
115 Self {
116 message: value.message.value,
117 category: value
118 .category
119 .map_or(PyFunctionWarningCategory::UserWarning, |cat| {
120 PyFunctionWarningCategory::Path(cat.value)
121 }),
122 span: value.span,
123 }
124 }
125}
126
127pub trait WarningFactory {
128 fn build_py_warning(&self, ctx: &Ctx) -> TokenStream;
129 fn span(&self) -> Span;
130}
131
132impl WarningFactory for PyFunctionWarning {
133 fn build_py_warning(&self, ctx: &Ctx) -> TokenStream {
134 let message = &self.message.value();
135 let c_message = LitCStr::new(
136 CString::new(message.clone()).unwrap(),
137 Spanned::span(&message),
138 ctx,
139 );
140 let pyo3_path = &ctx.pyo3_path;
141 let category = match &self.category {
142 PyFunctionWarningCategory::Path(path) => quote! {#path},
143 PyFunctionWarningCategory::UserWarning => {
144 quote! {#pyo3_path::exceptions::PyUserWarning}
145 }
146 PyFunctionWarningCategory::DeprecationWarning => {
147 quote! {#pyo3_path::exceptions::PyDeprecationWarning}
148 }
149 };
150 quote! {
151 #pyo3_path::PyErr::warn(py, &<#category as #pyo3_path::PyTypeInfo>::type_object(py), #c_message, 1)?;
152 }
153 }
154
155 fn span(&self) -> Span {
156 self.span
157 }
158}
159
160impl<T: WarningFactory> WarningFactory for Vec<T> {
161 fn build_py_warning(&self, ctx: &Ctx) -> TokenStream {
162 let warnings = self.iter().map(|warning| warning.build_py_warning(ctx));
163
164 quote! {
165 #(#warnings)*
166 }
167 }
168
169 fn span(&self) -> Span {
170 self.iter()
171 .map(|val| val.span())
172 .reduce(|acc, span| acc.join(span).unwrap_or(acc))
173 .unwrap()
174 }
175}
176
177impl Parse for PyFunctionWarningAttribute {
178 fn parse(input: ParseStream<'_>) -> Result<Self> {
179 let mut message: Option<PyFunctionWarningMessageAttribute> = None;
180 let mut category: Option<PyFunctionWarningCategoryAttribute> = None;
181
182 let span = input.parse::<attributes::kw::warn>()?.span();
183
184 let content;
185 syn::parenthesized!(content in input);
186
187 while !content.is_empty() {
188 let lookahead = content.lookahead1();
189
190 if lookahead.peek(attributes::kw::message) {
191 message = content
192 .parse::<PyFunctionWarningMessageAttribute>()
193 .map(Some)?;
194 } else if lookahead.peek(attributes::kw::category) {
195 category = content
196 .parse::<PyFunctionWarningCategoryAttribute>()
197 .map(Some)?;
198 } else {
199 return Err(lookahead.error());
200 }
201
202 if content.peek(Token![,]) {
203 content.parse::<Token![,]>()?;
204 }
205 }
206
207 Ok(PyFunctionWarningAttribute {
208 message: message.ok_or(syn::Error::new(
209 content.span(),
210 "missing `message` in `warn` attribute",
211 ))?,
212 category,
213 span,
214 })
215 }
216}
217
218impl ToTokens for PyFunctionWarningAttribute {
219 fn to_tokens(&self, tokens: &mut TokenStream) {
220 let message_tokens = self.message.to_token_stream();
221 let category_tokens = self
222 .category
223 .as_ref()
224 .map_or(quote! {}, |cat| cat.to_token_stream());
225
226 let token_stream = quote! {
227 warn(#message_tokens, #category_tokens)
228 };
229
230 tokens.extend(token_stream);
231 }
232}
233
234#[derive(Default)]
235pub struct PyFunctionOptions {
236 pub pass_module: Option<attributes::kw::pass_module>,
237 pub name: Option<NameAttribute>,
238 pub signature: Option<SignatureAttribute>,
239 pub text_signature: Option<TextSignatureAttribute>,
240 pub krate: Option<CrateAttribute>,
241 pub warnings: Vec<PyFunctionWarning>,
242}
243
244impl Parse for PyFunctionOptions {
245 fn parse(input: ParseStream<'_>) -> Result<Self> {
246 let mut options = PyFunctionOptions::default();
247
248 let attrs = Punctuated::<PyFunctionOption, syn::Token![,]>::parse_terminated(input)?;
249 options.add_attributes(attrs)?;
250
251 Ok(options)
252 }
253}
254
255pub enum PyFunctionOption {
256 Name(NameAttribute),
257 PassModule(attributes::kw::pass_module),
258 Signature(SignatureAttribute),
259 TextSignature(TextSignatureAttribute),
260 Crate(CrateAttribute),
261 Warning(PyFunctionWarningAttribute),
262}
263
264impl Parse for PyFunctionOption {
265 fn parse(input: ParseStream<'_>) -> Result<Self> {
266 let lookahead = input.lookahead1();
267 if lookahead.peek(attributes::kw::name) {
268 input.parse().map(PyFunctionOption::Name)
269 } else if lookahead.peek(attributes::kw::pass_module) {
270 input.parse().map(PyFunctionOption::PassModule)
271 } else if lookahead.peek(attributes::kw::signature) {
272 input.parse().map(PyFunctionOption::Signature)
273 } else if lookahead.peek(attributes::kw::text_signature) {
274 input.parse().map(PyFunctionOption::TextSignature)
275 } else if lookahead.peek(syn::Token![crate]) {
276 input.parse().map(PyFunctionOption::Crate)
277 } else if lookahead.peek(attributes::kw::warn) {
278 input.parse().map(PyFunctionOption::Warning)
279 } else {
280 Err(lookahead.error())
281 }
282 }
283}
284
285impl PyFunctionOptions {
286 pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
287 let mut options = PyFunctionOptions::default();
288 options.add_attributes(take_pyo3_options(attrs)?)?;
289 Ok(options)
290 }
291
292 pub fn add_attributes(
293 &mut self,
294 attrs: impl IntoIterator<Item = PyFunctionOption>,
295 ) -> Result<()> {
296 macro_rules! set_option {
297 ($key:ident) => {
298 {
299 ensure_spanned!(
300 self.$key.is_none(),
301 $key.span() => concat!("`", stringify!($key), "` may only be specified once")
302 );
303 self.$key = Some($key);
304 }
305 };
306 }
307 for attr in attrs {
308 match attr {
309 PyFunctionOption::Name(name) => set_option!(name),
310 PyFunctionOption::PassModule(pass_module) => set_option!(pass_module),
311 PyFunctionOption::Signature(signature) => set_option!(signature),
312 PyFunctionOption::TextSignature(text_signature) => set_option!(text_signature),
313 PyFunctionOption::Crate(krate) => set_option!(krate),
314 PyFunctionOption::Warning(warning) => {
315 self.warnings.push(warning.into());
316 }
317 }
318 }
319 Ok(())
320 }
321}
322
323pub fn build_py_function(
324 ast: &mut syn::ItemFn,
325 mut options: PyFunctionOptions,
326) -> syn::Result<TokenStream> {
327 options.add_attributes(take_pyo3_options(&mut ast.attrs)?)?;
328 impl_wrap_pyfunction(ast, options)
329}
330
331pub fn impl_wrap_pyfunction(
334 func: &mut syn::ItemFn,
335 options: PyFunctionOptions,
336) -> syn::Result<TokenStream> {
337 check_generic(&func.sig)?;
338 let PyFunctionOptions {
339 pass_module,
340 name,
341 signature,
342 text_signature,
343 krate,
344 warnings,
345 } = options;
346
347 let ctx = &Ctx::new(&krate, Some(&func.sig));
348 let Ctx { pyo3_path, .. } = &ctx;
349
350 let python_name = name
351 .as_ref()
352 .map_or_else(|| &func.sig.ident, |name| &name.value.0)
353 .unraw();
354
355 let tp = if pass_module.is_some() {
356 let span = match func.sig.inputs.first() {
357 Some(syn::FnArg::Typed(first_arg)) => first_arg.ty.span(),
358 Some(syn::FnArg::Receiver(_)) | None => bail_spanned!(
359 func.sig.paren_token.span.join() => "expected `&PyModule` or `Py<PyModule>` as first argument with `pass_module`"
360 ),
361 };
362 method::FnType::FnModule(span)
363 } else {
364 method::FnType::FnStatic
365 };
366
367 let arguments = func
368 .sig
369 .inputs
370 .iter_mut()
371 .skip(if tp.skip_first_rust_argument_in_python_signature() {
372 1
373 } else {
374 0
375 })
376 .map(FnArg::parse)
377 .try_combine_syn_errors()?;
378
379 let signature = if let Some(signature) = signature {
380 FunctionSignature::from_arguments_and_attribute(arguments, signature)?
381 } else {
382 FunctionSignature::from_arguments(arguments)
383 };
384
385 let vis = &func.vis;
386 let name = &func.sig.ident;
387
388 #[cfg(feature = "experimental-inspect")]
389 let introspection = function_introspection_code(
390 pyo3_path,
391 Some(name),
392 &name.to_string(),
393 &signature,
394 None,
395 func.sig.output.clone(),
396 [] as [String; 0],
397 None,
398 );
399 #[cfg(not(feature = "experimental-inspect"))]
400 let introspection = quote! {};
401 #[cfg(feature = "experimental-inspect")]
402 let introspection_id = introspection_id_const();
403 #[cfg(not(feature = "experimental-inspect"))]
404 let introspection_id = quote! {};
405
406 let spec = method::FnSpec {
407 tp,
408 name: &func.sig.ident,
409 convention: CallingConvention::from_signature(&signature),
410 python_name,
411 signature,
412 text_signature,
413 asyncness: func.sig.asyncness,
414 unsafety: func.sig.unsafety,
415 warnings,
416 #[cfg(feature = "experimental-inspect")]
417 output: func.sig.output.clone(),
418 };
419
420 let wrapper_ident = format_ident!("__pyfunction_{}", spec.name);
421 if spec.asyncness.is_some() {
422 ensure_spanned!(
423 cfg!(feature = "experimental-async"),
424 spec.asyncness.span() => "async functions are only supported with the `experimental-async` feature"
425 );
426 }
427 let wrapper = spec.get_wrapper_function(&wrapper_ident, None, ctx)?;
428 let methoddef = spec.get_methoddef(wrapper_ident, &spec.get_doc(&func.attrs, ctx)?, ctx);
429
430 let wrapped_pyfunction = quote! {
431 #[doc(hidden)]
434 #vis mod #name {
435 pub(crate) struct MakeDef;
436 pub const _PYO3_DEF: #pyo3_path::impl_::pymethods::PyMethodDef = MakeDef::_PYO3_DEF;
437 #introspection_id
438 }
439
440 #[allow(unknown_lints, non_local_definitions)]
445 impl #name::MakeDef {
446 const _PYO3_DEF: #pyo3_path::impl_::pymethods::PyMethodDef = #methoddef;
447 }
448
449 #[allow(non_snake_case)]
450 #wrapper
451
452 #introspection
453 };
454 Ok(wrapped_pyfunction)
455}