#![warn(missing_docs)]
#![allow(unstable_name_collisions)]
pub mod error;
pub mod final_tagless;
pub mod symbolic;
pub mod backends;
pub use error::{DSLCompileError, Result};
pub use expr::Expr;
pub use final_tagless::{
ASTEval,
ASTMathExpr,
ASTRepr,
DirectEval,
MathBuilder,
MathExpr,
NumericType,
PrettyPrint,
StatisticalExpr,
TypeCategory,
TypedBuilderExpr,
TypedVar,
VariableRegistry,
};
pub use symbolic::symbolic::{
CompilationApproach, CompilationStrategy, OptimizationConfig, SymbolicOptimizer,
};
pub use symbolic::anf;
pub use backends::{CompiledRustFunction, RustCodeGenerator, RustCompiler, RustOptLevel};
#[cfg(feature = "cranelift")]
pub use backends::{
CompilationMetadata, CraneliftCompiledFunction, CraneliftCompiler, CraneliftFunctionSignature,
CraneliftOptLevel,
};
#[cfg(feature = "cranelift")]
pub use backends::cranelift;
pub use symbolic::summation::{SumResult, SummationConfig, SummationPattern, SummationSimplifier};
pub use symbolic::anf::{
ANFAtom, ANFCodeGen, ANFComputation, ANFConverter, ANFExpr, ANFVarGen, DomainAwareANFConverter,
DomainAwareOptimizationStats, VarRef, convert_to_anf, generate_rust_code,
};
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
pub mod prelude {
pub use crate::final_tagless::{
ASTEval,
ASTMathExpr,
ASTRepr,
DirectEval,
ExpressionBuilder,
MathBuilder,
MathExpr,
NumericType,
PrettyPrint,
StatisticalExpr,
TypeCategory,
TypedBuilderExpr,
TypedVar,
VariableRegistry,
};
pub use crate::error::{DSLCompileError, Result};
pub use crate::symbolic::symbolic::{OptimizationConfig, SymbolicOptimizer};
pub use crate::symbolic::symbolic_ad::{
SymbolicAD, SymbolicADConfig, convenience as ad_convenience,
};
pub use crate::backends::{
CompiledRustFunction, RustCodeGenerator, RustCompiler, RustOptLevel,
};
#[cfg(feature = "cranelift")]
pub use crate::backends::cranelift::{
CompilationMetadata, CompiledFunction, CraneliftCompiler, FunctionSignature,
};
pub use crate::expr::Expr;
pub use crate::symbolic::summation::{SummationConfig, SummationSimplifier};
pub use crate::symbolic::anf::{
ANFCodeGen, ANFConverter, ANFExpr, DomainAwareANFConverter, DomainAwareOptimizationStats,
convert_to_anf, generate_rust_code,
};
}
pub mod expr {
use crate::final_tagless::{DirectEval, MathExpr, NumericType, PrettyPrint};
use num_traits::Float;
use std::marker::PhantomData;
use std::ops::{Add, Div, Mul, Neg, Sub};
#[derive(Debug)]
pub struct Expr<E: MathExpr, T> {
pub(crate) repr: E::Repr<T>,
_phantom: PhantomData<E>,
}
impl<E: MathExpr, T> Clone for Expr<E, T>
where
E::Repr<T>: Clone,
{
fn clone(&self) -> Self {
Self {
repr: self.repr.clone(),
_phantom: PhantomData,
}
}
}
impl<E: MathExpr, T> Expr<E, T> {
pub fn new(repr: E::Repr<T>) -> Self {
Self {
repr,
_phantom: PhantomData,
}
}
pub fn into_repr(self) -> E::Repr<T> {
self.repr
}
pub fn as_repr(&self) -> &E::Repr<T> {
&self.repr
}
pub fn constant(value: T) -> Self
where
T: NumericType,
{
Self::new(E::constant(value))
}
#[must_use]
pub fn var(index: usize) -> Self
where
T: NumericType,
{
Self::new(E::var(index))
}
pub fn pow(self, exp: Self) -> Self
where
T: NumericType + Float,
{
Self::new(E::pow(self.repr, exp.repr))
}
pub fn ln(self) -> Self
where
T: NumericType + Float,
{
Self::new(E::ln(self.repr))
}
pub fn exp(self) -> Self
where
T: NumericType + Float,
{
Self::new(E::exp(self.repr))
}
pub fn sqrt(self) -> Self
where
T: NumericType + Float,
{
Self::new(E::sqrt(self.repr))
}
pub fn sin(self) -> Self
where
T: NumericType + Float,
{
Self::new(E::sin(self.repr))
}
pub fn cos(self) -> Self
where
T: NumericType + Float,
{
Self::new(E::cos(self.repr))
}
}
impl<T> Expr<DirectEval, T> {
pub fn var_with_value(index: usize, value: T) -> Self
where
T: NumericType,
{
Self::new(value)
}
pub fn eval(self) -> T {
self.repr
}
}
impl<T> Expr<PrettyPrint, T> {
#[must_use]
pub fn to_string(self) -> String {
self.repr
}
}
impl<E: MathExpr, L, R, Output> Add<Expr<E, R>> for Expr<E, L>
where
L: NumericType + Add<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
type Output = Expr<E, Output>;
fn add(self, rhs: Expr<E, R>) -> Self::Output {
Expr::new(E::add(self.repr, rhs.repr))
}
}
impl<E: MathExpr, L, R, Output> Sub<Expr<E, R>> for Expr<E, L>
where
L: NumericType + Sub<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
type Output = Expr<E, Output>;
fn sub(self, rhs: Expr<E, R>) -> Self::Output {
Expr::new(E::sub(self.repr, rhs.repr))
}
}
impl<E: MathExpr, L, R, Output> Mul<Expr<E, R>> for Expr<E, L>
where
L: NumericType + Mul<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
type Output = Expr<E, Output>;
fn mul(self, rhs: Expr<E, R>) -> Self::Output {
Expr::new(E::mul(self.repr, rhs.repr))
}
}
impl<E: MathExpr, L, R, Output> Div<Expr<E, R>> for Expr<E, L>
where
L: NumericType + Div<R, Output = Output>,
R: NumericType,
Output: NumericType,
{
type Output = Expr<E, Output>;
fn div(self, rhs: Expr<E, R>) -> Self::Output {
Expr::new(E::div(self.repr, rhs.repr))
}
}
impl<E: MathExpr, T> Neg for Expr<E, T>
where
T: NumericType + Neg<Output = T>,
{
type Output = Expr<E, T>;
fn neg(self) -> Self::Output {
Expr::new(E::neg(self.repr))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_version_info() {
assert!(!VERSION.is_empty());
println!("DSLCompile version: {VERSION}");
}
#[test]
fn test_ergonomic_api() {
let math = MathBuilder::new();
let x = math.var();
let expr = &x * 2.0 + 1.0;
let result = math.eval(&expr, &[3.0]);
assert_eq!(result, 7.0);
let y = math.var();
let expr2 = &x * 2.0 + &y;
let result2 = math.eval(&expr2, &[3.0, 4.0]);
assert_eq!(result2, 10.0); }
#[test]
fn test_optimization_pipeline() {
let math = MathBuilder::new();
let x = math.var();
let expr = &x - &x;
let optimized_result = math.eval(&expr, &[5.0]);
assert_eq!(optimized_result, 0.0);
let y = math.var();
let expr = &x * 2.0 + &y;
let result = math.eval(&expr, &[3.0, 4.0]);
assert_eq!(result, 10.0); }
#[test]
fn test_transcendental_functions() {
let math = MathBuilder::new();
let x = math.var();
let result = math.eval(&x.sin(), &[0.0]);
assert!((result - 0.0).abs() < 1e-10); }
#[cfg(feature = "cranelift")]
#[test]
#[ignore] fn test_cranelift_compilation() {
let math = MathBuilder::new();
let x = math.var();
let _expr = &x * 2.0 + 1.0;
use crate::final_tagless::ASTMathExpr;
let traditional_expr = <ASTEval as ASTMathExpr>::add(
<ASTEval as ASTMathExpr>::mul(
<ASTEval as ASTMathExpr>::var(0),
<ASTEval as ASTMathExpr>::constant(2.0),
),
<ASTEval as ASTMathExpr>::constant(1.0),
);
}
#[test]
fn test_rust_code_generation() {
let math = MathBuilder::new();
let x = math.var();
let _expr = &x * 2.0 + 1.0;
use crate::final_tagless::ASTMathExpr;
let traditional_expr = <ASTEval as ASTMathExpr>::add(
<ASTEval as ASTMathExpr>::mul(
<ASTEval as ASTMathExpr>::var(0),
<ASTEval as ASTMathExpr>::constant(2.0),
),
<ASTEval as ASTMathExpr>::constant(1.0),
);
let codegen = RustCodeGenerator::new();
let rust_code = codegen
.generate_function(&traditional_expr, "test_func")
.unwrap();
assert!(rust_code.contains("test_func"));
assert!(rust_code.contains("var_0 * 2"));
assert!(rust_code.contains("+ 1"));
}
}
#[cfg(test)]
mod integration_tests {
use super::*;
#[test]
fn test_end_to_end_pipeline() {
let expr = <ASTEval as ASTMathExpr>::add(
<ASTEval as ASTMathExpr>::mul(
<ASTEval as ASTMathExpr>::add(
<ASTEval as ASTMathExpr>::var(0),
<ASTEval as ASTMathExpr>::constant(0.0),
), <ASTEval as ASTMathExpr>::constant(2.0),
),
<ASTEval as ASTMathExpr>::sub(
<ASTEval as ASTMathExpr>::var(1),
<ASTEval as ASTMathExpr>::constant(0.0),
), );
let mut optimizer = SymbolicOptimizer::new().unwrap();
let optimized = optimizer.optimize(&expr).unwrap();
let codegen = RustCodeGenerator::new();
let rust_code = codegen
.generate_function(&optimized, "optimized_func")
.unwrap();
let direct_result = optimized.eval_two_vars(3.0, 4.0);
assert_eq!(direct_result, 10.0);
assert!(rust_code.contains("optimized_func"));
println!("Generated optimized Rust code:\n{rust_code}");
}
#[cfg(feature = "cranelift")]
#[test]
fn test_adaptive_compilation_strategy() {
let mut optimizer = SymbolicOptimizer::new().unwrap();
optimizer.set_compilation_strategy(CompilationStrategy::Adaptive {
call_threshold: 3,
complexity_threshold: 10,
});
let expr = <ASTEval as ASTMathExpr>::add(
<ASTEval as ASTMathExpr>::var(0),
<ASTEval as ASTMathExpr>::constant(1.0),
);
for i in 0..5 {
let approach = optimizer.choose_compilation_approach(&expr, "adaptive_test");
println!("Call {i}: {approach:?}");
if i < 2 {
assert_eq!(approach, CompilationApproach::Cranelift);
}
optimizer.record_execution("adaptive_test", 1000);
}
let approach = optimizer.choose_compilation_approach(&expr, "adaptive_test");
assert!(matches!(
approach,
CompilationApproach::UpgradeToRust | CompilationApproach::RustHotLoad
));
}
}
pub mod interval_domain;
pub mod polynomial {
pub use crate::final_tagless::polynomial::*;
}
pub mod ast;
pub mod compile_time;