1use crate::method;
5use crate::pyfunction::PyFunctionAttr;
6use crate::pymethod;
7use crate::pymethod::get_arg_names;
8use crate::utils;
9use proc_macro2::{Span, TokenStream};
10use quote::{format_ident, quote};
11use syn::Ident;
12
13pub fn py_init(fnname: &Ident, name: &Ident, doc: syn::LitStr) -> TokenStream {
16 let cb_name = Ident::new(&format!("PyInit_{}", name), Span::call_site());
17
18 quote! {
19 #[no_mangle]
20 #[allow(non_snake_case)]
21 pub unsafe extern "C" fn #cb_name() -> *mut pyo3::ffi::PyObject {
24 use pyo3::derive_utils::ModuleDef;
25 const NAME: &'static str = concat!(stringify!(#name), "\0");
26 static MODULE_DEF: ModuleDef = unsafe { ModuleDef::new(NAME) };
27
28 pyo3::callback_body!(_py, { MODULE_DEF.make_module(#doc, #fnname) })
29 }
30 }
31}
32
33pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> {
35 let mut stmts: Vec<syn::Stmt> = Vec::new();
36
37 for stmt in func.block.stmts.iter_mut() {
38 if let syn::Stmt::Item(syn::Item::Fn(ref mut func)) = stmt {
39 if let Some((module_name, python_name, pyfn_attrs)) =
40 extract_pyfn_attrs(&mut func.attrs)?
41 {
42 let function_to_python = add_fn_to_module(func, python_name, pyfn_attrs)?;
43 let function_wrapper_ident = function_wrapper_ident(&func.sig.ident);
44 let item: syn::ItemFn = syn::parse_quote! {
45 fn block_wrapper() {
46 #function_to_python
47 #module_name.add_function(#function_wrapper_ident(#module_name)?)?;
48 }
49 };
50 stmts.extend(item.block.stmts.into_iter());
51 }
52 };
53 stmts.push(stmt.clone());
54 }
55
56 func.block.stmts = stmts;
57 Ok(())
58}
59
60fn wrap_fn_argument<'a>(cap: &'a syn::PatType) -> syn::Result<method::FnArg<'a>> {
62 let (mutability, by_ref, ident) = match *cap.pat {
63 syn::Pat::Ident(ref patid) => (&patid.mutability, &patid.by_ref, &patid.ident),
64 _ => return Err(syn::Error::new_spanned(&cap.pat, "Unsupported argument")),
65 };
66
67 Ok(method::FnArg {
68 name: ident,
69 mutability,
70 by_ref,
71 ty: &cap.ty,
72 optional: utils::option_type_argument(&cap.ty),
73 py: utils::is_python(&cap.ty),
74 })
75}
76
77fn extract_pyfn_attrs(
79 attrs: &mut Vec<syn::Attribute>,
80) -> syn::Result<Option<(syn::Path, Ident, PyFunctionAttr)>> {
81 let mut new_attrs = Vec::new();
82 let mut fnname = None;
83 let mut modname = None;
84 let mut fn_attrs = PyFunctionAttr::default();
85
86 for attr in attrs.iter() {
87 match attr.parse_meta() {
88 Ok(syn::Meta::List(ref list)) if list.path.is_ident("pyfn") => {
89 let meta: Vec<_> = list.nested.iter().cloned().collect();
90 if meta.len() >= 2 {
91 match meta[0] {
93 syn::NestedMeta::Meta(syn::Meta::Path(ref path)) => {
94 modname = Some(path.clone())
95 }
96 _ => {
97 return Err(syn::Error::new_spanned(
98 &meta[0],
99 "The first parameter of pyfn must be a MetaItem",
100 ))
101 }
102 }
103 match meta[1] {
105 syn::NestedMeta::Lit(syn::Lit::Str(ref lits)) => {
106 fnname = Some(syn::Ident::new(&lits.value(), lits.span()));
107 }
108 _ => {
109 return Err(syn::Error::new_spanned(
110 &meta[1],
111 "The second parameter of pyfn must be a Literal",
112 ))
113 }
114 }
115 if list.nested.len() >= 3 {
117 fn_attrs = PyFunctionAttr::from_meta(&meta[2..meta.len()])?;
118 }
119 } else {
120 return Err(syn::Error::new_spanned(
121 attr,
122 format!("can not parse 'pyfn' params {:?}", attr),
123 ));
124 }
125 }
126 _ => new_attrs.push(attr.clone()),
127 }
128 }
129
130 *attrs = new_attrs;
131 match (modname, fnname) {
132 (Some(modname), Some(fnname)) => Ok(Some((modname, fnname, fn_attrs))),
133 _ => Ok(None),
134 }
135}
136
137fn function_wrapper_ident(name: &Ident) -> Ident {
139 format_ident!("__pyo3_get_function_{}", name)
141}
142
143pub fn add_fn_to_module(
146 func: &mut syn::ItemFn,
147 python_name: Ident,
148 pyfn_attrs: PyFunctionAttr,
149) -> syn::Result<TokenStream> {
150 let mut arguments = Vec::new();
151
152 for (i, input) in func.sig.inputs.iter().enumerate() {
153 match input {
154 syn::FnArg::Receiver(_) => {
155 return Err(syn::Error::new_spanned(
156 input,
157 "Unexpected receiver for #[pyfn]",
158 ))
159 }
160 syn::FnArg::Typed(ref cap) => {
161 if pyfn_attrs.pass_module && i == 0 {
162 if let syn::Type::Reference(tyref) = cap.ty.as_ref() {
163 if let syn::Type::Path(typath) = tyref.elem.as_ref() {
164 if typath
165 .path
166 .segments
167 .last()
168 .map(|seg| seg.ident == "PyModule")
169 .unwrap_or(false)
170 {
171 continue;
172 }
173 }
174 }
175 return Err(syn::Error::new_spanned(
176 cap,
177 "Expected &PyModule as first argument with `pass_module`.",
178 ));
179 } else {
180 arguments.push(wrap_fn_argument(cap)?);
181 }
182 }
183 }
184 }
185
186 let ty = method::get_return_info(&func.sig.output);
187
188 let text_signature = utils::parse_text_signature_attrs(&mut func.attrs, &python_name)?;
189 let doc = utils::get_doc(&func.attrs, text_signature, true)?;
190
191 let function_wrapper_ident = function_wrapper_ident(&func.sig.ident);
192
193 let spec = method::FnSpec {
194 tp: method::FnType::FnStatic,
195 name: &function_wrapper_ident,
196 python_name,
197 attrs: pyfn_attrs.arguments,
198 args: arguments,
199 output: ty,
200 doc,
201 };
202
203 let doc = syn::LitByteStr::new(spec.doc.value().as_bytes(), spec.doc.span());
204
205 let python_name = &spec.python_name;
206
207 let name = &func.sig.ident;
208 let wrapper_ident = format_ident!("__pyo3_raw_{}", name);
209 let wrapper = function_c_wrapper(name, &wrapper_ident, &spec, pyfn_attrs.pass_module);
210 Ok(quote! {
211 #wrapper
212 fn #function_wrapper_ident<'a>(
213 args: impl Into<pyo3::derive_utils::PyFunctionArguments<'a>>
214 ) -> pyo3::PyResult<&'a pyo3::types::PyCFunction> {
215 let name = concat!(stringify!(#python_name), "\0");
216 let name = std::ffi::CStr::from_bytes_with_nul(name.as_bytes()).unwrap();
217 let doc = std::ffi::CStr::from_bytes_with_nul(#doc).unwrap();
218 pyo3::types::PyCFunction::internal_new(
219 name,
220 doc,
221 pyo3::class::PyMethodType::PyCFunctionWithKeywords(#wrapper_ident),
222 pyo3::ffi::METH_VARARGS | pyo3::ffi::METH_KEYWORDS,
223 args.into(),
224 )
225 }
226 })
227}
228
229fn function_c_wrapper(
231 name: &Ident,
232 wrapper_ident: &Ident,
233 spec: &method::FnSpec<'_>,
234 pass_module: bool,
235) -> TokenStream {
236 let names: Vec<Ident> = get_arg_names(&spec);
237 let cb;
238 let slf_module;
239 if pass_module {
240 cb = quote! {
241 #name(_slf, #(#names),*)
242 };
243 slf_module = Some(quote! {
244 let _slf = _py.from_borrowed_ptr::<pyo3::types::PyModule>(_slf);
245 });
246 } else {
247 cb = quote! {
248 #name(#(#names),*)
249 };
250 slf_module = None;
251 };
252 let body = pymethod::impl_arg_params(spec, None, cb);
253 quote! {
254 unsafe extern "C" fn #wrapper_ident(
255 _slf: *mut pyo3::ffi::PyObject,
256 _args: *mut pyo3::ffi::PyObject,
257 _kwargs: *mut pyo3::ffi::PyObject) -> *mut pyo3::ffi::PyObject
258 {
259 const _LOCATION: &'static str = concat!(stringify!(#name), "()");
260 pyo3::callback_body!(_py, {
261 #slf_module
262 let _args = _py.from_borrowed_ptr::<pyo3::types::PyTuple>(_args);
263 let _kwargs: Option<&pyo3::types::PyDict> = _py.from_borrowed_ptr_or_opt(_kwargs);
264
265 #body
266 })
267 }
268 }
269}