use crate::ast::ASTRepr;
use crate::compile_time::{Add, Const, Cos, Exp, Ln, MathExpr, Mul, Pow, Sin, Sub, Var};
#[must_use]
pub fn generate_direct_code(ast: &ASTRepr<f64>, var_names: &[&str]) -> String {
match ast {
ASTRepr::Constant(c) => format!("{c}"),
ASTRepr::Variable(idx) => {
if *idx < var_names.len() {
var_names[*idx].to_string()
} else {
format!("0.0 /* undefined var {idx} */")
}
}
ASTRepr::Add(left, right) => {
format!(
"({} + {})",
generate_direct_code(left, var_names),
generate_direct_code(right, var_names)
)
}
ASTRepr::Sub(left, right) => {
format!(
"({} - {})",
generate_direct_code(left, var_names),
generate_direct_code(right, var_names)
)
}
ASTRepr::Mul(left, right) => {
format!(
"({} * {})",
generate_direct_code(left, var_names),
generate_direct_code(right, var_names)
)
}
ASTRepr::Div(left, right) => {
format!(
"({} / {})",
generate_direct_code(left, var_names),
generate_direct_code(right, var_names)
)
}
ASTRepr::Pow(base, exp) => {
format!(
"{}.powf({})",
generate_direct_code_with_parens(base, var_names),
generate_direct_code(exp, var_names)
)
}
ASTRepr::Sin(inner) => {
format!(
"{}.sin()",
generate_direct_code_with_parens(inner, var_names)
)
}
ASTRepr::Cos(inner) => {
format!(
"{}.cos()",
generate_direct_code_with_parens(inner, var_names)
)
}
ASTRepr::Exp(inner) => {
format!(
"{}.exp()",
generate_direct_code_with_parens(inner, var_names)
)
}
ASTRepr::Ln(inner) => {
format!(
"{}.ln()",
generate_direct_code_with_parens(inner, var_names)
)
}
ASTRepr::Sqrt(inner) => {
format!(
"{}.sqrt()",
generate_direct_code_with_parens(inner, var_names)
)
}
ASTRepr::Neg(inner) => {
format!("-({})", generate_direct_code(inner, var_names))
}
}
}
fn generate_direct_code_with_parens(ast: &ASTRepr<f64>, var_names: &[&str]) -> String {
match ast {
ASTRepr::Constant(_) | ASTRepr::Variable(_) => generate_direct_code(ast, var_names),
_ => {
format!("({})", generate_direct_code(ast, var_names))
}
}
}
#[must_use]
pub fn apply_simple_optimizations(ast: &ASTRepr<f64>) -> Option<ASTRepr<f64>> {
match ast {
ASTRepr::Ln(inner) => {
if let ASTRepr::Exp(exp_inner) = inner.as_ref() {
Some((**exp_inner).clone())
} else if let ASTRepr::Mul(left, right) = inner.as_ref() {
Some(ASTRepr::Add(
Box::new(ASTRepr::Ln(left.clone())),
Box::new(ASTRepr::Ln(right.clone())),
))
} else {
None
}
}
ASTRepr::Exp(inner) => {
if let ASTRepr::Ln(ln_inner) = inner.as_ref() {
Some((**ln_inner).clone())
} else if let ASTRepr::Add(left, right) = inner.as_ref() {
Some(ASTRepr::Mul(
Box::new(ASTRepr::Exp(left.clone())),
Box::new(ASTRepr::Exp(right.clone())),
))
} else {
None
}
}
ASTRepr::Add(left, right) => {
if let ASTRepr::Constant(0.0) = right.as_ref() {
Some((**left).clone())
} else if let ASTRepr::Constant(0.0) = left.as_ref() {
Some((**right).clone())
} else {
None
}
}
ASTRepr::Mul(left, right) => {
if let ASTRepr::Constant(1.0) = right.as_ref() {
Some((**left).clone())
} else if let ASTRepr::Constant(1.0) = left.as_ref() {
Some((**right).clone())
} else if let ASTRepr::Constant(0.0) = right.as_ref() {
Some(ASTRepr::Constant(0.0))
} else if let ASTRepr::Constant(0.0) = left.as_ref() {
Some(ASTRepr::Constant(0.0))
} else if let (ASTRepr::Exp(exp_left), ASTRepr::Exp(exp_right)) =
(left.as_ref(), right.as_ref())
{
Some(ASTRepr::Exp(Box::new(ASTRepr::Add(
exp_left.clone(),
exp_right.clone(),
))))
} else {
None
}
}
_ => None,
}
}
pub trait ToAst {
fn to_ast(&self) -> ASTRepr<f64>;
}
impl<const ID: usize> ToAst for Var<ID> {
fn to_ast(&self) -> ASTRepr<f64> {
ASTRepr::Variable(ID)
}
}
impl<const BITS: u64> ToAst for Const<BITS> {
fn to_ast(&self) -> ASTRepr<f64> {
ASTRepr::Constant(f64::from_bits(BITS))
}
}
impl<L, R> ToAst for Add<L, R>
where
L: ToAst + MathExpr,
R: ToAst + MathExpr,
{
fn to_ast(&self) -> ASTRepr<f64> {
ASTRepr::Add(Box::new(self.left.to_ast()), Box::new(self.right.to_ast()))
}
}
impl<L, R> ToAst for Mul<L, R>
where
L: ToAst + MathExpr,
R: ToAst + MathExpr,
{
fn to_ast(&self) -> ASTRepr<f64> {
ASTRepr::Mul(Box::new(self.left.to_ast()), Box::new(self.right.to_ast()))
}
}
impl<T> ToAst for Sin<T>
where
T: ToAst + MathExpr,
{
fn to_ast(&self) -> ASTRepr<f64> {
ASTRepr::Sin(Box::new(self.inner.to_ast()))
}
}
impl<T> ToAst for Cos<T>
where
T: ToAst + MathExpr,
{
fn to_ast(&self) -> ASTRepr<f64> {
ASTRepr::Cos(Box::new(self.inner.to_ast()))
}
}
impl<T> ToAst for Exp<T>
where
T: ToAst + MathExpr,
{
fn to_ast(&self) -> ASTRepr<f64> {
ASTRepr::Exp(Box::new(self.inner.to_ast()))
}
}
impl<T> ToAst for Ln<T>
where
T: ToAst + MathExpr,
{
fn to_ast(&self) -> ASTRepr<f64> {
ASTRepr::Ln(Box::new(self.inner.to_ast()))
}
}
impl<B, E> ToAst for Pow<B, E>
where
B: ToAst + MathExpr,
E: ToAst + MathExpr,
{
fn to_ast(&self) -> ASTRepr<f64> {
ASTRepr::Pow(
Box::new(self.base.to_ast()),
Box::new(self.exponent.to_ast()),
)
}
}
impl<L, R> ToAst for Sub<L, R>
where
L: ToAst + MathExpr,
R: ToAst + MathExpr,
{
fn to_ast(&self) -> ASTRepr<f64> {
ASTRepr::Sub(Box::new(self.left.to_ast()), Box::new(self.right.to_ast()))
}
}
#[must_use]
pub fn equality_saturation(ast: &ASTRepr<f64>, max_iterations: usize) -> ASTRepr<f64> {
let mut current = ast.clone();
let mut iteration = 0;
while iteration < max_iterations {
let next = apply_all_optimizations(¤t);
if ast_equal(¤t, &next) {
break;
}
current = next;
iteration += 1;
}
current
}
fn apply_all_optimizations(ast: &ASTRepr<f64>) -> ASTRepr<f64> {
let ast_with_optimized_children = match ast {
ASTRepr::Add(left, right) => ASTRepr::Add(
Box::new(apply_all_optimizations(left)),
Box::new(apply_all_optimizations(right)),
),
ASTRepr::Sub(left, right) => ASTRepr::Sub(
Box::new(apply_all_optimizations(left)),
Box::new(apply_all_optimizations(right)),
),
ASTRepr::Mul(left, right) => ASTRepr::Mul(
Box::new(apply_all_optimizations(left)),
Box::new(apply_all_optimizations(right)),
),
ASTRepr::Ln(inner) => ASTRepr::Ln(Box::new(apply_all_optimizations(inner))),
ASTRepr::Exp(inner) => ASTRepr::Exp(Box::new(apply_all_optimizations(inner))),
ASTRepr::Sin(inner) => ASTRepr::Sin(Box::new(apply_all_optimizations(inner))),
ASTRepr::Cos(inner) => ASTRepr::Cos(Box::new(apply_all_optimizations(inner))),
ASTRepr::Pow(base, exp) => ASTRepr::Pow(
Box::new(apply_all_optimizations(base)),
Box::new(apply_all_optimizations(exp)),
),
_ => ast.clone(),
};
apply_simple_optimizations(&ast_with_optimized_children).unwrap_or(ast_with_optimized_children)
}
fn ast_equal(a: &ASTRepr<f64>, b: &ASTRepr<f64>) -> bool {
match (a, b) {
(ASTRepr::Constant(a), ASTRepr::Constant(b)) => (a - b).abs() < 1e-10,
(ASTRepr::Variable(a), ASTRepr::Variable(b)) => a == b,
(ASTRepr::Add(a1, a2), ASTRepr::Add(b1, b2)) => ast_equal(a1, b1) && ast_equal(a2, b2),
(ASTRepr::Sub(a1, a2), ASTRepr::Sub(b1, b2)) => ast_equal(a1, b1) && ast_equal(a2, b2),
(ASTRepr::Mul(a1, a2), ASTRepr::Mul(b1, b2)) => ast_equal(a1, b1) && ast_equal(a2, b2),
(ASTRepr::Ln(a), ASTRepr::Ln(b)) => ast_equal(a, b),
(ASTRepr::Exp(a), ASTRepr::Exp(b)) => ast_equal(a, b),
(ASTRepr::Sin(a), ASTRepr::Sin(b)) => ast_equal(a, b),
(ASTRepr::Cos(a), ASTRepr::Cos(b)) => ast_equal(a, b),
(ASTRepr::Pow(a1, a2), ASTRepr::Pow(b1, b2)) => ast_equal(a1, b1) && ast_equal(a2, b2),
_ => false,
}
}
#[must_use]
pub fn eval_ast(ast: &ASTRepr<f64>, vars: &[f64]) -> f64 {
match ast {
ASTRepr::Constant(c) => *c,
ASTRepr::Variable(idx) => vars.get(*idx).copied().unwrap_or(0.0),
ASTRepr::Add(left, right) => eval_ast(left, vars) + eval_ast(right, vars),
ASTRepr::Sub(left, right) => eval_ast(left, vars) - eval_ast(right, vars),
ASTRepr::Mul(left, right) => eval_ast(left, vars) * eval_ast(right, vars),
ASTRepr::Div(left, right) => eval_ast(left, vars) / eval_ast(right, vars),
ASTRepr::Pow(base, exp) => eval_ast(base, vars).powf(eval_ast(exp, vars)),
ASTRepr::Sin(inner) => eval_ast(inner, vars).sin(),
ASTRepr::Cos(inner) => eval_ast(inner, vars).cos(),
ASTRepr::Exp(inner) => eval_ast(inner, vars).exp(),
ASTRepr::Ln(inner) => eval_ast(inner, vars).ln(),
ASTRepr::Sqrt(inner) => eval_ast(inner, vars).sqrt(),
ASTRepr::Neg(inner) => -eval_ast(inner, vars),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::compile_time::{optimize_compile_time, var};
#[test]
fn test_direct_code_generation() {
let x = var::<0>();
let expr = x.sin();
let ast = expr.to_ast();
let optimized = equality_saturation(&ast, 10);
let code = generate_direct_code(&optimized, &["x"]);
assert_eq!(code, "x.sin()");
}
#[test]
fn test_optimization_ln_exp() {
let x = var::<0>();
let expr = x.exp().ln();
let ast = expr.to_ast();
let optimized = equality_saturation(&ast, 10);
assert!(matches!(optimized, ASTRepr::Variable(0)));
}
#[test]
fn test_macro_usage() {
let x = 2.5_f64;
let y = 1.0_f64;
let result = optimize_compile_time!(var::<0>().add(var::<1>()), [x, y]);
let expected = x + y;
assert!((result - expected).abs() < 1e-10);
}
#[test]
fn test_identity_optimization() {
let x = 2.71_f64;
let result = optimize_compile_time!(var::<0>().add(constant(0.0)), [x]);
assert!((result - x).abs() < 1e-10);
}
}