1use num_bigint::BigInt;
9use num_integer::Integer;
10use num_traits::{One, Signed, ToPrimitive, Zero};
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub enum SymbolicExpr {
15 Integer(BigInt),
16 Rational(BigInt, BigInt),
18 Sqrt { radicand: BigInt },
20 ScaledSqrt { coeff: (BigInt, BigInt), rad: BigInt },
22 Pi,
23 E,
24 Add(Vec<SymbolicExpr>),
25 Mul(Vec<SymbolicExpr>),
26 Pow { base: Box<SymbolicExpr>, exp: Box<SymbolicExpr> },
27 Sin(Box<SymbolicExpr>),
28 Cos(Box<SymbolicExpr>),
29 Exp(Box<SymbolicExpr>),
30 Ln(Box<SymbolicExpr>),
31}
32
33use SymbolicExpr::*;
34
35impl SymbolicExpr {
36 pub fn int(n: i64) -> Self {
37 Integer(BigInt::from(n))
38 }
39 pub fn rational(p: i64, q: i64) -> Self {
40 Rational(BigInt::from(p), BigInt::from(q))
41 }
42 pub fn sqrt(n: i64) -> Self {
43 Sqrt { radicand: BigInt::from(n) }
44 }
45 pub fn add(terms: Vec<SymbolicExpr>) -> Self {
46 Add(terms)
47 }
48 pub fn mul(factors: Vec<SymbolicExpr>) -> Self {
49 Mul(factors)
50 }
51 pub fn sin(x: SymbolicExpr) -> Self {
52 Sin(Box::new(x))
53 }
54 pub fn cos(x: SymbolicExpr) -> Self {
55 Cos(Box::new(x))
56 }
57 pub fn exp(x: SymbolicExpr) -> Self {
58 Exp(Box::new(x))
59 }
60 pub fn ln(x: SymbolicExpr) -> Self {
61 Ln(Box::new(x))
62 }
63
64 fn as_rational(&self) -> Option<(BigInt, BigInt)> {
66 match self {
67 Integer(n) => Some((n.clone(), BigInt::one())),
68 Rational(p, q) => Some((p.clone(), q.clone())),
69 _ => None,
70 }
71 }
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76pub enum TowerLevel {
77 Integer,
78 Rational,
79 Algebraic,
80 Symbolic,
81 Transcendental,
82}
83
84pub fn tower_level(expr: &SymbolicExpr) -> TowerLevel {
86 match expr {
87 Integer(_) => TowerLevel::Integer,
88 Rational(_, _) => TowerLevel::Rational,
89 Sqrt { .. } | ScaledSqrt { .. } => TowerLevel::Algebraic,
90 Pi | E | Sin(_) | Cos(_) | Exp(_) | Ln(_) => TowerLevel::Transcendental,
91 Add(t) => t.iter().map(tower_level).max_by_key(level_rank).unwrap_or(TowerLevel::Integer),
92 Mul(t) => t.iter().map(tower_level).max_by_key(level_rank).unwrap_or(TowerLevel::Integer),
93 Pow { base, .. } => tower_level(base).max_symbolic(),
94 }
95}
96
97fn level_rank(l: &TowerLevel) -> u8 {
98 match l {
99 TowerLevel::Integer => 0,
100 TowerLevel::Rational => 1,
101 TowerLevel::Algebraic => 2,
102 TowerLevel::Symbolic => 3,
103 TowerLevel::Transcendental => 4,
104 }
105}
106
107impl TowerLevel {
108 fn max_symbolic(self) -> TowerLevel {
109 if level_rank(&self) >= level_rank(&TowerLevel::Symbolic) {
110 self
111 } else {
112 TowerLevel::Symbolic
113 }
114 }
115}
116
117pub struct IdentityGraph;
121
122impl Default for IdentityGraph {
123 fn default() -> Self {
124 Self::standard()
125 }
126}
127
128impl IdentityGraph {
129 pub fn standard() -> Self {
131 IdentityGraph
132 }
133
134 pub fn simplify(&self, expr: SymbolicExpr) -> SymbolicExpr {
136 let mut current = expr;
137 for _ in 0..64 {
138 let next = self.step(current.clone());
139 if next == current {
140 return next;
141 }
142 current = next;
143 }
144 current
145 }
146
147 fn step(&self, expr: SymbolicExpr) -> SymbolicExpr {
149 match expr {
150 Add(terms) => self.simplify_add(terms),
151 Mul(factors) => self.simplify_mul(factors),
152 Sin(x) => self.simplify_sin(self.step(*x)),
153 Cos(x) => self.simplify_cos(self.step(*x)),
154 Exp(x) => self.simplify_exp(self.step(*x)),
155 Ln(x) => self.simplify_ln(self.step(*x)),
156 Pow { base, exp } => Pow {
157 base: Box::new(self.step(*base)),
158 exp: Box::new(self.step(*exp)),
159 },
160 Rational(p, q) => normalize_rational(p, q),
161 other => other,
162 }
163 }
164
165 fn simplify_add(&self, terms: Vec<SymbolicExpr>) -> SymbolicExpr {
166 let mut const_num = BigInt::zero();
167 let mut const_den = BigInt::one();
168 let mut others: Vec<SymbolicExpr> = Vec::new();
169 for t in terms {
170 let t = self.step(t);
171 match &t {
172 Add(inner) => {
173 for it in inner.clone() {
175 self.accumulate_add(it, &mut const_num, &mut const_den, &mut others);
176 }
177 }
178 _ => self.accumulate_add(t, &mut const_num, &mut const_den, &mut others),
179 }
180 }
181 let mut result: Vec<SymbolicExpr> = Vec::new();
182 if !const_num.is_zero() {
183 result.push(normalize_rational(const_num, const_den));
184 }
185 result.append(&mut others);
186 match result.len() {
187 0 => SymbolicExpr::int(0),
188 1 => result.into_iter().next().unwrap(),
189 _ => Add(result),
190 }
191 }
192
193 fn accumulate_add(
194 &self,
195 t: SymbolicExpr,
196 num: &mut BigInt,
197 den: &mut BigInt,
198 others: &mut Vec<SymbolicExpr>,
199 ) {
200 if let Some((p, q)) = t.as_rational() {
201 *num = &*num * &q + &p * &*den;
203 *den = &*den * &q;
204 } else {
205 others.push(t);
206 }
207 }
208
209 fn simplify_mul(&self, factors: Vec<SymbolicExpr>) -> SymbolicExpr {
210 let mut coeff_num = BigInt::one();
211 let mut coeff_den = BigInt::one();
212 let mut radicand = BigInt::one();
213 let mut others: Vec<SymbolicExpr> = Vec::new();
214 let mut is_zero = false;
215
216 let mut stack: Vec<SymbolicExpr> = factors.into_iter().map(|f| self.step(f)).collect();
217 while let Some(f) = stack.pop() {
218 match f {
219 Mul(inner) => stack.extend(inner.into_iter().map(|f| self.step(f))),
220 Integer(n) => {
221 if n.is_zero() {
222 is_zero = true;
223 }
224 coeff_num *= n;
225 }
226 Rational(p, q) => {
227 if p.is_zero() {
228 is_zero = true;
229 }
230 coeff_num *= p;
231 coeff_den *= q;
232 }
233 Sqrt { radicand: r } => radicand *= r,
234 ScaledSqrt { coeff: (a, b), rad } => {
235 coeff_num *= a;
236 coeff_den *= b;
237 radicand *= rad;
238 }
239 other => others.push(other),
240 }
241 }
242
243 if is_zero {
244 return SymbolicExpr::int(0);
245 }
246
247 if !radicand.is_one() {
249 match simplify_sqrt(radicand) {
250 Integer(k) => coeff_num *= k,
251 ScaledSqrt { coeff: (a, b), rad } => {
252 coeff_num *= a;
253 coeff_den *= b;
254 others.push(Sqrt { radicand: rad });
255 }
256 Sqrt { radicand: r } => others.push(Sqrt { radicand: r }),
257 e => others.push(e),
258 }
259 }
260
261 let g = coeff_num.gcd(&coeff_den);
263 if !g.is_zero() {
264 coeff_num /= &g;
265 coeff_den /= &g;
266 }
267 if coeff_den.is_negative() {
268 coeff_num = -coeff_num;
269 coeff_den = -coeff_den;
270 }
271
272 let coeff_is_one = coeff_num.is_one() && coeff_den.is_one();
273
274 if others.len() == 1 {
276 if let Sqrt { radicand: r } = &others[0] {
277 if coeff_is_one {
278 return Sqrt { radicand: r.clone() };
279 }
280 return ScaledSqrt {
281 coeff: (coeff_num, coeff_den),
282 rad: r.clone(),
283 };
284 }
285 }
286
287 let mut result: Vec<SymbolicExpr> = Vec::new();
288 if !coeff_is_one {
289 result.push(normalize_rational(coeff_num, coeff_den));
290 }
291 result.append(&mut others);
292 match result.len() {
293 0 => SymbolicExpr::int(1),
294 1 => result.into_iter().next().unwrap(),
295 _ => Mul(result),
296 }
297 }
298
299 fn simplify_sin(&self, x: SymbolicExpr) -> SymbolicExpr {
300 if let Integer(n) = &x {
301 if n.is_zero() {
302 return SymbolicExpr::int(0);
303 }
304 }
305 if let Some((a, b)) = as_pi_multiple(&x) {
306 let (a, b) = reduce(a, b);
307 if let Some(v) = sin_pi_table(&a, &b) {
309 return v;
310 }
311 }
312 Sin(Box::new(x))
313 }
314
315 fn simplify_cos(&self, x: SymbolicExpr) -> SymbolicExpr {
316 if let Integer(n) = &x {
317 if n.is_zero() {
318 return SymbolicExpr::int(1);
319 }
320 }
321 if let Some((a, b)) = as_pi_multiple(&x) {
322 let (a, b) = reduce(a, b);
323 if let Some(v) = cos_pi_table(&a, &b) {
324 return v;
325 }
326 }
327 Cos(Box::new(x))
328 }
329
330 fn simplify_exp(&self, x: SymbolicExpr) -> SymbolicExpr {
331 if let Integer(n) = &x {
332 if n.is_zero() {
333 return SymbolicExpr::int(1);
334 }
335 }
336 Exp(Box::new(x))
337 }
338
339 fn simplify_ln(&self, x: SymbolicExpr) -> SymbolicExpr {
340 if let Integer(n) = &x {
341 if n.is_one() {
342 return SymbolicExpr::int(0);
343 }
344 }
345 Ln(Box::new(x))
346 }
347}
348
349fn normalize_rational(mut p: BigInt, mut q: BigInt) -> SymbolicExpr {
351 if q.is_zero() {
352 return Rational(p, q); }
354 if q.is_negative() {
355 p = -p;
356 q = -q;
357 }
358 let g = p.gcd(&q);
359 if !g.is_zero() {
360 p /= &g;
361 q /= &g;
362 }
363 if q.is_one() {
364 Integer(p)
365 } else {
366 Rational(p, q)
367 }
368}
369
370fn reduce(mut a: BigInt, mut b: BigInt) -> (BigInt, BigInt) {
372 if b.is_negative() {
373 a = -a;
374 b = -b;
375 }
376 let g = a.gcd(&b);
377 if !g.is_zero() {
378 a /= &g;
379 b /= &g;
380 }
381 (a, b)
382}
383
384fn simplify_sqrt(n: BigInt) -> SymbolicExpr {
386 if n.is_negative() || n.is_zero() {
387 return Sqrt { radicand: n };
388 }
389 let nu = match n.to_u128() {
390 Some(v) => v,
391 None => return Sqrt { radicand: n },
392 };
393 let mut square = 1u128;
394 let mut rad = nu;
395 let mut d = 2u128;
396 while d * d <= rad {
397 while rad % (d * d) == 0 {
398 rad /= d * d;
399 square *= d;
400 }
401 d += 1;
402 }
403 let s = BigInt::from(square);
404 let r = BigInt::from(rad);
405 if rad == 1 {
406 Integer(s)
407 } else if square == 1 {
408 Sqrt { radicand: r }
409 } else {
410 ScaledSqrt { coeff: (s, BigInt::one()), rad: r }
411 }
412}
413
414fn as_pi_multiple(expr: &SymbolicExpr) -> Option<(BigInt, BigInt)> {
416 match expr {
417 Pi => Some((BigInt::one(), BigInt::one())),
418 Mul(factors) => {
419 let mut num = BigInt::one();
420 let mut den = BigInt::one();
421 let mut pi_count = 0;
422 for f in factors {
423 match f {
424 Pi => pi_count += 1,
425 Integer(n) => num *= n,
426 Rational(p, q) => {
427 num *= p;
428 den *= q;
429 }
430 _ => return None,
431 }
432 }
433 if pi_count == 1 {
434 Some((num, den))
435 } else {
436 None
437 }
438 }
439 _ => None,
440 }
441}
442
443fn frac_is(a: &BigInt, b: &BigInt, n: i64, d: i64) -> bool {
444 *a == BigInt::from(n) && *b == BigInt::from(d)
445}
446
447fn sin_pi_table(a: &BigInt, b: &BigInt) -> Option<SymbolicExpr> {
449 if a.is_zero() {
450 return Some(SymbolicExpr::int(0));
451 }
452 if frac_is(a, b, 1, 1) {
453 return Some(SymbolicExpr::int(0)); }
455 if frac_is(a, b, 1, 6) {
456 return Some(SymbolicExpr::rational(1, 2));
457 }
458 if frac_is(a, b, 1, 4) {
459 return Some(ScaledSqrt { coeff: (BigInt::one(), BigInt::from(2)), rad: BigInt::from(2) });
460 }
461 if frac_is(a, b, 1, 3) {
462 return Some(ScaledSqrt { coeff: (BigInt::one(), BigInt::from(2)), rad: BigInt::from(3) });
463 }
464 if frac_is(a, b, 1, 2) {
465 return Some(SymbolicExpr::int(1)); }
467 None
468}
469
470fn cos_pi_table(a: &BigInt, b: &BigInt) -> Option<SymbolicExpr> {
472 if a.is_zero() {
473 return Some(SymbolicExpr::int(1));
474 }
475 if frac_is(a, b, 1, 1) {
476 return Some(SymbolicExpr::int(-1)); }
478 if frac_is(a, b, 1, 2) {
479 return Some(SymbolicExpr::int(0)); }
481 None
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 fn g() -> IdentityGraph {
489 IdentityGraph::standard()
490 }
491
492 #[test]
493 fn sin_pi_is_zero() {
494 assert_eq!(g().simplify(SymbolicExpr::sin(Pi)), SymbolicExpr::int(0));
495 }
496
497 #[test]
498 fn cos_pi_is_minus_one() {
499 assert_eq!(g().simplify(SymbolicExpr::cos(Pi)), SymbolicExpr::int(-1));
500 }
501
502 #[test]
503 fn sin_pi_over_six() {
504 let expr = SymbolicExpr::sin(Mul(vec![SymbolicExpr::rational(1, 6), Pi]));
505 assert_eq!(g().simplify(expr), SymbolicExpr::rational(1, 2));
506 }
507
508 #[test]
509 fn exp_zero_is_one() {
510 assert_eq!(g().simplify(SymbolicExpr::exp(SymbolicExpr::int(0))), SymbolicExpr::int(1));
511 }
512
513 #[test]
514 fn ln_one_is_zero() {
515 assert_eq!(g().simplify(SymbolicExpr::ln(SymbolicExpr::int(1))), SymbolicExpr::int(0));
516 }
517
518 #[test]
519 fn sqrt_times_sqrt() {
520 let expr = Mul(vec![SymbolicExpr::sqrt(2), SymbolicExpr::sqrt(2)]);
521 assert_eq!(g().simplify(expr), SymbolicExpr::int(2));
522 }
523
524 #[test]
525 fn x_times_zero() {
526 let expr = Mul(vec![Pi, SymbolicExpr::int(0)]);
527 assert_eq!(g().simplify(expr), SymbolicExpr::int(0));
528 }
529
530 #[test]
531 fn add_zero_identity() {
532 let expr = Add(vec![Pi, SymbolicExpr::int(0)]);
533 assert_eq!(g().simplify(expr), Pi);
534 }
535
536 #[test]
537 fn mul_one_identity() {
538 let expr = Mul(vec![Pi, SymbolicExpr::int(1)]);
539 assert_eq!(g().simplify(expr), Pi);
540 }
541
542 #[test]
543 fn sqrt_eight_simplifies() {
544 assert_eq!(
546 simplify_sqrt(BigInt::from(8)),
547 ScaledSqrt { coeff: (BigInt::from(2), BigInt::one()), rad: BigInt::from(2) }
548 );
549 }
550
551 #[test]
552 fn classification() {
553 assert_eq!(tower_level(&SymbolicExpr::int(3)), TowerLevel::Integer);
554 assert_eq!(tower_level(&SymbolicExpr::sqrt(2)), TowerLevel::Algebraic);
555 assert_eq!(tower_level(&Pi), TowerLevel::Transcendental);
556 }
557}