Skip to main content

phop_core/
forest.rs

1//! Layer A — tensorized EML forest forward evaluation.
2//!
3//! This module builds an autograd forward graph for an EML tree (or, later, a whole
4//! population) over a batch of data, with **guarded** evaluation so that `exp` cannot
5//! overflow and `ln` is never applied to non-positive values (Risk T1).
6//!
7//! Data enters through fed **placeholders** rather than `constant` nodes: this is the
8//! idiomatic scirs2-autograd path, avoids per-call allocation, and sidesteps verbose
9//! debug output emitted by the constant op.
10
11use crate::error::{PhopError, Result};
12use oxieml::{EmlNode, EmlTree};
13use scirs2_autograd as ag;
14use scirs2_autograd::tensor_ops as T;
15use scirs2_autograd::Context;
16use scirs2_core::ndarray::{Array1, Array2};
17use std::sync::{Mutex, OnceLock};
18
19/// Lower clamp on the argument of `ln` (keeps it strictly positive).
20pub const LN_EPS: f64 = 1e-12;
21/// Symmetric clamp on the argument of `exp` (prevents overflow to `inf`).
22pub const EXP_CLAMP: f64 = 50.0;
23
24/// Stable `'static` placeholder name for feature column `j` (leaked once, then cached).
25pub(crate) fn col_placeholder_name(j: usize) -> &'static str {
26    static CACHE: OnceLock<Mutex<Vec<&'static str>>> = OnceLock::new();
27    let m = CACHE.get_or_init(|| Mutex::new(Vec::new()));
28    let mut v = m.lock().expect("placeholder name cache poisoned");
29    while v.len() <= j {
30        let name: &'static str = Box::leak(format!("phop_x{}", v.len()).into_boxed_str());
31        v.push(name);
32    }
33    v[j]
34}
35
36/// Guarded EML primitive on graph tensors: `eml(a, b) = exp(clip(a)) - ln(clip(b))`.
37pub fn eml_guarded<'g>(a: ag::Tensor<'g, f64>, b: ag::Tensor<'g, f64>) -> ag::Tensor<'g, f64> {
38    let ea = T::exp(T::clip(a, -EXP_CLAMP, EXP_CLAMP));
39    let lb = T::ln(T::clip(b, LN_EPS, f64::MAX));
40    T::sub(ea, lb)
41}
42
43/// Build the forward graph for an EML node over per-variable column tensors.
44///
45/// `ones` is a fed `[batch]` tensor of ones; deriving every leaf from it (rather than from
46/// `ones`/`constant` ops) keeps the graph free of const-generating ops. The result is a
47/// `[batch]` prediction tensor: `1` leaves become `ones`, `Const(c)` becomes `ones * c`.
48pub(crate) fn build_forward<'g>(
49    node: &EmlNode,
50    cols: &[ag::Tensor<'g, f64>],
51    ones: ag::Tensor<'g, f64>,
52    _g: &'g Context<f64>,
53) -> ag::Tensor<'g, f64> {
54    match node {
55        EmlNode::One => ones,
56        EmlNode::Const(c) => ones.scalar_mul(*c),
57        EmlNode::Var(i) => cols[*i],
58        EmlNode::Eml { left, right } => {
59            let a = build_forward(left, cols, ones, _g);
60            let b = build_forward(right, cols, ones, _g);
61            eml_guarded(a, b)
62        }
63    }
64}
65
66/// Evaluate an [`EmlTree`] forward through autograd over the given data.
67///
68/// Returns predictions of length `data.nrows()`.
69///
70/// # Errors
71/// Returns [`PhopError::Eval`] if the autograd evaluation fails, or
72/// [`PhopError::NumericalInstability`] if the output contains non-finite values.
73pub fn eval_tree(tree: &EmlTree, data: &Array2<f64>) -> Result<Array1<f64>> {
74    let batch = data.nrows();
75    let n_vars = data.ncols();
76    let col_vals: Vec<Array1<f64>> = (0..n_vars).map(|j| data.column(j).to_owned()).collect();
77    let ones_val: Array1<f64> = Array1::from_elem(batch, 1.0);
78
79    let values: std::result::Result<Vec<f64>, String> = ag::run(|g: &mut Context<f64>| {
80        let cols: Vec<ag::Tensor<f64>> = (0..n_vars)
81            .map(|j| g.placeholder(col_placeholder_name(j), &[-1]))
82            .collect();
83        let ones = g.placeholder("phop_ones", &[-1]);
84        let out = build_forward(&tree.root, &cols, ones, g);
85        let mut feeder = ag::Feeder::new();
86        for (j, col) in cols.iter().enumerate() {
87            feeder = feeder.push(*col, col_vals[j].view().into_dyn());
88        }
89        feeder = feeder.push(ones, ones_val.view().into_dyn());
90        let results = g.evaluator().push(&out).set_feeder(feeder).run();
91        match &results[0] {
92            Ok(arr) => Ok(arr.iter().copied().collect()),
93            Err(e) => Err(format!("{e:?}")),
94        }
95    });
96
97    let values = values.map_err(PhopError::Eval)?;
98    if values.iter().any(|v| !v.is_finite()) {
99        return Err(PhopError::NumericalInstability(
100            "forward pass produced non-finite values".to_string(),
101        ));
102    }
103    Ok(Array1::from(values))
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use oxieml::{Canonical, EvalCtx};
110
111    #[test]
112    fn forward_matches_oxieml_exp() {
113        let x = EmlTree::var(0);
114        let tree = Canonical::exp(&x);
115        let data = Array2::from_shape_vec((4, 1), vec![0.0, 0.5, 1.0, 2.0]).unwrap();
116
117        let ours = eval_tree(&tree, &data).unwrap();
118        for i in 0..4 {
119            let ctx = EvalCtx::new(&[data[[i, 0]]]);
120            let theirs = tree.eval_real(&ctx).unwrap();
121            assert!(
122                (ours[i] - theirs).abs() < 1e-9,
123                "row {i}: ours={} theirs={theirs}",
124                ours[i]
125            );
126        }
127    }
128
129    #[test]
130    fn forward_matches_oxieml_two_vars() {
131        let t = EmlTree::eml(&EmlTree::var(0), &EmlTree::var(1));
132        let data = Array2::from_shape_vec((3, 2), vec![0.0, 1.0, 1.0, 2.0, 0.5, 3.0]).unwrap();
133        let ours = eval_tree(&t, &data).unwrap();
134        for i in 0..3 {
135            let ctx = EvalCtx::new(&[data[[i, 0]], data[[i, 1]]]);
136            let theirs = t.eval_real(&ctx).unwrap();
137            assert!((ours[i] - theirs).abs() < 1e-9, "row {i}");
138        }
139    }
140
141    #[test]
142    fn const_leaf_evaluates() {
143        // eml(x0, c) = exp(x0) - ln(c)
144        let t = EmlTree::eml(&EmlTree::var(0), &EmlTree::const_val(2.0));
145        let data = Array2::from_shape_vec((2, 1), vec![0.0, 1.0]).unwrap();
146        let ours = eval_tree(&t, &data).unwrap();
147        for i in 0..2 {
148            let ctx = EvalCtx::new(&[data[[i, 0]]]);
149            let theirs = t.eval_real(&ctx).unwrap();
150            assert!((ours[i] - theirs).abs() < 1e-9, "row {i}");
151        }
152    }
153
154    #[test]
155    fn guarded_eval_stays_finite_on_extremes(/* Risk T1 */) {
156        // eml(x0, x1) on inputs that would overflow exp or take ln of a non-positive number:
157        // the guard (clip exp arg to ±50, clip ln arg to >= 1e-12) must keep the output finite.
158        let t = EmlTree::eml(&EmlTree::var(0), &EmlTree::var(1));
159        let data = Array2::from_shape_vec(
160            (4, 2),
161            vec![
162                1000.0,
163                -5.0, // exp(1000) would be inf; ln(-5) is undefined
164                -1000.0,
165                0.0, // exp(-1000) underflows; ln(0) is -inf
166                1e308,
167                1e-300, // extreme magnitudes
168                f64::from(0),
169                1.0,
170            ],
171        )
172        .unwrap();
173        let out = eval_tree(&t, &data).expect("guarded eval must not error on extremes");
174        assert!(
175            out.iter().all(|v| v.is_finite()),
176            "guarded eval produced a non-finite value: {out:?}"
177        );
178    }
179}