use num_traits::Float;
use std::ops::{Add, Div, Mul, Neg, Sub};
pub trait NumericType: Clone + Default + Send + Sync + 'static + std::fmt::Display {}
impl<T> NumericType for T where T: Clone + Default + Send + Sync + 'static + std::fmt::Display {}
pub trait MathExpr {
type Repr<T>;
fn constant<T: NumericType>(value: T) -> Self::Repr<T>;
fn var<T: NumericType>(name: &str) -> Self::Repr<T>;
fn add<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
where
L: NumericType + Add<R, Output = Output>,
R: NumericType,
Output: NumericType;
fn sub<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
where
L: NumericType + Sub<R, Output = Output>,
R: NumericType,
Output: NumericType;
fn mul<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
where
L: NumericType + Mul<R, Output = Output>,
R: NumericType,
Output: NumericType;
fn div<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
where
L: NumericType + Div<R, Output = Output>,
R: NumericType,
Output: NumericType;
fn pow<T: NumericType + Float>(base: Self::Repr<T>, exp: Self::Repr<T>) -> Self::Repr<T>;
fn neg<T: NumericType + Neg<Output = T>>(expr: Self::Repr<T>) -> Self::Repr<T>;
fn ln<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T>;
fn exp<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T>;
fn sqrt<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T>;
fn sin<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T>;
fn cos<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T>;
}
#[derive(Debug, Clone)]
pub struct Expr<E: MathExpr, T> {
repr: E::Repr<T>,
_phantom: std::marker::PhantomData<E>,
}
impl<E: MathExpr, T> Expr<E, T> {
pub fn new(repr: E::Repr<T>) -> Self {
Self {
repr,
_phantom: std::marker::PhantomData,
}
}
pub fn into_repr(self) -> E::Repr<T> {
self.repr
}
pub fn as_repr(&self) -> &E::Repr<T> {
&self.repr
}
pub fn constant(value: T) -> Self
where
T: NumericType,
{
Self::new(E::constant(value))
}
pub fn var(name: &str) -> Self
where
T: NumericType,
{
Self::new(E::var(name))
}
pub fn pow(self, exp: Self) -> Self
where
T: NumericType + Float,
{
Self::new(E::pow(self.repr, exp.repr))
}
pub fn ln(self) -> Self
where
T: NumericType + Float,
{
Self::new(E::ln(self.repr))
}
pub fn exp(self) -> Self
where
T: NumericType + Float,
{
Self::new(E::exp(self.repr))
}
pub fn sqrt(self) -> Self
where
T: NumericType + Float,
{
Self::new(E::sqrt(self.repr))
}
pub fn sin(self) -> Self
where
T: NumericType + Float,
{
Self::new(E::sin(self.repr))
}
pub fn cos(self) -> Self
where
T: NumericType + Float,
{
Self::new(E::cos(self.repr))
}
}
impl<T> Expr<DirectEval, T> {
pub fn var_with_value(name: &str, value: T) -> Self
where
T: NumericType,
{
Self::new(DirectEval::var(name, value))
}
pub fn eval(self) -> T {
self.repr
}
}
impl<T> Expr<PrettyPrint, T> {
pub fn to_string(self) -> String {
self.repr
}
}
impl<E: MathExpr, L, R, Output> Add<Expr<E, R>> for Expr<E, L>
where
L: NumericType + Add<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
type Output = Expr<E, Output>;
fn add(self, rhs: Expr<E, R>) -> Self::Output {
Expr::new(E::add(self.repr, rhs.repr))
}
}
impl<E: MathExpr, L, R, Output> Sub<Expr<E, R>> for Expr<E, L>
where
L: NumericType + Sub<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
type Output = Expr<E, Output>;
fn sub(self, rhs: Expr<E, R>) -> Self::Output {
Expr::new(E::sub(self.repr, rhs.repr))
}
}
impl<E: MathExpr, L, R, Output> Mul<Expr<E, R>> for Expr<E, L>
where
L: NumericType + Mul<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
type Output = Expr<E, Output>;
fn mul(self, rhs: Expr<E, R>) -> Self::Output {
Expr::new(E::mul(self.repr, rhs.repr))
}
}
impl<E: MathExpr, L, R, Output> Div<Expr<E, R>> for Expr<E, L>
where
L: NumericType + Div<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
type Output = Expr<E, Output>;
fn div(self, rhs: Expr<E, R>) -> Self::Output {
Expr::new(E::div(self.repr, rhs.repr))
}
}
impl<E: MathExpr, T> Neg for Expr<E, T>
where
T: NumericType + Neg<Output = T>,
{
type Output = Expr<E, T>;
fn neg(self) -> Self::Output {
Expr::new(E::neg(self.repr))
}
}
pub mod polynomial {
use super::{MathExpr, NumericType};
use std::ops::{Add, Mul, Sub};
pub fn horner<E: MathExpr, T>(coeffs: &[T], x: E::Repr<T>) -> E::Repr<T>
where
T: NumericType + Clone + Add<Output = T> + Mul<Output = T>,
E::Repr<T>: Clone,
{
if coeffs.is_empty() {
return E::constant(T::default());
}
if coeffs.len() == 1 {
return E::constant(coeffs[0].clone());
}
let mut result = E::constant(coeffs[coeffs.len() - 1].clone());
for coeff in coeffs.iter().rev().skip(1) {
result = E::add(E::mul(result, x.clone()), E::constant(coeff.clone()));
}
result
}
pub fn horner_expr<E: MathExpr, T>(coeffs: &[E::Repr<T>], x: E::Repr<T>) -> E::Repr<T>
where
T: NumericType + Add<Output = T> + Mul<Output = T>,
E::Repr<T>: Clone,
{
if coeffs.is_empty() {
return E::constant(T::default());
}
if coeffs.len() == 1 {
return coeffs[0].clone();
}
let mut result = coeffs[coeffs.len() - 1].clone();
for coeff in coeffs.iter().rev().skip(1) {
result = E::add(E::mul(result, x.clone()), coeff.clone());
}
result
}
pub fn from_roots<E: MathExpr, T>(roots: &[T], x: E::Repr<T>) -> E::Repr<T>
where
T: NumericType + Clone + Sub<Output = T> + num_traits::One,
E::Repr<T>: Clone,
{
if roots.is_empty() {
return E::constant(num_traits::One::one());
}
let mut result = E::sub(x.clone(), E::constant(roots[0].clone()));
for root in roots.iter().skip(1) {
let factor = E::sub(x.clone(), E::constant(root.clone()));
result = E::mul(result, factor);
}
result
}
pub fn horner_derivative<E: MathExpr, T>(coeffs: &[T], x: E::Repr<T>) -> E::Repr<T>
where
T: NumericType + Clone + Add<Output = T> + Mul<Output = T> + num_traits::FromPrimitive,
E::Repr<T>: Clone,
{
if coeffs.len() <= 1 {
return E::constant(T::default());
}
let mut deriv_coeffs = Vec::with_capacity(coeffs.len() - 1);
for (i, coeff) in coeffs.iter().enumerate().skip(1) {
let power = num_traits::FromPrimitive::from_usize(i).unwrap_or_else(|| T::default());
deriv_coeffs.push(coeff.clone() * power);
}
horner::<E, T>(&deriv_coeffs, x)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DirectEval;
impl DirectEval {
#[must_use]
pub fn var<T: NumericType>(_name: &str, value: T) -> T {
value
}
}
impl MathExpr for DirectEval {
type Repr<T> = T;
fn constant<T: NumericType>(value: T) -> Self::Repr<T> {
value
}
fn var<T: NumericType>(_name: &str) -> Self::Repr<T> {
T::default()
}
fn add<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
where
L: NumericType + Add<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
left + right
}
fn sub<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
where
L: NumericType + Sub<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
left - right
}
fn mul<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
where
L: NumericType + Mul<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
left * right
}
fn div<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
where
L: NumericType + Div<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
left / right
}
fn pow<T: NumericType + Float>(base: Self::Repr<T>, exp: Self::Repr<T>) -> Self::Repr<T> {
base.powf(exp)
}
fn neg<T: NumericType + Neg<Output = T>>(expr: Self::Repr<T>) -> Self::Repr<T> {
-expr
}
fn ln<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
expr.ln()
}
fn exp<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
expr.exp()
}
fn sqrt<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
expr.sqrt()
}
fn sin<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
expr.sin()
}
fn cos<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
expr.cos()
}
}
pub trait StatisticalExpr: MathExpr {
fn logistic<T: NumericType + Float>(x: Self::Repr<T>) -> Self::Repr<T> {
let one = Self::constant(T::one());
let neg_x = Self::neg(x);
let exp_neg_x = Self::exp(neg_x);
let denominator = Self::add(one, exp_neg_x);
Self::div(Self::constant(T::one()), denominator)
}
fn softplus<T: NumericType + Float>(x: Self::Repr<T>) -> Self::Repr<T> {
let one = Self::constant(T::one());
let exp_x = Self::exp(x);
let one_plus_exp_x = Self::add(one, exp_x);
Self::ln(one_plus_exp_x)
}
fn sigmoid<T: NumericType + Float>(x: Self::Repr<T>) -> Self::Repr<T> {
Self::logistic(x)
}
}
impl StatisticalExpr for DirectEval {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PrettyPrint;
impl PrettyPrint {
#[must_use]
pub fn var(name: &str) -> String {
name.to_string()
}
}
impl MathExpr for PrettyPrint {
type Repr<T> = String;
fn constant<T: NumericType>(value: T) -> Self::Repr<T> {
format!("{value}")
}
fn var<T: NumericType>(name: &str) -> Self::Repr<T> {
name.to_string()
}
fn add<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
where
L: NumericType + Add<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
format!("({left} + {right})")
}
fn sub<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
where
L: NumericType + Sub<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
format!("({left} - {right})")
}
fn mul<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
where
L: NumericType + Mul<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
format!("({left} * {right})")
}
fn div<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
where
L: NumericType + Div<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
format!("({left} / {right})")
}
fn pow<T: NumericType + Float>(base: Self::Repr<T>, exp: Self::Repr<T>) -> Self::Repr<T> {
format!("({base} ^ {exp})")
}
fn neg<T: NumericType + Neg<Output = T>>(expr: Self::Repr<T>) -> Self::Repr<T> {
format!("(-{expr})")
}
fn ln<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
format!("ln({expr})")
}
fn exp<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
format!("exp({expr})")
}
fn sqrt<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
format!("sqrt({expr})")
}
fn sin<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
format!("sin({expr})")
}
fn cos<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
format!("cos({expr})")
}
}
impl StatisticalExpr for PrettyPrint {}
#[derive(Debug, Clone)]
pub enum ASTRepr<T> {
Constant(T),
Variable(String),
Add(Box<ASTRepr<T>>, Box<ASTRepr<T>>),
Sub(Box<ASTRepr<T>>, Box<ASTRepr<T>>),
Mul(Box<ASTRepr<T>>, Box<ASTRepr<T>>),
Div(Box<ASTRepr<T>>, Box<ASTRepr<T>>),
Pow(Box<ASTRepr<T>>, Box<ASTRepr<T>>),
Neg(Box<ASTRepr<T>>),
Ln(Box<ASTRepr<T>>),
Exp(Box<ASTRepr<T>>),
Sqrt(Box<ASTRepr<T>>),
Sin(Box<ASTRepr<T>>),
Cos(Box<ASTRepr<T>>),
}
#[cfg(feature = "jit")]
pub struct ASTEval;
#[cfg(feature = "jit")]
impl ASTEval {
pub fn var<T: NumericType>(name: &str) -> ASTRepr<T> {
ASTRepr::Variable(name.to_string())
}
}
#[cfg(feature = "jit")]
pub trait ASTMathExpr {
type Repr;
fn constant(value: f64) -> Self::Repr;
fn var(name: &str) -> Self::Repr;
fn add(left: Self::Repr, right: Self::Repr) -> Self::Repr;
fn sub(left: Self::Repr, right: Self::Repr) -> Self::Repr;
fn mul(left: Self::Repr, right: Self::Repr) -> Self::Repr;
fn div(left: Self::Repr, right: Self::Repr) -> Self::Repr;
fn pow(base: Self::Repr, exp: Self::Repr) -> Self::Repr;
fn neg(expr: Self::Repr) -> Self::Repr;
fn ln(expr: Self::Repr) -> Self::Repr;
fn exp(expr: Self::Repr) -> Self::Repr;
fn sqrt(expr: Self::Repr) -> Self::Repr;
fn sin(expr: Self::Repr) -> Self::Repr;
fn cos(expr: Self::Repr) -> Self::Repr;
}
#[cfg(feature = "jit")]
impl ASTMathExpr for ASTEval {
type Repr = ASTRepr<f64>;
fn constant(value: f64) -> Self::Repr {
ASTRepr::Constant(value)
}
fn var(name: &str) -> Self::Repr {
ASTRepr::Variable(name.to_string())
}
fn add(left: Self::Repr, right: Self::Repr) -> Self::Repr {
ASTRepr::Add(Box::new(left), Box::new(right))
}
fn sub(left: Self::Repr, right: Self::Repr) -> Self::Repr {
ASTRepr::Sub(Box::new(left), Box::new(right))
}
fn mul(left: Self::Repr, right: Self::Repr) -> Self::Repr {
ASTRepr::Mul(Box::new(left), Box::new(right))
}
fn div(left: Self::Repr, right: Self::Repr) -> Self::Repr {
ASTRepr::Div(Box::new(left), Box::new(right))
}
fn pow(base: Self::Repr, exp: Self::Repr) -> Self::Repr {
ASTRepr::Pow(Box::new(base), Box::new(exp))
}
fn neg(expr: Self::Repr) -> Self::Repr {
ASTRepr::Neg(Box::new(expr))
}
fn ln(expr: Self::Repr) -> Self::Repr {
ASTRepr::Ln(Box::new(expr))
}
fn exp(expr: Self::Repr) -> Self::Repr {
ASTRepr::Exp(Box::new(expr))
}
fn sqrt(expr: Self::Repr) -> Self::Repr {
ASTRepr::Sqrt(Box::new(expr))
}
fn sin(expr: Self::Repr) -> Self::Repr {
ASTRepr::Sin(Box::new(expr))
}
fn cos(expr: Self::Repr) -> Self::Repr {
ASTRepr::Cos(Box::new(expr))
}
}
#[cfg(feature = "jit")]
impl MathExpr for ASTEval {
type Repr<T> = ASTRepr<T>;
fn constant<T: NumericType>(value: T) -> Self::Repr<T> {
ASTRepr::Constant(value)
}
fn var<T: NumericType>(name: &str) -> Self::Repr<T> {
ASTRepr::Variable(name.to_string())
}
fn add<L, R, Output>(_left: Self::Repr<L>, _right: Self::Repr<R>) -> Self::Repr<Output>
where
L: NumericType + Add<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
unimplemented!("Use ASTMathExpr trait for practical JIT compilation with f64 types")
}
fn sub<L, R, Output>(_left: Self::Repr<L>, _right: Self::Repr<R>) -> Self::Repr<Output>
where
L: NumericType + Sub<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
unimplemented!("Use ASTMathExpr trait for practical JIT compilation with f64 types")
}
fn mul<L, R, Output>(_left: Self::Repr<L>, _right: Self::Repr<R>) -> Self::Repr<Output>
where
L: NumericType + Mul<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
unimplemented!("Use ASTMathExpr trait for practical JIT compilation with f64 types")
}
fn div<L, R, Output>(_left: Self::Repr<L>, _right: Self::Repr<R>) -> Self::Repr<Output>
where
L: NumericType + Div<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
unimplemented!("Use ASTMathExpr trait for practical JIT compilation with f64 types")
}
fn pow<T: NumericType + Float>(base: Self::Repr<T>, exp: Self::Repr<T>) -> Self::Repr<T> {
ASTRepr::Pow(Box::new(base), Box::new(exp))
}
fn neg<T: NumericType + Neg<Output = T>>(expr: Self::Repr<T>) -> Self::Repr<T> {
ASTRepr::Neg(Box::new(expr))
}
fn ln<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
ASTRepr::Ln(Box::new(expr))
}
fn exp<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
ASTRepr::Exp(Box::new(expr))
}
fn sqrt<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
ASTRepr::Sqrt(Box::new(expr))
}
fn sin<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
ASTRepr::Sin(Box::new(expr))
}
fn cos<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
ASTRepr::Cos(Box::new(expr))
}
}
#[cfg(feature = "jit")]
impl StatisticalExpr for ASTEval {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_direct_eval() {
fn linear<E: MathExpr>(x: E::Repr<f64>) -> E::Repr<f64> {
let two = E::constant(2.0);
let three = E::constant(3.0);
E::add(E::mul(two, x), three)
}
let result = linear::<DirectEval>(DirectEval::var("x", 5.0));
assert_eq!(result, 13.0); }
#[test]
fn test_statistical_extension() {
fn logistic_expr<E: StatisticalExpr>(x: E::Repr<f64>) -> E::Repr<f64> {
E::logistic(x)
}
let result = logistic_expr::<DirectEval>(DirectEval::var("x", 0.0));
assert!((result - 0.5).abs() < 0.001);
}
#[test]
fn test_pretty_print() {
fn quadratic<E: MathExpr>(x: E::Repr<f64>) -> E::Repr<f64>
where
E::Repr<f64>: Clone,
{
let a = E::constant(2.0);
let b = E::constant(3.0);
let c = E::constant(1.0);
E::add(
E::add(E::mul(a, E::pow(x.clone(), E::constant(2.0))), E::mul(b, x)),
c,
)
}
let pretty = quadratic::<PrettyPrint>(PrettyPrint::var("x"));
assert!(pretty.contains('x'));
assert!(pretty.contains('2'));
assert!(pretty.contains('3'));
assert!(pretty.contains('1'));
assert!(pretty.contains('^'));
assert!(pretty.contains('*'));
assert!(pretty.contains('+'));
}
#[test]
fn test_horner_polynomial() {
use crate::final_tagless::polynomial::horner;
let coeffs = [1.0, 3.0, 2.0]; let x = DirectEval::var("x", 2.0);
let result = horner::<DirectEval, f64>(&coeffs, x);
assert_eq!(result, 15.0);
let empty_coeffs: [f64; 0] = [];
let result_empty = horner::<DirectEval, f64>(&empty_coeffs, DirectEval::var("x", 5.0));
assert_eq!(result_empty, 0.0);
let single_coeff = [42.0];
let result_single = horner::<DirectEval, f64>(&single_coeff, DirectEval::var("x", 5.0));
assert_eq!(result_single, 42.0);
}
#[test]
fn test_horner_pretty_print() {
use crate::final_tagless::polynomial::horner;
let coeffs = [1.0, 3.0, 2.0]; let x = PrettyPrint::var("x");
let pretty = horner::<PrettyPrint, f64>(&coeffs, x);
assert!(pretty.contains('x'));
assert!(pretty.contains('1'));
assert!(pretty.contains('3'));
assert!(pretty.contains('2'));
}
#[test]
fn test_polynomial_from_roots() {
use crate::final_tagless::polynomial::from_roots;
let roots = [1.0, 2.0];
let result_0 = from_roots::<DirectEval, f64>(&roots, DirectEval::var("x", 0.0));
assert_eq!(result_0, 2.0);
let result_1 = from_roots::<DirectEval, f64>(&roots, DirectEval::var("x", 1.0));
assert_eq!(result_1, 0.0);
let result_2 = from_roots::<DirectEval, f64>(&roots, DirectEval::var("x", 2.0));
assert_eq!(result_2, 0.0);
let result_3 = from_roots::<DirectEval, f64>(&roots, DirectEval::var("x", 3.0));
assert_eq!(result_3, 2.0);
}
#[test]
fn test_expr_operator_overloading() {
fn quadratic(x: Expr<DirectEval, f64>) -> Expr<DirectEval, f64> {
let a = Expr::constant(2.0);
let b = Expr::constant(3.0);
let c = Expr::constant(1.0);
a * x.clone() * x.clone() + b * x + c
}
let x = Expr::var_with_value("x", 2.0);
let result = quadratic(x);
assert_eq!(result.eval(), 15.0);
let x = Expr::var_with_value("x", 0.0);
let result = quadratic(x);
assert_eq!(result.eval(), 1.0);
}
#[test]
fn test_expr_transcendental_functions() {
let x = Expr::var_with_value("x", 5.0);
let result = x.ln().exp();
assert!((result.eval() - 5.0).abs() < 1e-10);
let x = Expr::var_with_value("x", 1.5);
let sin_x = x.clone().sin();
let cos_x = x.cos();
let result = sin_x.clone() * sin_x + cos_x.clone() * cos_x;
assert!((result.eval() - 1.0).abs() < 1e-10);
}
#[test]
fn test_expr_pretty_print() {
fn simple_expr(x: Expr<PrettyPrint, f64>) -> Expr<PrettyPrint, f64> {
let two = Expr::constant(2.0);
let three = Expr::constant(3.0);
two * x + three
}
let x = Expr::<PrettyPrint, f64>::var("x");
let pretty = simple_expr(x);
let result = pretty.to_string();
assert!(result.contains('x'));
assert!(result.contains('2'));
assert!(result.contains('3'));
assert!(result.contains('*'));
assert!(result.contains('+'));
}
#[test]
fn test_expr_negation() {
let x = Expr::var_with_value("x", 5.0);
let neg_x = -x;
assert_eq!(neg_x.eval(), -5.0);
let x = Expr::var_with_value("x", 3.0);
let y = Expr::var_with_value("y", 2.0);
let result = -(x.clone() + y.clone());
let expected = -x - y;
assert_eq!(result.eval(), expected.eval());
assert_eq!(result.eval(), -5.0);
}
#[test]
fn test_expr_mixed_operations() {
let x = Expr::var_with_value("x", 4.0);
let one = Expr::constant(1.0);
let left = x.clone() + one.clone();
let right = x.clone() - one;
let result = left * right;
assert_eq!(result.eval(), 15.0);
let x_squared_minus_one = x.clone() * x - Expr::constant(1.0);
assert_eq!(result.eval(), x_squared_minus_one.eval());
}
}