1use proc_macro2::TokenStream;
2use pyo3::{FromPyObject, PyAny, PyResult};
3use quote::quote;
4
5use crate::{
6 dump, Assign, Call, ClassDef, CodeGen, CodeGenContext, Error, Expr, FunctionDef, Import,
7 ImportFrom, Node, PythonOptions, SymbolTableScopes,
8};
9
10use log::debug;
11
12use serde::{Deserialize, Serialize};
13
14#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
15pub struct Statement {
16 pub lineno: Option<usize>,
17 pub col_offset: Option<usize>,
18 pub end_lineno: Option<usize>,
19 pub end_col_offset: Option<usize>,
20 pub statement: StatementType,
21}
22
23impl<'a> FromPyObject<'a> for Statement {
24 fn extract(ob: &'a PyAny) -> PyResult<Self> {
25 Ok(Self {
26 lineno: ob.lineno(),
27 col_offset: ob.col_offset(),
28 end_lineno: ob.end_lineno(),
29 end_col_offset: ob.end_col_offset(),
30 statement: StatementType::extract(ob)?,
31 })
32 }
33}
34
35impl Node for Statement {
36 fn lineno(&self) -> Option<usize> {
37 self.lineno
38 }
39 fn col_offset(&self) -> Option<usize> {
40 self.col_offset
41 }
42 fn end_lineno(&self) -> Option<usize> {
43 self.end_lineno
44 }
45 fn end_col_offset(&self) -> Option<usize> {
46 self.end_col_offset
47 }
48}
49
50impl CodeGen for Statement {
51 type Context = CodeGenContext;
52 type Options = PythonOptions;
53 type SymbolTable = SymbolTableScopes;
54
55 fn find_symbols(self, symbols: Self::SymbolTable) -> Self::SymbolTable {
56 self.statement.clone().find_symbols(symbols)
57 }
58
59 fn to_rust(
60 self,
61 ctx: Self::Context,
62 options: Self::Options,
63 symbols: Self::SymbolTable,
64 ) -> Result<TokenStream, Box<dyn std::error::Error>> {
65 Ok(self
66 .statement
67 .clone()
68 .to_rust(ctx, options, symbols)
69 .expect(
70 self.error_message(
71 "<unknown>",
72 format!("failed to compile statement {:#?}", self),
73 )
74 .as_str(),
75 ))
76 }
77}
78
79#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
80pub enum StatementType {
81 AsyncFunctionDef(FunctionDef),
82 Assign(Assign),
83 Break,
84 Continue,
85 ClassDef(ClassDef),
86 Call(Call),
87 Pass,
88 Return(Option<Expr>),
89 Import(Import),
90 ImportFrom(ImportFrom),
91 Expr(Expr),
92 FunctionDef(FunctionDef),
93
94 Unimplemented(String),
95}
96
97impl<'a> FromPyObject<'a> for StatementType {
98 fn extract(ob: &'a PyAny) -> PyResult<Self> {
99 let err_msg = format!("getting type for statement {:?}", ob);
100 let ob_type = ob
101 .get_type()
102 .name()
103 .unwrap_or_else(|_| panic!("{}", ob.error_message("<unknown>", err_msg)));
104
105 debug!("statement...ob_type: {}...{}", ob_type, dump(ob, Some(4))?);
106 match ob_type.as_ref() {
107 "AsyncFunctionDef" => Ok(StatementType::AsyncFunctionDef(
108 FunctionDef::extract(ob).unwrap_or_else(|_| {
109 panic!("Failed to extract async function: {:?}", dump(ob, Some(4)))
110 }),
111 )),
112 "Assign" => {
113 let assignment = Assign::extract(ob).expect("reading assignment");
114 Ok(StatementType::Assign(assignment))
115 }
116 "Pass" => Ok(StatementType::Pass),
117 "Call" => {
118 let call =
119 Call::extract(ob.getattr("value").unwrap_or_else(|_| {
120 panic!("getting value from {:?} in call statement", ob)
121 }))
122 .unwrap_or_else(|_| panic!("extracting call statement {:?}", ob));
123 debug!("call: {:?}", call);
124 Ok(StatementType::Call(call))
125 }
126 "ClassDef" => Ok(StatementType::ClassDef(
127 ClassDef::extract(ob).unwrap_or_else(|_| panic!("Class definition {:?}", ob)),
128 )),
129 "Continue" => Ok(StatementType::Continue),
130 "Break" => Ok(StatementType::Break),
131 "FunctionDef" => Ok(StatementType::FunctionDef(
132 FunctionDef::extract(ob).unwrap_or_else(|_| {
133 panic!("Failed to extract function: {:?}", dump(ob, Some(4)))
134 }),
135 )),
136 "Import" => Ok(StatementType::Import(
137 Import::extract(ob).unwrap_or_else(|_| panic!("Import {:?}", ob)),
138 )),
139 "ImportFrom" => Ok(StatementType::ImportFrom(
140 ImportFrom::extract(ob).unwrap_or_else(|_| panic!("ImportFrom {:?}", ob)),
141 )),
142 "Expr" => {
143 let expr = Expr::extract(
144 ob.extract()
145 .unwrap_or_else(|_| panic!("extracting Expr {:?}", ob)),
146 )
147 .expect(format!("Expr {:?}", ob).as_str());
148 Ok(StatementType::Expr(expr))
149 }
150 "Return" => {
151 log::debug!("return expression: {}", dump(ob, None)?);
152 let expr = Expr::extract(
153 ob.extract()
154 .unwrap_or_else(|_| panic!("extracting return Expr {:?}", ob)),
155 )
156 .unwrap_or_else(|_| panic!("return Expr {:?}", dump(ob, None)));
157 Ok(StatementType::Return(Some(expr)))
158 }
159 _ => Err(pyo3::exceptions::PyValueError::new_err(format!(
160 "Unimplemented statement type {}, {}",
161 ob_type,
162 dump(ob, None)?
163 ))),
164 }
165 }
166}
167
168impl CodeGen for StatementType {
169 type Context = CodeGenContext;
170 type Options = PythonOptions;
171 type SymbolTable = SymbolTableScopes;
172
173 fn find_symbols(self, symbols: Self::SymbolTable) -> Self::SymbolTable {
174 match self {
175 StatementType::Assign(a) => a.find_symbols(symbols),
176 StatementType::ClassDef(c) => c.find_symbols(symbols),
177 StatementType::FunctionDef(f) => f.find_symbols(symbols),
178 StatementType::Import(i) => i.find_symbols(symbols),
179 StatementType::ImportFrom(i) => i.find_symbols(symbols),
180 StatementType::Expr(e) => e.find_symbols(symbols),
181 _ => symbols,
182 }
183 }
184
185 fn to_rust(
186 self,
187 ctx: Self::Context,
188 options: Self::Options,
189 symbols: Self::SymbolTable,
190 ) -> Result<TokenStream, Box<dyn std::error::Error>> {
191 match self {
192 StatementType::AsyncFunctionDef(s) => {
193 let func_def = s
194 .to_rust(Self::Context::Async(Box::new(ctx)), options, symbols)
195 .expect("Parsing async function");
196 Ok(quote!(#func_def))
197 }
198 StatementType::Assign(a) => a.to_rust(ctx, options, symbols),
199 StatementType::Break => Ok(quote! {break;}),
200 StatementType::Call(c) => c.to_rust(ctx, options, symbols),
201 StatementType::ClassDef(c) => c.to_rust(ctx, options, symbols),
202 StatementType::Continue => Ok(quote! {continue;}),
203 StatementType::Pass => Ok(quote! {}),
204 StatementType::FunctionDef(s) => s.to_rust(ctx, options, symbols),
205 StatementType::Import(s) => s.to_rust(ctx, options, symbols),
206 StatementType::ImportFrom(s) => s.to_rust(ctx, options, symbols),
207 StatementType::Expr(s) => s.to_rust(ctx, options, symbols),
208 StatementType::Return(None) => Ok(quote!(return)),
209 StatementType::Return(Some(e)) => {
210 let exp = e
211 .clone()
212 .to_rust(ctx, options, symbols)
213 .unwrap_or_else(|_| panic!("parsing expression {:#?}", e));
214 Ok(quote!(return #exp))
215 }
216 _ => {
217 let error = Error::StatementNotYetImplemented(self);
218 Err(Box::new(error))
219 }
220 }
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn check_pass_statement() {
230 let statement = StatementType::Pass;
231 let options = PythonOptions::default();
232 let tokens = statement.clone().to_rust(
233 CodeGenContext::Module("".to_string()),
234 options,
235 SymbolTableScopes::new(),
236 );
237
238 debug!("statement: {:?}, tokens: {:?}", statement, tokens);
239 assert_eq!(tokens.unwrap().is_empty(), true);
240 }
241
242 #[test]
243 fn check_break_statement() {
244 let statement = StatementType::Break;
245 let options = PythonOptions::default();
246 let tokens = statement.clone().to_rust(
247 CodeGenContext::Module("".to_string()),
248 options,
249 SymbolTableScopes::new(),
250 );
251
252 debug!("statement: {:?}, tokens: {:?}", statement, tokens);
253 assert_eq!(tokens.unwrap().is_empty(), false);
254 }
255
256 #[test]
257 fn check_continue_statement() {
258 let statement = StatementType::Continue;
259 let options = PythonOptions::default();
260 let tokens = statement.clone().to_rust(
261 CodeGenContext::Module("".to_string()),
262 options,
263 SymbolTableScopes::new(),
264 );
265
266 debug!("statement: {:?}, tokens: {:?}", statement, tokens);
267 assert_eq!(tokens.unwrap().is_empty(), false);
268 }
269
270 #[test]
271 fn return_with_nothing() {
272 let tree = crate::parse("return", "<none>").unwrap();
273 assert_eq!(tree.raw.body.len(), 1);
274 assert_eq!(
275 tree.raw.body[0].statement,
276 StatementType::Return(Some(Expr {
277 value: crate::tree::ExprType::NoneType(crate::tree::Constant(None)),
278 lineno: Some(1),
279 col_offset: Some(0),
280 end_lineno: Some(1),
281 end_col_offset: Some(6),
282 ..Default::default()
283 }))
284 );
285 }
286
287 #[test]
288 fn return_with_expr() {
289 let lit = litrs::Literal::Integer(litrs::IntegerLit::parse(String::from("8")).unwrap());
290 let tree = crate::parse("return 8", "<none>").unwrap();
291 assert_eq!(tree.raw.body.len(), 1);
292 assert_eq!(
293 tree.raw.body[0].statement,
294 StatementType::Return(Some(Expr {
295 value: crate::tree::ExprType::Constant(crate::tree::Constant(Some(lit))),
296 lineno: Some(1),
297 col_offset: Some(0),
298 end_lineno: Some(1),
299 end_col_offset: Some(8),
300 ..Default::default()
301 }))
302 );
303 }
304
305 #[test]
306 fn does_module_compile() {
307 let options = PythonOptions::default();
308 let result = crate::parse(
309 "#test comment
310def foo():
311 continue
312 pass
313",
314 "test_case",
315 )
316 .unwrap();
317 log::info!("{:?}", result);
318 let code = result.to_rust(
319 CodeGenContext::Module("".to_string()),
320 options,
321 SymbolTableScopes::new(),
322 );
323 log::info!("module: {:?}", code);
324 }
325}