use std::collections::HashSet;
use tensorlogic_ir::TLExpr;
use super::helpers::collect_free_pred_vars;
use super::pe_core::pe_rec;
use super::types::{PEConfig, PEEnv, PEResult, PEStats};
pub fn partially_evaluate(expr: &TLExpr, env: &PEEnv, config: &PEConfig) -> PEResult {
let mut stats = PEStats::default();
let result_expr = pe_rec(expr.clone(), env, config, 0, &mut stats);
let mut free_set = HashSet::new();
collect_free_pred_vars(&result_expr, &HashSet::new(), &mut free_set);
let mut residual_vars: Vec<String> = free_set.into_iter().collect();
residual_vars.sort();
PEResult {
expr: result_expr,
stats,
residual_vars,
}
}
pub fn specialize(expr: &TLExpr, bindings: &[(String, f64)], config: &PEConfig) -> PEResult {
let mut env = PEEnv::new();
for (name, val) in bindings {
env.bind_f64(name.clone(), *val);
}
partially_evaluate(expr, &env, config)
}
pub fn specialize_batch(
expr: &TLExpr,
binding_sets: &[Vec<(String, f64)>],
config: &PEConfig,
) -> Vec<PEResult> {
binding_sets
.iter()
.map(|bindings| specialize(expr, bindings, config))
.collect()
}