use crate::error::{MathCompileError, Result};
use crate::final_tagless::{ASTRepr, ExpressionBuilder, VariableRegistry};
use crate::symbolic::SymbolicOptimizer;
use crate::symbolic_ad::SymbolicAD;
use std::collections::HashMap;
#[derive(Debug)]
pub struct MathBuilder {
builder: ExpressionBuilder,
constants: HashMap<String, f64>,
optimizer: Option<SymbolicOptimizer>,
}
impl MathBuilder {
#[must_use]
pub fn new() -> Self {
let mut constants = HashMap::new();
constants.insert("pi".to_string(), std::f64::consts::PI);
constants.insert("e".to_string(), std::f64::consts::E);
constants.insert("tau".to_string(), std::f64::consts::TAU);
constants.insert("sqrt2".to_string(), std::f64::consts::SQRT_2);
constants.insert("ln2".to_string(), std::f64::consts::LN_2);
constants.insert("ln10".to_string(), std::f64::consts::LN_10);
Self {
builder: ExpressionBuilder::new(),
constants,
optimizer: None,
}
}
pub fn with_optimization() -> Result<Self> {
let mut builder = Self::new();
builder.optimizer = Some(SymbolicOptimizer::new()?);
Ok(builder)
}
#[must_use]
pub fn var(&mut self, name: &str) -> ASTRepr<f64> {
self.builder.var(name)
}
#[must_use]
pub fn constant(&self, value: f64) -> ASTRepr<f64> {
self.builder.constant(value)
}
pub fn math_constant(&self, name: &str) -> Result<ASTRepr<f64>> {
self.constants
.get(name)
.map(|&value| ASTRepr::Constant(value))
.ok_or_else(|| {
MathCompileError::InvalidInput(format!(
"Unknown mathematical constant: {name}. Available: {}",
self.constants
.keys()
.cloned()
.collect::<Vec<_>>()
.join(", ")
))
})
}
#[must_use]
pub fn add(&self, left: &ASTRepr<f64>, right: &ASTRepr<f64>) -> ASTRepr<f64> {
left + right
}
#[must_use]
pub fn sub(&self, left: &ASTRepr<f64>, right: &ASTRepr<f64>) -> ASTRepr<f64> {
left - right
}
#[must_use]
pub fn mul(&self, left: &ASTRepr<f64>, right: &ASTRepr<f64>) -> ASTRepr<f64> {
left * right
}
#[must_use]
pub fn div(&self, left: &ASTRepr<f64>, right: &ASTRepr<f64>) -> ASTRepr<f64> {
left / right
}
#[must_use]
pub fn pow(&self, base: &ASTRepr<f64>, exp: &ASTRepr<f64>) -> ASTRepr<f64> {
base.pow_ref(exp)
}
#[must_use]
pub fn neg(&self, expr: &ASTRepr<f64>) -> ASTRepr<f64> {
-expr
}
#[must_use]
pub fn ln(&self, expr: &ASTRepr<f64>) -> ASTRepr<f64> {
expr.ln_ref()
}
#[must_use]
pub fn exp(&self, expr: &ASTRepr<f64>) -> ASTRepr<f64> {
expr.exp_ref()
}
#[must_use]
pub fn sin(&self, expr: &ASTRepr<f64>) -> ASTRepr<f64> {
expr.sin_ref()
}
#[must_use]
pub fn cos(&self, expr: &ASTRepr<f64>) -> ASTRepr<f64> {
expr.cos_ref()
}
#[must_use]
pub fn sqrt(&self, expr: &ASTRepr<f64>) -> ASTRepr<f64> {
expr.sqrt_ref()
}
#[must_use]
pub fn poly(&self, coefficients: &[f64], variable: &ASTRepr<f64>) -> ASTRepr<f64> {
if coefficients.is_empty() {
return ASTRepr::Constant(0.0);
}
if coefficients.len() == 1 {
return ASTRepr::Constant(coefficients[0]);
}
let mut result = ASTRepr::Constant(coefficients[coefficients.len() - 1]);
for &coeff in coefficients.iter().rev().skip(1) {
result = self.add(&self.mul(&result, variable), &ASTRepr::Constant(coeff));
}
result
}
#[must_use]
pub fn linear(&self, a: f64, b: f64, variable: &ASTRepr<f64>) -> ASTRepr<f64> {
self.poly(&[b, a], variable)
}
#[must_use]
pub fn quadratic(&self, a: f64, b: f64, c: f64, variable: &ASTRepr<f64>) -> ASTRepr<f64> {
self.poly(&[c, b, a], variable)
}
#[must_use]
pub fn gaussian(&self, mean: f64, std_dev: f64, variable: &ASTRepr<f64>) -> ASTRepr<f64> {
let pi = self.math_constant("pi").unwrap();
let two = ASTRepr::Constant(2.0);
let sigma_squared = ASTRepr::Constant(std_dev * std_dev);
let norm_factor = self.div(
&ASTRepr::Constant(1.0),
&self.sqrt(&self.mul(&self.mul(&two, &pi), &sigma_squared)),
);
let x_minus_mu = self.sub(variable, &ASTRepr::Constant(mean));
let x_minus_mu_squared = self.pow(&x_minus_mu, &ASTRepr::Constant(2.0));
let exponent = self.neg(&self.div(&x_minus_mu_squared, &self.mul(&two, &sigma_squared)));
self.mul(&norm_factor, &self.exp(&exponent))
}
#[must_use]
pub fn logistic(&self, variable: &ASTRepr<f64>) -> ASTRepr<f64> {
let one = ASTRepr::Constant(1.0);
let neg_x = self.neg(variable);
let exp_neg_x = self.exp(&neg_x);
let denominator = self.add(&one, &exp_neg_x);
self.div(&one, &denominator)
}
#[must_use]
pub fn tanh(&self, variable: &ASTRepr<f64>) -> ASTRepr<f64> {
let two = ASTRepr::Constant(2.0);
let one = ASTRepr::Constant(1.0);
let two_x = self.mul(&two, variable);
let exp_2x = self.exp(&two_x);
let numerator = self.sub(&exp_2x, &one);
let denominator = self.add(&exp_2x, &one);
self.div(&numerator, &denominator)
}
#[must_use]
pub fn eval(&self, expr: &ASTRepr<f64>, variables: &[(&str, f64)]) -> f64 {
let named_vars: Vec<(String, f64)> = variables
.iter()
.map(|(name, value)| ((*name).to_string(), *value))
.collect();
self.builder.eval_with_named_vars(expr, &named_vars)
}
pub fn optimize(&mut self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
if let Some(ref mut optimizer) = self.optimizer {
optimizer.optimize(expr)
} else {
Ok(expr.clone())
}
}
pub fn derivative(&mut self, expr: &ASTRepr<f64>, var_name: &str) -> Result<ASTRepr<f64>> {
let var_index = self.builder.get_variable_index(var_name).ok_or_else(|| {
MathCompileError::InvalidInput(format!("Variable {var_name} not found in registry"))
})?;
let mut config = crate::symbolic_ad::SymbolicADConfig::default();
config.num_variables = self.builder.num_variables();
let mut ad = SymbolicAD::with_config(config)?;
let result = ad.compute_with_derivatives(expr)?;
let var_index_str = var_index.to_string();
result
.first_derivatives
.get(&var_index_str)
.cloned()
.ok_or_else(|| {
MathCompileError::InvalidInput(format!(
"Variable index {var_index} not found in derivatives"
))
})
}
pub fn gradient(&mut self, expr: &ASTRepr<f64>) -> Result<HashMap<String, ASTRepr<f64>>> {
let mut config = crate::symbolic_ad::SymbolicADConfig::default();
config.num_variables = self.builder.num_variables();
let mut ad = SymbolicAD::with_config(config)?;
let result = ad.compute_with_derivatives(expr)?;
let mut named_derivatives = HashMap::new();
for (index_str, derivative) in result.first_derivatives {
if let Ok(index) = index_str.parse::<usize>() {
if let Some(var_name) = self.builder.get_variable_name(index) {
named_derivatives.insert(var_name.to_string(), derivative);
}
}
}
Ok(named_derivatives)
}
#[must_use]
pub fn num_variables(&self) -> usize {
self.builder.num_variables()
}
#[must_use]
pub fn variable_names(&self) -> &[String] {
self.builder.variable_names()
}
#[must_use]
pub fn registry(&self) -> &VariableRegistry {
self.builder.registry()
}
pub fn clear(&mut self) {
self.builder = ExpressionBuilder::new();
}
pub fn validate(&self, expr: &ASTRepr<f64>) -> Result<()> {
self.validate_recursive(expr)
}
fn validate_recursive(&self, expr: &ASTRepr<f64>) -> Result<()> {
match expr {
ASTRepr::Constant(value) => {
if value.is_nan() || value.is_infinite() {
return Err(MathCompileError::InvalidInput(format!(
"Invalid constant value: {value}"
)));
}
}
ASTRepr::Variable(index) => {
if *index >= self.builder.num_variables() {
return Err(MathCompileError::InvalidInput(format!(
"Variable index {index} is out of bounds (max: {})",
self.builder.num_variables()
)));
}
}
ASTRepr::Add(left, right)
| ASTRepr::Sub(left, right)
| ASTRepr::Mul(left, right)
| ASTRepr::Div(left, right)
| ASTRepr::Pow(left, right) => {
self.validate_recursive(left)?;
self.validate_recursive(right)?;
if matches!(expr, ASTRepr::Div(_, _right)) {
if let ASTRepr::Constant(value) = right.as_ref() {
if value.abs() < f64::EPSILON {
return Err(MathCompileError::InvalidInput(
"Division by zero constant detected".to_string(),
));
}
}
}
}
ASTRepr::Neg(inner)
| ASTRepr::Ln(inner)
| ASTRepr::Exp(inner)
| ASTRepr::Sin(inner)
| ASTRepr::Cos(inner)
| ASTRepr::Sqrt(inner) => {
self.validate_recursive(inner)?;
}
}
Ok(())
}
}
impl Default for MathBuilder {
fn default() -> Self {
Self::new()
}
}
pub fn quick_eval(_expression: &str, _variables: &[(&str, f64)]) -> Result<f64> {
Err(MathCompileError::InvalidInput(
"Expression parsing not yet implemented. Please use MathBuilder for now.".to_string(),
))
}
pub mod presets {
use super::{ASTRepr, MathBuilder};
#[must_use]
pub fn standard_normal(math: &MathBuilder, variable: &ASTRepr<f64>) -> ASTRepr<f64> {
math.gaussian(0.0, 1.0, variable)
}
#[must_use]
pub fn relu(math: &MathBuilder, variable: &ASTRepr<f64>) -> ASTRepr<f64> {
let exp_x = math.exp(variable);
let one_plus_exp_x = math.add(&ASTRepr::Constant(1.0), &exp_x);
math.ln(&one_plus_exp_x)
}
#[must_use]
pub fn mse_loss(
math: &MathBuilder,
y_pred: &ASTRepr<f64>,
y_true: &ASTRepr<f64>,
) -> ASTRepr<f64> {
let diff = math.sub(y_pred, y_true);
math.pow(&diff, &ASTRepr::Constant(2.0))
}
#[must_use]
pub fn cross_entropy_loss(
math: &MathBuilder,
y_pred: &ASTRepr<f64>,
y_true: &ASTRepr<f64>,
) -> ASTRepr<f64> {
let ln_pred = math.ln(y_pred);
let product = math.mul(y_true, &ln_pred);
math.neg(&product)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_arithmetic() {
let mut math = MathBuilder::new();
let x = math.var("x");
let y = math.var("y");
let expr = math.add(&x, &y);
let result = math.eval(&expr, &[("x", 3.0), ("y", 4.0)]);
assert_eq!(result, 7.0);
}
#[test]
fn test_polynomial() {
let mut math = MathBuilder::new();
let x = math.var("x");
let poly = math.poly(&[1.0, 3.0, 2.0], &x);
let result = math.eval(&poly, &[("x", 2.0)]);
assert_eq!(result, 15.0);
}
#[test]
fn test_quadratic() {
let mut math = MathBuilder::new();
let x = math.var("x");
let quad = math.quadratic(2.0, -3.0, 1.0, &x);
let result = math.eval(&quad, &[("x", 2.0)]);
assert_eq!(result, 3.0);
}
#[test]
fn test_linear() {
let mut math = MathBuilder::new();
let x = math.var("x");
let linear = math.linear(2.0, 3.0, &x);
let result = math.eval(&linear, &[("x", 4.0)]);
assert_eq!(result, 11.0);
}
#[test]
fn test_gaussian() {
let mut math = MathBuilder::new();
let x = math.var("x");
let gaussian = math.gaussian(0.0, 1.0, &x);
let result = math.eval(&gaussian, &[("x", 0.0)]);
assert!((result - 0.3989).abs() < 0.001);
}
#[test]
fn test_logistic() {
let mut math = MathBuilder::new();
let x = math.var("x");
let logistic = math.logistic(&x);
let result = math.eval(&logistic, &[("x", 0.0)]);
assert!((result - 0.5).abs() < 1e-10);
}
#[test]
fn test_math_constants() {
let math = MathBuilder::new();
let pi = math.math_constant("pi").unwrap();
let e = math.math_constant("e").unwrap();
if let ASTRepr::Constant(pi_val) = pi {
assert!((pi_val - std::f64::consts::PI).abs() < 1e-10);
}
if let ASTRepr::Constant(e_val) = e {
assert!((e_val - std::f64::consts::E).abs() < 1e-10);
}
}
#[test]
fn test_validation() {
let math = MathBuilder::new();
let valid = ASTRepr::Constant(42.0);
assert!(math.validate(&valid).is_ok());
let invalid = ASTRepr::Constant(f64::NAN);
assert!(math.validate(&invalid).is_err());
let div_zero = math.div(&ASTRepr::Constant(1.0), &ASTRepr::Constant(0.0));
assert!(math.validate(&div_zero).is_err());
}
}