use crate::error::Result;
use crate::final_tagless::{ASTRepr, NumericType, VariableRegistry};
use crate::interval_domain::{IntervalDomain, IntervalDomainAnalyzer};
use num_traits::{Float, Zero};
use ordered_float::OrderedFloat;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum VarRef {
User(usize),
Bound(u32),
}
impl VarRef {
#[must_use]
pub fn to_rust_code(&self, registry: &VariableRegistry) -> String {
match self {
VarRef::User(idx) => registry.debug_name(*idx),
VarRef::Bound(id) => format!("t{id}"),
}
}
#[must_use]
pub fn debug_name(&self, registry: &VariableRegistry) -> String {
match self {
VarRef::User(idx) => {
format!("{}({})", registry.debug_name(*idx), idx)
}
VarRef::Bound(id) => format!("t{id}"),
}
}
#[must_use]
pub fn is_user(&self) -> bool {
matches!(self, VarRef::User(_))
}
#[must_use]
pub fn is_generated(&self) -> bool {
matches!(self, VarRef::Bound(_))
}
}
#[derive(Debug, Clone)]
pub struct ANFVarGen {
next_temp_id: u32,
}
impl ANFVarGen {
#[must_use]
pub fn new() -> Self {
Self { next_temp_id: 0 }
}
pub fn fresh(&mut self) -> VarRef {
let var = VarRef::Bound(self.next_temp_id);
self.next_temp_id += 1;
var
}
#[must_use]
pub fn user_var(&self, index: usize) -> VarRef {
VarRef::User(index)
}
#[must_use]
pub fn generated_count(&self) -> u32 {
self.next_temp_id
}
}
impl Default for ANFVarGen {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ANFAtom<T> {
Constant(T),
Variable(VarRef),
}
impl<T: NumericType> ANFAtom<T> {
pub fn is_constant(&self) -> bool {
matches!(self, ANFAtom::Constant(_))
}
pub fn is_variable(&self) -> bool {
matches!(self, ANFAtom::Variable(_))
}
pub fn as_constant(&self) -> Option<&T> {
match self {
ANFAtom::Constant(val) => Some(val),
_ => None,
}
}
pub fn as_variable(&self) -> Option<VarRef> {
match self {
ANFAtom::Variable(var) => Some(*var),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ANFComputation<T> {
Add(ANFAtom<T>, ANFAtom<T>),
Sub(ANFAtom<T>, ANFAtom<T>),
Mul(ANFAtom<T>, ANFAtom<T>),
Div(ANFAtom<T>, ANFAtom<T>),
Pow(ANFAtom<T>, ANFAtom<T>),
Neg(ANFAtom<T>),
Ln(ANFAtom<T>),
Exp(ANFAtom<T>),
Sin(ANFAtom<T>),
Cos(ANFAtom<T>),
Sqrt(ANFAtom<T>),
}
impl<T: NumericType> ANFComputation<T> {
pub fn operands(&self) -> Vec<&ANFAtom<T>> {
match self {
ANFComputation::Add(a, b)
| ANFComputation::Sub(a, b)
| ANFComputation::Mul(a, b)
| ANFComputation::Div(a, b)
| ANFComputation::Pow(a, b) => vec![a, b],
ANFComputation::Neg(a)
| ANFComputation::Ln(a)
| ANFComputation::Exp(a)
| ANFComputation::Sin(a)
| ANFComputation::Cos(a)
| ANFComputation::Sqrt(a) => vec![a],
}
}
pub fn is_constant_computation(&self) -> bool {
self.operands().iter().all(|atom| atom.is_constant())
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ANFExpr<T> {
Atom(ANFAtom<T>),
Let(VarRef, ANFComputation<T>, Box<ANFExpr<T>>),
}
impl<T> ANFExpr<T>
where
T: NumericType
+ Float
+ Copy
+ std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>
+ std::ops::Neg<Output = T>
+ Zero,
{
pub fn constant(value: T) -> Self {
ANFExpr::Atom(ANFAtom::Constant(value))
}
#[must_use]
pub fn variable(var_ref: VarRef) -> Self {
ANFExpr::Atom(ANFAtom::Variable(var_ref))
}
pub fn let_binding(var_ref: VarRef, computation: ANFComputation<T>, body: ANFExpr<T>) -> Self {
ANFExpr::Let(var_ref, computation, Box::new(body))
}
pub fn is_atom(&self) -> bool {
matches!(self, ANFExpr::Atom(_))
}
pub fn is_let(&self) -> bool {
matches!(self, ANFExpr::Let(_, _, _))
}
pub fn let_count(&self) -> usize {
match self {
ANFExpr::Atom(_) => 0,
ANFExpr::Let(_, _, body) => 1 + body.let_count(),
}
}
pub fn used_variables(&self) -> Vec<VarRef> {
let mut vars = Vec::new();
self.collect_variables(&mut vars);
vars.sort_unstable();
vars.dedup();
vars
}
fn collect_variables(&self, vars: &mut Vec<VarRef>) {
match self {
ANFExpr::Atom(ANFAtom::Variable(var)) => vars.push(*var),
ANFExpr::Atom(ANFAtom::Constant(_)) => {}
ANFExpr::Let(_, comp, body) => {
for operand in comp.operands() {
if let ANFAtom::Variable(var) = operand {
vars.push(*var);
}
}
body.collect_variables(vars);
}
}
}
pub fn eval(&self, variables: &HashMap<usize, T>) -> T {
self.eval_with_bound_vars(variables, &HashMap::new())
}
pub fn eval_domain_aware(
&self,
variables: &HashMap<usize, T>,
domain_analyzer: &IntervalDomainAnalyzer<T>,
) -> T
where
T: PartialOrd + From<f64>,
{
self.eval_with_bound_vars_domain_aware(variables, &HashMap::new(), domain_analyzer)
}
fn eval_with_bound_vars(
&self,
user_vars: &HashMap<usize, T>,
bound_vars: &HashMap<u32, T>,
) -> T {
match self {
ANFExpr::Atom(atom) => match atom {
ANFAtom::Constant(value) => *value,
ANFAtom::Variable(var_ref) => match var_ref {
VarRef::User(idx) => user_vars.get(idx).copied().unwrap_or_else(T::zero),
VarRef::Bound(id) => bound_vars.get(id).copied().unwrap_or_else(T::zero),
},
},
ANFExpr::Let(var_ref, computation, body) => {
let comp_result = match computation {
ANFComputation::Add(a, b) => {
self.eval_atom_with_bound(a, user_vars, bound_vars)
+ self.eval_atom_with_bound(b, user_vars, bound_vars)
}
ANFComputation::Sub(a, b) => {
self.eval_atom_with_bound(a, user_vars, bound_vars)
- self.eval_atom_with_bound(b, user_vars, bound_vars)
}
ANFComputation::Mul(a, b) => {
self.eval_atom_with_bound(a, user_vars, bound_vars)
* self.eval_atom_with_bound(b, user_vars, bound_vars)
}
ANFComputation::Div(a, b) => {
self.eval_atom_with_bound(a, user_vars, bound_vars)
/ self.eval_atom_with_bound(b, user_vars, bound_vars)
}
ANFComputation::Pow(a, b) => {
let base = self.eval_atom_with_bound(a, user_vars, bound_vars);
let exp = self.eval_atom_with_bound(b, user_vars, bound_vars);
Self::safe_powf(base, exp)
}
ANFComputation::Neg(a) => -self.eval_atom_with_bound(a, user_vars, bound_vars),
ANFComputation::Ln(a) => {
self.eval_atom_with_bound(a, user_vars, bound_vars).ln()
}
ANFComputation::Exp(a) => {
self.eval_atom_with_bound(a, user_vars, bound_vars).exp()
}
ANFComputation::Sin(a) => {
self.eval_atom_with_bound(a, user_vars, bound_vars).sin()
}
ANFComputation::Cos(a) => {
self.eval_atom_with_bound(a, user_vars, bound_vars).cos()
}
ANFComputation::Sqrt(a) => {
self.eval_atom_with_bound(a, user_vars, bound_vars).sqrt()
}
};
match var_ref {
VarRef::Bound(id) => {
let mut extended_bound_vars = bound_vars.clone();
extended_bound_vars.insert(*id, comp_result);
body.eval_with_bound_vars(user_vars, &extended_bound_vars)
}
VarRef::User(_) => {
body.eval_with_bound_vars(user_vars, bound_vars)
}
}
}
}
}
fn eval_with_bound_vars_domain_aware(
&self,
user_vars: &HashMap<usize, T>,
bound_vars: &HashMap<u32, T>,
domain_analyzer: &IntervalDomainAnalyzer<T>,
) -> T
where
T: PartialOrd + From<f64>,
{
match self {
ANFExpr::Atom(atom) => match atom {
ANFAtom::Constant(value) => *value,
ANFAtom::Variable(var_ref) => match var_ref {
VarRef::User(idx) => user_vars.get(idx).copied().unwrap_or_else(T::zero),
VarRef::Bound(id) => bound_vars.get(id).copied().unwrap_or_else(T::zero),
},
},
ANFExpr::Let(var_ref, computation, body) => {
let comp_result = match computation {
ANFComputation::Add(a, b) => {
self.eval_atom_with_bound(a, user_vars, bound_vars)
+ self.eval_atom_with_bound(b, user_vars, bound_vars)
}
ANFComputation::Sub(a, b) => {
self.eval_atom_with_bound(a, user_vars, bound_vars)
- self.eval_atom_with_bound(b, user_vars, bound_vars)
}
ANFComputation::Mul(a, b) => {
self.eval_atom_with_bound(a, user_vars, bound_vars)
* self.eval_atom_with_bound(b, user_vars, bound_vars)
}
ANFComputation::Div(a, b) => {
self.eval_atom_with_bound(a, user_vars, bound_vars)
/ self.eval_atom_with_bound(b, user_vars, bound_vars)
}
ANFComputation::Pow(a, b) => {
let base = self.eval_atom_with_bound(a, user_vars, bound_vars);
let exp = self.eval_atom_with_bound(b, user_vars, bound_vars);
Self::domain_aware_powf(base, exp, domain_analyzer)
}
ANFComputation::Neg(a) => -self.eval_atom_with_bound(a, user_vars, bound_vars),
ANFComputation::Ln(a) => {
self.eval_atom_with_bound(a, user_vars, bound_vars).ln()
}
ANFComputation::Exp(a) => {
self.eval_atom_with_bound(a, user_vars, bound_vars).exp()
}
ANFComputation::Sin(a) => {
self.eval_atom_with_bound(a, user_vars, bound_vars).sin()
}
ANFComputation::Cos(a) => {
self.eval_atom_with_bound(a, user_vars, bound_vars).cos()
}
ANFComputation::Sqrt(a) => {
self.eval_atom_with_bound(a, user_vars, bound_vars).sqrt()
}
};
match var_ref {
VarRef::Bound(id) => {
let mut extended_bound_vars = bound_vars.clone();
extended_bound_vars.insert(*id, comp_result);
body.eval_with_bound_vars_domain_aware(
user_vars,
&extended_bound_vars,
domain_analyzer,
)
}
VarRef::User(_) => {
body.eval_with_bound_vars_domain_aware(
user_vars,
bound_vars,
domain_analyzer,
)
}
}
}
}
}
fn safe_powf(base: T, exp: T) -> T {
let result = base.powf(exp);
if result.is_finite() || result.is_infinite() {
return result;
}
if result.is_nan() {
if base.is_finite() && base < T::zero() && exp.is_finite() {
return T::nan();
}
return result;
}
result
}
fn domain_aware_powf(base: T, exp: T, _domain_analyzer: &IntervalDomainAnalyzer<T>) -> T
where
T: PartialOrd + From<f64>,
{
let result = base.powf(exp);
if result.is_finite() || result.is_infinite() {
return result;
}
if result.is_nan() {
if base.is_finite() && base < T::zero() && exp.is_finite() {
return T::nan();
}
return result;
}
result
}
fn eval_atom_with_bound(
&self,
atom: &ANFAtom<T>,
user_vars: &HashMap<usize, T>,
bound_vars: &HashMap<u32, T>,
) -> T {
match atom {
ANFAtom::Constant(value) => *value,
ANFAtom::Variable(var_ref) => match var_ref {
VarRef::User(idx) => user_vars.get(idx).copied().unwrap_or_else(T::zero),
VarRef::Bound(id) => bound_vars.get(id).copied().unwrap_or_else(T::zero),
},
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum StructuralHash {
Constant(OrderedFloat<f64>),
Variable(usize),
Add(Box<StructuralHash>, Box<StructuralHash>),
Sub(Box<StructuralHash>, Box<StructuralHash>),
Mul(Box<StructuralHash>, Box<StructuralHash>),
Div(Box<StructuralHash>, Box<StructuralHash>),
Pow(Box<StructuralHash>, Box<StructuralHash>),
Neg(Box<StructuralHash>),
Ln(Box<StructuralHash>),
Exp(Box<StructuralHash>),
Sin(Box<StructuralHash>),
Cos(Box<StructuralHash>),
Sqrt(Box<StructuralHash>),
}
impl StructuralHash {
#[must_use]
pub fn from_expr(expr: &ASTRepr<f64>) -> Self {
match expr {
ASTRepr::Constant(val) => StructuralHash::Constant(OrderedFloat(*val)),
ASTRepr::Variable(idx) => StructuralHash::Variable(*idx),
ASTRepr::Add(left, right) => StructuralHash::Add(
Box::new(Self::from_expr(left)),
Box::new(Self::from_expr(right)),
),
ASTRepr::Sub(left, right) => StructuralHash::Sub(
Box::new(Self::from_expr(left)),
Box::new(Self::from_expr(right)),
),
ASTRepr::Mul(left, right) => StructuralHash::Mul(
Box::new(Self::from_expr(left)),
Box::new(Self::from_expr(right)),
),
ASTRepr::Div(left, right) => StructuralHash::Div(
Box::new(Self::from_expr(left)),
Box::new(Self::from_expr(right)),
),
ASTRepr::Pow(left, right) => StructuralHash::Pow(
Box::new(Self::from_expr(left)),
Box::new(Self::from_expr(right)),
),
ASTRepr::Neg(inner) => StructuralHash::Neg(Box::new(Self::from_expr(inner))),
ASTRepr::Ln(inner) => StructuralHash::Ln(Box::new(Self::from_expr(inner))),
ASTRepr::Exp(inner) => StructuralHash::Exp(Box::new(Self::from_expr(inner))),
ASTRepr::Sin(inner) => StructuralHash::Sin(Box::new(Self::from_expr(inner))),
ASTRepr::Cos(inner) => StructuralHash::Cos(Box::new(Self::from_expr(inner))),
ASTRepr::Sqrt(inner) => StructuralHash::Sqrt(Box::new(Self::from_expr(inner))),
}
}
}
#[derive(Debug)]
pub struct ANFConverter {
binding_depth: u32,
next_binding_id: u32,
expr_cache: HashMap<StructuralHash, (u32, VarRef, u32)>,
}
impl ANFConverter {
#[must_use]
pub fn new() -> Self {
Self {
binding_depth: 0,
next_binding_id: 0,
expr_cache: HashMap::new(),
}
}
pub fn convert(&mut self, expr: &ASTRepr<f64>) -> Result<ANFExpr<f64>> {
Ok(self.to_anf(expr))
}
fn to_anf(&mut self, expr: &ASTRepr<f64>) -> ANFExpr<f64> {
if let ASTRepr::Constant(value) = expr {
return ANFExpr::Atom(ANFAtom::Constant(*value));
}
if !matches!(expr, ASTRepr::Constant(_) | ASTRepr::Variable(_)) {
let structural_hash = StructuralHash::from_expr(expr);
if let Some((cached_scope, cached_var, _cached_binding_id)) =
self.expr_cache.get(&structural_hash)
{
if *cached_scope <= self.binding_depth {
return ANFExpr::Atom(ANFAtom::Variable(*cached_var));
}
self.expr_cache.remove(&structural_hash);
}
}
match expr {
ASTRepr::Constant(value) => ANFExpr::Atom(ANFAtom::Constant(*value)),
ASTRepr::Variable(index) => ANFExpr::Atom(ANFAtom::Variable(VarRef::User(*index))),
ASTRepr::Add(left, right) => {
self.convert_binary_op_with_cse(expr, left, right, ANFComputation::Add)
}
ASTRepr::Sub(left, right) => {
self.convert_binary_op_with_cse(expr, left, right, ANFComputation::Sub)
}
ASTRepr::Mul(left, right) => {
self.convert_binary_op_with_cse(expr, left, right, ANFComputation::Mul)
}
ASTRepr::Div(left, right) => {
self.convert_binary_op_with_cse(expr, left, right, ANFComputation::Div)
}
ASTRepr::Pow(left, right) => {
self.convert_binary_op_with_cse(expr, left, right, ANFComputation::Pow)
}
ASTRepr::Neg(inner) => self.convert_unary_op_with_cse(expr, inner, ANFComputation::Neg),
ASTRepr::Ln(inner) => self.convert_unary_op_with_cse(expr, inner, ANFComputation::Ln),
ASTRepr::Exp(inner) => self.convert_unary_op_with_cse(expr, inner, ANFComputation::Exp),
ASTRepr::Sin(inner) => self.convert_unary_op_with_cse(expr, inner, ANFComputation::Sin),
ASTRepr::Cos(inner) => self.convert_unary_op_with_cse(expr, inner, ANFComputation::Cos),
ASTRepr::Sqrt(inner) => {
self.convert_unary_op_with_cse(expr, inner, ANFComputation::Sqrt)
}
}
}
fn convert_binary_op_with_cse(
&mut self,
expr: &ASTRepr<f64>,
left: &ASTRepr<f64>,
right: &ASTRepr<f64>,
op_constructor: fn(ANFAtom<f64>, ANFAtom<f64>) -> ANFComputation<f64>,
) -> ANFExpr<f64> {
if matches!(op_constructor(ANFAtom::Constant(0.0), ANFAtom::Constant(0.0)), ANFComputation::Pow(_, _)) {
if let ASTRepr::Constant(exp_val) = right {
if exp_val.fract() == 0.0 && exp_val.abs() <= 64.0 && *exp_val != 0.0 && *exp_val != 1.0 {
let exp_int = *exp_val as i32;
return self.convert_integer_power_to_anf(left, exp_int);
}
}
}
fn extract_final_var(expr: &ANFExpr<f64>) -> Option<VarRef> {
match expr {
ANFExpr::Let(var, _, body) => extract_final_var(body).or(Some(*var)),
ANFExpr::Atom(ANFAtom::Variable(var)) => Some(*var),
_ => None,
}
}
let (left_expr, left_atom_orig) = self.to_anf_atom(left);
let (right_expr, right_atom_orig) = self.to_anf_atom(right);
let left_atom = match &left_expr {
Some(e) => extract_final_var(e).map_or(left_atom_orig, ANFAtom::Variable),
None => left_atom_orig,
};
let right_atom = match &right_expr {
Some(e) => extract_final_var(e).map_or(right_atom_orig, ANFAtom::Variable),
None => right_atom_orig,
};
let computation = op_constructor(left_atom.clone(), right_atom.clone());
if left_atom.is_constant() && right_atom.is_constant() {
let result = match computation {
ANFComputation::Add(ANFAtom::Constant(a), ANFAtom::Constant(b)) => {
ANFAtom::Constant(a + b)
}
ANFComputation::Sub(ANFAtom::Constant(a), ANFAtom::Constant(b)) => {
ANFAtom::Constant(a - b)
}
ANFComputation::Mul(ANFAtom::Constant(a), ANFAtom::Constant(b)) => {
ANFAtom::Constant(a * b)
}
ANFComputation::Div(ANFAtom::Constant(a), ANFAtom::Constant(b)) => {
ANFAtom::Constant(a / b)
}
ANFComputation::Pow(ANFAtom::Constant(a), ANFAtom::Constant(b)) => {
let result = a.powf(b);
if result.is_finite() && !result.is_nan() {
ANFAtom::Constant(result)
} else {
let binding_id = self.next_binding_id;
self.next_binding_id += 1;
let result_var = VarRef::Bound(binding_id);
let structural_hash = StructuralHash::from_expr(expr);
self.expr_cache.insert(
structural_hash,
(self.binding_depth, result_var, binding_id),
);
self.binding_depth += 1;
let body = ANFExpr::Atom(ANFAtom::Variable(result_var));
self.binding_depth -= 1;
return self.chain_lets(
left_expr,
right_expr,
ANFExpr::Let(result_var, computation, Box::new(body)),
);
}
}
_ => unreachable!(),
};
return ANFExpr::Atom(result);
}
let binding_id = self.next_binding_id;
self.next_binding_id += 1;
let result_var = VarRef::Bound(binding_id);
let structural_hash = StructuralHash::from_expr(expr);
self.expr_cache.insert(
structural_hash,
(self.binding_depth, result_var, binding_id),
);
self.binding_depth += 1;
let body = ANFExpr::Atom(ANFAtom::Variable(result_var));
self.binding_depth -= 1;
self.chain_lets(
left_expr,
right_expr,
ANFExpr::Let(result_var, computation, Box::new(body)),
)
}
fn convert_integer_power_to_anf(&mut self, base: &ASTRepr<f64>, exp: i32) -> ANFExpr<f64> {
let (base_expr, base_atom) = self.to_anf_atom(base);
match exp {
0 => ANFExpr::Atom(ANFAtom::Constant(1.0)),
1 => match base_expr {
Some(expr) => expr,
None => ANFExpr::Atom(base_atom),
},
-1 => {
let one = ANFAtom::Constant(1.0);
let div_computation = ANFComputation::Div(one, base_atom);
let binding_id = self.next_binding_id;
self.next_binding_id += 1;
let result_var = VarRef::Bound(binding_id);
let body = ANFExpr::Atom(ANFAtom::Variable(result_var));
let div_expr = ANFExpr::Let(result_var, div_computation, Box::new(body));
match base_expr {
Some(expr) => self.wrap_with_lets(Some(expr), div_expr),
None => div_expr,
}
},
2 => {
let mul_computation = ANFComputation::Mul(base_atom.clone(), base_atom);
let binding_id = self.next_binding_id;
self.next_binding_id += 1;
let result_var = VarRef::Bound(binding_id);
let body = ANFExpr::Atom(ANFAtom::Variable(result_var));
let mul_expr = ANFExpr::Let(result_var, mul_computation, Box::new(body));
match base_expr {
Some(expr) => self.wrap_with_lets(Some(expr), mul_expr),
None => mul_expr,
}
},
exp if exp > 0 => {
self.generate_binary_exponentiation_anf(base_expr, base_atom, exp as u32)
},
exp if exp < 0 => {
let positive_power = self.generate_binary_exponentiation_anf(base_expr.clone(), base_atom.clone(), (-exp) as u32);
let power_var = self.extract_result_var(&positive_power);
let one = ANFAtom::Constant(1.0);
let div_computation = ANFComputation::Div(one, ANFAtom::Variable(power_var));
let binding_id = self.next_binding_id;
self.next_binding_id += 1;
let result_var = VarRef::Bound(binding_id);
let body = ANFExpr::Atom(ANFAtom::Variable(result_var));
let div_expr = ANFExpr::Let(result_var, div_computation, Box::new(body));
self.wrap_with_lets(Some(positive_power), div_expr)
},
_ => unreachable!(),
}
}
fn generate_binary_exponentiation_anf(
&mut self,
base_expr: Option<ANFExpr<f64>>,
base_atom: ANFAtom<f64>,
exp: u32,
) -> ANFExpr<f64> {
if exp == 1 {
return match base_expr {
Some(expr) => expr,
None => ANFExpr::Atom(base_atom),
};
}
let mut result_atom = base_atom.clone();
let mut result_expr = base_expr;
let mut current_exp = exp;
let mut accumulated_expr: Option<ANFExpr<f64>> = None;
let mut accumulated_atom: Option<ANFAtom<f64>> = None;
if current_exp % 2 == 1 {
accumulated_atom = Some(base_atom.clone());
accumulated_expr = result_expr.clone();
current_exp -= 1;
}
while current_exp > 0 {
let square_computation = ANFComputation::Mul(result_atom.clone(), result_atom.clone());
let binding_id = self.next_binding_id;
self.next_binding_id += 1;
let square_var = VarRef::Bound(binding_id);
let square_body = ANFExpr::Atom(ANFAtom::Variable(square_var));
let square_expr = ANFExpr::Let(square_var, square_computation, Box::new(square_body));
result_expr = Some(match result_expr {
Some(expr) => self.wrap_with_lets(Some(expr), square_expr),
None => square_expr,
});
result_atom = ANFAtom::Variable(square_var);
current_exp /= 2;
if current_exp % 2 == 1 && current_exp > 0 {
match (&accumulated_atom, &accumulated_expr) {
(Some(acc_atom), acc_expr) => {
let mul_computation = ANFComputation::Mul(acc_atom.clone(), result_atom.clone());
let binding_id = self.next_binding_id;
self.next_binding_id += 1;
let mul_var = VarRef::Bound(binding_id);
let mul_body = ANFExpr::Atom(ANFAtom::Variable(mul_var));
let mul_expr = ANFExpr::Let(mul_var, mul_computation, Box::new(mul_body));
accumulated_expr = Some(match (acc_expr, &result_expr) {
(Some(acc_e), Some(res_e)) => {
let combined = self.wrap_with_lets(Some(res_e.clone()), mul_expr);
self.wrap_with_lets(Some(acc_e.clone()), combined)
},
(Some(acc_e), None) => self.wrap_with_lets(Some(acc_e.clone()), mul_expr),
(None, Some(res_e)) => self.wrap_with_lets(Some(res_e.clone()), mul_expr),
(None, None) => mul_expr,
});
accumulated_atom = Some(ANFAtom::Variable(mul_var));
},
(None, _) => {
accumulated_atom = Some(result_atom.clone());
accumulated_expr = result_expr.clone();
},
}
current_exp -= 1;
}
}
match (accumulated_expr, accumulated_atom) {
(Some(expr), _) => expr,
(None, Some(atom)) => ANFExpr::Atom(atom),
(None, None) => result_expr.unwrap_or(ANFExpr::Atom(result_atom)),
}
}
fn extract_result_var(&self, expr: &ANFExpr<f64>) -> VarRef {
match expr {
ANFExpr::Atom(ANFAtom::Variable(var)) => *var,
ANFExpr::Let(var, _, _) => *var,
_ => panic!("Expected variable result from power expression"),
}
}
fn convert_unary_op_with_cse(
&mut self,
expr: &ASTRepr<f64>,
inner: &ASTRepr<f64>,
op_constructor: fn(ANFAtom<f64>) -> ANFComputation<f64>,
) -> ANFExpr<f64> {
fn extract_final_var(expr: &ANFExpr<f64>) -> Option<VarRef> {
match expr {
ANFExpr::Let(var, _, body) => extract_final_var(body).or(Some(*var)),
ANFExpr::Atom(ANFAtom::Variable(var)) => Some(*var),
_ => None,
}
}
let (inner_expr, inner_atom_orig) = self.to_anf_atom(inner);
let inner_atom = match &inner_expr {
Some(e) => extract_final_var(e).map_or(inner_atom_orig, ANFAtom::Variable),
None => inner_atom_orig,
};
let computation = op_constructor(inner_atom.clone());
if inner_atom.is_constant() {
let result = match computation {
ANFComputation::Neg(ANFAtom::Constant(a)) => ANFAtom::Constant(-a),
ANFComputation::Ln(ANFAtom::Constant(a)) => ANFAtom::Constant(a.ln()),
ANFComputation::Exp(ANFAtom::Constant(a)) => ANFAtom::Constant(a.exp()),
ANFComputation::Sin(ANFAtom::Constant(a)) => ANFAtom::Constant(a.sin()),
ANFComputation::Cos(ANFAtom::Constant(a)) => ANFAtom::Constant(a.cos()),
ANFComputation::Sqrt(ANFAtom::Constant(a)) => ANFAtom::Constant(a.sqrt()),
_ => unreachable!(),
};
return ANFExpr::Atom(result);
}
let binding_id = self.next_binding_id;
self.next_binding_id += 1;
let result_var = VarRef::Bound(binding_id);
let structural_hash = StructuralHash::from_expr(expr);
self.expr_cache.insert(
structural_hash,
(self.binding_depth, result_var, binding_id),
);
self.binding_depth += 1;
let body = ANFExpr::Atom(ANFAtom::Variable(result_var));
self.binding_depth -= 1;
self.wrap_with_lets(
inner_expr,
ANFExpr::Let(result_var, computation, Box::new(body)),
)
}
fn to_anf_atom(&mut self, expr: &ASTRepr<f64>) -> (Option<ANFExpr<f64>>, ANFAtom<f64>) {
match expr {
ASTRepr::Constant(value) => (None, ANFAtom::Constant(*value)),
ASTRepr::Variable(index) => (None, ANFAtom::Variable(VarRef::User(*index))),
_ => {
let anf_expr = self.to_anf(expr);
match anf_expr {
ANFExpr::Atom(atom) => (None, atom),
ANFExpr::Let(var, computation, body) => (
Some(ANFExpr::Let(var, computation, body)),
ANFAtom::Variable(var),
),
}
}
}
}
fn chain_lets<T: NumericType + Clone>(
&self,
first: Option<ANFExpr<T>>,
second: Option<ANFExpr<T>>,
final_expr: ANFExpr<T>,
) -> ANFExpr<T> {
match (first, second) {
(None, None) => final_expr,
(Some(first_expr), None) => self.wrap_with_lets(Some(first_expr), final_expr),
(None, Some(second_expr)) => self.wrap_with_lets(Some(second_expr), final_expr),
(Some(first_expr), Some(second_expr)) => {
let combined = self.wrap_with_lets(Some(second_expr), final_expr);
self.wrap_with_lets(Some(first_expr), combined)
}
}
}
fn wrap_with_lets<T: NumericType + Clone>(
&self,
wrapper: Option<ANFExpr<T>>,
body: ANFExpr<T>,
) -> ANFExpr<T> {
match wrapper {
None => body,
Some(ANFExpr::Let(var, computation, inner_body)) => ANFExpr::Let(
var,
computation,
Box::new(self.wrap_with_lets(Some(*inner_body), body)),
),
Some(ANFExpr::Atom(_)) => body, }
}
}
impl Default for ANFConverter {
fn default() -> Self {
Self::new()
}
}
pub fn convert_to_anf(expr: &ASTRepr<f64>) -> Result<ANFExpr<f64>> {
let mut converter = ANFConverter::new();
converter.convert(expr)
}
#[derive(Debug)]
pub struct DomainAwareANFConverter {
anf_converter: ANFConverter,
domain_analyzer: IntervalDomainAnalyzer<f64>,
variable_domains: HashMap<u32, IntervalDomain<f64>>,
safety_cache: HashMap<String, bool>,
}
impl DomainAwareANFConverter {
#[must_use]
pub fn new(domain_analyzer: IntervalDomainAnalyzer<f64>) -> Self {
Self {
anf_converter: ANFConverter::new(),
domain_analyzer,
variable_domains: HashMap::new(),
safety_cache: HashMap::new(),
}
}
pub fn convert(&mut self, expr: &ASTRepr<f64>) -> Result<ANFExpr<f64>> {
let anf = self.anf_converter.convert(expr)?;
self.propagate_domain_information(&anf, expr);
Ok(anf)
}
fn propagate_domain_information(&mut self, anf: &ANFExpr<f64>, original_expr: &ASTRepr<f64>) {
match anf {
ANFExpr::Atom(_) => {
}
ANFExpr::Let(var_ref, computation, body) => {
let domain = self.compute_computation_domain(computation);
if let VarRef::Bound(id) = var_ref {
self.variable_domains.insert(*id, domain);
}
self.propagate_domain_information(body, original_expr);
}
}
}
fn compute_computation_domain(&self, computation: &ANFComputation<f64>) -> IntervalDomain<f64> {
match computation {
ANFComputation::Add(left, right) => {
let left_domain = self.compute_atom_domain(left);
let right_domain = self.compute_atom_domain(right);
if left_domain == IntervalDomain::Bottom || right_domain == IntervalDomain::Bottom {
IntervalDomain::Bottom
} else {
IntervalDomain::Top
}
}
ANFComputation::Mul(left, right) => {
let left_domain = self.compute_atom_domain(left);
let right_domain = self.compute_atom_domain(right);
if left_domain.is_positive(0.0) && right_domain.is_positive(0.0) {
IntervalDomain::positive(0.0)
} else {
IntervalDomain::Top
}
}
ANFComputation::Exp(_) => {
IntervalDomain::positive(0.0)
}
_ => IntervalDomain::Top, }
}
fn compute_atom_domain(&self, atom: &ANFAtom<f64>) -> IntervalDomain<f64> {
match atom {
ANFAtom::Constant(val) => IntervalDomain::Constant(*val),
ANFAtom::Variable(var_ref) => self.get_variable_domain(*var_ref),
}
}
#[must_use]
pub fn get_variable_domain(&self, var_ref: VarRef) -> IntervalDomain<f64> {
match var_ref {
VarRef::User(idx) => self.domain_analyzer.get_variable_domain(idx),
VarRef::Bound(id) => self
.variable_domains
.get(&id)
.cloned()
.unwrap_or(IntervalDomain::Top),
}
}
pub fn set_generated_variable_domain(&mut self, var_id: u32, domain: IntervalDomain<f64>) {
self.variable_domains.insert(var_id, domain);
}
#[must_use]
pub fn domain_analyzer(&self) -> &IntervalDomainAnalyzer<f64> {
&self.domain_analyzer
}
pub fn domain_analyzer_mut(&mut self) -> &mut IntervalDomainAnalyzer<f64> {
&mut self.domain_analyzer
}
pub fn convert_with_domain_constraint(
&mut self,
expr: &ASTRepr<f64>,
expected_domain: &IntervalDomain<f64>,
) -> Result<ANFExpr<f64>> {
let anf = self.convert(expr)?;
let output_domain = self.analyze_expression_domain(expr);
if !self.is_domain_compatible(&output_domain, expected_domain) {
return Err(crate::error::DSLCompileError::DomainError(format!(
"Expression domain {output_domain:?} is not compatible with expected domain {expected_domain:?}"
)));
}
Ok(anf)
}
fn analyze_expression_domain(&self, expr: &ASTRepr<f64>) -> IntervalDomain<f64> {
match expr {
ASTRepr::Constant(val) => IntervalDomain::Constant(*val),
ASTRepr::Variable(idx) => self.domain_analyzer.get_variable_domain(*idx),
ASTRepr::Exp(_) => {
IntervalDomain::positive(0.0)
}
ASTRepr::Ln(inner) => {
let inner_domain = self.analyze_expression_domain(inner);
if inner_domain.is_positive(0.0) {
IntervalDomain::Top } else {
IntervalDomain::Bottom }
}
ASTRepr::Sqrt(inner) => {
let inner_domain = self.analyze_expression_domain(inner);
if inner_domain.is_non_negative(0.0) {
IntervalDomain::non_negative(0.0)
} else {
IntervalDomain::Bottom }
}
ASTRepr::Mul(left, right) => {
let left_domain = self.analyze_expression_domain(left);
let right_domain = self.analyze_expression_domain(right);
if left_domain.is_positive(0.0) && right_domain.is_positive(0.0) {
IntervalDomain::positive(0.0)
} else {
IntervalDomain::Top }
}
_ => IntervalDomain::Top, }
}
fn is_domain_compatible(
&self,
domain1: &IntervalDomain<f64>,
domain2: &IntervalDomain<f64>,
) -> bool {
match (domain1, domain2) {
(IntervalDomain::Bottom, _) => true,
(_, IntervalDomain::Bottom) => false,
(IntervalDomain::Top, IntervalDomain::Top) => true,
(IntervalDomain::Top, _) => false,
(_, IntervalDomain::Top) => true,
(IntervalDomain::Constant(a), IntervalDomain::Constant(b)) => a == b,
(IntervalDomain::Interval { .. }, IntervalDomain::Interval { .. }) => {
domain1 == domain2
}
(IntervalDomain::Constant(val), interval) => interval.contains(*val),
(interval, IntervalDomain::Constant(val)) => interval.contains(*val),
}
}
pub fn is_operation_safe(
&mut self,
operation: &str,
operands: &[&ASTRepr<f64>],
) -> Result<bool> {
let cache_key = format!("{operation}:{operands:?}");
if let Some(&cached_result) = self.safety_cache.get(&cache_key) {
return Ok(cached_result);
}
let result = match operation {
"ln" => {
if operands.len() != 1 {
return Ok(false);
}
let domain = self.analyze_expression_domain(operands[0]);
domain.is_positive(0.0)
}
"sqrt" => {
if operands.len() != 1 {
return Ok(false);
}
let domain = self.analyze_expression_domain(operands[0]);
domain.is_non_negative(0.0)
}
"div" => {
if operands.len() != 2 {
return Ok(false);
}
let denominator_domain = self.analyze_expression_domain(operands[1]);
!matches!(denominator_domain, IntervalDomain::Constant(x) if x == 0.0)
}
_ => true, };
self.safety_cache.insert(cache_key, result);
Ok(result)
}
pub fn clear_caches(&mut self) {
self.safety_cache.clear();
self.variable_domains.clear();
}
#[must_use]
pub fn get_optimization_stats(&self) -> DomainAwareOptimizationStats {
DomainAwareOptimizationStats {
generated_variables: self.variable_domains.len(),
safety_checks_cached: self.safety_cache.len(),
anf_let_bindings: 0, }
}
}
#[derive(Debug, Clone)]
pub struct DomainAwareOptimizationStats {
pub generated_variables: usize,
pub safety_checks_cached: usize,
pub anf_let_bindings: usize,
}
#[derive(Debug)]
pub struct ANFCodeGen<'a> {
registry: &'a VariableRegistry,
}
impl<'a> ANFCodeGen<'a> {
#[must_use]
pub fn new(registry: &'a VariableRegistry) -> Self {
Self { registry }
}
pub fn generate<T: NumericType + std::fmt::Display>(&self, expr: &ANFExpr<T>) -> String {
match expr {
ANFExpr::Atom(atom) => self.generate_atom(atom),
ANFExpr::Let(var, computation, body) => {
let var_name = var.to_rust_code(self.registry);
let comp_code = self.generate_computation(computation);
let body_code = self.generate(body);
format!("{{ let {var_name} = {comp_code};\n{body_code} }}")
}
}
}
fn generate_atom<T: NumericType + std::fmt::Display>(&self, atom: &ANFAtom<T>) -> String {
match atom {
ANFAtom::Constant(value) => value.to_string(),
ANFAtom::Variable(var) => var.to_rust_code(self.registry),
}
}
fn generate_computation<T: NumericType + std::fmt::Display>(
&self,
comp: &ANFComputation<T>,
) -> String {
match comp {
ANFComputation::Add(left, right) => {
format!(
"{} + {}",
self.generate_atom(left),
self.generate_atom(right)
)
}
ANFComputation::Sub(left, right) => {
format!(
"{} - {}",
self.generate_atom(left),
self.generate_atom(right)
)
}
ANFComputation::Mul(left, right) => {
format!(
"{} * {}",
self.generate_atom(left),
self.generate_atom(right)
)
}
ANFComputation::Div(left, right) => {
format!(
"{} / {}",
self.generate_atom(left),
self.generate_atom(right)
)
}
ANFComputation::Pow(left, right) => {
format!(
"{}.powf({})",
self.generate_atom(left),
self.generate_atom(right)
)
}
ANFComputation::Neg(operand) => {
format!("-{}", self.generate_atom(operand))
}
ANFComputation::Ln(operand) => {
format!("{}.ln()", self.generate_atom(operand))
}
ANFComputation::Exp(operand) => {
format!("{}.exp()", self.generate_atom(operand))
}
ANFComputation::Sin(operand) => {
format!("{}.sin()", self.generate_atom(operand))
}
ANFComputation::Cos(operand) => {
format!("{}.cos()", self.generate_atom(operand))
}
ANFComputation::Sqrt(operand) => {
format!("{}.sqrt()", self.generate_atom(operand))
}
}
}
pub fn generate_function<T: NumericType + std::fmt::Display>(
&self,
name: &str,
expr: &ANFExpr<T>,
) -> String {
let param_list: Vec<String> = (0..self.registry.len())
.map(|i| format!("{}: f64", self.registry.debug_name(i)))
.collect();
let body = self.generate(expr);
format!(
"fn {}({}) -> f64 {{\n {}\n}}",
name,
param_list.join(", "),
body.replace('\n', "\n ")
)
}
}
pub fn generate_rust_code<T: NumericType + std::fmt::Display>(
expr: &ANFExpr<T>,
registry: &VariableRegistry,
) -> String {
let codegen = ANFCodeGen::new(registry);
codegen.generate(expr)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_var_generic() {
let mut generic = ANFVarGen::new();
let v1 = generic.fresh();
let v2 = generic.fresh();
let v3 = generic.user_var(0);
assert_eq!(v1, VarRef::Bound(0));
assert_eq!(v2, VarRef::Bound(1));
assert_eq!(v3, VarRef::User(0));
assert_ne!(v1, v2);
assert_ne!(v1, v3);
}
#[test]
fn test_anf_atom() {
let const_atom: ANFAtom<f64> = ANFAtom::Constant(42.0);
let var_atom: ANFAtom<f64> = ANFAtom::Variable(VarRef::Bound(0));
assert!(const_atom.is_constant());
assert!(!const_atom.is_variable());
assert_eq!(const_atom.as_constant(), Some(&42.0));
assert!(var_atom.is_variable());
assert!(!var_atom.is_constant());
assert_eq!(var_atom.as_variable(), Some(VarRef::Bound(0)));
}
#[test]
fn test_anf_computation_operands() {
let a: ANFAtom<f64> = ANFAtom::Constant(1.0);
let b: ANFAtom<f64> = ANFAtom::Variable(VarRef::Bound(0));
let add = ANFComputation::Add(a.clone(), b.clone());
let operands = add.operands();
assert_eq!(operands.len(), 2);
assert_eq!(operands[0], &a);
assert_eq!(operands[1], &b);
let neg = ANFComputation::Neg(a.clone());
let neg_operands = neg.operands();
assert_eq!(neg_operands.len(), 1);
assert_eq!(neg_operands[0], &a);
}
#[test]
fn test_anf_expr_construction() {
let var = VarRef::Bound(0);
let const_val = 42.0;
let atom_expr: ANFExpr<f64> = ANFExpr::constant(const_val);
let var_expr: ANFExpr<f64> = ANFExpr::variable(var);
assert!(atom_expr.is_atom());
assert!(var_expr.is_atom());
let computation = ANFComputation::Add(ANFAtom::Variable(var), ANFAtom::Constant(1.0));
let let_expr: ANFExpr<f64> = ANFExpr::let_binding(var, computation, atom_expr);
assert!(let_expr.is_let());
assert_eq!(let_expr.let_count(), 1);
}
#[test]
fn test_variable_collection() {
let var1 = VarRef::Bound(0);
let var2 = VarRef::User(0);
let computation = ANFComputation::Add(
ANFAtom::Variable(var2), ANFAtom::Constant(1.0),
);
let body: ANFExpr<f64> = ANFExpr::Atom(ANFAtom::Variable(var1)); let expr: ANFExpr<f64> = ANFExpr::Let(var1, computation, Box::new(body));
let used_vars = expr.used_variables();
assert_eq!(used_vars.len(), 2);
assert!(used_vars.contains(&var1));
assert!(used_vars.contains(&var2));
}
#[test]
fn test_anf_conversion() {
use crate::final_tagless::ASTEval;
let x = ASTEval::var(0); let one = ASTEval::constant(1.0);
let x_plus_one = ASTEval::add(x.clone(), one.clone());
let sin_expr = ASTEval::sin(x_plus_one.clone());
let cos_expr = ASTEval::cos(x_plus_one);
let full_expr = ASTEval::add(sin_expr, cos_expr);
let anf_result = convert_to_anf(&full_expr);
assert!(anf_result.is_ok());
let anf = anf_result.unwrap();
assert!(anf.let_count() > 0);
assert!(anf.is_let());
}
#[test]
fn test_anf_code_generation() {
use crate::final_tagless::{ASTEval, VariableRegistry};
let mut registry = VariableRegistry::new();
let x_idx = registry.register_variable();
let x = ASTEval::var(x_idx);
let two = ASTEval::constant(2.0);
let one = ASTEval::constant(1.0);
let x_squared = ASTEval::mul(x.clone(), x.clone());
let two_x = ASTEval::mul(two, x);
let sum1 = ASTEval::add(x_squared, two_x);
let quadratic = ASTEval::add(sum1, one);
let anf = convert_to_anf(&quadratic).unwrap();
let code = generate_rust_code(&anf, ®istry);
assert!(code.contains("let t"));
assert!(code.contains("var_0"));
let codegen = ANFCodeGen::new(®istry);
let function_code = codegen.generate_function("quadratic", &anf);
assert!(function_code.contains("fn quadratic"));
assert!(function_code.contains("var_0: f64")); assert!(function_code.contains("-> f64"));
println!("Generated code:\n{code}");
println!("Generated function:\n{function_code}");
}
#[test]
fn test_anf_complete_pipeline() {
use crate::final_tagless::{ASTEval, VariableRegistry};
let mut registry = VariableRegistry::new();
let x_idx = registry.register_variable();
let y_idx = registry.register_variable();
let x = ASTEval::var(x_idx);
let y = ASTEval::var(y_idx);
let x_plus_y = ASTEval::add(x, y);
let sin_term = ASTEval::sin(x_plus_y.clone());
let cos_term = ASTEval::cos(x_plus_y.clone());
let exp_term = ASTEval::exp(x_plus_y);
let sum1 = ASTEval::add(sin_term, cos_term);
let final_expr = ASTEval::add(sum1, exp_term);
let anf = convert_to_anf(&final_expr).unwrap();
let codegen = ANFCodeGen::new(®istry);
let function_code = codegen.generate_function("demo_function", &anf);
println!("\n=== ANF Demo: Automatic Common Subexpression Elimination ===");
println!("Original expression: sin(x + y) + cos(x + y) + exp(x + y)");
println!("ANF introduces variables for shared subexpressions automatically\n");
println!("Generated function:");
println!("{function_code}");
assert!(anf.let_count() >= 1); assert!(function_code.contains("fn demo_function"));
assert!(function_code.contains("var_0: f64, var_1: f64")); assert!(function_code.contains("-> f64"));
}
#[test]
fn test_cse_simple_case() {
use crate::final_tagless::{ASTEval, VariableRegistry};
let mut registry = VariableRegistry::new();
let x_idx = registry.register_variable();
let x = ASTEval::var(x_idx);
let one = ASTEval::constant(1.0);
let x_plus_one_left = ASTEval::add(x.clone(), one.clone());
let x_plus_one_right = ASTEval::add(x, one);
let final_expr = ASTEval::add(x_plus_one_left, x_plus_one_right);
let anf = convert_to_anf(&final_expr).unwrap();
let codegen = ANFCodeGen::new(®istry);
let function_code = codegen.generate_function("cse_test", &anf);
println!("\n=== CSE Test: (x + 1) + (x + 1) ===");
println!("Generated function:");
println!("{function_code}");
let code_contains_reuse = function_code.matches("x + 1").count() == 1;
println!(
"CSE working correctly: {}",
if code_contains_reuse {
"✅ YES"
} else {
"❌ NO"
}
);
assert!(anf.let_count() > 0);
}
#[test]
fn test_cse_debug() {
use crate::final_tagless::{ASTEval, VariableRegistry};
let mut registry = VariableRegistry::new();
let x_idx = registry.register_variable();
let x = ASTEval::var(x_idx);
let expr = ASTEval::add(x.clone(), x.clone());
let anf = convert_to_anf(&expr).unwrap();
let codegen = ANFCodeGen::new(®istry);
let function_code = codegen.generate_function("debug_test", &anf);
println!("\n=== CSE Debug: x + x ===");
println!("Generated function:");
println!("{function_code}");
assert!(anf.let_count() == 1);
}
#[test]
fn test_cse_failing_case() {
use crate::final_tagless::{ASTEval, VariableRegistry};
let mut registry = VariableRegistry::new();
let x_idx = registry.register_variable();
let x = ASTEval::var(x_idx);
let one = ASTEval::constant(1.0);
let x_plus_one_left = ASTEval::add(x.clone(), one.clone());
let x_plus_one_right = ASTEval::add(x, one);
let expr = ASTEval::add(x_plus_one_left, x_plus_one_right);
let anf = convert_to_anf(&expr).unwrap();
let codegen = ANFCodeGen::new(®istry);
let function_code = codegen.generate_function("failing_case", &anf);
println!("\n=== CSE Failing Case: (x + 1) + (x + 1) ===");
println!("Generated function:");
println!("{function_code}");
println!("Let count: {}", anf.let_count());
println!("ANF structure: {anf:#?}");
}
}