1use proc_macro2::TokenStream;
2use pyo3::{FromPyObject, PyAny, PyResult};
3use quote::quote;
4use serde::{Deserialize, Serialize};
5
6use crate::{
7 dump, CodeGen, CodeGenContext, Error, ExprType, Node, PythonOptions, SymbolTableScopes,
8};
9
10#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
11pub enum BinOps {
12 Add,
13 Sub,
14 Mult,
15 Div,
16 FloorDiv,
17 Mod,
18 Pow,
19 LShift,
20 RShift,
21 BitOr,
22 BitXor,
23 BitAnd,
24 MatMult,
25
26 Unknown,
27}
28
29impl<'a> FromPyObject<'a> for BinOps {
30 fn extract(ob: &'a PyAny) -> PyResult<Self> {
31 let err_msg = format!("Unimplemented unary op {}", dump(ob, None)?);
32 Err(pyo3::exceptions::PyValueError::new_err(
33 ob.error_message("<unknown>", err_msg),
34 ))
35 }
36}
37
38#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
39pub struct BinOp {
40 op: BinOps,
41 left: Box<ExprType>,
42 right: Box<ExprType>,
43}
44
45impl<'a> FromPyObject<'a> for BinOp {
46 fn extract(ob: &'a PyAny) -> PyResult<Self> {
47 log::debug!("ob: {}", dump(ob, None)?);
48 let op = ob.getattr("op").expect(
49 ob.error_message("<unknown>", "error getting unary operator")
50 .as_str(),
51 );
52
53 let op_type = op.get_type().name().expect(
54 ob.error_message(
55 "<unknown>",
56 format!("extracting type name {:?} for binary operator", op),
57 )
58 .as_str(),
59 );
60
61 let left = ob.getattr("left").expect(
62 ob.error_message("<unknown>", "error getting binary operand")
63 .as_str(),
64 );
65
66 let right = ob.getattr("right").expect(
67 ob.error_message("<unknown>", "error getting binary operand")
68 .as_str(),
69 );
70 log::debug!("left: {}, right: {}", dump(left, None)?, dump(right, None)?);
71
72 let op = match op_type.as_ref() {
73 "Add" => BinOps::Add,
74 "Sub" => BinOps::Sub,
75 "Mult" => BinOps::Mult,
76 "Div" => BinOps::Div,
77 "FloorDiv" => BinOps::FloorDiv,
78 "Mod" => BinOps::Mod,
79 "Pow" => BinOps::Pow,
80 "LShift" => BinOps::LShift,
81 "RShift" => BinOps::RShift,
82 "BitOr" => BinOps::BitOr,
83 "BitXor" => BinOps::BitXor,
84 "BitAnd" => BinOps::BitAnd,
85 "MatMult" => BinOps::MatMult,
86
87 _ => {
88 log::debug!("Found unknown BinOp {:?}", op);
89 BinOps::Unknown
90 }
91 };
92
93 log::debug!(
94 "left: {}, right: {}, op: {:?}/{:?}",
95 dump(left, None)?,
96 dump(right, None)?,
97 op_type,
98 op
99 );
100
101 let right = ExprType::extract(right).expect("getting binary operator operand");
102 let left = ExprType::extract(left).expect("getting binary operator operand");
103
104 return Ok(BinOp {
105 op: op,
106 left: Box::new(left),
107 right: Box::new(right),
108 });
109 }
110}
111
112impl<'a> CodeGen for BinOp {
113 type Context = CodeGenContext;
114 type Options = PythonOptions;
115 type SymbolTable = SymbolTableScopes;
116
117 fn to_rust(
118 self,
119 ctx: Self::Context,
120 options: Self::Options,
121 symbols: Self::SymbolTable,
122 ) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
123 let left = self
124 .left
125 .clone()
126 .to_rust(ctx.clone(), options.clone(), symbols.clone())?;
127 let right = self
128 .right
129 .clone()
130 .to_rust(ctx.clone(), options.clone(), symbols.clone())?;
131 match self.op {
132 BinOps::Add => Ok(quote!((#left) + (#right))),
133 BinOps::Sub => Ok(quote!((#left) - (#right))),
134 BinOps::Mult => Ok(quote!((#left) * (#right))),
135 BinOps::Div => Ok(quote!((#left) as f64 / (#right) as f64)),
136 BinOps::FloorDiv => Ok(quote!((#left) / (#right))),
137 BinOps::Mod => Ok(quote!((#left) % (#right))),
138 BinOps::Pow => Ok(quote!((#left).pow(#right))),
139 BinOps::LShift => Ok(quote!((#left) << (#right))),
140 BinOps::RShift => Ok(quote!((#left) >> (#right))),
141 BinOps::BitOr => Ok(quote!((#left) | (#right))),
142 BinOps::BitXor => Ok(quote!((#left) ^ (#right))),
143 BinOps::BitAnd => Ok(quote!((#left) & (#right))),
144 _ => Err(Error::BinOpNotYetImplemented(self).into()),
146 }
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 #[test]
155 fn test_add() {
156 let options = PythonOptions::default();
157 let result = crate::parse("1 + 2", "test_case.py").unwrap();
158 log::info!("Python tree: {:?}", result);
159 let code = result.to_rust(
162 CodeGenContext::Module("test_case".to_string()),
163 options,
164 SymbolTableScopes::new(),
165 );
166 log::info!("module: {:?}", code);
167 }
168
169 #[test]
170 fn test_subtract() {
171 let options = PythonOptions::default();
172 let result = crate::parse("1 - 2", "test_case.py").unwrap();
173 log::info!("Python tree: {:?}", result);
174 let code = result.to_rust(
177 CodeGenContext::Module("test_case".to_string()),
178 options,
179 SymbolTableScopes::new(),
180 );
181 log::info!("module: {:?}", code);
182 }
183}