1use proc_macro2::TokenStream;
2use pyo3::{Bound, FromPyObject, PyAny, PyResult, prelude::PyAnyMethods, types::PyTypeMethods};
3use quote::quote;
4use serde::{Deserialize, Serialize};
5
6use crate::{
7 dump, CodeGen, CodeGenContext, Error, ExprType, Node, PythonOptions, SymbolTableScopes,
8};
9
10#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
11pub enum Compares {
12 Eq,
13 NotEq,
14 Lt,
15 LtE,
16 Gt,
17 GtE,
18 Is,
19 IsNot,
20 In,
21 NotIn,
22
23 Unknown,
24}
25
26impl<'a> FromPyObject<'a> for Compares {
27 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
28 let err_msg = format!("Unimplemented unary op {}", dump(ob, None)?);
29 Err(pyo3::exceptions::PyValueError::new_err(
30 ob.error_message("<unknown>", err_msg),
31 ))
32 }
33}
34
35#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
36pub struct Compare {
37 ops: Vec<Compares>,
38 left: Box<ExprType>,
39 comparators: Vec<ExprType>,
40}
41
42impl<'a> FromPyObject<'a> for Compare {
43 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
44 log::debug!("ob: {}", dump(ob, None)?);
45
46 let ops_bound: Vec<Bound<PyAny>> = ob
48 .getattr("ops")
49 .expect(
50 ob.error_message("<unknown>", "error getting unary operator")
51 .as_str(),
52 )
53 .extract()
54 .expect("getting ops from Compare");
55
56 let mut op_list = Vec::new();
57
58 for op in ops_bound.iter() {
59 let op_type = op.get_type().name().expect(
60 ob.error_message(
61 "<unknown>",
62 "error extracting type name for binary operator",
63 )
64 .as_str(),
65 );
66
67 let op_type_str: String = op_type.extract()?;
68 let op = match op_type_str.as_str() {
69 "Eq" => Compares::Eq,
70 "NotEq" => Compares::NotEq,
71 "Lt" => Compares::Lt,
72 "LtE" => Compares::LtE,
73 "Gt" => Compares::Gt,
74 "GtE" => Compares::GtE,
75 "Is" => Compares::Is,
76 "IsNot" => Compares::IsNot,
77 "In" => Compares::In,
78 "NotIn" => Compares::NotIn,
79
80 _ => {
81 log::debug!("Found unknown Compare with type: {}", op_type_str);
82 Compares::Unknown
83 }
84 };
85 op_list.push(op);
86 }
87
88 let left = ob.getattr("left").expect(
89 ob.error_message("<unknown>", "error getting comparator")
90 .as_str(),
91 );
92
93 let comparators = ob.getattr("comparators").expect(
94 ob.error_message("<unknown>", "error getting compoarator")
95 .as_str(),
96 );
97 log::debug!(
98 "left: {}, comparators: {}",
99 dump(&left, None)?,
100 dump(&comparators, None)?
101 );
102
103 let left = left.extract().expect("getting binary operator operand");
104 let comparators: Vec<ExprType> = comparators
105 .extract()
106 .expect("getting comparators from Compare");
107
108 log::debug!(
109 "left: {:?}, comparators: {:?}, op: {:?}",
110 left,
111 comparators,
112 op_list
113 );
114
115 return Ok(Compare {
116 ops: op_list,
117 left: Box::new(left),
118 comparators: comparators,
119 });
120 }
121}
122
123impl CodeGen for Compare {
124 type Context = CodeGenContext;
125 type Options = PythonOptions;
126 type SymbolTable = SymbolTableScopes;
127
128 fn to_rust(
129 self,
130 ctx: Self::Context,
131 options: Self::Options,
132 symbols: Self::SymbolTable,
133 ) -> Result<TokenStream, Box<dyn std::error::Error>> {
134 let mut outer_ts = TokenStream::new();
135 let left = self
136 .left
137 .clone()
138 .to_rust(ctx.clone(), options.clone(), symbols.clone())?;
139 let ops = self.ops.clone();
140 let comparators = self.comparators.clone();
141
142 let mut index = 0;
143 for op in ops.iter() {
144 let comparator = comparators
145 .get(index)
146 .expect("getting comparator")
147 .clone()
148 .to_rust(ctx.clone(), options.clone(), symbols.clone())?;
149 let tokens = match op {
150 Compares::Eq => quote!((#left) == (#comparator)),
151 Compares::NotEq => quote!((#left) != (#comparator)),
152 Compares::Lt => quote!((#left) < (#comparator)),
153 Compares::LtE => quote!((#left) <= (#comparator)),
154 Compares::Gt => quote!((#left) > (#comparator)),
155 Compares::GtE => quote!((#left) >= (#comparator)),
156 Compares::Is => quote!(&#left == &#comparator),
157 Compares::IsNot => quote!(&#left != &#comparator),
158 Compares::In => quote!((#comparator).get(#left) == Some(_)),
159 Compares::NotIn => quote!((#comparator).get(#left) == None),
160
161 _ => return Err(Error::CompareNotYetImplemented(self).into()),
162 };
163
164 index += 1;
165
166 outer_ts.extend(tokens);
167 if index < ops.len() {
168 outer_ts.extend(quote!( && ));
169 }
170 }
171 Ok(outer_ts)
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 #[test]
180 fn test_simple_eq() {
181 let options = PythonOptions::default();
182 let result = crate::parse("1 == 2", "test_case.py").unwrap();
183 log::info!("Python tree: {:?}", result);
184 let code = result.to_rust(
187 CodeGenContext::Module("test_case".to_string()),
188 options,
189 SymbolTableScopes::new(),
190 );
191 log::info!("module: {:?}", code);
192 }
193
194 #[test]
195 fn test_complex_compare() {
196 let options = PythonOptions::default();
197 let result = crate::parse("1 < a > 6", "test_case.py").unwrap();
198 log::info!("Python tree: {:?}", result);
199 let code = result.to_rust(
202 CodeGenContext::Module("test_case".to_string()),
203 options,
204 SymbolTableScopes::new(),
205 );
206 log::info!("module: {:?}", code);
207 }
208}