python_ast/ast/tree/
list_comp.rs

1use proc_macro2::TokenStream;
2use pyo3::{Bound, FromPyObject, PyAny, PyResult, prelude::PyAnyMethods};
3use quote::quote;
4use serde::{Deserialize, Serialize};
5
6use crate::{
7    CodeGen, CodeGenContext, ExprType, Node, PythonOptions, SymbolTableScopes,
8    PyAttributeExtractor, extract_list,
9};
10
11/// List comprehension (e.g., [x ** 2 for x in range(10) if x % 2 == 0])
12#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
13pub struct ListComp {
14    /// The element expression being computed
15    pub elt: Box<ExprType>,
16    /// The generators (for clauses)
17    pub generators: Vec<Comprehension>,
18    /// Position information
19    pub lineno: Option<usize>,
20    pub col_offset: Option<usize>,
21    pub end_lineno: Option<usize>,
22    pub end_col_offset: Option<usize>,
23}
24
25/// Set comprehension (e.g., {x for x in range(10) if x % 2 == 0})
26#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
27pub struct SetComp {
28    /// The element expression being computed
29    pub elt: Box<ExprType>,
30    /// The generators (for clauses)
31    pub generators: Vec<Comprehension>,
32    /// Position information
33    pub lineno: Option<usize>,
34    pub col_offset: Option<usize>,
35    pub end_lineno: Option<usize>,
36    pub end_col_offset: Option<usize>,
37}
38
39/// Generator expression (e.g., (x for x in range(10) if x % 2 == 0))
40#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
41pub struct GeneratorExp {
42    /// The element expression being computed
43    pub elt: Box<ExprType>,
44    /// The generators (for clauses)
45    pub generators: Vec<Comprehension>,
46    /// Position information
47    pub lineno: Option<usize>,
48    pub col_offset: Option<usize>,
49    pub end_lineno: Option<usize>,
50    pub end_col_offset: Option<usize>,
51}
52
53/// Dictionary comprehension (e.g., {k: v for k, v in items.items()})
54#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
55pub struct DictComp {
56    /// The key expression being computed
57    pub key: Box<ExprType>,
58    /// The value expression being computed
59    pub value: Box<ExprType>,
60    /// The generators (for clauses)
61    pub generators: Vec<Comprehension>,
62    /// Position information
63    pub lineno: Option<usize>,
64    pub col_offset: Option<usize>,
65    pub end_lineno: Option<usize>,
66    pub end_col_offset: Option<usize>,
67}
68
69/// A comprehension generator (for x in iter if condition)
70#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
71pub struct Comprehension {
72    /// The target variable(s) (e.g., x in "for x in range(10)")
73    pub target: ExprType,
74    /// The iterable expression (e.g., range(10) in "for x in range(10)")
75    pub iter: ExprType,
76    /// The conditions (if clauses)
77    pub ifs: Vec<ExprType>,
78    /// Whether this is an async comprehension
79    pub is_async: bool,
80}
81
82impl<'a> FromPyObject<'a> for ListComp {
83    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
84        // Extract the element expression
85        let elt = ob.extract_attr_with_context("elt", "list comprehension element")?;
86        let elt: ExprType = elt.extract()?;
87        
88        // Extract generators
89        let generators: Vec<Comprehension> = extract_list(ob, "generators", "list comprehension generators")?;
90        
91        Ok(ListComp {
92            elt: Box::new(elt),
93            generators,
94            lineno: ob.lineno(),
95            col_offset: ob.col_offset(),
96            end_lineno: ob.end_lineno(),
97            end_col_offset: ob.end_col_offset(),
98        })
99    }
100}
101
102impl<'a> FromPyObject<'a> for SetComp {
103    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
104        // Extract the element expression
105        let elt = ob.extract_attr_with_context("elt", "set comprehension element")?;
106        let elt: ExprType = elt.extract()?;
107        
108        // Extract generators
109        let generators: Vec<Comprehension> = extract_list(ob, "generators", "set comprehension generators")?;
110        
111        Ok(SetComp {
112            elt: Box::new(elt),
113            generators,
114            lineno: ob.lineno(),
115            col_offset: ob.col_offset(),
116            end_lineno: ob.end_lineno(),
117            end_col_offset: ob.end_col_offset(),
118        })
119    }
120}
121
122impl<'a> FromPyObject<'a> for GeneratorExp {
123    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
124        // Extract the element expression
125        let elt = ob.extract_attr_with_context("elt", "generator expression element")?;
126        let elt: ExprType = elt.extract()?;
127        
128        // Extract generators
129        let generators: Vec<Comprehension> = extract_list(ob, "generators", "generator expression generators")?;
130        
131        Ok(GeneratorExp {
132            elt: Box::new(elt),
133            generators,
134            lineno: ob.lineno(),
135            col_offset: ob.col_offset(),
136            end_lineno: ob.end_lineno(),
137            end_col_offset: ob.end_col_offset(),
138        })
139    }
140}
141
142impl<'a> FromPyObject<'a> for DictComp {
143    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
144        // Extract the key expression
145        let key = ob.extract_attr_with_context("key", "dict comprehension key")?;
146        let key: ExprType = key.extract()?;
147        
148        // Extract the value expression
149        let value = ob.extract_attr_with_context("value", "dict comprehension value")?;
150        let value: ExprType = value.extract()?;
151        
152        // Extract generators
153        let generators: Vec<Comprehension> = extract_list(ob, "generators", "dict comprehension generators")?;
154        
155        Ok(DictComp {
156            key: Box::new(key),
157            value: Box::new(value),
158            generators,
159            lineno: ob.lineno(),
160            col_offset: ob.col_offset(),
161            end_lineno: ob.end_lineno(),
162            end_col_offset: ob.end_col_offset(),
163        })
164    }
165}
166
167impl<'a> FromPyObject<'a> for Comprehension {
168    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
169        // Extract target
170        let target = ob.extract_attr_with_context("target", "comprehension target")?;
171        let target: ExprType = target.extract()?;
172        
173        // Extract iter
174        let iter = ob.extract_attr_with_context("iter", "comprehension iter")?;
175        let iter: ExprType = iter.extract()?;
176        
177        // Extract ifs (list of conditions)
178        let ifs: Vec<ExprType> = extract_list(ob, "ifs", "comprehension conditions").unwrap_or_default();
179        
180        // Extract is_async
181        let is_async: bool = ob.getattr("is_async")?.extract().unwrap_or(false);
182        
183        Ok(Comprehension {
184            target,
185            iter,
186            ifs,
187            is_async,
188        })
189    }
190}
191
192impl Node for ListComp {
193    fn lineno(&self) -> Option<usize> { self.lineno }
194    fn col_offset(&self) -> Option<usize> { self.col_offset }
195    fn end_lineno(&self) -> Option<usize> { self.end_lineno }
196    fn end_col_offset(&self) -> Option<usize> { self.end_col_offset }
197}
198
199impl Node for SetComp {
200    fn lineno(&self) -> Option<usize> { self.lineno }
201    fn col_offset(&self) -> Option<usize> { self.col_offset }
202    fn end_lineno(&self) -> Option<usize> { self.end_lineno }
203    fn end_col_offset(&self) -> Option<usize> { self.end_col_offset }
204}
205
206impl Node for GeneratorExp {
207    fn lineno(&self) -> Option<usize> { self.lineno }
208    fn col_offset(&self) -> Option<usize> { self.col_offset }
209    fn end_lineno(&self) -> Option<usize> { self.end_lineno }
210    fn end_col_offset(&self) -> Option<usize> { self.end_col_offset }
211}
212
213impl Node for DictComp {
214    fn lineno(&self) -> Option<usize> { self.lineno }
215    fn col_offset(&self) -> Option<usize> { self.col_offset }
216    fn end_lineno(&self) -> Option<usize> { self.end_lineno }
217    fn end_col_offset(&self) -> Option<usize> { self.end_col_offset }
218}
219
220impl CodeGen for ListComp {
221    type Context = CodeGenContext;
222    type Options = PythonOptions;
223    type SymbolTable = SymbolTableScopes;
224
225    fn find_symbols(self, symbols: Self::SymbolTable) -> Self::SymbolTable {
226        // Process the element and generators
227        let symbols = (*self.elt).clone().find_symbols(symbols);
228        self.generators.into_iter().fold(symbols, |acc, generator| {
229            let acc = generator.target.find_symbols(acc);
230            let acc = generator.iter.find_symbols(acc);
231            generator.ifs.into_iter().fold(acc, |acc, if_expr| if_expr.find_symbols(acc))
232        })
233    }
234
235    fn to_rust(
236        self,
237        ctx: Self::Context,
238        options: Self::Options,
239        symbols: Self::SymbolTable,
240    ) -> Result<TokenStream, Box<dyn std::error::Error>> {
241        // For now, generate a simple Vec collection since Rust doesn't have list comprehensions
242        // This is a simplified translation that doesn't handle all cases
243        if self.generators.len() == 1 {
244            let generator = &self.generators[0];
245            let elt = (*self.elt).clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
246            let iter_expr = generator.iter.clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
247            
248            if generator.ifs.is_empty() {
249                // Simple case: [expr for x in iter] -> iter.map(|x| expr).collect()
250                Ok(quote! {
251                    (#iter_expr).into_iter().map(|_item| #elt).collect::<Vec<_>>()
252                })
253            } else {
254                // With conditions: [expr for x in iter if cond] -> iter.filter(cond).map(expr).collect()
255                let conditions: Result<Vec<_>, _> = generator.ifs.iter()
256                    .map(|if_expr| if_expr.clone().to_rust(ctx.clone(), options.clone(), symbols.clone()))
257                    .collect();
258                let conditions = conditions?;
259                Ok(quote! {
260                    (#iter_expr).into_iter()
261                        .filter(|_item| { #(#conditions)&&* })
262                        .map(|_item| #elt)
263                        .collect::<Vec<_>>()
264                })
265            }
266        } else {
267            // Multiple generators would need nested iteration - this is complex
268            // For now, return a placeholder
269            Ok(quote! {
270                vec![] // Complex list comprehension with multiple generators not fully supported
271            })
272        }
273    }
274}
275
276impl CodeGen for SetComp {
277    type Context = CodeGenContext;
278    type Options = PythonOptions;
279    type SymbolTable = SymbolTableScopes;
280
281    fn find_symbols(self, symbols: Self::SymbolTable) -> Self::SymbolTable {
282        // Process the element and generators
283        let symbols = (*self.elt).clone().find_symbols(symbols);
284        self.generators.into_iter().fold(symbols, |acc, generator| {
285            let acc = generator.target.find_symbols(acc);
286            let acc = generator.iter.find_symbols(acc);
287            generator.ifs.into_iter().fold(acc, |acc, if_expr| if_expr.find_symbols(acc))
288        })
289    }
290
291    fn to_rust(
292        self,
293        ctx: Self::Context,
294        options: Self::Options,
295        symbols: Self::SymbolTable,
296    ) -> Result<TokenStream, Box<dyn std::error::Error>> {
297        // For now, generate a simple HashSet collection since Rust doesn't have set comprehensions
298        // This is a simplified translation that doesn't handle all cases
299        if self.generators.len() == 1 {
300            let generator = &self.generators[0];
301            let elt = (*self.elt).clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
302            let iter_expr = generator.iter.clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
303            
304            if generator.ifs.is_empty() {
305                // Simple case: {expr for x in iter} -> iter.map(|x| expr).collect()
306                Ok(quote! {
307                    (#iter_expr).into_iter().map(|_item| #elt).collect::<std::collections::HashSet<_>>()
308                })
309            } else {
310                // With conditions: {expr for x in iter if cond} -> iter.filter(cond).map(expr).collect()
311                let conditions: Result<Vec<_>, _> = generator.ifs.iter()
312                    .map(|if_expr| if_expr.clone().to_rust(ctx.clone(), options.clone(), symbols.clone()))
313                    .collect();
314                let conditions = conditions?;
315                Ok(quote! {
316                    (#iter_expr).into_iter()
317                        .filter(|_item| { #(#conditions)&&* })
318                        .map(|_item| #elt)
319                        .collect::<std::collections::HashSet<_>>()
320                })
321            }
322        } else {
323            // Multiple generators would need nested iteration - this is complex
324            // For now, return a placeholder
325            Ok(quote! {
326                std::collections::HashSet::new() // Complex set comprehension with multiple generators not fully supported
327            })
328        }
329    }
330}
331
332impl CodeGen for GeneratorExp {
333    type Context = CodeGenContext;
334    type Options = PythonOptions;
335    type SymbolTable = SymbolTableScopes;
336
337    fn find_symbols(self, symbols: Self::SymbolTable) -> Self::SymbolTable {
338        // Process the element and generators
339        let symbols = (*self.elt).clone().find_symbols(symbols);
340        self.generators.into_iter().fold(symbols, |acc, generator| {
341            let acc = generator.target.find_symbols(acc);
342            let acc = generator.iter.find_symbols(acc);
343            generator.ifs.into_iter().fold(acc, |acc, if_expr| if_expr.find_symbols(acc))
344        })
345    }
346
347    fn to_rust(
348        self,
349        ctx: Self::Context,
350        options: Self::Options,
351        symbols: Self::SymbolTable,
352    ) -> Result<TokenStream, Box<dyn std::error::Error>> {
353        // For now, generate a simple iterator since Rust doesn't have generator expressions
354        // This is a simplified translation that doesn't handle all cases
355        if self.generators.len() == 1 {
356            let generator = &self.generators[0];
357            let elt = (*self.elt).clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
358            let iter_expr = generator.iter.clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
359            
360            if generator.ifs.is_empty() {
361                // Simple case: (expr for x in iter) -> iter.map(|x| expr)
362                Ok(quote! {
363                    (#iter_expr).into_iter().map(|_item| #elt)
364                })
365            } else {
366                // With conditions: (expr for x in iter if cond) -> iter.filter(cond).map(expr)
367                let conditions: Result<Vec<_>, _> = generator.ifs.iter()
368                    .map(|if_expr| if_expr.clone().to_rust(ctx.clone(), options.clone(), symbols.clone()))
369                    .collect();
370                let conditions = conditions?;
371                Ok(quote! {
372                    (#iter_expr).into_iter()
373                        .filter(|_item| { #(#conditions)&&* })
374                        .map(|_item| #elt)
375                })
376            }
377        } else {
378            // Multiple generators would need nested iteration - this is complex
379            // For now, return a placeholder
380            Ok(quote! {
381                std::iter::empty() // Complex generator expression with multiple generators not fully supported
382            })
383        }
384    }
385}
386
387impl CodeGen for DictComp {
388    type Context = CodeGenContext;
389    type Options = PythonOptions;
390    type SymbolTable = SymbolTableScopes;
391
392    fn find_symbols(self, symbols: Self::SymbolTable) -> Self::SymbolTable {
393        // Process the key, value and generators
394        let symbols = (*self.key).clone().find_symbols(symbols);
395        let symbols = (*self.value).clone().find_symbols(symbols);
396        self.generators.into_iter().fold(symbols, |acc, generator| {
397            let acc = generator.target.find_symbols(acc);
398            let acc = generator.iter.find_symbols(acc);
399            generator.ifs.into_iter().fold(acc, |acc, if_expr| if_expr.find_symbols(acc))
400        })
401    }
402
403    fn to_rust(
404        self,
405        ctx: Self::Context,
406        options: Self::Options,
407        symbols: Self::SymbolTable,
408    ) -> Result<TokenStream, Box<dyn std::error::Error>> {
409        // For now, generate a simple HashMap collection since Rust doesn't have dict comprehensions
410        if self.generators.len() == 1 {
411            let generator = &self.generators[0];
412            let key = (*self.key).clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
413            let value = (*self.value).clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
414            let iter_expr = generator.iter.clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
415            
416            if generator.ifs.is_empty() {
417                // Simple case: {k: v for x in iter} -> iter.map(|x| (k, v)).collect()
418                Ok(quote! {
419                    (#iter_expr).into_iter().map(|_item| (#key, #value)).collect::<std::collections::HashMap<_, _>>()
420                })
421            } else {
422                // With conditions: {k: v for x in iter if cond}
423                let conditions: Result<Vec<_>, _> = generator.ifs.iter()
424                    .map(|if_expr| if_expr.clone().to_rust(ctx.clone(), options.clone(), symbols.clone()))
425                    .collect();
426                let conditions = conditions?;
427                Ok(quote! {
428                    (#iter_expr).into_iter()
429                        .filter(|_item| { #(#conditions)&&* })
430                        .map(|_item| (#key, #value))
431                        .collect::<std::collections::HashMap<_, _>>()
432                })
433            }
434        } else {
435            // Multiple generators would need nested iteration - this is complex
436            Ok(quote! {
437                std::collections::HashMap::new() // Complex dict comprehension with multiple generators not fully supported
438            })
439        }
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    // Note: These tests might need additional AST node implementations
446    // create_parse_test!(test_simple_listcomp, "[x for x in range(5)]", "test.py");
447    // create_parse_test!(test_listcomp_with_condition, "[x for x in range(10) if x % 2 == 0]", "test.py");
448}