#![warn(missing_docs)]
#![allow(unstable_name_collisions)]
pub mod error;
pub mod final_tagless;
pub mod ast_utils;
pub mod power_utils;
pub mod symbolic;
#[cfg(feature = "optimization")]
pub mod egglog_integration;
pub mod symbolic_ad;
pub mod summation;
pub mod backends;
pub mod transcendental;
pub mod ergonomics;
pub use error::{MathCompileError, Result};
pub use expr::Expr;
pub use final_tagless::{
ASTEval, ASTMathExpr, ASTRepr, DirectEval, MathExpr, NumericType, PrettyPrint, StatisticalExpr,
};
pub use symbolic::{
CompilationApproach, CompilationStrategy, OptimizationConfig, SymbolicOptimizer,
};
pub use backends::{CompiledRustFunction, RustCodeGenerator, RustCompiler, RustOptLevel};
pub use ergonomics::{MathBuilder, presets};
#[cfg(feature = "cranelift")]
pub use backends::cranelift::{CompilationStats, JITCompiler, JITFunction, JITSignature};
#[cfg(feature = "cranelift")]
pub use backends::cranelift;
pub use symbolic_ad::{FunctionWithDerivatives, SymbolicAD, SymbolicADConfig, SymbolicADStats};
pub use summation::{SumResult, SummationConfig, SummationPattern, SummationSimplifier};
pub use anf::{
ANFAtom, ANFCodeGen, ANFComputation, ANFConverter, ANFExpr, ANFVarGen, VarRef, convert_to_anf,
generate_rust_code,
};
pub mod anf;
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
pub mod prelude {
pub use crate::final_tagless::{
ASTEval, ASTMathExpr, ASTRepr, DirectEval, ExpressionBuilder, MathExpr, NumericType,
PrettyPrint, StatisticalExpr, VariableRegistry,
};
pub use crate::error::{MathCompileError, Result};
pub use crate::ergonomics::{MathBuilder, presets};
pub use crate::symbolic::{OptimizationConfig, SymbolicOptimizer};
pub use crate::symbolic_ad::{SymbolicAD, SymbolicADConfig, convenience as ad_convenience};
pub use crate::backends::{
CompiledRustFunction, RustCodeGenerator, RustCompiler, RustOptLevel,
};
#[cfg(feature = "cranelift")]
pub use crate::backends::cranelift::{
CompilationStats, JITCompiler, JITFunction, JITSignature,
};
pub use crate::expr::Expr;
pub use crate::summation::{SummationConfig, SummationSimplifier};
pub use crate::anf::{ANFCodeGen, ANFExpr, convert_to_anf, generate_rust_code};
}
pub mod expr {
use crate::final_tagless::{DirectEval, MathExpr, NumericType, PrettyPrint};
use num_traits::Float;
use std::marker::PhantomData;
use std::ops::{Add, Div, Mul, Neg, Sub};
#[derive(Debug)]
pub struct Expr<E: MathExpr, T> {
pub(crate) repr: E::Repr<T>,
_phantom: PhantomData<E>,
}
impl<E: MathExpr, T> Clone for Expr<E, T>
where
E::Repr<T>: Clone,
{
fn clone(&self) -> Self {
Self {
repr: self.repr.clone(),
_phantom: PhantomData,
}
}
}
impl<E: MathExpr, T> Expr<E, T> {
pub fn new(repr: E::Repr<T>) -> Self {
Self {
repr,
_phantom: PhantomData,
}
}
pub fn into_repr(self) -> E::Repr<T> {
self.repr
}
pub fn as_repr(&self) -> &E::Repr<T> {
&self.repr
}
pub fn constant(value: T) -> Self
where
T: NumericType,
{
Self::new(E::constant(value))
}
#[must_use]
pub fn var_by_index(index: usize) -> Self
where
T: NumericType,
{
Self::new(E::var_by_index(index))
}
#[must_use]
pub fn var(name: &str) -> Self
where
T: NumericType,
{
Self::new(E::var(name))
}
pub fn pow(self, exp: Self) -> Self
where
T: NumericType + Float,
{
Self::new(E::pow(self.repr, exp.repr))
}
pub fn ln(self) -> Self
where
T: NumericType + Float,
{
Self::new(E::ln(self.repr))
}
pub fn exp(self) -> Self
where
T: NumericType + Float,
{
Self::new(E::exp(self.repr))
}
pub fn sqrt(self) -> Self
where
T: NumericType + Float,
{
Self::new(E::sqrt(self.repr))
}
pub fn sin(self) -> Self
where
T: NumericType + Float,
{
Self::new(E::sin(self.repr))
}
pub fn cos(self) -> Self
where
T: NumericType + Float,
{
Self::new(E::cos(self.repr))
}
}
impl<T> Expr<DirectEval, T> {
pub fn var_with_value(name: &str, value: T) -> Self
where
T: NumericType,
{
Self::new(DirectEval::var(name, value))
}
pub fn eval(self) -> T {
self.repr
}
}
impl<T> Expr<PrettyPrint, T> {
#[must_use]
pub fn to_string(self) -> String {
self.repr
}
}
impl<E: MathExpr, L, R, Output> Add<Expr<E, R>> for Expr<E, L>
where
L: NumericType + Add<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
type Output = Expr<E, Output>;
fn add(self, rhs: Expr<E, R>) -> Self::Output {
Expr::new(E::add(self.repr, rhs.repr))
}
}
impl<E: MathExpr, L, R, Output> Sub<Expr<E, R>> for Expr<E, L>
where
L: NumericType + Sub<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
type Output = Expr<E, Output>;
fn sub(self, rhs: Expr<E, R>) -> Self::Output {
Expr::new(E::sub(self.repr, rhs.repr))
}
}
impl<E: MathExpr, L, R, Output> Mul<Expr<E, R>> for Expr<E, L>
where
L: NumericType + Mul<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
type Output = Expr<E, Output>;
fn mul(self, rhs: Expr<E, R>) -> Self::Output {
Expr::new(E::mul(self.repr, rhs.repr))
}
}
impl<E: MathExpr, L, R, Output> Div<Expr<E, R>> for Expr<E, L>
where
L: NumericType + Div<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
type Output = Expr<E, Output>;
fn div(self, rhs: Expr<E, R>) -> Self::Output {
Expr::new(E::div(self.repr, rhs.repr))
}
}
impl<E: MathExpr, T> Neg for Expr<E, T>
where
T: NumericType + Neg<Output = T>,
{
type Output = Expr<E, T>;
fn neg(self) -> Self::Output {
Expr::new(E::neg(self.repr))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_version_info() {
assert!(!VERSION.is_empty());
println!("MathCompile version: {VERSION}");
}
#[test]
fn test_basic_expression_building() {
let mut math = MathBuilder::new();
let x = math.var("x");
let expr = math.add(&math.mul(&x, &math.constant(2.0)), &math.constant(1.0));
let result = math.eval(&expr, &[("x", 3.0)]);
assert_eq!(result, 7.0);
let y = math.var("y");
let expr2 = math.add(&math.mul(&x, &math.constant(2.0)), &y);
let result2 = math.eval(&expr2, &[("x", 3.0), ("y", 4.0)]);
assert_eq!(result2, 10.0); }
#[test]
fn test_optimization_pipeline() {
let mut math = MathBuilder::new();
let x = math.var("x");
let expr = math.add(&x, &math.constant(0.0));
let result = math.eval(&expr, &[("x", 5.0)]);
assert_eq!(result, 5.0);
let expr = math.mul(&x, &math.constant(1.0));
let result = math.eval(&expr, &[("x", 7.0)]);
assert_eq!(result, 7.0);
let expr = math.mul(&x, &math.constant(0.0));
let result = math.eval(&expr, &[("x", 100.0)]);
assert_eq!(result, 0.0);
let y = math.var("y");
let expr = math.add(&math.mul(&x, &math.constant(2.0)), &y);
let result = math.eval(&expr, &[("x", 3.0), ("y", 4.0)]);
assert_eq!(result, 10.0);
let expr = math.sin(&x);
let result = math.eval(&expr, &[("x", 0.0)]);
assert!((result - 0.0).abs() < 1e-10); }
#[cfg(feature = "cranelift")]
#[test]
fn test_cranelift_compilation() {
let mut math = MathBuilder::new();
let x = math.var("x");
let expr = math.add(&math.mul(&x, &math.constant(2.0)), &math.constant(1.0));
use crate::final_tagless::ASTMathExpr;
let traditional_expr = <ASTEval as ASTMathExpr>::add(
<ASTEval as ASTMathExpr>::mul(
<ASTEval as ASTMathExpr>::var(0),
<ASTEval as ASTMathExpr>::constant(2.0),
),
<ASTEval as ASTMathExpr>::constant(1.0),
);
let compiler = JITCompiler::new().unwrap();
let jit_func = compiler.compile_single_var(&traditional_expr, "x").unwrap();
let result = jit_func.call_single(3.0);
assert_eq!(result, 7.0); }
#[test]
fn test_rust_code_generation() {
let mut math = MathBuilder::new();
let x = math.var("x");
let expr = math.add(&math.mul(&x, &math.constant(2.0)), &math.constant(1.0));
use crate::final_tagless::ASTMathExpr;
let traditional_expr = <ASTEval as ASTMathExpr>::add(
<ASTEval as ASTMathExpr>::mul(
<ASTEval as ASTMathExpr>::var(0),
<ASTEval as ASTMathExpr>::constant(2.0),
),
<ASTEval as ASTMathExpr>::constant(1.0),
);
let codegen = RustCodeGenerator::new();
let rust_code = codegen
.generate_function(&traditional_expr, "test_func")
.unwrap();
assert!(rust_code.contains("test_func"));
assert!(rust_code.contains("var_0 * 2"));
assert!(rust_code.contains("+ 1"));
}
#[test]
fn test_expr_operator_overloading() {
use crate::expr::Expr;
fn quadratic(x: Expr<DirectEval, f64>) -> Expr<DirectEval, f64> {
let a = Expr::constant(2.0);
let b = Expr::constant(3.0);
let c = Expr::constant(1.0);
a * x.clone() * x.clone() + b * x + c
}
let x = Expr::var_with_value("x", 2.0);
let result = quadratic(x);
assert_eq!(result.eval(), 15.0);
let x = Expr::var_with_value("x", 0.0);
let result = quadratic(x);
assert_eq!(result.eval(), 1.0);
}
#[test]
fn test_expr_transcendental_functions() {
use crate::expr::Expr;
let x = Expr::var_with_value("x", 5.0);
let result = x.ln().exp();
assert!((result.eval() - 5.0_f64).abs() < 1e-10);
let x = Expr::var_with_value("x", 1.5_f64);
let sin_x = x.clone().sin();
let cos_x = x.cos();
let result = sin_x.clone() * sin_x + cos_x.clone() * cos_x;
assert!((result.eval() - 1.0_f64).abs() < 1e-10);
}
#[test]
fn test_expr_pretty_print() {
use crate::expr::Expr;
fn simple_expr(x: Expr<PrettyPrint, f64>) -> Expr<PrettyPrint, f64> {
let two = Expr::constant(2.0);
let three = Expr::constant(3.0);
two * x + three
}
let x = Expr::<PrettyPrint, f64>::var("x");
let pretty = simple_expr(x);
let result = pretty.to_string();
assert!(result.contains('x'));
assert!(result.contains('2'));
assert!(result.contains('3'));
assert!(result.contains('*'));
assert!(result.contains('+'));
}
#[test]
fn test_expr_negation() {
use crate::expr::Expr;
let x = Expr::var_with_value("x", 5.0);
let neg_x = -x;
assert_eq!(neg_x.eval(), -5.0);
let x = Expr::var_with_value("x", 3.0);
let y = Expr::var_with_value("y", 2.0);
let result = -(x.clone() + y.clone());
let expected = -x - y;
let result_val = result.eval();
assert_eq!(result_val, expected.eval());
assert_eq!(result_val, -5.0);
}
#[test]
fn test_expr_mixed_operations() {
use crate::expr::Expr;
let x = Expr::var_with_value("x", 4.0);
let one = Expr::constant(1.0);
let left = x.clone() + one.clone();
let right = x.clone() - one;
let result = left * right;
let result_val = result.eval();
assert_eq!(result_val, 15.0);
let x_squared_minus_one = x.clone() * x - Expr::constant(1.0);
assert_eq!(result_val, x_squared_minus_one.eval());
}
}
#[cfg(test)]
mod integration_tests {
use super::*;
#[test]
fn test_end_to_end_pipeline() {
let expr = <ASTEval as ASTMathExpr>::add(
<ASTEval as ASTMathExpr>::mul(
<ASTEval as ASTMathExpr>::add(
<ASTEval as ASTMathExpr>::var(0),
<ASTEval as ASTMathExpr>::constant(0.0),
), <ASTEval as ASTMathExpr>::constant(2.0),
),
<ASTEval as ASTMathExpr>::sub(
<ASTEval as ASTMathExpr>::var(1),
<ASTEval as ASTMathExpr>::constant(0.0),
), );
let mut optimizer = SymbolicOptimizer::new().unwrap();
let optimized = optimizer.optimize(&expr).unwrap();
let codegen = RustCodeGenerator::new();
let rust_code = codegen
.generate_function(&optimized, "optimized_func")
.unwrap();
let direct_result = DirectEval::eval_two_vars(&optimized, 3.0, 4.0);
assert_eq!(direct_result, 10.0);
assert!(rust_code.contains("optimized_func"));
println!("Generated optimized Rust code:\n{rust_code}");
}
#[cfg(feature = "cranelift")]
#[test]
fn test_adaptive_compilation_strategy() {
let mut optimizer = SymbolicOptimizer::new().unwrap();
optimizer.set_compilation_strategy(CompilationStrategy::Adaptive {
call_threshold: 3,
complexity_threshold: 10,
});
let expr = <ASTEval as ASTMathExpr>::add(
<ASTEval as ASTMathExpr>::var(0),
<ASTEval as ASTMathExpr>::constant(1.0),
);
for i in 0..5 {
let approach = optimizer.choose_compilation_approach(&expr, "adaptive_test");
println!("Call {i}: {approach:?}");
if i < 2 {
assert_eq!(approach, CompilationApproach::Cranelift);
}
optimizer.record_execution("adaptive_test", 1000);
}
let approach = optimizer.choose_compilation_approach(&expr, "adaptive_test");
assert!(matches!(
approach,
CompilationApproach::UpgradeToRust | CompilationApproach::RustHotLoad
));
}
}
pub mod pretty;
pub use pretty::{pretty_anf, pretty_ast};
pub mod interval_domain;