use crate::error::Result;
use crate::final_tagless::ASTRepr;
use crate::symbolic::SymbolicOptimizer;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct SymbolicADConfig {
pub pre_optimize: bool,
pub post_optimize: bool,
pub share_subexpressions: bool,
pub max_derivative_order: usize,
pub num_variables: usize,
}
impl Default for SymbolicADConfig {
fn default() -> Self {
Self {
pre_optimize: true,
post_optimize: true,
share_subexpressions: true,
max_derivative_order: 1,
num_variables: 1,
}
}
}
#[derive(Debug, Clone)]
pub struct FunctionWithDerivatives<T> {
pub function: ASTRepr<T>,
pub first_derivatives: HashMap<String, ASTRepr<T>>,
pub second_derivatives: HashMap<(String, String), ASTRepr<T>>,
pub shared_subexpressions: HashMap<String, ASTRepr<T>>,
pub stats: SymbolicADStats,
}
#[derive(Debug, Clone, Default)]
pub struct SymbolicADStats {
pub function_operations_before: usize,
pub function_operations_after: usize,
pub total_operations_before: usize,
pub total_operations_after: usize,
pub shared_subexpressions_count: usize,
pub stage_times_us: [u64; 3],
}
impl SymbolicADStats {
#[must_use]
pub fn function_optimization_ratio(&self) -> f64 {
if self.function_operations_before == 0 {
1.0
} else {
self.function_operations_after as f64 / self.function_operations_before as f64
}
}
#[must_use]
pub fn total_optimization_ratio(&self) -> f64 {
if self.total_operations_before == 0 {
1.0
} else {
self.total_operations_after as f64 / self.total_operations_before as f64
}
}
#[must_use]
pub fn total_time_us(&self) -> u64 {
self.stage_times_us.iter().sum()
}
#[must_use]
pub fn optimization_ratio(&self) -> f64 {
self.function_optimization_ratio()
}
#[must_use]
pub fn operations_before(&self) -> usize {
self.function_operations_before
}
#[must_use]
pub fn operations_after(&self) -> usize {
self.function_operations_after
}
}
pub struct SymbolicAD {
config: SymbolicADConfig,
optimizer: SymbolicOptimizer,
derivative_cache: HashMap<String, ASTRepr<f64>>,
}
impl SymbolicAD {
pub fn new() -> Result<Self> {
Self::with_config(SymbolicADConfig::default())
}
pub fn with_config(config: SymbolicADConfig) -> Result<Self> {
let optimizer = SymbolicOptimizer::new()?;
Ok(Self {
config,
optimizer,
derivative_cache: HashMap::new(),
})
}
pub fn compute_with_derivatives(
&mut self,
expr: &ASTRepr<f64>,
) -> Result<FunctionWithDerivatives<f64>> {
let _start_time = std::time::Instant::now();
let mut stats = SymbolicADStats::default();
stats.function_operations_before = expr.count_operations();
stats.total_operations_before = expr.count_operations();
let stage1_start = std::time::Instant::now();
let pre_optimized = if self.config.pre_optimize {
self.optimizer.optimize(expr)?
} else {
expr.clone()
};
stats.stage_times_us[0] = stage1_start.elapsed().as_micros() as u64;
let stage2_start = std::time::Instant::now();
let mut first_derivatives = HashMap::new();
let mut second_derivatives = HashMap::new();
let variables = self.config.num_variables;
for var in 0..variables {
let derivative = self.symbolic_derivative(&pre_optimized, var)?;
first_derivatives.insert(var.to_string(), derivative);
}
if self.config.max_derivative_order >= 2 {
for var1 in 0..variables {
for var2 in 0..variables {
if let Some(first_deriv) = first_derivatives.get(&var1.to_string()) {
let second_deriv = self.symbolic_derivative(first_deriv, var2)?;
second_derivatives
.insert((var1.to_string(), var2.to_string()), second_deriv);
}
}
}
}
stats.total_operations_before = stats.function_operations_before
+ first_derivatives
.values()
.map(super::final_tagless::ASTRepr::count_operations)
.sum::<usize>()
+ second_derivatives
.values()
.map(super::final_tagless::ASTRepr::count_operations)
.sum::<usize>();
stats.stage_times_us[1] = stage2_start.elapsed().as_micros() as u64;
let stage3_start = std::time::Instant::now();
let (
optimized_function,
optimized_derivatives,
optimized_second_derivatives,
shared_subexpressions,
) = if self.config.post_optimize && self.config.share_subexpressions {
self.optimize_with_subexpression_sharing(
&pre_optimized,
&first_derivatives,
&second_derivatives,
)?
} else if self.config.post_optimize {
let opt_func = self.optimizer.optimize(&pre_optimized)?;
let mut opt_first = HashMap::new();
for (var, deriv) in &first_derivatives {
opt_first.insert(var.clone(), self.optimizer.optimize(deriv)?);
}
let mut opt_second = HashMap::new();
for ((var1, var2), deriv) in &second_derivatives {
opt_second.insert(
(var1.clone(), var2.clone()),
self.optimizer.optimize(deriv)?,
);
}
(opt_func, opt_first, opt_second, HashMap::new())
} else {
(
pre_optimized,
first_derivatives,
second_derivatives,
HashMap::new(),
)
};
stats.stage_times_us[2] = stage3_start.elapsed().as_micros() as u64;
stats.function_operations_after = optimized_function.count_operations();
stats.total_operations_after = stats.function_operations_after
+ optimized_derivatives
.values()
.map(super::final_tagless::ASTRepr::count_operations)
.sum::<usize>()
+ optimized_second_derivatives
.values()
.map(super::final_tagless::ASTRepr::count_operations)
.sum::<usize>();
stats.shared_subexpressions_count = shared_subexpressions.len();
Ok(FunctionWithDerivatives {
function: optimized_function,
first_derivatives: optimized_derivatives,
second_derivatives: optimized_second_derivatives,
shared_subexpressions,
stats,
})
}
fn symbolic_derivative(&mut self, expr: &ASTRepr<f64>, var: usize) -> Result<ASTRepr<f64>> {
let cache_key = format!("{expr:?}_{var}");
if let Some(cached) = self.derivative_cache.get(&cache_key) {
return Ok(cached.clone());
}
let derivative = self.compute_derivative_recursive(expr, var)?;
self.derivative_cache.insert(cache_key, derivative.clone());
Ok(derivative)
}
fn compute_derivative_recursive(
&self,
expr: &ASTRepr<f64>,
var: usize,
) -> Result<ASTRepr<f64>> {
match expr {
ASTRepr::Constant(_) => Ok(ASTRepr::Constant(0.0)),
ASTRepr::Variable(index) => {
if *index == var {
Ok(ASTRepr::Constant(1.0))
} else {
Ok(ASTRepr::Constant(0.0))
}
}
ASTRepr::Add(left, right) => {
let left_deriv = self.compute_derivative_recursive(left, var)?;
let right_deriv = self.compute_derivative_recursive(right, var)?;
Ok(ASTRepr::Add(Box::new(left_deriv), Box::new(right_deriv)))
}
ASTRepr::Sub(left, right) => {
let left_deriv = self.compute_derivative_recursive(left, var)?;
let right_deriv = self.compute_derivative_recursive(right, var)?;
Ok(ASTRepr::Sub(Box::new(left_deriv), Box::new(right_deriv)))
}
ASTRepr::Mul(left, right) => {
let left_deriv = self.compute_derivative_recursive(left, var)?;
let right_deriv = self.compute_derivative_recursive(right, var)?;
let term1 = ASTRepr::Mul(left.clone(), Box::new(right_deriv));
let term2 = ASTRepr::Mul(right.clone(), Box::new(left_deriv));
Ok(ASTRepr::Add(Box::new(term1), Box::new(term2)))
}
ASTRepr::Div(left, right) => {
let left_deriv = self.compute_derivative_recursive(left, var)?;
let right_deriv = self.compute_derivative_recursive(right, var)?;
let numerator_term1 = ASTRepr::Mul(right.clone(), Box::new(left_deriv));
let numerator_term2 = ASTRepr::Mul(left.clone(), Box::new(right_deriv));
let numerator = ASTRepr::Sub(Box::new(numerator_term1), Box::new(numerator_term2));
let denominator = ASTRepr::Mul(right.clone(), right.clone());
Ok(ASTRepr::Div(Box::new(numerator), Box::new(denominator)))
}
ASTRepr::Pow(base, exp) => {
let base_deriv = self.compute_derivative_recursive(base, var)?;
let exp_deriv = self.compute_derivative_recursive(exp, var)?;
let ln_base = ASTRepr::Ln(base.clone());
let term1 = ASTRepr::Mul(Box::new(exp_deriv), Box::new(ln_base));
let u_prime_over_u = ASTRepr::Div(Box::new(base_deriv), base.clone());
let term2 = ASTRepr::Mul(exp.clone(), Box::new(u_prime_over_u));
let sum = ASTRepr::Add(Box::new(term1), Box::new(term2));
let original_power = ASTRepr::Pow(base.clone(), exp.clone());
Ok(ASTRepr::Mul(Box::new(original_power), Box::new(sum)))
}
ASTRepr::Neg(inner) => {
let inner_deriv = self.compute_derivative_recursive(inner, var)?;
Ok(ASTRepr::Neg(Box::new(inner_deriv)))
}
ASTRepr::Ln(inner) => {
let inner_deriv = self.compute_derivative_recursive(inner, var)?;
Ok(ASTRepr::Div(Box::new(inner_deriv), inner.clone()))
}
ASTRepr::Exp(inner) => {
let inner_deriv = self.compute_derivative_recursive(inner, var)?;
let exp_inner = ASTRepr::Exp(inner.clone());
Ok(ASTRepr::Mul(Box::new(exp_inner), Box::new(inner_deriv)))
}
ASTRepr::Sqrt(inner) => {
let inner_deriv = self.compute_derivative_recursive(inner, var)?;
let sqrt_inner = ASTRepr::Sqrt(inner.clone());
let two = ASTRepr::Constant(2.0);
let denominator = ASTRepr::Mul(Box::new(two), Box::new(sqrt_inner));
Ok(ASTRepr::Div(Box::new(inner_deriv), Box::new(denominator)))
}
ASTRepr::Sin(inner) => {
let inner_deriv = self.compute_derivative_recursive(inner, var)?;
let cos_inner = ASTRepr::Cos(inner.clone());
Ok(ASTRepr::Mul(Box::new(cos_inner), Box::new(inner_deriv)))
}
ASTRepr::Cos(inner) => {
let inner_deriv = self.compute_derivative_recursive(inner, var)?;
let sin_inner = ASTRepr::Sin(inner.clone());
let neg_sin = ASTRepr::Neg(Box::new(sin_inner));
Ok(ASTRepr::Mul(Box::new(neg_sin), Box::new(inner_deriv)))
}
}
}
fn optimize_with_subexpression_sharing(
&mut self,
function: &ASTRepr<f64>,
first_derivatives: &HashMap<String, ASTRepr<f64>>,
second_derivatives: &HashMap<(String, String), ASTRepr<f64>>,
) -> Result<(
ASTRepr<f64>,
HashMap<String, ASTRepr<f64>>,
HashMap<(String, String), ASTRepr<f64>>,
HashMap<String, ASTRepr<f64>>,
)> {
let optimized_function = self.optimizer.optimize(function)?;
let mut optimized_first = HashMap::new();
for (var, deriv) in first_derivatives {
optimized_first.insert(var.clone(), self.optimizer.optimize(deriv)?);
}
let mut optimized_second = HashMap::new();
for ((var1, var2), deriv) in second_derivatives {
optimized_second.insert(
(var1.clone(), var2.clone()),
self.optimizer.optimize(deriv)?,
);
}
let shared_subexpressions = HashMap::new();
Ok((
optimized_function,
optimized_first,
optimized_second,
shared_subexpressions,
))
}
#[must_use]
pub fn config(&self) -> &SymbolicADConfig {
&self.config
}
pub fn set_config(&mut self, config: SymbolicADConfig) {
self.config = config;
}
pub fn clear_cache(&mut self) {
self.derivative_cache.clear();
}
#[must_use]
pub fn cache_stats(&self) -> (usize, usize) {
(
self.derivative_cache.len(),
self.derivative_cache.capacity(),
)
}
}
impl Default for SymbolicAD {
fn default() -> Self {
Self::new().expect("Failed to create default SymbolicAD")
}
}
pub mod convenience {
use super::{ASTRepr, HashMap, Result, SymbolicAD, SymbolicADConfig};
use crate::final_tagless::{ASTEval, ASTMathExpr};
pub fn gradient(
expr: &ASTRepr<f64>,
variables: &[&str],
) -> Result<HashMap<String, ASTRepr<f64>>> {
let mut config = SymbolicADConfig::default();
config.num_variables = variables.len();
let mut ad = SymbolicAD::with_config(config)?;
let result = ad.compute_with_derivatives(expr)?;
Ok(result.first_derivatives)
}
pub fn hessian(
expr: &ASTRepr<f64>,
variables: &[&str],
) -> Result<HashMap<(String, String), ASTRepr<f64>>> {
let mut config = SymbolicADConfig::default();
config.num_variables = variables.len();
config.max_derivative_order = 2;
let mut ad = SymbolicAD::with_config(config)?;
let result = ad.compute_with_derivatives(expr)?;
Ok(result.second_derivatives)
}
#[must_use]
pub fn quadratic(a: f64, b: f64, c: f64) -> ASTRepr<f64> {
let x = ASTEval::var(0); let x_squared = ASTEval::pow(x.clone(), ASTEval::constant(2.0));
ASTEval::add(
ASTEval::add(
ASTEval::mul(ASTEval::constant(a), x_squared),
ASTEval::mul(ASTEval::constant(b), x),
),
ASTEval::constant(c),
)
}
#[must_use]
pub fn bivariate_quadratic(a: f64, b: f64, c: f64, d: f64, e: f64, f: f64) -> ASTRepr<f64> {
let x = ASTEval::var(0); let y = ASTEval::var(1);
let x_squared = ASTEval::pow(x.clone(), ASTEval::constant(2.0));
let y_squared = ASTEval::pow(y.clone(), ASTEval::constant(2.0));
let xy = ASTEval::mul(x.clone(), y.clone());
ASTEval::add(
ASTEval::add(
ASTEval::add(
ASTEval::add(
ASTEval::add(
ASTEval::mul(ASTEval::constant(a), x_squared),
ASTEval::mul(ASTEval::constant(b), xy),
),
ASTEval::mul(ASTEval::constant(c), y_squared),
),
ASTEval::mul(ASTEval::constant(d), x),
),
ASTEval::mul(ASTEval::constant(e), y),
),
ASTEval::constant(f),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::final_tagless::{ASTEval, ASTMathExpr, DirectEval};
#[test]
fn test_symbolic_ad_creation() {
let ad = SymbolicAD::new();
assert!(ad.is_ok());
let config = SymbolicADConfig {
num_variables: 2,
max_derivative_order: 2,
..Default::default()
};
let ad_with_config = SymbolicAD::with_config(config);
assert!(ad_with_config.is_ok());
}
#[test]
fn test_basic_derivative_rules() {
let mut ad = SymbolicAD::new().unwrap();
let x = ASTEval::var(0); let dx = ad.symbolic_derivative(&x, 0).unwrap();
match dx {
ASTRepr::Constant(val) => assert_eq!(val, 1.0),
_ => panic!("Expected constant 1.0"),
}
let constant = ASTEval::constant(5.0);
let dc = ad.symbolic_derivative(&constant, 0).unwrap();
match dc {
ASTRepr::Constant(val) => assert_eq!(val, 0.0),
_ => panic!("Expected constant 0.0"),
}
let y = ASTEval::var(1); let dy = ad.symbolic_derivative(&y, 0).unwrap(); match dy {
ASTRepr::Constant(val) => assert_eq!(val, 0.0),
_ => panic!("Expected constant 0.0"),
}
}
#[test]
fn test_arithmetic_derivative_rules() {
let mut ad = SymbolicAD::new().unwrap();
let expr = ASTEval::add(ASTEval::var(0), ASTEval::constant(2.0));
let derivative = ad.symbolic_derivative(&expr, 0).unwrap();
match &derivative {
ASTRepr::Add(left, right) => match (left.as_ref(), right.as_ref()) {
(ASTRepr::Constant(1.0), ASTRepr::Constant(0.0)) => {}
_ => panic!("Expected Add(1.0, 0.0), got {derivative:?}"),
},
_ => panic!("Expected addition, got {derivative:?}"),
}
}
#[test]
fn test_product_rule() {
let mut ad = SymbolicAD::new().unwrap();
let x = ASTEval::var(0);
let x_squared = ASTEval::mul(x.clone(), x);
let derivative = ad.symbolic_derivative(&x_squared, 0).unwrap();
match derivative {
ASTRepr::Add(_, _) => {
let result = DirectEval::eval_two_vars(&derivative, 3.0, 0.0);
assert_eq!(result, 6.0);
}
_ => panic!("Expected addition for product rule"),
}
}
#[test]
fn test_chain_rule() {
let mut ad = SymbolicAD::new().unwrap();
let sin_x = ASTEval::sin(ASTEval::var(0));
let derivative = ad.symbolic_derivative(&sin_x, 0).unwrap();
match &derivative {
ASTRepr::Mul(left, right) => match (left.as_ref(), right.as_ref()) {
(ASTRepr::Cos(_), ASTRepr::Constant(1.0)) => {}
(ASTRepr::Constant(1.0), ASTRepr::Cos(_)) => {}
_ => panic!("Expected cos(x) * 1, got {derivative:?}"),
},
_ => panic!("Expected multiplication for chain rule"),
}
}
#[test]
fn test_convenience_functions() {
let quadratic = convenience::quadratic(2.0, 3.0, 1.0); let grad = convenience::gradient(&quadratic, &["0"]).unwrap();
assert!(grad.contains_key("0"));
let derivative = &grad["0"];
let result_at_2 = DirectEval::eval_two_vars(derivative, 2.0, 0.0);
assert_eq!(result_at_2, 11.0);
let bivariate = convenience::bivariate_quadratic(1.0, 2.0, 1.0, 0.0, 0.0, 0.0); let grad_biv = convenience::gradient(&bivariate, &["0", "1"]).unwrap();
assert!(grad_biv.contains_key("0"));
assert!(grad_biv.contains_key("1"));
let dx_at_1_2 = DirectEval::eval_two_vars(&grad_biv["0"], 1.0, 2.0);
let dy_at_1_2 = DirectEval::eval_two_vars(&grad_biv["1"], 1.0, 2.0);
assert_eq!(dx_at_1_2, 6.0); assert_eq!(dy_at_1_2, 6.0); }
#[test]
fn test_full_pipeline() {
let mut ad = SymbolicAD::new().unwrap();
let expr = ASTEval::add(
ASTEval::mul(ASTEval::var(0), ASTEval::constant(0.0)), ASTEval::pow(ASTEval::var(0), ASTEval::constant(2.0)), );
let result = ad.compute_with_derivatives(&expr).unwrap();
assert!(result.first_derivatives.contains_key("0"));
assert!(result.stats.function_operations_before > 0);
assert!(result.stats.function_operations_after > 0);
println!(
"Original operations: {}",
result.stats.function_operations_before
);
println!(
"Optimized operations: {}",
result.stats.function_operations_after
);
println!(
"Function optimization ratio: {:.2}",
result.stats.function_optimization_ratio()
);
println!("Total time: {} μs", result.stats.total_time_us());
}
#[test]
fn test_cache_functionality() {
let mut ad = SymbolicAD::new().unwrap();
let expr = ASTEval::pow(ASTEval::var(0), ASTEval::constant(2.0));
let _deriv1 = ad.symbolic_derivative(&expr, 0).unwrap();
let (cache_size_1, _) = ad.cache_stats();
let _deriv2 = ad.symbolic_derivative(&expr, 0).unwrap();
let (cache_size_2, _) = ad.cache_stats();
assert!(cache_size_1 > 0);
assert_eq!(cache_size_1, cache_size_2);
}
}