use crate::error::EmlError;
use crate::tree::{EmlNode, EmlTree};
use num_complex::Complex64;
#[derive(Clone, Debug)]
pub struct EvalCtx {
vars: Vec<f64>,
}
const EXP_CLAMP: f64 = 709.0;
const IMAG_THRESHOLD: f64 = 1e-12;
#[cfg_attr(not(feature = "parallel"), allow(dead_code))]
const PARALLEL_BATCH_THRESHOLD: usize = 128;
impl EvalCtx {
pub fn new(vars: &[f64]) -> Self {
Self {
vars: vars.to_vec(),
}
}
pub fn get(&self, index: usize) -> Option<f64> {
self.vars.get(index).copied()
}
pub fn num_vars(&self) -> usize {
self.vars.len()
}
pub fn as_slice(&self) -> &[f64] {
&self.vars
}
}
impl EmlTree {
pub fn eval_real(&self, ctx: &EvalCtx) -> Result<f64, EmlError> {
let complex_vars: Vec<Complex64> =
ctx.vars.iter().map(|&v| Complex64::new(v, 0.0)).collect();
let result = self.eval_complex(&complex_vars)?;
if result.im.abs() < IMAG_THRESHOLD {
let re = result.re;
if re.is_nan() {
return Err(EmlError::NanEncountered);
}
Ok(re)
} else {
Err(EmlError::ComplexResult(result.im.abs()))
}
}
pub fn eval_complex(&self, vars: &[Complex64]) -> Result<Complex64, EmlError> {
let mut instructions = Vec::new();
flatten_postorder(&self.root, &mut instructions);
let mut stack: Vec<Complex64> = Vec::with_capacity(instructions.len());
for inst in &instructions {
match inst {
Instruction::PushOne => {
stack.push(Complex64::new(1.0, 0.0));
}
Instruction::PushVar(idx) => {
let idx = *idx;
if idx >= vars.len() {
return Err(EmlError::VarOutOfBounds(idx, vars.len()));
}
stack.push(vars[idx]);
}
Instruction::Eml => {
let right = stack.pop().ok_or(EmlError::NanEncountered)?;
let left = stack.pop().ok_or(EmlError::NanEncountered)?;
let result = eml_complex(left, right)?;
stack.push(result);
}
}
}
debug_assert_eq!(stack.len(), 1);
Ok(stack[0])
}
pub fn eval_batch(&self, data: &[Vec<f64>]) -> Result<Vec<f64>, EmlError> {
if data.is_empty() {
return Err(EmlError::EmptyData);
}
let mut instructions = Vec::new();
flatten_postorder(&self.root, &mut instructions);
#[cfg(feature = "parallel")]
if data.len() >= PARALLEL_BATCH_THRESHOLD {
use rayon::prelude::*;
return data
.par_iter()
.map(|point| eval_point(&instructions, point))
.collect::<Result<Vec<f64>, EmlError>>();
}
data.iter()
.map(|point| eval_point(&instructions, point))
.collect::<Result<Vec<f64>, EmlError>>()
}
}
#[derive(Clone, Debug)]
enum Instruction {
PushOne,
PushVar(usize),
Eml,
}
fn eval_point(instructions: &[Instruction], point: &[f64]) -> Result<f64, EmlError> {
let complex_vars: Vec<Complex64> = point.iter().map(|&v| Complex64::new(v, 0.0)).collect();
let mut stack: Vec<Complex64> = Vec::with_capacity(instructions.len());
for inst in instructions {
match inst {
Instruction::PushOne => {
stack.push(Complex64::new(1.0, 0.0));
}
Instruction::PushVar(idx) => {
let idx = *idx;
if idx >= complex_vars.len() {
return Err(EmlError::VarOutOfBounds(idx, complex_vars.len()));
}
stack.push(complex_vars[idx]);
}
Instruction::Eml => {
let right = stack.pop().ok_or(EmlError::NanEncountered)?;
let left = stack.pop().ok_or(EmlError::NanEncountered)?;
let result = eml_complex(left, right)?;
stack.push(result);
}
}
}
let result = stack[0];
if result.im.abs() < IMAG_THRESHOLD {
if result.re.is_nan() {
return Err(EmlError::NanEncountered);
}
Ok(result.re)
} else {
Err(EmlError::ComplexResult(result.im.abs()))
}
}
fn flatten_postorder(node: &EmlNode, out: &mut Vec<Instruction>) {
match node {
EmlNode::One => out.push(Instruction::PushOne),
EmlNode::Var(idx) => out.push(Instruction::PushVar(*idx)),
EmlNode::Eml { left, right } => {
flatten_postorder(left, out);
flatten_postorder(right, out);
out.push(Instruction::Eml);
}
}
}
fn eml_complex(left: Complex64, right: Complex64) -> Result<Complex64, EmlError> {
let clamped_left = if left.re > EXP_CLAMP {
Complex64::new(EXP_CLAMP, left.im)
} else if left.re < -EXP_CLAMP {
Complex64::new(-EXP_CLAMP, left.im)
} else {
left
};
let exp_part = clamped_left.exp();
let ln_part = right.ln();
let result = exp_part - ln_part;
if result.re.is_nan() || result.im.is_nan() {
return Err(EmlError::NanEncountered);
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tree::EmlTree;
#[test]
fn test_eval_one() {
let t = EmlTree::one();
let ctx = EvalCtx::new(&[]);
let result = t.eval_real(&ctx).expect("eval of One leaf should succeed");
assert!((result - 1.0).abs() < 1e-15);
}
#[test]
fn test_eval_var() {
let t = EmlTree::var(0);
let ctx = EvalCtx::new(&[2.71]);
let result = t.eval_real(&ctx).expect("eval of Var leaf should succeed");
assert!((result - 2.71).abs() < 1e-15);
}
#[test]
fn test_eval_exp() {
let x = EmlTree::var(0);
let one = EmlTree::one();
let exp_x = EmlTree::eml(&x, &one);
let ctx = EvalCtx::new(&[1.0]);
let result = exp_x
.eval_real(&ctx)
.expect("eval of eml(x,1) = exp(x) should succeed");
assert!((result - std::f64::consts::E).abs() < 1e-10);
}
#[test]
fn test_eval_euler() {
let one = EmlTree::one();
let euler = EmlTree::eml(&one, &one);
let ctx = EvalCtx::new(&[]);
let result = euler
.eval_real(&ctx)
.expect("eval of eml(1,1) = e should succeed");
assert!((result - std::f64::consts::E).abs() < 1e-10);
}
#[test]
fn test_eval_batch() {
let x = EmlTree::var(0);
let one = EmlTree::one();
let exp_x = EmlTree::eml(&x, &one);
let data = vec![vec![0.0], vec![1.0], vec![2.0]];
let results = exp_x
.eval_batch(&data)
.expect("batch eval of exp should succeed");
assert!((results[0] - 1.0).abs() < 1e-10);
assert!((results[1] - std::f64::consts::E).abs() < 1e-10);
assert!((results[2] - (2.0_f64).exp()).abs() < 1e-10);
}
#[test]
fn test_var_out_of_bounds() {
let t = EmlTree::var(5);
let ctx = EvalCtx::new(&[1.0]);
assert!(matches!(
t.eval_real(&ctx),
Err(EmlError::VarOutOfBounds(5, 1))
));
}
#[test]
fn test_eval_batch_parallel() {
let x = EmlTree::var(0);
let one = EmlTree::one();
let exp_x = EmlTree::eml(&x, &one);
let data: Vec<Vec<f64>> = (0..200).map(|i| vec![i as f64 * 0.01]).collect();
let results = exp_x
.eval_batch(&data)
.expect("parallel batch eval should succeed");
assert_eq!(results.len(), 200);
for (i, &r) in results.iter().enumerate() {
let expected = (i as f64 * 0.01_f64).exp();
assert!(
(r - expected).abs() < 1e-8,
"index {i}: got {r}, expected {expected}"
);
}
}
#[test]
fn test_eval_batch_parallel_error_short_circuit() {
let t = EmlTree::var(5);
let data = vec![vec![1.0], vec![2.0], vec![3.0]];
let result = t.eval_batch(&data);
assert!(matches!(result, Err(EmlError::VarOutOfBounds(5, 1))));
}
}