1use num_traits::Float;
117use std::collections::HashMap;
118use std::fmt::Debug;
119use std::ops::{Add, Div, Mul, Neg, Sub};
120use std::sync::{Arc, RwLock};
121
122pub trait NumericType:
125 Clone + Default + Send + Sync + 'static + std::fmt::Display + std::fmt::Debug
126{
127}
128
129impl<T> NumericType for T where
131 T: Clone + Default + Send + Sync + 'static + std::fmt::Display + std::fmt::Debug
132{
133}
134
135pub trait MathExpr {
139 type Repr<T>;
141
142 fn constant<T: NumericType>(value: T) -> Self::Repr<T>;
144
145 fn var<T: NumericType>(name: &str) -> Self::Repr<T>;
147
148 fn var_by_index<T: NumericType>(index: usize) -> Self::Repr<T>;
150
151 fn add<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
154 where
155 L: NumericType + Add<R, Output = Output>,
156 R: NumericType,
157 Output: NumericType;
158
159 fn sub<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
161 where
162 L: NumericType + Sub<R, Output = Output>,
163 R: NumericType,
164 Output: NumericType;
165
166 fn mul<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
168 where
169 L: NumericType + Mul<R, Output = Output>,
170 R: NumericType,
171 Output: NumericType;
172
173 fn div<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
175 where
176 L: NumericType + Div<R, Output = Output>,
177 R: NumericType,
178 Output: NumericType;
179
180 fn pow<T: NumericType + Float>(base: Self::Repr<T>, exp: Self::Repr<T>) -> Self::Repr<T>;
182
183 fn neg<T: NumericType + Neg<Output = T>>(expr: Self::Repr<T>) -> Self::Repr<T>;
185
186 fn ln<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T>;
188
189 fn exp<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T>;
191
192 fn sqrt<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T>;
194
195 fn sin<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T>;
197
198 fn cos<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T>;
200}
201
202pub mod polynomial {
208 use super::{MathExpr, NumericType};
209 use std::ops::{Add, Mul, Sub};
210
211 pub fn horner<E: MathExpr, T>(coeffs: &[T], x: E::Repr<T>) -> E::Repr<T>
239 where
240 T: NumericType + Clone + Add<Output = T> + Mul<Output = T>,
241 E::Repr<T>: Clone,
242 {
243 if coeffs.is_empty() {
244 return E::constant(T::default());
245 }
246
247 if coeffs.len() == 1 {
248 return E::constant(coeffs[0].clone());
249 }
250
251 let mut result = E::constant(coeffs[coeffs.len() - 1].clone());
253
254 for coeff in coeffs.iter().rev().skip(1) {
256 result = E::add(E::mul(result, x.clone()), E::constant(coeff.clone()));
257 }
258
259 result
260 }
261
262 pub fn horner_expr<E: MathExpr, T>(coeffs: &[E::Repr<T>], x: E::Repr<T>) -> E::Repr<T>
283 where
284 T: NumericType + Add<Output = T> + Mul<Output = T>,
285 E::Repr<T>: Clone,
286 {
287 if coeffs.is_empty() {
288 return E::constant(T::default());
289 }
290
291 if coeffs.len() == 1 {
292 return coeffs[0].clone();
293 }
294
295 let mut result = coeffs[coeffs.len() - 1].clone();
297
298 for coeff in coeffs.iter().rev().skip(1) {
300 result = E::add(E::mul(result, x.clone()), coeff.clone());
301 }
302
303 result
304 }
305
306 pub fn from_roots<E: MathExpr, T>(roots: &[T], x: E::Repr<T>) -> E::Repr<T>
324 where
325 T: NumericType + Clone + Sub<Output = T> + num_traits::One,
326 E::Repr<T>: Clone,
327 {
328 if roots.is_empty() {
329 return E::constant(num_traits::One::one());
330 }
331
332 let mut result = E::sub(x.clone(), E::constant(roots[0].clone()));
333
334 for root in roots.iter().skip(1) {
335 let factor = E::sub(x.clone(), E::constant(root.clone()));
336 result = E::mul(result, factor);
337 }
338
339 result
340 }
341
342 pub fn horner_derivative<E: MathExpr, T>(coeffs: &[T], x: E::Repr<T>) -> E::Repr<T>
361 where
362 T: NumericType + Clone + Add<Output = T> + Mul<Output = T> + num_traits::FromPrimitive,
363 E::Repr<T>: Clone,
364 {
365 if coeffs.len() <= 1 {
366 return E::constant(T::default());
367 }
368
369 let mut deriv_coeffs = Vec::with_capacity(coeffs.len() - 1);
371 for (i, coeff) in coeffs.iter().enumerate().skip(1) {
372 let power = num_traits::FromPrimitive::from_usize(i).unwrap_or_else(|| T::default());
374 deriv_coeffs.push(coeff.clone() * power);
375 }
376
377 horner::<E, T>(&deriv_coeffs, x)
378 }
379}
380
381pub struct DirectEval;
467
468impl DirectEval {
469 #[must_use]
472 pub fn var<T: NumericType>(name: &str, value: T) -> T {
473 value
474 }
475
476 #[must_use]
478 pub fn var_by_index<T: NumericType>(_index: usize, value: T) -> T {
479 value
480 }
481
482 #[must_use]
484 pub fn eval_with_vars<T: NumericType + Float + Copy>(expr: &ASTRepr<T>, variables: &[T]) -> T {
485 Self::eval_vars_optimized(expr, variables)
486 }
487
488 #[must_use]
490 pub fn eval_vars_optimized<T: NumericType + Float + Copy>(
491 expr: &ASTRepr<T>,
492 variables: &[T],
493 ) -> T {
494 match expr {
495 ASTRepr::Constant(value) => *value,
496 ASTRepr::Variable(index) => variables.get(*index).copied().unwrap_or_else(|| T::zero()),
497 ASTRepr::Add(left, right) => {
498 Self::eval_vars_optimized(left, variables)
499 + Self::eval_vars_optimized(right, variables)
500 }
501 ASTRepr::Sub(left, right) => {
502 Self::eval_vars_optimized(left, variables)
503 - Self::eval_vars_optimized(right, variables)
504 }
505 ASTRepr::Mul(left, right) => {
506 Self::eval_vars_optimized(left, variables)
507 * Self::eval_vars_optimized(right, variables)
508 }
509 ASTRepr::Div(left, right) => {
510 Self::eval_vars_optimized(left, variables)
511 / Self::eval_vars_optimized(right, variables)
512 }
513 ASTRepr::Pow(base, exp) => Self::eval_vars_optimized(base, variables)
514 .powf(Self::eval_vars_optimized(exp, variables)),
515 ASTRepr::Neg(inner) => -Self::eval_vars_optimized(inner, variables),
516 ASTRepr::Ln(inner) => Self::eval_vars_optimized(inner, variables).ln(),
517 ASTRepr::Exp(inner) => Self::eval_vars_optimized(inner, variables).exp(),
518 ASTRepr::Sin(inner) => Self::eval_vars_optimized(inner, variables).sin(),
519 ASTRepr::Cos(inner) => Self::eval_vars_optimized(inner, variables).cos(),
520 ASTRepr::Sqrt(inner) => Self::eval_vars_optimized(inner, variables).sqrt(),
521 }
522 }
523
524 #[must_use]
526 pub fn eval_two_vars(expr: &ASTRepr<f64>, x: f64, y: f64) -> f64 {
527 Self::eval_two_vars_fast(expr, x, y)
528 }
529
530 #[must_use]
532 pub fn eval_two_vars_fast(expr: &ASTRepr<f64>, x: f64, y: f64) -> f64 {
533 match expr {
534 ASTRepr::Constant(value) => *value,
535 ASTRepr::Variable(index) => match *index {
536 0 => x,
537 1 => y,
538 _ => 0.0, },
540 ASTRepr::Add(left, right) => {
541 Self::eval_two_vars_fast(left, x, y) + Self::eval_two_vars_fast(right, x, y)
542 }
543 ASTRepr::Sub(left, right) => {
544 Self::eval_two_vars_fast(left, x, y) - Self::eval_two_vars_fast(right, x, y)
545 }
546 ASTRepr::Mul(left, right) => {
547 Self::eval_two_vars_fast(left, x, y) * Self::eval_two_vars_fast(right, x, y)
548 }
549 ASTRepr::Div(left, right) => {
550 Self::eval_two_vars_fast(left, x, y) / Self::eval_two_vars_fast(right, x, y)
551 }
552 ASTRepr::Pow(base, exp) => {
553 Self::eval_two_vars_fast(base, x, y).powf(Self::eval_two_vars_fast(exp, x, y))
554 }
555 ASTRepr::Neg(inner) => -Self::eval_two_vars_fast(inner, x, y),
556 ASTRepr::Ln(inner) => Self::eval_two_vars_fast(inner, x, y).ln(),
557 ASTRepr::Exp(inner) => Self::eval_two_vars_fast(inner, x, y).exp(),
558 ASTRepr::Sin(inner) => Self::eval_two_vars_fast(inner, x, y).sin(),
559 ASTRepr::Cos(inner) => Self::eval_two_vars_fast(inner, x, y).cos(),
560 ASTRepr::Sqrt(inner) => Self::eval_two_vars_fast(inner, x, y).sqrt(),
561 }
562 }
563}
564
565impl MathExpr for DirectEval {
566 type Repr<T> = T;
567
568 fn constant<T: NumericType>(value: T) -> Self::Repr<T> {
569 value
570 }
571
572 fn var<T: NumericType>(name: &str) -> Self::Repr<T> {
573 T::default()
575 }
576
577 fn var_by_index<T: NumericType>(_index: usize) -> Self::Repr<T> {
578 T::default()
579 }
580
581 fn add<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
582 where
583 L: NumericType + Add<R, Output = Output>,
584 R: NumericType,
585 Output: NumericType,
586 {
587 left + right
588 }
589
590 fn sub<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
591 where
592 L: NumericType + Sub<R, Output = Output>,
593 R: NumericType,
594 Output: NumericType,
595 {
596 left - right
597 }
598
599 fn mul<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
600 where
601 L: NumericType + Mul<R, Output = Output>,
602 R: NumericType,
603 Output: NumericType,
604 {
605 left * right
606 }
607
608 fn div<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
609 where
610 L: NumericType + Div<R, Output = Output>,
611 R: NumericType,
612 Output: NumericType,
613 {
614 left / right
615 }
616
617 fn pow<T: NumericType + Float>(base: Self::Repr<T>, exp: Self::Repr<T>) -> Self::Repr<T> {
618 base.powf(exp)
619 }
620
621 fn neg<T: NumericType + Neg<Output = T>>(expr: Self::Repr<T>) -> Self::Repr<T> {
622 -expr
623 }
624
625 fn ln<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
626 expr.ln()
627 }
628
629 fn exp<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
630 expr.exp()
631 }
632
633 fn sqrt<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
634 expr.sqrt()
635 }
636
637 fn sin<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
638 expr.sin()
639 }
640
641 fn cos<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
642 expr.cos()
643 }
644}
645
646pub trait StatisticalExpr: MathExpr {
648 fn logistic<T: NumericType + Float>(x: Self::Repr<T>) -> Self::Repr<T> {
650 let one = Self::constant(T::one());
651 let neg_x = Self::neg(x);
652 let exp_neg_x = Self::exp(neg_x);
653 let denominator = Self::add(one, exp_neg_x);
654 Self::div(Self::constant(T::one()), denominator)
655 }
656
657 fn softplus<T: NumericType + Float>(x: Self::Repr<T>) -> Self::Repr<T> {
659 let one = Self::constant(T::one());
660 let exp_x = Self::exp(x);
661 let one_plus_exp_x = Self::add(one, exp_x);
662 Self::ln(one_plus_exp_x)
663 }
664
665 fn sigmoid<T: NumericType + Float>(x: Self::Repr<T>) -> Self::Repr<T> {
667 Self::logistic(x)
668 }
669}
670
671impl StatisticalExpr for DirectEval {}
673
674pub struct PrettyPrint;
747
748impl PrettyPrint {
749 #[must_use]
751 pub fn var(name: &str) -> String {
752 name.to_string()
753 }
754}
755
756impl MathExpr for PrettyPrint {
757 type Repr<T> = String;
758
759 fn constant<T: NumericType>(value: T) -> Self::Repr<T> {
760 format!("{value}")
761 }
762
763 fn var<T: NumericType>(name: &str) -> Self::Repr<T> {
764 name.to_string()
765 }
766
767 fn var_by_index<T: NumericType>(_index: usize) -> Self::Repr<T> {
768 T::default().to_string()
769 }
770
771 fn add<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
772 where
773 L: NumericType + Add<R, Output = Output>,
774 R: NumericType,
775 Output: NumericType,
776 {
777 format!("({left} + {right})")
778 }
779
780 fn sub<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
781 where
782 L: NumericType + Sub<R, Output = Output>,
783 R: NumericType,
784 Output: NumericType,
785 {
786 format!("({left} - {right})")
787 }
788
789 fn mul<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
790 where
791 L: NumericType + Mul<R, Output = Output>,
792 R: NumericType,
793 Output: NumericType,
794 {
795 format!("({left} * {right})")
796 }
797
798 fn div<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
799 where
800 L: NumericType + Div<R, Output = Output>,
801 R: NumericType,
802 Output: NumericType,
803 {
804 format!("({left} / {right})")
805 }
806
807 fn pow<T: NumericType + Float>(base: Self::Repr<T>, exp: Self::Repr<T>) -> Self::Repr<T> {
808 format!("({base} ^ {exp})")
809 }
810
811 fn neg<T: NumericType + Neg<Output = T>>(expr: Self::Repr<T>) -> Self::Repr<T> {
812 format!("(-{expr})")
813 }
814
815 fn ln<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
816 format!("ln({expr})")
817 }
818
819 fn exp<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
820 format!("exp({expr})")
821 }
822
823 fn sqrt<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
824 format!("sqrt({expr})")
825 }
826
827 fn sin<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
828 format!("sin({expr})")
829 }
830
831 fn cos<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
832 format!("cos({expr})")
833 }
834}
835
836impl StatisticalExpr for PrettyPrint {}
838
839#[derive(Debug, Clone, PartialEq)]
862pub enum ASTRepr<T> {
863 Constant(T),
865 Variable(usize),
867 Add(Box<ASTRepr<T>>, Box<ASTRepr<T>>),
869 Sub(Box<ASTRepr<T>>, Box<ASTRepr<T>>),
871 Mul(Box<ASTRepr<T>>, Box<ASTRepr<T>>),
873 Div(Box<ASTRepr<T>>, Box<ASTRepr<T>>),
875 Pow(Box<ASTRepr<T>>, Box<ASTRepr<T>>),
877 Neg(Box<ASTRepr<T>>),
879 Ln(Box<ASTRepr<T>>),
881 Exp(Box<ASTRepr<T>>),
883 Sqrt(Box<ASTRepr<T>>),
885 Sin(Box<ASTRepr<T>>),
887 Cos(Box<ASTRepr<T>>),
889}
890
891impl<T> ASTRepr<T> {
892 pub fn count_operations(&self) -> usize {
894 match self {
895 ASTRepr::Constant(_) | ASTRepr::Variable(_) => 0,
896 ASTRepr::Add(left, right)
897 | ASTRepr::Sub(left, right)
898 | ASTRepr::Mul(left, right)
899 | ASTRepr::Div(left, right)
900 | ASTRepr::Pow(left, right) => 1 + left.count_operations() + right.count_operations(),
901 ASTRepr::Neg(inner)
902 | ASTRepr::Ln(inner)
903 | ASTRepr::Exp(inner)
904 | ASTRepr::Sin(inner)
905 | ASTRepr::Cos(inner)
906 | ASTRepr::Sqrt(inner) => 1 + inner.count_operations(),
907 }
908 }
909
910 pub fn variable_index(&self) -> Option<usize> {
912 match self {
913 ASTRepr::Variable(index) => Some(*index),
914 _ => None,
915 }
916 }
917}
918
919pub struct ASTEval;
925
926impl ASTEval {
927 #[must_use]
929 pub fn var<T: NumericType>(index: usize) -> ASTRepr<T> {
930 ASTRepr::Variable(index)
931 }
932
933 #[must_use]
936 pub fn var_by_name(_name: &str) -> ASTRepr<f64> {
937 ASTRepr::Variable(0)
939 }
940}
941
942pub trait ASTMathExpr {
945 type Repr;
947
948 fn constant(value: f64) -> Self::Repr;
950
951 fn var(index: usize) -> Self::Repr;
953
954 fn add(left: Self::Repr, right: Self::Repr) -> Self::Repr;
956
957 fn sub(left: Self::Repr, right: Self::Repr) -> Self::Repr;
959
960 fn mul(left: Self::Repr, right: Self::Repr) -> Self::Repr;
962
963 fn div(left: Self::Repr, right: Self::Repr) -> Self::Repr;
965
966 fn pow(base: Self::Repr, exp: Self::Repr) -> Self::Repr;
968
969 fn neg(expr: Self::Repr) -> Self::Repr;
971
972 fn ln(expr: Self::Repr) -> Self::Repr;
974
975 fn exp(expr: Self::Repr) -> Self::Repr;
977
978 fn sqrt(expr: Self::Repr) -> Self::Repr;
980
981 fn sin(expr: Self::Repr) -> Self::Repr;
983
984 fn cos(expr: Self::Repr) -> Self::Repr;
986}
987
988impl ASTMathExpr for ASTEval {
989 type Repr = ASTRepr<f64>;
990
991 fn constant(value: f64) -> Self::Repr {
992 ASTRepr::Constant(value)
993 }
994
995 fn var(index: usize) -> Self::Repr {
996 ASTRepr::Variable(index)
997 }
998
999 fn add(left: Self::Repr, right: Self::Repr) -> Self::Repr {
1000 ASTRepr::Add(Box::new(left), Box::new(right))
1001 }
1002
1003 fn sub(left: Self::Repr, right: Self::Repr) -> Self::Repr {
1004 ASTRepr::Sub(Box::new(left), Box::new(right))
1005 }
1006
1007 fn mul(left: Self::Repr, right: Self::Repr) -> Self::Repr {
1008 ASTRepr::Mul(Box::new(left), Box::new(right))
1009 }
1010
1011 fn div(left: Self::Repr, right: Self::Repr) -> Self::Repr {
1012 ASTRepr::Div(Box::new(left), Box::new(right))
1013 }
1014
1015 fn pow(base: Self::Repr, exp: Self::Repr) -> Self::Repr {
1016 ASTRepr::Pow(Box::new(base), Box::new(exp))
1017 }
1018
1019 fn neg(expr: Self::Repr) -> Self::Repr {
1020 ASTRepr::Neg(Box::new(expr))
1021 }
1022
1023 fn ln(expr: Self::Repr) -> Self::Repr {
1024 ASTRepr::Ln(Box::new(expr))
1025 }
1026
1027 fn exp(expr: Self::Repr) -> Self::Repr {
1028 ASTRepr::Exp(Box::new(expr))
1029 }
1030
1031 fn sqrt(expr: Self::Repr) -> Self::Repr {
1032 ASTRepr::Sqrt(Box::new(expr))
1033 }
1034
1035 fn sin(expr: Self::Repr) -> Self::Repr {
1036 ASTRepr::Sin(Box::new(expr))
1037 }
1038
1039 fn cos(expr: Self::Repr) -> Self::Repr {
1040 ASTRepr::Cos(Box::new(expr))
1041 }
1042}
1043
1044impl MathExpr for ASTEval {
1047 type Repr<T> = ASTRepr<T>;
1048
1049 fn constant<T: NumericType>(value: T) -> Self::Repr<T> {
1050 ASTRepr::Constant(value)
1051 }
1052
1053 fn var<T: NumericType>(_name: &str) -> Self::Repr<T> {
1054 ASTRepr::Variable(0)
1056 }
1057
1058 fn var_by_index<T: NumericType>(index: usize) -> Self::Repr<T> {
1059 ASTRepr::Variable(index)
1060 }
1061
1062 fn add<L, R, Output>(_left: Self::Repr<L>, _right: Self::Repr<R>) -> Self::Repr<Output>
1063 where
1064 L: NumericType + Add<R, Output = Output>,
1065 R: NumericType,
1066 Output: NumericType,
1067 {
1068 unimplemented!("Use ASTMathExpr or ASTMathExprf64 for concrete implementations")
1071 }
1072
1073 fn sub<L, R, Output>(_left: Self::Repr<L>, _right: Self::Repr<R>) -> Self::Repr<Output>
1074 where
1075 L: NumericType + Sub<R, Output = Output>,
1076 R: NumericType,
1077 Output: NumericType,
1078 {
1079 unimplemented!("Use ASTMathExpr or ASTMathExprf64 for concrete implementations")
1080 }
1081
1082 fn mul<L, R, Output>(_left: Self::Repr<L>, _right: Self::Repr<R>) -> Self::Repr<Output>
1083 where
1084 L: NumericType + Mul<R, Output = Output>,
1085 R: NumericType,
1086 Output: NumericType,
1087 {
1088 unimplemented!("Use ASTMathExpr or ASTMathExprf64 for concrete implementations")
1089 }
1090
1091 fn div<L, R, Output>(_left: Self::Repr<L>, _right: Self::Repr<R>) -> Self::Repr<Output>
1092 where
1093 L: NumericType + Div<R, Output = Output>,
1094 R: NumericType,
1095 Output: NumericType,
1096 {
1097 unimplemented!("Use ASTMathExpr or ASTMathExprf64 for concrete implementations")
1098 }
1099
1100 fn pow<T: NumericType + Float>(base: Self::Repr<T>, exp: Self::Repr<T>) -> Self::Repr<T> {
1101 ASTRepr::Pow(Box::new(base), Box::new(exp))
1102 }
1103
1104 fn neg<T: NumericType + Neg<Output = T>>(expr: Self::Repr<T>) -> Self::Repr<T> {
1105 ASTRepr::Neg(Box::new(expr))
1106 }
1107
1108 fn ln<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
1109 ASTRepr::Ln(Box::new(expr))
1110 }
1111
1112 fn exp<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
1113 ASTRepr::Exp(Box::new(expr))
1114 }
1115
1116 fn sqrt<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
1117 ASTRepr::Sqrt(Box::new(expr))
1118 }
1119
1120 fn sin<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
1121 ASTRepr::Sin(Box::new(expr))
1122 }
1123
1124 fn cos<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
1125 ASTRepr::Cos(Box::new(expr))
1126 }
1127}
1128
1129impl StatisticalExpr for ASTEval {}
1130
1131pub trait ASTMathExprf64 {
1133 type Repr;
1135
1136 fn constant(value: f64) -> Self::Repr;
1138
1139 fn var(index: usize) -> Self::Repr;
1141
1142 fn add(left: Self::Repr, right: Self::Repr) -> Self::Repr;
1144
1145 fn sub(left: Self::Repr, right: Self::Repr) -> Self::Repr;
1147
1148 fn mul(left: Self::Repr, right: Self::Repr) -> Self::Repr;
1150
1151 fn div(left: Self::Repr, right: Self::Repr) -> Self::Repr;
1153
1154 fn pow(base: Self::Repr, exp: Self::Repr) -> Self::Repr;
1156
1157 fn neg(expr: Self::Repr) -> Self::Repr;
1159
1160 fn ln(expr: Self::Repr) -> Self::Repr;
1162
1163 fn exp(expr: Self::Repr) -> Self::Repr;
1165
1166 fn sqrt(expr: Self::Repr) -> Self::Repr;
1168
1169 fn sin(expr: Self::Repr) -> Self::Repr;
1171
1172 fn cos(expr: Self::Repr) -> Self::Repr;
1174}
1175
1176impl ASTMathExprf64 for ASTEval {
1177 type Repr = ASTRepr<f64>;
1178
1179 fn constant(value: f64) -> Self::Repr {
1180 ASTRepr::Constant(value)
1181 }
1182
1183 fn var(index: usize) -> Self::Repr {
1184 ASTRepr::Variable(index)
1185 }
1186
1187 fn add(left: Self::Repr, right: Self::Repr) -> Self::Repr {
1188 ASTRepr::Add(Box::new(left), Box::new(right))
1189 }
1190
1191 fn sub(left: Self::Repr, right: Self::Repr) -> Self::Repr {
1192 ASTRepr::Sub(Box::new(left), Box::new(right))
1193 }
1194
1195 fn mul(left: Self::Repr, right: Self::Repr) -> Self::Repr {
1196 ASTRepr::Mul(Box::new(left), Box::new(right))
1197 }
1198
1199 fn div(left: Self::Repr, right: Self::Repr) -> Self::Repr {
1200 ASTRepr::Div(Box::new(left), Box::new(right))
1201 }
1202
1203 fn pow(base: Self::Repr, exp: Self::Repr) -> Self::Repr {
1204 ASTRepr::Pow(Box::new(base), Box::new(exp))
1205 }
1206
1207 fn neg(expr: Self::Repr) -> Self::Repr {
1208 ASTRepr::Neg(Box::new(expr))
1209 }
1210
1211 fn ln(expr: Self::Repr) -> Self::Repr {
1212 ASTRepr::Ln(Box::new(expr))
1213 }
1214
1215 fn exp(expr: Self::Repr) -> Self::Repr {
1216 ASTRepr::Exp(Box::new(expr))
1217 }
1218
1219 fn sqrt(expr: Self::Repr) -> Self::Repr {
1220 ASTRepr::Sqrt(Box::new(expr))
1221 }
1222
1223 fn sin(expr: Self::Repr) -> Self::Repr {
1224 ASTRepr::Sin(Box::new(expr))
1225 }
1226
1227 fn cos(expr: Self::Repr) -> Self::Repr {
1228 ASTRepr::Cos(Box::new(expr))
1229 }
1230}
1231
1232pub trait RangeType: Clone + Send + Sync + 'static + std::fmt::Debug {
1241 type IndexType: NumericType;
1243
1244 fn start(&self) -> Self::IndexType;
1246
1247 fn end(&self) -> Self::IndexType;
1249
1250 fn contains(&self, value: &Self::IndexType) -> bool;
1252
1253 fn len(&self) -> Self::IndexType;
1255
1256 fn is_empty(&self) -> bool;
1258}
1259
1260#[derive(Debug, Clone, PartialEq, Eq)]
1276pub struct IntRange {
1277 pub start: i64,
1278 pub end: i64, }
1280
1281impl IntRange {
1282 #[must_use]
1284 pub fn new(start: i64, end: i64) -> Self {
1285 Self { start, end }
1286 }
1287
1288 #[must_use]
1290 pub fn one_to_n(n: i64) -> Self {
1291 Self::new(1, n)
1292 }
1293
1294 #[must_use]
1296 pub fn zero_to_n_minus_one(n: i64) -> Self {
1297 Self::new(0, n - 1)
1298 }
1299
1300 pub fn iter(&self) -> impl Iterator<Item = i64> {
1302 self.start..=self.end
1303 }
1304}
1305
1306impl RangeType for IntRange {
1307 type IndexType = i64;
1308
1309 fn start(&self) -> Self::IndexType {
1310 self.start
1311 }
1312
1313 fn end(&self) -> Self::IndexType {
1314 self.end
1315 }
1316
1317 fn contains(&self, value: &Self::IndexType) -> bool {
1318 *value >= self.start && *value <= self.end
1319 }
1320
1321 fn len(&self) -> Self::IndexType {
1322 if self.end >= self.start {
1323 self.end - self.start + 1
1324 } else {
1325 0
1326 }
1327 }
1328
1329 fn is_empty(&self) -> bool {
1330 self.end < self.start
1331 }
1332}
1333
1334#[derive(Debug, Clone, PartialEq)]
1339pub struct FloatRange {
1340 pub start: f64,
1341 pub end: f64,
1342 pub step: f64,
1343}
1344
1345impl FloatRange {
1346 #[must_use]
1348 pub fn new(start: f64, end: f64, step: f64) -> Self {
1349 Self { start, end, step }
1350 }
1351
1352 #[must_use]
1354 pub fn unit_step(start: f64, end: f64) -> Self {
1355 Self::new(start, end, 1.0)
1356 }
1357}
1358
1359impl RangeType for FloatRange {
1360 type IndexType = f64;
1361
1362 fn start(&self) -> Self::IndexType {
1363 self.start
1364 }
1365
1366 fn end(&self) -> Self::IndexType {
1367 self.end
1368 }
1369
1370 fn contains(&self, value: &Self::IndexType) -> bool {
1371 *value >= self.start && *value <= self.end
1372 }
1373
1374 fn len(&self) -> Self::IndexType {
1375 if self.end >= self.start && self.step > 0.0 {
1376 ((self.end - self.start) / self.step).floor() + 1.0
1377 } else {
1378 0.0
1379 }
1380 }
1381
1382 fn is_empty(&self) -> bool {
1383 self.end < self.start || self.step <= 0.0
1384 }
1385}
1386
1387#[derive(Debug, Clone)]
1404pub struct SymbolicRange<T> {
1405 pub start: Box<ASTRepr<T>>,
1406 pub end: Box<ASTRepr<T>>,
1407}
1408
1409impl<T: NumericType> SymbolicRange<T> {
1410 pub fn new(start: ASTRepr<T>, end: ASTRepr<T>) -> Self {
1412 Self {
1413 start: Box::new(start),
1414 end: Box::new(end),
1415 }
1416 }
1417
1418 pub fn one_to_expr(end: ASTRepr<T>) -> Self
1420 where
1421 T: num_traits::One,
1422 {
1423 Self::new(ASTRepr::Constant(T::one()), end)
1424 }
1425
1426 pub fn evaluate_bounds(&self, variables: &[T]) -> Option<(T, T)>
1428 where
1429 T: Float + Copy,
1430 {
1431 let start_val = DirectEval::eval_with_vars(&self.start, variables);
1432 let end_val = DirectEval::eval_with_vars(&self.end, variables);
1433 Some((start_val, end_val))
1434 }
1435}
1436
1437pub trait SummandFunction<T>: Clone + std::fmt::Debug {
1445 type Body: Clone;
1447
1448 fn index_var(&self) -> &str;
1450
1451 fn body(&self) -> &Self::Body;
1453
1454 fn apply(&self, index: T) -> Self::Body;
1456
1457 fn depends_on_index(&self) -> bool;
1459
1460 fn extract_independent_factors(&self) -> (Vec<Self::Body>, Self::Body);
1463}
1464
1465#[derive(Debug, Clone)]
1488pub struct ASTFunction<T> {
1489 pub index_var: String,
1490 pub body: ASTRepr<T>,
1491}
1492
1493impl<T: NumericType> ASTFunction<T> {
1494 pub fn new(index_var: &str, body: ASTRepr<T>) -> Self {
1496 Self {
1497 index_var: index_var.to_string(),
1498 body,
1499 }
1500 }
1501
1502 pub fn linear(index_var: &str, coefficient: T, constant: T) -> Self {
1504 let body = ASTRepr::Add(
1505 Box::new(ASTRepr::Mul(
1506 Box::new(ASTRepr::Constant(coefficient)),
1507 Box::new(ASTRepr::Variable(0)), )),
1509 Box::new(ASTRepr::Constant(constant)),
1510 );
1511 Self::new(index_var, body)
1512 }
1513
1514 pub fn power(index_var: &str, exponent: T) -> Self {
1516 let body = ASTRepr::Pow(
1517 Box::new(ASTRepr::Variable(0)), Box::new(ASTRepr::Constant(exponent)),
1519 );
1520 Self::new(index_var, body)
1521 }
1522
1523 pub fn constant_func(index_var: &str, value: T) -> Self {
1525 let body = ASTRepr::Constant(value);
1526 Self::new(index_var, body)
1527 }
1528}
1529
1530impl<T: NumericType + Float + Copy> SummandFunction<T> for ASTFunction<T> {
1531 type Body = ASTRepr<T>;
1532
1533 fn index_var(&self) -> &str {
1534 &self.index_var
1535 }
1536
1537 fn body(&self) -> &Self::Body {
1538 &self.body
1539 }
1540
1541 fn apply(&self, index: T) -> Self::Body {
1542 self.substitute_variable(&self.index_var, index)
1545 }
1546
1547 fn depends_on_index(&self) -> bool {
1548 self.contains_variable(&self.body, &self.index_var)
1549 }
1550
1551 fn extract_independent_factors(&self) -> (Vec<Self::Body>, Self::Body) {
1552 self.extract_factors_recursive(&self.body)
1555 }
1556}
1557
1558impl<T: NumericType + Copy> ASTFunction<T> {
1559 fn substitute_variable(&self, var_name: &str, value: T) -> ASTRepr<T> {
1561 self.substitute_in_expr(&self.body, var_name, value)
1562 }
1563
1564 fn substitute_in_expr(&self, expr: &ASTRepr<T>, var_name: &str, value: T) -> ASTRepr<T> {
1566 match expr {
1567 ASTRepr::Constant(c) => ASTRepr::Constant(*c),
1568 ASTRepr::Variable(index) => {
1569 let expected_index = match var_name {
1571 "i" => 0,
1572 "j" => 1,
1573 "k" => 2,
1574 "x" => 0,
1575 "y" => 1,
1576 "z" => 2,
1577 _ => {
1578 if let Some(idx) = get_variable_index(var_name) {
1580 idx
1581 } else {
1582 return expr.clone();
1584 }
1585 }
1586 };
1587
1588 if *index == expected_index {
1589 ASTRepr::Constant(value)
1590 } else {
1591 expr.clone()
1592 }
1593 }
1594 ASTRepr::Add(left, right) => ASTRepr::Add(
1595 Box::new(self.substitute_in_expr(left, var_name, value)),
1596 Box::new(self.substitute_in_expr(right, var_name, value)),
1597 ),
1598 ASTRepr::Sub(left, right) => ASTRepr::Sub(
1599 Box::new(self.substitute_in_expr(left, var_name, value)),
1600 Box::new(self.substitute_in_expr(right, var_name, value)),
1601 ),
1602 ASTRepr::Mul(left, right) => ASTRepr::Mul(
1603 Box::new(self.substitute_in_expr(left, var_name, value)),
1604 Box::new(self.substitute_in_expr(right, var_name, value)),
1605 ),
1606 ASTRepr::Div(left, right) => ASTRepr::Div(
1607 Box::new(self.substitute_in_expr(left, var_name, value)),
1608 Box::new(self.substitute_in_expr(right, var_name, value)),
1609 ),
1610 ASTRepr::Pow(base, exp) => ASTRepr::Pow(
1611 Box::new(self.substitute_in_expr(base, var_name, value)),
1612 Box::new(self.substitute_in_expr(exp, var_name, value)),
1613 ),
1614 ASTRepr::Neg(inner) => {
1615 ASTRepr::Neg(Box::new(self.substitute_in_expr(inner, var_name, value)))
1616 }
1617 ASTRepr::Ln(inner) => {
1618 ASTRepr::Ln(Box::new(self.substitute_in_expr(inner, var_name, value)))
1619 }
1620 ASTRepr::Exp(inner) => {
1621 ASTRepr::Exp(Box::new(self.substitute_in_expr(inner, var_name, value)))
1622 }
1623 ASTRepr::Sin(inner) => {
1624 ASTRepr::Sin(Box::new(self.substitute_in_expr(inner, var_name, value)))
1625 }
1626 ASTRepr::Cos(inner) => {
1627 ASTRepr::Cos(Box::new(self.substitute_in_expr(inner, var_name, value)))
1628 }
1629 ASTRepr::Sqrt(inner) => {
1630 ASTRepr::Sqrt(Box::new(self.substitute_in_expr(inner, var_name, value)))
1631 }
1632 }
1633 }
1634
1635 fn contains_variable(&self, expr: &ASTRepr<T>, var_name: &str) -> bool {
1638 match expr {
1639 ASTRepr::Constant(_) => false,
1640 ASTRepr::Variable(index) => {
1641 let expected_index = match var_name {
1644 "i" => 0,
1645 "j" => 1,
1646 "k" => 2,
1647 "x" => 0,
1648 "y" => 1,
1649 "z" => 2,
1650 _ => {
1651 if let Some(idx) = get_variable_index(var_name) {
1653 idx
1654 } else {
1655 return false;
1657 }
1658 }
1659 };
1660 *index == expected_index
1661 }
1662 ASTRepr::Add(left, right)
1663 | ASTRepr::Sub(left, right)
1664 | ASTRepr::Mul(left, right)
1665 | ASTRepr::Div(left, right)
1666 | ASTRepr::Pow(left, right) => {
1667 self.contains_variable(left, var_name) || self.contains_variable(right, var_name)
1668 }
1669 ASTRepr::Neg(inner)
1670 | ASTRepr::Ln(inner)
1671 | ASTRepr::Exp(inner)
1672 | ASTRepr::Sin(inner)
1673 | ASTRepr::Cos(inner)
1674 | ASTRepr::Sqrt(inner) => self.contains_variable(inner, var_name),
1675 }
1676 }
1677
1678 fn extract_factors_recursive(&self, expr: &ASTRepr<T>) -> (Vec<ASTRepr<T>>, ASTRepr<T>)
1680 where
1681 T: One,
1682 {
1683 match expr {
1684 ASTRepr::Mul(left, right) => {
1686 let left_depends = self.contains_variable(left, &self.index_var);
1687 let right_depends = self.contains_variable(right, &self.index_var);
1688
1689 match (left_depends, right_depends) {
1690 (false, false) => {
1691 (vec![expr.clone()], ASTRepr::Constant(T::one()))
1693 }
1694 (false, true) => {
1695 (vec![(**left).clone()], (**right).clone())
1697 }
1698 (true, false) => {
1699 (vec![(**right).clone()], (**left).clone())
1701 }
1702 (true, true) => {
1703 (vec![], expr.clone())
1705 }
1706 }
1707 }
1708 _ => {
1710 if self.contains_variable(expr, &self.index_var) {
1711 (vec![], expr.clone())
1712 } else {
1713 (vec![expr.clone()], ASTRepr::Constant(T::one()))
1714 }
1715 }
1716 }
1717 }
1718}
1719
1720use num_traits::One;
1722
1723pub trait SummationExpr: MathExpr {
1729 fn sum_finite<T, R, F>(range: Self::Repr<R>, function: Self::Repr<F>) -> Self::Repr<T>
1734 where
1735 T: NumericType,
1736 R: RangeType,
1737 F: SummandFunction<T>,
1738 Self::Repr<T>: Clone;
1739
1740 fn sum_infinite<T, F>(start: Self::Repr<T>, function: Self::Repr<F>) -> Self::Repr<T>
1745 where
1746 T: NumericType,
1747 F: SummandFunction<T>,
1748 Self::Repr<T>: Clone;
1749
1750 fn sum_telescoping<T, F>(range: Self::Repr<IntRange>, function: Self::Repr<F>) -> Self::Repr<T>
1755 where
1756 T: NumericType,
1757 F: SummandFunction<T>;
1758
1759 fn range_to<T: NumericType>(start: Self::Repr<T>, end: Self::Repr<T>) -> Self::Repr<IntRange>;
1761
1762 fn function<T: NumericType>(index_var: &str, body: Self::Repr<T>)
1764 -> Self::Repr<ASTFunction<T>>;
1765}
1766
1767impl<T> ASTRepr<T> {
1769 pub fn count_summation_operations(&self) -> usize {
1780 0
1783 }
1784}
1785
1786#[cfg(test)]
1787mod tests {
1788 use super::*;
1789
1790 #[test]
1791 fn test_direct_eval() {
1792 fn linear<E: MathExpr>(x: E::Repr<f64>) -> E::Repr<f64>
1793 where
1794 E: MathExpr,
1795 {
1796 E::add(E::mul(E::constant(2.0), x), E::constant(1.0))
1797 }
1798
1799 let result = linear::<DirectEval>(DirectEval::var("x", 5.0));
1800 assert_eq!(result, 11.0); }
1802
1803 #[test]
1804 fn test_statistical_extension() {
1805 fn logistic_expr<E: StatisticalExpr>(x: E::Repr<f64>) -> E::Repr<f64>
1806 where
1807 E: StatisticalExpr,
1808 {
1809 E::logistic(x)
1810 }
1811
1812 let result = logistic_expr::<DirectEval>(DirectEval::var("x", 0.0));
1813 assert!((result - 0.5).abs() < 1e-10); }
1815
1816 #[test]
1817 fn test_pretty_print() {
1818 fn quadratic<E: MathExpr>(x: E::Repr<f64>) -> E::Repr<f64>
1819 where
1820 E: MathExpr,
1821 E::Repr<f64>: Clone,
1822 {
1823 let a = E::constant(2.0);
1824 let b = E::constant(3.0);
1825 let c = E::constant(1.0);
1826
1827 E::add(
1828 E::add(E::mul(a, E::pow(x.clone(), E::constant(2.0))), E::mul(b, x)),
1829 c,
1830 )
1831 }
1832
1833 let expr = quadratic::<PrettyPrint>(PrettyPrint::var("x"));
1834 assert!(expr.contains('x'));
1835 assert!(expr.contains('2'));
1836 assert!(expr.contains('3'));
1837 assert!(expr.contains('1'));
1838 }
1839
1840 #[test]
1841 fn test_horner_polynomial() {
1842 let coeffs = [1.0, 2.0, 3.0];
1845 let x = DirectEval::var("x", 2.0);
1846 let result = polynomial::horner::<DirectEval, f64>(&coeffs, x);
1847 assert_eq!(result, 17.0);
1848 }
1849
1850 #[test]
1851 fn test_horner_pretty_print() {
1852 let coeffs = [1.0, 2.0, 3.0];
1853 let x = PrettyPrint::var("x");
1854 let result = polynomial::horner::<PrettyPrint, f64>(&coeffs, x);
1855 assert!(result.contains('x'));
1856 }
1857
1858 #[test]
1859 fn test_polynomial_from_roots() {
1860 let roots = [1.0, 2.0];
1863 let x = DirectEval::var("x", 0.0);
1864 let result = polynomial::from_roots::<DirectEval, f64>(&roots, x);
1865 assert_eq!(result, 2.0);
1866
1867 let x = DirectEval::var("x", 3.0);
1869 let result = polynomial::from_roots::<DirectEval, f64>(&roots, x);
1870 assert_eq!(result, 2.0);
1871 }
1872
1873 #[test]
1874 fn test_division_operations() {
1875 let div_1_3: f64 = DirectEval::div(DirectEval::constant(1.0), DirectEval::constant(3.0));
1876 assert!((div_1_3 - 1.0 / 3.0).abs() < 1e-10);
1877
1878 let div_10_2: f64 = DirectEval::div(DirectEval::constant(10.0), DirectEval::constant(2.0));
1879 assert!((div_10_2 - 5.0).abs() < 1e-10);
1880
1881 let div_by_one: f64 =
1883 DirectEval::div(DirectEval::constant(42.0), DirectEval::constant(1.0));
1884 assert!((div_by_one - 42.0).abs() < 1e-10);
1885 }
1886
1887 #[test]
1888 fn test_transcendental_functions() {
1889 let ln_e: f64 = DirectEval::ln(DirectEval::constant(std::f64::consts::E));
1891 assert!((ln_e - 1.0).abs() < 1e-10);
1892
1893 let exp_1: f64 = DirectEval::exp(DirectEval::constant(1.0));
1895 assert!((exp_1 - std::f64::consts::E).abs() < 1e-10);
1896
1897 let sqrt_4: f64 = DirectEval::sqrt(DirectEval::constant(4.0));
1899 assert!((sqrt_4 - 2.0).abs() < 1e-10);
1900
1901 let sin_pi_2: f64 = DirectEval::sin(DirectEval::constant(std::f64::consts::PI / 2.0));
1903 assert!((sin_pi_2 - 1.0).abs() < 1e-10);
1904
1905 let cos_0: f64 = DirectEval::cos(DirectEval::constant(0.0));
1907 assert!((cos_0 - 1.0).abs() < 1e-10);
1908 }
1909
1910 #[test]
1911 fn test_pretty_print_basic() {
1912 let var_x = PrettyPrint::var("x");
1914 assert_eq!(var_x, "x");
1915
1916 let const_5 = PrettyPrint::constant::<f64>(5.0);
1918 assert_eq!(const_5, "5");
1919
1920 let add_expr =
1922 PrettyPrint::add::<f64, f64, f64>(PrettyPrint::var("x"), PrettyPrint::constant(1.0));
1923 assert_eq!(add_expr, "(x + 1)");
1924 }
1925
1926 #[test]
1927 fn test_efficient_variable_indexing() {
1928 let expr = ASTRepr::Add(
1930 Box::new(ASTRepr::Variable(0)), Box::new(ASTRepr::Variable(1)), );
1933 let result = DirectEval::eval_with_vars(&expr, &[2.0, 3.0]);
1934 assert_eq!(result, 5.0);
1935
1936 let expr = ASTRepr::Mul(
1938 Box::new(ASTRepr::Variable(0)), Box::new(ASTRepr::Variable(1)), );
1941 let result = DirectEval::eval_with_vars(&expr, &[4.0, 5.0]);
1942 assert_eq!(result, 20.0);
1943 }
1944
1945 #[test]
1946 fn test_mixed_variable_types() {
1947 let expr = ASTRepr::Add(
1949 Box::new(ASTRepr::Variable(0)), Box::new(ASTRepr::Variable(1)), );
1952 let result = DirectEval::eval_with_vars(&expr, &[2.0, 3.0]);
1953 assert_eq!(result, 5.0);
1954 }
1955
1956 #[test]
1957 fn test_variable_index_access() {
1958 let expr: ASTRepr<f64> = ASTRepr::Variable(5);
1959 assert_eq!(expr.variable_index(), Some(5));
1960
1961 let expr: ASTRepr<f64> = ASTRepr::Constant(42.0);
1962 assert_eq!(expr.variable_index(), None);
1963 }
1964
1965 #[test]
1966 fn test_out_of_bounds_variable_index() {
1967 let expr = ASTRepr::Variable(10); let result = DirectEval::eval_with_vars(&expr, &[1.0, 2.0]);
1970 assert_eq!(result, 0.0); }
1972
1973 #[test]
1978 fn test_int_range() {
1979 let range = IntRange::new(1, 10);
1980 assert_eq!(range.start(), 1);
1981 assert_eq!(range.end(), 10);
1982 assert_eq!(range.len(), 10);
1983 assert!(range.contains(&5));
1984 assert!(!range.contains(&15));
1985 assert!(!range.is_empty());
1986
1987 let empty_range = IntRange::new(5, 3);
1988 assert!(empty_range.is_empty());
1989 assert_eq!(empty_range.len(), 0);
1990 }
1991
1992 #[test]
1993 fn test_float_range() {
1994 let range = FloatRange::new(1.0, 10.0, 1.0);
1995 assert_eq!(range.start(), 1.0);
1996 assert_eq!(range.end(), 10.0);
1997 assert_eq!(range.len(), 10.0);
1998 assert!(range.contains(&5.5));
1999 assert!(!range.contains(&15.0));
2000
2001 let empty_range = FloatRange::new(5.0, 3.0, 1.0);
2002 assert!(empty_range.is_empty());
2003 }
2004
2005 #[test]
2006 fn test_symbolic_range() {
2007 let range = SymbolicRange::new(
2009 ASTRepr::Constant(1.0),
2010 ASTRepr::Variable(0), );
2012
2013 let bounds = range.evaluate_bounds(&[10.0]);
2015 assert_eq!(bounds, Some((1.0, 10.0)));
2016
2017 let range2 = SymbolicRange::new(
2019 ASTRepr::Variable(0), ASTRepr::Variable(1), );
2022
2023 let bounds2 = range2.evaluate_bounds(&[2.0, 8.0]);
2024 assert_eq!(bounds2, Some((2.0, 8.0)));
2025 }
2026
2027 #[test]
2028 fn test_ast_function_creation() {
2029 let func = ASTFunction::linear("i", 2.0, 3.0);
2031 assert_eq!(func.index_var(), "i");
2032 assert!(func.depends_on_index());
2033
2034 let const_func = ASTFunction::constant_func("i", 42.0);
2036 assert!(!const_func.depends_on_index());
2037 }
2038
2039 #[test]
2040 fn test_ast_function_substitution() {
2041 let func = ASTFunction::linear("i", 2.0, 3.0);
2043 let result = func.apply(5.0);
2044
2045 let evaluated = DirectEval::eval_with_vars(&result, &[]);
2047 assert_eq!(evaluated, 13.0); }
2049
2050 #[test]
2051 fn test_ast_function_factor_extraction() {
2052 let func = ASTFunction::new(
2054 "i",
2055 ASTRepr::Mul(
2056 Box::new(ASTRepr::Constant(3.0)),
2057 Box::new(ASTRepr::Variable(0)), ),
2059 );
2060
2061 let (factors, remaining) = func.extract_independent_factors();
2062 assert_eq!(factors.len(), 1); if let Some(ASTRepr::Constant(value)) = factors.first() {
2066 assert_eq!(*value, 3.0);
2067 } else {
2068 panic!("Expected constant factor");
2069 }
2070 }
2071
2072 #[test]
2073 fn test_range_convenience_methods() {
2074 let range_1_to_n = IntRange::one_to_n(10);
2075 assert_eq!(range_1_to_n.start(), 1);
2076 assert_eq!(range_1_to_n.end(), 10);
2077
2078 let range_0_to_n_minus_1 = IntRange::zero_to_n_minus_one(10);
2079 assert_eq!(range_0_to_n_minus_1.start(), 0);
2080 assert_eq!(range_0_to_n_minus_1.end(), 9);
2081 }
2082
2083 #[test]
2084 fn test_power_function() {
2085 let func = ASTFunction::power("i", 2.0);
2087 assert!(func.depends_on_index());
2088
2089 let result = func.apply(3.0);
2091 let evaluated = DirectEval::eval_with_vars(&result, &[]);
2092 assert_eq!(evaluated, 9.0);
2093 }
2094
2095 #[test]
2096 fn test_variable_registry() {
2097 let mut builder = ExpressionBuilder::new();
2099
2100 let x_index = builder.register_variable("x");
2102 let y_index = builder.register_variable("y");
2103 let x_index_again = builder.register_variable("x"); assert_ne!(x_index, y_index);
2107 assert_eq!(x_index_again, x_index);
2109
2110 assert_eq!(builder.get_variable_index("x"), Some(x_index));
2112 assert_eq!(builder.get_variable_index("y"), Some(y_index));
2113 assert_eq!(builder.get_variable_index("z"), None);
2114
2115 assert_eq!(builder.get_variable_name(x_index), Some("x"));
2116 assert_eq!(builder.get_variable_name(y_index), Some("y"));
2117 let max_index = std::cmp::max(x_index, y_index);
2119 assert_eq!(builder.get_variable_name(max_index + 10), None);
2120 }
2121
2122 #[test]
2123 fn test_named_variable_evaluation() {
2124 let mut builder = ExpressionBuilder::new();
2126
2127 let expr = ASTRepr::Add(Box::new(builder.var("x")), Box::new(builder.var("y")));
2129
2130 let named_vars = vec![("x".to_string(), 3.0), ("y".to_string(), 4.0)];
2132 let result = builder.eval_with_named_vars(&expr, &named_vars);
2133 assert_eq!(result, 7.0);
2134 }
2135
2136 #[test]
2137 fn test_mixed_variable_access() {
2138 let mut builder = ExpressionBuilder::new();
2140
2141 let x_idx = builder.register_variable("x");
2143 let y_idx = builder.register_variable("y");
2144
2145 let expr = ASTRepr::Mul(
2147 Box::new(ASTRepr::Variable(x_idx)),
2148 Box::new(ASTRepr::Variable(y_idx)),
2149 );
2150
2151 let result1 = builder.eval_with_vars(&expr, &[2.0, 5.0]);
2153 assert_eq!(result1, 10.0);
2154
2155 let named_vars = vec![("x".to_string(), 2.0), ("y".to_string(), 5.0)];
2157 let result2 = builder.eval_with_named_vars(&expr, &named_vars);
2158 assert_eq!(result2, 10.0);
2159 }
2160
2161 #[test]
2162 fn test_variable_registry_performance() {
2163 let mut builder = ExpressionBuilder::new();
2165
2166 let start_count = builder.num_variables();
2168 assert_eq!(start_count, 0); let mut indices = Vec::new();
2172 for i in 0..1000 {
2173 let var_name = format!("perf_test_var_{i}");
2174 let index = builder.register_variable(&var_name);
2175 indices.push(index);
2176 assert_eq!(index, i); }
2178
2179 for i in 0..1000 {
2181 let var_name = format!("perf_test_var_{i}");
2182 let found_index = builder.get_variable_index(&var_name);
2183 assert_eq!(found_index, Some(i));
2184
2185 let found_name = builder.get_variable_name(i);
2186 assert_eq!(found_name, Some(var_name.as_str()));
2187 }
2188
2189 let final_count = builder.num_variables();
2191 assert_eq!(final_count, 1000);
2192 }
2193
2194 #[test]
2195 fn test_generic_operator_overloading() {
2196 let x_f64 = ASTRepr::<f64>::Variable(0);
2198 let y_f64 = ASTRepr::<f64>::Variable(1);
2199 let const_f64 = ASTRepr::<f64>::Constant(2.5);
2200
2201 let expr_f64 = &x_f64 + &y_f64 * &const_f64;
2202 assert_eq!(expr_f64.count_operations(), 2); let x_f32 = ASTRepr::<f32>::Variable(0);
2206 let y_f32 = ASTRepr::<f32>::Variable(1);
2207 let const_f32 = ASTRepr::<f32>::Constant(2.5_f32);
2208
2209 let expr_f32 = &x_f32 + &y_f32 * &const_f32;
2210 assert_eq!(expr_f32.count_operations(), 2); let neg_f64 = -&x_f64;
2214 let neg_f32 = -&x_f32;
2215
2216 match neg_f64 {
2217 ASTRepr::Neg(_) => {}
2218 _ => panic!("Expected negation"),
2219 }
2220
2221 match neg_f32 {
2222 ASTRepr::Neg(_) => {}
2223 _ => panic!("Expected negation"),
2224 }
2225
2226 let sin_f64 = x_f64.sin();
2228 let exp_f32 = x_f32.exp();
2229
2230 match sin_f64 {
2231 ASTRepr::Sin(_) => {}
2232 _ => panic!("Expected sine"),
2233 }
2234
2235 match exp_f32 {
2236 ASTRepr::Exp(_) => {}
2237 _ => panic!("Expected exponential"),
2238 }
2239 }
2240}
2241
2242#[derive(Debug, Clone)]
2246pub struct VariableRegistry {
2247 name_to_index: HashMap<String, usize>,
2249 index_to_name: Vec<String>,
2251}
2252
2253impl VariableRegistry {
2254 #[must_use]
2256 pub fn new() -> Self {
2257 Self {
2258 name_to_index: HashMap::new(),
2259 index_to_name: Vec::new(),
2260 }
2261 }
2262
2263 pub fn register_variable(&mut self, name: &str) -> usize {
2266 if let Some(&index) = self.name_to_index.get(name) {
2267 index
2268 } else {
2269 let index = self.index_to_name.len();
2270 self.name_to_index.insert(name.to_string(), index);
2271 self.index_to_name.push(name.to_string());
2272 index
2273 }
2274 }
2275
2276 #[must_use]
2278 pub fn get_index(&self, name: &str) -> Option<usize> {
2279 self.name_to_index.get(name).copied()
2280 }
2281
2282 #[must_use]
2284 pub fn get_name(&self, index: usize) -> Option<&str> {
2285 self.index_to_name
2286 .get(index)
2287 .map(std::string::String::as_str)
2288 }
2289
2290 #[must_use]
2292 pub fn get_all_names(&self) -> &[String] {
2293 &self.index_to_name
2294 }
2295
2296 #[must_use]
2298 pub fn len(&self) -> usize {
2299 self.index_to_name.len()
2300 }
2301
2302 #[must_use]
2304 pub fn is_empty(&self) -> bool {
2305 self.index_to_name.is_empty()
2306 }
2307
2308 pub fn clear(&mut self) {
2310 self.name_to_index.clear();
2311 self.index_to_name.clear();
2312 }
2313
2314 #[must_use]
2317 pub fn create_variable_map(&self, values: &[(String, f64)]) -> Vec<f64> {
2318 let mut result = vec![0.0; self.len()];
2319 for (name, value) in values {
2320 if let Some(index) = self.get_index(name) {
2321 result[index] = *value;
2322 }
2323 }
2324 result
2325 }
2326
2327 #[must_use]
2330 pub fn create_ordered_variable_map(&self, values: &[f64]) -> Vec<f64> {
2331 let mut result = vec![0.0; self.len()];
2332 for (i, &value) in values.iter().enumerate() {
2333 if i < result.len() {
2334 result[i] = value;
2335 }
2336 }
2337 result
2338 }
2339}
2340
2341impl Default for VariableRegistry {
2342 fn default() -> Self {
2343 Self::new()
2344 }
2345}
2346
2347static GLOBAL_REGISTRY: std::sync::LazyLock<Arc<RwLock<VariableRegistry>>> =
2349 std::sync::LazyLock::new(|| Arc::new(RwLock::new(VariableRegistry::new())));
2350
2351pub fn global_registry() -> Arc<RwLock<VariableRegistry>> {
2353 GLOBAL_REGISTRY.clone()
2354}
2355
2356#[must_use]
2358pub fn register_variable(name: &str) -> usize {
2359 let registry = global_registry();
2360 let mut guard = registry.write().unwrap();
2361 guard.register_variable(name)
2362}
2363
2364#[must_use]
2366pub fn get_variable_index(name: &str) -> Option<usize> {
2367 let registry = global_registry();
2368 let guard = registry.read().unwrap();
2369 guard.get_index(name)
2370}
2371
2372#[must_use]
2374pub fn get_variable_name(index: usize) -> Option<String> {
2375 let registry = global_registry();
2376 let guard = registry.read().unwrap();
2377 guard.get_name(index).map(std::string::ToString::to_string)
2378}
2379
2380#[must_use]
2382pub fn create_variable_map(values: &[(String, f64)]) -> Vec<f64> {
2383 let registry = global_registry();
2384 let guard = registry.read().unwrap();
2385 guard.create_variable_map(values)
2386}
2387
2388pub fn clear_global_registry() {
2390 let registry = global_registry();
2391 let mut guard = registry.write().unwrap();
2392 guard.clear();
2393}
2394
2395#[derive(Debug, Clone)]
2399pub struct ExpressionBuilder {
2400 registry: VariableRegistry,
2401}
2402
2403impl ExpressionBuilder {
2404 #[must_use]
2406 pub fn new() -> Self {
2407 Self {
2408 registry: VariableRegistry::new(),
2409 }
2410 }
2411
2412 pub fn register_variable(&mut self, name: &str) -> usize {
2414 self.registry.register_variable(name)
2415 }
2416
2417 pub fn var(&mut self, name: &str) -> ASTRepr<f64> {
2419 let index = self.register_variable(name);
2420 ASTRepr::Variable(index)
2421 }
2422
2423 #[must_use]
2425 pub fn var_by_index(&self, index: usize) -> ASTRepr<f64> {
2426 ASTRepr::Variable(index)
2427 }
2428
2429 #[must_use]
2431 pub fn constant(&self, value: f64) -> ASTRepr<f64> {
2432 ASTRepr::Constant(value)
2433 }
2434
2435 #[must_use]
2437 pub fn registry(&self) -> &VariableRegistry {
2438 &self.registry
2439 }
2440
2441 pub fn registry_mut(&mut self) -> &mut VariableRegistry {
2443 &mut self.registry
2444 }
2445
2446 #[must_use]
2448 pub fn eval_with_named_vars(&self, expr: &ASTRepr<f64>, named_vars: &[(String, f64)]) -> f64 {
2449 let var_array = self.registry.create_variable_map(named_vars);
2450 DirectEval::eval_with_vars(expr, &var_array)
2451 }
2452
2453 #[must_use]
2455 pub fn eval_with_vars(&self, expr: &ASTRepr<f64>, variables: &[f64]) -> f64 {
2456 DirectEval::eval_with_vars(expr, variables)
2457 }
2458
2459 #[must_use]
2461 pub fn num_variables(&self) -> usize {
2462 self.registry.len()
2463 }
2464
2465 #[must_use]
2467 pub fn variable_names(&self) -> &[String] {
2468 self.registry.get_all_names()
2469 }
2470
2471 #[must_use]
2473 pub fn get_variable_index(&self, name: &str) -> Option<usize> {
2474 self.registry.get_index(name)
2475 }
2476
2477 #[must_use]
2479 pub fn get_variable_name(&self, index: usize) -> Option<&str> {
2480 self.registry.get_name(index)
2481 }
2482}
2483
2484impl Default for ExpressionBuilder {
2485 fn default() -> Self {
2486 Self::new()
2487 }
2488}
2489
2490impl<T> Add for ASTRepr<T>
2496where
2497 T: NumericType + Add<Output = T>,
2498{
2499 type Output = ASTRepr<T>;
2500
2501 fn add(self, rhs: Self) -> Self::Output {
2502 ASTRepr::Add(Box::new(self), Box::new(rhs))
2503 }
2504}
2505
2506impl<T> Add<&ASTRepr<T>> for &ASTRepr<T>
2508where
2509 T: NumericType + Add<Output = T>,
2510{
2511 type Output = ASTRepr<T>;
2512
2513 fn add(self, rhs: &ASTRepr<T>) -> Self::Output {
2514 ASTRepr::Add(Box::new(self.clone()), Box::new(rhs.clone()))
2515 }
2516}
2517
2518impl<T> Add<ASTRepr<T>> for &ASTRepr<T>
2520where
2521 T: NumericType + Add<Output = T>,
2522{
2523 type Output = ASTRepr<T>;
2524
2525 fn add(self, rhs: ASTRepr<T>) -> Self::Output {
2526 ASTRepr::Add(Box::new(self.clone()), Box::new(rhs))
2527 }
2528}
2529
2530impl<T> Add<&ASTRepr<T>> for ASTRepr<T>
2531where
2532 T: NumericType + Add<Output = T>,
2533{
2534 type Output = ASTRepr<T>;
2535
2536 fn add(self, rhs: &ASTRepr<T>) -> Self::Output {
2537 ASTRepr::Add(Box::new(self), Box::new(rhs.clone()))
2538 }
2539}
2540
2541impl<T> Sub for ASTRepr<T>
2543where
2544 T: NumericType + Sub<Output = T>,
2545{
2546 type Output = ASTRepr<T>;
2547
2548 fn sub(self, rhs: Self) -> Self::Output {
2549 ASTRepr::Sub(Box::new(self), Box::new(rhs))
2550 }
2551}
2552
2553impl<T> Sub<&ASTRepr<T>> for &ASTRepr<T>
2555where
2556 T: NumericType + Sub<Output = T>,
2557{
2558 type Output = ASTRepr<T>;
2559
2560 fn sub(self, rhs: &ASTRepr<T>) -> Self::Output {
2561 ASTRepr::Sub(Box::new(self.clone()), Box::new(rhs.clone()))
2562 }
2563}
2564
2565impl<T> Sub<ASTRepr<T>> for &ASTRepr<T>
2567where
2568 T: NumericType + Sub<Output = T>,
2569{
2570 type Output = ASTRepr<T>;
2571
2572 fn sub(self, rhs: ASTRepr<T>) -> Self::Output {
2573 ASTRepr::Sub(Box::new(self.clone()), Box::new(rhs))
2574 }
2575}
2576
2577impl<T> Sub<&ASTRepr<T>> for ASTRepr<T>
2578where
2579 T: NumericType + Sub<Output = T>,
2580{
2581 type Output = ASTRepr<T>;
2582
2583 fn sub(self, rhs: &ASTRepr<T>) -> Self::Output {
2584 ASTRepr::Sub(Box::new(self), Box::new(rhs.clone()))
2585 }
2586}
2587
2588impl<T> Mul for ASTRepr<T>
2590where
2591 T: NumericType + Mul<Output = T>,
2592{
2593 type Output = ASTRepr<T>;
2594
2595 fn mul(self, rhs: Self) -> Self::Output {
2596 ASTRepr::Mul(Box::new(self), Box::new(rhs))
2597 }
2598}
2599
2600impl<T> Mul<&ASTRepr<T>> for &ASTRepr<T>
2602where
2603 T: NumericType + Mul<Output = T>,
2604{
2605 type Output = ASTRepr<T>;
2606
2607 fn mul(self, rhs: &ASTRepr<T>) -> Self::Output {
2608 ASTRepr::Mul(Box::new(self.clone()), Box::new(rhs.clone()))
2609 }
2610}
2611
2612impl<T> Mul<ASTRepr<T>> for &ASTRepr<T>
2614where
2615 T: NumericType + Mul<Output = T>,
2616{
2617 type Output = ASTRepr<T>;
2618
2619 fn mul(self, rhs: ASTRepr<T>) -> Self::Output {
2620 ASTRepr::Mul(Box::new(self.clone()), Box::new(rhs))
2621 }
2622}
2623
2624impl<T> Mul<&ASTRepr<T>> for ASTRepr<T>
2625where
2626 T: NumericType + Mul<Output = T>,
2627{
2628 type Output = ASTRepr<T>;
2629
2630 fn mul(self, rhs: &ASTRepr<T>) -> Self::Output {
2631 ASTRepr::Mul(Box::new(self), Box::new(rhs.clone()))
2632 }
2633}
2634
2635impl<T> Div for ASTRepr<T>
2637where
2638 T: NumericType + Div<Output = T>,
2639{
2640 type Output = ASTRepr<T>;
2641
2642 fn div(self, rhs: Self) -> Self::Output {
2643 ASTRepr::Div(Box::new(self), Box::new(rhs))
2644 }
2645}
2646
2647impl<T> Div<&ASTRepr<T>> for &ASTRepr<T>
2649where
2650 T: NumericType + Div<Output = T>,
2651{
2652 type Output = ASTRepr<T>;
2653
2654 fn div(self, rhs: &ASTRepr<T>) -> Self::Output {
2655 ASTRepr::Div(Box::new(self.clone()), Box::new(rhs.clone()))
2656 }
2657}
2658
2659impl<T> Div<ASTRepr<T>> for &ASTRepr<T>
2661where
2662 T: NumericType + Div<Output = T>,
2663{
2664 type Output = ASTRepr<T>;
2665
2666 fn div(self, rhs: ASTRepr<T>) -> Self::Output {
2667 ASTRepr::Div(Box::new(self.clone()), Box::new(rhs))
2668 }
2669}
2670
2671impl<T> Div<&ASTRepr<T>> for ASTRepr<T>
2672where
2673 T: NumericType + Div<Output = T>,
2674{
2675 type Output = ASTRepr<T>;
2676
2677 fn div(self, rhs: &ASTRepr<T>) -> Self::Output {
2678 ASTRepr::Div(Box::new(self), Box::new(rhs.clone()))
2679 }
2680}
2681
2682impl<T> Neg for ASTRepr<T>
2684where
2685 T: NumericType + Neg<Output = T>,
2686{
2687 type Output = ASTRepr<T>;
2688
2689 fn neg(self) -> Self::Output {
2690 ASTRepr::Neg(Box::new(self))
2691 }
2692}
2693
2694impl<T> Neg for &ASTRepr<T>
2696where
2697 T: NumericType + Neg<Output = T>,
2698{
2699 type Output = ASTRepr<T>;
2700
2701 fn neg(self) -> Self::Output {
2702 ASTRepr::Neg(Box::new(self.clone()))
2703 }
2704}
2705
2706impl<T> ASTRepr<T>
2708where
2709 T: NumericType,
2710{
2711 #[must_use]
2713 pub fn pow(self, exp: ASTRepr<T>) -> ASTRepr<T>
2714 where
2715 T: Float,
2716 {
2717 ASTRepr::Pow(Box::new(self), Box::new(exp))
2718 }
2719
2720 #[must_use]
2722 pub fn pow_ref(&self, exp: &ASTRepr<T>) -> ASTRepr<T>
2723 where
2724 T: Float,
2725 {
2726 ASTRepr::Pow(Box::new(self.clone()), Box::new(exp.clone()))
2727 }
2728
2729 #[must_use]
2731 pub fn ln(self) -> ASTRepr<T>
2732 where
2733 T: Float,
2734 {
2735 ASTRepr::Ln(Box::new(self))
2736 }
2737
2738 #[must_use]
2740 pub fn ln_ref(&self) -> ASTRepr<T>
2741 where
2742 T: Float,
2743 {
2744 ASTRepr::Ln(Box::new(self.clone()))
2745 }
2746
2747 #[must_use]
2749 pub fn exp(self) -> ASTRepr<T>
2750 where
2751 T: Float,
2752 {
2753 ASTRepr::Exp(Box::new(self))
2754 }
2755
2756 #[must_use]
2758 pub fn exp_ref(&self) -> ASTRepr<T>
2759 where
2760 T: Float,
2761 {
2762 ASTRepr::Exp(Box::new(self.clone()))
2763 }
2764
2765 #[must_use]
2767 pub fn sqrt(self) -> ASTRepr<T>
2768 where
2769 T: Float,
2770 {
2771 ASTRepr::Sqrt(Box::new(self))
2772 }
2773
2774 #[must_use]
2776 pub fn sqrt_ref(&self) -> ASTRepr<T>
2777 where
2778 T: Float,
2779 {
2780 ASTRepr::Sqrt(Box::new(self.clone()))
2781 }
2782
2783 #[must_use]
2785 pub fn sin(self) -> ASTRepr<T>
2786 where
2787 T: Float,
2788 {
2789 ASTRepr::Sin(Box::new(self))
2790 }
2791
2792 #[must_use]
2794 pub fn sin_ref(&self) -> ASTRepr<T>
2795 where
2796 T: Float,
2797 {
2798 ASTRepr::Sin(Box::new(self.clone()))
2799 }
2800
2801 #[must_use]
2803 pub fn cos(self) -> ASTRepr<T>
2804 where
2805 T: Float,
2806 {
2807 ASTRepr::Cos(Box::new(self))
2808 }
2809
2810 #[must_use]
2812 pub fn cos_ref(&self) -> ASTRepr<T>
2813 where
2814 T: Float,
2815 {
2816 ASTRepr::Cos(Box::new(self.clone()))
2817 }
2818}