1use proc_macro2::TokenStream;
2use pyo3::{Bound, FromPyObject, PyAny, PyResult, prelude::PyAnyMethods};
3use quote::quote;
4use serde::{Deserialize, Serialize};
5
6use crate::{
7 dump, CodeGen, CodeGenContext, Error, ExprType, Node, PythonOptions, SymbolTableScopes,
8 PythonOperator, BinaryOperation, FromPythonString, PyAttributeExtractor,
9};
10
11#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
12pub enum BinOps {
13 Add,
14 Sub,
15 Mult,
16 Div,
17 FloorDiv,
18 Mod,
19 Pow,
20 LShift,
21 RShift,
22 BitOr,
23 BitXor,
24 BitAnd,
25 MatMult,
26
27 Unknown,
28}
29
30impl FromPythonString for BinOps {
31 fn from_python_string(s: &str) -> Option<Self> {
32 match s {
33 "Add" => Some(BinOps::Add),
34 "Sub" => Some(BinOps::Sub),
35 "Mult" => Some(BinOps::Mult),
36 "Div" => Some(BinOps::Div),
37 "FloorDiv" => Some(BinOps::FloorDiv),
38 "Mod" => Some(BinOps::Mod),
39 "Pow" => Some(BinOps::Pow),
40 "LShift" => Some(BinOps::LShift),
41 "RShift" => Some(BinOps::RShift),
42 "BitOr" => Some(BinOps::BitOr),
43 "BitXor" => Some(BinOps::BitXor),
44 "BitAnd" => Some(BinOps::BitAnd),
45 "MatMult" => Some(BinOps::MatMult),
46 _ => None,
47 }
48 }
49
50 fn unknown() -> Self {
51 BinOps::Unknown
52 }
53}
54
55impl PythonOperator for BinOps {
56 fn to_rust_op(&self) -> Result<TokenStream, Box<dyn std::error::Error>> {
57 match self {
58 BinOps::Add => Ok(quote!(+)),
59 BinOps::Sub => Ok(quote!(-)),
60 BinOps::Mult => Ok(quote!(*)),
61 BinOps::Div => Ok(quote!(as f64 /)),
62 BinOps::FloorDiv => Ok(quote!(/)),
63 BinOps::Mod => Ok(quote!(%)),
64 BinOps::Pow => Ok(quote!(.pow)),
65 BinOps::LShift => Ok(quote!(<<)),
66 BinOps::RShift => Ok(quote!(>>)),
67 BinOps::BitOr => Ok(quote!(|)),
68 BinOps::BitXor => Ok(quote!(^)),
69 BinOps::BitAnd => Ok(quote!(&)),
70 _ => Err(Error::BinOpNotYetImplemented(BinOp {
71 op: self.clone(),
72 left: Box::new(ExprType::Name(crate::Name { id: "unknown".to_string() })),
73 right: Box::new(ExprType::Name(crate::Name { id: "unknown".to_string() })),
74 }).into()),
75 }
76 }
77
78 fn precedence(&self) -> u8 {
79 match self {
80 BinOps::Pow => 8,
81 BinOps::Mult | BinOps::Div | BinOps::FloorDiv | BinOps::Mod => 7,
82 BinOps::Add | BinOps::Sub => 6,
83 BinOps::LShift | BinOps::RShift => 5,
84 BinOps::BitAnd => 4,
85 BinOps::BitXor => 3,
86 BinOps::BitOr => 2,
87 _ => 1,
88 }
89 }
90
91 fn is_unknown(&self) -> bool {
92 matches!(self, BinOps::Unknown)
93 }
94}
95
96impl<'a> FromPyObject<'a> for BinOps {
97 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
98 let err_msg = format!("Unimplemented binary op {}", dump(ob, None)?);
99 Err(pyo3::exceptions::PyValueError::new_err(
100 ob.error_message("<unknown>", err_msg),
101 ))
102 }
103}
104
105#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
106pub struct BinOp {
107 op: BinOps,
108 left: Box<ExprType>,
109 right: Box<ExprType>,
110}
111
112impl BinaryOperation for BinOp {
113 type OperatorType = BinOps;
114
115 fn operator(&self) -> &Self::OperatorType {
116 &self.op
117 }
118
119 fn left(&self) -> &ExprType {
120 &self.left
121 }
122
123 fn right(&self) -> &ExprType {
124 &self.right
125 }
126}
127
128impl<'a> FromPyObject<'a> for BinOp {
129 fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
130 log::debug!("ob: {}", dump(ob, None)?);
131
132 let op = ob.extract_attr_with_context("op", "binary operator")?;
133 let op_type_str = op.extract_type_name("binary operator")?;
134
135 let left = ob.extract_attr_with_context("left", "binary operand")?;
136 let right = ob.extract_attr_with_context("right", "binary operand")?;
137
138 log::debug!("left: {}, right: {}", dump(&left, None)?, dump(&right, None)?);
139
140 let op = BinOps::parse_or_unknown(&op_type_str);
141 if matches!(op, BinOps::Unknown) {
142 log::debug!("Found unknown BinOp {:?}", op_type_str);
143 }
144
145 let left = left.extract().expect("getting binary operator operand");
146 let right = right.extract().expect("getting binary operator operand");
147
148 Ok(BinOp {
149 op,
150 left: Box::new(left),
151 right: Box::new(right),
152 })
153 }
154}
155
156impl CodeGen for BinOp {
157 type Context = CodeGenContext;
158 type Options = PythonOptions;
159 type SymbolTable = SymbolTableScopes;
160
161 fn to_rust(
162 self,
163 ctx: Self::Context,
164 options: Self::Options,
165 symbols: Self::SymbolTable,
166 ) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
167 if matches!(self.op, BinOps::Pow) {
169 let left = self.left.clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
170 let right = self.right.clone().to_rust(ctx, options, symbols)?;
171 return Ok(quote!((#left).pow(#right)));
172 }
173
174 if matches!(self.op, BinOps::Div) {
176 let left = self.left.clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
177 let right = self.right.clone().to_rust(ctx, options, symbols)?;
178 return Ok(quote!((#left) as f64 / (#right) as f64));
179 }
180
181 self.generate_rust_code(ctx, options, symbols)
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use crate::create_parse_test;
190
191 create_parse_test!(test_add, "1 + 2", "test_case.py");
192 create_parse_test!(test_subtract, "1 - 2", "test_case.py");
193 create_parse_test!(test_multiply, "3 * 4", "test_case.py");
194 create_parse_test!(test_divide, "8 / 2", "test_case.py");
195 create_parse_test!(test_power, "2 ** 3", "test_case.py");
196 create_parse_test!(test_modulo, "10 % 3", "test_case.py");
197
198 #[test]
199 fn test_operator_precedence() {
200 let add_op = BinOps::Add;
201 let mul_op = BinOps::Mult;
202 let pow_op = BinOps::Pow;
203
204 assert!(pow_op.precedence() > mul_op.precedence());
205 assert!(mul_op.precedence() > add_op.precedence());
206 }
207
208 #[test]
209 fn test_unknown_operator() {
210 let unknown_op = BinOps::Unknown;
211 assert!(unknown_op.is_unknown());
212 assert!(unknown_op.to_rust_op().is_err());
213 }
214
215 #[test]
216 fn test_from_python_string() {
217 assert_eq!(BinOps::from_python_string("Add"), Some(BinOps::Add));
218 assert_eq!(BinOps::from_python_string("Unknown"), None);
219 assert_eq!(BinOps::parse_or_unknown("Invalid"), BinOps::Unknown);
220 }
221}