use wick_core::Ast;
use wick_core::optimize::Pass;
pub struct ScalarConstantFolding;
impl Pass for ScalarConstantFolding {
fn transform(&self, ast: &Ast) -> Option<Ast> {
let Ast::Call(name, args) = ast else {
return None;
};
let const_args: Option<Vec<f64>> = args
.iter()
.map(|a| match a {
Ast::Num(n) => Some(*n),
_ => None,
})
.collect();
let const_args = const_args?;
let result = evaluate_scalar_function(name, &const_args)?;
Some(Ast::Num(result))
}
}
fn evaluate_scalar_function(name: &str, args: &[f64]) -> Option<f64> {
match (name, args.len()) {
("pi", 0) => Some(std::f64::consts::PI),
("e", 0) => Some(std::f64::consts::E),
("tau", 0) => Some(std::f64::consts::TAU),
("sin", 1) => Some(args[0].sin()),
("cos", 1) => Some(args[0].cos()),
("tan", 1) => Some(args[0].tan()),
("asin", 1) => Some(args[0].asin()),
("acos", 1) => Some(args[0].acos()),
("atan", 1) => Some(args[0].atan()),
("sinh", 1) => Some(args[0].sinh()),
("cosh", 1) => Some(args[0].cosh()),
("tanh", 1) => Some(args[0].tanh()),
("exp", 1) => Some(args[0].exp()),
("exp2", 1) => Some(args[0].exp2()),
("log", 1) => Some(args[0].ln()),
("ln", 1) => Some(args[0].ln()),
("log2", 1) => Some(args[0].log2()),
("log10", 1) => Some(args[0].log10()),
("sqrt", 1) => Some(args[0].sqrt()),
("inversesqrt", 1) => Some(1.0 / args[0].sqrt()),
("abs", 1) => Some(args[0].abs()),
("sign", 1) => Some(if args[0] > 0.0 {
1.0
} else if args[0] < 0.0 {
-1.0
} else {
0.0
}),
("floor", 1) => Some(args[0].floor()),
("ceil", 1) => Some(args[0].ceil()),
("round", 1) => Some(args[0].round()),
("trunc", 1) => Some(args[0].trunc()),
("fract", 1) => Some(args[0].fract()),
("saturate", 1) => Some(args[0].clamp(0.0, 1.0)),
("atan2", 2) => Some(args[0].atan2(args[1])),
("pow", 2) => Some(args[0].powf(args[1])),
("min", 2) => Some(args[0].min(args[1])),
("max", 2) => Some(args[0].max(args[1])),
("step", 2) => Some(if args[1] < args[0] { 0.0 } else { 1.0 }),
("clamp", 3) => Some(args[0].clamp(args[1], args[2])),
("lerp", 3) | ("mix", 3) => {
let (a, b, t) = (args[0], args[1], args[2]);
Some(a + (b - a) * t)
}
("smoothstep", 3) => {
let (edge0, edge1, x) = (args[0], args[1], args[2]);
let t = ((x - edge0) / (edge1 - edge0)).clamp(0.0, 1.0);
Some(t * t * (3.0 - 2.0 * t))
}
("inverse_lerp", 3) => {
let (a, b, v) = (args[0], args[1], args[2]);
Some((v - a) / (b - a))
}
("remap", 5) => {
let (x, in_lo, in_hi, out_lo, out_hi) = (args[0], args[1], args[2], args[3], args[4]);
let t = (x - in_lo) / (in_hi - in_lo);
Some(out_lo + (out_hi - out_lo) * t)
}
_ => None,
}
}
pub fn scalar_passes() -> Vec<&'static dyn Pass> {
vec![&ScalarConstantFolding]
}
#[cfg(test)]
mod tests {
use super::*;
use wick_core::Expr;
use wick_core::optimize::{optimize, standard_passes};
fn optimized(input: &str) -> String {
let expr = Expr::parse(input).unwrap();
let mut passes: Vec<&dyn Pass> = standard_passes();
passes.push(&ScalarConstantFolding);
let result = optimize(expr.ast().clone(), &passes);
result.to_string()
}
fn optimized_value(input: &str) -> f64 {
let expr = Expr::parse(input).unwrap();
let mut passes: Vec<&dyn Pass> = standard_passes();
passes.push(&ScalarConstantFolding);
let result = optimize(expr.ast().clone(), &passes);
match result {
Ast::Num(n) => n,
other => panic!("expected Num, got {other:?}"),
}
}
#[test]
fn test_pi() {
let v = optimized_value("pi()");
assert!((v - std::f64::consts::PI).abs() < 0.0001);
}
#[test]
fn test_e() {
let v = optimized_value("e()");
assert!((v - std::f64::consts::E).abs() < 0.0001);
}
#[test]
fn test_tau() {
let v = optimized_value("tau()");
assert!((v - std::f64::consts::TAU).abs() < 0.0001);
}
#[test]
fn test_sin_zero() {
let v = optimized_value("sin(0)");
assert!(v.abs() < 0.0001);
}
#[test]
fn test_cos_zero() {
let v = optimized_value("cos(0)");
assert!((v - 1.0).abs() < 0.0001);
}
#[test]
fn test_tan_zero() {
let v = optimized_value("tan(0)");
assert!(v.abs() < 0.0001);
}
#[test]
fn test_exp_zero() {
let v = optimized_value("exp(0)");
assert!((v - 1.0).abs() < 0.0001);
}
#[test]
fn test_ln_one() {
let v = optimized_value("ln(1)");
assert!(v.abs() < 0.0001);
}
#[test]
fn test_sqrt() {
let v = optimized_value("sqrt(16)");
assert!((v - 4.0).abs() < 0.0001);
}
#[test]
fn test_log2() {
let v = optimized_value("log2(8)");
assert!((v - 3.0).abs() < 0.0001);
}
#[test]
fn test_abs() {
assert_eq!(optimized("abs(-5)"), "5");
}
#[test]
fn test_floor() {
assert_eq!(optimized("floor(3.7)"), "3");
}
#[test]
fn test_ceil() {
assert_eq!(optimized("ceil(3.2)"), "4");
}
#[test]
fn test_round() {
assert_eq!(optimized("round(3.5)"), "4");
}
#[test]
fn test_sign() {
assert_eq!(optimized("sign(-5)"), "-1");
assert_eq!(optimized("sign(5)"), "1");
assert_eq!(optimized("sign(0)"), "0");
}
#[test]
fn test_saturate() {
assert_eq!(optimized("saturate(1.5)"), "1");
assert_eq!(optimized("saturate(-0.5)"), "0");
}
#[test]
fn test_min_max() {
assert_eq!(optimized("min(3, 7)"), "3");
assert_eq!(optimized("max(3, 7)"), "7");
}
#[test]
fn test_pow() {
let v = optimized_value("pow(2, 3)");
assert!((v - 8.0).abs() < 0.0001);
}
#[test]
fn test_step() {
assert_eq!(optimized("step(0.5, 0.3)"), "0");
assert_eq!(optimized("step(0.5, 0.7)"), "1");
}
#[test]
fn test_clamp() {
assert_eq!(optimized("clamp(5, 0, 3)"), "3");
assert_eq!(optimized("clamp(-1, 0, 3)"), "0");
}
#[test]
fn test_lerp() {
assert_eq!(optimized("lerp(0, 10, 0.5)"), "5");
assert_eq!(optimized("mix(0, 10, 0.5)"), "5");
}
#[test]
fn test_smoothstep() {
let v = optimized_value("smoothstep(0, 1, 0.5)");
assert!((v - 0.5).abs() < 0.1);
}
#[test]
fn test_inverse_lerp() {
let v = optimized_value("inverse_lerp(0, 10, 5)");
assert!((v - 0.5).abs() < 0.0001);
}
#[test]
fn test_remap() {
let v = optimized_value("remap(5, 0, 10, 0, 100)");
assert!((v - 50.0).abs() < 0.0001);
}
#[test]
fn test_combined_sin_cos() {
assert_eq!(optimized("sin(0) + cos(0)"), "1");
}
#[test]
fn test_combined_with_arithmetic() {
assert_eq!(optimized("sqrt(4) * 3 + 1"), "7");
}
#[test]
fn test_nested_calls() {
assert_eq!(optimized("sqrt(sqrt(16))"), "2");
}
#[test]
fn test_partial_constant() {
assert_eq!(optimized("sqrt(x)"), "sqrt(x)");
}
#[test]
fn test_partial_mixed() {
assert_eq!(optimized("sqrt(4) + x"), "(2 + x)");
}
}