1use crate::module::add_fn_to_module;
4use proc_macro2::TokenStream;
5use syn::ext::IdentExt;
6use syn::parse::ParseBuffer;
7use syn::punctuated::Punctuated;
8use syn::spanned::Spanned;
9use syn::{NestedMeta, Path};
10
11#[derive(Debug, Clone, PartialEq)]
12pub enum Argument {
13 VarArgsSeparator,
14 VarArgs(syn::Path),
15 KeywordArgs(syn::Path),
16 Arg(syn::Path, Option<String>),
17 Kwarg(syn::Path, Option<String>),
18}
19
20#[derive(Default)]
22pub struct PyFunctionAttr {
23 pub arguments: Vec<Argument>,
24 has_kw: bool,
25 has_varargs: bool,
26 has_kwargs: bool,
27 pub pass_module: bool,
28}
29
30impl syn::parse::Parse for PyFunctionAttr {
31 fn parse(input: &ParseBuffer) -> syn::Result<Self> {
32 let attr = Punctuated::<NestedMeta, syn::Token![,]>::parse_terminated(input)?;
33 Self::from_meta(&attr)
34 }
35}
36
37impl PyFunctionAttr {
38 pub fn from_meta<'a>(iter: impl IntoIterator<Item = &'a NestedMeta>) -> syn::Result<Self> {
39 let mut slf = PyFunctionAttr::default();
40
41 for item in iter {
42 slf.add_item(item)?
43 }
44 Ok(slf)
45 }
46
47 pub fn add_item(&mut self, item: &NestedMeta) -> syn::Result<()> {
48 match item {
49 NestedMeta::Meta(syn::Meta::Path(ref ident)) if ident.is_ident("pass_module") => {
50 self.pass_module = true;
51 }
52 NestedMeta::Meta(syn::Meta::Path(ref ident)) => self.add_work(item, ident)?,
53 NestedMeta::Meta(syn::Meta::NameValue(ref nv)) => {
54 self.add_name_value(item, nv)?;
55 }
56 NestedMeta::Lit(ref lit) => {
57 self.add_literal(item, lit)?;
58 }
59 NestedMeta::Meta(syn::Meta::List(ref list)) => {
60 return Err(syn::Error::new_spanned(
61 list,
62 "List is not supported as argument",
63 ));
64 }
65 }
66 Ok(())
67 }
68
69 fn add_literal(&mut self, item: &NestedMeta, lit: &syn::Lit) -> syn::Result<()> {
70 match lit {
71 syn::Lit::Str(ref lits) if lits.value() == "*" => {
72 self.vararg_is_ok(item)?;
74 self.has_varargs = true;
75 self.arguments.push(Argument::VarArgsSeparator);
76 Ok(())
77 }
78 _ => Err(syn::Error::new_spanned(
79 item,
80 format!("Only \"*\" is supported here, got: {:?}", lit),
81 )),
82 }
83 }
84
85 fn add_work(&mut self, item: &NestedMeta, path: &Path) -> syn::Result<()> {
86 if self.has_kw || self.has_kwargs {
87 return Err(syn::Error::new_spanned(
88 item,
89 "Positional argument or varargs(*) is not allowed after keyword arguments",
90 ));
91 }
92 if self.has_varargs {
93 self.arguments.push(Argument::Kwarg(path.clone(), None));
94 } else {
95 self.arguments.push(Argument::Arg(path.clone(), None));
96 }
97 Ok(())
98 }
99
100 fn vararg_is_ok(&self, item: &NestedMeta) -> syn::Result<()> {
101 if self.has_kwargs || self.has_varargs {
102 return Err(syn::Error::new_spanned(
103 item,
104 "* is not allowed after varargs(*) or kwargs(**)",
105 ));
106 }
107 Ok(())
108 }
109
110 fn kw_arg_is_ok(&self, item: &NestedMeta) -> syn::Result<()> {
111 if self.has_kwargs {
112 return Err(syn::Error::new_spanned(
113 item,
114 "Keyword argument or kwargs(**) is not allowed after kwargs(**)",
115 ));
116 }
117 Ok(())
118 }
119
120 fn add_nv_common(
121 &mut self,
122 item: &NestedMeta,
123 name: &syn::Path,
124 value: String,
125 ) -> syn::Result<()> {
126 self.kw_arg_is_ok(item)?;
127 if self.has_varargs {
128 self.arguments
130 .push(Argument::Kwarg(name.clone(), Some(value)));
131 } else {
132 self.has_kw = true;
133 self.arguments
134 .push(Argument::Arg(name.clone(), Some(value)));
135 }
136 Ok(())
137 }
138
139 fn add_name_value(&mut self, item: &NestedMeta, nv: &syn::MetaNameValue) -> syn::Result<()> {
140 match nv.lit {
141 syn::Lit::Str(ref litstr) => {
142 if litstr.value() == "*" {
143 self.vararg_is_ok(item)?;
145 self.has_varargs = true;
146 self.arguments.push(Argument::VarArgs(nv.path.clone()));
147 } else if litstr.value() == "**" {
148 self.kw_arg_is_ok(item)?;
150 self.has_kwargs = true;
151 self.arguments.push(Argument::KeywordArgs(nv.path.clone()));
152 } else {
153 self.add_nv_common(item, &nv.path, litstr.value())?;
154 }
155 }
156 syn::Lit::Int(ref litint) => {
157 self.add_nv_common(item, &nv.path, format!("{}", litint))?;
158 }
159 syn::Lit::Bool(ref litb) => {
160 self.add_nv_common(item, &nv.path, format!("{}", litb.value))?;
161 }
162 _ => {
163 return Err(syn::Error::new_spanned(
164 nv.lit.clone(),
165 "Only string literal is supported",
166 ));
167 }
168 };
169 Ok(())
170 }
171}
172
173pub fn parse_name_attribute(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Option<syn::Ident>> {
174 let mut name_attrs = Vec::new();
175
176 attrs.retain(|attr| match attr.parse_meta() {
178 Ok(syn::Meta::NameValue(ref nv)) if nv.path.is_ident("name") => {
179 name_attrs.push((nv.lit.clone(), attr.span()));
180 false
181 }
182 _ => true,
183 });
184
185 if 1 < name_attrs.len() {
186 return Err(syn::Error::new(
187 name_attrs[0].1,
188 "#[name] can not be specified multiple times",
189 ));
190 }
191
192 match name_attrs.get(0) {
193 Some((syn::Lit::Str(s), span)) => {
194 let mut ident: syn::Ident = s.parse()?;
195 ident.set_span(*span);
197 Ok(Some(ident))
198 }
199 Some((_, span)) => Err(syn::Error::new(
200 *span,
201 "Expected string literal for #[name] argument",
202 )),
203 None => Ok(None),
204 }
205}
206
207pub fn build_py_function(ast: &mut syn::ItemFn, args: PyFunctionAttr) -> syn::Result<TokenStream> {
208 let python_name =
209 parse_name_attribute(&mut ast.attrs)?.unwrap_or_else(|| ast.sig.ident.unraw());
210 add_fn_to_module(ast, python_name, args)
211}
212
213#[cfg(test)]
214mod test {
215 use super::{Argument, PyFunctionAttr};
216 use proc_macro2::TokenStream;
217 use quote::quote;
218 use syn::parse_quote;
219
220 fn items(input: TokenStream) -> syn::Result<Vec<Argument>> {
221 let py_fn_attr: PyFunctionAttr = syn::parse2(input)?;
222 Ok(py_fn_attr.arguments)
223 }
224
225 #[test]
226 fn test_errs() {
227 assert!(items(quote! {test="1", test2}).is_err());
228 assert!(items(quote! {test, "*", args="*"}).is_err());
229 assert!(items(quote! {test, kwargs="**", args="*"}).is_err());
230 assert!(items(quote! {test, kwargs="**", args}).is_err());
231 }
232
233 #[test]
234 fn test_simple_args() {
235 let args = items(quote! {test1, test2, test3="None"}).unwrap();
236 assert!(
237 args == vec![
238 Argument::Arg(parse_quote! {test1}, None),
239 Argument::Arg(parse_quote! {test2}, None),
240 Argument::Arg(parse_quote! {test3}, Some("None".to_owned())),
241 ]
242 );
243 }
244
245 #[test]
246 fn test_varargs() {
247 let args = items(quote! {test1, test2="None", "*", test3="None"}).unwrap();
248 assert!(
249 args == vec![
250 Argument::Arg(parse_quote! {test1}, None),
251 Argument::Arg(parse_quote! {test2}, Some("None".to_owned())),
252 Argument::VarArgsSeparator,
253 Argument::Kwarg(parse_quote! {test3}, Some("None".to_owned())),
254 ]
255 );
256
257 let args = items(quote! {"*", test1, test2}).unwrap();
258 assert!(
259 args == vec![
260 Argument::VarArgsSeparator,
261 Argument::Kwarg(parse_quote! {test1}, None),
262 Argument::Kwarg(parse_quote! {test2}, None),
263 ]
264 );
265
266 let args = items(quote! {"*", test1, test2="None"}).unwrap();
267 assert!(
268 args == vec![
269 Argument::VarArgsSeparator,
270 Argument::Kwarg(parse_quote! {test1}, None),
271 Argument::Kwarg(parse_quote! {test2}, Some("None".to_owned())),
272 ]
273 );
274
275 let args = items(quote! {"*", test1="None", test2}).unwrap();
276 assert!(
277 args == vec![
278 Argument::VarArgsSeparator,
279 Argument::Kwarg(parse_quote! {test1}, Some("None".to_owned())),
280 Argument::Kwarg(parse_quote! {test2}, None),
281 ]
282 );
283 }
284
285 #[test]
286 fn test_all() {
287 let args =
288 items(quote! {test1, test2="None", args="*", test3="None", kwargs="**"}).unwrap();
289 assert!(
290 args == vec![
291 Argument::Arg(parse_quote! {test1}, None),
292 Argument::Arg(parse_quote! {test2}, Some("None".to_owned())),
293 Argument::VarArgs(parse_quote! {args}),
294 Argument::Kwarg(parse_quote! {test3}, Some("None".to_owned())),
295 Argument::KeywordArgs(parse_quote! {kwargs}),
296 ]
297 );
298 }
299}