symrs/expr/
pow.rs

1use super::*;
2
3#[derive(Clone)]
4pub struct Pow {
5    pub base: Box<dyn Expr>,
6    pub exponent: Box<dyn Expr>,
7}
8
9impl Expr for Pow {
10    fn known_expr(&self) -> KnownExpr {
11        KnownExpr::Pow(self)
12    }
13
14    fn as_pow(&self) -> Option<&Pow> {
15        Some(self)
16    }
17    fn get_ref<'a>(&'a self) -> &'a dyn Expr {
18        self as &dyn Expr
19    }
20    fn for_each_arg(&self, f: &mut dyn FnMut(&dyn Arg) -> ()) {
21        f(&*self.base);
22        f(&*self.exponent)
23    }
24
25    fn from_args(&self, args: Vec<Box<dyn Arg>>) -> Box<dyn Expr> {
26        Box::new(Pow {
27            base: args[0].clone().into(),
28            exponent: args[1].clone().into(),
29        })
30    }
31
32    fn clone_box(&self) -> Box<dyn Expr> {
33        Box::new(self.clone())
34    }
35
36    fn is_number(&self) -> bool {
37        self.base.is_number() && self.exponent.is_number()
38    }
39
40    fn str(&self) -> String {
41        match (
42            self.base.known_expr(),
43            KnownExpr::from_expr_box(&self.exponent),
44        ) {
45            (KnownExpr::Rational(r), _) => format!("({})^{}", r.str(), self.exponent.str()),
46            (_, KnownExpr::Integer(Integer { value: -1 })) => format!("1 / {}", self.base.str()),
47
48            _ => format!("{}^{}", self.base.str(), self.exponent.str()),
49        }
50    }
51
52    fn get_exponent(&self) -> (Box<dyn Expr>, Box<dyn Expr>) {
53        (self.base.clone(), self.exponent.clone())
54    }
55
56    fn is_one(&self) -> bool {
57        self.exponent.is_neg_one() && self.base.is_one() || self.exponent.is_zero()
58    }
59
60    fn to_cpp(&self) -> String {
61        let exponent = &self.exponent;
62        if exponent.is_zero() {
63            String::from("1")
64        } else if exponent.is_one() {
65            self.base.to_cpp()
66        } else if exponent.is_neg_one() {
67            format!("1 / {}", self.base.to_cpp())
68        } else {
69            if let KnownExpr::Integer(Integer { value: n }) = exponent.known_expr()
70                && *n > 0
71            {
72                let n = *n as usize;
73                let base_cpp = self.base.to_cpp();
74
75                let mut res =
76                    String::with_capacity((base_cpp.len() + 3) * (n - 1) + base_cpp.len());
77                res += &base_cpp;
78                for _ in 1..n {
79                    res += " * ";
80                    res += &base_cpp;
81                }
82                res
83            } else {
84                format!("pow({}, {})", self.base.to_cpp(), self.exponent.to_cpp())
85            }
86        }
87    }
88
89    fn simplify(&self) -> Box<dyn Expr> {
90        let Pow { base, exponent } = self;
91
92        if exponent.is_one() {
93            if let Some(pow) = base.as_pow() {
94                pow.simplify()
95            } else {
96                base.simplify()
97            }
98        } else if exponent.is_zero() {
99            Integer::one_box()
100        } else if base.is_one() {
101            Integer::one_box()
102        } else if let Some(pow) = base.as_pow() {
103            let base = pow.base.clone_box();
104            let exponent = &pow.exponent * exponent;
105            Pow::pow(base, exponent)
106        } else {
107            match (base.known_expr(), exponent.known_expr()) {
108                (
109                    KnownExpr::Rational(Rational { num, denom }),
110                    KnownExpr::Integer(Integer { value }),
111                ) if *value > 0 => {
112                    Rational::new_box(num.pow(*value as u32), denom.pow(*value as u32))
113                }
114                (
115                    KnownExpr::Integer(Integer { value: n }),
116                    KnownExpr::Integer(Integer { value: e }),
117                ) if *e > 0 => Integer::new_box(n.pow(*e as u32)),
118                _ => self.clone_box(),
119            }
120        }
121    }
122}
123
124impl Pow {
125    pub fn new(base: &Box<dyn Expr>, exponent: &Box<dyn Expr>) -> Box<dyn Expr> {
126        Box::new(Pow {
127            base: base.clone(),
128            exponent: exponent.clone(),
129        })
130    }
131
132    pub fn new_move(base: Box<dyn Expr>, exponent: Box<dyn Expr>) -> Pow {
133        Pow { base, exponent }
134    }
135
136    pub fn new_box(base: Box<dyn Expr>, exponent: Box<dyn Expr>) -> Box<dyn Expr> {
137        Box::new(Pow { base, exponent })
138    }
139    pub fn base(&self) -> &dyn Expr {
140        &*self.base
141    }
142
143    pub fn exponent(&self) -> &dyn Expr {
144        &*self.exponent
145    }
146
147    pub fn pow(mut base: Box<dyn Expr>, mut exponent: Box<dyn Expr>) -> Box<dyn Expr> {
148        match (base.clone().known_expr(), exponent.known_expr()) {
149            (KnownExpr::Rational(r), KnownExpr::Integer(i)) if i.value > 0 => {
150                return Rational::new_box(r.num.pow(i.value as u32), r.denom.pow(i.value as u32));
151            }
152            (KnownExpr::Rational(r), _) => {
153                let mut r = r.clone();
154                if exponent.is_negative_number() {
155                    r.invert();
156                    exponent = match exponent.known_expr() {
157                        KnownExpr::Integer(i) => Box::new(-i),
158                        KnownExpr::Rational(r) => Box::new(-r),
159                        _ => panic!("{:?}", exponent.clone_box()),
160                    };
161                }
162                base = r.simplify().clone_box();
163            }
164            (
165                KnownExpr::Pow(Pow {
166                    base: base_base,
167                    exponent: base_exponent,
168                }),
169                _,
170            ) => {
171                base = base_base.clone_box();
172                exponent = base_exponent.get_ref() * exponent.get_ref();
173            }
174            _ => (),
175        }
176        if exponent.is_one() {
177            base.clone()
178        } else if exponent.is_zero() {
179            Integer::one_box()
180        } else {
181            match (base.as_f64(), exponent.as_f64()) {
182                (Some(b), Some(e)) => {
183                    let res = b.powf(e);
184
185                    if res.fract() == 0. {
186                        return Integer::new_box(res.to_isize().unwrap());
187                    }
188                }
189                _ => (),
190            }
191            Pow::new_box(base.clone(), exponent)
192        }
193    }
194}
195
196impl fmt::Debug for Pow {
197    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198        write!(f, "{:?}", self.get_ref())
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    #[test]
206    fn test_pow_simplify() {
207        let expr = Symbol::new("x").ipow(2).ipow(3);
208
209        assert_eq!(expr.srepr(), "Pow(Symbol(x), Integer(6))")
210    }
211
212    #[test]
213    fn test_sqrt_2() {
214        assert_eq!(
215            Integer::new(2).sqrt().srepr(),
216            "Pow(Integer(2), Rational(1, 2))"
217        )
218    }
219
220    #[test]
221    fn test_sqrt_4_simplifies() {
222        assert_eq!(
223            Integer::new(4).pow(&Rational::new_box(1, 2)).srepr(),
224            "Integer(2)"
225        )
226    }
227
228    #[test]
229    fn test_mul_sqrts() {
230        assert_eq!(
231            (Integer::new_box(2).sqrt() * Integer::new_box(3).sqrt()).srepr(),
232            "Pow(Integer(6), Rational(1, 2))"
233        )
234    }
235
236    #[test]
237    fn test_simplify_pow() {
238        assert_eq!(
239            Pow {
240                base: Pow::new_box(Symbol::new_box("x"), Integer::new_box(2)),
241                exponent: Integer::new_box(3)
242            }
243            .simplify()
244            .get_ref(),
245            Pow {
246                base: Symbol::new_box("x"),
247                exponent: Integer::new_box(6)
248            }
249            .get_ref()
250        )
251    }
252
253    #[test]
254    fn test_simplify_rational_pow() {
255        assert_eq!(Rational::new(2, 3).ipow(2), Rational::new_box(4, 9))
256    }
257}