python_ast/ast/tree/
arguments.rs

1//! The module defines Python-syntax arguments and maps them into Rust-syntax versions.
2use proc_macro2::TokenStream;
3use pyo3::{Bound, FromPyObject, PyAny, PyResult, prelude::PyAnyMethods};
4use quote::quote;
5use serde::{Deserialize, Serialize};
6
7use crate::{
8    CodeGen, CodeGenContext, ExprType, Node, PythonOptions, SymbolTableScopes,
9};
10
11/// A complete argument representation that can hold any Python expression.
12/// This replaces the limited Arg enum to support all argument types.
13#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
14pub struct Argument {
15    /// The argument expression (can be any valid Python expression)
16    pub value: ExprType,
17    /// Position information
18    pub lineno: Option<usize>,
19    pub col_offset: Option<usize>,
20    pub end_lineno: Option<usize>,
21    pub end_col_offset: Option<usize>,
22}
23
24/// An argument value that can be any expression.
25/// This replaces the old limited Arg enum.
26pub type Arg = ExprType;
27
28/// A function parameter definition with optional type annotation and default value.
29#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
30pub struct Parameter {
31    /// Parameter name
32    pub arg: String,
33    /// Optional type annotation
34    pub annotation: Option<Box<ExprType>>,
35    /// Optional type comment (deprecated Python feature)
36    pub type_comment: Option<String>,
37    /// Position information
38    pub lineno: Option<usize>,
39    pub col_offset: Option<usize>,
40    pub end_lineno: Option<usize>,
41    pub end_col_offset: Option<usize>,
42}
43
44/// Comprehensive function arguments structure supporting all Python argument types.
45#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
46pub struct Arguments {
47    /// Positional-only parameters (before / in Python 3.8+)
48    pub posonlyargs: Vec<Parameter>,
49    /// Regular positional parameters
50    pub args: Vec<Parameter>,
51    /// Variable positional parameter (*args)
52    pub vararg: Option<Parameter>,
53    /// Keyword-only parameters (after * or *args)
54    pub kwonlyargs: Vec<Parameter>,
55    /// Default values for keyword-only parameters (None = required)
56    pub kw_defaults: Vec<Option<Box<ExprType>>>,
57    /// Variable keyword parameter (**kwargs)
58    pub kwarg: Option<Parameter>,
59    /// Default values for regular positional parameters
60    pub defaults: Vec<Box<ExprType>>,
61}
62
63
64/// Function call arguments supporting all Python call patterns.
65#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
66pub struct CallArguments {
67    /// Positional arguments
68    pub args: Vec<ExprType>,
69    /// Keyword arguments
70    pub keywords: Vec<crate::Keyword>,
71}
72
73// Implementation for new Argument struct
74impl<'a> FromPyObject<'a> for Argument {
75    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
76        // Extract the expression value
77        let value: ExprType = ob.extract()?;
78        
79        Ok(Self {
80            value,
81            lineno: ob.lineno(),
82            col_offset: ob.col_offset(),
83            end_lineno: ob.end_lineno(),
84            end_col_offset: ob.end_col_offset(),
85        })
86    }
87}
88
89impl CodeGen for Argument {
90    type Context = CodeGenContext;
91    type Options = PythonOptions;
92    type SymbolTable = SymbolTableScopes;
93
94    fn to_rust(
95        self,
96        ctx: Self::Context,
97        options: Self::Options,
98        symbols: Self::SymbolTable,
99    ) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
100        self.value.to_rust(ctx, options, symbols)
101    }
102}
103
104// Implementation for Parameter struct
105impl<'a> FromPyObject<'a> for Parameter {
106    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
107        let arg: String = ob.getattr("arg")?.extract()?;
108        
109        // Extract optional annotation
110        let annotation = if let Ok(ann) = ob.getattr("annotation") {
111            if ann.is_none() {
112                None
113            } else {
114                Some(Box::new(ann.extract()?))
115            }
116        } else {
117            None
118        };
119        
120        // Extract optional type comment
121        let type_comment = if let Ok(tc) = ob.getattr("type_comment") {
122            if tc.is_none() {
123                None
124            } else {
125                Some(tc.extract()?)
126            }
127        } else {
128            None
129        };
130        
131        Ok(Self {
132            arg,
133            annotation,
134            type_comment,
135            lineno: ob.lineno(),
136            col_offset: ob.col_offset(),
137            end_lineno: ob.end_lineno(),
138            end_col_offset: ob.end_col_offset(),
139        })
140    }
141}
142
143impl CodeGen for Parameter {
144    type Context = CodeGenContext;
145    type Options = PythonOptions;
146    type SymbolTable = SymbolTableScopes;
147
148    fn to_rust(
149        self,
150        ctx: Self::Context,
151        options: Self::Options,
152        symbols: Self::SymbolTable,
153    ) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
154        use quote::format_ident;
155        
156        let param_name = format_ident!("{}", self.arg);
157        
158        // Generate type annotation if present
159        if let Some(annotation) = self.annotation {
160            let rust_type = annotation.to_rust(ctx, options, symbols)?;
161            Ok(quote!(#param_name: #rust_type))
162        } else {
163            // Default to generic type for untyped parameters
164            Ok(quote!(#param_name: impl Into<PyObject>))
165        }
166    }
167}
168
169// Implementation for Arguments struct
170impl<'a> FromPyObject<'a> for Arguments {
171    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
172        // Extract each field with proper error handling
173        let posonlyargs: Vec<Parameter> = ob.getattr("posonlyargs")?.extract().unwrap_or_default();
174        let args: Vec<Parameter> = ob.getattr("args")?.extract().unwrap_or_default();
175        
176        let vararg = if let Ok(va) = ob.getattr("vararg") {
177            if va.is_none() { None } else { Some(va.extract()?) }
178        } else { None };
179        
180        let kwonlyargs: Vec<Parameter> = ob.getattr("kwonlyargs")?.extract().unwrap_or_default();
181        
182        // Handle kw_defaults which can contain None values
183        let kw_defaults = if let Ok(kw_def) = ob.getattr("kw_defaults") {
184            let defaults_list: Vec<Bound<PyAny>> = kw_def.extract().unwrap_or_default();
185            let mut processed_defaults = Vec::new();
186            for default in defaults_list {
187                if default.is_none() {
188                    processed_defaults.push(None);
189                } else {
190                    processed_defaults.push(Some(Box::new(default.extract()?)));
191                }
192            }
193            processed_defaults
194        } else {
195            Vec::new()
196        };
197        
198        let kwarg = if let Ok(kw) = ob.getattr("kwarg") {
199            if kw.is_none() { None } else { Some(kw.extract()?) }
200        } else { None };
201        
202        let defaults_raw: Vec<ExprType> = ob.getattr("defaults")?.extract().unwrap_or_default();
203        let defaults = defaults_raw.into_iter().map(Box::new).collect();
204        
205        Ok(Self {
206            posonlyargs,
207            args,
208            vararg,
209            kwonlyargs,
210            kw_defaults,
211            kwarg,
212            defaults,
213        })
214    }
215}
216
217impl CodeGen for Arguments {
218    type Context = CodeGenContext;
219    type Options = PythonOptions;
220    type SymbolTable = SymbolTableScopes;
221
222    fn to_rust(
223        self,
224        ctx: Self::Context,
225        options: Self::Options,
226        symbols: Self::SymbolTable,
227    ) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
228        let mut params = Vec::new();
229        
230        // Process positional-only arguments
231        for arg in self.posonlyargs {
232            let param = arg.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
233            params.push(param);
234        }
235        
236        // Process regular positional arguments with defaults
237        let defaults_offset = self.args.len().saturating_sub(self.defaults.len());
238        for (i, arg) in self.args.into_iter().enumerate() {
239            if i >= defaults_offset {
240                // This argument has a default value
241                let default_idx = i - defaults_offset;
242                let default_value = &self.defaults[default_idx];
243                let _default_rust = default_value.as_ref().clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
244                let param_name = quote::format_ident!("{}", arg.arg);
245                
246                if let Some(annotation) = &arg.annotation {
247                    let rust_type = annotation.as_ref().clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
248                    params.push(quote!(#param_name: Option<#rust_type>));
249                } else {
250                    params.push(quote!(#param_name: Option<impl Into<PyObject>>));
251                }
252            } else {
253                let param = arg.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
254                params.push(param);
255            }
256        }
257        
258        // Process *args
259        if let Some(vararg) = self.vararg {
260            let vararg_name = quote::format_ident!("{}", vararg.arg);
261            params.push(quote!(#vararg_name: impl IntoIterator<Item = impl Into<PyObject>>));
262        }
263        
264        // Process keyword-only arguments
265        for (i, arg) in self.kwonlyargs.into_iter().enumerate() {
266            let param_name = quote::format_ident!("{}", arg.arg);
267            
268            // Check if this keyword-only arg has a default
269            let has_default = i < self.kw_defaults.len() && self.kw_defaults[i].is_some();
270            
271            if let Some(annotation) = &arg.annotation {
272                let rust_type = annotation.as_ref().clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
273                if has_default {
274                    params.push(quote!(#param_name: Option<#rust_type>));
275                } else {
276                    params.push(quote!(#param_name: #rust_type));
277                }
278            } else {
279                if has_default {
280                    params.push(quote!(#param_name: Option<impl Into<PyObject>>));
281                } else {
282                    params.push(quote!(#param_name: impl Into<PyObject>));
283                }
284            }
285        }
286        
287        // Process **kwargs
288        if let Some(kwarg) = self.kwarg {
289            let kwarg_name = quote::format_ident!("{}", kwarg.arg);
290            params.push(quote!(#kwarg_name: impl IntoIterator<Item = (impl AsRef<str>, impl Into<PyObject>)>));
291        }
292        
293        Ok(quote!(#(#params),*))
294    }
295}
296
297
298// Implementation for CallArguments
299impl<'a> FromPyObject<'a> for CallArguments {
300    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
301        let args: Vec<ExprType> = ob.getattr("args")?.extract().unwrap_or_default();
302        let keywords: Vec<crate::Keyword> = ob.getattr("keywords")?.extract().unwrap_or_default();
303        
304        Ok(Self { args, keywords })
305    }
306}
307
308impl CodeGen for CallArguments {
309    type Context = CodeGenContext;
310    type Options = PythonOptions;
311    type SymbolTable = SymbolTableScopes;
312
313    fn to_rust(
314        self,
315        ctx: Self::Context,
316        options: Self::Options,
317        symbols: Self::SymbolTable,
318    ) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
319        let mut all_args = Vec::new();
320        
321        // Add positional arguments
322        for arg in self.args {
323            let rust_arg = arg.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
324            all_args.push(rust_arg);
325        }
326        
327        // Add keyword arguments
328        for keyword in self.keywords {
329            let rust_kw = keyword.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
330            all_args.push(rust_kw);
331        }
332        
333        Ok(quote!(#(#all_args),*))
334    }
335}
336
337
338// Node trait implementations for position tracking
339impl Node for Argument {
340    fn lineno(&self) -> Option<usize> { self.lineno }
341    fn col_offset(&self) -> Option<usize> { self.col_offset }
342    fn end_lineno(&self) -> Option<usize> { self.end_lineno }
343    fn end_col_offset(&self) -> Option<usize> { self.end_col_offset }
344}
345
346impl Node for Parameter {
347    fn lineno(&self) -> Option<usize> { self.lineno }
348    fn col_offset(&self) -> Option<usize> { self.col_offset }
349    fn end_lineno(&self) -> Option<usize> { self.end_lineno }
350    fn end_col_offset(&self) -> Option<usize> { self.end_col_offset }
351}
352
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357    use crate::{parse, CodeGenContext, ExprType, PythonOptions, SymbolTableScopes};
358    use test_log::test;
359
360    #[test]
361    fn test_simple_function_call() {
362        let code = "func(1, 2, 3)";
363        let result = parse(code, "test.py").unwrap();
364        
365        // Generate Rust code
366        let options = PythonOptions::default();
367        let symbols = SymbolTableScopes::new();
368        let rust_code = result.to_rust(
369            CodeGenContext::Module("test".to_string()),
370            options,
371            symbols,
372        ).unwrap();
373        
374        println!("Generated code: {}", rust_code);
375        // Should generate function call with positional arguments
376    }
377
378    #[test]
379    fn test_keyword_arguments() {
380        let code = "func(a=1, b=2)";
381        let result = parse(code, "test.py").unwrap();
382        
383        let options = PythonOptions::default();
384        let symbols = SymbolTableScopes::new();
385        let rust_code = result.to_rust(
386            CodeGenContext::Module("test".to_string()),
387            options,
388            symbols,
389        ).unwrap();
390        
391        println!("Generated code: {}", rust_code);
392        // Should generate function call with keyword arguments
393    }
394
395    #[test]
396    fn test_mixed_arguments() {
397        let code = "func(1, 2, c=3, d=4)";
398        let result = parse(code, "test.py").unwrap();
399        
400        let options = PythonOptions::default();
401        let symbols = SymbolTableScopes::new();
402        let rust_code = result.to_rust(
403            CodeGenContext::Module("test".to_string()),
404            options,
405            symbols,
406        ).unwrap();
407        
408        println!("Generated code: {}", rust_code);
409        // Should generate function call with mixed positional and keyword arguments
410    }
411
412    #[test]
413    fn test_function_with_defaults() {
414        let code = r#"
415def func(a, b=2, c=3):
416    pass
417        "#;
418        let result = parse(code, "test.py").unwrap();
419        
420        let options = PythonOptions::default();
421        let symbols = SymbolTableScopes::new();
422        let rust_code = result.to_rust(
423            CodeGenContext::Module("test".to_string()),
424            options,
425            symbols,
426        ).unwrap();
427        
428        println!("Generated function: {}", rust_code);
429        // Should generate function with optional parameters
430    }
431
432    #[test]
433    fn test_function_with_varargs() {
434        let code = r#"
435def func(a, *args):
436    pass
437        "#;
438        let result = parse(code, "test.py").unwrap();
439        
440        let options = PythonOptions::default();
441        let symbols = SymbolTableScopes::new();
442        let rust_code = result.to_rust(
443            CodeGenContext::Module("test".to_string()),
444            options,
445            symbols,
446        ).unwrap();
447        
448        println!("Generated function: {}", rust_code);
449        // Should generate function with variable arguments
450    }
451
452    #[test]
453    fn test_function_with_kwargs() {
454        let code = r#"
455def func(a, **kwargs):
456    pass
457        "#;
458        let result = parse(code, "test.py").unwrap();
459        
460        let options = PythonOptions::default();
461        let symbols = SymbolTableScopes::new();
462        let rust_code = result.to_rust(
463            CodeGenContext::Module("test".to_string()),
464            options,
465            symbols,
466        ).unwrap();
467        
468        println!("Generated function: {}", rust_code);
469        // Should generate function with keyword arguments dict
470    }
471
472    #[test]
473    fn test_complex_function_signature() {
474        let code = r#"
475def func(a, b=2, *args, c, d=4, **kwargs):
476    pass
477        "#;
478        let result = parse(code, "test.py").unwrap();
479        
480        let options = PythonOptions::default();
481        let symbols = SymbolTableScopes::new();
482        let rust_code = result.to_rust(
483            CodeGenContext::Module("test".to_string()),
484            options,
485            symbols,
486        ).unwrap();
487        
488        println!("Generated function: {}", rust_code);
489        // Should generate function with all argument types
490    }
491
492    #[test]
493    fn test_keyword_only_arguments() {
494        let code = r#"
495def func(a, *, b, c=3):
496    pass
497        "#;
498        let result = parse(code, "test.py").unwrap();
499        
500        let options = PythonOptions::default();
501        let symbols = SymbolTableScopes::new();
502        let rust_code = result.to_rust(
503            CodeGenContext::Module("test".to_string()),
504            options,
505            symbols,
506        ).unwrap();
507        
508        println!("Generated function: {}", rust_code);
509        // Should generate function with keyword-only arguments
510    }
511
512    #[test]
513    fn test_argument_unpacking_call() {
514        // Note: This would require additional AST node support for Starred expressions
515        let code = "func(*args, **kwargs)";
516        let result = parse(code, "test.py");
517        
518        match result {
519            Ok(ast) => {
520                let options = PythonOptions::default();
521                let symbols = SymbolTableScopes::new();
522                let rust_code = ast.to_rust(
523                    CodeGenContext::Module("test".to_string()),
524                    options,
525                    symbols,
526                );
527                
528                match rust_code {
529                    Ok(code) => println!("Generated code: {}", code),
530                    Err(e) => println!("Expected error for unimplemented feature: {}", e),
531                }
532            }
533            Err(e) => println!("Parse error (expected for unimplemented features): {}", e),
534        }
535    }
536
537    #[test]
538    fn test_arg_with_constant() {
539        // Test that Arg (now ExprType) works with constants
540        use litrs::Literal;
541        let literal = Literal::parse("42").unwrap().into_owned();
542        let constant = crate::Constant(Some(literal));
543        let arg: Arg = ExprType::Constant(constant);
544        
545        let options = PythonOptions::default();
546        let symbols = SymbolTableScopes::new();
547        let rust_code = arg.to_rust(
548            CodeGenContext::Module("test".to_string()),
549            options,
550            symbols,
551        ).unwrap();
552        
553        println!("Constant arg code: {}", rust_code);
554        assert!(rust_code.to_string().contains("42"));
555    }
556
557    #[test]
558    fn test_arg_with_name() {
559        // Test that Arg (now ExprType) works with name expressions
560        let name_expr = ExprType::Name(crate::Name {
561            id: "variable".to_string(),
562        });
563        let arg: Arg = name_expr;
564        
565        let options = PythonOptions::default();
566        let symbols = SymbolTableScopes::new();
567        let rust_code = arg.to_rust(
568            CodeGenContext::Module("test".to_string()),
569            options,
570            symbols,
571        ).unwrap();
572        
573        println!("Name arg code: {}", rust_code);
574        assert!(rust_code.to_string().contains("variable"));
575    }
576}