python_ast/ast/tree/
statement.rs

1use proc_macro2::TokenStream;
2use pyo3::{Bound, FromPyObject, PyAny, PyResult, prelude::PyAnyMethods, types::PyTypeMethods};
3use quote::quote;
4
5use crate::{
6    dump, Assign, AugAssign, Call, ClassDef, CodeGen, CodeGenContext, Error, Expr, FunctionDef, Import,
7    ImportFrom, Node, PythonOptions, SymbolTableScopes, If, For, While, Try, AsyncWith, AsyncFor, Raise, With,
8};
9
10use log::debug;
11
12use serde::{Deserialize, Serialize};
13
14/// AST node types that can be used as a statement implement this type.
15pub trait PyStatementTrait: Clone + PartialEq {
16}
17
18#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
19pub struct Statement {
20    pub lineno: Option<usize>,
21    pub col_offset: Option<usize>,
22    pub end_lineno: Option<usize>,
23    pub end_col_offset: Option<usize>,
24    pub statement: StatementType,
25}
26
27impl<'a> FromPyObject<'a> for Statement {
28    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
29        Ok(Self {
30            lineno: ob.lineno(),
31            col_offset: ob.col_offset(),
32            end_lineno: ob.end_lineno(),
33            end_col_offset: ob.end_col_offset(),
34            statement: StatementType::extract_bound(ob)?,
35        })
36    }
37}
38
39impl Node for Statement {
40    fn lineno(&self) -> Option<usize> {
41        self.lineno
42    }
43    fn col_offset(&self) -> Option<usize> {
44        self.col_offset
45    }
46    fn end_lineno(&self) -> Option<usize> {
47        self.end_lineno
48    }
49    fn end_col_offset(&self) -> Option<usize> {
50        self.end_col_offset
51    }
52}
53
54impl CodeGen for Statement {
55    type Context = CodeGenContext;
56    type Options = PythonOptions;
57    type SymbolTable = SymbolTableScopes;
58
59    fn find_symbols(self, symbols: Self::SymbolTable) -> Self::SymbolTable {
60        self.statement.clone().find_symbols(symbols)
61    }
62
63    fn to_rust(
64        self,
65        ctx: Self::Context,
66        options: Self::Options,
67        symbols: Self::SymbolTable,
68    ) -> Result<TokenStream, Box<dyn std::error::Error>> {
69        Ok(self
70            .statement
71            .clone()
72            .to_rust(ctx, options, symbols)
73            .expect(
74                self.error_message(
75                    "<unknown>",
76                    format!("failed to compile statement {:#?}", self),
77                )
78                .as_str(),
79            ))
80    }
81}
82
83#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
84pub enum StatementType {
85    AsyncFunctionDef(FunctionDef),
86    Assign(Assign),
87    AugAssign(AugAssign),
88    Break,
89    Continue,
90    ClassDef(ClassDef),
91    Call(Call),
92    Pass,
93    Return(Option<Expr>),
94    Import(Import),
95    ImportFrom(ImportFrom),
96    Expr(Expr),
97    FunctionDef(FunctionDef),
98    If(If),
99    For(For),
100    While(While),
101    Try(Try),
102    AsyncWith(AsyncWith),
103    AsyncFor(AsyncFor),
104    Raise(Raise),
105    With(With),
106
107    Unimplemented(String),
108}
109
110impl<'a> FromPyObject<'a> for StatementType {
111    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
112        let err_msg = format!("getting type for statement {:?}", ob);
113        let ob_type = ob
114            .get_type()
115            .name()
116            .unwrap_or_else(|_| panic!("{}", ob.error_message("<unknown>", err_msg)));
117
118        debug!("statement...ob_type: {}...{}", ob_type, dump(ob, Some(4))?);
119        match ob_type.extract::<String>()?.as_str() {
120            "AsyncFunctionDef" => Ok(StatementType::AsyncFunctionDef(
121                FunctionDef::extract_bound(ob).unwrap_or_else(|_| {
122                    panic!("Failed to extract async function: {:?}", dump(ob, Some(4)))
123                }),
124            )),
125            "Assign" => {
126                let assignment = Assign::extract_bound(ob).expect("reading assignment");
127                Ok(StatementType::Assign(assignment))
128            }
129            "AugAssign" => {
130                let aug_assignment = AugAssign::extract_bound(ob).expect("reading augmented assignment");
131                Ok(StatementType::AugAssign(aug_assignment))
132            }
133            "Pass" => Ok(StatementType::Pass),
134            "Call" => {
135                let call =
136                    Call::extract_bound(&ob.getattr("value").unwrap_or_else(|_| {
137                        panic!("getting value from {:?} in call statement", ob)
138                    }))
139                    .unwrap_or_else(|_| panic!("extracting call statement {:?}", ob));
140                debug!("call: {:?}", call);
141                Ok(StatementType::Call(call))
142            }
143            "ClassDef" => Ok(StatementType::ClassDef(
144                ClassDef::extract_bound(ob).unwrap_or_else(|_| panic!("Class definition {:?}", ob)),
145            )),
146            "Continue" => Ok(StatementType::Continue),
147            "Break" => Ok(StatementType::Break),
148            "FunctionDef" => Ok(StatementType::FunctionDef(
149                FunctionDef::extract_bound(ob).unwrap_or_else(|_| {
150                    panic!("Failed to extract function: {:?}", dump(ob, Some(4)))
151                }),
152            )),
153            "Import" => Ok(StatementType::Import(
154                Import::extract_bound(ob).unwrap_or_else(|_| panic!("Import {:?}", ob)),
155            )),
156            "ImportFrom" => Ok(StatementType::ImportFrom(
157                ImportFrom::extract_bound(ob).unwrap_or_else(|_| panic!("ImportFrom {:?}", ob)),
158            )),
159            "Expr" => {
160                let expr = ob.extract()
161                    .expect(format!("Expr {:?}", ob).as_str());
162                Ok(StatementType::Expr(expr))
163            }
164            "Return" => {
165                log::debug!("return expression: {}", dump(ob, None)?);
166                // Extract the return value from the Return statement's 'value' field
167                let return_value = if let Ok(value_attr) = ob.getattr("value") {
168                    if value_attr.is_none() {
169                        // Bare 'return' statement - create a NoneType Expr
170                        Some(Expr {
171                            value: crate::tree::ExprType::NoneType(crate::tree::Constant(None)),
172                            ctx: None,
173                            lineno: ob.lineno(),
174                            col_offset: ob.col_offset(),
175                            end_lineno: ob.end_lineno(),
176                            end_col_offset: ob.end_col_offset(),
177                        })
178                    } else {
179                        // Return with actual expression - extract as ExprType then wrap in Expr
180                        let expr_value: crate::tree::ExprType = value_attr.extract()
181                            .unwrap_or_else(|_| panic!("return value ExprType {:?}", dump(&value_attr, None).unwrap_or_else(|_| "unknown".to_string())));
182                        Some(Expr {
183                            value: expr_value,
184                            ctx: None,
185                            lineno: ob.lineno(),
186                            col_offset: ob.col_offset(),
187                            end_lineno: ob.end_lineno(),
188                            end_col_offset: ob.end_col_offset(),
189                        })
190                    }
191                } else {
192                    None
193                };
194                Ok(StatementType::Return(return_value))
195            }
196            "If" => {
197                let if_stmt = If::extract_bound(ob)
198                    .unwrap_or_else(|_| panic!("If statement {:?}", dump(ob, None)));
199                Ok(StatementType::If(if_stmt))
200            }
201            "For" => {
202                let for_stmt = For::extract_bound(ob)
203                    .unwrap_or_else(|_| panic!("For statement {:?}", dump(ob, None)));
204                Ok(StatementType::For(for_stmt))
205            }
206            "While" => {
207                let while_stmt = While::extract_bound(ob)
208                    .unwrap_or_else(|_| panic!("While statement {:?}", dump(ob, None)));
209                Ok(StatementType::While(while_stmt))
210            }
211            "Try" => {
212                let try_stmt = Try::extract_bound(ob)
213                    .unwrap_or_else(|_| panic!("Try statement {:?}", dump(ob, None)));
214                Ok(StatementType::Try(try_stmt))
215            }
216            "AsyncWith" => {
217                let async_with_stmt = AsyncWith::extract_bound(ob)
218                    .unwrap_or_else(|_| panic!("AsyncWith statement {:?}", dump(ob, None)));
219                Ok(StatementType::AsyncWith(async_with_stmt))
220            }
221            "AsyncFor" => {
222                let async_for_stmt = AsyncFor::extract_bound(ob)
223                    .unwrap_or_else(|_| panic!("AsyncFor statement {:?}", dump(ob, None)));
224                Ok(StatementType::AsyncFor(async_for_stmt))
225            }
226            "Raise" => {
227                let raise_stmt = Raise::extract_bound(ob)
228                    .unwrap_or_else(|_| panic!("Raise statement {:?}", dump(ob, None)));
229                Ok(StatementType::Raise(raise_stmt))
230            }
231            "With" => {
232                let with_stmt = With::extract_bound(ob)
233                    .unwrap_or_else(|_| panic!("With statement {:?}", dump(ob, None)));
234                Ok(StatementType::With(with_stmt))
235            }
236            _ => Err(pyo3::exceptions::PyValueError::new_err(format!(
237                "Unimplemented statement type {}, {}",
238                ob_type,
239                dump(ob, None)?
240            ))),
241        }
242    }
243}
244
245impl CodeGen for StatementType {
246    type Context = CodeGenContext;
247    type Options = PythonOptions;
248    type SymbolTable = SymbolTableScopes;
249
250    fn find_symbols(self, symbols: Self::SymbolTable) -> Self::SymbolTable {
251        match self {
252            StatementType::Assign(a) => a.find_symbols(symbols),
253            StatementType::AugAssign(a) => a.find_symbols(symbols),
254            StatementType::ClassDef(c) => c.find_symbols(symbols),
255            StatementType::FunctionDef(f) => f.find_symbols(symbols),
256            StatementType::Import(i) => i.find_symbols(symbols),
257            StatementType::ImportFrom(i) => i.find_symbols(symbols),
258            StatementType::Expr(e) => e.find_symbols(symbols),
259            StatementType::If(i) => i.find_symbols(symbols),
260            StatementType::For(f) => f.find_symbols(symbols),
261            StatementType::While(w) => w.find_symbols(symbols),
262            StatementType::Try(t) => t.find_symbols(symbols),
263            StatementType::AsyncWith(aw) => aw.find_symbols(symbols),
264            StatementType::AsyncFor(af) => af.find_symbols(symbols),
265            StatementType::Raise(r) => r.find_symbols(symbols),
266            StatementType::With(w) => w.find_symbols(symbols),
267            _ => symbols,
268        }
269    }
270
271    fn to_rust(
272        self,
273        ctx: Self::Context,
274        options: Self::Options,
275        symbols: Self::SymbolTable,
276    ) -> Result<TokenStream, Box<dyn std::error::Error>> {
277        match self {
278            StatementType::AsyncFunctionDef(s) => {
279                let func_def = s
280                    .to_rust(Self::Context::Async(Box::new(ctx)), options, symbols)
281                    .expect("Parsing async function");
282                Ok(quote!(#func_def))
283            }
284            StatementType::Assign(a) => a.to_rust(ctx, options, symbols),
285            StatementType::AugAssign(a) => a.to_rust(ctx, options, symbols),
286            StatementType::Break => Ok(quote! {break;}),
287            StatementType::Call(c) => c.to_rust(ctx, options, symbols),
288            StatementType::ClassDef(c) => c.to_rust(ctx, options, symbols),
289            StatementType::Continue => Ok(quote! {continue;}),
290            StatementType::Pass => Ok(quote! {}),
291            StatementType::FunctionDef(s) => s.to_rust(ctx, options, symbols),
292            StatementType::Import(s) => s.to_rust(ctx, options, symbols),
293            StatementType::ImportFrom(s) => s.to_rust(ctx, options, symbols),
294            StatementType::Expr(s) => s.to_rust(ctx, options, symbols),
295            StatementType::Return(None) => Ok(quote!(return)),
296            StatementType::Return(Some(e)) => {
297                let exp = e
298                    .clone()
299                    .to_rust(ctx, options, symbols)
300                    .unwrap_or_else(|_| panic!("parsing expression {:#?}", e));
301                Ok(quote!(return #exp))
302            }
303            StatementType::If(i) => i.to_rust(ctx, options, symbols),
304            StatementType::For(f) => f.to_rust(ctx, options, symbols),
305            StatementType::While(w) => w.to_rust(ctx, options, symbols),
306            StatementType::Try(t) => t.to_rust(ctx, options, symbols),
307            StatementType::AsyncWith(aw) => aw.to_rust(ctx, options, symbols),
308            StatementType::AsyncFor(af) => af.to_rust(ctx, options, symbols),
309            StatementType::Raise(r) => r.to_rust(ctx, options, symbols),
310            StatementType::With(w) => w.to_rust(ctx, options, symbols),
311            _ => {
312                let error = Error::StatementNotYetImplemented(self);
313                Err(Box::new(error))
314            }
315        }
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn check_pass_statement() {
325        let statement = StatementType::Pass;
326        let options = PythonOptions::default();
327        let tokens = statement.clone().to_rust(
328            CodeGenContext::Module("".to_string()),
329            options,
330            SymbolTableScopes::new(),
331        );
332
333        debug!("statement: {:?}, tokens: {:?}", statement, tokens);
334        assert_eq!(tokens.unwrap().is_empty(), true);
335    }
336
337    #[test]
338    fn check_break_statement() {
339        let statement = StatementType::Break;
340        let options = PythonOptions::default();
341        let tokens = statement.clone().to_rust(
342            CodeGenContext::Module("".to_string()),
343            options,
344            SymbolTableScopes::new(),
345        );
346
347        debug!("statement: {:?}, tokens: {:?}", statement, tokens);
348        assert_eq!(tokens.unwrap().is_empty(), false);
349    }
350
351    #[test]
352    fn check_continue_statement() {
353        let statement = StatementType::Continue;
354        let options = PythonOptions::default();
355        let tokens = statement.clone().to_rust(
356            CodeGenContext::Module("".to_string()),
357            options,
358            SymbolTableScopes::new(),
359        );
360
361        debug!("statement: {:?}, tokens: {:?}", statement, tokens);
362        assert_eq!(tokens.unwrap().is_empty(), false);
363    }
364
365    #[test]
366    fn return_with_nothing() {
367        let tree = crate::parse("return", "<none>").unwrap();
368        assert_eq!(tree.raw.body.len(), 1);
369        assert_eq!(
370            tree.raw.body[0].statement,
371            StatementType::Return(Some(Expr {
372                value: crate::tree::ExprType::NoneType(crate::tree::Constant(None)),
373                lineno: Some(1),
374                col_offset: Some(0),
375                end_lineno: Some(1),
376                end_col_offset: Some(6),
377                ..Default::default()
378            }))
379        );
380    }
381
382    #[test]
383    fn return_with_expr() {
384        let lit = litrs::Literal::Integer(litrs::IntegerLit::parse(String::from("8")).unwrap());
385        let tree = crate::parse("return 8", "<none>").unwrap();
386        assert_eq!(tree.raw.body.len(), 1);
387        assert_eq!(
388            tree.raw.body[0].statement,
389            StatementType::Return(Some(Expr {
390                value: crate::tree::ExprType::Constant(crate::tree::Constant(Some(lit))),
391                lineno: Some(1),
392                col_offset: Some(0),
393                end_lineno: Some(1),
394                end_col_offset: Some(8),
395                ..Default::default()
396            }))
397        );
398    }
399
400    #[test]
401    fn does_module_compile() {
402        let options = PythonOptions::default();
403        let result = crate::parse(
404            "#test comment
405def foo():
406    continue
407    pass
408",
409            "test_case",
410        )
411        .unwrap();
412        log::info!("{:?}", result);
413        let code = result.to_rust(
414            CodeGenContext::Module("".to_string()),
415            options,
416            SymbolTableScopes::new(),
417        );
418        log::info!("module: {:?}", code);
419    }
420}