1use proc_macro2::TokenStream;
3use pyo3::{Bound, FromPyObject, PyAny, PyResult, prelude::PyAnyMethods};
4use quote::quote;
5use serde::{Deserialize, Serialize};
6
7use crate::{
8 CodeGen, CodeGenContext, ExprType, Node, PythonOptions, SymbolTableScopes,
9};
10
11#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
14pub struct Argument {
15 pub value: ExprType,
17 pub lineno: Option<usize>,
19 pub col_offset: Option<usize>,
20 pub end_lineno: Option<usize>,
21 pub end_col_offset: Option<usize>,
22}
23
24pub type Arg = ExprType;
27
28#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
30pub struct Parameter {
31 pub arg: String,
33 pub annotation: Option<Box<ExprType>>,
35 pub type_comment: Option<String>,
37 pub lineno: Option<usize>,
39 pub col_offset: Option<usize>,
40 pub end_lineno: Option<usize>,
41 pub end_col_offset: Option<usize>,
42}
43
44#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
46pub struct Arguments {
47 pub posonlyargs: Vec<Parameter>,
49 pub args: Vec<Parameter>,
51 pub vararg: Option<Parameter>,
53 pub kwonlyargs: Vec<Parameter>,
55 pub kw_defaults: Vec<Option<Box<ExprType>>>,
57 pub kwarg: Option<Parameter>,
59 pub defaults: Vec<Box<ExprType>>,
61}
62
63
64#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
66pub struct CallArguments {
67 pub args: Vec<ExprType>,
69 pub keywords: Vec<crate::Keyword>,
71}
72
73impl<'a> FromPyObject<'a> for Argument {
75 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
76 let value: ExprType = ob.extract()?;
78
79 Ok(Self {
80 value,
81 lineno: ob.lineno(),
82 col_offset: ob.col_offset(),
83 end_lineno: ob.end_lineno(),
84 end_col_offset: ob.end_col_offset(),
85 })
86 }
87}
88
89impl CodeGen for Argument {
90 type Context = CodeGenContext;
91 type Options = PythonOptions;
92 type SymbolTable = SymbolTableScopes;
93
94 fn to_rust(
95 self,
96 ctx: Self::Context,
97 options: Self::Options,
98 symbols: Self::SymbolTable,
99 ) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
100 self.value.to_rust(ctx, options, symbols)
101 }
102}
103
104impl<'a> FromPyObject<'a> for Parameter {
106 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
107 let arg: String = ob.getattr("arg")?.extract()?;
108
109 let annotation = if let Ok(ann) = ob.getattr("annotation") {
111 if ann.is_none() {
112 None
113 } else {
114 Some(Box::new(ann.extract()?))
115 }
116 } else {
117 None
118 };
119
120 let type_comment = if let Ok(tc) = ob.getattr("type_comment") {
122 if tc.is_none() {
123 None
124 } else {
125 Some(tc.extract()?)
126 }
127 } else {
128 None
129 };
130
131 Ok(Self {
132 arg,
133 annotation,
134 type_comment,
135 lineno: ob.lineno(),
136 col_offset: ob.col_offset(),
137 end_lineno: ob.end_lineno(),
138 end_col_offset: ob.end_col_offset(),
139 })
140 }
141}
142
143impl CodeGen for Parameter {
144 type Context = CodeGenContext;
145 type Options = PythonOptions;
146 type SymbolTable = SymbolTableScopes;
147
148 fn to_rust(
149 self,
150 ctx: Self::Context,
151 options: Self::Options,
152 symbols: Self::SymbolTable,
153 ) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
154 use quote::format_ident;
155
156 let param_name = format_ident!("{}", self.arg);
157
158 if let Some(annotation) = self.annotation {
160 let rust_type = annotation.to_rust(ctx, options, symbols)?;
161 Ok(quote!(#param_name: #rust_type))
162 } else {
163 Ok(quote!(#param_name: impl Into<PyObject>))
165 }
166 }
167}
168
169impl<'a> FromPyObject<'a> for Arguments {
171 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
172 let posonlyargs: Vec<Parameter> = ob.getattr("posonlyargs")?.extract().unwrap_or_default();
174 let args: Vec<Parameter> = ob.getattr("args")?.extract().unwrap_or_default();
175
176 let vararg = if let Ok(va) = ob.getattr("vararg") {
177 if va.is_none() { None } else { Some(va.extract()?) }
178 } else { None };
179
180 let kwonlyargs: Vec<Parameter> = ob.getattr("kwonlyargs")?.extract().unwrap_or_default();
181
182 let kw_defaults = if let Ok(kw_def) = ob.getattr("kw_defaults") {
184 let defaults_list: Vec<Bound<PyAny>> = kw_def.extract().unwrap_or_default();
185 let mut processed_defaults = Vec::new();
186 for default in defaults_list {
187 if default.is_none() {
188 processed_defaults.push(None);
189 } else {
190 processed_defaults.push(Some(Box::new(default.extract()?)));
191 }
192 }
193 processed_defaults
194 } else {
195 Vec::new()
196 };
197
198 let kwarg = if let Ok(kw) = ob.getattr("kwarg") {
199 if kw.is_none() { None } else { Some(kw.extract()?) }
200 } else { None };
201
202 let defaults_raw: Vec<ExprType> = ob.getattr("defaults")?.extract().unwrap_or_default();
203 let defaults = defaults_raw.into_iter().map(Box::new).collect();
204
205 Ok(Self {
206 posonlyargs,
207 args,
208 vararg,
209 kwonlyargs,
210 kw_defaults,
211 kwarg,
212 defaults,
213 })
214 }
215}
216
217impl CodeGen for Arguments {
218 type Context = CodeGenContext;
219 type Options = PythonOptions;
220 type SymbolTable = SymbolTableScopes;
221
222 fn to_rust(
223 self,
224 ctx: Self::Context,
225 options: Self::Options,
226 symbols: Self::SymbolTable,
227 ) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
228 let mut params = Vec::new();
229
230 for arg in self.posonlyargs {
232 let param = arg.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
233 params.push(param);
234 }
235
236 let defaults_offset = self.args.len().saturating_sub(self.defaults.len());
238 for (i, arg) in self.args.into_iter().enumerate() {
239 if i >= defaults_offset {
240 let default_idx = i - defaults_offset;
242 let default_value = &self.defaults[default_idx];
243 let _default_rust = default_value.as_ref().clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
244 let param_name = quote::format_ident!("{}", arg.arg);
245
246 if let Some(annotation) = &arg.annotation {
247 let rust_type = annotation.as_ref().clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
248 params.push(quote!(#param_name: Option<#rust_type>));
249 } else {
250 params.push(quote!(#param_name: Option<impl Into<PyObject>>));
251 }
252 } else {
253 let param = arg.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
254 params.push(param);
255 }
256 }
257
258 if let Some(vararg) = self.vararg {
260 let vararg_name = quote::format_ident!("{}", vararg.arg);
261 params.push(quote!(#vararg_name: impl IntoIterator<Item = impl Into<PyObject>>));
262 }
263
264 for (i, arg) in self.kwonlyargs.into_iter().enumerate() {
266 let param_name = quote::format_ident!("{}", arg.arg);
267
268 let has_default = i < self.kw_defaults.len() && self.kw_defaults[i].is_some();
270
271 if let Some(annotation) = &arg.annotation {
272 let rust_type = annotation.as_ref().clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
273 if has_default {
274 params.push(quote!(#param_name: Option<#rust_type>));
275 } else {
276 params.push(quote!(#param_name: #rust_type));
277 }
278 } else {
279 if has_default {
280 params.push(quote!(#param_name: Option<impl Into<PyObject>>));
281 } else {
282 params.push(quote!(#param_name: impl Into<PyObject>));
283 }
284 }
285 }
286
287 if let Some(kwarg) = self.kwarg {
289 let kwarg_name = quote::format_ident!("{}", kwarg.arg);
290 params.push(quote!(#kwarg_name: impl IntoIterator<Item = (impl AsRef<str>, impl Into<PyObject>)>));
291 }
292
293 Ok(quote!(#(#params),*))
294 }
295}
296
297
298impl<'a> FromPyObject<'a> for CallArguments {
300 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
301 let args: Vec<ExprType> = ob.getattr("args")?.extract().unwrap_or_default();
302 let keywords: Vec<crate::Keyword> = ob.getattr("keywords")?.extract().unwrap_or_default();
303
304 Ok(Self { args, keywords })
305 }
306}
307
308impl CodeGen for CallArguments {
309 type Context = CodeGenContext;
310 type Options = PythonOptions;
311 type SymbolTable = SymbolTableScopes;
312
313 fn to_rust(
314 self,
315 ctx: Self::Context,
316 options: Self::Options,
317 symbols: Self::SymbolTable,
318 ) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
319 let mut all_args = Vec::new();
320
321 for arg in self.args {
323 let rust_arg = arg.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
324 all_args.push(rust_arg);
325 }
326
327 for keyword in self.keywords {
329 let rust_kw = keyword.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
330 all_args.push(rust_kw);
331 }
332
333 Ok(quote!(#(#all_args),*))
334 }
335}
336
337
338impl Node for Argument {
340 fn lineno(&self) -> Option<usize> { self.lineno }
341 fn col_offset(&self) -> Option<usize> { self.col_offset }
342 fn end_lineno(&self) -> Option<usize> { self.end_lineno }
343 fn end_col_offset(&self) -> Option<usize> { self.end_col_offset }
344}
345
346impl Node for Parameter {
347 fn lineno(&self) -> Option<usize> { self.lineno }
348 fn col_offset(&self) -> Option<usize> { self.col_offset }
349 fn end_lineno(&self) -> Option<usize> { self.end_lineno }
350 fn end_col_offset(&self) -> Option<usize> { self.end_col_offset }
351}
352
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357 use crate::{parse, CodeGenContext, ExprType, PythonOptions, SymbolTableScopes};
358 use test_log::test;
359
360 #[test]
361 fn test_simple_function_call() {
362 let code = "func(1, 2, 3)";
363 let result = parse(code, "test.py").unwrap();
364
365 let options = PythonOptions::default();
367 let symbols = SymbolTableScopes::new();
368 let _rust_code = result.to_rust(
369 CodeGenContext::Module("test".to_string()),
370 options,
371 symbols,
372 ).unwrap();
373
374 }
376
377 #[test]
378 fn test_keyword_arguments() {
379 let code = "func(a=1, b=2)";
380 let result = parse(code, "test.py").unwrap();
381
382 let options = PythonOptions::default();
383 let symbols = SymbolTableScopes::new();
384 let _rust_code = result.to_rust(
385 CodeGenContext::Module("test".to_string()),
386 options,
387 symbols,
388 ).unwrap();
389
390 }
392
393 #[test]
394 fn test_mixed_arguments() {
395 let code = "func(1, 2, c=3, d=4)";
396 let result = parse(code, "test.py").unwrap();
397
398 let options = PythonOptions::default();
399 let symbols = SymbolTableScopes::new();
400 let _rust_code = result.to_rust(
401 CodeGenContext::Module("test".to_string()),
402 options,
403 symbols,
404 ).unwrap();
405
406 }
408
409 #[test]
410 fn test_function_with_defaults() {
411 let code = r#"
412def func(a, b=2, c=3):
413 pass
414 "#;
415 let result = parse(code, "test.py").unwrap();
416
417 let options = PythonOptions::default();
418 let symbols = SymbolTableScopes::new();
419 let _rust_code = result.to_rust(
420 CodeGenContext::Module("test".to_string()),
421 options,
422 symbols,
423 ).unwrap();
424
425 }
427
428 #[test]
429 fn test_function_with_varargs() {
430 let code = r#"
431def func(a, *args):
432 pass
433 "#;
434 let result = parse(code, "test.py").unwrap();
435
436 let options = PythonOptions::default();
437 let symbols = SymbolTableScopes::new();
438 let _rust_code = result.to_rust(
439 CodeGenContext::Module("test".to_string()),
440 options,
441 symbols,
442 ).unwrap();
443
444 }
446
447 #[test]
448 fn test_function_with_kwargs() {
449 let code = r#"
450def func(a, **kwargs):
451 pass
452 "#;
453 let result = parse(code, "test.py").unwrap();
454
455 let options = PythonOptions::default();
456 let symbols = SymbolTableScopes::new();
457 let _rust_code = result.to_rust(
458 CodeGenContext::Module("test".to_string()),
459 options,
460 symbols,
461 ).unwrap();
462
463 }
465
466 #[test]
467 fn test_complex_function_signature() {
468 let code = r#"
469def func(a, b=2, *args, c, d=4, **kwargs):
470 pass
471 "#;
472 let result = parse(code, "test.py").unwrap();
473
474 let options = PythonOptions::default();
475 let symbols = SymbolTableScopes::new();
476 let _rust_code = result.to_rust(
477 CodeGenContext::Module("test".to_string()),
478 options,
479 symbols,
480 ).unwrap();
481
482 }
484
485 #[test]
486 fn test_keyword_only_arguments() {
487 let code = r#"
488def func(a, *, b, c=3):
489 pass
490 "#;
491 let result = parse(code, "test.py").unwrap();
492
493 let options = PythonOptions::default();
494 let symbols = SymbolTableScopes::new();
495 let _rust_code = result.to_rust(
496 CodeGenContext::Module("test".to_string()),
497 options,
498 symbols,
499 ).unwrap();
500
501 }
503
504 #[test]
505 fn test_argument_unpacking_call() {
506 let code = "func(*args, **kwargs)";
508 let result = parse(code, "test.py");
509
510 match result {
511 Ok(ast) => {
512 let options = PythonOptions::default();
513 let symbols = SymbolTableScopes::new();
514 let rust_code = ast.to_rust(
515 CodeGenContext::Module("test".to_string()),
516 options,
517 symbols,
518 );
519
520 match rust_code {
521 Ok(_code) => { },
522 Err(_e) => { },
523 }
524 }
525 Err(_e) => { },
526 }
527 }
528
529 #[test]
530 fn test_arg_with_constant() {
531 use litrs::Literal;
533 let literal = Literal::parse("42").unwrap().into_owned();
534 let constant = crate::Constant(Some(literal));
535 let arg: Arg = ExprType::Constant(constant);
536
537 let options = PythonOptions::default();
538 let symbols = SymbolTableScopes::new();
539 let rust_code = arg.to_rust(
540 CodeGenContext::Module("test".to_string()),
541 options,
542 symbols,
543 ).unwrap();
544
545 assert!(rust_code.to_string().contains("42"));
546 }
547
548 #[test]
549 fn test_arg_with_name() {
550 let name_expr = ExprType::Name(crate::Name {
552 id: "variable".to_string(),
553 });
554 let arg: Arg = name_expr;
555
556 let options = PythonOptions::default();
557 let symbols = SymbolTableScopes::new();
558 let rust_code = arg.to_rust(
559 CodeGenContext::Module("test".to_string()),
560 options,
561 symbols,
562 ).unwrap();
563
564 assert!(rust_code.to_string().contains("variable"));
565 }
566}