use crate::backends::vector::Vector;
use crate::builder::build_combined_function;
use crate::convert::build_ast;
use crate::equation::{extract_all_symbols, extract_symbols};
use crate::errors::EquationError;
use crate::expr::Expr;
use crate::prelude::Matrix;
use crate::types::CombinedJITFunction;
use evalexpr::build_operator_tree;
use itertools::Itertools;
use rayon::prelude::*;
use std::collections::HashMap;
use std::sync::Arc;
pub struct EquationSystem {
pub equations: Vec<String>,
pub asts: Vec<Expr>,
pub variable_map: HashMap<String, u32>,
pub sorted_variables: Vec<String>,
pub combined_fun: CombinedJITFunction,
pub partial_derivatives: HashMap<String, CombinedJITFunction>,
output_type: OutputType,
}
impl EquationSystem {
pub fn new(expressions: Vec<String>) -> Result<Self, EquationError> {
let sorted_variables = extract_all_symbols(&expressions);
let variable_map: HashMap<String, u32> = sorted_variables
.iter()
.enumerate()
.map(|(i, v)| (v.clone(), i as u32))
.collect();
let asts = Self::create_asts(&expressions, &variable_map)?;
Self::build(asts, expressions, variable_map, OutputType::Vector)
}
pub fn from_var_map(
expressions: Vec<String>,
variable_map: &HashMap<String, u32>,
) -> Result<Self, EquationError> {
let asts = Self::create_asts(&expressions, variable_map)?;
Self::build(asts, expressions, variable_map.clone(), OutputType::Vector)
}
pub(crate) fn from_asts(
asts: Vec<Expr>,
variable_map: &HashMap<String, u32>,
output_type: OutputType,
) -> Result<Self, EquationError> {
let expressions = asts.iter().map(|ast| ast.to_string()).collect();
Self::build(asts, expressions, variable_map.clone(), output_type)
}
fn build(
asts: Vec<Expr>,
equations: Vec<String>,
variable_map: HashMap<String, u32>,
output_type: OutputType,
) -> Result<Self, EquationError> {
let combined_fun = build_combined_function(asts.clone(), equations.len())?;
let mut jacobian_funs = HashMap::with_capacity(variable_map.len());
let sorted_variables: Vec<String> = variable_map
.iter()
.sorted_by_key(|(_, idx)| *idx)
.map(|(var, _)| var.clone())
.collect();
for var in sorted_variables {
let derivative_ast = asts
.iter()
.map(|ast| *ast.derivative(&var))
.collect::<Vec<Expr>>();
let jacobian_fun = build_combined_function(derivative_ast, asts.len())?;
jacobian_funs.insert(var, jacobian_fun);
}
Ok(Self {
equations,
asts,
variable_map: variable_map.clone(),
sorted_variables: variable_map.keys().sorted().cloned().collect(),
combined_fun,
partial_derivatives: jacobian_funs,
output_type,
})
}
fn create_asts(
expressions: &[String],
variable_map: &HashMap<String, u32>,
) -> Result<Vec<Expr>, EquationError> {
expressions
.iter()
.map(|expr| {
let node = build_operator_tree(expr)?;
let expr_vars = extract_symbols(&node);
for var in expr_vars.keys() {
if !variable_map.contains_key(var) {
return Err(EquationError::VariableNotFound(var.clone()));
}
}
let ast = build_ast(&node, variable_map)?;
Ok(*ast.simplify())
})
.collect::<Result<Vec<_>, EquationError>>()
}
pub fn eval_into<V: Vector, R: Vector>(
&self,
inputs: &V,
results: &mut R,
) -> Result<(), EquationError> {
self.validate_input_length(inputs.as_slice())?;
if results.len() != self.equations.len() {
return Err(EquationError::InvalidInputLength {
expected: self.equations.len(),
got: results.len(),
});
}
(self.combined_fun)(inputs.as_slice(), results.as_mut_slice());
Ok(())
}
pub fn eval<V: Vector>(&self, inputs: &V) -> Result<V, EquationError> {
let mut results = V::zeros(self.equations.len());
self.eval_into(inputs, &mut results)?;
Ok(results)
}
pub fn eval_into_matrix<V: Vector, R: Matrix>(
&self,
inputs: &V,
results: &mut R,
) -> Result<(), EquationError> {
match self.output_type {
OutputType::Vector => {
return Err(EquationError::MatrixOutputRequired);
}
OutputType::Matrix(n_rows, n_cols) => {
self.validate_matrix_dimensions(n_rows, n_cols)?;
}
}
(self.combined_fun)(inputs.as_slice(), results.flat_mut_slice());
Ok(())
}
pub fn eval_matrix<V: Vector, R: Matrix>(&self, inputs: &V) -> Result<R, EquationError> {
match self.output_type {
OutputType::Vector => Err(EquationError::MatrixOutputRequired),
OutputType::Matrix(n_rows, n_cols) => {
let mut results = R::zeros(n_rows, n_cols);
self.eval_into_matrix(inputs, &mut results)?;
Ok(results)
}
}
}
pub fn eval_parallel<V: Vector + Send + Sync>(
&self,
input_sets: &[V],
) -> Result<Vec<V>, EquationError> {
let num_threads = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(8);
let chunk_size = (input_sets.len() / (num_threads * 4)).max(1);
let n_equations = self.equations.len();
let systems: Vec<_> = (0..num_threads).map(|_| self.clone()).collect();
Ok(input_sets
.par_chunks(chunk_size)
.enumerate()
.map(|(chunk_idx, chunk)| {
let system = &systems[chunk_idx % systems.len()];
chunk
.iter()
.map(|inputs| {
let mut results = V::zeros(n_equations);
(system.combined_fun)(inputs.as_slice(), results.as_mut_slice());
results
})
.collect::<Vec<_>>()
})
.flatten()
.collect())
}
pub fn gradient(&self, inputs: &[f64], variable: &str) -> Result<Vec<f64>, EquationError> {
self.validate_input_length(inputs)?;
let n_equations = self.equations.len();
let mut results = vec![0.0; n_equations];
self.partial_derivatives
.get(variable)
.ok_or(EquationError::VariableNotFound(variable.to_string()))?(
inputs, &mut results
);
Ok(results)
}
pub fn eval_jacobian(
&self,
inputs: &[f64],
variables: Option<&[String]>,
) -> Result<Vec<Vec<f64>>, EquationError> {
self.validate_input_length(inputs)?;
let sorted_variables = variables.unwrap_or(&self.sorted_variables);
let mut results = Vec::with_capacity(self.equations.len());
let n_vars = sorted_variables.len();
for _ in 0..self.equations.len() {
results.push(Vec::with_capacity(n_vars));
}
let n_equations = self.equations.len();
for var in sorted_variables {
let fun = self.partial_derivatives.get(var).unwrap();
let mut derivatives = vec![0.0; n_equations];
fun(inputs, &mut derivatives);
for (eq_idx, &value) in derivatives.iter().enumerate() {
results[eq_idx].push(value);
}
}
Ok(results)
}
pub fn jacobian_wrt(&self, variables: &[&str]) -> Result<EquationSystem, EquationError> {
for var in variables {
if !self.variable_map.contains_key(*var) {
return Err(EquationError::VariableNotFound(var.to_string()));
}
}
let mut asts = vec![];
for ast in self.asts.iter() {
for var in variables {
asts.push(*ast.derivative(var));
}
}
let output_type = OutputType::Matrix(self.num_equations(), variables.len());
EquationSystem::from_asts(asts, &self.variable_map, output_type)
}
pub fn derive_wrt(&self, variables: &[&str]) -> Result<EquationSystem, EquationError> {
for var in variables {
if !self.variable_map.contains_key(*var) {
return Err(EquationError::VariableNotFound(var.to_string()));
}
}
let mut derivative_asts = self.asts.clone();
for var in variables {
derivative_asts = derivative_asts
.into_iter()
.map(|ast| *ast.derivative(var))
.collect();
}
EquationSystem::from_asts(derivative_asts, &self.variable_map, OutputType::Vector)
}
pub fn validate_matrix_dimensions(
&self,
n_rows: usize,
n_cols: usize,
) -> Result<(), EquationError> {
match self.output_type {
OutputType::Vector => {
return Err(EquationError::MatrixOutputRequired);
}
OutputType::Matrix(expected_rows, expected_cols) => {
if n_rows != expected_rows || n_cols != expected_cols {
return Err(EquationError::InvalidMatrixDimensions {
expected_rows,
expected_cols,
got_rows: n_rows,
got_cols: n_cols,
});
}
}
}
Ok(())
}
pub fn sorted_variables(&self) -> &[String] {
&self.sorted_variables
}
pub fn variables(&self) -> &HashMap<String, u32> {
&self.variable_map
}
pub fn equations(&self) -> &[String] {
&self.equations
}
pub fn fun(&self) -> &CombinedJITFunction {
&self.combined_fun
}
pub fn jacobian_funs(&self) -> &HashMap<String, CombinedJITFunction> {
&self.partial_derivatives
}
pub fn gradient_fun(&self, variable: &str) -> &CombinedJITFunction {
self.partial_derivatives.get(variable).unwrap()
}
pub fn num_equations(&self) -> usize {
self.equations.len()
}
fn validate_input_length(&self, inputs: &[f64]) -> Result<(), EquationError> {
if inputs.len() != self.sorted_variables.len() {
return Err(EquationError::InvalidInputLength {
expected: self.sorted_variables.len(),
got: inputs.len(),
});
}
Ok(())
}
}
impl Clone for EquationSystem {
fn clone(&self) -> Self {
Self {
equations: self.equations.clone(),
asts: self.asts.clone(),
variable_map: self.variable_map.clone(),
sorted_variables: self.sorted_variables.clone(),
combined_fun: Arc::clone(&self.combined_fun),
partial_derivatives: self.partial_derivatives.clone(),
output_type: self.output_type,
}
}
}
#[derive(Debug, Clone, Copy)]
pub(crate) enum OutputType {
Vector,
Matrix(usize, usize),
}
#[cfg(test)]
mod tests {
use super::*;
use nalgebra::DVector;
use ndarray::{Array1, Array2};
#[test]
fn test_system_with_different_variables() -> Result<(), Box<dyn std::error::Error>> {
let expressions = vec![
"2*x + y".to_string(), "z^2".to_string(), "x + y + z".to_string(), ];
let system = EquationSystem::new(expressions)?;
assert_eq!(system.sorted_variables, &["x", "y", "z"]);
let results = system.eval(&[1.0, 2.0, 3.0])?;
assert_eq!(
results.as_slice(),
vec![
4.0, 9.0, 6.0, ]
);
Ok(())
}
#[test]
fn test_consistent_variable_ordering() -> Result<(), Box<dyn std::error::Error>> {
let expressions = vec![
"y + x".to_string(), "x + z".to_string(), ];
let system = EquationSystem::new(expressions)?;
assert_eq!(system.sorted_variables, &["x", "y", "z"]);
let results = system.eval(&vec![1.0, 2.0, 3.0])?;
assert_eq!(
results.as_slice(),
vec![
3.0, 4.0, ]
);
Ok(())
}
#[test]
#[should_panic]
fn test_invalid_input_length() {
let system = EquationSystem::new(vec!["x + y".to_string(), "y + z".to_string()]).unwrap();
let _ = system.eval(&[1.0, 2.0]).unwrap();
}
#[test]
fn test_complex_expressions() -> Result<(), Box<dyn std::error::Error>> {
let expressions = vec![
"(x + y) * (x - y)".to_string(), "x^3 + y^2 * z".to_string(), "(x + y + z) / (x + 1)".to_string(), ];
let system = EquationSystem::new(expressions)?;
let results = system.eval(&[2.0, 3.0, 4.0])?;
assert_eq!(results[0], -5.0); assert_eq!(results[1], 44.0); assert_eq!(results[2], 3.0);
Ok(())
}
#[test]
fn test_custom_variable_map() -> Result<(), Box<dyn std::error::Error>> {
let mut var_map = HashMap::new();
var_map.insert("alpha".to_string(), 1);
var_map.insert("beta".to_string(), 0);
let expressions = vec!["2*alpha + beta".to_string(), "alpha^2 - beta".to_string()];
let system = EquationSystem::from_var_map(expressions, &var_map)?;
let results = system.eval(&[2.0, 1.0])?;
assert_eq!(results.as_slice(), &[4.0, -1.0]);
Ok(())
}
#[test]
fn test_error_undefined_variable() {
let expressions = vec![
"x + y".to_string(),
"x + undefined_var".to_string(), ];
let mut var_map = HashMap::new();
var_map.insert("x".to_string(), 0);
var_map.insert("y".to_string(), 1);
let result = EquationSystem::from_var_map(expressions, &var_map);
assert!(matches!(result, Err(EquationError::VariableNotFound(_))));
}
#[test]
fn test_empty_system() -> Result<(), Box<dyn std::error::Error>> {
let system = EquationSystem::new(vec![])?;
let results = system.eval(&[])?;
assert!(results.is_empty());
Ok(())
}
#[test]
fn test_derive_wrt() -> Result<(), Box<dyn std::error::Error>> {
let system = EquationSystem::new(vec!["x^2*y".to_string(), "x*y^2".to_string()])?;
let dx = system.derive_wrt(&["x"]).unwrap();
let mut dx_results = vec![0.0, 0.0];
dx.eval_into(&[2.0, 3.0], &mut dx_results).unwrap();
assert_eq!(dx_results, vec![12.0, 9.0]);
let dxy = system.derive_wrt(&["x", "y"]).unwrap();
let mut dxy_results = vec![0.0, 0.0];
dxy.eval_into(&[2.0, 3.0], &mut dxy_results).unwrap();
assert_eq!(dxy_results, vec![4.0, 6.0]);
Ok(())
}
#[test]
fn test_derive_wrt_invalid_variable() {
let system =
EquationSystem::new(vec!["2*x + y^2".to_string(), "x^2 + z".to_string()]).unwrap();
let result = system.derive_wrt(&["w"]);
assert!(matches!(result, Err(EquationError::VariableNotFound(_))));
}
#[test]
fn test_jacobian() -> Result<(), Box<dyn std::error::Error>> {
let system = EquationSystem::new(vec![
"x^2*y".to_string(), "x*y^2".to_string(), ])?;
let jacobian = system.eval_jacobian(&[2.0, 3.0], None)?;
assert_eq!(jacobian.len(), 2); assert_eq!(jacobian[0], vec![12.0, 4.0]); assert_eq!(jacobian[1], vec![9.0, 12.0]);
Ok(())
}
#[test]
fn test_jacobian_wrt() -> Result<(), Box<dyn std::error::Error>> {
let system = EquationSystem::new(vec![
"x^2*y + z".to_string(), "x*y^2 - z^2".to_string(), ])?;
let jacobian_fn = system.jacobian_wrt(&["x", "y"]).unwrap();
let mut results = Array2::zeros((2, 2));
jacobian_fn
.eval_into_matrix(&vec![2.0, 3.0, 1.0], &mut results)
.unwrap();
assert_eq!(results[[0, 0]], 12.0); assert_eq!(results[[0, 1]], 4.0); assert_eq!(results[[1, 0]], 9.0); assert_eq!(results[[1, 1]], 12.0);
Ok(())
}
#[test]
fn test_jacobian_wrt_single_variable() -> Result<(), Box<dyn std::error::Error>> {
let system = EquationSystem::new(vec![
"x^2*y".to_string(), "x*y^2".to_string(), ])?;
let jacobian_fn = system.jacobian_wrt(&["x"])?;
let mut results = Array2::zeros((2, 1));
jacobian_fn
.eval_into_matrix(&vec![2.0, 3.0], &mut results)
.unwrap();
assert_eq!(results[[0, 0]], 12.0); assert_eq!(results[[1, 0]], 9.0);
Ok(())
}
#[test]
fn test_jacobian_wrt_all_variables() -> Result<(), Box<dyn std::error::Error>> {
let system = EquationSystem::new(vec![
"x^2*y + z".to_string(), "x*y^2 - z^2".to_string(), ])?;
let jacobian_fn = system.jacobian_wrt(&["x", "y", "z"])?;
let mut results = Array2::zeros((2, 3));
jacobian_fn
.eval_into_matrix(&vec![2.0, 3.0, 1.0], &mut results)
.unwrap();
assert_eq!(results[[0, 0]], 12.0); assert_eq!(results[[0, 1]], 4.0); assert_eq!(results[[0, 2]], 1.0); assert_eq!(results[[1, 0]], 9.0); assert_eq!(results[[1, 1]], 12.0); assert_eq!(results[[1, 2]], -2.0);
Ok(())
}
#[test]
fn test_jacobian_wrt_invalid_variable() {
let system =
EquationSystem::new(vec!["x^2*y + z".to_string(), "x*y^2 - z^2".to_string()]).unwrap();
let result = system.jacobian_wrt(&["x", "w"]);
assert!(matches!(result, Err(EquationError::VariableNotFound(_))));
}
#[test]
fn test_jacobian_wrt_reuse_buffer() -> Result<(), Box<dyn std::error::Error>> {
let system = EquationSystem::new(vec![
"x^2*y".to_string(), "x*y^2".to_string(), ])?;
let jacobian_fn = system.jacobian_wrt(&["x", "y"])?;
let mut results = Array2::zeros((2, 2));
jacobian_fn
.eval_into_matrix(&vec![2.0, 3.0], &mut results)
.unwrap();
assert_eq!(results[[0, 0]], 12.0);
assert_eq!(results[[0, 1]], 4.0);
assert_eq!(results[[1, 0]], 9.0);
assert_eq!(results[[1, 1]], 12.0);
jacobian_fn
.eval_into_matrix(&vec![1.0, 2.0], &mut results)
.unwrap();
assert_eq!(results[[0, 0]], 4.0);
assert_eq!(results[[0, 1]], 1.0);
assert_eq!(results[[1, 0]], 4.0);
assert_eq!(results[[1, 1]], 4.0);
Ok(())
}
#[test]
fn test_different_vector_types() -> Result<(), Box<dyn std::error::Error>> {
let system = EquationSystem::new(vec!["x^2*y".to_string(), "x*y^2".to_string()])?;
let vec_inputs = vec![2.0, 3.0];
let vec_results = system.eval(&vec_inputs)?;
assert_eq!(vec_results.as_slice(), &[12.0, 18.0]);
let ndarray_inputs = Array1::from_vec(vec![2.0, 3.0]);
let ndarray_results = system.eval(&ndarray_inputs)?;
assert_eq!(ndarray_results.as_slice().unwrap(), &[12.0, 18.0]);
let nalgebra_inputs = DVector::from_vec(vec![2.0, 3.0]);
let nalgebra_results = system.eval(&nalgebra_inputs)?;
assert_eq!(nalgebra_results.as_slice(), &[12.0, 18.0]);
Ok(())
}
#[test]
fn test_eval_parallel() -> Result<(), Box<dyn std::error::Error>> {
let system = EquationSystem::new(vec!["x^2*y".to_string(), "x*y^2".to_string()])?;
let input_sets = vec![
vec![2.0, 3.0],
vec![1.0, 2.0],
vec![3.0, 4.0],
vec![0.0, 1.0],
];
let results = system.eval_parallel(&input_sets)?;
assert_eq!(results[0].as_slice(), &[12.0, 18.0]); assert_eq!(results[1].as_slice(), &[2.0, 4.0]); assert_eq!(results[2].as_slice(), &[36.0, 48.0]); assert_eq!(results[3].as_slice(), &[0.0, 0.0]);
Ok(())
}
#[test]
fn test_eval_into() -> Result<(), Box<dyn std::error::Error>> {
let system = EquationSystem::new(vec!["x^2*y".to_string(), "x*y^2".to_string()])?;
let mut results = vec![0.0; 2];
system.eval_into(&vec![2.0, 3.0], &mut results)?;
assert_eq!(results, vec![12.0, 18.0]);
let mut ndarray_results = Array1::zeros(2);
system.eval_into(&Array1::from_vec(vec![2.0, 3.0]), &mut ndarray_results)?;
assert_eq!(ndarray_results.as_slice().unwrap(), &[12.0, 18.0]);
let mut wrong_size = vec![0.0; 3];
assert!(matches!(
system.eval_into(&vec![2.0, 3.0], &mut wrong_size),
Err(EquationError::InvalidInputLength { .. })
));
Ok(())
}
#[test]
fn test_matrix_output_errors() -> Result<(), Box<dyn std::error::Error>> {
let system = EquationSystem::new(vec!["x^2*y".to_string(), "x*y^2".to_string()])?;
let mut results = Array2::zeros((2, 2));
assert!(matches!(
system.eval_into_matrix(&vec![2.0, 3.0], &mut results),
Err(EquationError::MatrixOutputRequired)
));
assert!(matches!(
system.eval_matrix::<_, Array2<f64>>(&vec![2.0, 3.0]),
Err(EquationError::MatrixOutputRequired)
));
Ok(())
}
#[test]
fn test_gradient() -> Result<(), Box<dyn std::error::Error>> {
let system = EquationSystem::new(vec![
"x^2*y + z".to_string(), "x*y^2 - z^2".to_string(), ])?;
let dx = system.gradient(&[2.0, 3.0, 1.0], "x")?;
assert_eq!(dx, vec![12.0, 9.0]);
let dy = system.gradient(&[2.0, 3.0, 1.0], "y")?;
assert_eq!(dy, vec![4.0, 12.0]);
let dz = system.gradient(&[2.0, 3.0, 1.0], "z")?;
assert_eq!(dz, vec![1.0, -2.0]);
let result = system.gradient(&[2.0, 3.0, 1.0], "w");
assert!(matches!(result, Err(EquationError::VariableNotFound(_))));
let result = system.gradient(&[2.0, 3.0], "x");
assert!(matches!(
result,
Err(EquationError::InvalidInputLength { .. })
));
Ok(())
}
#[test]
fn test_eval_matrix_on_vector_system() -> Result<(), Box<dyn std::error::Error>> {
let system = EquationSystem::new(vec!["x^2*y".to_string(), "x*y^2".to_string()])?;
let mut results = Array2::zeros((2, 2));
assert!(matches!(
system.eval_into_matrix(&vec![2.0, 3.0], &mut results),
Err(EquationError::MatrixOutputRequired)
));
assert!(matches!(
system.eval_matrix::<_, Array2<f64>>(&vec![2.0, 3.0]),
Err(EquationError::MatrixOutputRequired)
));
Ok(())
}
#[test]
fn test_getters() -> Result<(), Box<dyn std::error::Error>> {
let system = EquationSystem::new(vec!["x^2*y".to_string(), "x*y^2".to_string()])?;
assert_eq!(system.sorted_variables(), &["x", "y"]);
let var_map = system.variables();
assert_eq!(var_map.get("x"), Some(&0));
assert_eq!(var_map.get("y"), Some(&1));
assert_eq!(system.equations(), &["x^2*y", "x*y^2"]);
let fun = system.fun();
let mut results = vec![0.0; 2];
fun(&[2.0, 3.0], &mut results);
assert_eq!(results, vec![12.0, 18.0]);
let jacobian_funs = system.jacobian_funs();
assert!(jacobian_funs.contains_key("x"));
assert!(jacobian_funs.contains_key("y"));
let dx_fun = system.gradient_fun("x");
let mut dx_results = vec![0.0; 2];
dx_fun(&[2.0, 3.0], &mut dx_results);
assert_eq!(dx_results, vec![12.0, 9.0]);
assert_eq!(system.num_equations(), 2);
Ok(())
}
}