python_ast/ast/tree/
expression.rs

1use proc_macro2::TokenStream;
2use pyo3::{FromPyObject, PyAny, PyResult};
3use quote::quote;
4use serde::{Deserialize, Serialize};
5
6use crate::{
7    dump, Attribute, Await, BinOp, BoolOp, Call, CodeGen, CodeGenContext, Compare, Constant, Error,
8    Name, NamedExpr, Node, PythonOptions, SymbolTableScopes, UnaryOp,
9};
10
11/// Mostly this shouldn't be used, but it exists so that we don't have to manually implement FromPyObject on all of ExprType
12#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
13#[repr(transparent)]
14pub struct Container<T>(pub T);
15
16impl<'a> FromPyObject<'a> for Container<crate::pytypes::List<ExprType>> {
17    fn extract(ob: &'a PyAny) -> PyResult<Self> {
18        let list = crate::pytypes::List::<ExprType>::new();
19
20        log::debug!("pylist: {}", dump(ob, Some(4))?);
21        let _converted_list: Vec<&PyAny> = ob.extract()?;
22        for item in ob.iter().expect("extracting list") {
23            log::debug!("item: {:?}", item);
24        }
25
26        Ok(Self(list))
27    }
28}
29
30#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
31pub enum ExprType {
32    BoolOp(BoolOp),
33    NamedExpr(NamedExpr),
34    BinOp(BinOp),
35    UnaryOp(UnaryOp),
36    /*Lambda(Lamda),
37    IfExp(IfExp),
38    Dict(Dict),
39    Set(Set),
40    ListComp(ListComp),
41    SetComp(SetComp),
42    DictComp(DictComp),
43    GeneratorExp(),*/
44    Await(Await),
45    /*Yield(),
46    YieldFrom(),*/
47    Compare(Compare),
48    Call(Call),
49    /*FormattedValue(),
50    JoinedStr(),*/
51    Constant(Constant),
52
53    /// These can appear in a few places, such as the left side of an assignment.
54    Attribute(Attribute), /*
55                          Subscript(),
56                          Starred(),*/
57    Name(Name),
58    List(Vec<ExprType>),
59    /*Tuple(),
60    Slice(),*/
61    NoneType(Constant),
62
63    Unimplemented(String),
64    #[default]
65    Unknown,
66}
67
68impl<'a> FromPyObject<'a> for ExprType {
69    fn extract(ob: &'a PyAny) -> PyResult<Self> {
70        log::debug!("exprtype ob: {}", dump(ob, Some(4))?);
71
72        let expr_type = ob.get_type().name().expect(
73            ob.error_message(
74                "<unknown>",
75                format!("extracting type name {:?} in expression", dump(ob, None)),
76            )
77            .as_str(),
78        );
79        log::debug!("expression type: {}, value: {}", expr_type, dump(ob, None)?);
80
81        let r = match expr_type.as_ref() {
82            "Attribute" => {
83                let a = Attribute::extract(ob).expect(
84                    ob.error_message(
85                        "<unknown>",
86                        format!("extracting Attribute in expression {}", dump(ob, None)?),
87                    )
88                    .as_str(),
89                );
90                Ok(Self::Attribute(a))
91            }
92            "Await" => {
93                //println!("await: {}", dump(ob, None)?);
94                let a = Await::extract(ob).expect(
95                    ob.error_message(
96                        "<unknown>",
97                        format!("extracting await value in expression {}", dump(ob, None)?),
98                    )
99                    .as_str(),
100                );
101                Ok(Self::Await(a))
102            }
103            "Call" => {
104                let et = Call::extract(ob).expect(
105                    ob.error_message(
106                        "<unknown>",
107                        format!("parsing Call expression {}", dump(ob, None)?),
108                    )
109                    .as_str(),
110                );
111                Ok(Self::Call(et))
112            }
113            "Compare" => {
114                let c = Compare::extract(ob).expect(
115                    ob.error_message(
116                        "<unknown>",
117                        format!("extracting Compare in expression {}", dump(ob, None)?),
118                    )
119                    .as_str(),
120                );
121                Ok(Self::Compare(c))
122            }
123            "Constant" => {
124                log::debug!("constant: {}", dump(ob, None)?);
125                let c = Constant::extract(ob).expect(
126                    ob.error_message(
127                        "<unknown>",
128                        format!("extracting Constant in expression {}", dump(ob, None)?),
129                    )
130                    .as_str(),
131                );
132                Ok(Self::Constant(c))
133            }
134            "List" => {
135                //let list = crate::pytypes::List::<ExprType>::new();
136                let list: Vec<ExprType> = ob
137                    .extract()
138                    .expect(format!("extracting List {}", dump(ob, None)?).as_str());
139                Ok(Self::List(list))
140            }
141            "Name" => {
142                let name = Name::extract(ob).expect(
143                    ob.error_message(
144                        "<unknown>",
145                        format!("parsing Name expression {}", dump(ob, None)?),
146                    )
147                    .as_str(),
148                );
149                Ok(Self::Name(name))
150            }
151            "UnaryOp" => {
152                let c = UnaryOp::extract(ob).expect(
153                    ob.error_message(
154                        "<unknown>",
155                        format!("extracting UnaryOp in expression {}", dump(ob, None)?),
156                    )
157                    .as_str(),
158                );
159                Ok(Self::UnaryOp(c))
160            }
161            "BinOp" => {
162                let c = BinOp::extract(ob).expect(
163                    ob.error_message(
164                        "<unknown>",
165                        format!("extracting BinOp in expression {}", dump(ob, None)?),
166                    )
167                    .as_str(),
168                );
169                Ok(Self::BinOp(c))
170            }
171            _ => {
172                let err_msg = format!(
173                    "Unimplemented expression type {}, {}",
174                    expr_type,
175                    dump(ob, None)?
176                );
177                Err(pyo3::exceptions::PyValueError::new_err(
178                    ob.error_message("<unknown>", err_msg.as_str()),
179                ))
180            }
181        };
182        r
183    }
184}
185
186impl<'a> CodeGen for ExprType {
187    type Context = CodeGenContext;
188    type Options = PythonOptions;
189    type SymbolTable = SymbolTableScopes;
190
191    fn to_rust(
192        self,
193        ctx: Self::Context,
194        options: Self::Options,
195        symbols: Self::SymbolTable,
196    ) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
197        match self {
198            ExprType::Attribute(attribute) => attribute.to_rust(ctx, options, symbols),
199            ExprType::Await(func) => func.to_rust(ctx, options, symbols),
200            ExprType::BinOp(binop) => binop.to_rust(ctx, options, symbols),
201            ExprType::Call(call) => call.to_rust(ctx, options, symbols),
202            ExprType::Compare(c) => c.to_rust(ctx, options, symbols),
203            ExprType::Constant(c) => c.to_rust(ctx, options, symbols),
204            ExprType::List(l) => {
205                let mut ts = TokenStream::new();
206                for li in l {
207                    let code = li
208                        .clone()
209                        .to_rust(ctx.clone(), options.clone(), symbols.clone())
210                        .expect(format!("Extracting list item {:?}", li).as_str());
211                    ts.extend(code);
212                    ts.extend(quote!(,));
213                }
214                Ok(ts)
215            }
216            ExprType::Name(name) => name.to_rust(ctx, options, symbols),
217            ExprType::NoneType(c) => c.to_rust(ctx, options, symbols),
218            ExprType::UnaryOp(operand) => operand.to_rust(ctx, options, symbols),
219
220            _ => {
221                let error = Error::ExprTypeNotYetImplemented(self);
222                Err(error.into())
223            }
224        }
225    }
226}
227
228/// An Expr only contains a single value key, which leads to the actual expression,
229/// which is one of several types.
230#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
231pub struct Expr {
232    pub value: ExprType,
233    pub ctx: Option<String>,
234    pub lineno: Option<usize>,
235    pub col_offset: Option<usize>,
236    pub end_lineno: Option<usize>,
237    pub end_col_offset: Option<usize>,
238}
239
240impl<'a> FromPyObject<'a> for Expr {
241    fn extract(ob: &'a PyAny) -> PyResult<Self> {
242        let err_msg = format!("extracting object value {} in expression", dump(ob, None)?);
243
244        let ob_value = ob
245            .getattr("value")
246            .expect(ob.error_message("<unknown>", err_msg.as_str()).as_str());
247        log::debug!("ob_value: {}", dump(ob_value, None)?);
248
249        // The context is Load, Store, etc. For some types of expressions such as Constants, it does not exist.
250        let ctx: Option<String> = if let Ok(pyany) = ob_value.getattr("ctx") {
251            pyany.get_type().extract().unwrap_or_default()
252        } else {
253            None
254        };
255
256        let mut r = Self {
257            value: ExprType::Unknown,
258            ctx: ctx,
259            lineno: ob.lineno(),
260            col_offset: ob.col_offset(),
261            end_lineno: ob.end_lineno(),
262            end_col_offset: ob.end_col_offset(),
263        };
264
265        let expr_type = ob_value.get_type().name().expect(
266            ob.error_message(
267                "<unknown>",
268                format!("extracting type name {:?} in expression", ob_value),
269            )
270            .as_str(),
271        );
272        log::debug!(
273            "expression type: {}, value: {}",
274            expr_type,
275            dump(ob_value, None)?
276        );
277        match expr_type.as_ref() {
278            "Atribute" => {
279                let a = Attribute::extract(ob_value).expect(
280                    ob.error_message(
281                        "<unknown>",
282                        format!("extracting BinOp in expression {:?}", dump(ob_value, None)?),
283                    )
284                    .as_str(),
285                );
286                r.value = ExprType::Attribute(a);
287                Ok(r)
288            }
289            "Await" => {
290                let a = Await::extract(ob_value).expect(
291                    ob.error_message(
292                        "<unknown>",
293                        format!("extracting BinOp in expression {:?}", dump(ob_value, None)?),
294                    )
295                    .as_str(),
296                );
297                r.value = ExprType::Await(a);
298                Ok(r)
299            }
300            "BinOp" => {
301                let c = BinOp::extract(ob_value).expect(
302                    ob.error_message(
303                        "<unknown>",
304                        format!("extracting BinOp in expression {:?}", dump(ob_value, None)?),
305                    )
306                    .as_str(),
307                );
308                r.value = ExprType::BinOp(c);
309                Ok(r)
310            }
311            "BoolOp" => {
312                let c = BoolOp::extract(ob_value).expect(
313                    ob.error_message(
314                        "<unknown>",
315                        format!("extracting BinOp in expression {:?}", dump(ob_value, None)?),
316                    )
317                    .as_str(),
318                );
319                r.value = ExprType::BoolOp(c);
320                Ok(r)
321            }
322            "Call" => {
323                let et = Call::extract(ob_value).expect(
324                    ob.error_message(
325                        "<unknown>",
326                        format!("parsing Call expression {:?}", ob_value),
327                    )
328                    .as_str(),
329                );
330                r.value = ExprType::Call(et);
331                Ok(r)
332            }
333            "Constant" => {
334                let c = Constant::extract(ob_value).expect(
335                    ob.error_message(
336                        "<unknown>",
337                        format!(
338                            "extracting Constant in expression {:?}",
339                            dump(ob_value, None)?
340                        ),
341                    )
342                    .as_str(),
343                );
344                r.value = ExprType::Constant(c);
345                Ok(r)
346            }
347            "Compare" => {
348                let c = Compare::extract(ob_value).expect(
349                    ob.error_message(
350                        "<unknown>",
351                        format!(
352                            "extracting Compare in expression {:?}",
353                            dump(ob_value, None)?
354                        ),
355                    )
356                    .as_str(),
357                );
358                r.value = ExprType::Compare(c);
359                Ok(r)
360            }
361            "List" => {
362                //let list = crate::pytypes::List::<ExprType>::new();
363                let list: Vec<ExprType> = ob.extract().expect("extracting List");
364                r.value = ExprType::List(list);
365                Ok(r)
366            }
367            "Name" => {
368                let name = Name::extract(ob_value).expect(
369                    ob.error_message(
370                        "<unknown>",
371                        format!("parsing Call expression {:?}", ob_value),
372                    )
373                    .as_str(),
374                );
375                r.value = ExprType::Name(name);
376                Ok(r)
377            }
378            "UnaryOp" => {
379                let c = UnaryOp::extract(ob_value).expect(
380                    ob.error_message(
381                        "<unknown>",
382                        format!(
383                            "extracting UnaryOp in expression {:?}",
384                            dump(ob_value, None)?
385                        ),
386                    )
387                    .as_str(),
388                );
389                r.value = ExprType::UnaryOp(c);
390                Ok(r)
391            }
392            // In sitations where an expression is optional, we may see a NoneType expressions.
393            "NoneType" => {
394                r.value = ExprType::NoneType(Constant(None));
395                Ok(r)
396            }
397            _ => {
398                let err_msg = format!(
399                    "Unimplemented expression type {}, {}",
400                    expr_type,
401                    dump(ob, None)?
402                );
403                Err(pyo3::exceptions::PyValueError::new_err(
404                    ob.error_message("<unknown>", err_msg.as_str()),
405                ))
406            }
407        }
408    }
409}
410
411impl CodeGen for Expr {
412    type Context = CodeGenContext;
413    type Options = PythonOptions;
414    type SymbolTable = SymbolTableScopes;
415
416    fn to_rust(
417        self,
418        ctx: Self::Context,
419        options: Self::Options,
420        symbols: Self::SymbolTable,
421    ) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
422        let module_name = match ctx.clone() {
423            CodeGenContext::Module(name) => name,
424            _ => "unknown".to_string(),
425        };
426
427        match self.value.clone() {
428            ExprType::Await(a) => a.to_rust(ctx.clone(), options, symbols),
429            ExprType::BinOp(binop) => binop.to_rust(ctx.clone(), options, symbols),
430            ExprType::BoolOp(boolop) => boolop.to_rust(ctx.clone(), options, symbols),
431            ExprType::Call(call) => call.to_rust(ctx.clone(), options, symbols),
432            ExprType::Constant(constant) => constant.to_rust(ctx, options, symbols),
433            ExprType::Compare(compare) => compare.to_rust(ctx, options, symbols),
434            ExprType::UnaryOp(operand) => operand.to_rust(ctx, options, symbols),
435            ExprType::Name(name) => name.to_rust(ctx, options, symbols),
436            // NoneType expressions generate no code.
437            ExprType::NoneType(_c) => Ok(quote!()),
438            _ => {
439                let error = Error::ExprTypeNotYetImplemented(self.value);
440                Err(error.into())
441            }
442        }
443    }
444}
445
446impl Node for Expr {
447    fn lineno(&self) -> Option<usize> {
448        self.lineno
449    }
450
451    fn col_offset(&self) -> Option<usize> {
452        self.col_offset
453    }
454
455    fn end_lineno(&self) -> Option<usize> {
456        self.end_lineno
457    }
458
459    fn end_col_offset(&self) -> Option<usize> {
460        self.end_col_offset
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467
468    #[test]
469    fn check_call_expression() {
470        let expression = crate::parse("test()", "test.py").unwrap();
471        println!("Python tree: {:#?}", expression);
472        let mut options = PythonOptions::default();
473        options.with_std_python = false;
474        let symbols = SymbolTableScopes::new();
475        let tokens = expression
476            .clone()
477            .to_rust(CodeGenContext::Module("test".to_string()), options, symbols)
478            .unwrap();
479        println!("Rust tokens: {}", tokens.to_string());
480        assert_eq!(tokens.to_string(), quote!(test()).to_string());
481    }
482}