use crate::error::Result;
use crate::final_tagless::{ASTRepr, NumericType, VariableRegistry};
use num_traits::{Float, Zero};
use ordered_float::OrderedFloat;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum VarRef {
User(usize),
Bound(u32),
}
impl VarRef {
#[must_use]
pub fn to_rust_code(&self, registry: &VariableRegistry) -> String {
match self {
VarRef::User(idx) => registry.get_name(*idx).unwrap_or("unknown").to_string(),
VarRef::Bound(id) => format!("t{id}"),
}
}
#[must_use]
pub fn debug_name(&self, registry: &VariableRegistry) -> String {
match self {
VarRef::User(idx) => {
format!("{}({})", registry.get_name(*idx).unwrap_or("?"), idx)
}
VarRef::Bound(id) => format!("t{id}"),
}
}
#[must_use]
pub fn is_user(&self) -> bool {
matches!(self, VarRef::User(_))
}
#[must_use]
pub fn is_generated(&self) -> bool {
matches!(self, VarRef::Bound(_))
}
}
#[derive(Debug, Clone)]
pub struct ANFVarGen {
next_temp_id: u32,
}
impl ANFVarGen {
#[must_use]
pub fn new() -> Self {
Self { next_temp_id: 0 }
}
pub fn fresh(&mut self) -> VarRef {
let var = VarRef::Bound(self.next_temp_id);
self.next_temp_id += 1;
var
}
#[must_use]
pub fn user_var(&self, index: usize) -> VarRef {
VarRef::User(index)
}
#[must_use]
pub fn generated_count(&self) -> u32 {
self.next_temp_id
}
}
impl Default for ANFVarGen {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ANFAtom<T> {
Constant(T),
Variable(VarRef),
}
impl<T: NumericType> ANFAtom<T> {
pub fn is_constant(&self) -> bool {
matches!(self, ANFAtom::Constant(_))
}
pub fn is_variable(&self) -> bool {
matches!(self, ANFAtom::Variable(_))
}
pub fn as_constant(&self) -> Option<&T> {
match self {
ANFAtom::Constant(val) => Some(val),
_ => None,
}
}
pub fn as_variable(&self) -> Option<VarRef> {
match self {
ANFAtom::Variable(var) => Some(*var),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ANFComputation<T> {
Add(ANFAtom<T>, ANFAtom<T>),
Sub(ANFAtom<T>, ANFAtom<T>),
Mul(ANFAtom<T>, ANFAtom<T>),
Div(ANFAtom<T>, ANFAtom<T>),
Pow(ANFAtom<T>, ANFAtom<T>),
Neg(ANFAtom<T>),
Ln(ANFAtom<T>),
Exp(ANFAtom<T>),
Sin(ANFAtom<T>),
Cos(ANFAtom<T>),
Sqrt(ANFAtom<T>),
}
impl<T: NumericType> ANFComputation<T> {
pub fn operands(&self) -> Vec<&ANFAtom<T>> {
match self {
ANFComputation::Add(a, b)
| ANFComputation::Sub(a, b)
| ANFComputation::Mul(a, b)
| ANFComputation::Div(a, b)
| ANFComputation::Pow(a, b) => vec![a, b],
ANFComputation::Neg(a)
| ANFComputation::Ln(a)
| ANFComputation::Exp(a)
| ANFComputation::Sin(a)
| ANFComputation::Cos(a)
| ANFComputation::Sqrt(a) => vec![a],
}
}
pub fn is_constant_computation(&self) -> bool {
self.operands().iter().all(|atom| atom.is_constant())
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ANFExpr<T> {
Atom(ANFAtom<T>),
Let(VarRef, ANFComputation<T>, Box<ANFExpr<T>>),
}
impl<T> ANFExpr<T>
where
T: NumericType
+ Float
+ Copy
+ std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>
+ std::ops::Neg<Output = T>
+ Zero,
{
pub fn constant(value: T) -> Self {
ANFExpr::Atom(ANFAtom::Constant(value))
}
#[must_use]
pub fn variable(var_ref: VarRef) -> Self {
ANFExpr::Atom(ANFAtom::Variable(var_ref))
}
pub fn let_binding(var_ref: VarRef, computation: ANFComputation<T>, body: ANFExpr<T>) -> Self {
ANFExpr::Let(var_ref, computation, Box::new(body))
}
pub fn is_atom(&self) -> bool {
matches!(self, ANFExpr::Atom(_))
}
pub fn is_let(&self) -> bool {
matches!(self, ANFExpr::Let(_, _, _))
}
pub fn let_count(&self) -> usize {
match self {
ANFExpr::Atom(_) => 0,
ANFExpr::Let(_, _, body) => 1 + body.let_count(),
}
}
pub fn used_variables(&self) -> Vec<VarRef> {
let mut vars = Vec::new();
self.collect_variables(&mut vars);
vars.sort_unstable();
vars.dedup();
vars
}
fn collect_variables(&self, vars: &mut Vec<VarRef>) {
match self {
ANFExpr::Atom(ANFAtom::Variable(var)) => vars.push(*var),
ANFExpr::Atom(ANFAtom::Constant(_)) => {}
ANFExpr::Let(_, comp, body) => {
for operand in comp.operands() {
if let ANFAtom::Variable(var) = operand {
vars.push(*var);
}
}
body.collect_variables(vars);
}
}
}
pub fn eval(&self, variables: &HashMap<usize, T>) -> T {
self.eval_with_bound_vars(variables, &HashMap::new())
}
fn eval_with_bound_vars(
&self,
user_vars: &HashMap<usize, T>,
bound_vars: &HashMap<u32, T>,
) -> T {
match self {
ANFExpr::Atom(atom) => match atom {
ANFAtom::Constant(value) => *value,
ANFAtom::Variable(var_ref) => match var_ref {
VarRef::User(idx) => user_vars.get(idx).copied().unwrap_or_else(T::zero),
VarRef::Bound(id) => bound_vars.get(id).copied().unwrap_or_else(T::zero),
},
},
ANFExpr::Let(var_ref, computation, body) => {
let comp_result = match computation {
ANFComputation::Add(a, b) => {
self.eval_atom_with_bound(a, user_vars, bound_vars)
+ self.eval_atom_with_bound(b, user_vars, bound_vars)
}
ANFComputation::Sub(a, b) => {
self.eval_atom_with_bound(a, user_vars, bound_vars)
- self.eval_atom_with_bound(b, user_vars, bound_vars)
}
ANFComputation::Mul(a, b) => {
self.eval_atom_with_bound(a, user_vars, bound_vars)
* self.eval_atom_with_bound(b, user_vars, bound_vars)
}
ANFComputation::Div(a, b) => {
self.eval_atom_with_bound(a, user_vars, bound_vars)
/ self.eval_atom_with_bound(b, user_vars, bound_vars)
}
ANFComputation::Pow(a, b) => self
.eval_atom_with_bound(a, user_vars, bound_vars)
.powf(self.eval_atom_with_bound(b, user_vars, bound_vars)),
ANFComputation::Neg(a) => -self.eval_atom_with_bound(a, user_vars, bound_vars),
ANFComputation::Ln(a) => {
self.eval_atom_with_bound(a, user_vars, bound_vars).ln()
}
ANFComputation::Exp(a) => {
self.eval_atom_with_bound(a, user_vars, bound_vars).exp()
}
ANFComputation::Sin(a) => {
self.eval_atom_with_bound(a, user_vars, bound_vars).sin()
}
ANFComputation::Cos(a) => {
self.eval_atom_with_bound(a, user_vars, bound_vars).cos()
}
ANFComputation::Sqrt(a) => {
self.eval_atom_with_bound(a, user_vars, bound_vars).sqrt()
}
};
match var_ref {
VarRef::Bound(id) => {
let mut extended_bound_vars = bound_vars.clone();
extended_bound_vars.insert(*id, comp_result);
body.eval_with_bound_vars(user_vars, &extended_bound_vars)
}
VarRef::User(_) => {
body.eval_with_bound_vars(user_vars, bound_vars)
}
}
}
}
}
fn eval_atom_with_bound(
&self,
atom: &ANFAtom<T>,
user_vars: &HashMap<usize, T>,
bound_vars: &HashMap<u32, T>,
) -> T {
match atom {
ANFAtom::Constant(value) => *value,
ANFAtom::Variable(var_ref) => match var_ref {
VarRef::User(idx) => user_vars.get(idx).copied().unwrap_or_else(T::zero),
VarRef::Bound(id) => bound_vars.get(id).copied().unwrap_or_else(T::zero),
},
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum StructuralHash {
Constant(OrderedFloat<f64>),
Variable(usize),
Add(Box<StructuralHash>, Box<StructuralHash>),
Sub(Box<StructuralHash>, Box<StructuralHash>),
Mul(Box<StructuralHash>, Box<StructuralHash>),
Div(Box<StructuralHash>, Box<StructuralHash>),
Pow(Box<StructuralHash>, Box<StructuralHash>),
Neg(Box<StructuralHash>),
Ln(Box<StructuralHash>),
Exp(Box<StructuralHash>),
Sin(Box<StructuralHash>),
Cos(Box<StructuralHash>),
Sqrt(Box<StructuralHash>),
}
impl StructuralHash {
#[must_use]
pub fn from_expr(expr: &ASTRepr<f64>) -> Self {
match expr {
ASTRepr::Constant(val) => StructuralHash::Constant(OrderedFloat(*val)),
ASTRepr::Variable(idx) => StructuralHash::Variable(*idx),
ASTRepr::Add(left, right) => StructuralHash::Add(
Box::new(Self::from_expr(left)),
Box::new(Self::from_expr(right)),
),
ASTRepr::Sub(left, right) => StructuralHash::Sub(
Box::new(Self::from_expr(left)),
Box::new(Self::from_expr(right)),
),
ASTRepr::Mul(left, right) => StructuralHash::Mul(
Box::new(Self::from_expr(left)),
Box::new(Self::from_expr(right)),
),
ASTRepr::Div(left, right) => StructuralHash::Div(
Box::new(Self::from_expr(left)),
Box::new(Self::from_expr(right)),
),
ASTRepr::Pow(left, right) => StructuralHash::Pow(
Box::new(Self::from_expr(left)),
Box::new(Self::from_expr(right)),
),
ASTRepr::Neg(inner) => StructuralHash::Neg(Box::new(Self::from_expr(inner))),
ASTRepr::Ln(inner) => StructuralHash::Ln(Box::new(Self::from_expr(inner))),
ASTRepr::Exp(inner) => StructuralHash::Exp(Box::new(Self::from_expr(inner))),
ASTRepr::Sin(inner) => StructuralHash::Sin(Box::new(Self::from_expr(inner))),
ASTRepr::Cos(inner) => StructuralHash::Cos(Box::new(Self::from_expr(inner))),
ASTRepr::Sqrt(inner) => StructuralHash::Sqrt(Box::new(Self::from_expr(inner))),
}
}
}
#[derive(Debug)]
pub struct ANFConverter {
binding_depth: u32,
next_binding_id: u32,
expr_cache: HashMap<StructuralHash, (u32, VarRef, u32)>,
}
impl ANFConverter {
#[must_use]
pub fn new() -> Self {
Self {
binding_depth: 0,
next_binding_id: 0,
expr_cache: HashMap::new(),
}
}
pub fn convert(&mut self, expr: &ASTRepr<f64>) -> Result<ANFExpr<f64>> {
Ok(self.to_anf(expr))
}
fn to_anf(&mut self, expr: &ASTRepr<f64>) -> ANFExpr<f64> {
if let ASTRepr::Constant(value) = expr {
return ANFExpr::Atom(ANFAtom::Constant(*value));
}
if !matches!(expr, ASTRepr::Constant(_) | ASTRepr::Variable(_)) {
let structural_hash = StructuralHash::from_expr(expr);
if let Some((cached_scope, cached_var, _cached_binding_id)) =
self.expr_cache.get(&structural_hash)
{
if *cached_scope <= self.binding_depth {
return ANFExpr::Atom(ANFAtom::Variable(*cached_var));
}
self.expr_cache.remove(&structural_hash);
}
}
match expr {
ASTRepr::Constant(value) => ANFExpr::Atom(ANFAtom::Constant(*value)),
ASTRepr::Variable(index) => ANFExpr::Atom(ANFAtom::Variable(VarRef::User(*index))),
ASTRepr::Add(left, right) => {
self.convert_binary_op_with_cse(expr, left, right, ANFComputation::Add)
}
ASTRepr::Sub(left, right) => {
self.convert_binary_op_with_cse(expr, left, right, ANFComputation::Sub)
}
ASTRepr::Mul(left, right) => {
self.convert_binary_op_with_cse(expr, left, right, ANFComputation::Mul)
}
ASTRepr::Div(left, right) => {
self.convert_binary_op_with_cse(expr, left, right, ANFComputation::Div)
}
ASTRepr::Pow(left, right) => {
self.convert_binary_op_with_cse(expr, left, right, ANFComputation::Pow)
}
ASTRepr::Neg(inner) => self.convert_unary_op_with_cse(expr, inner, ANFComputation::Neg),
ASTRepr::Ln(inner) => self.convert_unary_op_with_cse(expr, inner, ANFComputation::Ln),
ASTRepr::Exp(inner) => self.convert_unary_op_with_cse(expr, inner, ANFComputation::Exp),
ASTRepr::Sin(inner) => self.convert_unary_op_with_cse(expr, inner, ANFComputation::Sin),
ASTRepr::Cos(inner) => self.convert_unary_op_with_cse(expr, inner, ANFComputation::Cos),
ASTRepr::Sqrt(inner) => {
self.convert_unary_op_with_cse(expr, inner, ANFComputation::Sqrt)
}
}
}
fn convert_binary_op_with_cse(
&mut self,
expr: &ASTRepr<f64>,
left: &ASTRepr<f64>,
right: &ASTRepr<f64>,
op_constructor: fn(ANFAtom<f64>, ANFAtom<f64>) -> ANFComputation<f64>,
) -> ANFExpr<f64> {
fn extract_final_var(expr: &ANFExpr<f64>) -> Option<VarRef> {
match expr {
ANFExpr::Let(var, _, body) => extract_final_var(body).or(Some(*var)),
ANFExpr::Atom(ANFAtom::Variable(var)) => Some(*var),
_ => None,
}
}
let (left_expr, left_atom_orig) = self.to_anf_atom(left);
let (right_expr, right_atom_orig) = self.to_anf_atom(right);
let left_atom = match &left_expr {
Some(e) => extract_final_var(e).map_or(left_atom_orig, ANFAtom::Variable),
None => left_atom_orig,
};
let right_atom = match &right_expr {
Some(e) => extract_final_var(e).map_or(right_atom_orig, ANFAtom::Variable),
None => right_atom_orig,
};
let computation = op_constructor(left_atom.clone(), right_atom.clone());
if left_atom.is_constant() && right_atom.is_constant() {
let result = match computation {
ANFComputation::Add(ANFAtom::Constant(a), ANFAtom::Constant(b)) => {
ANFAtom::Constant(a + b)
}
ANFComputation::Sub(ANFAtom::Constant(a), ANFAtom::Constant(b)) => {
ANFAtom::Constant(a - b)
}
ANFComputation::Mul(ANFAtom::Constant(a), ANFAtom::Constant(b)) => {
ANFAtom::Constant(a * b)
}
ANFComputation::Div(ANFAtom::Constant(a), ANFAtom::Constant(b)) => {
ANFAtom::Constant(a / b)
}
ANFComputation::Pow(ANFAtom::Constant(a), ANFAtom::Constant(b)) => {
ANFAtom::Constant(a.powf(b))
}
_ => unreachable!(),
};
return ANFExpr::Atom(result);
}
let binding_id = self.next_binding_id;
self.next_binding_id += 1;
let result_var = VarRef::Bound(binding_id);
let structural_hash = StructuralHash::from_expr(expr);
self.expr_cache.insert(
structural_hash,
(self.binding_depth, result_var, binding_id),
);
self.binding_depth += 1;
let body = ANFExpr::Atom(ANFAtom::Variable(result_var));
self.binding_depth -= 1;
self.chain_lets(
left_expr,
right_expr,
ANFExpr::Let(result_var, computation, Box::new(body)),
)
}
fn convert_unary_op_with_cse(
&mut self,
expr: &ASTRepr<f64>,
inner: &ASTRepr<f64>,
op_constructor: fn(ANFAtom<f64>) -> ANFComputation<f64>,
) -> ANFExpr<f64> {
fn extract_final_var(expr: &ANFExpr<f64>) -> Option<VarRef> {
match expr {
ANFExpr::Let(var, _, body) => extract_final_var(body).or(Some(*var)),
ANFExpr::Atom(ANFAtom::Variable(var)) => Some(*var),
_ => None,
}
}
let (inner_expr, inner_atom_orig) = self.to_anf_atom(inner);
let inner_atom = match &inner_expr {
Some(e) => extract_final_var(e).map_or(inner_atom_orig, ANFAtom::Variable),
None => inner_atom_orig,
};
let computation = op_constructor(inner_atom.clone());
if inner_atom.is_constant() {
let result = match computation {
ANFComputation::Neg(ANFAtom::Constant(a)) => ANFAtom::Constant(-a),
ANFComputation::Ln(ANFAtom::Constant(a)) => ANFAtom::Constant(a.ln()),
ANFComputation::Exp(ANFAtom::Constant(a)) => ANFAtom::Constant(a.exp()),
ANFComputation::Sin(ANFAtom::Constant(a)) => ANFAtom::Constant(a.sin()),
ANFComputation::Cos(ANFAtom::Constant(a)) => ANFAtom::Constant(a.cos()),
ANFComputation::Sqrt(ANFAtom::Constant(a)) => ANFAtom::Constant(a.sqrt()),
_ => unreachable!(),
};
return ANFExpr::Atom(result);
}
let binding_id = self.next_binding_id;
self.next_binding_id += 1;
let result_var = VarRef::Bound(binding_id);
let structural_hash = StructuralHash::from_expr(expr);
self.expr_cache.insert(
structural_hash,
(self.binding_depth, result_var, binding_id),
);
self.binding_depth += 1;
let body = ANFExpr::Atom(ANFAtom::Variable(result_var));
self.binding_depth -= 1;
self.wrap_with_lets(
inner_expr,
ANFExpr::Let(result_var, computation, Box::new(body)),
)
}
fn to_anf_atom(&mut self, expr: &ASTRepr<f64>) -> (Option<ANFExpr<f64>>, ANFAtom<f64>) {
match expr {
ASTRepr::Constant(value) => (None, ANFAtom::Constant(*value)),
ASTRepr::Variable(index) => (None, ANFAtom::Variable(VarRef::User(*index))),
_ => {
let anf_expr = self.to_anf(expr);
match anf_expr {
ANFExpr::Atom(atom) => (None, atom),
ANFExpr::Let(var, computation, body) => (
Some(ANFExpr::Let(var, computation, body)),
ANFAtom::Variable(var),
),
}
}
}
}
fn chain_lets<T: NumericType + Clone>(
&self,
first: Option<ANFExpr<T>>,
second: Option<ANFExpr<T>>,
final_expr: ANFExpr<T>,
) -> ANFExpr<T> {
match (first, second) {
(None, None) => final_expr,
(Some(first_expr), None) => self.wrap_with_lets(Some(first_expr), final_expr),
(None, Some(second_expr)) => self.wrap_with_lets(Some(second_expr), final_expr),
(Some(first_expr), Some(second_expr)) => {
let combined = self.wrap_with_lets(Some(second_expr), final_expr);
self.wrap_with_lets(Some(first_expr), combined)
}
}
}
fn wrap_with_lets<T: NumericType + Clone>(
&self,
wrapper: Option<ANFExpr<T>>,
body: ANFExpr<T>,
) -> ANFExpr<T> {
match wrapper {
None => body,
Some(ANFExpr::Let(var, computation, inner_body)) => ANFExpr::Let(
var,
computation,
Box::new(self.wrap_with_lets(Some(*inner_body), body)),
),
Some(ANFExpr::Atom(_)) => body, }
}
}
impl Default for ANFConverter {
fn default() -> Self {
Self::new()
}
}
pub fn convert_to_anf(expr: &ASTRepr<f64>) -> Result<ANFExpr<f64>> {
let mut converter = ANFConverter::new();
converter.convert(expr)
}
#[derive(Debug)]
pub struct ANFCodeGen<'a> {
registry: &'a VariableRegistry,
}
impl<'a> ANFCodeGen<'a> {
#[must_use]
pub fn new(registry: &'a VariableRegistry) -> Self {
Self { registry }
}
pub fn generate<T: NumericType + std::fmt::Display>(&self, expr: &ANFExpr<T>) -> String {
match expr {
ANFExpr::Atom(atom) => self.generate_atom(atom),
ANFExpr::Let(var, computation, body) => {
let var_name = var.to_rust_code(self.registry);
let comp_code = self.generate_computation(computation);
let body_code = self.generate(body);
format!("{{ let {var_name} = {comp_code};\n{body_code} }}")
}
}
}
fn generate_atom<T: NumericType + std::fmt::Display>(&self, atom: &ANFAtom<T>) -> String {
match atom {
ANFAtom::Constant(value) => value.to_string(),
ANFAtom::Variable(var) => var.to_rust_code(self.registry),
}
}
fn generate_computation<T: NumericType + std::fmt::Display>(
&self,
comp: &ANFComputation<T>,
) -> String {
match comp {
ANFComputation::Add(left, right) => {
format!(
"{} + {}",
self.generate_atom(left),
self.generate_atom(right)
)
}
ANFComputation::Sub(left, right) => {
format!(
"{} - {}",
self.generate_atom(left),
self.generate_atom(right)
)
}
ANFComputation::Mul(left, right) => {
format!(
"{} * {}",
self.generate_atom(left),
self.generate_atom(right)
)
}
ANFComputation::Div(left, right) => {
format!(
"{} / {}",
self.generate_atom(left),
self.generate_atom(right)
)
}
ANFComputation::Pow(left, right) => {
format!(
"{}.powf({})",
self.generate_atom(left),
self.generate_atom(right)
)
}
ANFComputation::Neg(operand) => {
format!("-{}", self.generate_atom(operand))
}
ANFComputation::Ln(operand) => {
format!("{}.ln()", self.generate_atom(operand))
}
ANFComputation::Exp(operand) => {
format!("{}.exp()", self.generate_atom(operand))
}
ANFComputation::Sin(operand) => {
format!("{}.sin()", self.generate_atom(operand))
}
ANFComputation::Cos(operand) => {
format!("{}.cos()", self.generate_atom(operand))
}
ANFComputation::Sqrt(operand) => {
format!("{}.sqrt()", self.generate_atom(operand))
}
}
}
pub fn generate_function<T: NumericType + std::fmt::Display>(
&self,
name: &str,
expr: &ANFExpr<T>,
) -> String {
let param_list: Vec<String> = self
.registry
.get_all_names()
.iter()
.map(|name| format!("{name}: f64"))
.collect();
let body = self.generate(expr);
format!(
"fn {}({}) -> f64 {{\n {}\n}}",
name,
param_list.join(", "),
body.replace('\n', "\n ")
)
}
}
pub fn generate_rust_code<T: NumericType + std::fmt::Display>(
expr: &ANFExpr<T>,
registry: &VariableRegistry,
) -> String {
let codegen = ANFCodeGen::new(registry);
codegen.generate(expr)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_var_generic() {
let mut generic = ANFVarGen::new();
let v1 = generic.fresh();
let v2 = generic.fresh();
let v3 = generic.user_var(0);
assert_eq!(v1, VarRef::Bound(0));
assert_eq!(v2, VarRef::Bound(1));
assert_eq!(v3, VarRef::User(0));
assert_ne!(v1, v2);
assert_ne!(v1, v3);
}
#[test]
fn test_anf_atom() {
let const_atom: ANFAtom<f64> = ANFAtom::Constant(42.0);
let var_atom: ANFAtom<f64> = ANFAtom::Variable(VarRef::Bound(0));
assert!(const_atom.is_constant());
assert!(!const_atom.is_variable());
assert_eq!(const_atom.as_constant(), Some(&42.0));
assert!(var_atom.is_variable());
assert!(!var_atom.is_constant());
assert_eq!(var_atom.as_variable(), Some(VarRef::Bound(0)));
}
#[test]
fn test_anf_computation_operands() {
let a: ANFAtom<f64> = ANFAtom::Constant(1.0);
let b: ANFAtom<f64> = ANFAtom::Variable(VarRef::Bound(0));
let add = ANFComputation::Add(a.clone(), b.clone());
let operands = add.operands();
assert_eq!(operands.len(), 2);
assert_eq!(operands[0], &a);
assert_eq!(operands[1], &b);
let neg = ANFComputation::Neg(a.clone());
let neg_operands = neg.operands();
assert_eq!(neg_operands.len(), 1);
assert_eq!(neg_operands[0], &a);
}
#[test]
fn test_anf_expr_construction() {
let var = VarRef::Bound(0);
let const_val = 42.0;
let atom_expr: ANFExpr<f64> = ANFExpr::constant(const_val);
let var_expr: ANFExpr<f64> = ANFExpr::variable(var);
assert!(atom_expr.is_atom());
assert!(var_expr.is_atom());
let computation = ANFComputation::Add(ANFAtom::Variable(var), ANFAtom::Constant(1.0));
let let_expr: ANFExpr<f64> = ANFExpr::let_binding(var, computation, atom_expr);
assert!(let_expr.is_let());
assert_eq!(let_expr.let_count(), 1);
}
#[test]
fn test_variable_collection() {
let var1 = VarRef::Bound(0);
let var2 = VarRef::User(0);
let computation = ANFComputation::Add(
ANFAtom::Variable(var2), ANFAtom::Constant(1.0),
);
let body: ANFExpr<f64> = ANFExpr::Atom(ANFAtom::Variable(var1)); let expr: ANFExpr<f64> = ANFExpr::Let(var1, computation, Box::new(body));
let used_vars = expr.used_variables();
assert_eq!(used_vars.len(), 2);
assert!(used_vars.contains(&var1));
assert!(used_vars.contains(&var2));
}
#[test]
fn test_anf_conversion() {
use crate::final_tagless::{ASTEval, ASTMathExpr};
let x = ASTEval::var(0); let one = ASTEval::constant(1.0);
let x_plus_one = ASTEval::add(x.clone(), one.clone());
let sin_expr = ASTEval::sin(x_plus_one.clone());
let cos_expr = ASTEval::cos(x_plus_one);
let full_expr = ASTEval::add(sin_expr, cos_expr);
let anf_result = convert_to_anf(&full_expr);
assert!(anf_result.is_ok());
let anf = anf_result.unwrap();
assert!(anf.let_count() > 0);
assert!(anf.is_let());
}
#[test]
fn test_anf_code_generation() {
use crate::final_tagless::{ASTEval, ASTMathExpr, VariableRegistry};
let mut registry = VariableRegistry::new();
let x_idx = registry.register_variable("x");
let x = ASTEval::var(x_idx);
let two = ASTEval::constant(2.0);
let one = ASTEval::constant(1.0);
let x_squared = ASTEval::mul(x.clone(), x.clone());
let two_x = ASTEval::mul(two, x);
let sum1 = ASTEval::add(x_squared, two_x);
let quadratic = ASTEval::add(sum1, one);
let anf = convert_to_anf(&quadratic).unwrap();
let code = generate_rust_code(&anf, ®istry);
assert!(code.contains("let t"));
assert!(code.contains('x'));
let codegen = ANFCodeGen::new(®istry);
let function_code = codegen.generate_function("quadratic", &anf);
assert!(function_code.contains("fn quadratic"));
assert!(function_code.contains("x: f64"));
assert!(function_code.contains("-> f64"));
println!("Generated code:\n{code}");
println!("Generated function:\n{function_code}");
}
#[test]
fn test_anf_complete_pipeline() {
use crate::final_tagless::{ASTEval, ASTMathExpr, VariableRegistry};
let mut registry = VariableRegistry::new();
let x_idx = registry.register_variable("x");
let y_idx = registry.register_variable("y");
let x = ASTEval::var(x_idx);
let y = ASTEval::var(y_idx);
let x_plus_y = ASTEval::add(x, y);
let sin_term = ASTEval::sin(x_plus_y.clone());
let cos_term = ASTEval::cos(x_plus_y.clone());
let exp_term = ASTEval::exp(x_plus_y);
let sum1 = ASTEval::add(sin_term, cos_term);
let final_expr = ASTEval::add(sum1, exp_term);
let anf = convert_to_anf(&final_expr).unwrap();
let codegen = ANFCodeGen::new(®istry);
let function_code = codegen.generate_function("demo_function", &anf);
println!("\n=== ANF Demo: Automatic Common Subexpression Elimination ===");
println!("Original expression: sin(x + y) + cos(x + y) + exp(x + y)");
println!("ANF introduces variables for shared subexpressions automatically\n");
println!("Generated function:");
println!("{function_code}");
assert!(anf.let_count() >= 1); assert!(function_code.contains("fn demo_function"));
assert!(function_code.contains("x: f64, y: f64"));
assert!(function_code.contains("-> f64"));
}
#[test]
fn test_cse_simple_case() {
use crate::final_tagless::{ASTEval, ASTMathExpr, VariableRegistry};
let mut registry = VariableRegistry::new();
let x_idx = registry.register_variable("x");
let x = ASTEval::var(x_idx);
let one = ASTEval::constant(1.0);
let x_plus_one_left = ASTEval::add(x.clone(), one.clone());
let x_plus_one_right = ASTEval::add(x, one);
let final_expr = ASTEval::add(x_plus_one_left, x_plus_one_right);
let anf = convert_to_anf(&final_expr).unwrap();
let codegen = ANFCodeGen::new(®istry);
let function_code = codegen.generate_function("cse_test", &anf);
println!("\n=== CSE Test: (x + 1) + (x + 1) ===");
println!("Generated function:");
println!("{function_code}");
let code_contains_reuse = function_code.matches("x + 1").count() == 1;
println!(
"CSE working correctly: {}",
if code_contains_reuse {
"✅ YES"
} else {
"❌ NO"
}
);
assert!(anf.let_count() > 0);
}
#[test]
fn test_cse_debug() {
use crate::final_tagless::{ASTEval, ASTMathExpr, VariableRegistry};
let mut registry = VariableRegistry::new();
let x_idx = registry.register_variable("x");
let x = ASTEval::var(x_idx);
let expr = ASTEval::add(x.clone(), x.clone());
let anf = convert_to_anf(&expr).unwrap();
let codegen = ANFCodeGen::new(®istry);
let function_code = codegen.generate_function("debug_test", &anf);
println!("\n=== CSE Debug: x + x ===");
println!("Generated function:");
println!("{function_code}");
assert!(anf.let_count() == 1);
}
#[test]
fn test_cse_failing_case() {
use crate::final_tagless::{ASTEval, ASTMathExpr, VariableRegistry};
let mut registry = VariableRegistry::new();
let x_idx = registry.register_variable("x");
let x = ASTEval::var(x_idx);
let one = ASTEval::constant(1.0);
let x_plus_one_left = ASTEval::add(x.clone(), one.clone());
let x_plus_one_right = ASTEval::add(x, one);
let expr = ASTEval::add(x_plus_one_left, x_plus_one_right);
let anf = convert_to_anf(&expr).unwrap();
let codegen = ANFCodeGen::new(®istry);
let function_code = codegen.generate_function("failing_case", &anf);
println!("\n=== CSE Failing Case: (x + 1) + (x + 1) ===");
println!("Generated function:");
println!("{function_code}");
println!("Let count: {}", anf.let_count());
println!("ANF structure: {anf:#?}");
}
}