pub mod optimized;
pub mod scoped;
pub use scoped::{
ScopedConst, ScopedMathExpr, ScopedVar, ScopedVarArray, compose, scoped_constant, scoped_var,
};
pub trait MathExpr: Clone + Sized {
fn eval(&self, vars: &[f64]) -> f64;
fn add<T: MathExpr>(self, other: T) -> Add<Self, T> {
Add {
left: self,
right: other,
}
}
fn mul<T: MathExpr>(self, other: T) -> Mul<Self, T> {
Mul {
left: self,
right: other,
}
}
fn sub<T: MathExpr>(self, other: T) -> Sub<Self, T> {
Sub {
left: self,
right: other,
}
}
fn div<T: MathExpr>(self, other: T) -> Div<Self, T> {
Div {
left: self,
right: other,
}
}
fn pow<T: MathExpr>(self, exponent: T) -> Pow<Self, T> {
Pow {
base: self,
exponent,
}
}
fn exp(self) -> Exp<Self> {
Exp { inner: self }
}
fn ln(self) -> Ln<Self> {
Ln { inner: self }
}
fn sin(self) -> Sin<Self> {
Sin { inner: self }
}
fn cos(self) -> Cos<Self> {
Cos { inner: self }
}
fn sqrt(self) -> Sqrt<Self> {
Sqrt { inner: self }
}
fn neg(self) -> Neg<Self> {
Neg { inner: self }
}
}
pub trait Optimize: MathExpr {
type Optimized: MathExpr;
fn optimize(self) -> Self::Optimized;
}
#[derive(Clone, Debug)]
pub struct Var<const ID: usize>;
impl<const ID: usize> MathExpr for Var<ID> {
fn eval(&self, vars: &[f64]) -> f64 {
vars.get(ID).copied().unwrap_or(0.0)
}
}
#[derive(Clone, Debug)]
pub struct Const<const BITS: u64>;
impl<const BITS: u64> Const<BITS> {
#[must_use]
pub fn new(_value: f64) -> Self {
Self
}
#[must_use]
pub fn value(&self) -> f64 {
f64::from_bits(BITS)
}
}
impl<const BITS: u64> MathExpr for Const<BITS> {
fn eval(&self, _vars: &[f64]) -> f64 {
self.value()
}
}
#[derive(Clone, Debug)]
pub struct Add<L: MathExpr, R: MathExpr> {
left: L,
right: R,
}
impl<L: MathExpr, R: MathExpr> MathExpr for Add<L, R> {
fn eval(&self, vars: &[f64]) -> f64 {
self.left.eval(vars) + self.right.eval(vars)
}
}
#[derive(Clone, Debug)]
pub struct Mul<L: MathExpr, R: MathExpr> {
left: L,
right: R,
}
impl<L: MathExpr, R: MathExpr> MathExpr for Mul<L, R> {
fn eval(&self, vars: &[f64]) -> f64 {
self.left.eval(vars) * self.right.eval(vars)
}
}
#[derive(Clone, Debug)]
pub struct Sub<L: MathExpr, R: MathExpr> {
left: L,
right: R,
}
impl<L: MathExpr, R: MathExpr> MathExpr for Sub<L, R> {
fn eval(&self, vars: &[f64]) -> f64 {
self.left.eval(vars) - self.right.eval(vars)
}
}
#[derive(Clone, Debug)]
pub struct Div<L: MathExpr, R: MathExpr> {
left: L,
right: R,
}
impl<L: MathExpr, R: MathExpr> MathExpr for Div<L, R> {
fn eval(&self, vars: &[f64]) -> f64 {
self.left.eval(vars) / self.right.eval(vars)
}
}
#[derive(Clone, Debug)]
pub struct Pow<B: MathExpr, E: MathExpr> {
base: B,
exponent: E,
}
impl<B: MathExpr, E: MathExpr> MathExpr for Pow<B, E> {
fn eval(&self, vars: &[f64]) -> f64 {
self.base.eval(vars).powf(self.exponent.eval(vars))
}
}
#[derive(Clone, Debug)]
pub struct Exp<T: MathExpr> {
inner: T,
}
impl<T: MathExpr> MathExpr for Exp<T> {
fn eval(&self, vars: &[f64]) -> f64 {
self.inner.eval(vars).exp()
}
}
#[derive(Clone, Debug)]
pub struct Ln<T: MathExpr> {
inner: T,
}
impl<T: MathExpr> MathExpr for Ln<T> {
fn eval(&self, vars: &[f64]) -> f64 {
self.inner.eval(vars).ln()
}
}
#[derive(Clone, Debug)]
pub struct Sin<T: MathExpr> {
inner: T,
}
impl<T: MathExpr> MathExpr for Sin<T> {
fn eval(&self, vars: &[f64]) -> f64 {
self.inner.eval(vars).sin()
}
}
#[derive(Clone, Debug)]
pub struct Cos<T: MathExpr> {
inner: T,
}
impl<T: MathExpr> MathExpr for Cos<T> {
fn eval(&self, vars: &[f64]) -> f64 {
self.inner.eval(vars).cos()
}
}
#[derive(Clone, Debug)]
pub struct Sqrt<T: MathExpr> {
inner: T,
}
impl<T: MathExpr> MathExpr for Sqrt<T> {
fn eval(&self, vars: &[f64]) -> f64 {
self.inner.eval(vars).sqrt()
}
}
#[derive(Clone, Debug)]
pub struct Neg<T: MathExpr> {
inner: T,
}
impl<T: MathExpr> MathExpr for Neg<T> {
fn eval(&self, vars: &[f64]) -> f64 {
-self.inner.eval(vars)
}
}
impl<T: MathExpr> Optimize for Ln<Exp<T>> {
type Optimized = T;
fn optimize(self) -> T {
self.inner.inner
}
}
impl<T: MathExpr> Optimize for Exp<Ln<T>> {
type Optimized = T;
fn optimize(self) -> T {
self.inner.inner
}
}
impl<const ID: usize> Optimize for Add<Var<ID>, Const<0>> {
type Optimized = Var<ID>;
fn optimize(self) -> Var<ID> {
self.left
}
}
impl<const ID: usize> Optimize for Add<Const<0>, Var<ID>> {
type Optimized = Var<ID>;
fn optimize(self) -> Var<ID> {
self.right
}
}
impl<const ID: usize> Optimize for Mul<Var<ID>, Const<4607182418800017408>> {
type Optimized = Var<ID>;
fn optimize(self) -> Var<ID> {
self.left
}
}
impl<const ID: usize> Optimize for Mul<Const<4607182418800017408>, Var<ID>> {
type Optimized = Var<ID>;
fn optimize(self) -> Var<ID> {
self.right
}
}
impl<const ID: usize> Optimize for Mul<Var<ID>, Const<0>> {
type Optimized = Const<0>;
fn optimize(self) -> Const<0> {
Const
}
}
impl<const ID: usize> Optimize for Mul<Const<0>, Var<ID>> {
type Optimized = Const<0>;
fn optimize(self) -> Const<0> {
Const
}
}
impl<A: MathExpr, B: MathExpr> Optimize for Ln<Mul<A, B>> {
type Optimized = Add<Ln<A>, Ln<B>>;
fn optimize(self) -> Add<Ln<A>, Ln<B>> {
Add {
left: Ln {
inner: self.inner.left,
},
right: Ln {
inner: self.inner.right,
},
}
}
}
impl<A: MathExpr, B: MathExpr> Optimize for Exp<Add<A, B>> {
type Optimized = Mul<Exp<A>, Exp<B>>;
fn optimize(self) -> Mul<Exp<A>, Exp<B>> {
Mul {
left: Exp {
inner: self.inner.left,
},
right: Exp {
inner: self.inner.right,
},
}
}
}
#[must_use]
pub const fn var<const ID: usize>() -> Var<ID> {
Var
}
#[must_use]
pub fn constant(value: f64) -> impl MathExpr + optimized::ToAst {
ConstantValue { value }
}
#[derive(Clone, Debug)]
pub struct ConstantValue {
value: f64,
}
impl MathExpr for ConstantValue {
fn eval(&self, _vars: &[f64]) -> f64 {
self.value
}
}
impl optimized::ToAst for ConstantValue {
fn to_ast(&self) -> crate::ast::ASTRepr<f64> {
crate::ast::ASTRepr::Constant(self.value)
}
}
#[must_use]
pub const fn zero() -> Const<0> {
Const
}
#[must_use]
pub const fn one() -> Const<4607182418800017408> {
Const
}
pub use dslcompile_macros::optimize_compile_time;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_evaluation() {
let x = var::<0>();
let y = var::<1>();
let expr = x.clone().add(y.clone());
let result = expr.eval(&[2.0, 3.0]);
assert_eq!(result, 5.0);
}
#[test]
fn test_complex_expression() {
let x = var::<0>();
let y = var::<1>();
let expr = x.clone().mul(y.clone()).add(x.clone());
let result = expr.eval(&[2.0, 3.0]);
assert_eq!(result, 8.0); }
#[test]
fn test_transcendental_functions() {
let x = var::<0>();
let expr = x.clone().exp().ln();
let result = expr.eval(&[2.0]);
assert!((result - 2.0).abs() < 1e-10);
}
#[test]
fn test_ln_exp_optimization() {
let x = var::<0>();
let original = x.clone().exp().ln();
let optimized = original.clone().optimize();
let original_result = original.eval(&[2.0]);
let optimized_result = optimized.eval(&[2.0]);
assert!((original_result - optimized_result).abs() < 1e-10);
assert!((optimized_result - 2.0).abs() < 1e-10);
}
#[test]
fn test_zero_addition_optimization() {
let x = var::<0>();
let zero_const = zero();
let original = x.clone().add(zero_const);
let optimized = original.clone().optimize();
let original_result = original.eval(&[5.0]);
let optimized_result = optimized.eval(&[5.0]);
assert_eq!(original_result, optimized_result);
assert_eq!(optimized_result, 5.0);
}
}