python_ast/ast/tree/
compare.rs

1use proc_macro2::TokenStream;
2use pyo3::{FromPyObject, PyAny, PyResult};
3use quote::quote;
4use serde::{Deserialize, Serialize};
5
6use crate::{
7    dump, CodeGen, CodeGenContext, Error, ExprType, Node, PythonOptions, SymbolTableScopes,
8};
9
10#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
11pub enum Compares {
12    Eq,
13    NotEq,
14    Lt,
15    LtE,
16    Gt,
17    GtE,
18    Is,
19    IsNot,
20    In,
21    NotIn,
22
23    Unknown,
24}
25
26impl<'a> FromPyObject<'a> for Compares {
27    fn extract(ob: &'a PyAny) -> PyResult<Self> {
28        let err_msg = format!("Unimplemented unary op {}", dump(ob, None)?);
29        Err(pyo3::exceptions::PyValueError::new_err(
30            ob.error_message("<unknown>", err_msg),
31        ))
32    }
33}
34
35#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
36pub struct Compare {
37    ops: Vec<Compares>,
38    left: Box<ExprType>,
39    comparators: Vec<ExprType>,
40}
41
42impl<'a> FromPyObject<'a> for Compare {
43    fn extract(ob: &'a PyAny) -> PyResult<Self> {
44        log::debug!("ob: {}", dump(ob, None)?);
45
46        // Python allows for multiple comparators, rust we only supports one, so we have to rewrite the comparison a little.
47        let ops: Vec<&PyAny> = ob
48            .getattr("ops")
49            .expect(
50                ob.error_message("<unknown>", "error getting unary operator")
51                    .as_str(),
52            )
53            .extract()
54            .expect("getting ops from Compare");
55
56        let mut op_list = Vec::new();
57
58        for op in ops.iter() {
59            let op_type = op.get_type().name().expect(
60                ob.error_message(
61                    "<unknown>",
62                    format!("extracting type name {:?} for binary operator", op),
63                )
64                .as_str(),
65            );
66
67            let op = match op_type.as_ref() {
68                "Eq" => Compares::Eq,
69                "NotEq" => Compares::NotEq,
70                "Lt" => Compares::Lt,
71                "LtE" => Compares::LtE,
72                "Gt" => Compares::Gt,
73                "GtE" => Compares::GtE,
74                "Is" => Compares::Is,
75                "IsNot" => Compares::IsNot,
76                "In" => Compares::In,
77                "NotIn" => Compares::NotIn,
78
79                _ => {
80                    log::debug!("Found unknown Compare {:?}", op);
81                    Compares::Unknown
82                }
83            };
84            op_list.push(op);
85        }
86
87        let left = ob.getattr("left").expect(
88            ob.error_message("<unknown>", "error getting comparator")
89                .as_str(),
90        );
91
92        let comparators = ob.getattr("comparators").expect(
93            ob.error_message("<unknown>", "error getting compoarator")
94                .as_str(),
95        );
96        log::debug!(
97            "left: {}, comparators: {}",
98            dump(left, None)?,
99            dump(comparators, None)?
100        );
101
102        let left = ExprType::extract(left).expect("getting binary operator operand");
103        let comparators: Vec<ExprType> = comparators
104            .extract()
105            .expect("getting comparators from Compare");
106
107        log::debug!(
108            "left: {:?}, comparators: {:?}, op: {:?}",
109            left,
110            comparators,
111            op_list
112        );
113
114        return Ok(Compare {
115            ops: op_list,
116            left: Box::new(left),
117            comparators: comparators,
118        });
119    }
120}
121
122impl CodeGen for Compare {
123    type Context = CodeGenContext;
124    type Options = PythonOptions;
125    type SymbolTable = SymbolTableScopes;
126
127    fn to_rust(
128        self,
129        ctx: Self::Context,
130        options: Self::Options,
131        symbols: Self::SymbolTable,
132    ) -> Result<TokenStream, Box<dyn std::error::Error>> {
133        let mut outer_ts = TokenStream::new();
134        let left = self
135            .left
136            .clone()
137            .to_rust(ctx.clone(), options.clone(), symbols.clone())?;
138        let ops = self.ops.clone();
139        let comparators = self.comparators.clone();
140
141        let mut index = 0;
142        for op in ops.iter() {
143            let comparator = comparators
144                .get(index)
145                .expect("getting comparator")
146                .clone()
147                .to_rust(ctx.clone(), options.clone(), symbols.clone())?;
148            let tokens = match op {
149                Compares::Eq => quote!(((#left) == (#comparator))),
150                Compares::NotEq => quote!(((#left) != (#comparator))),
151                Compares::Lt => quote!(((#left) < (#comparator))),
152                Compares::LtE => quote!(((#left) <= (#comparator))),
153                Compares::Gt => quote!(((#left) > (#comparator))),
154                Compares::GtE => quote!(((#left) >= (#comparator))),
155                Compares::Is => quote!((&(#left) == &(#comparator))),
156                Compares::IsNot => quote!((&(#left) != &(#comparator))),
157                Compares::In => quote!(((#comparator).get(#left) == Some(_))),
158                Compares::NotIn => quote!(((#comparator).get(#left) == None)),
159
160                _ => return Err(Error::CompareNotYetImplemented(self).into()),
161            };
162
163            index += 1;
164
165            outer_ts.extend(tokens);
166            if index < ops.len() {
167                outer_ts.extend(quote!( && ));
168            }
169        }
170        Ok(outer_ts)
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177
178    #[test]
179    fn test_simple_eq() {
180        let options = PythonOptions::default();
181        let result = crate::parse("1 == 2", "test_case.py").unwrap();
182        log::info!("Python tree: {:?}", result);
183        //info!("{}", result);
184
185        let code = result.to_rust(
186            CodeGenContext::Module("test_case".to_string()),
187            options,
188            SymbolTableScopes::new(),
189        );
190        log::info!("module: {:?}", code);
191    }
192
193    #[test]
194    fn test_complex_compare() {
195        let options = PythonOptions::default();
196        let result = crate::parse("1 < a > 6", "test_case.py").unwrap();
197        log::info!("Python tree: {:?}", result);
198        //info!("{}", result);
199
200        let code = result.to_rust(
201            CodeGenContext::Module("test_case".to_string()),
202            options,
203            SymbolTableScopes::new(),
204        );
205        log::info!("module: {:?}", code);
206    }
207}