use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use evalexpr::{build_operator_tree, Node, Operator};
use crate::backends::vector::Vector;
use crate::builder::build_function;
use crate::convert::build_ast;
use crate::errors::EquationError;
use crate::expr::Expr;
use crate::system::OutputType;
use crate::types::JITFunction;
use crate::EquationSystem;
use colored::Colorize;
use itertools::Itertools;
pub struct Equation {
equation_str: String,
ast: Box<Expr>,
fun: JITFunction,
derivatives_first_order: HashMap<String, JITFunction>,
derivatives_second_order: Vec<Vec<JITFunction>>,
var_map: HashMap<String, u32>,
sorted_variables: Vec<String>,
}
impl std::fmt::Debug for Equation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "{{\n")?;
writeln!(f, " {}: {}\n", "Equation".cyan(), self.equation_str)?;
writeln!(f, " {}: {:?}\n", "Variables".cyan(), self.var_map)?;
writeln!(
f,
" {}: {:?}\n",
"Sorted Variables".cyan(),
self.sorted_variables
)?;
writeln!(f, "}}")?;
Ok(())
}
}
impl std::fmt::Display for Equation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "{{\n")?;
writeln!(f, " {}: {}\n", "Equation".cyan(), self.equation_str)?;
writeln!(f, " {}: {:?}\n", "Variables".cyan(), self.var_map)?;
writeln!(
f,
" {}: {:?}\n",
"Sorted Variables".cyan(),
self.sorted_variables
)?;
writeln!(f, "}}")?;
Ok(())
}
}
impl Equation {
pub fn new(equation_str: String) -> Result<Self, EquationError> {
let node = build_operator_tree(&equation_str)?;
let variables = extract_symbols(&node);
Self::build(&variables, equation_str)
}
pub fn from_var_map(
equation_str: String,
variables: &HashMap<String, u32>,
) -> Result<Self, EquationError> {
Self::build(variables, equation_str)
}
fn build(
variables: &HashMap<String, u32>,
equation_str: String,
) -> Result<Self, EquationError> {
let node = build_operator_tree(&equation_str)?;
let mut non_defined_variables = HashSet::new();
let control_variables = extract_symbols(&node);
for variable in control_variables.keys() {
if !variables.contains_key(variable) {
non_defined_variables.insert(variable.clone());
}
}
if !non_defined_variables.is_empty() {
return Err(EquationError::VariableNotFound(
non_defined_variables
.into_iter()
.collect::<Vec<String>>()
.join(", "),
));
}
let sorted_variables: Vec<String> = variables
.iter()
.sorted_by_key(|(_, &idx)| idx)
.map(|(var, _)| var.clone())
.collect();
let ast = build_ast(&node, variables)?;
let ast = *ast.simplify();
let fun = build_function(ast.clone())?;
let mut derivatives_first_order = HashMap::new();
for variable in sorted_variables.iter() {
let derivative = ast.derivative(variable);
let derivative_func = build_function(*derivative)?;
derivatives_first_order.insert(variable.clone(), derivative_func);
}
let mut derivatives_second_order = Vec::new();
for variable in sorted_variables.iter() {
let mut derivatives_second_order_row = Vec::new();
for variable2 in sorted_variables.iter() {
let derivative = ast.derivative(variable).derivative(variable2);
let derivative_func = build_function(*derivative)?;
derivatives_second_order_row.push(derivative_func);
}
derivatives_second_order.push(derivatives_second_order_row);
}
Ok(Self {
equation_str,
ast: Box::new(ast),
fun,
derivatives_first_order,
derivatives_second_order,
var_map: variables.clone(),
sorted_variables,
})
}
pub fn eval<V: Vector>(&self, values: &V) -> Result<f64, EquationError> {
self.validate_input_length(values.as_slice())?;
Ok((self.fun)(values.as_slice()))
}
pub fn gradient(&self, values: &[f64]) -> Result<Vec<f64>, EquationError> {
self.validate_input_length(values)?;
Ok(self
.sorted_variables
.iter()
.map(|variable| (self.derivatives_first_order[variable])(values))
.collect())
}
pub fn hessian(&self, values: &[f64]) -> Result<Vec<Vec<f64>>, EquationError> {
self.validate_input_length(values)?;
Ok(self
.derivatives_second_order
.iter()
.map(|row| row.iter().map(|func| func(values)).collect())
.collect())
}
pub fn derivative(&self, variable: &str) -> Result<&JITFunction, EquationError> {
self.derivatives_first_order
.get(variable)
.ok_or(EquationError::DerivativeNotFound(variable.to_string()))
}
pub fn derive_wrt(&self, variables: &[&str]) -> Result<JITFunction, EquationError> {
let mut non_defined_variables = HashSet::new();
for variable in variables.iter() {
if !self.sorted_variables.contains(&variable.to_string()) {
non_defined_variables.insert(variable.to_string());
}
}
if !non_defined_variables.is_empty() {
return Err(EquationError::DerivativeNotFound(
non_defined_variables
.into_iter()
.collect::<Vec<String>>()
.join(", "),
));
}
let mut expr = self.ast.clone();
for variable in variables {
expr = expr.derivative(variable);
}
let fun = build_function(*expr)?;
Ok(fun)
}
pub fn derive_wrt_stack(&self, variables: &[&str]) -> Result<EquationSystem, EquationError> {
let mut derivative_asts = Vec::with_capacity(variables.len());
for variable in variables {
derivative_asts.push(*self.ast.derivative(variable));
}
EquationSystem::from_asts(derivative_asts, &self.var_map, OutputType::Vector)
}
pub fn variables(&self) -> &HashMap<String, u32> {
&self.var_map
}
pub fn equation_str(&self) -> &str {
&self.equation_str
}
pub fn fun(&self) -> &JITFunction {
&self.fun
}
pub fn sorted_variables(&self) -> &[String] {
&self.sorted_variables
}
fn validate_input_length(&self, values: &[f64]) -> Result<(), EquationError> {
if values.len() != self.sorted_variables.len() {
return Err(EquationError::InvalidInputLength {
expected: self.sorted_variables.len(),
got: values.len(),
});
}
Ok(())
}
}
pub fn extract_symbols(node: &Node) -> HashMap<String, u32> {
let mut symbols = HashSet::new();
extract_symbols_from_node(node, &mut symbols);
let mut symbols: Vec<String> = symbols.into_iter().collect();
symbols.sort();
symbols
.into_iter()
.enumerate()
.map(|(i, v)| (v, i as u32))
.collect()
}
pub fn extract_all_symbols(equations: &[String]) -> Vec<String> {
let all_symbols: HashSet<String> = equations
.iter()
.flat_map(|e| {
let tree: Node = build_operator_tree(e).unwrap();
let symbols = extract_symbols(&tree);
symbols.keys().cloned().collect::<Vec<String>>()
})
.collect();
let mut all_symbols: Vec<String> = all_symbols.into_iter().collect();
all_symbols.sort();
all_symbols
}
fn extract_symbols_from_node(node: &Node, symbols: &mut HashSet<String>) {
match node.operator() {
Operator::VariableIdentifierRead { identifier } => {
symbols.insert(identifier.to_string());
}
_ => {
for child in node.children() {
extract_symbols_from_node(child, symbols);
}
}
}
}
impl Clone for Equation {
fn clone(&self) -> Self {
Self {
equation_str: self.equation_str.clone(),
ast: self.ast.clone(),
fun: Arc::clone(&self.fun),
derivatives_first_order: self.derivatives_first_order.clone(),
derivatives_second_order: self.derivatives_second_order.clone(),
var_map: self.var_map.clone(),
sorted_variables: self.sorted_variables.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_equation() {
let eq = Equation::new("2*x + y^2".to_string()).unwrap();
let result = eq.eval(&[1.0, 2.0]).unwrap();
assert_eq!(result, 6.0);
}
#[test]
fn test_gradient() {
let eq = Equation::new("2*x + y^2".to_string()).unwrap();
let gradient = eq.gradient(&[1.0, 2.0]).unwrap();
assert_eq!(gradient, vec![2.0, 4.0]);
}
#[test]
fn test_hessian() {
let eq = Equation::new("2*x + y^2".to_string()).unwrap();
let hessian = eq.hessian(&[1.0, 2.0]).unwrap();
assert_eq!(hessian, vec![vec![0.0, 0.0], vec![0.0, 2.0]]);
}
#[test]
fn test_derivative() {
let eq = Equation::new("2*x + y^2".to_string()).unwrap();
let derivative = eq.derivative("x").unwrap();
let values = vec![1.0, 2.0];
let result = derivative(&values);
assert_eq!(result, 2.0);
}
#[test]
fn test_derive_wrt() {
let eq = Equation::new("x^2 * y^2".to_string()).unwrap();
let dxdy = eq.derive_wrt(&["x", "y"]).unwrap();
let values = vec![2.0, 3.0];
let result = dxdy(&values);
assert_eq!(result, 24.0);
}
#[test]
#[should_panic]
fn test_derive_wrt_invalid() {
let eq = Equation::new("x^2 * y^2".to_string()).unwrap();
let _ = eq.derive_wrt(&["x", "z"]).expect("Invalid variable");
}
#[test]
#[should_panic]
fn test_eval_invalid() {
let eq = Equation::new("2*x + y^2".to_string()).unwrap();
let _ = eq.eval(&[1.0]).expect("Invalid input length");
}
#[test]
fn test_from_var_map() {
let eq = Equation::from_var_map(
"2*x + y^2".to_string(),
&HashMap::from([("x".to_string(), 1), ("y".to_string(), 0)]),
)
.unwrap();
let result = eq.eval(&[2.0, 1.0]).unwrap();
assert_eq!(result, 6.0);
}
#[test]
#[should_panic]
fn test_from_var_map_invalid() {
let _ = Equation::from_var_map(
"2*x + y^2".to_string(),
&HashMap::from([("x".to_string(), 0), ("z".to_string(), 1)]),
)
.expect("Invalid variable");
}
#[test]
fn test_derive_wrt_stack() {
let eq = Equation::new("x^2 + 2*x*y + y^2 + z^3".to_string()).unwrap();
let derivatives = eq.derive_wrt_stack(&["x", "z"]).unwrap();
let values = vec![2.0, 3.0, 2.0]; let mut results = vec![0.0, 0.0];
derivatives.eval_into(&values, &mut results).unwrap();
assert_eq!(results, vec![10.0, 12.0]);
let derivatives = eq.derive_wrt_stack(&["z", "x"]).unwrap();
let mut results = vec![0.0, 0.0];
derivatives.eval_into(&values, &mut results).unwrap();
assert_eq!(results, vec![12.0, 10.0]);
}
#[test]
fn test_all_backends() {
use nalgebra::DVector;
use ndarray::Array1;
let eq = Equation::new("2*x + y^2".to_string()).unwrap();
let expected = 6.0;
let vec_input = vec![1.0, 2.0];
assert_eq!(eq.eval(&vec_input).unwrap(), expected);
let nalgebra_input = DVector::from_vec(vec![1.0, 2.0]);
assert_eq!(eq.eval(&nalgebra_input).unwrap(), expected);
let ndarray_input = Array1::from_vec(vec![1.0, 2.0]);
assert_eq!(eq.eval(&ndarray_input).unwrap(), expected);
}
#[test]
fn test_extract_all_symbols() {
let equations = vec![
"2*x + y".to_string(),
"z + x^2".to_string(),
"y*z".to_string(),
];
let variables = extract_all_symbols(&equations);
assert_eq!(
variables,
vec!["x".to_string(), "y".to_string(), "z".to_string()]
);
}
#[test]
fn test_debug_and_display_formatting() {
let eq = Equation::new("2*x + y^2".to_string()).unwrap();
let debug_output = format!("{eq:?}");
assert!(debug_output.contains("Equation"));
assert!(debug_output.contains("2*x + y^2"));
let display_output = format!("{eq}");
assert!(display_output.contains("Equation"));
assert!(display_output.contains("2*x + y^2"));
}
#[test]
fn test_equation_clone() {
let eq = Equation::new("2*x + y^2".to_string()).unwrap();
let cloned = eq.clone();
let values = vec![1.0, 2.0];
assert_eq!(eq.eval(&values).unwrap(), cloned.eval(&values).unwrap());
assert_eq!(
eq.gradient(&values).unwrap(),
cloned.gradient(&values).unwrap()
);
}
#[test]
fn test_invalid_expression() {
let result = Equation::new("2*x + )".to_string());
assert!(result.is_err());
}
#[test]
fn test_accessor_methods() {
let eq = Equation::new("2*x + y^2".to_string()).unwrap();
assert_eq!(eq.equation_str(), "2*x + y^2");
assert!(!eq.variables().is_empty());
assert!(!eq.sorted_variables().is_empty());
}
#[test]
fn test_variable_ordering() {
let mut vars = HashMap::new();
vars.insert("z".to_string(), 0);
vars.insert("y".to_string(), 1);
vars.insert("x".to_string(), 2);
let eq = Equation::from_var_map("x + y + z".to_string(), &vars).unwrap();
assert_eq!(eq.sorted_variables(), &["z", "y", "x"]);
let result = eq.eval(&[1.0, 2.0, 3.0]).unwrap(); assert_eq!(result, 6.0);
}
}