use std::sync::Arc;
use crate::error::EmlError;
use crate::symreg::topology::topology_interval_feasible;
use crate::symreg::{DiscoveredFormula, SymRegConfig, SymRegEngine};
use crate::tree::{EmlNode, EmlTree};
type Rng = rand::rngs::StdRng;
#[derive(Clone, Debug)]
enum PartialNode {
Hole,
One,
Var(usize),
Eml(Box<PartialNode>, Box<PartialNode>),
}
impl PartialNode {
fn hole_count(&self) -> usize {
match self {
PartialNode::Hole => 1,
PartialNode::One | PartialNode::Var(_) => 0,
PartialNode::Eml(l, r) => l.hole_count() + r.hole_count(),
}
}
fn expand_leftmost(&mut self, action: &MctsAction) -> bool {
match self {
PartialNode::Hole => {
*self = match action {
MctsAction::One => PartialNode::One,
MctsAction::Var(i) => PartialNode::Var(*i),
MctsAction::Expand => {
PartialNode::Eml(Box::new(PartialNode::Hole), Box::new(PartialNode::Hole))
}
};
true
}
PartialNode::One | PartialNode::Var(_) => false,
PartialNode::Eml(l, r) => {
if l.expand_leftmost(action) {
true
} else {
r.expand_leftmost(action)
}
}
}
}
fn complete_random(&mut self, num_vars: usize, rng: &mut Rng) {
use rand::RngExt;
match self {
PartialNode::Hole => {
let choices = 1 + num_vars; let idx = rng.random_range(0..choices);
*self = if idx == 0 {
PartialNode::One
} else {
PartialNode::Var(idx - 1)
};
}
PartialNode::One | PartialNode::Var(_) => {}
PartialNode::Eml(l, r) => {
l.complete_random(num_vars, rng);
r.complete_random(num_vars, rng);
}
}
}
fn to_eml_node(&self) -> Arc<EmlNode> {
match self {
PartialNode::Hole => {
debug_assert!(false, "to_eml_node called on a Hole — invariant violated");
Arc::new(EmlNode::One)
}
PartialNode::One => Arc::new(EmlNode::One),
PartialNode::Var(i) => Arc::new(EmlNode::Var(*i)),
PartialNode::Eml(l, r) => Arc::new(EmlNode::Eml {
left: l.to_eml_node(),
right: r.to_eml_node(),
}),
}
}
}
#[derive(Clone, Debug)]
enum MctsAction {
One,
Var(usize),
Expand,
}
fn legal_actions(hole_depth: usize, max_depth: usize, num_vars: usize) -> Vec<MctsAction> {
let mut actions = Vec::with_capacity(1 + num_vars + 1);
actions.push(MctsAction::One);
for i in 0..num_vars {
actions.push(MctsAction::Var(i));
}
if hole_depth < max_depth {
actions.push(MctsAction::Expand);
}
actions
}
fn leftmost_hole_depth(node: &PartialNode, current: usize) -> Option<usize> {
match node {
PartialNode::Hole => Some(current),
PartialNode::One | PartialNode::Var(_) => None,
PartialNode::Eml(l, r) => {
leftmost_hole_depth(l, current + 1).or_else(|| leftmost_hole_depth(r, current + 1))
}
}
}
struct MctsNode {
partial: PartialNode,
visits: u64,
total_value: f64,
children: Vec<usize>,
parent: usize,
fully_expanded: bool,
next_action_idx: usize,
leftmost_hole_depth: Option<usize>,
}
impl MctsNode {
fn new(partial: PartialNode, parent: usize) -> Self {
let hole_depth = leftmost_hole_depth(&partial, 0);
Self {
partial,
visits: 0,
total_value: 0.0,
children: Vec::new(),
parent,
fully_expanded: false,
next_action_idx: 0,
leftmost_hole_depth: hole_depth,
}
}
fn is_complete(&self) -> bool {
self.partial.hole_count() == 0
}
fn ucb1(&self, parent_visits: u64, exploration: f64) -> f64 {
if self.visits == 0 {
return f64::INFINITY;
}
let exploitation = self.total_value / self.visits as f64;
let ln_parent = (parent_visits as f64).ln();
let exploration_term = exploration * (ln_parent / self.visits as f64).sqrt();
exploitation + exploration_term
}
}
fn partial_to_tree(node: &PartialNode) -> EmlTree {
let root = node.to_eml_node();
EmlTree::from_node(root)
}
pub(crate) fn run_mcts(
engine: &SymRegEngine,
inputs: &[Vec<f64>],
targets: &[f64],
num_vars: usize,
iterations: usize,
exploration: f64,
) -> Result<Vec<DiscoveredFormula>, EmlError> {
if inputs.is_empty() || targets.is_empty() {
return Err(EmlError::EmptyData);
}
if inputs.len() != targets.len() {
return Err(EmlError::DimensionMismatch(inputs.len(), targets.len()));
}
if iterations == 0 {
return Ok(vec![]);
}
let config = engine.config();
let max_depth = config.max_depth;
let surrogate_iters = config.max_iter.clamp(10, 50);
let surrogate_config = SymRegConfig {
max_iter: surrogate_iters,
num_restarts: 1,
cv_folds: None,
..config.clone()
};
let surrogate_engine = SymRegEngine::new(surrogate_config);
let interval_data = if config.interval_pruning {
use crate::lower_interval::IntervalLO;
let input_intervals: Vec<IntervalLO> = (0..num_vars)
.map(|j| {
let mut lo = f64::INFINITY;
let mut hi = f64::NEG_INFINITY;
for row in inputs.iter() {
if let Some(&v) = row.get(j) {
if v < lo {
lo = v;
}
if v > hi {
hi = v;
}
}
}
if lo.is_finite() && hi.is_finite() {
IntervalLO::new(lo, hi)
} else {
IntervalLO::full()
}
})
.collect();
let target_lo = targets.iter().copied().fold(f64::INFINITY, f64::min);
let target_hi = targets.iter().copied().fold(f64::NEG_INFINITY, f64::max);
Some((input_intervals, target_lo, target_hi))
} else {
None
};
let mut rng = make_mcts_rng(config.seed);
let root_partial = PartialNode::Hole;
let mut arena: Vec<MctsNode> = vec![MctsNode::new(root_partial, usize::MAX)];
let mut complete_nodes: Vec<(usize, f64)> = Vec::new();
let mut rollout_candidates: Vec<(EmlTree, f64)> = Vec::new();
for _iter in 0..iterations {
let mut node_idx = 0usize;
loop {
let (is_complete, fully_expanded) = {
let node = &arena[node_idx];
(node.is_complete(), node.fully_expanded)
};
if is_complete || !fully_expanded {
break;
}
let parent_visits = arena[node_idx].visits;
let children: Vec<usize> = arena[node_idx].children.clone();
let best_child = children
.iter()
.copied()
.max_by(|&a, &b| {
arena[a]
.ucb1(parent_visits, exploration)
.partial_cmp(&arena[b].ucb1(parent_visits, exploration))
.unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or(node_idx);
node_idx = best_child;
}
let expanded_idx = {
let (is_complete, fully_expanded) = {
let node = &arena[node_idx];
(node.is_complete(), node.fully_expanded)
};
if is_complete || fully_expanded {
node_idx
} else {
let (hole_depth, actions) = {
let node = &arena[node_idx];
let hd = node.leftmost_hole_depth.unwrap_or(0);
let acts = legal_actions(hd, max_depth, num_vars);
(hd, acts)
};
let _ = hole_depth;
let action_idx = arena[node_idx].next_action_idx;
if action_idx >= actions.len() {
arena[node_idx].fully_expanded = true;
node_idx
} else {
let action = &actions[action_idx];
let mut new_partial = arena[node_idx].partial.clone();
new_partial.expand_leftmost(action);
arena[node_idx].next_action_idx += 1;
if arena[node_idx].next_action_idx >= actions.len() {
arena[node_idx].fully_expanded = true;
}
let child_idx = arena.len();
let mut child_node = MctsNode::new(new_partial, node_idx);
child_node.leftmost_hole_depth = leftmost_hole_depth(&child_node.partial, 0);
arena.push(child_node);
arena[node_idx].children.push(child_idx);
child_idx
}
}
};
let (reward, rollout_tree_opt) = {
let mut rollout_partial = arena[expanded_idx].partial.clone();
rollout_partial.complete_random(num_vars, &mut rng);
let tree = partial_to_tree(&rollout_partial);
let feasible = if let Some((ref ivs, tlo, thi)) = interval_data {
let threshold = config.interval_pruning_depth_threshold;
if tree.depth() < threshold {
true
} else {
topology_interval_feasible(&tree, ivs, tlo, thi)
}
} else {
true
};
if !feasible {
(0.0, None)
} else {
let formula =
surrogate_engine.optimize_topology_pub(&tree, inputs, targets, expanded_idx);
match formula {
Some(f) => {
let r = 1.0 / (1.0 + f.mse);
(r, Some((tree, f.mse)))
}
None => (0.0, None),
}
}
};
if arena[expanded_idx].is_complete() {
complete_nodes.push((expanded_idx, reward));
}
if let Some(rt) = rollout_tree_opt {
rollout_candidates.push(rt);
}
let mut idx = expanded_idx;
loop {
arena[idx].visits += 1;
arena[idx].total_value += reward;
let parent = arena[idx].parent;
if parent == usize::MAX {
break;
}
idx = parent;
}
}
let mut candidate_trees: Vec<(EmlTree, f64)> = rollout_candidates;
for (node_idx, _) in &complete_nodes {
let node = &arena[*node_idx];
if node.is_complete() && node.visits > 0 {
let avg_reward = node.total_value / node.visits as f64;
let pseudo_mse = if avg_reward > 0.0 {
1.0 / avg_reward - 1.0
} else {
f64::INFINITY
};
let tree = partial_to_tree(&node.partial);
candidate_trees.push((tree, pseudo_mse));
}
}
candidate_trees.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let top_k = 20_usize;
let mut seen_hashes = std::collections::HashSet::new();
let unique_candidates: Vec<EmlTree> = candidate_trees
.into_iter()
.filter_map(|(tree, _)| {
use std::collections::hash_map::DefaultHasher;
use std::hash::Hasher;
let simplified = tree.lower().simplify();
let mut h = DefaultHasher::new();
simplified.structural_hash(&mut h);
let hash = h.finish();
if seen_hashes.insert(hash) {
Some(tree)
} else {
None
}
})
.take(top_k)
.collect();
if unique_candidates.is_empty() {
return Ok(vec![]);
}
engine.optimize_and_finalize_pub(unique_candidates, inputs, targets)
}
fn make_mcts_rng(seed: Option<u64>) -> Rng {
use rand::SeedableRng;
const MCTS_SALT: u64 = 0xDEAD_BEEF_CAFE_1234;
match seed {
Some(s) => {
let mixed = {
let mut z = s.wrapping_add(MCTS_SALT);
z = (z ^ (z >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9);
z = (z ^ (z >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb);
z ^ (z >> 31)
};
Rng::seed_from_u64(mixed)
}
None => rand::make_rng::<Rng>(),
}
}