use dslcompile::ast::ast_utils::remap_variables;
use dslcompile::final_tagless::DirectEval;
use dslcompile::prelude::*;
use std::collections::HashMap;
#[test]
fn test_shared_variable_composition_naive() {
let math_f = ExpressionBuilder::new();
let x_f = math_f.var(); let y_f = math_f.var(); let f_expr = &x_f * &x_f + &x_f * &y_f + &y_f * &y_f;
let math_g = ExpressionBuilder::new();
let y_g = math_g.var(); let z_g = math_g.var(); let g_expr = 2.0 * &y_g + 3.0 * &z_g;
let h_wrong = f_expr.as_ast().clone() + g_expr.as_ast().clone();
let result_wrong = DirectEval::eval_with_vars(&h_wrong, &[1.0, 2.0]);
println!("Naive (wrong) result: h(1,2) = {result_wrong}");
assert_eq!(result_wrong, 15.0); }
#[test]
fn test_shared_variable_composition_manual() {
let math_f = ExpressionBuilder::new();
let x_f = math_f.var(); let y_f = math_f.var(); let f_expr = &x_f * &x_f + &x_f * &y_f + &y_f * &y_f;
let f_ast = f_expr.as_ast();
let math_g = ExpressionBuilder::new();
let y_g = math_g.var(); let z_g = math_g.var(); let g_expr = 2.0 * &y_g + 3.0 * &z_g;
let g_ast = g_expr.as_ast();
let mut g_var_map = HashMap::new();
g_var_map.insert(0, 1); g_var_map.insert(1, 2); let g_remapped = remap_variables(g_ast, &g_var_map);
let h_ast = dslcompile::ast::ASTRepr::Add(Box::new(f_ast.clone()), Box::new(g_remapped));
let result = DirectEval::eval_with_vars(&h_ast, &[1.0, 2.0, 3.0]);
assert_eq!(result, 20.0);
println!("Manual remapping: h(1,2,3) = f(1,2) + g(2,3) = {result}");
}
#[test]
fn test_shared_variable_composition_systematic() {
#[derive(Debug, Clone)]
struct NamedFunction {
ast: dslcompile::ast::ASTRepr<f64>,
var_names: Vec<String>, }
impl NamedFunction {
fn new(ast: dslcompile::ast::ASTRepr<f64>, var_names: Vec<String>) -> Self {
Self { ast, var_names }
}
fn remap_to_global_indices(
&self,
global_var_map: &HashMap<String, usize>,
) -> dslcompile::ast::ASTRepr<f64> {
let mut local_to_global = HashMap::new();
for (local_idx, var_name) in self.var_names.iter().enumerate() {
if let Some(&global_idx) = global_var_map.get(var_name) {
local_to_global.insert(local_idx, global_idx);
}
}
remap_variables(&self.ast, &local_to_global)
}
}
let math_f = ExpressionBuilder::new();
let x_f = math_f.var(); let y_f = math_f.var(); let f_expr = &x_f * &x_f + &x_f * &y_f + &y_f * &y_f;
let f_named = NamedFunction::new(
f_expr.as_ast().clone(),
vec!["x".to_string(), "y".to_string()],
);
let math_g = ExpressionBuilder::new();
let y_g = math_g.var(); let z_g = math_g.var(); let g_expr = 2.0 * &y_g + 3.0 * &z_g;
let g_named = NamedFunction::new(
g_expr.as_ast().clone(),
vec!["y".to_string(), "z".to_string()],
);
let mut global_var_map = HashMap::new();
global_var_map.insert("x".to_string(), 0);
global_var_map.insert("y".to_string(), 1);
global_var_map.insert("z".to_string(), 2);
let f_global = f_named.remap_to_global_indices(&global_var_map);
let g_global = g_named.remap_to_global_indices(&global_var_map);
let h_ast = dslcompile::ast::ASTRepr::Add(Box::new(f_global), Box::new(g_global));
let result = DirectEval::eval_with_vars(&h_ast, &[1.0, 2.0, 3.0]);
assert_eq!(result, 20.0);
println!("Systematic approach: h(1,2,3) = {result}");
}
#[test]
fn test_complex_shared_variable_case() {
let math_f = ExpressionBuilder::new();
let x_f = math_f.var();
let y_f = math_f.var();
let f_expr = x_f.clone().sin() * y_f.clone().cos();
let math_g = ExpressionBuilder::new();
let y_g = math_g.var();
let z_g = math_g.var();
let g_expr = (y_g.clone() + z_g.clone()).exp();
let math_k = ExpressionBuilder::new();
let w_k = math_k.var();
let x_k = math_k.var();
let k_expr = &w_k * &w_k - &x_k * &x_k;
let mut global_var_map = HashMap::new();
global_var_map.insert("w".to_string(), 0);
global_var_map.insert("x".to_string(), 1);
global_var_map.insert("y".to_string(), 2);
global_var_map.insert("z".to_string(), 3);
let mut f_map = HashMap::new();
f_map.insert(0, 1); f_map.insert(1, 2); let f_remapped = remap_variables(f_expr.as_ast(), &f_map);
let mut g_map = HashMap::new();
g_map.insert(0, 2); g_map.insert(1, 3); let g_remapped = remap_variables(g_expr.as_ast(), &g_map);
let mut k_map = HashMap::new();
k_map.insert(0, 0); k_map.insert(1, 1); let k_remapped = remap_variables(k_expr.as_ast(), &k_map);
let h_ast = dslcompile::ast::ASTRepr::Add(
Box::new(dslcompile::ast::ASTRepr::Add(
Box::new(f_remapped),
Box::new(g_remapped),
)),
Box::new(k_remapped),
);
let w_val = 1.0_f64;
let x_val = 2.0_f64;
let y_val = 3.0_f64;
let z_val = 4.0_f64;
let result = DirectEval::eval_with_vars(&h_ast, &[w_val, x_val, y_val, z_val]);
let expected_f = x_val.sin() * y_val.cos();
let expected_g = (y_val + z_val).exp();
let expected_k = w_val * w_val - x_val * x_val;
let expected = expected_f + expected_g + expected_k;
assert!((result - expected).abs() < 1e-10);
println!("Complex composition: h(1,2,3,4) = {result}");
}
#[test]
fn test_automatic_shared_variable_detection() {
fn analyze_variable_usage(
functions: &[(&dslcompile::ast::ASTRepr<f64>, &[&str])],
) -> (HashMap<String, usize>, Vec<HashMap<usize, usize>>) {
let mut all_vars = std::collections::BTreeSet::new();
for (_, var_names) in functions {
for &name in *var_names {
all_vars.insert(name.to_string());
}
}
let global_mapping: HashMap<String, usize> = all_vars
.into_iter()
.enumerate()
.map(|(i, name)| (name, i))
.collect();
let mut local_mappings = Vec::new();
for (_, var_names) in functions {
let mut local_map = HashMap::new();
for (local_idx, &var_name) in var_names.iter().enumerate() {
if let Some(&global_idx) = global_mapping.get(var_name) {
local_map.insert(local_idx, global_idx);
}
}
local_mappings.push(local_map);
}
(global_mapping, local_mappings)
}
let math_f = ExpressionBuilder::new();
let x_f = math_f.var();
let y_f = math_f.var();
let f_expr = &x_f * &x_f + &y_f * &y_f; let f_ast = f_expr.as_ast();
let math_g = ExpressionBuilder::new();
let y_g = math_g.var();
let z_g = math_g.var();
let g_expr = 2.0 * &y_g + &z_g; let g_ast = g_expr.as_ast();
let functions: &[(&dslcompile::ast::ASTRepr<f64>, &[&str])] =
&[(f_ast, &["x", "y"][..]), (g_ast, &["y", "z"][..])];
let (global_mapping, local_mappings) = analyze_variable_usage(functions);
println!("Global variable mapping: {global_mapping:?}");
println!("Local mappings: {local_mappings:?}");
let f_remapped = remap_variables(f_ast, &local_mappings[0]);
let g_remapped = remap_variables(g_ast, &local_mappings[1]);
let h_ast = dslcompile::ast::ASTRepr::Add(Box::new(f_remapped), Box::new(g_remapped));
let result = DirectEval::eval_with_vars(&h_ast, &[1.0, 2.0, 3.0]);
assert_eq!(result, 12.0);
println!("Automatic detection: h(1,2,3) = {result}");
}