python_ast/ast/tree/
compare.rs

1use proc_macro2::TokenStream;
2use pyo3::{Bound, FromPyObject, PyAny, PyResult, prelude::PyAnyMethods, types::PyTypeMethods};
3use quote::quote;
4use serde::{Deserialize, Serialize};
5
6use crate::{
7    dump, CodeGen, CodeGenContext, Error, ExprType, Node, PythonOptions, SymbolTableScopes,
8};
9
10#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
11pub enum Compares {
12    Eq,
13    NotEq,
14    Lt,
15    LtE,
16    Gt,
17    GtE,
18    Is,
19    IsNot,
20    In,
21    NotIn,
22
23    Unknown,
24}
25
26impl<'a> FromPyObject<'a> for Compares {
27    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
28        let err_msg = format!("Unimplemented unary op {}", dump(ob, None)?);
29        Err(pyo3::exceptions::PyValueError::new_err(
30            ob.error_message("<unknown>", err_msg),
31        ))
32    }
33}
34
35#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
36pub struct Compare {
37    ops: Vec<Compares>,
38    left: Box<ExprType>,
39    comparators: Vec<ExprType>,
40}
41
42impl<'a> FromPyObject<'a> for Compare {
43    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
44        log::debug!("ob: {}", dump(ob, None)?);
45
46        // Python allows for multiple comparators, rust we only supports one, so we have to rewrite the comparison a little.
47        let ops_bound: Vec<Bound<PyAny>> = ob
48            .getattr("ops")
49            .expect(
50                ob.error_message("<unknown>", "error getting unary operator")
51                    .as_str(),
52            )
53            .extract()
54            .expect("getting ops from Compare");
55
56        let mut op_list = Vec::new();
57
58        for op in ops_bound.iter() {
59            let op_type = op.get_type().name().expect(
60                ob.error_message(
61                    "<unknown>",
62                    "error extracting type name for binary operator",
63                )
64                .as_str(),
65            );
66
67            let op_type_str: String = op_type.extract()?;
68            let op = match op_type_str.as_str() {
69                "Eq" => Compares::Eq,
70                "NotEq" => Compares::NotEq,
71                "Lt" => Compares::Lt,
72                "LtE" => Compares::LtE,
73                "Gt" => Compares::Gt,
74                "GtE" => Compares::GtE,
75                "Is" => Compares::Is,
76                "IsNot" => Compares::IsNot,
77                "In" => Compares::In,
78                "NotIn" => Compares::NotIn,
79
80                _ => {
81                    log::debug!("Found unknown Compare with type: {}", op_type_str);
82                    Compares::Unknown
83                }
84            };
85            op_list.push(op);
86        }
87
88        let left = ob.getattr("left").expect(
89            ob.error_message("<unknown>", "error getting comparator")
90                .as_str(),
91        );
92
93        let comparators = ob.getattr("comparators").expect(
94            ob.error_message("<unknown>", "error getting compoarator")
95                .as_str(),
96        );
97        log::debug!(
98            "left: {}, comparators: {}",
99            dump(&left, None)?,
100            dump(&comparators, None)?
101        );
102
103        let left = left.extract().expect("getting binary operator operand");
104        let comparators: Vec<ExprType> = comparators
105            .extract()
106            .expect("getting comparators from Compare");
107
108        log::debug!(
109            "left: {:?}, comparators: {:?}, op: {:?}",
110            left,
111            comparators,
112            op_list
113        );
114
115        return Ok(Compare {
116            ops: op_list,
117            left: Box::new(left),
118            comparators: comparators,
119        });
120    }
121}
122
123impl CodeGen for Compare {
124    type Context = CodeGenContext;
125    type Options = PythonOptions;
126    type SymbolTable = SymbolTableScopes;
127
128    fn to_rust(
129        self,
130        ctx: Self::Context,
131        options: Self::Options,
132        symbols: Self::SymbolTable,
133    ) -> Result<TokenStream, Box<dyn std::error::Error>> {
134        let mut outer_ts = TokenStream::new();
135        let left = self
136            .left
137            .clone()
138            .to_rust(ctx.clone(), options.clone(), symbols.clone())?;
139        let ops = self.ops.clone();
140        let comparators = self.comparators.clone();
141
142        let mut index = 0;
143        for op in ops.iter() {
144            let comparator = comparators
145                .get(index)
146                .expect("getting comparator")
147                .clone()
148                .to_rust(ctx.clone(), options.clone(), symbols.clone())?;
149            let tokens = match op {
150                Compares::Eq => quote!((#left) == (#comparator)),
151                Compares::NotEq => quote!((#left) != (#comparator)),
152                Compares::Lt => quote!((#left) < (#comparator)),
153                Compares::LtE => quote!((#left) <= (#comparator)),
154                Compares::Gt => quote!((#left) > (#comparator)),
155                Compares::GtE => quote!((#left) >= (#comparator)),
156                Compares::Is => quote!(&#left == &#comparator),
157                Compares::IsNot => quote!(&#left != &#comparator),
158                Compares::In => quote!((#comparator).get(#left) == Some(_)),
159                Compares::NotIn => quote!((#comparator).get(#left) == None),
160
161                _ => return Err(Error::CompareNotYetImplemented(self).into()),
162            };
163
164            index += 1;
165
166            outer_ts.extend(tokens);
167            if index < ops.len() {
168                outer_ts.extend(quote!( && ));
169            }
170        }
171        Ok(outer_ts)
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178
179    #[test]
180    fn test_simple_eq() {
181        let options = PythonOptions::default();
182        let result = crate::parse("1 == 2", "test_case.py").unwrap();
183        log::info!("Python tree: {:?}", result);
184        //info!("{}", result);
185
186        let code = result.to_rust(
187            CodeGenContext::Module("test_case".to_string()),
188            options,
189            SymbolTableScopes::new(),
190        );
191        log::info!("module: {:?}", code);
192    }
193
194    #[test]
195    fn test_complex_compare() {
196        let options = PythonOptions::default();
197        let result = crate::parse("1 < a > 6", "test_case.py").unwrap();
198        log::info!("Python tree: {:?}", result);
199        //info!("{}", result);
200
201        let code = result.to_rust(
202            CodeGenContext::Module("test_case".to_string()),
203            options,
204            SymbolTableScopes::new(),
205        );
206        log::info!("module: {:?}", code);
207    }
208}