1use crate::{Arg, CodeGen, CodeGenContext, Node, PythonOptions, SymbolTableScopes};
2
3use proc_macro2::TokenStream;
4
5use std::default::Default;
6
7use pyo3::FromPyObject;
8use quote::{format_ident, quote};
9
10use serde::{Deserialize, Serialize};
11
12#[derive(Clone, Debug, Default, FromPyObject, PartialEq, Serialize, Deserialize)]
13pub struct Parameter {
14 pub arg: String,
15}
16
17impl CodeGen for Parameter {
18 type Context = CodeGenContext;
19 type Options = PythonOptions;
20 type SymbolTable = SymbolTableScopes;
21
22 fn to_rust(
23 self,
24 _ctx: Self::Context,
25 _options: Self::Options,
26 _symbols: Self::SymbolTable,
27 ) -> Result<TokenStream, Box<dyn std::error::Error>> {
28 let ident = format_ident!("{}", self.arg);
29 Ok(quote! {
30 #ident: PyObject
31 })
32 }
33}
34#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
36pub struct ParameterList {
37 pub posonlyargs: Vec<Parameter>,
38 pub args: Vec<Parameter>,
39 pub vararg: Option<Parameter>,
40 pub kwonlyargs: Vec<Parameter>,
41 pub kw_defaults: Vec<Arg>,
42 pub kwarg: Option<Parameter>,
43 pub defaults: Vec<Arg>,
44}
45
46use pyo3::{PyAny, PyResult};
47
48impl<'source> FromPyObject<'source> for ParameterList {
55 fn extract(ob: &'source PyAny) -> PyResult<Self> {
56 let err_msg = ob.error_message("<unknown>", "failed extracting posonlyargs");
57 let posonlyargs = ob.getattr("posonlyargs").expect(err_msg.as_str());
58 let posonlyargs_list: Vec<Parameter> = posonlyargs
59 .extract()
60 .expect("failed extracting posonlyargs");
61
62 let err_msg = ob.error_message("<unknown>", "failed extracting args");
63 let args = ob.getattr("args").expect(err_msg.as_str());
64 let args_list: Vec<Parameter> = args.extract().expect(err_msg.as_str());
65
66 let err_msg = ob.error_message("<unknown>", "failed extracting varargs");
67 let vararg = ob.getattr("vararg").expect(err_msg.as_str());
68 let vararg_option: Option<Parameter> = vararg.extract().expect(err_msg.as_str());
69
70 let err_msg = ob.error_message("<unknown>", "failed extracting kwonlyargs");
71 let kwonlyargs = ob.getattr("kwonlyargs").expect(err_msg.as_str());
72 let kwonlyargs_list: Vec<Parameter> = kwonlyargs.extract().expect(err_msg.as_str());
73
74 let err_msg = ob.error_message("<unknown>", "failed extracting kw_default");
75 let kw_defaults = ob.getattr("kw_defaults").expect(err_msg.as_str());
76 let kw_defaults_list: Vec<Arg> = if let Ok(list) = kw_defaults.extract() {
77 list
78 } else {
79 Vec::new()
80 };
81
82 let err_msg = ob.error_message("<unknown>", "failed extracting kwargs");
83 let kwarg = ob.getattr("kwarg").expect(err_msg.as_str());
84 let kwarg_option: Option<Parameter> = kwarg.extract().expect(err_msg.as_str());
85
86 let err_msg = ob.error_message("<unknown>", "failed extracting defaults");
87 let defaults = ob.getattr("defaults").expect(err_msg.as_str());
88 let defaults_list: Vec<Arg> = defaults.extract().expect(err_msg.as_str());
89
90 Ok(ParameterList {
91 posonlyargs: posonlyargs_list,
92 args: args_list,
93 vararg: vararg_option,
94 kwonlyargs: kwonlyargs_list,
95 kw_defaults: kw_defaults_list,
96 kwarg: kwarg_option,
97 defaults: defaults_list,
98
99 ..Default::default()
100 })
101 }
102}
103
104impl CodeGen for ParameterList {
105 type Context = CodeGenContext;
106 type Options = PythonOptions;
107 type SymbolTable = SymbolTableScopes;
108
109 fn to_rust(
110 self,
111 ctx: Self::Context,
112 options: Self::Options,
113 symbols: Self::SymbolTable,
114 ) -> Result<TokenStream, Box<dyn std::error::Error>> {
115 let mut stream = TokenStream::new();
116
117 for arg in self.args {
119 stream.extend(
120 arg.clone()
121 .to_rust(ctx.clone(), options.clone(), symbols.clone())
122 .expect(format!("generating arg {:?}", arg).as_str()),
123 );
124 stream.extend(quote!(,));
125 }
126
127 if let Some(arg) = self.vararg {
129 let name = format_ident!("{}", arg.arg);
130 stream.extend(quote!(#name: Vec<PyAny>));
131 stream.extend(quote!(,));
132 }
133
134 for arg in self.kwonlyargs {
136 stream.extend(
137 arg.clone()
138 .to_rust(ctx.clone(), options.clone(), symbols.clone())
139 .expect(format!("generating kwonlyarg {:?}", arg).as_str()),
140 );
141 stream.extend(quote!(,));
142 }
143
144 if let Some(arg) = self.kwarg {
146 let name = format_ident!("{}", arg.arg);
147 stream.extend(quote!(#name: PyDict<PyAny>));
148 stream.extend(quote!(,));
149 }
150
151 Ok(quote!(#stream))
152 }
153}
154
155#[cfg(test)]
158mod tests {
159 use super::*;
160 use test_log::test;
161
162 use crate::parse;
163 use crate::tree::statement::StatementType;
164 use crate::tree::Module;
165 use pyo3::PyResult;
166
167 fn setup(input: &str) -> PyResult<Module> {
168 let ast = parse(input, "__test__.py")?;
169 Ok(ast)
170 }
171
172 #[test]
173 fn no_parameters() {
174 let test_function = "def foo():\n pass\n";
175 let module = setup(test_function).unwrap();
176
177 let function_def_statement = module.raw.body[0].clone();
178
179 if let StatementType::FunctionDef(f) = function_def_statement.statement {
180 assert_eq!(f.args.args.len(), 0)
181 } else {
182 panic!(
183 "Expected function definition, found {:#?}",
184 function_def_statement
185 );
186 }
187 }
188
189 #[test]
190 fn one_parameter() {
191 let test_function = "def foo1(a):\n pass\n";
192 let module = setup(test_function).unwrap();
193
194 let function_def_statement = module.raw.body[0].clone();
195
196 if let StatementType::FunctionDef(f) = function_def_statement.statement {
197 assert_eq!(f.args.args.len(), 1)
198 } else {
199 panic!(
200 "Expected function definition, found {:#?}",
201 function_def_statement
202 );
203 }
204 }
205
206 #[test]
207 fn multiple_positional_parameter() {
208 let test_function = "def foo2(a, b, c):\n pass\n";
209 let module = setup(test_function).unwrap();
210
211 let function_def_statement = module.raw.body[0].clone();
212
213 if let StatementType::FunctionDef(f) = function_def_statement.statement {
214 assert_eq!(f.args.args.len(), 3)
215 } else {
216 panic!(
217 "Expected function definition, found {:#?}",
218 function_def_statement
219 );
220 }
221 }
222
223 #[test]
224 fn vararg_only() {
225 let test_function = "def foo3(*a):\n pass\n";
226 let module = setup(test_function).unwrap();
227
228 let function_def_statement = module.raw.body[0].clone();
229
230 if let StatementType::FunctionDef(f) = function_def_statement.statement {
231 assert_eq!(f.args.args.len(), 0);
232 assert_eq!(
233 f.args.vararg,
234 Some(Parameter {
235 arg: "a".to_string()
236 })
237 );
238 } else {
239 panic!(
240 "Expected function definition, found {:#?}",
241 function_def_statement
242 );
243 }
244 }
245
246 #[test]
247 fn positional_and_vararg() {
248 let test_function = "def foo4(a, *b):\n pass\n";
249 let module = setup(test_function).unwrap();
250
251 let function_def_statement = module.raw.body[0].clone();
252
253 if let StatementType::FunctionDef(f) = function_def_statement.statement {
254 assert_eq!(f.args.args.len(), 1);
255 assert_eq!(
256 f.args.vararg,
257 Some(Parameter {
258 arg: "b".to_string()
259 })
260 );
261 } else {
262 panic!(
263 "Expected function definition, found {:#?}",
264 function_def_statement
265 );
266 }
267 }
268
269 #[test]
270 fn positional_and_vararg_and_kw() {
271 let test_function = "def foo5(a, *b, c=7):\n pass\n";
272 let module = setup(test_function).unwrap();
273
274 let function_def_statement = module.raw.body[0].clone();
275
276 if let StatementType::FunctionDef(f) = function_def_statement.statement {
277 assert_eq!(f.args.args.len(), 1);
278 assert_eq!(
279 f.args.vararg,
280 Some(Parameter {
281 arg: "b".to_string()
282 })
283 );
284 assert_eq!(
285 f.args.kwonlyargs,
286 vec![Parameter {
287 arg: "c".to_string()
288 }]
289 );
290 } else {
291 panic!(
292 "Expected function definition, found {:#?}",
293 function_def_statement
294 );
295 }
296 }
297
298 #[test]
299 fn positional_and_kw() {
300 let test_function = "def foo6(a, c=7):\n pass\n";
301 let module = setup(test_function).unwrap();
302
303 println!("module: {:#?}", module);
304 let function_def_statement = module.raw.body[0].clone();
305
306 if let StatementType::FunctionDef(f) = function_def_statement.statement {
307 println!("{:?}", f);
308 assert_eq!(f.args.args.len(), 2);
309 assert_eq!(f.args.defaults.len(), 1);
310 } else {
312 panic!(
313 "Expected function definition, found {:#?}",
314 function_def_statement
315 );
316 }
317 }
318
319 #[test]
320 fn default_only() {
321 let test_function = "def foo7(a=7):\n pass\n";
322 let module = setup(test_function).unwrap();
323
324 let function_def_statement = module.raw.body[0].clone();
325
326 if let StatementType::FunctionDef(f) = function_def_statement.statement {
327 assert_eq!(f.args.args.len(), 1);
328 assert_eq!(f.args.defaults.len(), 1);
329 } else {
331 panic!(
332 "Expected function definition, found {:#?}",
333 function_def_statement
334 );
335 }
336 }
337
338 #[test]
339 fn kwargs_only() {
340 let test_function = "def foo8(**a):\n pass\n";
341 let module = setup(test_function).unwrap();
342
343 let function_def_statement = module.raw.body[0].clone();
344
345 if let StatementType::FunctionDef(f) = function_def_statement.statement {
346 assert_eq!(f.args.args.len(), 0);
347 assert_eq!(
348 f.args.kwarg,
349 Some(Parameter {
350 arg: "a".to_string()
351 })
352 );
353 } else {
354 panic!(
355 "Expected function definition, found {:#?}",
356 function_def_statement
357 );
358 }
359 }
360
361 #[test]
362 fn named_and_positional() {
363 let test_function = "def foo9(a, *, b):\n pass\n";
364 let module = setup(test_function).unwrap();
365
366 let function_def_statement = module.raw.body[0].clone();
367
368 if let StatementType::FunctionDef(f) = function_def_statement.statement {
369 assert_eq!(f.args.args.len(), 1);
370 assert_eq!(f.args.vararg, None);
371 assert_eq!(
372 f.args.kwonlyargs,
373 vec![Parameter {
374 arg: "b".to_string()
375 }]
376 );
377 } else {
378 panic!(
379 "Expected function definition, found {:#?}",
380 function_def_statement
381 );
382 }
383 }
384}