1use proc_macro2::TokenStream;
2use pyo3::{FromPyObject, PyAny, PyResult};
3use quote::quote;
4use serde::{Deserialize, Serialize};
5
6use crate::{
7 dump, CodeGen, CodeGenContext, Error, ExprType, Node, PythonOptions, SymbolTableScopes,
8};
9
10#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
11pub enum Compares {
12 Eq,
13 NotEq,
14 Lt,
15 LtE,
16 Gt,
17 GtE,
18 Is,
19 IsNot,
20 In,
21 NotIn,
22
23 Unknown,
24}
25
26impl<'a> FromPyObject<'a> for Compares {
27 fn extract(ob: &'a PyAny) -> PyResult<Self> {
28 let err_msg = format!("Unimplemented unary op {}", dump(ob, None)?);
29 Err(pyo3::exceptions::PyValueError::new_err(
30 ob.error_message("<unknown>", err_msg),
31 ))
32 }
33}
34
35#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
36pub struct Compare {
37 ops: Vec<Compares>,
38 left: Box<ExprType>,
39 comparators: Vec<ExprType>,
40}
41
42impl<'a> FromPyObject<'a> for Compare {
43 fn extract(ob: &'a PyAny) -> PyResult<Self> {
44 log::debug!("ob: {}", dump(ob, None)?);
45
46 let ops: Vec<&PyAny> = ob
48 .getattr("ops")
49 .expect(
50 ob.error_message("<unknown>", "error getting unary operator")
51 .as_str(),
52 )
53 .extract()
54 .expect("getting ops from Compare");
55
56 let mut op_list = Vec::new();
57
58 for op in ops.iter() {
59 let op_type = op.get_type().name().expect(
60 ob.error_message(
61 "<unknown>",
62 format!("extracting type name {:?} for binary operator", op),
63 )
64 .as_str(),
65 );
66
67 let op = match op_type.as_ref() {
68 "Eq" => Compares::Eq,
69 "NotEq" => Compares::NotEq,
70 "Lt" => Compares::Lt,
71 "LtE" => Compares::LtE,
72 "Gt" => Compares::Gt,
73 "GtE" => Compares::GtE,
74 "Is" => Compares::Is,
75 "IsNot" => Compares::IsNot,
76 "In" => Compares::In,
77 "NotIn" => Compares::NotIn,
78
79 _ => {
80 log::debug!("Found unknown Compare {:?}", op);
81 Compares::Unknown
82 }
83 };
84 op_list.push(op);
85 }
86
87 let left = ob.getattr("left").expect(
88 ob.error_message("<unknown>", "error getting comparator")
89 .as_str(),
90 );
91
92 let comparators = ob.getattr("comparators").expect(
93 ob.error_message("<unknown>", "error getting compoarator")
94 .as_str(),
95 );
96 log::debug!(
97 "left: {}, comparators: {}",
98 dump(left, None)?,
99 dump(comparators, None)?
100 );
101
102 let left = ExprType::extract(left).expect("getting binary operator operand");
103 let comparators: Vec<ExprType> = comparators
104 .extract()
105 .expect("getting comparators from Compare");
106
107 log::debug!(
108 "left: {:?}, comparators: {:?}, op: {:?}",
109 left,
110 comparators,
111 op_list
112 );
113
114 return Ok(Compare {
115 ops: op_list,
116 left: Box::new(left),
117 comparators: comparators,
118 });
119 }
120}
121
122impl CodeGen for Compare {
123 type Context = CodeGenContext;
124 type Options = PythonOptions;
125 type SymbolTable = SymbolTableScopes;
126
127 fn to_rust(
128 self,
129 ctx: Self::Context,
130 options: Self::Options,
131 symbols: Self::SymbolTable,
132 ) -> Result<TokenStream, Box<dyn std::error::Error>> {
133 let mut outer_ts = TokenStream::new();
134 let left = self
135 .left
136 .clone()
137 .to_rust(ctx.clone(), options.clone(), symbols.clone())?;
138 let ops = self.ops.clone();
139 let comparators = self.comparators.clone();
140
141 let mut index = 0;
142 for op in ops.iter() {
143 let comparator = comparators
144 .get(index)
145 .expect("getting comparator")
146 .clone()
147 .to_rust(ctx.clone(), options.clone(), symbols.clone())?;
148 let tokens = match op {
149 Compares::Eq => quote!(((#left) == (#comparator))),
150 Compares::NotEq => quote!(((#left) != (#comparator))),
151 Compares::Lt => quote!(((#left) < (#comparator))),
152 Compares::LtE => quote!(((#left) <= (#comparator))),
153 Compares::Gt => quote!(((#left) > (#comparator))),
154 Compares::GtE => quote!(((#left) >= (#comparator))),
155 Compares::Is => quote!((&(#left) == &(#comparator))),
156 Compares::IsNot => quote!((&(#left) != &(#comparator))),
157 Compares::In => quote!(((#comparator).get(#left) == Some(_))),
158 Compares::NotIn => quote!(((#comparator).get(#left) == None)),
159
160 _ => return Err(Error::CompareNotYetImplemented(self).into()),
161 };
162
163 index += 1;
164
165 outer_ts.extend(tokens);
166 if index < ops.len() {
167 outer_ts.extend(quote!( && ));
168 }
169 }
170 Ok(outer_ts)
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177
178 #[test]
179 fn test_simple_eq() {
180 let options = PythonOptions::default();
181 let result = crate::parse("1 == 2", "test_case.py").unwrap();
182 log::info!("Python tree: {:?}", result);
183 let code = result.to_rust(
186 CodeGenContext::Module("test_case".to_string()),
187 options,
188 SymbolTableScopes::new(),
189 );
190 log::info!("module: {:?}", code);
191 }
192
193 #[test]
194 fn test_complex_compare() {
195 let options = PythonOptions::default();
196 let result = crate::parse("1 < a > 6", "test_case.py").unwrap();
197 log::info!("Python tree: {:?}", result);
198 let code = result.to_rust(
201 CodeGenContext::Module("test_case".to_string()),
202 options,
203 SymbolTableScopes::new(),
204 );
205 log::info!("module: {:?}", code);
206 }
207}