1use proc_macro2::TokenStream;
2use pyo3::{Bound, FromPyObject, PyAny, PyResult, prelude::PyAnyMethods, types::PyTypeMethods};
3use quote::quote;
4
5use crate::{
6 dump, Assign, AugAssign, Call, ClassDef, CodeGen, CodeGenContext, Error, Expr, FunctionDef, Import,
7 ImportFrom, Node, PythonOptions, SymbolTableScopes, If, For, While, Try, AsyncWith, AsyncFor, Raise, With,
8};
9
10use log::debug;
11
12use serde::{Deserialize, Serialize};
13
14pub trait PyStatementTrait: Clone + PartialEq {
16}
17
18#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
19pub struct Statement {
20 pub lineno: Option<usize>,
21 pub col_offset: Option<usize>,
22 pub end_lineno: Option<usize>,
23 pub end_col_offset: Option<usize>,
24 pub statement: StatementType,
25}
26
27impl<'a> FromPyObject<'a> for Statement {
28 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
29 Ok(Self {
30 lineno: ob.lineno(),
31 col_offset: ob.col_offset(),
32 end_lineno: ob.end_lineno(),
33 end_col_offset: ob.end_col_offset(),
34 statement: StatementType::extract_bound(ob)?,
35 })
36 }
37}
38
39impl Node for Statement {
40 fn lineno(&self) -> Option<usize> {
41 self.lineno
42 }
43 fn col_offset(&self) -> Option<usize> {
44 self.col_offset
45 }
46 fn end_lineno(&self) -> Option<usize> {
47 self.end_lineno
48 }
49 fn end_col_offset(&self) -> Option<usize> {
50 self.end_col_offset
51 }
52}
53
54impl CodeGen for Statement {
55 type Context = CodeGenContext;
56 type Options = PythonOptions;
57 type SymbolTable = SymbolTableScopes;
58
59 fn find_symbols(self, symbols: Self::SymbolTable) -> Self::SymbolTable {
60 self.statement.clone().find_symbols(symbols)
61 }
62
63 fn to_rust(
64 self,
65 ctx: Self::Context,
66 options: Self::Options,
67 symbols: Self::SymbolTable,
68 ) -> Result<TokenStream, Box<dyn std::error::Error>> {
69 Ok(self
70 .statement
71 .clone()
72 .to_rust(ctx, options, symbols)
73 .expect(
74 self.error_message(
75 "<unknown>",
76 format!("failed to compile statement {:#?}", self),
77 )
78 .as_str(),
79 ))
80 }
81}
82
83#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
84pub enum StatementType {
85 AsyncFunctionDef(FunctionDef),
86 Assign(Assign),
87 AugAssign(AugAssign),
88 Break,
89 Continue,
90 ClassDef(ClassDef),
91 Call(Call),
92 Pass,
93 Return(Option<Expr>),
94 Import(Import),
95 ImportFrom(ImportFrom),
96 Expr(Expr),
97 FunctionDef(FunctionDef),
98 If(If),
99 For(For),
100 While(While),
101 Try(Try),
102 AsyncWith(AsyncWith),
103 AsyncFor(AsyncFor),
104 Raise(Raise),
105 With(With),
106
107 Unimplemented(String),
108}
109
110impl<'a> FromPyObject<'a> for StatementType {
111 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
112 let err_msg = format!("getting type for statement {:?}", ob);
113 let ob_type = ob
114 .get_type()
115 .name()
116 .unwrap_or_else(|_| panic!("{}", ob.error_message("<unknown>", err_msg)));
117
118 debug!("statement...ob_type: {}...{}", ob_type, dump(ob, Some(4))?);
119 match ob_type.extract::<String>()?.as_str() {
120 "AsyncFunctionDef" => Ok(StatementType::AsyncFunctionDef(
121 FunctionDef::extract_bound(ob).unwrap_or_else(|_| {
122 panic!("Failed to extract async function: {:?}", dump(ob, Some(4)))
123 }),
124 )),
125 "Assign" => {
126 let assignment = Assign::extract_bound(ob).expect("reading assignment");
127 Ok(StatementType::Assign(assignment))
128 }
129 "AugAssign" => {
130 let aug_assignment = AugAssign::extract_bound(ob).expect("reading augmented assignment");
131 Ok(StatementType::AugAssign(aug_assignment))
132 }
133 "Pass" => Ok(StatementType::Pass),
134 "Call" => {
135 let call =
136 Call::extract_bound(&ob.getattr("value").unwrap_or_else(|_| {
137 panic!("getting value from {:?} in call statement", ob)
138 }))
139 .unwrap_or_else(|_| panic!("extracting call statement {:?}", ob));
140 debug!("call: {:?}", call);
141 Ok(StatementType::Call(call))
142 }
143 "ClassDef" => Ok(StatementType::ClassDef(
144 ClassDef::extract_bound(ob).unwrap_or_else(|_| panic!("Class definition {:?}", ob)),
145 )),
146 "Continue" => Ok(StatementType::Continue),
147 "Break" => Ok(StatementType::Break),
148 "FunctionDef" => Ok(StatementType::FunctionDef(
149 FunctionDef::extract_bound(ob).unwrap_or_else(|_| {
150 panic!("Failed to extract function: {:?}", dump(ob, Some(4)))
151 }),
152 )),
153 "Import" => Ok(StatementType::Import(
154 Import::extract_bound(ob).unwrap_or_else(|_| panic!("Import {:?}", ob)),
155 )),
156 "ImportFrom" => Ok(StatementType::ImportFrom(
157 ImportFrom::extract_bound(ob).unwrap_or_else(|_| panic!("ImportFrom {:?}", ob)),
158 )),
159 "Expr" => {
160 let expr = ob.extract()
161 .expect(format!("Expr {:?}", ob).as_str());
162 Ok(StatementType::Expr(expr))
163 }
164 "Return" => {
165 log::debug!("return expression: {}", dump(ob, None)?);
166 let return_value = if let Ok(value_attr) = ob.getattr("value") {
168 if value_attr.is_none() {
169 Some(Expr {
171 value: crate::tree::ExprType::NoneType(crate::tree::Constant(None)),
172 ctx: None,
173 lineno: ob.lineno(),
174 col_offset: ob.col_offset(),
175 end_lineno: ob.end_lineno(),
176 end_col_offset: ob.end_col_offset(),
177 })
178 } else {
179 let expr_value: crate::tree::ExprType = value_attr.extract()
181 .unwrap_or_else(|_| panic!("return value ExprType {:?}", dump(&value_attr, None).unwrap_or_else(|_| "unknown".to_string())));
182 Some(Expr {
183 value: expr_value,
184 ctx: None,
185 lineno: ob.lineno(),
186 col_offset: ob.col_offset(),
187 end_lineno: ob.end_lineno(),
188 end_col_offset: ob.end_col_offset(),
189 })
190 }
191 } else {
192 None
193 };
194 Ok(StatementType::Return(return_value))
195 }
196 "If" => {
197 let if_stmt = If::extract_bound(ob)
198 .unwrap_or_else(|_| panic!("If statement {:?}", dump(ob, None)));
199 Ok(StatementType::If(if_stmt))
200 }
201 "For" => {
202 let for_stmt = For::extract_bound(ob)
203 .unwrap_or_else(|_| panic!("For statement {:?}", dump(ob, None)));
204 Ok(StatementType::For(for_stmt))
205 }
206 "While" => {
207 let while_stmt = While::extract_bound(ob)
208 .unwrap_or_else(|_| panic!("While statement {:?}", dump(ob, None)));
209 Ok(StatementType::While(while_stmt))
210 }
211 "Try" => {
212 let try_stmt = Try::extract_bound(ob)
213 .unwrap_or_else(|_| panic!("Try statement {:?}", dump(ob, None)));
214 Ok(StatementType::Try(try_stmt))
215 }
216 "AsyncWith" => {
217 let async_with_stmt = AsyncWith::extract_bound(ob)
218 .unwrap_or_else(|_| panic!("AsyncWith statement {:?}", dump(ob, None)));
219 Ok(StatementType::AsyncWith(async_with_stmt))
220 }
221 "AsyncFor" => {
222 let async_for_stmt = AsyncFor::extract_bound(ob)
223 .unwrap_or_else(|_| panic!("AsyncFor statement {:?}", dump(ob, None)));
224 Ok(StatementType::AsyncFor(async_for_stmt))
225 }
226 "Raise" => {
227 let raise_stmt = Raise::extract_bound(ob)
228 .unwrap_or_else(|_| panic!("Raise statement {:?}", dump(ob, None)));
229 Ok(StatementType::Raise(raise_stmt))
230 }
231 "With" => {
232 let with_stmt = With::extract_bound(ob)
233 .unwrap_or_else(|_| panic!("With statement {:?}", dump(ob, None)));
234 Ok(StatementType::With(with_stmt))
235 }
236 _ => Err(pyo3::exceptions::PyValueError::new_err(format!(
237 "Unimplemented statement type {}, {}",
238 ob_type,
239 dump(ob, None)?
240 ))),
241 }
242 }
243}
244
245impl CodeGen for StatementType {
246 type Context = CodeGenContext;
247 type Options = PythonOptions;
248 type SymbolTable = SymbolTableScopes;
249
250 fn find_symbols(self, symbols: Self::SymbolTable) -> Self::SymbolTable {
251 match self {
252 StatementType::Assign(a) => a.find_symbols(symbols),
253 StatementType::AugAssign(a) => a.find_symbols(symbols),
254 StatementType::ClassDef(c) => c.find_symbols(symbols),
255 StatementType::FunctionDef(f) => f.find_symbols(symbols),
256 StatementType::Import(i) => i.find_symbols(symbols),
257 StatementType::ImportFrom(i) => i.find_symbols(symbols),
258 StatementType::Expr(e) => e.find_symbols(symbols),
259 StatementType::If(i) => i.find_symbols(symbols),
260 StatementType::For(f) => f.find_symbols(symbols),
261 StatementType::While(w) => w.find_symbols(symbols),
262 StatementType::Try(t) => t.find_symbols(symbols),
263 StatementType::AsyncWith(aw) => aw.find_symbols(symbols),
264 StatementType::AsyncFor(af) => af.find_symbols(symbols),
265 StatementType::Raise(r) => r.find_symbols(symbols),
266 StatementType::With(w) => w.find_symbols(symbols),
267 _ => symbols,
268 }
269 }
270
271 fn to_rust(
272 self,
273 ctx: Self::Context,
274 options: Self::Options,
275 symbols: Self::SymbolTable,
276 ) -> Result<TokenStream, Box<dyn std::error::Error>> {
277 match self {
278 StatementType::AsyncFunctionDef(s) => {
279 let func_def = s
280 .to_rust(Self::Context::Async(Box::new(ctx)), options, symbols)
281 .expect("Parsing async function");
282 Ok(quote!(#func_def))
283 }
284 StatementType::Assign(a) => a.to_rust(ctx, options, symbols),
285 StatementType::AugAssign(a) => a.to_rust(ctx, options, symbols),
286 StatementType::Break => Ok(quote! {break;}),
287 StatementType::Call(c) => c.to_rust(ctx, options, symbols),
288 StatementType::ClassDef(c) => c.to_rust(ctx, options, symbols),
289 StatementType::Continue => Ok(quote! {continue;}),
290 StatementType::Pass => Ok(quote! {}),
291 StatementType::FunctionDef(s) => s.to_rust(ctx, options, symbols),
292 StatementType::Import(s) => s.to_rust(ctx, options, symbols),
293 StatementType::ImportFrom(s) => s.to_rust(ctx, options, symbols),
294 StatementType::Expr(s) => s.to_rust(ctx, options, symbols),
295 StatementType::Return(None) => Ok(quote!(return)),
296 StatementType::Return(Some(e)) => {
297 let exp = e
298 .clone()
299 .to_rust(ctx, options, symbols)
300 .unwrap_or_else(|_| panic!("parsing expression {:#?}", e));
301 Ok(quote!(return #exp))
302 }
303 StatementType::If(i) => i.to_rust(ctx, options, symbols),
304 StatementType::For(f) => f.to_rust(ctx, options, symbols),
305 StatementType::While(w) => w.to_rust(ctx, options, symbols),
306 StatementType::Try(t) => t.to_rust(ctx, options, symbols),
307 StatementType::AsyncWith(aw) => aw.to_rust(ctx, options, symbols),
308 StatementType::AsyncFor(af) => af.to_rust(ctx, options, symbols),
309 StatementType::Raise(r) => r.to_rust(ctx, options, symbols),
310 StatementType::With(w) => w.to_rust(ctx, options, symbols),
311 _ => {
312 let error = Error::StatementNotYetImplemented(self);
313 Err(Box::new(error))
314 }
315 }
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn check_pass_statement() {
325 let statement = StatementType::Pass;
326 let options = PythonOptions::default();
327 let tokens = statement.clone().to_rust(
328 CodeGenContext::Module("".to_string()),
329 options,
330 SymbolTableScopes::new(),
331 );
332
333 debug!("statement: {:?}, tokens: {:?}", statement, tokens);
334 assert_eq!(tokens.unwrap().is_empty(), true);
335 }
336
337 #[test]
338 fn check_break_statement() {
339 let statement = StatementType::Break;
340 let options = PythonOptions::default();
341 let tokens = statement.clone().to_rust(
342 CodeGenContext::Module("".to_string()),
343 options,
344 SymbolTableScopes::new(),
345 );
346
347 debug!("statement: {:?}, tokens: {:?}", statement, tokens);
348 assert_eq!(tokens.unwrap().is_empty(), false);
349 }
350
351 #[test]
352 fn check_continue_statement() {
353 let statement = StatementType::Continue;
354 let options = PythonOptions::default();
355 let tokens = statement.clone().to_rust(
356 CodeGenContext::Module("".to_string()),
357 options,
358 SymbolTableScopes::new(),
359 );
360
361 debug!("statement: {:?}, tokens: {:?}", statement, tokens);
362 assert_eq!(tokens.unwrap().is_empty(), false);
363 }
364
365 #[test]
366 fn return_with_nothing() {
367 let tree = crate::parse("return", "<none>").unwrap();
368 assert_eq!(tree.raw.body.len(), 1);
369 assert_eq!(
370 tree.raw.body[0].statement,
371 StatementType::Return(Some(Expr {
372 value: crate::tree::ExprType::NoneType(crate::tree::Constant(None)),
373 lineno: Some(1),
374 col_offset: Some(0),
375 end_lineno: Some(1),
376 end_col_offset: Some(6),
377 ..Default::default()
378 }))
379 );
380 }
381
382 #[test]
383 fn return_with_expr() {
384 let lit = litrs::Literal::Integer(litrs::IntegerLit::parse(String::from("8")).unwrap());
385 let tree = crate::parse("return 8", "<none>").unwrap();
386 assert_eq!(tree.raw.body.len(), 1);
387 assert_eq!(
388 tree.raw.body[0].statement,
389 StatementType::Return(Some(Expr {
390 value: crate::tree::ExprType::Constant(crate::tree::Constant(Some(lit))),
391 lineno: Some(1),
392 col_offset: Some(0),
393 end_lineno: Some(1),
394 end_col_offset: Some(8),
395 ..Default::default()
396 }))
397 );
398 }
399
400 #[test]
401 fn does_module_compile() {
402 let options = PythonOptions::default();
403 let result = crate::parse(
404 "#test comment
405def foo():
406 continue
407 pass
408",
409 "test_case",
410 )
411 .unwrap();
412 log::info!("{:?}", result);
413 let code = result.to_rust(
414 CodeGenContext::Module("".to_string()),
415 options,
416 SymbolTableScopes::new(),
417 );
418 log::info!("module: {:?}", code);
419 }
420}