python_ast/ast/tree/
arguments.rs

1//! The module defines Python-syntax arguments and maps them into Rust-syntax versions.
2use proc_macro2::TokenStream;
3use pyo3::{Bound, FromPyObject, PyAny, PyResult, prelude::PyAnyMethods};
4use quote::quote;
5use serde::{Deserialize, Serialize};
6
7use crate::{
8    CodeGen, CodeGenContext, ExprType, Node, PythonOptions, SymbolTableScopes,
9};
10
11/// A complete argument representation that can hold any Python expression.
12/// This replaces the limited Arg enum to support all argument types.
13#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
14pub struct Argument {
15    /// The argument expression (can be any valid Python expression)
16    pub value: ExprType,
17    /// Position information
18    pub lineno: Option<usize>,
19    pub col_offset: Option<usize>,
20    pub end_lineno: Option<usize>,
21    pub end_col_offset: Option<usize>,
22}
23
24/// An argument value that can be any expression.
25/// This replaces the old limited Arg enum.
26pub type Arg = ExprType;
27
28/// A function parameter definition with optional type annotation and default value.
29#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
30pub struct Parameter {
31    /// Parameter name
32    pub arg: String,
33    /// Optional type annotation
34    pub annotation: Option<Box<ExprType>>,
35    /// Optional type comment (deprecated Python feature)
36    pub type_comment: Option<String>,
37    /// Position information
38    pub lineno: Option<usize>,
39    pub col_offset: Option<usize>,
40    pub end_lineno: Option<usize>,
41    pub end_col_offset: Option<usize>,
42}
43
44/// Comprehensive function arguments structure supporting all Python argument types.
45#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
46pub struct Arguments {
47    /// Positional-only parameters (before / in Python 3.8+)
48    pub posonlyargs: Vec<Parameter>,
49    /// Regular positional parameters
50    pub args: Vec<Parameter>,
51    /// Variable positional parameter (*args)
52    pub vararg: Option<Parameter>,
53    /// Keyword-only parameters (after * or *args)
54    pub kwonlyargs: Vec<Parameter>,
55    /// Default values for keyword-only parameters (None = required)
56    pub kw_defaults: Vec<Option<Box<ExprType>>>,
57    /// Variable keyword parameter (**kwargs)
58    pub kwarg: Option<Parameter>,
59    /// Default values for regular positional parameters
60    pub defaults: Vec<Box<ExprType>>,
61}
62
63
64/// Function call arguments supporting all Python call patterns.
65#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
66pub struct CallArguments {
67    /// Positional arguments
68    pub args: Vec<ExprType>,
69    /// Keyword arguments
70    pub keywords: Vec<crate::Keyword>,
71}
72
73// Implementation for new Argument struct
74impl<'a> FromPyObject<'a> for Argument {
75    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
76        // Extract the expression value
77        let value: ExprType = ob.extract()?;
78        
79        Ok(Self {
80            value,
81            lineno: ob.lineno(),
82            col_offset: ob.col_offset(),
83            end_lineno: ob.end_lineno(),
84            end_col_offset: ob.end_col_offset(),
85        })
86    }
87}
88
89impl CodeGen for Argument {
90    type Context = CodeGenContext;
91    type Options = PythonOptions;
92    type SymbolTable = SymbolTableScopes;
93
94    fn to_rust(
95        self,
96        ctx: Self::Context,
97        options: Self::Options,
98        symbols: Self::SymbolTable,
99    ) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
100        self.value.to_rust(ctx, options, symbols)
101    }
102}
103
104// Implementation for Parameter struct
105impl<'a> FromPyObject<'a> for Parameter {
106    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
107        let arg: String = ob.getattr("arg")?.extract()?;
108        
109        // Extract optional annotation
110        let annotation = if let Ok(ann) = ob.getattr("annotation") {
111            if ann.is_none() {
112                None
113            } else {
114                Some(Box::new(ann.extract()?))
115            }
116        } else {
117            None
118        };
119        
120        // Extract optional type comment
121        let type_comment = if let Ok(tc) = ob.getattr("type_comment") {
122            if tc.is_none() {
123                None
124            } else {
125                Some(tc.extract()?)
126            }
127        } else {
128            None
129        };
130        
131        Ok(Self {
132            arg,
133            annotation,
134            type_comment,
135            lineno: ob.lineno(),
136            col_offset: ob.col_offset(),
137            end_lineno: ob.end_lineno(),
138            end_col_offset: ob.end_col_offset(),
139        })
140    }
141}
142
143impl CodeGen for Parameter {
144    type Context = CodeGenContext;
145    type Options = PythonOptions;
146    type SymbolTable = SymbolTableScopes;
147
148    fn to_rust(
149        self,
150        ctx: Self::Context,
151        options: Self::Options,
152        symbols: Self::SymbolTable,
153    ) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
154        use quote::format_ident;
155        
156        let param_name = format_ident!("{}", self.arg);
157        
158        // Generate type annotation if present
159        if let Some(annotation) = self.annotation {
160            let rust_type = annotation.to_rust(ctx, options, symbols)?;
161            Ok(quote!(#param_name: #rust_type))
162        } else {
163            // Default to generic type for untyped parameters
164            Ok(quote!(#param_name: impl Into<PyObject>))
165        }
166    }
167}
168
169// Implementation for Arguments struct
170impl<'a> FromPyObject<'a> for Arguments {
171    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
172        // Extract each field with proper error handling
173        let posonlyargs: Vec<Parameter> = ob.getattr("posonlyargs")?.extract().unwrap_or_default();
174        let args: Vec<Parameter> = ob.getattr("args")?.extract().unwrap_or_default();
175        
176        let vararg = if let Ok(va) = ob.getattr("vararg") {
177            if va.is_none() { None } else { Some(va.extract()?) }
178        } else { None };
179        
180        let kwonlyargs: Vec<Parameter> = ob.getattr("kwonlyargs")?.extract().unwrap_or_default();
181        
182        // Handle kw_defaults which can contain None values
183        let kw_defaults = if let Ok(kw_def) = ob.getattr("kw_defaults") {
184            let defaults_list: Vec<Bound<PyAny>> = kw_def.extract().unwrap_or_default();
185            let mut processed_defaults = Vec::new();
186            for default in defaults_list {
187                if default.is_none() {
188                    processed_defaults.push(None);
189                } else {
190                    processed_defaults.push(Some(Box::new(default.extract()?)));
191                }
192            }
193            processed_defaults
194        } else {
195            Vec::new()
196        };
197        
198        let kwarg = if let Ok(kw) = ob.getattr("kwarg") {
199            if kw.is_none() { None } else { Some(kw.extract()?) }
200        } else { None };
201        
202        let defaults_raw: Vec<ExprType> = ob.getattr("defaults")?.extract().unwrap_or_default();
203        let defaults = defaults_raw.into_iter().map(Box::new).collect();
204        
205        Ok(Self {
206            posonlyargs,
207            args,
208            vararg,
209            kwonlyargs,
210            kw_defaults,
211            kwarg,
212            defaults,
213        })
214    }
215}
216
217impl CodeGen for Arguments {
218    type Context = CodeGenContext;
219    type Options = PythonOptions;
220    type SymbolTable = SymbolTableScopes;
221
222    fn to_rust(
223        self,
224        ctx: Self::Context,
225        options: Self::Options,
226        symbols: Self::SymbolTable,
227    ) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
228        let mut params = Vec::new();
229        
230        // Process positional-only arguments
231        for arg in self.posonlyargs {
232            let param = arg.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
233            params.push(param);
234        }
235        
236        // Process regular positional arguments with defaults
237        let defaults_offset = self.args.len().saturating_sub(self.defaults.len());
238        for (i, arg) in self.args.into_iter().enumerate() {
239            if i >= defaults_offset {
240                // This argument has a default value
241                let default_idx = i - defaults_offset;
242                let default_value = &self.defaults[default_idx];
243                let _default_rust = default_value.as_ref().clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
244                let param_name = quote::format_ident!("{}", arg.arg);
245                
246                if let Some(annotation) = &arg.annotation {
247                    let rust_type = annotation.as_ref().clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
248                    params.push(quote!(#param_name: Option<#rust_type>));
249                } else {
250                    params.push(quote!(#param_name: Option<impl Into<PyObject>>));
251                }
252            } else {
253                let param = arg.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
254                params.push(param);
255            }
256        }
257        
258        // Process *args
259        if let Some(vararg) = self.vararg {
260            let vararg_name = quote::format_ident!("{}", vararg.arg);
261            params.push(quote!(#vararg_name: impl IntoIterator<Item = impl Into<PyObject>>));
262        }
263        
264        // Process keyword-only arguments
265        for (i, arg) in self.kwonlyargs.into_iter().enumerate() {
266            let param_name = quote::format_ident!("{}", arg.arg);
267            
268            // Check if this keyword-only arg has a default
269            let has_default = i < self.kw_defaults.len() && self.kw_defaults[i].is_some();
270            
271            if let Some(annotation) = &arg.annotation {
272                let rust_type = annotation.as_ref().clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
273                if has_default {
274                    params.push(quote!(#param_name: Option<#rust_type>));
275                } else {
276                    params.push(quote!(#param_name: #rust_type));
277                }
278            } else {
279                if has_default {
280                    params.push(quote!(#param_name: Option<impl Into<PyObject>>));
281                } else {
282                    params.push(quote!(#param_name: impl Into<PyObject>));
283                }
284            }
285        }
286        
287        // Process **kwargs
288        if let Some(kwarg) = self.kwarg {
289            let kwarg_name = quote::format_ident!("{}", kwarg.arg);
290            params.push(quote!(#kwarg_name: impl IntoIterator<Item = (impl AsRef<str>, impl Into<PyObject>)>));
291        }
292        
293        Ok(quote!(#(#params),*))
294    }
295}
296
297
298// Implementation for CallArguments
299impl<'a> FromPyObject<'a> for CallArguments {
300    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
301        let args: Vec<ExprType> = ob.getattr("args")?.extract().unwrap_or_default();
302        let keywords: Vec<crate::Keyword> = ob.getattr("keywords")?.extract().unwrap_or_default();
303        
304        Ok(Self { args, keywords })
305    }
306}
307
308impl CodeGen for CallArguments {
309    type Context = CodeGenContext;
310    type Options = PythonOptions;
311    type SymbolTable = SymbolTableScopes;
312
313    fn to_rust(
314        self,
315        ctx: Self::Context,
316        options: Self::Options,
317        symbols: Self::SymbolTable,
318    ) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
319        let mut all_args = Vec::new();
320        
321        // Add positional arguments
322        for arg in self.args {
323            let rust_arg = arg.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
324            all_args.push(rust_arg);
325        }
326        
327        // Add keyword arguments
328        for keyword in self.keywords {
329            let rust_kw = keyword.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
330            all_args.push(rust_kw);
331        }
332        
333        Ok(quote!(#(#all_args),*))
334    }
335}
336
337
338// Node trait implementations for position tracking
339impl Node for Argument {
340    fn lineno(&self) -> Option<usize> { self.lineno }
341    fn col_offset(&self) -> Option<usize> { self.col_offset }
342    fn end_lineno(&self) -> Option<usize> { self.end_lineno }
343    fn end_col_offset(&self) -> Option<usize> { self.end_col_offset }
344}
345
346impl Node for Parameter {
347    fn lineno(&self) -> Option<usize> { self.lineno }
348    fn col_offset(&self) -> Option<usize> { self.col_offset }
349    fn end_lineno(&self) -> Option<usize> { self.end_lineno }
350    fn end_col_offset(&self) -> Option<usize> { self.end_col_offset }
351}
352
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357    use crate::{parse, CodeGenContext, ExprType, PythonOptions, SymbolTableScopes};
358    use test_log::test;
359
360    #[test]
361    fn test_simple_function_call() {
362        let code = "func(1, 2, 3)";
363        let result = parse(code, "test.py").unwrap();
364        
365        // Generate Rust code
366        let options = PythonOptions::default();
367        let symbols = SymbolTableScopes::new();
368        let _rust_code = result.to_rust(
369            CodeGenContext::Module("test".to_string()),
370            options,
371            symbols,
372        ).unwrap();
373        
374        // Should generate function call with positional arguments
375    }
376
377    #[test]
378    fn test_keyword_arguments() {
379        let code = "func(a=1, b=2)";
380        let result = parse(code, "test.py").unwrap();
381        
382        let options = PythonOptions::default();
383        let symbols = SymbolTableScopes::new();
384        let _rust_code = result.to_rust(
385            CodeGenContext::Module("test".to_string()),
386            options,
387            symbols,
388        ).unwrap();
389        
390        // Should generate function call with keyword arguments
391    }
392
393    #[test]
394    fn test_mixed_arguments() {
395        let code = "func(1, 2, c=3, d=4)";
396        let result = parse(code, "test.py").unwrap();
397        
398        let options = PythonOptions::default();
399        let symbols = SymbolTableScopes::new();
400        let _rust_code = result.to_rust(
401            CodeGenContext::Module("test".to_string()),
402            options,
403            symbols,
404        ).unwrap();
405        
406        // Should generate function call with mixed positional and keyword arguments
407    }
408
409    #[test]
410    fn test_function_with_defaults() {
411        let code = r#"
412def func(a, b=2, c=3):
413    pass
414        "#;
415        let result = parse(code, "test.py").unwrap();
416        
417        let options = PythonOptions::default();
418        let symbols = SymbolTableScopes::new();
419        let _rust_code = result.to_rust(
420            CodeGenContext::Module("test".to_string()),
421            options,
422            symbols,
423        ).unwrap();
424        
425        // Should generate function with optional parameters
426    }
427
428    #[test]
429    fn test_function_with_varargs() {
430        let code = r#"
431def func(a, *args):
432    pass
433        "#;
434        let result = parse(code, "test.py").unwrap();
435        
436        let options = PythonOptions::default();
437        let symbols = SymbolTableScopes::new();
438        let _rust_code = result.to_rust(
439            CodeGenContext::Module("test".to_string()),
440            options,
441            symbols,
442        ).unwrap();
443        
444        // Should generate function with variable arguments
445    }
446
447    #[test]
448    fn test_function_with_kwargs() {
449        let code = r#"
450def func(a, **kwargs):
451    pass
452        "#;
453        let result = parse(code, "test.py").unwrap();
454        
455        let options = PythonOptions::default();
456        let symbols = SymbolTableScopes::new();
457        let _rust_code = result.to_rust(
458            CodeGenContext::Module("test".to_string()),
459            options,
460            symbols,
461        ).unwrap();
462        
463        // Should generate function with keyword arguments dict
464    }
465
466    #[test]
467    fn test_complex_function_signature() {
468        let code = r#"
469def func(a, b=2, *args, c, d=4, **kwargs):
470    pass
471        "#;
472        let result = parse(code, "test.py").unwrap();
473        
474        let options = PythonOptions::default();
475        let symbols = SymbolTableScopes::new();
476        let _rust_code = result.to_rust(
477            CodeGenContext::Module("test".to_string()),
478            options,
479            symbols,
480        ).unwrap();
481        
482        // Should generate function with all argument types
483    }
484
485    #[test]
486    fn test_keyword_only_arguments() {
487        let code = r#"
488def func(a, *, b, c=3):
489    pass
490        "#;
491        let result = parse(code, "test.py").unwrap();
492        
493        let options = PythonOptions::default();
494        let symbols = SymbolTableScopes::new();
495        let _rust_code = result.to_rust(
496            CodeGenContext::Module("test".to_string()),
497            options,
498            symbols,
499        ).unwrap();
500        
501        // Should generate function with keyword-only arguments
502    }
503
504    #[test]
505    fn test_argument_unpacking_call() {
506        // Note: This would require additional AST node support for Starred expressions
507        let code = "func(*args, **kwargs)";
508        let result = parse(code, "test.py");
509        
510        match result {
511            Ok(ast) => {
512                let options = PythonOptions::default();
513                let symbols = SymbolTableScopes::new();
514                let rust_code = ast.to_rust(
515                    CodeGenContext::Module("test".to_string()),
516                    options,
517                    symbols,
518                );
519                
520                match rust_code {
521                    Ok(_code) => { /* Code generation succeeded as expected */ },
522                    Err(_e) => { /* Expected error for unimplemented feature */ },
523                }
524            }
525            Err(_e) => { /* Parse error expected for unimplemented features */ },
526        }
527    }
528
529    #[test]
530    fn test_arg_with_constant() {
531        // Test that Arg (now ExprType) works with constants
532        use litrs::Literal;
533        let literal = Literal::parse("42").unwrap().into_owned();
534        let constant = crate::Constant(Some(literal));
535        let arg: Arg = ExprType::Constant(constant);
536        
537        let options = PythonOptions::default();
538        let symbols = SymbolTableScopes::new();
539        let rust_code = arg.to_rust(
540            CodeGenContext::Module("test".to_string()),
541            options,
542            symbols,
543        ).unwrap();
544        
545        assert!(rust_code.to_string().contains("42"));
546    }
547
548    #[test]
549    fn test_arg_with_name() {
550        // Test that Arg (now ExprType) works with name expressions
551        let name_expr = ExprType::Name(crate::Name {
552            id: "variable".to_string(),
553        });
554        let arg: Arg = name_expr;
555        
556        let options = PythonOptions::default();
557        let symbols = SymbolTableScopes::new();
558        let rust_code = arg.to_rust(
559            CodeGenContext::Module("test".to_string()),
560            options,
561            symbols,
562        ).unwrap();
563        
564        assert!(rust_code.to_string().contains("variable"));
565    }
566}