use crate::config::Config;
use crate::dataset::DataSet;
use crate::error::{PhopError, Result};
use crate::pareto::ParetoFront;
use crate::solution::Solution;
use oxieml::EmlTree;
use scirs2_autograd as ag;
use scirs2_autograd::optimizers::adam::Adam;
use scirs2_autograd::prelude::*;
use scirs2_autograd::tensor_ops as T;
use scirs2_core::ndarray::Array1;
#[cfg(test)]
use scirs2_core::ndarray::Array2;
const MAX_RESTARTS: usize = 16;
use crate::rng::SplitMix64;
fn gumbel_name(l: usize) -> &'static str {
use std::sync::{Mutex, OnceLock};
static CACHE: OnceLock<Mutex<Vec<&'static str>>> = OnceLock::new();
let m = CACHE.get_or_init(|| Mutex::new(Vec::new()));
let mut v = m.lock().expect("gumbel name cache poisoned");
while v.len() <= l {
let name: &'static str = Box::leak(format!("phop_g{}", v.len()).into_boxed_str());
v.push(name);
}
v[l]
}
enum LeafChoice {
Var(usize),
Const(f64),
}
fn build_tree(node: usize, internal_count: usize, choices: &[LeafChoice]) -> EmlTree {
if node >= internal_count {
match &choices[node - internal_count] {
LeafChoice::Var(j) => EmlTree::var(*j),
LeafChoice::Const(c) => EmlTree::const_val(*c),
}
} else {
let l = build_tree(2 * node + 1, internal_count, choices);
let r = build_tree(2 * node + 1 + 1, internal_count, choices);
EmlTree::eml(&l, &r)
}
}
fn run_restart(ds: &DataSet, cfg: &Config, depth: usize, seed: u64) -> Result<Solution> {
let n_vars = ds.n_vars();
let k = n_vars + 1; let internal_count = (1usize << depth) - 1;
let n_leaves = 1usize << depth;
let total = (1usize << (depth + 1)) - 1;
let batch = ds.len();
let x = &ds.x;
let y = &ds.y;
let mut env = ag::VariableEnvironment::<f64>::new();
for l in 0..n_leaves {
for i in 0..k {
env.name(format!("z{l}_{i}"))
.set(scirs2_core::ndarray::arr0(0.0));
}
env.name(format!("c{l}"))
.set(scirs2_core::ndarray::arr0(1.0));
}
let ids: Vec<_> = env.default_namespace().current_var_ids();
let block = k + 1;
let logit_id = |l: usize, i: usize| ids[l * block + i];
let const_id = |l: usize| ids[l * block + k];
let adam = Adam::new(
cfg.learning_rate,
1e-8,
0.9,
0.999,
ids.clone(),
&mut env,
"phop_g_adam",
);
let col_vals: Vec<Array1<f64>> = (0..n_vars).map(|j| x.column(j).to_owned()).collect();
let ones_val: Array1<f64> = Array1::from_elem(batch, 1.0);
let mut rng = SplitMix64::new(seed);
let _silencer = crate::silence::SilenceStdout::new();
for epoch in 0..cfg.max_epochs {
let tau = cfg
.temperature(epoch as f64 / cfg.max_epochs.max(1) as f64)
.max(1e-2);
let inv_tau = 1.0 / tau;
let gval_arrays: Vec<_> = (0..n_leaves * k)
.map(|_| scirs2_core::ndarray::arr0(rng.gumbel()))
.collect();
env.run(|g| {
let cols: Vec<ag::Tensor<f64>> = (0..n_vars)
.map(|j| g.placeholder(crate::forest::col_placeholder_name(j), &[-1]))
.collect();
let ones = g.placeholder("phop_ones", &[-1]);
let yt = g.placeholder("phop_y", &[-1]);
let gphs: Vec<ag::Tensor<f64>> = (0..n_leaves * k)
.map(|idx| g.placeholder(gumbel_name(idx), &[]))
.collect();
let mut vals: Vec<Option<ag::Tensor<f64>>> = vec![None; total];
let mut struct_pen: Option<ag::Tensor<f64>> = None;
for l in 0..n_leaves {
let cst = g.variable_by_id(const_id(l));
let const_col = T::mul(cst, ones); let source = |i: usize| if i < n_vars { cols[i] } else { const_col };
let a: Vec<ag::Tensor<f64>> = (0..k)
.map(|i| {
let z = g.variable_by_id(logit_id(l, i));
let perturbed = T::add(z, gphs[l * k + i]).scalar_mul(inv_tau);
T::exp(T::clip(perturbed, -30.0, 30.0)) })
.collect();
let mut denom = a[0];
for ai in a.iter().skip(1) {
denom = T::add(denom, *ai); }
let weights: Vec<ag::Tensor<f64>> = a.iter().map(|&ai| T::div(ai, denom)).collect();
let mut leaf_val: Option<ag::Tensor<f64>> = None;
for (i, &w) in weights.iter().enumerate() {
let term = T::mul(w, source(i)); leaf_val = Some(match leaf_val {
None => term,
Some(acc) => T::add(acc, term),
});
}
for &w in weights.iter().take(n_vars) {
struct_pen = Some(match struct_pen {
None => w,
Some(acc) => T::add(acc, w),
});
}
vals[internal_count + l] = leaf_val;
}
for i in (0..internal_count).rev() {
let a = vals[2 * i + 1].expect("child computed");
let b = vals[2 * i + 2].expect("child computed");
vals[i] = Some(crate::forest::eml_guarded(a, b));
}
let pred = vals[0].expect("root computed");
let mut loss = T::reduce_mean(T::square(T::sub(pred, yt)), &[0], false);
let struct_lambda =
cfg.lambda_complexity + cfg.lambda_sparsity + cfg.lambda_parsimony * depth as f64;
if let Some(pen) = struct_pen {
if struct_lambda != 0.0 {
loss = T::add(loss, pen.scalar_mul(struct_lambda));
}
}
let var_tensors: Vec<ag::Tensor<f64>> =
ids.iter().map(|&vid| g.variable_by_id(vid)).collect();
let grads = T::grad(&[loss], &var_tensors);
let mut feeder = ag::Feeder::new();
for (j, cv) in col_vals.iter().enumerate() {
feeder = feeder.push(cols[j], cv.view().into_dyn());
}
feeder = feeder.push(ones, ones_val.view().into_dyn());
feeder = feeder.push(yt, y.view().into_dyn());
for (idx, gph) in gphs.iter().enumerate() {
feeder = feeder.push(*gph, gval_arrays[idx].view().into_dyn());
}
adam.update(&var_tensors, &grads, g, feeder);
});
}
let choices: Vec<LeafChoice> = env.run(|g| {
(0..n_leaves)
.map(|l| {
let read = |vid| {
g.variable_by_id(vid)
.eval(g)
.ok()
.and_then(|a| a.iter().copied().next())
.unwrap_or(0.0)
};
let logits: Vec<f64> = (0..k).map(|i| read(logit_id(l, i))).collect();
let cst = read(const_id(l));
let best = (0..k)
.max_by(|&i, &j| {
logits[i]
.partial_cmp(&logits[j])
.unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or(0);
if best < n_vars {
LeafChoice::Var(best)
} else {
LeafChoice::Const(cst)
}
})
.collect()
});
drop(_silencer);
let tree = build_tree(0, internal_count, &choices);
let (polished, _) = crate::polish::polish_constants(&tree, ds, 40);
let (snapped, m) = crate::polish::snap_constants(&polished, ds, 0.02);
Ok(Solution::new(snapped, m))
}
pub fn discover_gumbel(ds: &DataSet, cfg: &Config) -> Result<ParetoFront> {
if ds.is_empty() {
return Err(PhopError::ShapeMismatch("empty dataset".to_string()));
}
let depth = cfg.max_depth.clamp(1, 4);
let restarts = cfg.population.clamp(1, MAX_RESTARTS);
let mut sols: Vec<Solution> = Vec::new();
for r in 0..restarts {
if let Ok(sol) = run_restart(ds, cfg, depth, cfg.seed.wrapping_add(r as u64 + 1)) {
if sol.mse.is_finite() {
sols.push(sol);
}
}
}
if sols.is_empty() {
return Err(PhopError::NotConverged(
"no Gumbel-Softmax restart converged to a finite solution".to_string(),
));
}
Ok(ParetoFront::from_candidates(sols))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn gumbel_recovers_exp_structure() {
let xs: Vec<f64> = (0..40).map(|i| f64::from(i) * 0.08).collect();
let ys: Vec<f64> = xs.iter().map(|&x| x.exp()).collect();
let x = Array2::from_shape_vec((xs.len(), 1), xs).unwrap();
let y = Array1::from(ys.clone());
let ds = DataSet::from_arrays(x, y).unwrap();
let cfg = Config::default()
.max_depth(1)
.population(6)
.max_epochs(1200)
.learning_rate(0.1);
let front = discover_gumbel(&ds, &cfg).unwrap();
assert!(!front.is_empty());
let mean = ys.iter().sum::<f64>() / ys.len() as f64;
let var = ys.iter().map(|v| (v - mean) * (v - mean)).sum::<f64>() / ys.len() as f64;
let best = front.best().unwrap();
assert!(
best.mse < var * 0.5,
"gumbel best mse {} not below half-variance {} ({})",
best.mse,
var * 0.5,
best.pretty()
);
}
#[test]
fn structural_penalty_preserves_fit_and_is_deterministic() {
let xs: Vec<f64> = (0..40).map(|i| f64::from(i) * 0.08).collect();
let ys: Vec<f64> = xs.iter().map(|&x| x.exp()).collect();
let x = Array2::from_shape_vec((xs.len(), 1), xs).unwrap();
let ds = DataSet::from_arrays(x, Array1::from(ys.clone())).unwrap();
let mut cfg = Config::default()
.max_depth(1)
.population(4)
.max_epochs(1000)
.learning_rate(0.1)
.seed(11);
cfg.lambda_complexity = 1e-2;
cfg.lambda_sparsity = 1e-2;
cfg.lambda_parsimony = 1e-2;
let mean = ys.iter().sum::<f64>() / ys.len() as f64;
let var = ys.iter().map(|v| (v - mean) * (v - mean)).sum::<f64>() / ys.len() as f64;
let a = discover_gumbel(&ds, &cfg).unwrap();
let b = discover_gumbel(&ds, &cfg).unwrap();
let best_a = a.best().unwrap();
assert!(
best_a.mse < var * 0.5,
"penalized gumbel did not fit: mse={} ({})",
best_a.mse,
best_a.pretty()
);
assert!(
(best_a.mse - b.best().unwrap().mse).abs() < 1e-9,
"gumbel not deterministic for a fixed seed"
);
}
}