oxieml 0.1.1

EML operator: all elementary functions from exp(x) - ln(y)
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
//! Monte-Carlo Tree Search (MCTS) over partial EML tree topologies.
//!
//! Implements UCB1-guided exploration of the EML grammar space:
//!
//! ```text
//! S → One | Var(i) | Eml(S, S)
//! ```
//!
//! Each MCTS node represents a *partial* EML tree (a tree with some leaves
//! still unexpanded, called HOLEs). The algorithm selects the leftmost HOLE
//! for expansion at each step, guaranteeing that every complete tree is
//! reachable by exactly one action sequence (no double-counting).
//!
//! **UCB1 score** (for child `c` with parent `p`):
//!
//! ```text
//! score(c) = c.total_value / c.visits
//!            + exploration * sqrt(ln(p.visits) / c.visits)
//! ```
//!
//! **Reward**: `1.0 / (1.0 + mse)` — bounded in `(0, 1]`, suitable for UCB1.

use std::sync::Arc;

use crate::error::EmlError;
use crate::symreg::topology::topology_interval_feasible;
use crate::symreg::{DiscoveredFormula, SymRegConfig, SymRegEngine};
use crate::tree::{EmlNode, EmlTree};

type Rng = rand::rngs::StdRng;

/// A partial EML tree: a recursive enum that mirrors `EmlNode` but adds a `Hole` variant
/// for unexpanded leaves.
///
/// We avoid the flat `Vec<Option<EmlNode>>` representation because `EmlNode` is
/// `Arc`-recursive; conversion would require double marshalling. A recursive
/// enum converts to `Arc<EmlNode>` in O(n) with a simple `match`.
#[derive(Clone, Debug)]
enum PartialNode {
    /// Unexpanded leaf — will be replaced by One, Var(i), or Eml during expansion.
    Hole,
    /// The constant `1` (corresponds to `EmlNode::One`).
    One,
    /// Input variable `x_i` (corresponds to `EmlNode::Var(i)`).
    Var(usize),
    /// `eml(left, right) = exp(left) − ln(right)`.
    Eml(Box<PartialNode>, Box<PartialNode>),
}

impl PartialNode {
    /// Count HOLEs in the subtree.
    fn hole_count(&self) -> usize {
        match self {
            PartialNode::Hole => 1,
            PartialNode::One | PartialNode::Var(_) => 0,
            PartialNode::Eml(l, r) => l.hole_count() + r.hole_count(),
        }
    }

    /// Find the leftmost HOLE and apply `action` to it.
    ///
    /// Returns `true` if the action was applied (i.e., a HOLE was found).
    fn expand_leftmost(&mut self, action: &MctsAction) -> bool {
        match self {
            PartialNode::Hole => {
                *self = match action {
                    MctsAction::One => PartialNode::One,
                    MctsAction::Var(i) => PartialNode::Var(*i),
                    MctsAction::Expand => {
                        PartialNode::Eml(Box::new(PartialNode::Hole), Box::new(PartialNode::Hole))
                    }
                };
                true
            }
            PartialNode::One | PartialNode::Var(_) => false,
            PartialNode::Eml(l, r) => {
                if l.expand_leftmost(action) {
                    true
                } else {
                    r.expand_leftmost(action)
                }
            }
        }
    }

    /// Complete all remaining HOLEs by sampling from `{One, Var(0), ..., Var(n-1)}`
    /// uniformly at random (no more `Expand` — forces a finite tree).
    fn complete_random(&mut self, num_vars: usize, rng: &mut Rng) {
        use rand::RngExt;
        match self {
            PartialNode::Hole => {
                let choices = 1 + num_vars; // One + Var(0..n-1)
                let idx = rng.random_range(0..choices);
                *self = if idx == 0 {
                    PartialNode::One
                } else {
                    PartialNode::Var(idx - 1)
                };
            }
            PartialNode::One | PartialNode::Var(_) => {}
            PartialNode::Eml(l, r) => {
                l.complete_random(num_vars, rng);
                r.complete_random(num_vars, rng);
            }
        }
    }

    /// Convert a complete (Hole-free) `PartialNode` into `Arc<EmlNode>`.
    ///
    /// Panics in debug builds if any `Hole` remains (invariant violation).
    fn to_eml_node(&self) -> Arc<EmlNode> {
        match self {
            PartialNode::Hole => {
                // This should never happen if called on a complete tree.
                // Return a sentinel (One) instead of panicking in release.
                debug_assert!(false, "to_eml_node called on a Hole — invariant violated");
                Arc::new(EmlNode::One)
            }
            PartialNode::One => Arc::new(EmlNode::One),
            PartialNode::Var(i) => Arc::new(EmlNode::Var(*i)),
            PartialNode::Eml(l, r) => Arc::new(EmlNode::Eml {
                left: l.to_eml_node(),
                right: r.to_eml_node(),
            }),
        }
    }
}

/// The action taken to expand the leftmost HOLE.
#[derive(Clone, Debug)]
enum MctsAction {
    /// Replace the HOLE with the constant `1`.
    One,
    /// Replace the HOLE with input variable `x_i`.
    Var(usize),
    /// Replace the HOLE with `eml(HOLE, HOLE)` — adds two new HOLEs.
    Expand,
}

/// Legal actions for expanding the leftmost HOLE at depth `hole_depth`.
///
/// If `hole_depth >= max_depth`, only terminal actions (One, Var) are legal —
/// adding an Eml node would push children to depth `hole_depth + 1 > max_depth`.
fn legal_actions(hole_depth: usize, max_depth: usize, num_vars: usize) -> Vec<MctsAction> {
    let mut actions = Vec::with_capacity(1 + num_vars + 1);
    actions.push(MctsAction::One);
    for i in 0..num_vars {
        actions.push(MctsAction::Var(i));
    }
    if hole_depth < max_depth {
        actions.push(MctsAction::Expand);
    }
    actions
}

/// Compute the depth of the leftmost HOLE in a `PartialNode` tree.
fn leftmost_hole_depth(node: &PartialNode, current: usize) -> Option<usize> {
    match node {
        PartialNode::Hole => Some(current),
        PartialNode::One | PartialNode::Var(_) => None,
        PartialNode::Eml(l, r) => {
            leftmost_hole_depth(l, current + 1).or_else(|| leftmost_hole_depth(r, current + 1))
        }
    }
}

/// A node in the MCTS search tree.
///
/// Uses a flat `Vec<MctsNode>` with index-based parent/child links to avoid
/// `Rc<RefCell<...>>` lifetime complexity.
struct MctsNode {
    /// The partial tree stored at this MCTS node.
    partial: PartialNode,
    /// Number of times this node has been visited.
    visits: u64,
    /// Cumulative reward (`1/(1+mse)`, bounded in `(0,1]`).
    total_value: f64,
    /// Indices of child nodes in the flat arena.
    children: Vec<usize>,
    /// Index of parent node (`usize::MAX` for the root).
    parent: usize,
    /// Whether all legal actions from this node have been tried.
    fully_expanded: bool,
    /// Number of children already expanded (index into `legal_actions`).
    next_action_idx: usize,
    /// Depth of the leftmost HOLE at this node (cached for action generation).
    leftmost_hole_depth: Option<usize>,
}

impl MctsNode {
    fn new(partial: PartialNode, parent: usize) -> Self {
        let hole_depth = leftmost_hole_depth(&partial, 0);
        Self {
            partial,
            visits: 0,
            total_value: 0.0,
            children: Vec::new(),
            parent,
            fully_expanded: false,
            next_action_idx: 0,
            leftmost_hole_depth: hole_depth,
        }
    }

    /// Returns `true` if this partial tree has no remaining HOLEs.
    fn is_complete(&self) -> bool {
        self.partial.hole_count() == 0
    }

    /// UCB1 score for this node given parent's visit count.
    fn ucb1(&self, parent_visits: u64, exploration: f64) -> f64 {
        if self.visits == 0 {
            return f64::INFINITY;
        }
        let exploitation = self.total_value / self.visits as f64;
        let ln_parent = (parent_visits as f64).ln();
        let exploration_term = exploration * (ln_parent / self.visits as f64).sqrt();
        exploitation + exploration_term
    }
}

/// Convert a complete `PartialNode` to an `EmlTree`.
///
/// `EmlTree::from_node` counts variables internally via `count_vars`.
fn partial_to_tree(node: &PartialNode) -> EmlTree {
    let root = node.to_eml_node();
    EmlTree::from_node(root)
}

/// Run the MCTS algorithm over EML topology space.
///
/// This is the main entry point called from `SymRegEngine::discover_mcts`.
pub(crate) fn run_mcts(
    engine: &SymRegEngine,
    inputs: &[Vec<f64>],
    targets: &[f64],
    num_vars: usize,
    iterations: usize,
    exploration: f64,
) -> Result<Vec<DiscoveredFormula>, EmlError> {
    if inputs.is_empty() || targets.is_empty() {
        return Err(EmlError::EmptyData);
    }
    if inputs.len() != targets.len() {
        return Err(EmlError::DimensionMismatch(inputs.len(), targets.len()));
    }
    if iterations == 0 {
        return Ok(vec![]);
    }

    let config = engine.config();
    let max_depth = config.max_depth;

    // Surrogate engine: cheap Adam for rollout simulation.
    let surrogate_iters = config.max_iter.clamp(10, 50);
    let surrogate_config = SymRegConfig {
        max_iter: surrogate_iters,
        num_restarts: 1,
        cv_folds: None,
        ..config.clone()
    };
    let surrogate_engine = SymRegEngine::new(surrogate_config);

    // Interval pruning setup (if enabled).
    let interval_data = if config.interval_pruning {
        use crate::lower_interval::IntervalLO;
        let input_intervals: Vec<IntervalLO> = (0..num_vars)
            .map(|j| {
                let mut lo = f64::INFINITY;
                let mut hi = f64::NEG_INFINITY;
                for row in inputs.iter() {
                    if let Some(&v) = row.get(j) {
                        if v < lo {
                            lo = v;
                        }
                        if v > hi {
                            hi = v;
                        }
                    }
                }
                if lo.is_finite() && hi.is_finite() {
                    IntervalLO::new(lo, hi)
                } else {
                    IntervalLO::full()
                }
            })
            .collect();
        let target_lo = targets.iter().copied().fold(f64::INFINITY, f64::min);
        let target_hi = targets.iter().copied().fold(f64::NEG_INFINITY, f64::max);
        Some((input_intervals, target_lo, target_hi))
    } else {
        None
    };

    // RNG: use a different salt from the topology-level seeds.
    let mut rng = make_mcts_rng(config.seed);

    // Flat arena of MCTS nodes (avoids Rc cycles).
    let root_partial = PartialNode::Hole;
    let mut arena: Vec<MctsNode> = vec![MctsNode::new(root_partial, usize::MAX)];

    // Track which complete trees we've encountered and their best rewards.
    // Key = index in `arena`, Value = reward.
    let mut complete_nodes: Vec<(usize, f64)> = Vec::new();

    // Track rollout-completed trees and their surrogate MSE.
    // These are the primary source of interesting candidates when max_depth > 1.
    let mut rollout_candidates: Vec<(EmlTree, f64)> = Vec::new();

    for _iter in 0..iterations {
        // === SELECTION ===
        // Walk from root using UCB1 until we reach a node that:
        //   (a) is a complete tree (terminal), or
        //   (b) has at least one unexpanded action (will be expanded next).
        let mut node_idx = 0usize;
        loop {
            let (is_complete, fully_expanded) = {
                let node = &arena[node_idx];
                (node.is_complete(), node.fully_expanded)
            };

            if is_complete || !fully_expanded {
                // Stop here: either terminal, or has room to expand.
                break;
            }

            // All children visited: select the best UCB1 child.
            let parent_visits = arena[node_idx].visits;
            let children: Vec<usize> = arena[node_idx].children.clone();
            let best_child = children
                .iter()
                .copied()
                .max_by(|&a, &b| {
                    arena[a]
                        .ucb1(parent_visits, exploration)
                        .partial_cmp(&arena[b].ucb1(parent_visits, exploration))
                        .unwrap_or(std::cmp::Ordering::Equal)
                })
                .unwrap_or(node_idx);
            node_idx = best_child;
        }

        // === EXPANSION ===
        let expanded_idx = {
            let (is_complete, fully_expanded) = {
                let node = &arena[node_idx];
                (node.is_complete(), node.fully_expanded)
            };

            if is_complete || fully_expanded {
                // Can't expand further: simulate from here.
                node_idx
            } else {
                // Determine legal actions for this node's leftmost HOLE.
                let (hole_depth, actions) = {
                    let node = &arena[node_idx];
                    let hd = node.leftmost_hole_depth.unwrap_or(0);
                    let acts = legal_actions(hd, max_depth, num_vars);
                    (hd, acts)
                };
                let _ = hole_depth; // captured inside legal_actions call

                let action_idx = arena[node_idx].next_action_idx;
                if action_idx >= actions.len() {
                    // All actions exhausted — mark fully expanded.
                    arena[node_idx].fully_expanded = true;
                    node_idx
                } else {
                    // Apply action to create a new child.
                    let action = &actions[action_idx];
                    let mut new_partial = arena[node_idx].partial.clone();
                    new_partial.expand_leftmost(action);

                    // Update parent's action counter.
                    arena[node_idx].next_action_idx += 1;
                    if arena[node_idx].next_action_idx >= actions.len() {
                        arena[node_idx].fully_expanded = true;
                    }

                    let child_idx = arena.len();
                    let mut child_node = MctsNode::new(new_partial, node_idx);
                    // Recompute leftmost_hole_depth after expansion.
                    child_node.leftmost_hole_depth = leftmost_hole_depth(&child_node.partial, 0);
                    arena.push(child_node);
                    arena[node_idx].children.push(child_idx);
                    child_idx
                }
            }
        };

        // === SIMULATION (rollout) ===
        // Clone the partial tree, complete remaining HOLEs randomly, evaluate.
        // We also keep the rollout tree so we can include it in the candidate pool.
        let (reward, rollout_tree_opt) = {
            let mut rollout_partial = arena[expanded_idx].partial.clone();
            rollout_partial.complete_random(num_vars, &mut rng);

            let tree = partial_to_tree(&rollout_partial);

            // Interval pruning: skip infeasible topologies.
            let feasible = if let Some((ref ivs, tlo, thi)) = interval_data {
                let threshold = config.interval_pruning_depth_threshold;
                if tree.depth() < threshold {
                    true
                } else {
                    topology_interval_feasible(&tree, ivs, tlo, thi)
                }
            } else {
                true
            };

            if !feasible {
                // Assign a very low reward for pruned topologies; discard tree.
                (0.0, None)
            } else {
                // Quick Adam fit — keep the tree AND the reward.
                let formula =
                    surrogate_engine.optimize_topology_pub(&tree, inputs, targets, expanded_idx);
                match formula {
                    Some(f) => {
                        let r = 1.0 / (1.0 + f.mse);
                        (r, Some((tree, f.mse)))
                    }
                    None => (0.0, None),
                }
            }
        };

        // Track complete nodes (arena-complete path) AND rollout trees (simulation path).
        if arena[expanded_idx].is_complete() {
            complete_nodes.push((expanded_idx, reward));
        }
        // Always record the rollout tree regardless — this is the primary source of
        // interesting complete trees when max_depth > 1.
        if let Some(rt) = rollout_tree_opt {
            rollout_candidates.push(rt);
        }

        // === BACKPROPAGATION ===
        let mut idx = expanded_idx;
        loop {
            arena[idx].visits += 1;
            arena[idx].total_value += reward;
            let parent = arena[idx].parent;
            if parent == usize::MAX {
                break;
            }
            idx = parent;
        }
    }

    // === FINALIZATION ===
    // Merge three sources of complete trees into a single candidate pool:
    //   1. Rollout trees (simulation-completed partials) — the primary source
    //      for max_depth > 1; each has a surrogate MSE from the quick Adam fit.
    //   2. Arena-complete nodes — trees that were fully expanded during selection
    //      without needing a random completion rollout.
    //
    // We store (tree, mse) for rollout candidates and convert arena-complete nodes
    // to the same format using their accumulated average reward.
    let mut candidate_trees: Vec<(EmlTree, f64)> = rollout_candidates;

    // Arena-complete nodes: convert average reward back to a pseudo-MSE.
    // reward = 1/(1+mse) → mse = 1/reward − 1
    for (node_idx, _) in &complete_nodes {
        let node = &arena[*node_idx];
        if node.is_complete() && node.visits > 0 {
            let avg_reward = node.total_value / node.visits as f64;
            let pseudo_mse = if avg_reward > 0.0 {
                1.0 / avg_reward - 1.0
            } else {
                f64::INFINITY
            };
            let tree = partial_to_tree(&node.partial);
            candidate_trees.push((tree, pseudo_mse));
        }
    }

    // Sort by MSE ascending (lowest = best fit).
    candidate_trees.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));

    // De-duplicate by structural hash, keeping the top-K unique trees.
    let top_k = 20_usize;
    let mut seen_hashes = std::collections::HashSet::new();
    let unique_candidates: Vec<EmlTree> = candidate_trees
        .into_iter()
        .filter_map(|(tree, _)| {
            use std::collections::hash_map::DefaultHasher;
            use std::hash::Hasher;
            let simplified = tree.lower().simplify();
            let mut h = DefaultHasher::new();
            simplified.structural_hash(&mut h);
            let hash = h.finish();
            if seen_hashes.insert(hash) {
                Some(tree)
            } else {
                None
            }
        })
        .take(top_k)
        .collect();

    if unique_candidates.is_empty() {
        return Ok(vec![]);
    }

    // Full Adam optimization on the top candidates.
    engine.optimize_and_finalize_pub(unique_candidates, inputs, targets)
}

/// Create an RNG for MCTS rollouts with a distinct salt from topology seeds.
fn make_mcts_rng(seed: Option<u64>) -> Rng {
    use rand::SeedableRng;
    const MCTS_SALT: u64 = 0xDEAD_BEEF_CAFE_1234;
    match seed {
        Some(s) => {
            // SplitMix64 mixing with the salt.
            let mixed = {
                let mut z = s.wrapping_add(MCTS_SALT);
                z = (z ^ (z >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9);
                z = (z ^ (z >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb);
                z ^ (z >> 31)
            };
            Rng::seed_from_u64(mixed)
        }
        None => rand::make_rng::<Rng>(),
    }
}