use std::collections::HashMap;
use std::sync::Arc;
use serde::Deserialize;
use serde::Serialize;
use crate::symbolic::calculus::differentiate;
use crate::symbolic::core::Expr;
use crate::symbolic::matrix::eigen_decomposition;
use crate::symbolic::simplify_dag::simplify;
use crate::symbolic::solve::solve_system;
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum ExtremumType {
LocalMin,
LocalMax,
SaddlePoint,
Unknown,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CriticalPoint {
pub point: HashMap<Expr, Expr>,
pub point_type: ExtremumType,
}
pub fn find_extrema(
f: &Expr,
vars: &[&str],
) -> Result<Vec<CriticalPoint>, String> {
let mut grad_eqs = Vec::new();
for &var in vars {
let deriv = differentiate(f, var);
grad_eqs.push(Expr::Eq(Arc::new(deriv), Arc::new(Expr::Constant(0.0))));
}
let critical_points_sol = match solve_system(&grad_eqs, vars) {
| Some(sol) => sol,
| None => {
return Ok(vec![]);
},
};
let crit_point_map: HashMap<Expr, Expr> = critical_points_sol.into_iter().collect();
let hessian = hessian_matrix(f, vars);
let mut hessian_at_point = hessian;
for (var, val) in &crit_point_map {
hessian_at_point =
crate::symbolic::calculus::substitute(&hessian_at_point, &var.to_string(), val);
}
let (eigenvalues_expr, _) = eigen_decomposition(&hessian_at_point)?;
if let Expr::Matrix(eig_rows) = eigenvalues_expr {
let eigenvalues: Vec<f64> = eig_rows
.iter()
.flatten()
.map(|e| evaluate_constant_expr(e).unwrap_or(f64::NAN))
.collect();
let point_type = if eigenvalues.iter().all(|&v| v > 0.0) {
ExtremumType::LocalMin
} else if eigenvalues.iter().all(|&v| v < 0.0) {
ExtremumType::LocalMax
} else if eigenvalues.iter().any(|&v| v > 0.0) && eigenvalues.iter().any(|&v| v < 0.0) {
ExtremumType::SaddlePoint
} else {
ExtremumType::Unknown
};
Ok(vec![CriticalPoint {
point: crit_point_map,
point_type,
}])
} else {
Err("Could not determine \
eigenvalues of the \
Hessian."
.to_string())
}
}
pub(crate) fn evaluate_constant_expr(expr: &Expr) -> Option<f64> {
use num_traits::ToPrimitive;
match expr {
| Expr::Constant(c) => Some(*c),
| Expr::BigInt(i) => i.to_f64(),
| Expr::Rational(r) => r.to_f64(),
| Expr::Add(a, b) => Some(evaluate_constant_expr(a)? + evaluate_constant_expr(b)?),
| Expr::Sub(a, b) => Some(evaluate_constant_expr(a)? - evaluate_constant_expr(b)?),
| Expr::Mul(a, b) => Some(evaluate_constant_expr(a)? * evaluate_constant_expr(b)?),
| Expr::Div(a, b) => Some(evaluate_constant_expr(a)? / evaluate_constant_expr(b)?),
| Expr::Neg(a) => Some(-evaluate_constant_expr(a)?),
| Expr::Power(a, b) => Some(evaluate_constant_expr(a)?.powf(evaluate_constant_expr(b)?)),
| Expr::Sqrt(a) => Some(evaluate_constant_expr(a)?.sqrt()),
| Expr::Dag(node) => evaluate_constant_expr(&node.to_expr().ok()?),
| _ => None,
}
}
#[must_use]
pub fn hessian_matrix(
f: &Expr,
vars: &[&str],
) -> Expr {
let n = vars.len();
let mut matrix = vec![vec![Expr::Constant(0.0); n]; n];
for i in 0..n {
for j in 0..n {
let df_dxi = differentiate(f, vars[i]);
let d2f_dxj_dxi = differentiate(&df_dxi, vars[j]);
matrix[i][j] = simplify(&d2f_dxj_dxi);
}
}
Expr::Matrix(matrix)
}
pub fn find_constrained_extrema(
f: &Expr,
constraints: &[Expr],
vars: &[&str],
) -> Result<Vec<HashMap<Expr, Expr>>, String> {
let mut lambda_vars = Vec::new();
for i in 0..constraints.len() {
lambda_vars.push(format!("lambda_{i}"));
}
let mut lagrangian = f.clone();
for (i, g) in constraints.iter().enumerate() {
let lambda_i = Expr::Variable(lambda_vars[i].clone());
let term = Expr::new_mul(lambda_i, g.clone());
lagrangian = simplify(&Expr::new_sub(lagrangian, term));
}
let mut all_vars: Vec<&str> = vars.to_vec();
let lambda_vars_str: Vec<&str> = lambda_vars
.iter()
.map(std::string::String::as_str)
.collect();
all_vars.extend(&lambda_vars_str);
let mut grad_eqs = Vec::new();
for &var in &all_vars {
let deriv = differentiate(&lagrangian, var);
grad_eqs.push(Expr::Eq(Arc::new(deriv), Arc::new(Expr::Constant(0.0))));
}
match solve_system(&grad_eqs, &all_vars) {
| Some(solution) => Ok(vec![solution.into_iter().collect()]),
| None => Ok(vec![]),
}
}