doctor_syn/transformation/
expand.rs

1use crate::error::{Result, Error};
2use crate::visitor::Visitor;
3use syn::{Expr, ExprUnary, ExprBinary, UnOp, BinOp};
4use syn::spanned::Spanned;
5use super::tools::*;
6
7#[derive(Debug)]
8pub struct Expand {
9}
10
11// Get a series of additions/subtractions.
12fn match_sum(expr: &Expr, sum: &mut Vec<Expr>, is_negated: bool) -> Result<()> {
13    let expr = deparen(expr);
14    match expr {
15        Expr::Binary(ExprBinary { left, op, right, ..}) => {
16            match op {
17                BinOp::Add(_) => {
18                    match_sum(left, sum, is_negated)?;
19                    match_sum(right, sum, is_negated)?;
20                    return Ok(());
21                }
22                BinOp::Sub(_) => {
23                    match_sum(left, sum, is_negated)?;
24                    match_sum(right, sum, !is_negated)?;
25                    return Ok(());
26                }
27                BinOp::Mul(_) => {
28                    let mut left_sum = Vec::new();
29                    match_sum(left, &mut left_sum, is_negated)?;
30                    let mut right_sum = Vec::new();
31                    match_sum(right, &mut right_sum, false)?;
32                    for lhs in left_sum.iter() {
33                        for rhs in right_sum.iter() {
34                            sum.push(make_binary(lhs.clone(), op.clone(), rhs.clone()));
35                        }
36                    }
37                    return Ok(());
38                }
39                _ => {
40                }
41            }
42        }
43        _ => {
44        }
45    }
46    if is_negated {
47        sum.push(negate(&expr));
48    } else {
49        sum.push(expr.clone());
50    }
51    Ok(())
52}
53
54impl Visitor for Expand {
55    fn visit_unary(&self, exprunary: &ExprUnary) -> Result<Expr> {
56        match exprunary.op {
57            UnOp::Deref(_) => Err(Error::UnsupportedExpr(exprunary.span())),
58            UnOp::Not(_) => Err(Error::UnsupportedExpr(exprunary.span())),
59            UnOp::Neg(_) => {
60                let mut sum = Vec::new();
61                match_sum(deparen(&exprunary.expr), &mut sum, true)?;
62                Ok(make_sum(&*sum))
63            },
64        }
65    }
66
67    fn visit_binary(&self, exprbinary: &ExprBinary) -> Result<Expr> {
68        let mut sum = Vec::new();
69        let expr : Expr = exprbinary.clone().into(); 
70        match_sum(&expr, &mut sum, false)?;
71        Ok(make_sum(&*sum))
72    }
73}
74
75#[test]
76fn expand() -> Result<()> {
77    use crate::expr;
78    // Unary
79    assert_eq!(expr!(- - x).expand()?, expr!(x));
80    assert_eq!(expr!(-(x+1)).expand()?, expr!(-x-1));
81
82    // Binary add/sub
83    assert_eq!(expr!((x+1)+(x+1)).expand()?, expr!(x + 1 + x + 1));
84    assert_eq!(expr!((x+1)+((x+1)+(x+1))).expand()?, expr!(x + 1 + x + 1 + x + 1));
85    assert_eq!(expr!((x+1)-(x+1)).expand()?, expr!(x + 1 - x - 1));
86    assert_eq!(expr!((x+1)-((x+1)-(x+1))).expand()?, expr!(x + 1 - x - 1 + x + 1));
87    assert_eq!(expr!((x+1)-((x+1)-(-x+1))).expand()?, expr!(x + 1 - x - 1 - x + 1));
88
89    // Binary mul
90    assert_eq!(expr!(x*x).expand()?, expr!(x * x));
91    assert_eq!(expr!(x*(x+1)).expand()?, expr!(x * x + x * 1));
92    assert_eq!(expr!((x+1)*x).expand()?, expr!(x * x + 1 * x));
93    assert_eq!(expr!((x+1)*(x+1)).expand()?, expr!(x * x + x * 1 + 1 * x + 1 * 1));
94    assert_eq!(expr!((x+1)*(x+1)*(x+1)).expand()?, expr!(x * x * x + x * x * 1 + x * 1 * x + x * 1 * 1 + 1 * x * x + 1 * x * 1 + 1 * 1 * x + 1 * 1 * 1));
95    Ok(())
96}