1use proc_macro2::TokenStream;
2use pyo3::{Bound, FromPyObject, PyAny, PyResult, prelude::PyAnyMethods};
3use quote::quote;
4use serde::{Deserialize, Serialize};
5
6use crate::{
7 dump, CodeGen, CodeGenContext, Error, ExprType, Node, PythonOptions, SymbolTableScopes,
8 PythonOperator, BinaryOperation, FromPythonString, PyAttributeExtractor,
9};
10
11#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
12pub enum BinOps {
13 Add,
14 Sub,
15 Mult,
16 Div,
17 FloorDiv,
18 Mod,
19 Pow,
20 LShift,
21 RShift,
22 BitOr,
23 BitXor,
24 BitAnd,
25 MatMult,
26
27 Unknown,
28}
29
30impl FromPythonString for BinOps {
31 fn from_python_string(s: &str) -> Option<Self> {
32 match s {
33 "Add" => Some(BinOps::Add),
34 "Sub" => Some(BinOps::Sub),
35 "Mult" => Some(BinOps::Mult),
36 "Div" => Some(BinOps::Div),
37 "FloorDiv" => Some(BinOps::FloorDiv),
38 "Mod" => Some(BinOps::Mod),
39 "Pow" => Some(BinOps::Pow),
40 "LShift" => Some(BinOps::LShift),
41 "RShift" => Some(BinOps::RShift),
42 "BitOr" => Some(BinOps::BitOr),
43 "BitXor" => Some(BinOps::BitXor),
44 "BitAnd" => Some(BinOps::BitAnd),
45 "MatMult" => Some(BinOps::MatMult),
46 _ => None,
47 }
48 }
49
50 fn unknown() -> Self {
51 BinOps::Unknown
52 }
53}
54
55impl PythonOperator for BinOps {
56 fn to_rust_op(&self) -> Result<TokenStream, Box<dyn std::error::Error>> {
57 match self {
58 BinOps::Add => Ok(quote!(+)),
59 BinOps::Sub => Ok(quote!(-)),
60 BinOps::Mult => Ok(quote!(*)),
61 BinOps::Div => Ok(quote!(as f64 /)),
62 BinOps::FloorDiv => Ok(quote!(/)),
63 BinOps::Mod => Ok(quote!(%)),
64 BinOps::Pow => Ok(quote!(.pow)),
65 BinOps::LShift => Ok(quote!(<<)),
66 BinOps::RShift => Ok(quote!(>>)),
67 BinOps::BitOr => Ok(quote!(|)),
68 BinOps::BitXor => Ok(quote!(^)),
69 BinOps::BitAnd => Ok(quote!(&)),
70 _ => Err(Error::BinOpNotYetImplemented(BinOp {
71 op: self.clone(),
72 left: Box::new(ExprType::Name(crate::Name { id: "unknown".to_string() })),
73 right: Box::new(ExprType::Name(crate::Name { id: "unknown".to_string() })),
74 }).into()),
75 }
76 }
77
78 fn precedence(&self) -> u8 {
79 match self {
80 BinOps::Pow => 8,
81 BinOps::Mult | BinOps::Div | BinOps::FloorDiv | BinOps::Mod => 7,
82 BinOps::Add | BinOps::Sub => 6,
83 BinOps::LShift | BinOps::RShift => 5,
84 BinOps::BitAnd => 4,
85 BinOps::BitXor => 3,
86 BinOps::BitOr => 2,
87 _ => 1,
88 }
89 }
90
91 fn is_unknown(&self) -> bool {
92 matches!(self, BinOps::Unknown)
93 }
94}
95
96impl<'a> FromPyObject<'a> for BinOps {
97 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
98 let err_msg = format!("Unimplemented binary op {}", dump(ob, None)?);
99 Err(pyo3::exceptions::PyValueError::new_err(
100 ob.error_message("<unknown>", err_msg),
101 ))
102 }
103}
104
105#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
106pub struct BinOp {
107 op: BinOps,
108 left: Box<ExprType>,
109 right: Box<ExprType>,
110}
111
112impl BinaryOperation for BinOp {
113 type OperatorType = BinOps;
114
115 fn operator(&self) -> &Self::OperatorType {
116 &self.op
117 }
118
119 fn left(&self) -> &ExprType {
120 &self.left
121 }
122
123 fn right(&self) -> &ExprType {
124 &self.right
125 }
126}
127
128impl<'a> FromPyObject<'a> for BinOp {
129 fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
130 log::debug!("ob: {}", dump(ob, None)?);
131
132 let op = ob.extract_attr_with_context("op", "binary operator")?;
133 let op_type_str = op.extract_type_name("binary operator")?;
134
135 let left = ob.extract_attr_with_context("left", "binary operand")?;
136 let right = ob.extract_attr_with_context("right", "binary operand")?;
137
138 log::debug!("left: {}, right: {}", dump(&left, None)?, dump(&right, None)?);
139
140 let op = BinOps::parse_or_unknown(&op_type_str);
141 if matches!(op, BinOps::Unknown) {
142 log::debug!("Found unknown BinOp {:?}", op_type_str);
143 }
144
145 let left = left.extract().expect("getting binary operator operand");
146 let right = right.extract().expect("getting binary operator operand");
147
148 Ok(BinOp {
149 op,
150 left: Box::new(left),
151 right: Box::new(right),
152 })
153 }
154}
155
156impl CodeGen for BinOp {
157 type Context = CodeGenContext;
158 type Options = PythonOptions;
159 type SymbolTable = SymbolTableScopes;
160
161 fn to_rust(
162 self,
163 ctx: Self::Context,
164 options: Self::Options,
165 symbols: Self::SymbolTable,
166 ) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
167 if matches!(self.op, BinOps::Pow) {
169 let left = self.left.clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
170 let right = self.right.clone().to_rust(ctx, options, symbols)?;
171 return Ok(quote!((#left).pow(#right)));
172 }
173
174 if matches!(self.op, BinOps::Div) {
176 let left = self.left.clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
177 let right = self.right.clone().to_rust(ctx, options, symbols)?;
178 return Ok(quote!((#left) as f64 / (#right) as f64));
179 }
180
181 if matches!(self.op, BinOps::Add) {
183 let left = self.left.clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
184 let right = self.right.clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
185 let left_str = left.to_string();
186 let right_str = right.to_string();
187
188 if left_str.contains("vec !") || right_str.contains("iter ()") || right_str.contains("sys :: argv") {
190 return Ok(quote! {
192 {
193 let mut vec = #left;
194 vec.extend(#right);
195 vec
196 }
197 });
198 }
199 }
200
201 self.generate_rust_code(ctx, options, symbols)
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use crate::create_parse_test;
210
211 create_parse_test!(test_add, "1 + 2", "test_case.py");
212 create_parse_test!(test_subtract, "1 - 2", "test_case.py");
213 create_parse_test!(test_multiply, "3 * 4", "test_case.py");
214 create_parse_test!(test_divide, "8 / 2", "test_case.py");
215 create_parse_test!(test_power, "2 ** 3", "test_case.py");
216 create_parse_test!(test_modulo, "10 % 3", "test_case.py");
217
218 #[test]
219 fn test_operator_precedence() {
220 let add_op = BinOps::Add;
221 let mul_op = BinOps::Mult;
222 let pow_op = BinOps::Pow;
223
224 assert!(pow_op.precedence() > mul_op.precedence());
225 assert!(mul_op.precedence() > add_op.precedence());
226 }
227
228 #[test]
229 fn test_unknown_operator() {
230 let unknown_op = BinOps::Unknown;
231 assert!(unknown_op.is_unknown());
232 assert!(unknown_op.to_rust_op().is_err());
233 }
234
235 #[test]
236 fn test_from_python_string() {
237 assert_eq!(BinOps::from_python_string("Add"), Some(BinOps::Add));
238 assert_eq!(BinOps::from_python_string("Unknown"), None);
239 assert_eq!(BinOps::parse_or_unknown("Invalid"), BinOps::Unknown);
240 }
241}