python_ast/ast/tree/
bin_ops.rs

1use proc_macro2::TokenStream;
2use pyo3::{FromPyObject, PyAny, PyResult};
3use quote::quote;
4use serde::{Deserialize, Serialize};
5
6use crate::{
7    dump, CodeGen, CodeGenContext, Error, ExprType, Node, PythonOptions, SymbolTableScopes,
8};
9
10#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
11pub enum BinOps {
12    Add,
13    Sub,
14    Mult,
15    Div,
16    FloorDiv,
17    Mod,
18    Pow,
19    LShift,
20    RShift,
21    BitOr,
22    BitXor,
23    BitAnd,
24    MatMult,
25
26    Unknown,
27}
28
29impl<'a> FromPyObject<'a> for BinOps {
30    fn extract(ob: &'a PyAny) -> PyResult<Self> {
31        let err_msg = format!("Unimplemented unary op {}", dump(ob, None)?);
32        Err(pyo3::exceptions::PyValueError::new_err(
33            ob.error_message("<unknown>", err_msg),
34        ))
35    }
36}
37
38#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
39pub struct BinOp {
40    op: BinOps,
41    left: Box<ExprType>,
42    right: Box<ExprType>,
43}
44
45impl<'a> FromPyObject<'a> for BinOp {
46    fn extract(ob: &'a PyAny) -> PyResult<Self> {
47        log::debug!("ob: {}", dump(ob, None)?);
48        let op = ob.getattr("op").expect(
49            ob.error_message("<unknown>", "error getting unary operator")
50                .as_str(),
51        );
52
53        let op_type = op.get_type().name().expect(
54            ob.error_message(
55                "<unknown>",
56                format!("extracting type name {:?} for binary operator", op),
57            )
58            .as_str(),
59        );
60
61        let left = ob.getattr("left").expect(
62            ob.error_message("<unknown>", "error getting binary operand")
63                .as_str(),
64        );
65
66        let right = ob.getattr("right").expect(
67            ob.error_message("<unknown>", "error getting binary operand")
68                .as_str(),
69        );
70        log::debug!("left: {}, right: {}", dump(left, None)?, dump(right, None)?);
71
72        let op = match op_type.as_ref() {
73            "Add" => BinOps::Add,
74            "Sub" => BinOps::Sub,
75            "Mult" => BinOps::Mult,
76            "Div" => BinOps::Div,
77            "FloorDiv" => BinOps::FloorDiv,
78            "Mod" => BinOps::Mod,
79            "Pow" => BinOps::Pow,
80            "LShift" => BinOps::LShift,
81            "RShift" => BinOps::RShift,
82            "BitOr" => BinOps::BitOr,
83            "BitXor" => BinOps::BitXor,
84            "BitAnd" => BinOps::BitAnd,
85            "MatMult" => BinOps::MatMult,
86
87            _ => {
88                log::debug!("Found unknown BinOp {:?}", op);
89                BinOps::Unknown
90            }
91        };
92
93        log::debug!(
94            "left: {}, right: {}, op: {:?}/{:?}",
95            dump(left, None)?,
96            dump(right, None)?,
97            op_type,
98            op
99        );
100
101        let right = ExprType::extract(right).expect("getting binary operator operand");
102        let left = ExprType::extract(left).expect("getting binary operator operand");
103
104        return Ok(BinOp {
105            op: op,
106            left: Box::new(left),
107            right: Box::new(right),
108        });
109    }
110}
111
112impl<'a> CodeGen for BinOp {
113    type Context = CodeGenContext;
114    type Options = PythonOptions;
115    type SymbolTable = SymbolTableScopes;
116
117    fn to_rust(
118        self,
119        ctx: Self::Context,
120        options: Self::Options,
121        symbols: Self::SymbolTable,
122    ) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
123        let left = self
124            .left
125            .clone()
126            .to_rust(ctx.clone(), options.clone(), symbols.clone())?;
127        let right = self
128            .right
129            .clone()
130            .to_rust(ctx.clone(), options.clone(), symbols.clone())?;
131        match self.op {
132            BinOps::Add => Ok(quote!((#left) + (#right))),
133            BinOps::Sub => Ok(quote!((#left) - (#right))),
134            BinOps::Mult => Ok(quote!((#left) * (#right))),
135            BinOps::Div => Ok(quote!((#left) as f64 / (#right) as f64)),
136            BinOps::FloorDiv => Ok(quote!((#left) / (#right))),
137            BinOps::Mod => Ok(quote!((#left) % (#right))),
138            BinOps::Pow => Ok(quote!((#left).pow(#right))),
139            BinOps::LShift => Ok(quote!((#left) << (#right))),
140            BinOps::RShift => Ok(quote!((#left) >> (#right))),
141            BinOps::BitOr => Ok(quote!((#left) | (#right))),
142            BinOps::BitXor => Ok(quote!((#left) ^ (#right))),
143            BinOps::BitAnd => Ok(quote!((#left) & (#right))),
144            //MatMult, XXX implement this
145            _ => Err(Error::BinOpNotYetImplemented(self).into()),
146        }
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn test_add() {
156        let options = PythonOptions::default();
157        let result = crate::parse("1 + 2", "test_case.py").unwrap();
158        log::info!("Python tree: {:?}", result);
159        //info!("{}", result);
160
161        let code = result.to_rust(
162            CodeGenContext::Module("test_case".to_string()),
163            options,
164            SymbolTableScopes::new(),
165        );
166        log::info!("module: {:?}", code);
167    }
168
169    #[test]
170    fn test_subtract() {
171        let options = PythonOptions::default();
172        let result = crate::parse("1 - 2", "test_case.py").unwrap();
173        log::info!("Python tree: {:?}", result);
174        //info!("{}", result);
175
176        let code = result.to_rust(
177            CodeGenContext::Module("test_case".to_string()),
178            options,
179            SymbolTableScopes::new(),
180        );
181        log::info!("module: {:?}", code);
182    }
183}