use super::ast_repr::ASTRepr;
use crate::final_tagless::traits::NumericType;
use num_traits::Float;
pub fn normalize<T: NumericType + Clone + Float>(expr: &ASTRepr<T>) -> ASTRepr<T> {
match expr {
ASTRepr::Constant(value) => ASTRepr::Constant(*value),
ASTRepr::Variable(index) => ASTRepr::Variable(*index),
ASTRepr::Add(left, right) => {
let norm_left = normalize(left);
let norm_right = normalize(right);
ASTRepr::Add(Box::new(norm_left), Box::new(norm_right))
}
ASTRepr::Mul(left, right) => {
let norm_left = normalize(left);
let norm_right = normalize(right);
ASTRepr::Mul(Box::new(norm_left), Box::new(norm_right))
}
ASTRepr::Pow(base, exp) => {
let norm_base = normalize(base);
let norm_exp = normalize(exp);
ASTRepr::Pow(Box::new(norm_base), Box::new(norm_exp))
}
ASTRepr::Neg(inner) => {
let norm_inner = normalize(inner);
ASTRepr::Neg(Box::new(norm_inner))
}
ASTRepr::Ln(inner) => {
let norm_inner = normalize(inner);
ASTRepr::Ln(Box::new(norm_inner))
}
ASTRepr::Exp(inner) => {
let norm_inner = normalize(inner);
ASTRepr::Exp(Box::new(norm_inner))
}
ASTRepr::Sin(inner) => {
let norm_inner = normalize(inner);
ASTRepr::Sin(Box::new(norm_inner))
}
ASTRepr::Cos(inner) => {
let norm_inner = normalize(inner);
ASTRepr::Cos(Box::new(norm_inner))
}
ASTRepr::Sqrt(inner) => {
let norm_inner = normalize(inner);
ASTRepr::Sqrt(Box::new(norm_inner))
}
ASTRepr::Sub(left, right) => {
let norm_left = normalize(left);
let norm_right = normalize(right);
ASTRepr::Add(
Box::new(norm_left),
Box::new(ASTRepr::Neg(Box::new(norm_right))),
)
}
ASTRepr::Div(left, right) => {
let norm_left = normalize(left);
let norm_right = normalize(right);
ASTRepr::Mul(
Box::new(norm_left),
Box::new(ASTRepr::Pow(
Box::new(norm_right),
Box::new(ASTRepr::Constant(-T::one())),
)),
)
}
}
}
pub fn is_canonical<T: NumericType>(expr: &ASTRepr<T>) -> bool {
match expr {
ASTRepr::Constant(_) | ASTRepr::Variable(_) => true,
ASTRepr::Add(left, right) | ASTRepr::Mul(left, right) | ASTRepr::Pow(left, right) => {
is_canonical(left) && is_canonical(right)
}
ASTRepr::Neg(inner)
| ASTRepr::Ln(inner)
| ASTRepr::Exp(inner)
| ASTRepr::Sin(inner)
| ASTRepr::Cos(inner)
| ASTRepr::Sqrt(inner) => is_canonical(inner),
ASTRepr::Sub(_, _) | ASTRepr::Div(_, _) => false,
}
}
pub fn denormalize<T: NumericType + Clone + PartialEq + Float>(expr: &ASTRepr<T>) -> ASTRepr<T> {
match expr {
ASTRepr::Constant(value) => ASTRepr::Constant(*value),
ASTRepr::Variable(index) => ASTRepr::Variable(*index),
ASTRepr::Add(left, right) => {
if let ASTRepr::Neg(neg_inner) = right.as_ref() {
let denorm_left = denormalize(left);
let denorm_neg_inner = denormalize(neg_inner);
ASTRepr::Sub(Box::new(denorm_left), Box::new(denorm_neg_inner))
} else {
let denorm_left = denormalize(left);
let denorm_right = denormalize(right);
ASTRepr::Add(Box::new(denorm_left), Box::new(denorm_right))
}
}
ASTRepr::Mul(left, right) => {
if let ASTRepr::Pow(base, exp) = right.as_ref() {
if let ASTRepr::Constant(exp_val) = exp.as_ref() {
if (*exp_val + T::one()).abs() < T::epsilon() {
let denorm_left = denormalize(left);
let denorm_base = denormalize(base);
return ASTRepr::Div(Box::new(denorm_left), Box::new(denorm_base));
}
}
}
let denorm_left = denormalize(left);
let denorm_right = denormalize(right);
ASTRepr::Mul(Box::new(denorm_left), Box::new(denorm_right))
}
ASTRepr::Pow(base, exp) => {
let denorm_base = denormalize(base);
let denorm_exp = denormalize(exp);
ASTRepr::Pow(Box::new(denorm_base), Box::new(denorm_exp))
}
ASTRepr::Neg(inner) => {
let denorm_inner = denormalize(inner);
ASTRepr::Neg(Box::new(denorm_inner))
}
ASTRepr::Ln(inner) => {
let denorm_inner = denormalize(inner);
ASTRepr::Ln(Box::new(denorm_inner))
}
ASTRepr::Exp(inner) => {
let denorm_inner = denormalize(inner);
ASTRepr::Exp(Box::new(denorm_inner))
}
ASTRepr::Sin(inner) => {
let denorm_inner = denormalize(inner);
ASTRepr::Sin(Box::new(denorm_inner))
}
ASTRepr::Cos(inner) => {
let denorm_inner = denormalize(inner);
ASTRepr::Cos(Box::new(denorm_inner))
}
ASTRepr::Sqrt(inner) => {
let denorm_inner = denormalize(inner);
ASTRepr::Sqrt(Box::new(denorm_inner))
}
ASTRepr::Sub(left, right) => {
let denorm_left = denormalize(left);
let denorm_right = denormalize(right);
ASTRepr::Sub(Box::new(denorm_left), Box::new(denorm_right))
}
ASTRepr::Div(left, right) => {
let denorm_left = denormalize(left);
let denorm_right = denormalize(right);
ASTRepr::Div(Box::new(denorm_left), Box::new(denorm_right))
}
}
}
pub fn count_operations<T: NumericType>(expr: &ASTRepr<T>) -> (usize, usize, usize, usize) {
let mut add_count = 0;
let mut mul_count = 0;
let mut sub_count = 0;
let mut div_count = 0;
fn count_recursive<T: NumericType>(
expr: &ASTRepr<T>,
add: &mut usize,
mul: &mut usize,
sub: &mut usize,
div: &mut usize,
) {
match expr {
ASTRepr::Add(left, right) => {
*add += 1;
count_recursive(left, add, mul, sub, div);
count_recursive(right, add, mul, sub, div);
}
ASTRepr::Sub(left, right) => {
*sub += 1;
count_recursive(left, add, mul, sub, div);
count_recursive(right, add, mul, sub, div);
}
ASTRepr::Mul(left, right) => {
*mul += 1;
count_recursive(left, add, mul, sub, div);
count_recursive(right, add, mul, sub, div);
}
ASTRepr::Div(left, right) => {
*div += 1;
count_recursive(left, add, mul, sub, div);
count_recursive(right, add, mul, sub, div);
}
ASTRepr::Pow(base, exp) => {
count_recursive(base, add, mul, sub, div);
count_recursive(exp, add, mul, sub, div);
}
ASTRepr::Neg(inner)
| ASTRepr::Ln(inner)
| ASTRepr::Exp(inner)
| ASTRepr::Sin(inner)
| ASTRepr::Cos(inner)
| ASTRepr::Sqrt(inner) => {
count_recursive(inner, add, mul, sub, div);
}
ASTRepr::Constant(_) | ASTRepr::Variable(_) => {}
}
}
count_recursive(
expr,
&mut add_count,
&mut mul_count,
&mut sub_count,
&mut div_count,
);
(add_count, mul_count, sub_count, div_count)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_subtraction_normalization() {
let expr: ASTRepr<f64> = ASTRepr::Sub(
Box::new(ASTRepr::Variable(0)),
Box::new(ASTRepr::Variable(1)),
);
let normalized = normalize(&expr);
match normalized {
ASTRepr::Add(left, right) => {
assert!(matches!(left.as_ref(), ASTRepr::Variable(0)));
assert!(
matches!(right.as_ref(), ASTRepr::Neg(inner) if matches!(inner.as_ref(), ASTRepr::Variable(1)))
);
}
_ => panic!("Expected Add operation after normalization"),
}
}
#[test]
fn test_division_normalization() {
let expr: ASTRepr<f64> = ASTRepr::Div(
Box::new(ASTRepr::Variable(0)),
Box::new(ASTRepr::Variable(1)),
);
let normalized = normalize(&expr);
match normalized {
ASTRepr::Mul(left, right) => {
assert!(matches!(left.as_ref(), ASTRepr::Variable(0)));
match right.as_ref() {
ASTRepr::Pow(base, exp) => {
assert!(matches!(base.as_ref(), ASTRepr::Variable(1)));
assert!(
matches!(exp.as_ref(), ASTRepr::Constant(val) if (*val - (-1.0)).abs() < 1e-12)
);
}
_ => panic!("Expected Pow operation in normalized division"),
}
}
_ => panic!("Expected Mul operation after normalization"),
}
}
#[test]
fn test_nested_normalization() {
let expr: ASTRepr<f64> = ASTRepr::Div(
Box::new(ASTRepr::Sub(
Box::new(ASTRepr::Variable(0)),
Box::new(ASTRepr::Variable(1)),
)),
Box::new(ASTRepr::Add(
Box::new(ASTRepr::Variable(2)),
Box::new(ASTRepr::Variable(3)),
)),
);
let normalized = normalize(&expr);
assert!(is_canonical(&normalized));
}
#[test]
fn test_is_canonical() {
let canonical: ASTRepr<f64> = ASTRepr::Add(
Box::new(ASTRepr::Variable(0)),
Box::new(ASTRepr::Neg(Box::new(ASTRepr::Variable(1)))),
);
assert!(is_canonical(&canonical));
let non_canonical: ASTRepr<f64> = ASTRepr::Sub(
Box::new(ASTRepr::Variable(0)),
Box::new(ASTRepr::Variable(1)),
);
assert!(!is_canonical(&non_canonical));
}
#[test]
fn test_denormalization() {
let original: ASTRepr<f64> = ASTRepr::Sub(
Box::new(ASTRepr::Variable(0)),
Box::new(ASTRepr::Variable(1)),
);
let normalized = normalize(&original);
let denormalized = denormalize(&normalized);
match denormalized {
ASTRepr::Sub(left, right) => {
assert!(matches!(left.as_ref(), ASTRepr::Variable(0)));
assert!(matches!(right.as_ref(), ASTRepr::Variable(1)));
}
_ => panic!("Expected Sub operation after denormalization"),
}
}
#[test]
fn test_division_denormalization() {
let original: ASTRepr<f64> = ASTRepr::Div(
Box::new(ASTRepr::Variable(0)),
Box::new(ASTRepr::Variable(1)),
);
let normalized = normalize(&original);
let denormalized = denormalize(&normalized);
match denormalized {
ASTRepr::Div(left, right) => {
assert!(matches!(left.as_ref(), ASTRepr::Variable(0)));
assert!(matches!(right.as_ref(), ASTRepr::Variable(1)));
}
_ => panic!("Expected Div operation after denormalization"),
}
}
#[test]
fn test_operation_counting() {
let expr: ASTRepr<f64> = ASTRepr::Add(
Box::new(ASTRepr::Sub(
Box::new(ASTRepr::Variable(0)),
Box::new(ASTRepr::Variable(1)),
)),
Box::new(ASTRepr::Div(
Box::new(ASTRepr::Variable(2)),
Box::new(ASTRepr::Variable(3)),
)),
);
let (add, mul, sub, div) = count_operations(&expr);
assert_eq!(add, 1);
assert_eq!(mul, 0);
assert_eq!(sub, 1);
assert_eq!(div, 1);
let normalized = normalize(&expr);
let (norm_add, norm_mul, norm_sub, norm_div) = count_operations(&normalized);
assert!(norm_add > add);
assert!(norm_mul > mul);
assert_eq!(norm_sub, 0);
assert_eq!(norm_div, 0);
}
#[test]
fn test_complex_expression_normalization() {
let expr: ASTRepr<f64> = ASTRepr::Div(
Box::new(ASTRepr::Sub(
Box::new(ASTRepr::Mul(
Box::new(ASTRepr::Variable(0)),
Box::new(ASTRepr::Variable(1)),
)),
Box::new(ASTRepr::Constant(2.0)),
)),
Box::new(ASTRepr::Add(
Box::new(ASTRepr::Variable(2)),
Box::new(ASTRepr::Constant(1.0)),
)),
);
let normalized = normalize(&expr);
assert!(is_canonical(&normalized));
let denormalized = denormalize(&normalized);
assert!(!is_canonical(&denormalized)); }
}