use std::collections::HashMap;
use hyperreal::{Real, RealSign};
#[derive(Clone, Debug, PartialEq)]
pub enum GpRealExpr {
Constant(Box<Real>),
Input(usize),
Add(Box<GpRealExpr>, Box<GpRealExpr>),
Sub(Box<GpRealExpr>, Box<GpRealExpr>),
Mul(Box<GpRealExpr>, Box<GpRealExpr>),
Div(Box<GpRealExpr>, Box<GpRealExpr>),
Neg(Box<GpRealExpr>),
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct GpValidationLimits {
pub input_arity: usize,
pub max_depth: usize,
pub max_nodes: usize,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum GpValidationIssue {
InputOutOfBounds {
input: usize,
arity: usize,
},
DepthExceeded {
depth: usize,
max_depth: usize,
},
NodeBudgetExceeded {
nodes: usize,
max_nodes: usize,
},
DivisionByZero,
MissingInput {
input: usize,
},
UnsupportedDivision,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct GpValidationReport {
pub depth: usize,
pub nodes: usize,
pub issues: Vec<GpValidationIssue>,
}
impl GpValidationReport {
pub fn is_valid(&self) -> bool {
self.issues.is_empty()
}
}
impl GpRealExpr {
pub fn validate(&self, limits: GpValidationLimits) -> GpValidationReport {
let mut issues = Vec::new();
let depth = self.depth();
let nodes = self.node_count();
if depth > limits.max_depth {
issues.push(GpValidationIssue::DepthExceeded {
depth,
max_depth: limits.max_depth,
});
}
if nodes > limits.max_nodes {
issues.push(GpValidationIssue::NodeBudgetExceeded {
nodes,
max_nodes: limits.max_nodes,
});
}
self.collect_validation_issues(limits.input_arity, &mut issues);
GpValidationReport {
depth,
nodes,
issues,
}
}
pub fn eval(&self, inputs: &[Real]) -> Result<Real, GpValidationIssue> {
match self {
Self::Constant(value) => Ok((**value).clone()),
Self::Input(index) => inputs
.get(*index)
.cloned()
.ok_or(GpValidationIssue::MissingInput { input: *index }),
Self::Add(left, right) => Ok(left.eval(inputs)? + right.eval(inputs)?),
Self::Sub(left, right) => Ok(left.eval(inputs)? - right.eval(inputs)?),
Self::Mul(left, right) => Ok(left.eval(inputs)? * right.eval(inputs)?),
Self::Div(left, right) => (left.eval(inputs)? / right.eval(inputs)?)
.map_err(|_| GpValidationIssue::UnsupportedDivision),
Self::Neg(value) => Ok(-value.eval(inputs)?),
}
}
pub fn depth(&self) -> usize {
match self {
Self::Constant(_) | Self::Input(_) => 1,
Self::Neg(value) => 1 + value.depth(),
Self::Add(left, right)
| Self::Sub(left, right)
| Self::Mul(left, right)
| Self::Div(left, right) => 1 + left.depth().max(right.depth()),
}
}
pub fn node_count(&self) -> usize {
match self {
Self::Constant(_) | Self::Input(_) => 1,
Self::Neg(value) => 1 + value.node_count(),
Self::Add(left, right)
| Self::Sub(left, right)
| Self::Mul(left, right)
| Self::Div(left, right) => 1 + left.node_count() + right.node_count(),
}
}
fn collect_validation_issues(&self, arity: usize, issues: &mut Vec<GpValidationIssue>) {
match self {
Self::Input(index) if *index >= arity => {
issues.push(GpValidationIssue::InputOutOfBounds {
input: *index,
arity,
})
}
Self::Div(_, right)
if matches!(
right
.constant_value()
.map(|value| value.structural_facts().sign),
Some(Some(RealSign::Zero))
) =>
{
issues.push(GpValidationIssue::DivisionByZero);
}
_ => {}
}
for child in self.children() {
child.collect_validation_issues(arity, issues);
}
}
fn children(&self) -> Vec<&GpRealExpr> {
match self {
Self::Add(left, right)
| Self::Sub(left, right)
| Self::Mul(left, right)
| Self::Div(left, right) => vec![left, right],
Self::Neg(value) => vec![value],
Self::Constant(_) | Self::Input(_) => Vec::new(),
}
}
fn constant_value(&self) -> Option<Real> {
match self {
Self::Constant(value) => Some((**value).clone()),
Self::Neg(value) => value.constant_value().map(|value| -value),
Self::Add(left, right) => Some(left.constant_value()? + right.constant_value()?),
Self::Sub(left, right) => Some(left.constant_value()? - right.constant_value()?),
Self::Mul(left, right) => Some(left.constant_value()? * right.constant_value()?),
Self::Div(left, right) => (left.constant_value()? / right.constant_value()?).ok(),
Self::Input(_) => None,
}
}
}
pub fn eval_gp_batch(
expressions: &[GpRealExpr],
inputs: &HashMap<usize, Real>,
) -> Vec<Result<Real, GpValidationIssue>> {
let dense_inputs = dense_input_vector(inputs);
expressions
.iter()
.map(|expression| expression.eval(&dense_inputs))
.collect()
}
fn dense_input_vector(inputs: &HashMap<usize, Real>) -> Vec<Real> {
let len = inputs
.keys()
.copied()
.max()
.map(|index| index + 1)
.unwrap_or(0);
let mut values = vec![Real::zero(); len];
for (index, value) in inputs {
values[*index] = value.clone();
}
values
}