1use super::*;
2
3#[derive(Clone)]
4pub struct Pow {
5 pub base: Box<dyn Expr>,
6 pub exponent: Box<dyn Expr>,
7}
8
9impl Expr for Pow {
10 fn known_expr(&self) -> KnownExpr {
11 KnownExpr::Pow(self)
12 }
13
14 fn as_pow(&self) -> Option<&Pow> {
15 Some(self)
16 }
17 fn get_ref<'a>(&'a self) -> &'a dyn Expr {
18 self as &dyn Expr
19 }
20 fn for_each_arg(&self, f: &mut dyn FnMut(&dyn Arg) -> ()) {
21 f(&*self.base);
22 f(&*self.exponent)
23 }
24
25 fn from_args(&self, args: Vec<Box<dyn Arg>>) -> Box<dyn Expr> {
26 Box::new(Pow {
27 base: args[0].clone().into(),
28 exponent: args[1].clone().into(),
29 })
30 }
31
32 fn clone_box(&self) -> Box<dyn Expr> {
33 Box::new(self.clone())
34 }
35
36 fn is_number(&self) -> bool {
37 self.base.is_number() && self.exponent.is_number()
38 }
39
40 fn str(&self) -> String {
41 match (
42 self.base.known_expr(),
43 KnownExpr::from_expr_box(&self.exponent),
44 ) {
45 (KnownExpr::Rational(r), _) => format!("({})^{}", r.str(), self.exponent.str()),
46 (_, KnownExpr::Integer(Integer { value: -1 })) => format!("1 / {}", self.base.str()),
47
48 _ => format!("{}^{}", self.base.str(), self.exponent.str()),
49 }
50 }
51
52 fn get_exponent(&self) -> (Box<dyn Expr>, Box<dyn Expr>) {
53 (self.base.clone(), self.exponent.clone())
54 }
55
56 fn is_one(&self) -> bool {
57 self.exponent.is_neg_one() && self.base.is_one() || self.exponent.is_zero()
58 }
59
60 fn to_cpp(&self) -> String {
61 let exponent = &self.exponent;
62 if exponent.is_zero() {
63 String::from("1")
64 } else if exponent.is_one() {
65 self.base.to_cpp()
66 } else if exponent.is_neg_one() {
67 format!("1 / {}", self.base.to_cpp())
68 } else {
69 if let KnownExpr::Integer(Integer { value: n }) = exponent.known_expr()
70 && *n > 0
71 {
72 let n = *n as usize;
73 let base_cpp = self.base.to_cpp();
74
75 let mut res =
76 String::with_capacity((base_cpp.len() + 3) * (n - 1) + base_cpp.len());
77 res += &base_cpp;
78 for _ in 1..n {
79 res += " * ";
80 res += &base_cpp;
81 }
82 res
83 } else {
84 format!("pow({}, {})", self.base.to_cpp(), self.exponent.to_cpp())
85 }
86 }
87 }
88
89 fn simplify(&self) -> Box<dyn Expr> {
90 let Pow { base, exponent } = self;
91
92 if exponent.is_one() {
93 if let Some(pow) = base.as_pow() {
94 pow.simplify()
95 } else {
96 base.simplify()
97 }
98 } else if exponent.is_zero() {
99 Integer::one_box()
100 } else if base.is_one() {
101 Integer::one_box()
102 } else if let Some(pow) = base.as_pow() {
103 let base = pow.base.clone_box();
104 let exponent = &pow.exponent * exponent;
105 Pow::pow(base, exponent)
106 } else {
107 match (base.known_expr(), exponent.known_expr()) {
108 (
109 KnownExpr::Rational(Rational { num, denom }),
110 KnownExpr::Integer(Integer { value }),
111 ) if *value > 0 => {
112 Rational::new_box(num.pow(*value as u32), denom.pow(*value as u32))
113 }
114 (
115 KnownExpr::Integer(Integer { value: n }),
116 KnownExpr::Integer(Integer { value: e }),
117 ) if *e > 0 => Integer::new_box(n.pow(*e as u32)),
118 _ => self.clone_box(),
119 }
120 }
121 }
122}
123
124impl Pow {
125 pub fn new(base: &Box<dyn Expr>, exponent: &Box<dyn Expr>) -> Box<dyn Expr> {
126 Box::new(Pow {
127 base: base.clone(),
128 exponent: exponent.clone(),
129 })
130 }
131
132 pub fn new_move(base: Box<dyn Expr>, exponent: Box<dyn Expr>) -> Pow {
133 Pow { base, exponent }
134 }
135
136 pub fn new_box(base: Box<dyn Expr>, exponent: Box<dyn Expr>) -> Box<dyn Expr> {
137 Box::new(Pow { base, exponent })
138 }
139 pub fn base(&self) -> &dyn Expr {
140 &*self.base
141 }
142
143 pub fn exponent(&self) -> &dyn Expr {
144 &*self.exponent
145 }
146
147 pub fn pow(mut base: Box<dyn Expr>, mut exponent: Box<dyn Expr>) -> Box<dyn Expr> {
148 match (base.clone().known_expr(), exponent.known_expr()) {
149 (KnownExpr::Rational(r), KnownExpr::Integer(i)) if i.value > 0 => {
150 return Rational::new_box(r.num.pow(i.value as u32), r.denom.pow(i.value as u32));
151 }
152 (KnownExpr::Rational(r), _) => {
153 let mut r = r.clone();
154 if exponent.is_negative_number() {
155 r.invert();
156 exponent = match exponent.known_expr() {
157 KnownExpr::Integer(i) => Box::new(-i),
158 KnownExpr::Rational(r) => Box::new(-r),
159 _ => panic!("{:?}", exponent.clone_box()),
160 };
161 }
162 base = r.simplify().clone_box();
163 }
164 (
165 KnownExpr::Pow(Pow {
166 base: base_base,
167 exponent: base_exponent,
168 }),
169 _,
170 ) => {
171 base = base_base.clone_box();
172 exponent = base_exponent.get_ref() * exponent.get_ref();
173 }
174 _ => (),
175 }
176 if exponent.is_one() {
177 base.clone()
178 } else if exponent.is_zero() {
179 Integer::one_box()
180 } else {
181 match (base.as_f64(), exponent.as_f64()) {
182 (Some(b), Some(e)) => {
183 let res = b.powf(e);
184
185 if res.fract() == 0. {
186 return Integer::new_box(res.to_isize().unwrap());
187 }
188 }
189 _ => (),
190 }
191 Pow::new_box(base.clone(), exponent)
192 }
193 }
194}
195
196impl fmt::Debug for Pow {
197 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198 write!(f, "{:?}", self.get_ref())
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 #[test]
206 fn test_pow_simplify() {
207 let expr = Symbol::new("x").ipow(2).ipow(3);
208
209 assert_eq!(expr.srepr(), "Pow(Symbol(x), Integer(6))")
210 }
211
212 #[test]
213 fn test_sqrt_2() {
214 assert_eq!(
215 Integer::new(2).sqrt().srepr(),
216 "Pow(Integer(2), Rational(1, 2))"
217 )
218 }
219
220 #[test]
221 fn test_sqrt_4_simplifies() {
222 assert_eq!(
223 Integer::new(4).pow(&Rational::new_box(1, 2)).srepr(),
224 "Integer(2)"
225 )
226 }
227
228 #[test]
229 fn test_mul_sqrts() {
230 assert_eq!(
231 (Integer::new_box(2).sqrt() * Integer::new_box(3).sqrt()).srepr(),
232 "Pow(Integer(6), Rational(1, 2))"
233 )
234 }
235
236 #[test]
237 fn test_simplify_pow() {
238 assert_eq!(
239 Pow {
240 base: Pow::new_box(Symbol::new_box("x"), Integer::new_box(2)),
241 exponent: Integer::new_box(3)
242 }
243 .simplify()
244 .get_ref(),
245 Pow {
246 base: Symbol::new_box("x"),
247 exponent: Integer::new_box(6)
248 }
249 .get_ref()
250 )
251 }
252
253 #[test]
254 fn test_simplify_rational_pow() {
255 assert_eq!(Rational::new(2, 3).ipow(2), Rational::new_box(4, 9))
256 }
257}