python_ast/ast/tree/
statement.rs

1use proc_macro2::TokenStream;
2use pyo3::{FromPyObject, PyAny, PyResult};
3use quote::quote;
4
5use crate::{
6    dump, Assign, Call, ClassDef, CodeGen, CodeGenContext, Error, Expr, FunctionDef, Import,
7    ImportFrom, Node, PythonOptions, SymbolTableScopes,
8};
9
10use log::debug;
11
12use serde::{Deserialize, Serialize};
13
14#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
15pub struct Statement {
16    pub lineno: Option<usize>,
17    pub col_offset: Option<usize>,
18    pub end_lineno: Option<usize>,
19    pub end_col_offset: Option<usize>,
20    pub statement: StatementType,
21}
22
23impl<'a> FromPyObject<'a> for Statement {
24    fn extract(ob: &'a PyAny) -> PyResult<Self> {
25        Ok(Self {
26            lineno: ob.lineno(),
27            col_offset: ob.col_offset(),
28            end_lineno: ob.end_lineno(),
29            end_col_offset: ob.end_col_offset(),
30            statement: StatementType::extract(ob)?,
31        })
32    }
33}
34
35impl Node for Statement {
36    fn lineno(&self) -> Option<usize> {
37        self.lineno
38    }
39    fn col_offset(&self) -> Option<usize> {
40        self.col_offset
41    }
42    fn end_lineno(&self) -> Option<usize> {
43        self.end_lineno
44    }
45    fn end_col_offset(&self) -> Option<usize> {
46        self.end_col_offset
47    }
48}
49
50impl CodeGen for Statement {
51    type Context = CodeGenContext;
52    type Options = PythonOptions;
53    type SymbolTable = SymbolTableScopes;
54
55    fn find_symbols(self, symbols: Self::SymbolTable) -> Self::SymbolTable {
56        self.statement.clone().find_symbols(symbols)
57    }
58
59    fn to_rust(
60        self,
61        ctx: Self::Context,
62        options: Self::Options,
63        symbols: Self::SymbolTable,
64    ) -> Result<TokenStream, Box<dyn std::error::Error>> {
65        Ok(self
66            .statement
67            .clone()
68            .to_rust(ctx, options, symbols)
69            .expect(
70                self.error_message(
71                    "<unknown>",
72                    format!("failed to compile statement {:#?}", self),
73                )
74                .as_str(),
75            ))
76    }
77}
78
79#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
80pub enum StatementType {
81    AsyncFunctionDef(FunctionDef),
82    Assign(Assign),
83    Break,
84    Continue,
85    ClassDef(ClassDef),
86    Call(Call),
87    Pass,
88    Return(Option<Expr>),
89    Import(Import),
90    ImportFrom(ImportFrom),
91    Expr(Expr),
92    FunctionDef(FunctionDef),
93
94    Unimplemented(String),
95}
96
97impl<'a> FromPyObject<'a> for StatementType {
98    fn extract(ob: &'a PyAny) -> PyResult<Self> {
99        let err_msg = format!("getting type for statement {:?}", ob);
100        let ob_type = ob
101            .get_type()
102            .name()
103            .unwrap_or_else(|_| panic!("{}", ob.error_message("<unknown>", err_msg)));
104
105        debug!("statement...ob_type: {}...{}", ob_type, dump(ob, Some(4))?);
106        match ob_type.as_ref() {
107            "AsyncFunctionDef" => Ok(StatementType::AsyncFunctionDef(
108                FunctionDef::extract(ob).unwrap_or_else(|_| {
109                    panic!("Failed to extract async function: {:?}", dump(ob, Some(4)))
110                }),
111            )),
112            "Assign" => {
113                let assignment = Assign::extract(ob).expect("reading assignment");
114                Ok(StatementType::Assign(assignment))
115            }
116            "Pass" => Ok(StatementType::Pass),
117            "Call" => {
118                let call =
119                    Call::extract(ob.getattr("value").unwrap_or_else(|_| {
120                        panic!("getting value from {:?} in call statement", ob)
121                    }))
122                    .unwrap_or_else(|_| panic!("extracting call statement {:?}", ob));
123                debug!("call: {:?}", call);
124                Ok(StatementType::Call(call))
125            }
126            "ClassDef" => Ok(StatementType::ClassDef(
127                ClassDef::extract(ob).unwrap_or_else(|_| panic!("Class definition {:?}", ob)),
128            )),
129            "Continue" => Ok(StatementType::Continue),
130            "Break" => Ok(StatementType::Break),
131            "FunctionDef" => Ok(StatementType::FunctionDef(
132                FunctionDef::extract(ob).unwrap_or_else(|_| {
133                    panic!("Failed to extract function: {:?}", dump(ob, Some(4)))
134                }),
135            )),
136            "Import" => Ok(StatementType::Import(
137                Import::extract(ob).unwrap_or_else(|_| panic!("Import {:?}", ob)),
138            )),
139            "ImportFrom" => Ok(StatementType::ImportFrom(
140                ImportFrom::extract(ob).unwrap_or_else(|_| panic!("ImportFrom {:?}", ob)),
141            )),
142            "Expr" => {
143                let expr = Expr::extract(
144                    ob.extract()
145                        .unwrap_or_else(|_| panic!("extracting Expr {:?}", ob)),
146                )
147                .expect(format!("Expr {:?}", ob).as_str());
148                Ok(StatementType::Expr(expr))
149            }
150            "Return" => {
151                log::debug!("return expression: {}", dump(ob, None)?);
152                let expr = Expr::extract(
153                    ob.extract()
154                        .unwrap_or_else(|_| panic!("extracting return Expr {:?}", ob)),
155                )
156                .unwrap_or_else(|_| panic!("return Expr {:?}", dump(ob, None)));
157                Ok(StatementType::Return(Some(expr)))
158            }
159            _ => Err(pyo3::exceptions::PyValueError::new_err(format!(
160                "Unimplemented statement type {}, {}",
161                ob_type,
162                dump(ob, None)?
163            ))),
164        }
165    }
166}
167
168impl CodeGen for StatementType {
169    type Context = CodeGenContext;
170    type Options = PythonOptions;
171    type SymbolTable = SymbolTableScopes;
172
173    fn find_symbols(self, symbols: Self::SymbolTable) -> Self::SymbolTable {
174        match self {
175            StatementType::Assign(a) => a.find_symbols(symbols),
176            StatementType::ClassDef(c) => c.find_symbols(symbols),
177            StatementType::FunctionDef(f) => f.find_symbols(symbols),
178            StatementType::Import(i) => i.find_symbols(symbols),
179            StatementType::ImportFrom(i) => i.find_symbols(symbols),
180            StatementType::Expr(e) => e.find_symbols(symbols),
181            _ => symbols,
182        }
183    }
184
185    fn to_rust(
186        self,
187        ctx: Self::Context,
188        options: Self::Options,
189        symbols: Self::SymbolTable,
190    ) -> Result<TokenStream, Box<dyn std::error::Error>> {
191        match self {
192            StatementType::AsyncFunctionDef(s) => {
193                let func_def = s
194                    .to_rust(Self::Context::Async(Box::new(ctx)), options, symbols)
195                    .expect("Parsing async function");
196                Ok(quote!(#func_def))
197            }
198            StatementType::Assign(a) => a.to_rust(ctx, options, symbols),
199            StatementType::Break => Ok(quote! {break;}),
200            StatementType::Call(c) => c.to_rust(ctx, options, symbols),
201            StatementType::ClassDef(c) => c.to_rust(ctx, options, symbols),
202            StatementType::Continue => Ok(quote! {continue;}),
203            StatementType::Pass => Ok(quote! {}),
204            StatementType::FunctionDef(s) => s.to_rust(ctx, options, symbols),
205            StatementType::Import(s) => s.to_rust(ctx, options, symbols),
206            StatementType::ImportFrom(s) => s.to_rust(ctx, options, symbols),
207            StatementType::Expr(s) => s.to_rust(ctx, options, symbols),
208            StatementType::Return(None) => Ok(quote!(return)),
209            StatementType::Return(Some(e)) => {
210                let exp = e
211                    .clone()
212                    .to_rust(ctx, options, symbols)
213                    .unwrap_or_else(|_| panic!("parsing expression {:#?}", e));
214                Ok(quote!(return #exp))
215            }
216            _ => {
217                let error = Error::StatementNotYetImplemented(self);
218                Err(Box::new(error))
219            }
220        }
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[test]
229    fn check_pass_statement() {
230        let statement = StatementType::Pass;
231        let options = PythonOptions::default();
232        let tokens = statement.clone().to_rust(
233            CodeGenContext::Module("".to_string()),
234            options,
235            SymbolTableScopes::new(),
236        );
237
238        debug!("statement: {:?}, tokens: {:?}", statement, tokens);
239        assert_eq!(tokens.unwrap().is_empty(), true);
240    }
241
242    #[test]
243    fn check_break_statement() {
244        let statement = StatementType::Break;
245        let options = PythonOptions::default();
246        let tokens = statement.clone().to_rust(
247            CodeGenContext::Module("".to_string()),
248            options,
249            SymbolTableScopes::new(),
250        );
251
252        debug!("statement: {:?}, tokens: {:?}", statement, tokens);
253        assert_eq!(tokens.unwrap().is_empty(), false);
254    }
255
256    #[test]
257    fn check_continue_statement() {
258        let statement = StatementType::Continue;
259        let options = PythonOptions::default();
260        let tokens = statement.clone().to_rust(
261            CodeGenContext::Module("".to_string()),
262            options,
263            SymbolTableScopes::new(),
264        );
265
266        debug!("statement: {:?}, tokens: {:?}", statement, tokens);
267        assert_eq!(tokens.unwrap().is_empty(), false);
268    }
269
270    #[test]
271    fn return_with_nothing() {
272        let tree = crate::parse("return", "<none>").unwrap();
273        assert_eq!(tree.raw.body.len(), 1);
274        assert_eq!(
275            tree.raw.body[0].statement,
276            StatementType::Return(Some(Expr {
277                value: crate::tree::ExprType::NoneType(crate::tree::Constant(None)),
278                lineno: Some(1),
279                col_offset: Some(0),
280                end_lineno: Some(1),
281                end_col_offset: Some(6),
282                ..Default::default()
283            }))
284        );
285    }
286
287    #[test]
288    fn return_with_expr() {
289        let lit = litrs::Literal::Integer(litrs::IntegerLit::parse(String::from("8")).unwrap());
290        let tree = crate::parse("return 8", "<none>").unwrap();
291        assert_eq!(tree.raw.body.len(), 1);
292        assert_eq!(
293            tree.raw.body[0].statement,
294            StatementType::Return(Some(Expr {
295                value: crate::tree::ExprType::Constant(crate::tree::Constant(Some(lit))),
296                lineno: Some(1),
297                col_offset: Some(0),
298                end_lineno: Some(1),
299                end_col_offset: Some(8),
300                ..Default::default()
301            }))
302        );
303    }
304
305    #[test]
306    fn does_module_compile() {
307        let options = PythonOptions::default();
308        let result = crate::parse(
309            "#test comment
310def foo():
311    continue
312    pass
313",
314            "test_case",
315        )
316        .unwrap();
317        log::info!("{:?}", result);
318        let code = result.to_rust(
319            CodeGenContext::Module("".to_string()),
320            options,
321            SymbolTableScopes::new(),
322        );
323        log::info!("module: {:?}", code);
324    }
325}