use crate::error::EmlError;
use crate::eval::EvalCtx;
use crate::tree::{EmlNode, EmlTree};
use num_complex::Complex64;
use std::sync::Arc;
#[derive(Clone, Debug)]
pub struct ParameterizedEmlTree {
pub topology: EmlTree,
pub params: Vec<f64>,
}
#[derive(Clone, Debug)]
enum TapeEntry {
Param,
Var,
Eml(usize, usize),
}
impl ParameterizedEmlTree {
pub fn from_topology(topology: &EmlTree, init_value: f64) -> Self {
let count = count_ones(&topology.root);
Self {
topology: topology.clone(),
params: vec![init_value; count],
}
}
pub fn num_params(&self) -> usize {
self.params.len()
}
pub fn forward(&self, ctx: &EvalCtx) -> Result<f64, EmlError> {
let (_tape, values) = self.build_tape_and_forward(ctx)?;
let result = values.last().copied().unwrap_or(Complex64::new(0.0, 0.0));
if result.im.abs() < 1e-12 {
Ok(result.re)
} else {
Err(EmlError::ComplexResult(result.im.abs()))
}
}
pub fn forward_with_jacobian(&self, ctx: &EvalCtx) -> Result<(f64, Vec<f64>), EmlError> {
let (tape, values) = self.build_tape_and_forward(ctx)?;
let output = values.last().copied().unwrap_or(Complex64::new(0.0, 0.0));
if output.im.abs() >= 1e-12 {
return Err(EmlError::ComplexResult(output.im.abs()));
}
let output_re = output.re;
let n = tape.len();
let mut grad_values = vec![Complex64::new(0.0, 0.0); n];
grad_values[n - 1] = Complex64::new(1.0, 0.0);
for i in (0..n).rev() {
let g = grad_values[i];
if let TapeEntry::Eml(left_idx, right_idx) = &tape[i] {
let left_val = values[*left_idx];
let right_val = values[*right_idx];
let d_left = clamped_exp(left_val);
let d_right = -Complex64::new(1.0, 0.0) / right_val;
grad_values[*left_idx] += g * d_left;
grad_values[*right_idx] += g * d_right;
}
}
let mut param_jac = Vec::with_capacity(self.params.len());
for (i, entry) in tape.iter().enumerate() {
if let TapeEntry::Param = entry {
param_jac.push(grad_values[i].re);
}
}
Ok((output_re, param_jac))
}
pub fn forward_backward(
&self,
ctx: &EvalCtx,
target: f64,
) -> Result<(f64, Vec<f64>), EmlError> {
let (tape, values) = self.build_tape_and_forward(ctx)?;
let output = values.last().copied().unwrap_or(Complex64::new(0.0, 0.0));
if output.im.abs() >= 1e-12 {
return Err(EmlError::ComplexResult(output.im.abs()));
}
let output_re = output.re;
let loss = (output_re - target) * (output_re - target);
let n = tape.len();
let mut grad_values = vec![Complex64::new(0.0, 0.0); n];
grad_values[n - 1] = Complex64::new(2.0 * (output_re - target), 0.0);
for i in (0..n).rev() {
let g = grad_values[i];
if let TapeEntry::Eml(left_idx, right_idx) = &tape[i] {
let left_val = values[*left_idx];
let right_val = values[*right_idx];
let d_left = clamped_exp(left_val);
let d_right = -Complex64::new(1.0, 0.0) / right_val;
grad_values[*left_idx] += g * d_left;
grad_values[*right_idx] += g * d_right;
}
}
let mut param_grads = Vec::with_capacity(self.params.len());
for (i, entry) in tape.iter().enumerate() {
if let TapeEntry::Param = entry {
param_grads.push(grad_values[i].re);
}
}
Ok((loss, param_grads))
}
fn build_tape_and_forward(
&self,
ctx: &EvalCtx,
) -> Result<(Vec<TapeEntry>, Vec<Complex64>), EmlError> {
let mut tape = Vec::new();
let mut values = Vec::new();
let mut param_idx = 0;
self.build_tape_recursive(
&self.topology.root,
ctx,
&mut tape,
&mut values,
&mut param_idx,
)?;
Ok((tape, values))
}
fn build_tape_recursive(
&self,
node: &EmlNode,
ctx: &EvalCtx,
tape: &mut Vec<TapeEntry>,
values: &mut Vec<Complex64>,
param_idx: &mut usize,
) -> Result<usize, EmlError> {
match node {
EmlNode::One => {
let idx = tape.len();
let p = self.params[*param_idx];
*param_idx += 1;
tape.push(TapeEntry::Param);
values.push(Complex64::new(p, 0.0));
Ok(idx)
}
EmlNode::Var(var_idx) => {
let idx = tape.len();
let val = ctx
.get(*var_idx)
.ok_or(EmlError::VarOutOfBounds(*var_idx, ctx.num_vars()))?;
tape.push(TapeEntry::Var);
values.push(Complex64::new(val, 0.0));
Ok(idx)
}
EmlNode::Eml { left, right } => {
let left_idx = self.build_tape_recursive(left, ctx, tape, values, param_idx)?;
let right_idx = self.build_tape_recursive(right, ctx, tape, values, param_idx)?;
let left_val = values[left_idx];
let right_val = values[right_idx];
let result = eml_complex_grad(left_val, right_val)?;
let idx = tape.len();
tape.push(TapeEntry::Eml(left_idx, right_idx));
values.push(result);
Ok(idx)
}
}
}
}
fn count_ones(node: &EmlNode) -> usize {
match node {
EmlNode::One => 1,
EmlNode::Var(_) => 0,
EmlNode::Eml { left, right } => count_ones(left) + count_ones(right),
}
}
fn clamped_exp(z: Complex64) -> Complex64 {
let clamped = if z.re > 709.0 {
Complex64::new(709.0, z.im)
} else if z.re < -709.0 {
Complex64::new(-709.0, z.im)
} else {
z
};
clamped.exp()
}
fn eml_complex_grad(left: Complex64, right: Complex64) -> Result<Complex64, EmlError> {
let exp_part = clamped_exp(left);
let ln_part = right.ln();
let result = exp_part - ln_part;
if result.re.is_nan() || result.im.is_nan() {
return Err(EmlError::NanEncountered);
}
Ok(result)
}
pub fn reconstruct_tree(ptree: &ParameterizedEmlTree) -> EmlTree {
let mut param_idx = 0;
let root = reconstruct_node(&ptree.topology.root, &ptree.params, &mut param_idx);
EmlTree::from_node(root)
}
fn reconstruct_node(node: &EmlNode, params: &[f64], param_idx: &mut usize) -> Arc<EmlNode> {
match node {
EmlNode::One => {
let _p = params[*param_idx];
*param_idx += 1;
Arc::new(EmlNode::One)
}
EmlNode::Var(i) => Arc::new(EmlNode::Var(*i)),
EmlNode::Eml { left, right } => {
let l = reconstruct_node(left, params, param_idx);
let r = reconstruct_node(right, params, param_idx);
Arc::new(EmlNode::Eml { left: l, right: r })
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parameterized_forward() {
let one = EmlTree::one();
let tree = EmlTree::eml(&one, &one);
let ptree = ParameterizedEmlTree::from_topology(&tree, 1.0);
assert_eq!(ptree.num_params(), 2);
let ctx = EvalCtx::new(&[]);
let result = ptree
.forward(&ctx)
.expect("parameterized forward pass should succeed");
assert!((result - std::f64::consts::E).abs() < 1e-10);
}
#[test]
fn test_forward_backward() {
let x = EmlTree::var(0);
let one = EmlTree::one();
let tree = EmlTree::eml(&x, &one);
let ptree = ParameterizedEmlTree::from_topology(&tree, 1.0);
assert_eq!(ptree.num_params(), 1);
let ctx = EvalCtx::new(&[1.0]);
let target = std::f64::consts::E;
let (loss, grads) = ptree
.forward_backward(&ctx, target)
.expect("forward_backward should succeed");
assert!(loss < 1e-20);
assert_eq!(grads.len(), 1);
}
#[test]
fn test_gradient_nonzero() {
let one = EmlTree::one();
let tree = EmlTree::eml(&one, &one);
let ptree = ParameterizedEmlTree::from_topology(&tree, 1.0);
let ctx = EvalCtx::new(&[]);
let (loss, grads) = ptree
.forward_backward(&ctx, 3.0)
.expect("gradient computation should succeed");
assert!(loss > 0.0);
assert!(grads.iter().any(|g| g.abs() > 1e-10));
}
}