1use crate::error::{PhopError, Result};
12use oxieml::{EmlNode, EmlTree};
13use scirs2_autograd as ag;
14use scirs2_autograd::tensor_ops as T;
15use scirs2_autograd::Context;
16use scirs2_core::ndarray::{Array1, Array2};
17use std::sync::{Mutex, OnceLock};
18
19pub const LN_EPS: f64 = 1e-12;
21pub const EXP_CLAMP: f64 = 50.0;
23
24pub(crate) fn col_placeholder_name(j: usize) -> &'static str {
26 static CACHE: OnceLock<Mutex<Vec<&'static str>>> = OnceLock::new();
27 let m = CACHE.get_or_init(|| Mutex::new(Vec::new()));
28 let mut v = m.lock().expect("placeholder name cache poisoned");
29 while v.len() <= j {
30 let name: &'static str = Box::leak(format!("phop_x{}", v.len()).into_boxed_str());
31 v.push(name);
32 }
33 v[j]
34}
35
36pub fn eml_guarded<'g>(a: ag::Tensor<'g, f64>, b: ag::Tensor<'g, f64>) -> ag::Tensor<'g, f64> {
38 let ea = T::exp(T::clip(a, -EXP_CLAMP, EXP_CLAMP));
39 let lb = T::ln(T::clip(b, LN_EPS, f64::MAX));
40 T::sub(ea, lb)
41}
42
43pub(crate) fn build_forward<'g>(
49 node: &EmlNode,
50 cols: &[ag::Tensor<'g, f64>],
51 ones: ag::Tensor<'g, f64>,
52 _g: &'g Context<f64>,
53) -> ag::Tensor<'g, f64> {
54 match node {
55 EmlNode::One => ones,
56 EmlNode::Const(c) => ones.scalar_mul(*c),
57 EmlNode::Var(i) => cols[*i],
58 EmlNode::Eml { left, right } => {
59 let a = build_forward(left, cols, ones, _g);
60 let b = build_forward(right, cols, ones, _g);
61 eml_guarded(a, b)
62 }
63 }
64}
65
66pub fn eval_tree(tree: &EmlTree, data: &Array2<f64>) -> Result<Array1<f64>> {
74 let batch = data.nrows();
75 let n_vars = data.ncols();
76 let col_vals: Vec<Array1<f64>> = (0..n_vars).map(|j| data.column(j).to_owned()).collect();
77 let ones_val: Array1<f64> = Array1::from_elem(batch, 1.0);
78
79 let values: std::result::Result<Vec<f64>, String> = ag::run(|g: &mut Context<f64>| {
80 let cols: Vec<ag::Tensor<f64>> = (0..n_vars)
81 .map(|j| g.placeholder(col_placeholder_name(j), &[-1]))
82 .collect();
83 let ones = g.placeholder("phop_ones", &[-1]);
84 let out = build_forward(&tree.root, &cols, ones, g);
85 let mut feeder = ag::Feeder::new();
86 for (j, col) in cols.iter().enumerate() {
87 feeder = feeder.push(*col, col_vals[j].view().into_dyn());
88 }
89 feeder = feeder.push(ones, ones_val.view().into_dyn());
90 let results = g.evaluator().push(&out).set_feeder(feeder).run();
91 match &results[0] {
92 Ok(arr) => Ok(arr.iter().copied().collect()),
93 Err(e) => Err(format!("{e:?}")),
94 }
95 });
96
97 let values = values.map_err(PhopError::Eval)?;
98 if values.iter().any(|v| !v.is_finite()) {
99 return Err(PhopError::NumericalInstability(
100 "forward pass produced non-finite values".to_string(),
101 ));
102 }
103 Ok(Array1::from(values))
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use oxieml::{Canonical, EvalCtx};
110
111 #[test]
112 fn forward_matches_oxieml_exp() {
113 let x = EmlTree::var(0);
114 let tree = Canonical::exp(&x);
115 let data = Array2::from_shape_vec((4, 1), vec![0.0, 0.5, 1.0, 2.0]).unwrap();
116
117 let ours = eval_tree(&tree, &data).unwrap();
118 for i in 0..4 {
119 let ctx = EvalCtx::new(&[data[[i, 0]]]);
120 let theirs = tree.eval_real(&ctx).unwrap();
121 assert!(
122 (ours[i] - theirs).abs() < 1e-9,
123 "row {i}: ours={} theirs={theirs}",
124 ours[i]
125 );
126 }
127 }
128
129 #[test]
130 fn forward_matches_oxieml_two_vars() {
131 let t = EmlTree::eml(&EmlTree::var(0), &EmlTree::var(1));
132 let data = Array2::from_shape_vec((3, 2), vec![0.0, 1.0, 1.0, 2.0, 0.5, 3.0]).unwrap();
133 let ours = eval_tree(&t, &data).unwrap();
134 for i in 0..3 {
135 let ctx = EvalCtx::new(&[data[[i, 0]], data[[i, 1]]]);
136 let theirs = t.eval_real(&ctx).unwrap();
137 assert!((ours[i] - theirs).abs() < 1e-9, "row {i}");
138 }
139 }
140
141 #[test]
142 fn const_leaf_evaluates() {
143 let t = EmlTree::eml(&EmlTree::var(0), &EmlTree::const_val(2.0));
145 let data = Array2::from_shape_vec((2, 1), vec![0.0, 1.0]).unwrap();
146 let ours = eval_tree(&t, &data).unwrap();
147 for i in 0..2 {
148 let ctx = EvalCtx::new(&[data[[i, 0]]]);
149 let theirs = t.eval_real(&ctx).unwrap();
150 assert!((ours[i] - theirs).abs() < 1e-9, "row {i}");
151 }
152 }
153
154 #[test]
155 fn guarded_eval_stays_finite_on_extremes() {
156 let t = EmlTree::eml(&EmlTree::var(0), &EmlTree::var(1));
159 let data = Array2::from_shape_vec(
160 (4, 2),
161 vec![
162 1000.0,
163 -5.0, -1000.0,
165 0.0, 1e308,
167 1e-300, f64::from(0),
169 1.0,
170 ],
171 )
172 .unwrap();
173 let out = eval_tree(&t, &data).expect("guarded eval must not error on extremes");
174 assert!(
175 out.iter().all(|v| v.is_finite()),
176 "guarded eval produced a non-finite value: {out:?}"
177 );
178 }
179}