1use super::*;
2use itertools::Itertools;
3use num_traits::ToPrimitive;
4
5pub mod macros;
6
7pub mod pow;
8pub use pow::*;
9
10pub mod eq;
11pub use eq::*;
12
13pub mod add;
14pub use add::*;
15
16pub mod mul;
17pub use mul::*;
18
19pub mod function;
20pub use function::*;
21
22pub mod diff;
23pub use diff::*;
24
25pub mod integral;
26pub use integral::*;
27
28pub mod symbol;
29use schemars::{JsonSchema, json_schema};
30pub use symbol::*;
31
32pub mod integer;
33pub use integer::*;
34
35pub mod rational;
36pub use rational::*;
37
38pub mod ops;
39pub type BoxExpr = Box<dyn Expr>;
42
43use std::{
44 any::Any,
45 cmp::{Ordering, PartialEq},
46 collections::HashMap,
47 fmt::{self},
48 hash::Hash,
49 iter,
50};
51
52pub trait Arg: Any {
53 fn srepr(&self) -> String;
54
55 fn clone_arg(&self) -> Box<dyn Arg>;
56
57 fn as_expr(&self) -> Option<Box<dyn Expr>> {
58 None
59 }
60
61 fn map_expr(&self, f: &dyn (Fn(&dyn Expr) -> Box<dyn Expr>)) -> Box<dyn Arg> {
62 if let Some(expr) = self.as_expr() {
63 f(expr.get_ref())
64 } else {
65 self.clone_arg()
66 }
67 }
68}
69
70pub trait ArgOperations {}
71
72impl<A: Arg> ArgOperations for A {}
73
74pub trait AsAny {
75 fn as_any(&self) -> &dyn Any;
76}
77
78impl AsAny for Box<dyn Arg> {
96 fn as_any(&self) -> &dyn Any {
97 let any = (&**self) as &dyn Any;
98 any
99 }
100}
101
102impl AsAny for Box<dyn Expr> {
103 fn as_any(&self) -> &dyn Any {
104 let any = (&**self) as &dyn Any;
105 any
106 }
107}
108
109impl fmt::Debug for Box<dyn Arg> {
110 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111 write!(f, "{}", self.srepr())
112 }
113}
114impl JsonSchema for dyn Expr {
115 fn schema_name() -> std::borrow::Cow<'static, str> {
116 "Expression".into()
117 }
118
119 fn schema_id() -> std::borrow::Cow<'static, str> {
120 concat!(module_path!(), "::Expression").into()
121 }
122
123 fn json_schema(_generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
124 json_schema!({
125 "title": "Expression",
126 "description": "A symbolic expression.",
127 "oneOf": [
128 {
129 "type": "string"
130 },
131 {
132 "type": "number"
133 }
134 ]
135 })
136 }
137
138 fn inline_schema() -> bool {
139 true
140 }
141}
142
143impl Arg for isize {
144 fn srepr(&self) -> String {
145 self.to_string()
146 }
147
148 fn clone_arg(&self) -> Box<dyn Arg> {
149 Box::new(self.clone())
150 }
151}
152
153impl Arg for String {
154 fn srepr(&self) -> String {
155 self.clone()
156 }
157
158 fn clone_arg(&self) -> Box<dyn Arg> {
159 Box::new(self.clone())
160 }
161}
162
163impl<A: Arg + Clone, B: Arg + Clone> Arg for (A, B) {
164 fn srepr(&self) -> String {
165 format!("({}, {})", self.0.srepr(), self.1.srepr())
166 }
167
168 fn clone_arg(&self) -> Box<dyn Arg> {
169 Box::new((self.0.clone(), self.1.clone()))
170 }
171}
172
173impl std::fmt::Debug for &dyn Arg {
174 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175 write!(f, "{}", self.srepr())
176 }
177}
178
179impl Arg for Box<dyn Expr> {
180 fn srepr(&self) -> String {
181 (**self).srepr()
182 }
183
184 fn clone_arg(&self) -> Box<dyn Arg> {
185 self.clone_box().into()
186 }
187
188 fn as_expr(&self) -> Option<Box<dyn Expr>> {
189 Some(self.clone_box())
190 }
191}
192
193impl Arg for usize {
194 fn srepr(&self) -> String {
195 self.to_string()
196 }
197
198 fn clone_arg(&self) -> Box<dyn Arg> {
199 Box::new(self.clone())
200 }
201}
202
203impl<A: Arg + Clone> Arg for Vec<A> {
204 fn srepr(&self) -> String {
205 let args = self
206 .iter()
207 .map(|arg| arg.srepr())
208 .collect::<Vec<String>>()
209 .join(", ");
210 format!("({})", args)
211 }
212
213 fn clone_arg(&self) -> Box<dyn Arg> {
214 Box::new(self.clone())
215 }
216}
217
218impl Clone for Box<dyn Arg> {
234 fn clone(&self) -> Self {
235 self.clone_arg()
236 }
237}
238
239impl From<Box<dyn Expr>> for Box<dyn Arg> {
240 fn from(value: Box<dyn Expr>) -> Self {
241 Box::new(value.clone())
242 }
243}
244
245impl From<Box<dyn Arg>> for Box<dyn Expr> {
246 fn from(value: Box<dyn Arg>) -> Self {
247 value.as_expr().expect("This arg is not an expr")
248 }
249}
250
251impl<T> Arg for T
252where
253 T: Expr,
254{
255 fn srepr(&self) -> String {
256 let args = self
257 .args()
258 .iter()
259 .map(|arg| arg.srepr())
260 .collect::<Vec<String>>()
261 .join(", ");
262 format!("{}({})", self.name(), args)
263 }
264
265 fn clone_arg(&self) -> Box<dyn Arg> {
266 self.as_arg()
267 }
268}
269
270impl FromIterator<Box<dyn Expr>> for Vec<Box<dyn Arg>> {
271 fn from_iter<T: IntoIterator<Item = Box<dyn Expr>>>(iter: T) -> Self {
272 let mut res = Vec::new();
273 for expr in iter {
274 res.push(expr.clone_box().into())
275 }
276 res
277 }
278}
279
280impl FromIterator<Box<dyn Arg>> for Vec<Box<dyn Expr>> {
281 fn from_iter<T: IntoIterator<Item = Box<dyn Arg>>>(iter: T) -> Self {
282 let mut res = Vec::new();
283 for arg in iter {
284 res.push(arg.clone().into())
286 }
291 res
292 }
293}
294
295pub trait Expr: Arg + Sync + Send {
302 fn as_expr(&self) -> Option<Box<dyn Expr>> {
303 Some(self.clone_box())
304 }
305 fn args(&self) -> Vec<Box<dyn Arg>> {
306 let mut res = Vec::new();
307 self.for_each_arg(&mut |a| res.push(a.clone_arg()));
308 res
309 }
310
311 fn for_each_arg(&self, f: &mut dyn FnMut(&dyn Arg) -> ());
312
313 fn args_map_exprs(&self, _f: &dyn Fn(&dyn Expr) -> Box<dyn Arg>) -> Vec<Box<dyn Arg>> {
314 todo!("Doesn't work at the moment");
315 }
327
328 fn from_args(&self, args: Vec<Box<dyn Arg>>) -> Box<dyn Expr> {
329 panic!(
330 "from_args not implemented for {}, supplied args :\n{:#?}",
331 self.name(),
332 &args
333 )
334 }
335
336 fn as_arg(&self) -> Box<dyn Arg> {
337 self.clone_box().into()
338 }
339
340 fn as_function(&self) -> Option<&Func> {
341 None
342 }
343 fn equals(&self, other: &dyn Expr) -> bool {
344 self.srepr() == other.srepr()
345 }
346 fn clone_box(&self) -> Box<dyn Expr>;
347
348 fn as_symbol(&self) -> Option<Symbol> {
349 let res = self.clone_box();
350 match KnownExpr::from_expr_box(&res) {
351 KnownExpr::Symbol(symbol) => Some(symbol.clone()),
352 _ => None,
353 }
354 }
355
356 fn as_eq(&self) -> Option<Equation> {
357 let res = self.clone_box();
358 match KnownExpr::from_expr_box(&res) {
359 KnownExpr::Eq(eq) => Some(eq.clone()),
360 _ => None,
361 }
362 }
363
364 fn as_mul(&self) -> Option<&Mul> {
365 None
366 }
367
368 fn as_pow(&self) -> Option<&Pow> {
369 None
370 }
371
372 fn as_f64(&self) -> Option<f64> {
373 None
374 }
375
376 fn to_cpp(&self) -> String {
377 self.str()
378 }
379
380 fn simplify(&self) -> Box<dyn Expr> {
381 self.from_args(
382 self.args()
383 .iter()
384 .map(|a| a.map_expr(&|e| e.simplify()))
385 .collect(),
386 )
387 }
388
389 fn simplify_with_dimension(&self, dim: usize) -> Box<dyn Expr> {
390 let expr = self.simplify();
391 expr.from_args(
392 expr.args()
393 .iter()
394 .map(|a| a.map_expr(&|e| e.simplify_with_dimension(dim)))
395 .collect(),
396 )
397 }
398
399 fn as_int(&self) -> Option<Integer> {
400 let res = self.clone_box();
401 match KnownExpr::from_expr_box(&res) {
402 KnownExpr::Integer(i) => Some(i.clone()),
403 _ => None,
404 }
405 }
406
407 fn str(&self) -> String;
408
409 fn pow(&self, exponent: &Box<dyn Expr>) -> Box<dyn Expr> {
410 Pow::pow(self.clone_box(), exponent.clone())
411 }
412
413 fn ipow(&self, exponent: isize) -> Box<dyn Expr> {
414 Pow::pow(self.clone_box(), Integer::new_box(exponent))
415 }
416
417 fn sqrt(&self) -> Box<dyn Expr> {
418 Pow::pow(self.clone_box(), Rational::new_box(1, 2))
419 }
420
421 fn get_exponent(&self) -> (Box<dyn Expr>, Box<dyn Expr>) {
422 (self.clone_box(), Integer::one_box())
423 }
424
425 fn diff(&self, var: &str, order: usize) -> Box<dyn Expr> {
426 Box::new(Diff::idiff(self.clone_box(), Symbol::new(var), order))
427 }
428
429 fn name(&self) -> String {
430 std::any::type_name_of_val(self)
431 .to_string()
432 .split("::")
433 .last()
434 .unwrap()
435 .to_string()
436 }
437
438 fn subs(&self, substitutions: &[[Box<dyn Expr>; 2]]) -> Box<dyn Expr> {
439 ops::subs(self, substitutions)
440 }
441
442 fn has(&self, expr: &dyn Expr) -> bool {
443 if self.equals(expr) {
444 true
445 } else {
446 self.args()
447 .iter()
448 .filter_map(|a| a.as_expr())
449 .any(|e| e.has(expr))
450 }
451 }
452
453 fn has_box(&self, expr: Box<dyn Expr>) -> bool {
454 self.has(&*expr)
455 }
456
457 fn expand(&self) -> Box<dyn Expr> {
461 self.from_args(
462 self.args()
463 .iter()
464 .map(|a| {
465 if let Some(expr) = a.as_expr() {
466 expr.expand()
467 } else {
468 a.clone()
469 }
470 })
471 .collect(),
472 )
473 }
474
475 fn factor(&self, factors: &[&dyn Expr]) -> Box<dyn Expr> {
479 ops::factor(self, factors)
480 }
481
482 fn is_one(&self) -> bool {
483 false
484 }
485 fn is_neg_one(&self) -> bool {
486 false
487 }
488
489 fn is_number(&self) -> bool {
490 false
491 }
492
493 fn is_negative_number(&self) -> bool {
494 false }
496
497 fn is_zero(&self) -> bool {
498 false
499 }
500
501 fn known_expr(&self) -> KnownExpr {
502 KnownExpr::Unknown
503 }
504
505 fn terms<'a>(&'a self) -> Box<dyn Iterator<Item = &'a dyn Expr> + 'a> {
506 Box::new(iter::once(self.get_ref()))
507 }
508
509 fn get_ref<'a>(&'a self) -> &'a dyn Expr;
510
511 fn get_coeff(&self) -> (Rational, Box<dyn Expr>) {
512 match KnownExpr::from_expr(self.get_ref()) {
513 KnownExpr::Integer(i) => i.into(),
514 KnownExpr::Pow(pow) => {
515 let (pow_coeff, pow_expr) = pow.base().get_coeff();
516
517 if pow_coeff.is_one() {
518 return (Rational::one(), pow_expr.pow(&pow.exponent().clone_box()));
519 }
520 let coeff_box = (pow_coeff).pow(&(pow.exponent()).clone_box());
521
522 match coeff_box.known_expr() {
523 KnownExpr::Integer(i) => return (i.into(), pow_expr),
524 KnownExpr::Rational(r) => return (*r, pow_expr),
525 KnownExpr::Pow(Pow {
526 base: coeff_base,
527 exponent: _,
528 }) => {
529 return (
530 Rational::one(),
531 Pow::pow(
532 coeff_base.clone_box() * pow_expr,
533 pow.exponent().clone_box(),
534 ),
535 );
536 }
537 _ => panic!("help: {:?}", coeff_box),
548 }
549
550 }
559 KnownExpr::Rational(r) => r.into(),
563 KnownExpr::Mul(Mul { operands }) => {
564 let mut coeff = Rational::one();
565 let mut expr = Integer::new_box(1);
566
567 operands
568 .iter()
569 .for_each(|op| match KnownExpr::from_expr_box(op) {
570 KnownExpr::Integer(i) => coeff *= i,
571 KnownExpr::Rational(r) => coeff *= r,
572 KnownExpr::Pow(Pow { base, exponent }) if exponent.is_neg_one() => {
573 let (pow_coeff, pow_expr) = base.get_coeff();
574 coeff /= pow_coeff;
575 expr *= (Pow {
576 base: pow_expr,
577 exponent: Integer::new_box(-1),
578 })
579 .get_ref();
580 }
581 _ => expr *= op,
582 });
583
584 (coeff, expr)
585 }
586 _ => (Rational::one(), self.clone_box()),
587 }
588 }
589
590 fn compare(&self, other: &dyn Expr) -> Option<Ordering> {
591 ops::compare(self, other)
592 }
593
594 fn evaluate(&self, vars: Option<HashMap<Symbol, BoxExpr>>) -> BoxExpr {
595 let substs = if let Some(vars) = vars {
596 vars.iter()
597 .map(|(s, v)| [s.clone_box(), v.clone_box()])
598 .collect_vec()
599 } else {
600 Vec::new()
601 };
602
603 let expr = self.subs(&substs).expand().simplify();
604 expr
605 }
606}
607
608impl std::hash::Hash for dyn Expr {
609 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
610 self.srepr().hash(state);
611 }
612}
613
614impl From<&Integer> for (Rational, Box<dyn Expr>) {
615 fn from(i: &Integer) -> Self {
616 (i.into(), Integer::new_box(1))
617 }
618}
619
620impl From<&Rational> for (Rational, Box<dyn Expr>) {
621 fn from(r: &Rational) -> Self {
622 (r.clone(), Integer::new_box(1))
623 }
624}
625
626impl ToPrimitive for Integer {
627 fn to_i64(&self) -> Option<i64> {
628 Some(self.value.try_into().unwrap())
629 }
630
631 fn to_u64(&self) -> Option<u64> {
632 Some(self.value.try_into().unwrap())
633 }
634}
635
636impl ToPrimitive for &Integer {
637 fn to_i64(&self) -> Option<i64> {
638 (*self).to_i64()
639 }
640
641 fn to_u64(&self) -> Option<u64> {
642 (*self).to_u64()
643 }
644}
645impl ExprOperations for &dyn Expr {}
646
647#[cfg(test)]
648mod tests {
649 use std::collections::HashSet;
650
651 use super::*;
652
653 #[test]
654 fn check_has() {
655 let x = &Symbol::new("x");
656 let y = &Symbol::new("y");
657 let expr = Equation::new(x, (y + x).get_ref());
658 assert!(expr.has(x));
659 assert!(expr.has(y));
660 assert!(expr.has((y + x).get_ref()));
661 assert!(!expr.has((x + y).get_ref()));
663 assert!(!expr.has(&Symbol::new("z")));
664 }
665
666 #[test]
667 fn test_expand_simple() {
668 let x = &Symbol::new("x") as &dyn Expr;
669 let y = &Symbol::new("y") as &dyn Expr;
670 let z = &Symbol::new("z") as &dyn Expr;
671
672 let expr = (x + y) * z;
673 let expected = x * z + y * z;
674
675 assert!((expr.expand()).equals(&*expected))
676 }
677 #[test]
678 fn test_expand_with_first_arg_int() {
679 let x = &Symbol::new("x");
680 let y = &Symbol::new("y");
681 let z = &Symbol::new("z");
682 let i2 = &Integer::new(2);
683
684 let expr = i2 * &(x + y) * z;
685 let expected = i2 * x * z + i2 * y * z;
686
687 assert!((expr.expand()).equals(&*expected))
688 }
689 #[test]
690 fn test_expand_complex() {
692 let x = &Symbol::new("x");
693 let y = &Symbol::new("y");
694 let w = &Symbol::new("w");
695 let z = &Symbol::new("z");
696 let i2 = &Integer::new(2);
697
698 let expr = i2 * &(x + y) * &(w + z);
699 let expected = i2 * x * w + i2 * x * z + i2 * y * w + i2 * y * z;
700
701 assert!((expr.expand()).equals(&*(expected)))
702 }
703
704 #[test]
705 fn test_get_coeff_trivial() {
706 assert_eq!(
707 Integer::new(1).get_coeff(),
708 (Rational::one(), Integer::new_box(1))
709 );
710 }
711 #[test]
712 fn test_get_coeff_basic() {
713 let expr = Integer::new(1).get_ref() / Integer::new(2).get_ref();
714 assert_eq!(expr.get_coeff(), (Rational::new(1, 2), Integer::new_box(1)));
715 }
716 #[test]
717 fn test_get_coeff_basic_2() {
718 let num = &Integer::new(5);
719 let denom = &Integer::new(7);
720 let expr = num.get_ref() / denom.get_ref();
721
722 assert_eq!(expr.get_coeff(), (Rational::new(5, 7), Integer::new_box(1)));
723 }
724
725 #[test]
726 fn test_get_coeff_normal() {
727 let x = &Symbol::new("x");
728 let num = &Integer::new(5);
729 let denom = &Integer::new(7);
730 let expr = x * num / denom;
731
732 assert_eq!(expr.get_coeff(), (Rational::new(5, 7), x.clone_box()));
733 }
734
735 #[test]
736 fn test_check_hashing_works() {
737 let mut set = HashSet::with_capacity(2);
738 let x = &Symbol::new_box("x");
739 let x_bis = &Symbol::new_box("x");
740
741 set.insert(x);
742 set.insert(x_bis);
743
744 assert_eq!(set.len(), 1)
745 }
746
747 #[test]
748 fn test_check_sqrt() {
749 let x = &Integer::new(2).sqrt();
750
751 assert_eq!(x.srepr(), "Pow(Integer(2), Rational(1, 2))")
752 }
753
754 #[test]
755 fn test_get_sqrt_exponent() {
756 let sqrt_2 = &Integer::new(2).sqrt();
757
758 assert_eq!(
759 sqrt_2.get_exponent(),
760 (Integer::new_box(2), Rational::new_box(1, 2))
761 )
762 }
763
764 #[test]
765 fn test_check_sqrt_simplifies() {
766 let x = &Integer::new(2).sqrt();
767
768 let expr = x * x;
769
770 assert_eq!(&expr, &Integer::new_box(2))
771 }
772
773 #[test]
774 fn test_check_coeff_sqrt_2() {
775 let sqrt_2 = &Integer::new(2).sqrt();
776
777 assert_eq!(sqrt_2.get_coeff(), (Rational::one(), sqrt_2.clone_box()))
778 }
779}
780
781pub struct ExprWrapper<'a, E: Expr> {
799 expr: &'a E,
800}
801
802impl<'a, E: Expr> std::cmp::PartialEq for ExprWrapper<'a, E> {
803 fn eq(&self, other: &Self) -> bool {
804 self.expr.srepr() == other.expr.srepr()
805 }
806}
807
808impl<'a, E: Expr> std::cmp::Eq for ExprWrapper<'a, E> {}
809
810impl<'a, E: Expr> ExprWrapper<'a, E> {
811 pub fn new(expr: &'a E) -> Self {
812 ExprWrapper { expr }
813 }
814}
815impl std::cmp::PartialEq for &dyn Expr {
816 fn eq(&self, other: &Self) -> bool {
817 self.srepr() == other.srepr()
818 }
819}
820
821impl std::cmp::PartialEq<Box<dyn Expr>> for &dyn Expr {
822 fn eq(&self, other: &Box<dyn Expr>) -> bool {
823 *self == &**other
824 }
825}
826
827impl std::cmp::PartialEq<&dyn Expr> for Box<dyn Expr> {
828 fn eq(&self, other: &&dyn Expr) -> bool {
829 self.get_ref() == *other
830 }
831}
832
833impl std::cmp::PartialEq<&Box<dyn Expr>> for &dyn Expr {
834 fn eq(&self, other: &&Box<dyn Expr>) -> bool {
835 *self == other.get_ref()
836 }
837}
838
839impl fmt::Debug for &dyn Expr {
840 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
841 write!(f, "{} [{}])", self.str(), self.srepr())
842 }
843}
844
845impl std::cmp::Eq for &dyn Expr {}
846
847pub trait ExprOperations {
876 fn subs_refs<'a, Iter: IntoIterator<Item = [&'a dyn Expr; 2]>>(
877 &self,
878 substitutions: Iter,
879 ) -> Box<dyn Expr>
880 where
881 Self: Expr,
882 {
883 for [replaced, replacement] in substitutions.into_iter() {
884 if self.srepr() == replaced.srepr() {
885 return replacement.clone_box();
886 }
887 }
888 todo!()
889 }
890
891 }
925
926impl<T> ExprOperations for T where T: Expr {}
927impl fmt::Display for &dyn Expr {
930 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
931 write!(f, "{}", self.str())
932 }
933}
934
935impl PartialEq for Box<dyn Expr> {
936 fn eq(&self, other: &Self) -> bool {
937 self.srepr() == other.srepr()
938 }
939}
940
941impl std::cmp::Eq for Box<dyn Expr> {}
942
943impl std::fmt::Debug for Box<dyn Expr> {
944 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
945 write!(f, "{:?}", self.get_ref())
946 }
947}
948
949impl std::fmt::Display for Box<dyn Expr> {
950 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
951 write!(f, "{}", self.str())
952 }
953}
954
955impl Clone for Box<dyn Expr> {
956 fn clone(&self) -> Self {
957 self.clone_box()
958 }
959}
960
961impl std::ops::Neg for &dyn Expr {
962 type Output = Box<dyn Expr>;
963
964 fn neg(self) -> Self::Output {
965 Integer::new_box(-1) * self
966 }
967}
968
969impl std::ops::Neg for &Box<dyn Expr> {
970 type Output = Box<dyn Expr>;
971
972 fn neg(self) -> Self::Output {
973 -&**self
974 }
975}
976
977impl std::ops::Neg for Box<dyn Expr> {
978 type Output = Box<dyn Expr>;
979
980 fn neg(self) -> Self::Output {
981 -&*self
982 }
983}