1use proc_macro2::TokenStream;
3use pyo3::{Bound, FromPyObject, PyAny, PyResult, prelude::PyAnyMethods};
4use quote::quote;
5use serde::{Deserialize, Serialize};
6
7use crate::{
8 CodeGen, CodeGenContext, ExprType, Node, PythonOptions, SymbolTableScopes,
9};
10
11#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
14pub struct Argument {
15 pub value: ExprType,
17 pub lineno: Option<usize>,
19 pub col_offset: Option<usize>,
20 pub end_lineno: Option<usize>,
21 pub end_col_offset: Option<usize>,
22}
23
24pub type Arg = ExprType;
27
28#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
30pub struct Parameter {
31 pub arg: String,
33 pub annotation: Option<Box<ExprType>>,
35 pub type_comment: Option<String>,
37 pub lineno: Option<usize>,
39 pub col_offset: Option<usize>,
40 pub end_lineno: Option<usize>,
41 pub end_col_offset: Option<usize>,
42}
43
44#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
46pub struct Arguments {
47 pub posonlyargs: Vec<Parameter>,
49 pub args: Vec<Parameter>,
51 pub vararg: Option<Parameter>,
53 pub kwonlyargs: Vec<Parameter>,
55 pub kw_defaults: Vec<Option<Box<ExprType>>>,
57 pub kwarg: Option<Parameter>,
59 pub defaults: Vec<Box<ExprType>>,
61}
62
63
64#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
66pub struct CallArguments {
67 pub args: Vec<ExprType>,
69 pub keywords: Vec<crate::Keyword>,
71}
72
73impl<'a> FromPyObject<'a> for Argument {
75 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
76 let value: ExprType = ob.extract()?;
78
79 Ok(Self {
80 value,
81 lineno: ob.lineno(),
82 col_offset: ob.col_offset(),
83 end_lineno: ob.end_lineno(),
84 end_col_offset: ob.end_col_offset(),
85 })
86 }
87}
88
89impl CodeGen for Argument {
90 type Context = CodeGenContext;
91 type Options = PythonOptions;
92 type SymbolTable = SymbolTableScopes;
93
94 fn to_rust(
95 self,
96 ctx: Self::Context,
97 options: Self::Options,
98 symbols: Self::SymbolTable,
99 ) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
100 self.value.to_rust(ctx, options, symbols)
101 }
102}
103
104impl<'a> FromPyObject<'a> for Parameter {
106 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
107 let arg: String = ob.getattr("arg")?.extract()?;
108
109 let annotation = if let Ok(ann) = ob.getattr("annotation") {
111 if ann.is_none() {
112 None
113 } else {
114 Some(Box::new(ann.extract()?))
115 }
116 } else {
117 None
118 };
119
120 let type_comment = if let Ok(tc) = ob.getattr("type_comment") {
122 if tc.is_none() {
123 None
124 } else {
125 Some(tc.extract()?)
126 }
127 } else {
128 None
129 };
130
131 Ok(Self {
132 arg,
133 annotation,
134 type_comment,
135 lineno: ob.lineno(),
136 col_offset: ob.col_offset(),
137 end_lineno: ob.end_lineno(),
138 end_col_offset: ob.end_col_offset(),
139 })
140 }
141}
142
143impl CodeGen for Parameter {
144 type Context = CodeGenContext;
145 type Options = PythonOptions;
146 type SymbolTable = SymbolTableScopes;
147
148 fn to_rust(
149 self,
150 ctx: Self::Context,
151 options: Self::Options,
152 symbols: Self::SymbolTable,
153 ) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
154 use quote::format_ident;
155
156 let param_name = format_ident!("{}", self.arg);
157
158 if let Some(annotation) = self.annotation {
160 let rust_type = annotation.to_rust(ctx, options, symbols)?;
161 Ok(quote!(#param_name: #rust_type))
162 } else {
163 Ok(quote!(#param_name: impl Into<PyObject>))
165 }
166 }
167}
168
169impl<'a> FromPyObject<'a> for Arguments {
171 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
172 let posonlyargs: Vec<Parameter> = ob.getattr("posonlyargs")?.extract().unwrap_or_default();
174 let args: Vec<Parameter> = ob.getattr("args")?.extract().unwrap_or_default();
175
176 let vararg = if let Ok(va) = ob.getattr("vararg") {
177 if va.is_none() { None } else { Some(va.extract()?) }
178 } else { None };
179
180 let kwonlyargs: Vec<Parameter> = ob.getattr("kwonlyargs")?.extract().unwrap_or_default();
181
182 let kw_defaults = if let Ok(kw_def) = ob.getattr("kw_defaults") {
184 let defaults_list: Vec<Bound<PyAny>> = kw_def.extract().unwrap_or_default();
185 let mut processed_defaults = Vec::new();
186 for default in defaults_list {
187 if default.is_none() {
188 processed_defaults.push(None);
189 } else {
190 processed_defaults.push(Some(Box::new(default.extract()?)));
191 }
192 }
193 processed_defaults
194 } else {
195 Vec::new()
196 };
197
198 let kwarg = if let Ok(kw) = ob.getattr("kwarg") {
199 if kw.is_none() { None } else { Some(kw.extract()?) }
200 } else { None };
201
202 let defaults_raw: Vec<ExprType> = ob.getattr("defaults")?.extract().unwrap_or_default();
203 let defaults = defaults_raw.into_iter().map(Box::new).collect();
204
205 Ok(Self {
206 posonlyargs,
207 args,
208 vararg,
209 kwonlyargs,
210 kw_defaults,
211 kwarg,
212 defaults,
213 })
214 }
215}
216
217impl CodeGen for Arguments {
218 type Context = CodeGenContext;
219 type Options = PythonOptions;
220 type SymbolTable = SymbolTableScopes;
221
222 fn to_rust(
223 self,
224 ctx: Self::Context,
225 options: Self::Options,
226 symbols: Self::SymbolTable,
227 ) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
228 let mut params = Vec::new();
229
230 for arg in self.posonlyargs {
232 let param = arg.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
233 params.push(param);
234 }
235
236 let defaults_offset = self.args.len().saturating_sub(self.defaults.len());
238 for (i, arg) in self.args.into_iter().enumerate() {
239 if i >= defaults_offset {
240 let default_idx = i - defaults_offset;
242 let default_value = &self.defaults[default_idx];
243 let _default_rust = default_value.as_ref().clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
244 let param_name = quote::format_ident!("{}", arg.arg);
245
246 if let Some(annotation) = &arg.annotation {
247 let rust_type = annotation.as_ref().clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
248 params.push(quote!(#param_name: Option<#rust_type>));
249 } else {
250 params.push(quote!(#param_name: Option<impl Into<PyObject>>));
251 }
252 } else {
253 let param = arg.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
254 params.push(param);
255 }
256 }
257
258 if let Some(vararg) = self.vararg {
260 let vararg_name = quote::format_ident!("{}", vararg.arg);
261 params.push(quote!(#vararg_name: impl IntoIterator<Item = impl Into<PyObject>>));
262 }
263
264 for (i, arg) in self.kwonlyargs.into_iter().enumerate() {
266 let param_name = quote::format_ident!("{}", arg.arg);
267
268 let has_default = i < self.kw_defaults.len() && self.kw_defaults[i].is_some();
270
271 if let Some(annotation) = &arg.annotation {
272 let rust_type = annotation.as_ref().clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
273 if has_default {
274 params.push(quote!(#param_name: Option<#rust_type>));
275 } else {
276 params.push(quote!(#param_name: #rust_type));
277 }
278 } else {
279 if has_default {
280 params.push(quote!(#param_name: Option<impl Into<PyObject>>));
281 } else {
282 params.push(quote!(#param_name: impl Into<PyObject>));
283 }
284 }
285 }
286
287 if let Some(kwarg) = self.kwarg {
289 let kwarg_name = quote::format_ident!("{}", kwarg.arg);
290 params.push(quote!(#kwarg_name: impl IntoIterator<Item = (impl AsRef<str>, impl Into<PyObject>)>));
291 }
292
293 Ok(quote!(#(#params),*))
294 }
295}
296
297
298impl<'a> FromPyObject<'a> for CallArguments {
300 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
301 let args: Vec<ExprType> = ob.getattr("args")?.extract().unwrap_or_default();
302 let keywords: Vec<crate::Keyword> = ob.getattr("keywords")?.extract().unwrap_or_default();
303
304 Ok(Self { args, keywords })
305 }
306}
307
308impl CodeGen for CallArguments {
309 type Context = CodeGenContext;
310 type Options = PythonOptions;
311 type SymbolTable = SymbolTableScopes;
312
313 fn to_rust(
314 self,
315 ctx: Self::Context,
316 options: Self::Options,
317 symbols: Self::SymbolTable,
318 ) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
319 let mut all_args = Vec::new();
320
321 for arg in self.args {
323 let rust_arg = arg.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
324 all_args.push(rust_arg);
325 }
326
327 for keyword in self.keywords {
329 let rust_kw = keyword.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
330 all_args.push(rust_kw);
331 }
332
333 Ok(quote!(#(#all_args),*))
334 }
335}
336
337
338impl Node for Argument {
340 fn lineno(&self) -> Option<usize> { self.lineno }
341 fn col_offset(&self) -> Option<usize> { self.col_offset }
342 fn end_lineno(&self) -> Option<usize> { self.end_lineno }
343 fn end_col_offset(&self) -> Option<usize> { self.end_col_offset }
344}
345
346impl Node for Parameter {
347 fn lineno(&self) -> Option<usize> { self.lineno }
348 fn col_offset(&self) -> Option<usize> { self.col_offset }
349 fn end_lineno(&self) -> Option<usize> { self.end_lineno }
350 fn end_col_offset(&self) -> Option<usize> { self.end_col_offset }
351}
352
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357 use crate::{parse, CodeGenContext, ExprType, PythonOptions, SymbolTableScopes};
358 use test_log::test;
359
360 #[test]
361 fn test_simple_function_call() {
362 let code = "func(1, 2, 3)";
363 let result = parse(code, "test.py").unwrap();
364
365 let options = PythonOptions::default();
367 let symbols = SymbolTableScopes::new();
368 let rust_code = result.to_rust(
369 CodeGenContext::Module("test".to_string()),
370 options,
371 symbols,
372 ).unwrap();
373
374 println!("Generated code: {}", rust_code);
375 }
377
378 #[test]
379 fn test_keyword_arguments() {
380 let code = "func(a=1, b=2)";
381 let result = parse(code, "test.py").unwrap();
382
383 let options = PythonOptions::default();
384 let symbols = SymbolTableScopes::new();
385 let rust_code = result.to_rust(
386 CodeGenContext::Module("test".to_string()),
387 options,
388 symbols,
389 ).unwrap();
390
391 println!("Generated code: {}", rust_code);
392 }
394
395 #[test]
396 fn test_mixed_arguments() {
397 let code = "func(1, 2, c=3, d=4)";
398 let result = parse(code, "test.py").unwrap();
399
400 let options = PythonOptions::default();
401 let symbols = SymbolTableScopes::new();
402 let rust_code = result.to_rust(
403 CodeGenContext::Module("test".to_string()),
404 options,
405 symbols,
406 ).unwrap();
407
408 println!("Generated code: {}", rust_code);
409 }
411
412 #[test]
413 fn test_function_with_defaults() {
414 let code = r#"
415def func(a, b=2, c=3):
416 pass
417 "#;
418 let result = parse(code, "test.py").unwrap();
419
420 let options = PythonOptions::default();
421 let symbols = SymbolTableScopes::new();
422 let rust_code = result.to_rust(
423 CodeGenContext::Module("test".to_string()),
424 options,
425 symbols,
426 ).unwrap();
427
428 println!("Generated function: {}", rust_code);
429 }
431
432 #[test]
433 fn test_function_with_varargs() {
434 let code = r#"
435def func(a, *args):
436 pass
437 "#;
438 let result = parse(code, "test.py").unwrap();
439
440 let options = PythonOptions::default();
441 let symbols = SymbolTableScopes::new();
442 let rust_code = result.to_rust(
443 CodeGenContext::Module("test".to_string()),
444 options,
445 symbols,
446 ).unwrap();
447
448 println!("Generated function: {}", rust_code);
449 }
451
452 #[test]
453 fn test_function_with_kwargs() {
454 let code = r#"
455def func(a, **kwargs):
456 pass
457 "#;
458 let result = parse(code, "test.py").unwrap();
459
460 let options = PythonOptions::default();
461 let symbols = SymbolTableScopes::new();
462 let rust_code = result.to_rust(
463 CodeGenContext::Module("test".to_string()),
464 options,
465 symbols,
466 ).unwrap();
467
468 println!("Generated function: {}", rust_code);
469 }
471
472 #[test]
473 fn test_complex_function_signature() {
474 let code = r#"
475def func(a, b=2, *args, c, d=4, **kwargs):
476 pass
477 "#;
478 let result = parse(code, "test.py").unwrap();
479
480 let options = PythonOptions::default();
481 let symbols = SymbolTableScopes::new();
482 let rust_code = result.to_rust(
483 CodeGenContext::Module("test".to_string()),
484 options,
485 symbols,
486 ).unwrap();
487
488 println!("Generated function: {}", rust_code);
489 }
491
492 #[test]
493 fn test_keyword_only_arguments() {
494 let code = r#"
495def func(a, *, b, c=3):
496 pass
497 "#;
498 let result = parse(code, "test.py").unwrap();
499
500 let options = PythonOptions::default();
501 let symbols = SymbolTableScopes::new();
502 let rust_code = result.to_rust(
503 CodeGenContext::Module("test".to_string()),
504 options,
505 symbols,
506 ).unwrap();
507
508 println!("Generated function: {}", rust_code);
509 }
511
512 #[test]
513 fn test_argument_unpacking_call() {
514 let code = "func(*args, **kwargs)";
516 let result = parse(code, "test.py");
517
518 match result {
519 Ok(ast) => {
520 let options = PythonOptions::default();
521 let symbols = SymbolTableScopes::new();
522 let rust_code = ast.to_rust(
523 CodeGenContext::Module("test".to_string()),
524 options,
525 symbols,
526 );
527
528 match rust_code {
529 Ok(code) => println!("Generated code: {}", code),
530 Err(e) => println!("Expected error for unimplemented feature: {}", e),
531 }
532 }
533 Err(e) => println!("Parse error (expected for unimplemented features): {}", e),
534 }
535 }
536
537 #[test]
538 fn test_arg_with_constant() {
539 use litrs::Literal;
541 let literal = Literal::parse("42").unwrap().into_owned();
542 let constant = crate::Constant(Some(literal));
543 let arg: Arg = ExprType::Constant(constant);
544
545 let options = PythonOptions::default();
546 let symbols = SymbolTableScopes::new();
547 let rust_code = arg.to_rust(
548 CodeGenContext::Module("test".to_string()),
549 options,
550 symbols,
551 ).unwrap();
552
553 println!("Constant arg code: {}", rust_code);
554 assert!(rust_code.to_string().contains("42"));
555 }
556
557 #[test]
558 fn test_arg_with_name() {
559 let name_expr = ExprType::Name(crate::Name {
561 id: "variable".to_string(),
562 });
563 let arg: Arg = name_expr;
564
565 let options = PythonOptions::default();
566 let symbols = SymbolTableScopes::new();
567 let rust_code = arg.to_rust(
568 CodeGenContext::Module("test".to_string()),
569 options,
570 symbols,
571 ).unwrap();
572
573 println!("Name arg code: {}", rust_code);
574 assert!(rust_code.to_string().contains("variable"));
575 }
576}