use dslcompile::ast::ast_utils::{combine_expressions_with_remapping, remap_variables};
use dslcompile::final_tagless::DirectEval;
use dslcompile::prelude::*;
use std::collections::HashMap;
#[test]
fn test_manual_variable_remapping() {
let math_f = ExpressionBuilder::new();
let x_f = math_f.var(); let f_expr = &x_f * &x_f + 2.0 * &x_f + 1.0;
let f_ast = f_expr.as_ast();
let math_g = ExpressionBuilder::new();
let y_g = math_g.var(); let g_expr = 3.0 * &y_g + 5.0;
let g_ast = g_expr.as_ast();
let mut var_map = HashMap::new();
var_map.insert(0, 1); let g_remapped = remap_variables(g_ast, &var_map);
let h_ast = dslcompile::ast::ASTRepr::Add(Box::new(f_ast.clone()), Box::new(g_remapped));
let result = DirectEval::eval_with_vars(&h_ast, &[2.0, 3.0]);
assert_eq!(result, 23.0);
println!("Manual remapping: h(2,3) = {result}");
}
#[test]
fn test_automatic_variable_remapping() {
let math_f = ExpressionBuilder::new();
let x_f = math_f.var();
let f_expr = &x_f * &x_f + 2.0 * &x_f + 1.0;
let f_ast = f_expr.as_ast();
let math_g = ExpressionBuilder::new();
let y_g = math_g.var();
let g_expr = 3.0 * &y_g + 5.0;
let g_ast = g_expr.as_ast();
let (remapped_expressions, total_vars) =
combine_expressions_with_remapping(&[f_ast.clone(), g_ast.clone()]);
assert_eq!(remapped_expressions.len(), 2);
assert_eq!(total_vars, 2);
let h_ast = dslcompile::ast::ASTRepr::Add(
Box::new(remapped_expressions[0].clone()),
Box::new(remapped_expressions[1].clone()),
);
let result = DirectEval::eval_with_vars(&h_ast, &[2.0, 3.0]);
assert_eq!(result, 23.0);
println!("Automatic remapping: h(2,3) = {result}");
}
#[test]
fn test_simple_composition_api() {
let math_f = ExpressionBuilder::new();
let x_f = math_f.var();
let f_expr = &x_f * &x_f + 2.0 * &x_f + 1.0;
let math_g = ExpressionBuilder::new();
let y_g = math_g.var();
let g_expr = 3.0 * &y_g + 5.0;
let (remapped_expressions, _) =
combine_expressions_with_remapping(&[f_expr.as_ast().clone(), g_expr.as_ast().clone()]);
let h_ast = dslcompile::ast::ASTRepr::Add(
Box::new(remapped_expressions[0].clone()),
Box::new(remapped_expressions[1].clone()),
);
let result = DirectEval::eval_with_vars(&h_ast, &[2.0, 3.0]);
assert_eq!(result, 23.0);
println!("Simple composition: h(2,3) = {result}");
}
#[test]
fn test_complex_composition() {
let math_f = ExpressionBuilder::new();
let x_f = math_f.var();
let f_expr = x_f.clone().sin() + x_f.clone().cos();
let math_g = ExpressionBuilder::new();
let y_g = math_g.var();
let g_expr = y_g.clone().exp() - (y_g.clone() + 1.0).ln();
let math_h = ExpressionBuilder::new();
let z_h = math_h.var();
let h_expr = &z_h * &z_h;
let (remapped_expressions, _) = combine_expressions_with_remapping(&[
f_expr.as_ast().clone(),
g_expr.as_ast().clone(),
h_expr.as_ast().clone(),
]);
let k_ast = dslcompile::ast::ASTRepr::Add(
Box::new(dslcompile::ast::ASTRepr::Mul(
Box::new(remapped_expressions[0].clone()),
Box::new(remapped_expressions[1].clone()),
)),
Box::new(remapped_expressions[2].clone()),
);
let x_val = 1.0_f64;
let y_val = 1.0_f64;
let z_val = 3.0_f64;
let result = DirectEval::eval_with_vars(&k_ast, &[x_val, y_val, z_val]);
let expected_f = x_val.sin() + x_val.cos();
let expected_g = y_val.exp() - (y_val + 1.0).ln();
let expected_h = z_val * z_val;
let expected = expected_f * expected_g + expected_h;
assert!((result - expected).abs() < 1e-10);
println!("Complex composition: k(1,1,3) = {result}");
}
#[test]
fn test_compile_time_variable_collision() {
use dslcompile::compile_time::*;
let f = var::<0>().mul(constant(2.0));
let g = var::<0>().mul(constant(3.0));
let h_wrong = f.add(g);
let result_wrong = h_wrong.eval(&[4.0]);
assert_eq!(result_wrong, 20.0);
let f_correct = var::<0>().mul(constant(2.0)); let g_correct = var::<1>().mul(constant(3.0)); let h_correct = f_correct.add(g_correct);
let result_correct = h_correct.eval(&[4.0, 7.0]);
assert_eq!(result_correct, 29.0);
println!("Compile-time collision: {result_wrong} vs correct: {result_correct}");
}
#[test]
fn test_compilation_with_remapped_variables() {
let math_f = ExpressionBuilder::new();
let x_f = math_f.var();
let f_expr = &x_f * &x_f + 2.0 * &x_f + 1.0;
let math_g = ExpressionBuilder::new();
let y_g = math_g.var();
let g_expr = 3.0 * &y_g + 5.0;
let (remapped_expressions, _) =
combine_expressions_with_remapping(&[f_expr.as_ast().clone(), g_expr.as_ast().clone()]);
let h_ast = dslcompile::ast::ASTRepr::Add(
Box::new(remapped_expressions[0].clone()),
Box::new(remapped_expressions[1].clone()),
);
let codegen = dslcompile::backends::rust_codegen::RustCodeGenerator::new();
let mut registry =
dslcompile::final_tagless::variables::typed_registry::VariableRegistry::new();
let _var0 = registry.register_variable(); let _var1 = registry.register_variable();
let rust_code =
codegen.generate_function_with_registry(&h_ast, "composed_func", "f64", ®istry);
assert!(rust_code.is_ok());
let code = rust_code.unwrap();
println!("Generated code:\n{code}");
assert!(code.contains("var_0"));
assert!(code.contains("var_1"));
}