python_ast/ast/tree/
for_stmt.rs

1use proc_macro2::TokenStream;
2use pyo3::{Bound, FromPyObject, PyAny, PyResult, types::PyAnyMethods};
3use quote::quote;
4use serde::{Deserialize, Serialize};
5
6use crate::{
7    CodeGen, CodeGenContext, ExprType, PythonOptions, SymbolTableScopes,
8    Node, impl_node_with_positions, PyAttributeExtractor, extract_list
9};
10
11use super::Statement;
12
13#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
14pub struct For {
15    pub target: ExprType,
16    pub iter: ExprType,
17    pub body: Vec<Statement>,
18    pub orelse: Vec<Statement>,
19    pub lineno: Option<usize>,
20    pub col_offset: Option<usize>,
21    pub end_lineno: Option<usize>,
22    pub end_col_offset: Option<usize>,
23}
24
25impl<'a> FromPyObject<'a> for For {
26    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
27        let target = ob.extract_attr_with_context("target", "for loop target")?;
28        let iter = ob.extract_attr_with_context("iter", "for loop iterator")?;
29        
30        let target = target.extract().expect("getting for target");
31        let iter = iter.extract().expect("getting for iter");
32        
33        let body: Vec<Statement> = extract_list(ob, "body", "for body statements")?;
34        let orelse: Vec<Statement> = extract_list(ob, "orelse", "for else statements")?;
35        
36        Ok(For {
37            target,
38            iter,
39            body,
40            orelse,
41            lineno: ob.lineno(),
42            col_offset: ob.col_offset(),
43            end_lineno: ob.end_lineno(),
44            end_col_offset: ob.end_col_offset(),
45        })
46    }
47}
48
49impl_node_with_positions!(For { lineno, col_offset, end_lineno, end_col_offset });
50
51impl CodeGen for For {
52    type Context = CodeGenContext;
53    type Options = PythonOptions;
54    type SymbolTable = SymbolTableScopes;
55
56    fn find_symbols(self, symbols: Self::SymbolTable) -> Self::SymbolTable {
57        let symbols = self.target.find_symbols(symbols);
58        let symbols = self.iter.find_symbols(symbols);
59        let symbols = self.body.into_iter().fold(symbols, |acc, stmt| stmt.find_symbols(acc));
60        self.orelse.into_iter().fold(symbols, |acc, stmt| stmt.find_symbols(acc))
61    }
62
63    fn to_rust(
64        self,
65        ctx: Self::Context,
66        options: Self::Options,
67        symbols: Self::SymbolTable,
68    ) -> Result<TokenStream, Box<dyn std::error::Error>> {
69        let target = self.target.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
70        let iter = self.iter.to_rust(ctx.clone(), options.clone(), symbols.clone())?;
71        
72        let body_stmts: Result<Vec<_>, _> = self.body
73            .into_iter()
74            .map(|stmt| stmt.to_rust(ctx.clone(), options.clone(), symbols.clone()))
75            .collect();
76        let body_stmts = body_stmts?;
77        
78        if self.orelse.is_empty() {
79            Ok(quote! {
80                for #target in #iter {
81                    #(#body_stmts)*
82                }
83            })
84        } else {
85            // Note: Rust doesn't have for-else, so we need to track completion
86            let else_stmts: Result<Vec<_>, _> = self.orelse
87                .into_iter()
88                .map(|stmt| stmt.to_rust(ctx.clone(), options.clone(), symbols.clone()))
89                .collect();
90            let else_stmts = else_stmts?;
91            
92            Ok(quote! {
93                {
94                    let mut completed = true;
95                    for #target in #iter {
96                        #(#body_stmts)*
97                        completed = false;
98                        break;
99                    }
100                    if completed {
101                        #(#else_stmts)*
102                    }
103                }
104            })
105        }
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use crate::create_parse_test;
113
114    create_parse_test!(test_simple_for, "for x in range(10):\n    print(x)", "for_test.py");
115    create_parse_test!(test_for_else, "for x in range(10):\n    print(x)\nelse:\n    print('done')", "for_test.py");
116    create_parse_test!(test_for_list, "for item in [1, 2, 3]:\n    print(item)", "for_test.py");
117}