python_ast/ast/tree/
bin_ops.rs

1use proc_macro2::TokenStream;
2use pyo3::{Bound, FromPyObject, PyAny, PyResult, prelude::PyAnyMethods};
3use quote::quote;
4use serde::{Deserialize, Serialize};
5
6use crate::{
7    dump, CodeGen, CodeGenContext, Error, ExprType, Node, PythonOptions, SymbolTableScopes,
8    PythonOperator, BinaryOperation, FromPythonString, PyAttributeExtractor,
9};
10
11#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
12pub enum BinOps {
13    Add,
14    Sub,
15    Mult,
16    Div,
17    FloorDiv,
18    Mod,
19    Pow,
20    LShift,
21    RShift,
22    BitOr,
23    BitXor,
24    BitAnd,
25    MatMult,
26
27    Unknown,
28}
29
30impl FromPythonString for BinOps {
31    fn from_python_string(s: &str) -> Option<Self> {
32        match s {
33            "Add" => Some(BinOps::Add),
34            "Sub" => Some(BinOps::Sub),
35            "Mult" => Some(BinOps::Mult),
36            "Div" => Some(BinOps::Div),
37            "FloorDiv" => Some(BinOps::FloorDiv),
38            "Mod" => Some(BinOps::Mod),
39            "Pow" => Some(BinOps::Pow),
40            "LShift" => Some(BinOps::LShift),
41            "RShift" => Some(BinOps::RShift),
42            "BitOr" => Some(BinOps::BitOr),
43            "BitXor" => Some(BinOps::BitXor),
44            "BitAnd" => Some(BinOps::BitAnd),
45            "MatMult" => Some(BinOps::MatMult),
46            _ => None,
47        }
48    }
49    
50    fn unknown() -> Self {
51        BinOps::Unknown
52    }
53}
54
55impl PythonOperator for BinOps {
56    fn to_rust_op(&self) -> Result<TokenStream, Box<dyn std::error::Error>> {
57        match self {
58            BinOps::Add => Ok(quote!(+)),
59            BinOps::Sub => Ok(quote!(-)),
60            BinOps::Mult => Ok(quote!(*)),
61            BinOps::Div => Ok(quote!(as f64 /)),
62            BinOps::FloorDiv => Ok(quote!(/)),
63            BinOps::Mod => Ok(quote!(%)),
64            BinOps::Pow => Ok(quote!(.pow)),
65            BinOps::LShift => Ok(quote!(<<)),
66            BinOps::RShift => Ok(quote!(>>)),
67            BinOps::BitOr => Ok(quote!(|)),
68            BinOps::BitXor => Ok(quote!(^)),
69            BinOps::BitAnd => Ok(quote!(&)),
70            _ => Err(Error::BinOpNotYetImplemented(BinOp { 
71                op: self.clone(), 
72                left: Box::new(ExprType::Name(crate::Name { id: "unknown".to_string() })),
73                right: Box::new(ExprType::Name(crate::Name { id: "unknown".to_string() })),
74            }).into()),
75        }
76    }
77    
78    fn precedence(&self) -> u8 {
79        match self {
80            BinOps::Pow => 8,
81            BinOps::Mult | BinOps::Div | BinOps::FloorDiv | BinOps::Mod => 7,
82            BinOps::Add | BinOps::Sub => 6,
83            BinOps::LShift | BinOps::RShift => 5,
84            BinOps::BitAnd => 4,
85            BinOps::BitXor => 3,
86            BinOps::BitOr => 2,
87            _ => 1,
88        }
89    }
90    
91    fn is_unknown(&self) -> bool {
92        matches!(self, BinOps::Unknown)
93    }
94}
95
96impl<'a> FromPyObject<'a> for BinOps {
97    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
98        let err_msg = format!("Unimplemented binary op {}", dump(ob, None)?);
99        Err(pyo3::exceptions::PyValueError::new_err(
100            ob.error_message("<unknown>", err_msg),
101        ))
102    }
103}
104
105#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
106pub struct BinOp {
107    op: BinOps,
108    left: Box<ExprType>,
109    right: Box<ExprType>,
110}
111
112impl BinaryOperation for BinOp {
113    type OperatorType = BinOps;
114    
115    fn operator(&self) -> &Self::OperatorType {
116        &self.op
117    }
118    
119    fn left(&self) -> &ExprType {
120        &self.left
121    }
122    
123    fn right(&self) -> &ExprType {
124        &self.right
125    }
126}
127
128impl<'a> FromPyObject<'a> for BinOp {
129    fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
130        log::debug!("ob: {}", dump(ob, None)?);
131        
132        let op = ob.extract_attr_with_context("op", "binary operator")?;
133        let op_type_str = op.extract_type_name("binary operator")?;
134        
135        let left = ob.extract_attr_with_context("left", "binary operand")?;
136        let right = ob.extract_attr_with_context("right", "binary operand")?;
137        
138        log::debug!("left: {}, right: {}", dump(&left, None)?, dump(&right, None)?);
139
140        let op = BinOps::parse_or_unknown(&op_type_str);
141        if matches!(op, BinOps::Unknown) {
142            log::debug!("Found unknown BinOp {:?}", op_type_str);
143        }
144
145        let left = left.extract().expect("getting binary operator operand");
146        let right = right.extract().expect("getting binary operator operand");
147
148        Ok(BinOp {
149            op,
150            left: Box::new(left),
151            right: Box::new(right),
152        })
153    }
154}
155
156impl CodeGen for BinOp {
157    type Context = CodeGenContext;
158    type Options = PythonOptions;
159    type SymbolTable = SymbolTableScopes;
160
161    fn to_rust(
162        self,
163        ctx: Self::Context,
164        options: Self::Options,
165        symbols: Self::SymbolTable,
166    ) -> std::result::Result<TokenStream, Box<dyn std::error::Error>> {
167        // Special handling for Pow operator which needs different syntax
168        if matches!(self.op, BinOps::Pow) {
169            let left = self.left.clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
170            let right = self.right.clone().to_rust(ctx, options, symbols)?;
171            return Ok(quote!((#left).pow(#right)));
172        }
173        
174        // For Div, we need to cast to f64
175        if matches!(self.op, BinOps::Div) {
176            let left = self.left.clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
177            let right = self.right.clone().to_rust(ctx, options, symbols)?;
178            return Ok(quote!((#left) as f64 / (#right) as f64));
179        }
180        
181        // Special handling for list addition (concatenation)
182        if matches!(self.op, BinOps::Add) {
183            let left = self.left.clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
184            let right = self.right.clone().to_rust(ctx.clone(), options.clone(), symbols.clone())?;
185            let left_str = left.to_string();
186            let right_str = right.to_string();
187            
188            // Check if we're adding vectors or lists together
189            if left_str.contains("vec !") || right_str.contains("iter ()") || right_str.contains("sys :: argv") {
190                // This is vector concatenation - use Vec::extend pattern
191                return Ok(quote! {
192                    {
193                        let mut vec = #left;
194                        vec.extend(#right);
195                        vec
196                    }
197                });
198            }
199        }
200        
201        // Use the generic binary operation implementation for everything else
202        self.generate_rust_code(ctx, options, symbols)
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use crate::create_parse_test;
210
211    create_parse_test!(test_add, "1 + 2", "test_case.py");
212    create_parse_test!(test_subtract, "1 - 2", "test_case.py");
213    create_parse_test!(test_multiply, "3 * 4", "test_case.py");
214    create_parse_test!(test_divide, "8 / 2", "test_case.py");
215    create_parse_test!(test_power, "2 ** 3", "test_case.py");
216    create_parse_test!(test_modulo, "10 % 3", "test_case.py");
217    
218    #[test]
219    fn test_operator_precedence() {
220        let add_op = BinOps::Add;
221        let mul_op = BinOps::Mult;
222        let pow_op = BinOps::Pow;
223        
224        assert!(pow_op.precedence() > mul_op.precedence());
225        assert!(mul_op.precedence() > add_op.precedence());
226    }
227    
228    #[test]
229    fn test_unknown_operator() {
230        let unknown_op = BinOps::Unknown;
231        assert!(unknown_op.is_unknown());
232        assert!(unknown_op.to_rust_op().is_err());
233    }
234    
235    #[test]
236    fn test_from_python_string() {
237        assert_eq!(BinOps::from_python_string("Add"), Some(BinOps::Add));
238        assert_eq!(BinOps::from_python_string("Unknown"), None);
239        assert_eq!(BinOps::parse_or_unknown("Invalid"), BinOps::Unknown);
240    }
241}