python_ast/ast/tree/
parameters.rs

1use crate::{Arg, CodeGen, CodeGenContext, Node, PythonOptions, SymbolTableScopes};
2
3use proc_macro2::TokenStream;
4
5use std::default::Default;
6
7use pyo3::FromPyObject;
8use quote::{format_ident, quote};
9
10use serde::{Deserialize, Serialize};
11
12#[derive(Clone, Debug, Default, FromPyObject, PartialEq, Serialize, Deserialize)]
13pub struct Parameter {
14    pub arg: String,
15}
16
17impl CodeGen for Parameter {
18    type Context = CodeGenContext;
19    type Options = PythonOptions;
20    type SymbolTable = SymbolTableScopes;
21
22    fn to_rust(
23        self,
24        _ctx: Self::Context,
25        _options: Self::Options,
26        _symbols: Self::SymbolTable,
27    ) -> Result<TokenStream, Box<dyn std::error::Error>> {
28        let ident = format_ident!("{}", self.arg);
29        Ok(quote! {
30            #ident: PyObject
31        })
32    }
33}
34/// The parameter list of a function.
35#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
36pub struct ParameterList {
37    pub posonlyargs: Vec<Parameter>,
38    pub args: Vec<Parameter>,
39    pub vararg: Option<Parameter>,
40    pub kwonlyargs: Vec<Parameter>,
41    pub kw_defaults: Vec<Arg>,
42    pub kwarg: Option<Parameter>,
43    pub defaults: Vec<Arg>,
44}
45
46use pyo3::{PyAny, PyResult};
47
48// We have to manually implement the conversion of ParameterList objects
49// because under a number of conditions, attributes that should be lists
50// are unset, which causes them to be retrieved as None, which causes the
51// derived implementation to error when converting to a Vec type. It would
52// be nice if they generated empty Vecs instead, but since it doesn't, we
53// have to do it manually.
54impl<'source> FromPyObject<'source> for ParameterList {
55    fn extract(ob: &'source PyAny) -> PyResult<Self> {
56        let err_msg = ob.error_message("<unknown>", "failed extracting posonlyargs");
57        let posonlyargs = ob.getattr("posonlyargs").expect(err_msg.as_str());
58        let posonlyargs_list: Vec<Parameter> = posonlyargs
59            .extract()
60            .expect("failed extracting posonlyargs");
61
62        let err_msg = ob.error_message("<unknown>", "failed extracting args");
63        let args = ob.getattr("args").expect(err_msg.as_str());
64        let args_list: Vec<Parameter> = args.extract().expect(err_msg.as_str());
65
66        let err_msg = ob.error_message("<unknown>", "failed extracting varargs");
67        let vararg = ob.getattr("vararg").expect(err_msg.as_str());
68        let vararg_option: Option<Parameter> = vararg.extract().expect(err_msg.as_str());
69
70        let err_msg = ob.error_message("<unknown>", "failed extracting kwonlyargs");
71        let kwonlyargs = ob.getattr("kwonlyargs").expect(err_msg.as_str());
72        let kwonlyargs_list: Vec<Parameter> = kwonlyargs.extract().expect(err_msg.as_str());
73
74        let err_msg = ob.error_message("<unknown>", "failed extracting kw_default");
75        let kw_defaults = ob.getattr("kw_defaults").expect(err_msg.as_str());
76        let kw_defaults_list: Vec<Arg> = if let Ok(list) = kw_defaults.extract() {
77            list
78        } else {
79            Vec::new()
80        };
81
82        let err_msg = ob.error_message("<unknown>", "failed extracting kwargs");
83        let kwarg = ob.getattr("kwarg").expect(err_msg.as_str());
84        let kwarg_option: Option<Parameter> = kwarg.extract().expect(err_msg.as_str());
85
86        let err_msg = ob.error_message("<unknown>", "failed extracting defaults");
87        let defaults = ob.getattr("defaults").expect(err_msg.as_str());
88        let defaults_list: Vec<Arg> = defaults.extract().expect(err_msg.as_str());
89
90        Ok(ParameterList {
91            posonlyargs: posonlyargs_list,
92            args: args_list,
93            vararg: vararg_option,
94            kwonlyargs: kwonlyargs_list,
95            kw_defaults: kw_defaults_list,
96            kwarg: kwarg_option,
97            defaults: defaults_list,
98
99            ..Default::default()
100        })
101    }
102}
103
104impl CodeGen for ParameterList {
105    type Context = CodeGenContext;
106    type Options = PythonOptions;
107    type SymbolTable = SymbolTableScopes;
108
109    fn to_rust(
110        self,
111        ctx: Self::Context,
112        options: Self::Options,
113        symbols: Self::SymbolTable,
114    ) -> Result<TokenStream, Box<dyn std::error::Error>> {
115        let mut stream = TokenStream::new();
116
117        // Ordinary args
118        for arg in self.args {
119            stream.extend(
120                arg.clone()
121                    .to_rust(ctx.clone(), options.clone(), symbols.clone())
122                    .expect(format!("generating arg {:?}", arg).as_str()),
123            );
124            stream.extend(quote!(,));
125        }
126
127        // Variable positional arg
128        if let Some(arg) = self.vararg {
129            let name = format_ident!("{}", arg.arg);
130            stream.extend(quote!(#name: Vec<PyAny>));
131            stream.extend(quote!(,));
132        }
133
134        // kwonlyargs
135        for arg in self.kwonlyargs {
136            stream.extend(
137                arg.clone()
138                    .to_rust(ctx.clone(), options.clone(), symbols.clone())
139                    .expect(format!("generating kwonlyarg {:?}", arg).as_str()),
140            );
141            stream.extend(quote!(,));
142        }
143
144        // kwarg
145        if let Some(arg) = self.kwarg {
146            let name = format_ident!("{}", arg.arg);
147            stream.extend(quote!(#name: PyDict<PyAny>));
148            stream.extend(quote!(,));
149        }
150
151        Ok(quote!(#stream))
152    }
153}
154
155// It's fairly easy to break the automatic parsing of parameter structs, so we need to have fairly sophisticated
156// test coverage for the various types of
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use test_log::test;
161
162    use crate::parse;
163    use crate::tree::statement::StatementType;
164    use crate::tree::Module;
165    use pyo3::PyResult;
166
167    fn setup(input: &str) -> PyResult<Module> {
168        let ast = parse(input, "__test__.py")?;
169        Ok(ast)
170    }
171
172    #[test]
173    fn no_parameters() {
174        let test_function = "def foo():\n    pass\n";
175        let module = setup(test_function).unwrap();
176
177        let function_def_statement = module.raw.body[0].clone();
178
179        if let StatementType::FunctionDef(f) = function_def_statement.statement {
180            assert_eq!(f.args.args.len(), 0)
181        } else {
182            panic!(
183                "Expected function definition, found {:#?}",
184                function_def_statement
185            );
186        }
187    }
188
189    #[test]
190    fn one_parameter() {
191        let test_function = "def foo1(a):\n    pass\n";
192        let module = setup(test_function).unwrap();
193
194        let function_def_statement = module.raw.body[0].clone();
195
196        if let StatementType::FunctionDef(f) = function_def_statement.statement {
197            assert_eq!(f.args.args.len(), 1)
198        } else {
199            panic!(
200                "Expected function definition, found {:#?}",
201                function_def_statement
202            );
203        }
204    }
205
206    #[test]
207    fn multiple_positional_parameter() {
208        let test_function = "def foo2(a, b, c):\n    pass\n";
209        let module = setup(test_function).unwrap();
210
211        let function_def_statement = module.raw.body[0].clone();
212
213        if let StatementType::FunctionDef(f) = function_def_statement.statement {
214            assert_eq!(f.args.args.len(), 3)
215        } else {
216            panic!(
217                "Expected function definition, found {:#?}",
218                function_def_statement
219            );
220        }
221    }
222
223    #[test]
224    fn vararg_only() {
225        let test_function = "def foo3(*a):\n    pass\n";
226        let module = setup(test_function).unwrap();
227
228        let function_def_statement = module.raw.body[0].clone();
229
230        if let StatementType::FunctionDef(f) = function_def_statement.statement {
231            assert_eq!(f.args.args.len(), 0);
232            assert_eq!(
233                f.args.vararg,
234                Some(Parameter {
235                    arg: "a".to_string()
236                })
237            );
238        } else {
239            panic!(
240                "Expected function definition, found {:#?}",
241                function_def_statement
242            );
243        }
244    }
245
246    #[test]
247    fn positional_and_vararg() {
248        let test_function = "def foo4(a, *b):\n    pass\n";
249        let module = setup(test_function).unwrap();
250
251        let function_def_statement = module.raw.body[0].clone();
252
253        if let StatementType::FunctionDef(f) = function_def_statement.statement {
254            assert_eq!(f.args.args.len(), 1);
255            assert_eq!(
256                f.args.vararg,
257                Some(Parameter {
258                    arg: "b".to_string()
259                })
260            );
261        } else {
262            panic!(
263                "Expected function definition, found {:#?}",
264                function_def_statement
265            );
266        }
267    }
268
269    #[test]
270    fn positional_and_vararg_and_kw() {
271        let test_function = "def foo5(a, *b, c=7):\n    pass\n";
272        let module = setup(test_function).unwrap();
273
274        let function_def_statement = module.raw.body[0].clone();
275
276        if let StatementType::FunctionDef(f) = function_def_statement.statement {
277            assert_eq!(f.args.args.len(), 1);
278            assert_eq!(
279                f.args.vararg,
280                Some(Parameter {
281                    arg: "b".to_string()
282                })
283            );
284            assert_eq!(
285                f.args.kwonlyargs,
286                vec![Parameter {
287                    arg: "c".to_string()
288                }]
289            );
290        } else {
291            panic!(
292                "Expected function definition, found {:#?}",
293                function_def_statement
294            );
295        }
296    }
297
298    #[test]
299    fn positional_and_kw() {
300        let test_function = "def foo6(a, c=7):\n    pass\n";
301        let module = setup(test_function).unwrap();
302
303        println!("module: {:#?}", module);
304        let function_def_statement = module.raw.body[0].clone();
305
306        if let StatementType::FunctionDef(f) = function_def_statement.statement {
307            println!("{:?}", f);
308            assert_eq!(f.args.args.len(), 2);
309            assert_eq!(f.args.defaults.len(), 1);
310            //assert_eq!(f.args.defaults[0], Arg::Constant(crate::Constant(Literal::parse(String::from("7")).unwrap())));
311        } else {
312            panic!(
313                "Expected function definition, found {:#?}",
314                function_def_statement
315            );
316        }
317    }
318
319    #[test]
320    fn default_only() {
321        let test_function = "def foo7(a=7):\n    pass\n";
322        let module = setup(test_function).unwrap();
323
324        let function_def_statement = module.raw.body[0].clone();
325
326        if let StatementType::FunctionDef(f) = function_def_statement.statement {
327            assert_eq!(f.args.args.len(), 1);
328            assert_eq!(f.args.defaults.len(), 1);
329            //assert_eq!(f.args.defaults[0], Arg::Constant(crate::Constant(Literal::parse(String::from("7")).unwrap())));
330        } else {
331            panic!(
332                "Expected function definition, found {:#?}",
333                function_def_statement
334            );
335        }
336    }
337
338    #[test]
339    fn kwargs_only() {
340        let test_function = "def foo8(**a):\n    pass\n";
341        let module = setup(test_function).unwrap();
342
343        let function_def_statement = module.raw.body[0].clone();
344
345        if let StatementType::FunctionDef(f) = function_def_statement.statement {
346            assert_eq!(f.args.args.len(), 0);
347            assert_eq!(
348                f.args.kwarg,
349                Some(Parameter {
350                    arg: "a".to_string()
351                })
352            );
353        } else {
354            panic!(
355                "Expected function definition, found {:#?}",
356                function_def_statement
357            );
358        }
359    }
360
361    #[test]
362    fn named_and_positional() {
363        let test_function = "def foo9(a, *, b):\n    pass\n";
364        let module = setup(test_function).unwrap();
365
366        let function_def_statement = module.raw.body[0].clone();
367
368        if let StatementType::FunctionDef(f) = function_def_statement.statement {
369            assert_eq!(f.args.args.len(), 1);
370            assert_eq!(f.args.vararg, None);
371            assert_eq!(
372                f.args.kwonlyargs,
373                vec![Parameter {
374                    arg: "b".to_string()
375                }]
376            );
377        } else {
378            panic!(
379                "Expected function definition, found {:#?}",
380                function_def_statement
381            );
382        }
383    }
384}