evalexpr_jit/system.rs
1//! System of equations evaluation with JIT compilation.
2//!
3//! This module provides functionality for evaluating multiple mathematical equations
4//! simultaneously using JIT compilation. The equations are combined into a single
5//! optimized function for efficient evaluation.
6//!
7//! # Features
8//!
9//! - JIT compilation of multiple equations into a single function
10//! - Consistent variable ordering across equations
11//! - Automatic variable extraction and mapping
12//! - Efficient batch evaluation
13//! - Automatic derivative computation and Jacobian matrix generation
14//! - Higher-order derivative support
15//! - Parallel evaluation capabilities
16//!
17//! # Example
18//!
19//! ```
20//! use evalexpr_jit::system::EquationSystem;
21//!
22//! let system = EquationSystem::new(vec![
23//! "2*x + y".to_string(), // first equation
24//! "x^2 + z".to_string(), // second equation
25//! ]).unwrap();
26//!
27//! // Variables are automatically sorted (x, y, z)
28//! let results = system.eval(&vec![1.0, 2.0, 3.0]).unwrap();
29//! assert_eq!(results.as_slice(), vec![4.0, 4.0]); // [2*1 + 2, 1^2 + 3]
30//!
31//! // Compute derivatives
32//! let dx = system.gradient(&vec![1.0, 2.0, 3.0], "x").unwrap();
33//! assert_eq!(dx.as_slice(), vec![2.0, 2.0]); // [d/dx(2x + y), d/dx(x^2 + z)]
34//! ```
35//!
36//! # Computing Derivatives
37//!
38//! There are two main ways to create new equation systems with derivatives from existing ones:
39//!
40//! ## 1. Using `derive_wrt` for Gradients
41//!
42//! Creates a new equation system that computes derivatives with respect to specified variables:
43//!
44//! ```
45//! # use evalexpr_jit::system::EquationSystem;
46//! let system = EquationSystem::new(vec![
47//! "x^2*y".to_string(), // f1
48//! "x*y^2".to_string(), // f2
49//! ]).unwrap();
50//!
51//! // First-order derivative with respect to x
52//! let dx = system.derive_wrt(&["x"]).unwrap();
53//! let results = dx.eval(&[2.0, 3.0]).unwrap();
54//! assert_eq!(results.as_slice(), &[12.0, 9.0]); // [d(x^2*y)/dx = 2xy, d(x*y^2)/dx = y^2]
55//!
56//! // Second-order mixed derivative (first x, then y)
57//! let dxy = system.derive_wrt(&["x", "y"]).unwrap();
58//! let results = dxy.eval(&[2.0, 3.0]).unwrap();
59//! assert_eq!(results.as_slice(), &[4.0, 6.0]); // [d²(x^2*y)/dxdy = 2x, d²(x*y^2)/dxdy = 2y]
60//! ```
61//!
62//! ## 2. Using `jacobian_wrt` for Jacobian Matrices
63//!
64//! Creates a new equation system that computes the Jacobian matrix with respect to specified variables:
65//!
66//! ```
67//! # use evalexpr_jit::system::EquationSystem;
68//! # use ndarray::Array2;
69//! let system = EquationSystem::new(vec![
70//! "x^2*y + z".to_string(), // f1
71//! "x*y^2 - z^2".to_string(), // f2
72//! ]).unwrap();
73//!
74//! // Create Jacobian system for x and y
75//! let jacobian_system = system.jacobian_wrt(&["x", "y"]).unwrap();
76//!
77//! // Prepare matrix to store results (2 equations × 2 variables)
78//! let mut results = Array2::zeros((2, 2));
79//!
80//! // Evaluate Jacobian at point (x=2, y=3, z=1)
81//! jacobian_system.eval_into_matrix(&[2.0, 3.0, 1.0], &mut results).unwrap();
82//!
83//! // results now contains:
84//! // [
85//! // [12.0, 4.0], // [∂f1/∂x, ∂f1/∂y]
86//! // [9.0, 12.0], // [∂f2/∂x, ∂f2/∂y]
87//! // ]
88//! ```
89//!
90
91use crate::backends::vector::Vector;
92use crate::builder::build_combined_function;
93use crate::convert::build_ast;
94use crate::equation::{extract_all_symbols, extract_symbols};
95use crate::errors::EquationError;
96use crate::expr::Expr;
97use crate::prelude::Matrix;
98use crate::types::CombinedJITFunction;
99use evalexpr::build_operator_tree;
100use itertools::Itertools;
101use rayon::prelude::*;
102use std::collections::HashMap;
103use std::sync::Arc;
104
105/// Represents a system of mathematical equations that can be evaluated together.
106pub struct EquationSystem {
107 /// The original string representations of the equations
108 pub equations: Vec<String>,
109 /// The AST representations of the equations
110 pub asts: Vec<Expr>,
111 /// Maps variable names to their indices in the input array
112 pub variable_map: HashMap<String, u32>,
113 /// Variables in sorted order for consistent input ordering
114 pub sorted_variables: Vec<String>,
115 /// The JIT-compiled function that evaluates all equations
116 pub combined_fun: CombinedJITFunction,
117 /// Partial derivatives of the system - maps variable names to their derivative functions
118 /// E.g. {"x": df(x, y, z)/dx, "y": df(x, y, z)/dy}
119 pub partial_derivatives: HashMap<String, CombinedJITFunction>,
120 /// The type of output
121 output_type: OutputType,
122}
123
124impl EquationSystem {
125 /// Creates a new equation system from a vector of expression strings.
126 ///
127 /// This constructor automatically extracts variables from the expressions and assigns them indices
128 /// in alphabetical order. The resulting system can evaluate all equations efficiently and compute
129 /// derivatives with respect to any variable.
130 ///
131 /// # Arguments
132 /// * `expressions` - Vector of mathematical expressions as strings
133 ///
134 /// # Returns
135 /// A new `EquationSystem` with JIT-compiled evaluation function and derivative capabilities
136 ///
137 /// # Example
138 /// ```
139 /// # use evalexpr_jit::system::EquationSystem;
140 /// let system = EquationSystem::new(vec![
141 /// "2*x + y".to_string(),
142 /// "x^2 + z".to_string()
143 /// ]).unwrap();
144 ///
145 /// // Evaluate system
146 /// let results = system.eval(&vec![1.0, 2.0, 3.0]).unwrap();
147 ///
148 /// // Compute derivatives
149 /// let dx = system.gradient(&vec![1.0, 2.0, 3.0], "x").unwrap();
150 /// ```
151 pub fn new(expressions: Vec<String>) -> Result<Self, EquationError> {
152 let sorted_variables = extract_all_symbols(&expressions);
153 let variable_map: HashMap<String, u32> = sorted_variables
154 .iter()
155 .enumerate()
156 .map(|(i, v)| (v.clone(), i as u32))
157 .collect();
158 let asts = Self::create_asts(&expressions, &variable_map)?;
159 Self::build(asts, expressions, variable_map, OutputType::Vector)
160 }
161
162 /// Creates a new equation system from a vector of expressions and a variable map.
163 ///
164 /// This constructor allows explicit control over variable ordering by providing a map
165 /// of variable names to their indices. This is useful when you need to ensure specific
166 /// variable ordering or when integrating with other systems that expect variables in
167 /// a particular order.
168 ///
169 /// # Arguments
170 /// * `expressions` - Vector of mathematical expressions as strings
171 /// * `variable_map` - Map of variable names to their indices, defining input order
172 ///
173 /// # Returns
174 /// A new `EquationSystem` with JIT-compiled evaluation function using the specified variable ordering
175 ///
176 /// # Example
177 /// ```
178 /// use evalexpr_jit::prelude::*;
179 /// use std::collections::HashMap;
180 ///
181 /// let var_map: HashMap<String, u32> = [
182 /// ("x".to_string(), 0),
183 /// ("y".to_string(), 1),
184 /// ("z".to_string(), 2),
185 /// ].into_iter().collect();
186 ///
187 /// let system = EquationSystem::from_var_map(
188 /// vec!["2*x + y".to_string(), "x^2 + z".to_string()],
189 /// &var_map
190 /// ).unwrap();
191 /// ```
192 pub fn from_var_map(
193 expressions: Vec<String>,
194 variable_map: &HashMap<String, u32>,
195 ) -> Result<Self, EquationError> {
196 let asts = Self::create_asts(&expressions, variable_map)?;
197 Self::build(asts, expressions, variable_map.clone(), OutputType::Vector)
198 }
199
200 /// Creates a new equation system directly from AST nodes and a variable map.
201 ///
202 /// This constructor allows creating a system from pre-built AST nodes rather than parsing
203 /// expressions from strings. This is useful when augmenting existing systems or creating derivatives
204 /// of existing systems.
205 ///
206 /// Please note, this constructor is not meant to be used by the end user and is only available for internal use
207 /// to create derivatives of existing systems.
208 ///
209 /// # Arguments
210 /// * `asts` - Vector of expression AST nodes
211 /// * `variable_map` - Map of variable names to their indices, defining input order
212 /// * `output_type` - The type of output. Used to determine the shape of the output vector
213 /// # Returns
214 /// A new `EquationSystem` with JIT-compiled evaluation function using the specified ASTs
215 pub(crate) fn from_asts(
216 asts: Vec<Expr>,
217 variable_map: &HashMap<String, u32>,
218 output_type: OutputType,
219 ) -> Result<Self, EquationError> {
220 let expressions = asts.iter().map(|ast| ast.to_string()).collect();
221 Self::build(asts, expressions, variable_map.clone(), output_type)
222 }
223
224 /// Core builder function used by both `new()` and `from_var_map()`.
225 ///
226 /// This internal function handles the common construction logic for both public constructors.
227 /// It builds ASTs and creates JIT-compiled functions for both evaluation and derivatives.
228 ///
229 /// # Arguments
230 /// * `asts` - Vector of AST nodes
231 /// * `equations` - Original expression strings
232 /// * `variable_map` - Map of variable names to indices
233 /// * `output_type` - The type of output (vector or matrix)
234 ///
235 /// # Returns
236 /// A new `EquationSystem` with JIT-compiled evaluation function and derivative capabilities
237 fn build(
238 asts: Vec<Expr>,
239 equations: Vec<String>,
240 variable_map: HashMap<String, u32>,
241 output_type: OutputType,
242 ) -> Result<Self, EquationError> {
243 // Create combined JIT function
244 let combined_fun = build_combined_function(asts.clone(), equations.len())?;
245
246 // Create derivative functions for each variable forming a Jacobian matrix
247 let mut jacobian_funs = HashMap::with_capacity(variable_map.len());
248
249 let sorted_variables: Vec<String> = variable_map
250 .iter()
251 .sorted_by_key(|(_, idx)| *idx)
252 .map(|(var, _)| var.clone())
253 .collect();
254
255 for var in sorted_variables {
256 let derivative_ast = asts
257 .iter()
258 .map(|ast| *ast.derivative(&var))
259 .collect::<Vec<Expr>>();
260 let jacobian_fun = build_combined_function(derivative_ast, asts.len())?;
261 jacobian_funs.insert(var, jacobian_fun);
262 }
263
264 Ok(Self {
265 equations,
266 asts,
267 variable_map: variable_map.clone(),
268 sorted_variables: variable_map.keys().sorted().cloned().collect(),
269 combined_fun,
270 partial_derivatives: jacobian_funs,
271 output_type,
272 })
273 }
274
275 /// Creates abstract syntax trees (ASTs) from a vector of mathematical expressions.
276 ///
277 /// This internal function parses each expression string into an AST, validates that all
278 /// variables used in the expressions exist in the provided variable map, and returns
279 /// simplified ASTs ready for compilation.
280 ///
281 /// # Arguments
282 /// * `expressions` - Vector of mathematical expression strings to parse
283 /// * `variable_map` - Map of valid variable names to their indices
284 ///
285 /// # Returns
286 /// A vector of simplified ASTs, one for each input expression
287 ///
288 /// # Errors
289 /// Returns `EquationError::VariableNotFound` if an expression uses a variable
290 /// that doesn't exist in the variable map
291 fn create_asts(
292 expressions: &[String],
293 variable_map: &HashMap<String, u32>,
294 ) -> Result<Vec<Expr>, EquationError> {
295 expressions
296 .iter()
297 .map(|expr| {
298 let node = build_operator_tree(expr)?;
299
300 // Validate variables
301 let expr_vars = extract_symbols(&node);
302 for var in expr_vars.keys() {
303 if !variable_map.contains_key(var) {
304 return Err(EquationError::VariableNotFound(var.clone()));
305 }
306 }
307
308 // Build and simplify AST
309 let ast = build_ast(&node, variable_map)?;
310 Ok(*ast.simplify())
311 })
312 .collect::<Result<Vec<_>, EquationError>>()
313 }
314
315 /// Evaluates all equations in the system with the given input values
316 /// into a pre-allocated buffer.
317 ///
318 /// # Arguments
319 /// * `inputs` - Input vector implementing the Vector trait
320 /// * `results` - Pre-allocated vector to store results
321 ///
322 /// # Returns
323 /// Reference to the results vector containing the evaluated values
324 pub fn eval_into<V: Vector, R: Vector>(
325 &self,
326 inputs: &V,
327 results: &mut R,
328 ) -> Result<(), EquationError> {
329 self.validate_input_length(inputs.as_slice())?;
330 if results.len() != self.equations.len() {
331 return Err(EquationError::InvalidInputLength {
332 expected: self.equations.len(),
333 got: results.len(),
334 });
335 }
336
337 (self.combined_fun)(inputs.as_slice(), results.as_mut_slice());
338 Ok(())
339 }
340
341 /// Evaluates all equations in the system with the given input values.
342 /// Allocates a new vector for results.
343 ///
344 /// # Arguments
345 /// * `inputs` - Input vector implementing the Vector trait
346 ///
347 /// # Returns
348 /// Vector of results, one for each equation in the system
349 pub fn eval<V: Vector>(&self, inputs: &V) -> Result<V, EquationError> {
350 let mut results = V::zeros(self.equations.len());
351 self.eval_into(inputs, &mut results)?;
352 Ok(results)
353 }
354
355 /// Evaluates all equations in the system with the given input values into a pre-allocated matrix.
356 ///
357 /// This method is used when the equation system is configured to output a matrix rather than a vector.
358 /// The results matrix must have the correct dimensions matching the system's output type.
359 ///
360 /// # Arguments
361 /// * `inputs` - Input vector implementing the Vector trait
362 /// * `results` - Pre-allocated matrix to store results
363 ///
364 /// # Returns
365 /// Ok(()) if successful, or an error if the system is not configured for matrix output
366 pub fn eval_into_matrix<V: Vector, R: Matrix>(
367 &self,
368 inputs: &V,
369 results: &mut R,
370 ) -> Result<(), EquationError> {
371 match self.output_type {
372 OutputType::Vector => {
373 // If the system is not configured to output a matrix, throw an error
374 return Err(EquationError::MatrixOutputRequired);
375 }
376 OutputType::Matrix(n_rows, n_cols) => {
377 self.validate_matrix_dimensions(n_rows, n_cols)?;
378 }
379 }
380
381 (self.combined_fun)(inputs.as_slice(), results.flat_mut_slice());
382 Ok(())
383 }
384
385 /// Evaluates all equations in the system with the given input values into a new matrix.
386 ///
387 /// This method allocates a new matrix with the correct dimensions and evaluates the system into it.
388 /// It should only be used when the equation system is configured to output a matrix.
389 ///
390 /// # Arguments
391 /// * `inputs` - Input vector implementing the Vector trait
392 ///
393 /// # Returns
394 /// Matrix containing the evaluated results, or an error if the system is not configured for matrix output
395 pub fn eval_matrix<V: Vector, R: Matrix>(&self, inputs: &V) -> Result<R, EquationError> {
396 match self.output_type {
397 OutputType::Vector => Err(EquationError::MatrixOutputRequired),
398 OutputType::Matrix(n_rows, n_cols) => {
399 let mut results = R::zeros(n_rows, n_cols);
400 self.eval_into_matrix(inputs, &mut results)?;
401 Ok(results)
402 }
403 }
404 }
405
406 /// Evaluates the equation system in parallel for multiple input sets.
407 ///
408 /// # Arguments
409 /// * `input_sets` - Slice of input vectors, each must match the number of variables
410 ///
411 /// # Returns
412 /// Vector of result vectors, one for each input set
413 pub fn eval_parallel<V: Vector + Send + Sync>(
414 &self,
415 input_sets: &[V],
416 ) -> Result<Vec<V>, EquationError> {
417 let num_threads = std::thread::available_parallelism()
418 .map(|n| n.get())
419 .unwrap_or(8);
420
421 let chunk_size = (input_sets.len() / (num_threads * 4)).max(1);
422 let n_equations = self.equations.len();
423
424 // Since we're using Arc, we can efficiently clone the system for parallel processing
425 let systems: Vec<_> = (0..num_threads).map(|_| self.clone()).collect();
426
427 Ok(input_sets
428 .par_chunks(chunk_size)
429 .enumerate()
430 .map(|(chunk_idx, chunk)| {
431 let system = &systems[chunk_idx % systems.len()];
432 chunk
433 .iter()
434 .map(|inputs| {
435 let mut results = V::zeros(n_equations);
436 (system.combined_fun)(inputs.as_slice(), results.as_mut_slice());
437 results
438 })
439 .collect::<Vec<_>>()
440 })
441 .flatten()
442 .collect())
443 }
444
445 /// Returns the gradient of the equation system with respect to a specific variable.
446 ///
447 /// The gradient contains the partial derivatives of all equations with respect to the given variable,
448 /// evaluated at the provided input values. This is equivalent to one row of the Jacobian matrix.
449 ///
450 /// # Arguments
451 /// * `inputs` - Slice of input values at which to evaluate the gradient
452 /// * `variable` - Name of the variable to compute derivatives with respect to
453 ///
454 /// # Returns
455 /// Vector containing the partial derivatives of each equation with respect to
456 /// the specified variable, evaluated at the given input values
457 ///
458 /// # Example
459 /// ```
460 /// use evalexpr_jit::prelude::*;
461 ///
462 /// let system = EquationSystem::new(vec![
463 /// "x^2*y".to_string(), // f1
464 /// "x*y^2".to_string(), // f2
465 /// ]).unwrap();
466 ///
467 /// let gradient = system.gradient(&[2.0, 3.0], "x").unwrap();
468 /// // gradient contains [12.0, 9.0] (∂f1/∂x, ∂f2/∂x)
469 /// ```
470 pub fn gradient(&self, inputs: &[f64], variable: &str) -> Result<Vec<f64>, EquationError> {
471 self.validate_input_length(inputs)?;
472 let n_equations = self.equations.len();
473 let mut results = vec![0.0; n_equations];
474 self.partial_derivatives
475 .get(variable)
476 .ok_or(EquationError::VariableNotFound(variable.to_string()))?(
477 inputs, &mut results
478 );
479 Ok(results)
480 }
481
482 /// Computes the Jacobian matrix of the equation system at the given input values.
483 ///
484 /// The Jacobian matrix contains all first-order partial derivatives of the system.
485 /// Each row corresponds to an equation, and each column corresponds to a variable.
486 /// The entry at position (i,j) is the partial derivative of equation i with respect to variable j.
487 ///
488 /// # Arguments
489 /// * `inputs` - Slice of input values at which to evaluate the Jacobian
490 /// * `variables` - Optional slice of variable names to include in the Jacobian. If None, includes all variables in sorted order.
491 ///
492 /// # Returns
493 /// The Jacobian matrix as a vector of vectors, where each inner vector
494 /// contains the partial derivatives of one equation with respect to all variables
495 ///
496 /// # Example
497 /// ```
498 /// use evalexpr_jit::prelude::*;
499 ///
500 /// let system = EquationSystem::new(vec![
501 /// "x^2*y".to_string(), // f1
502 /// "x*y^2".to_string(), // f2
503 /// ]).unwrap();
504 ///
505 /// let jacobian = system.eval_jacobian(&[2.0, 3.0], None).unwrap();
506 /// // jacobian[0] contains [12.0, 4.0] // ∂f1/∂x, ∂f1/∂y
507 /// // jacobian[1] contains [9.0, 12.0] // ∂f2/∂x, ∂f2/∂y
508 /// ```
509 pub fn eval_jacobian(
510 &self,
511 inputs: &[f64],
512 variables: Option<&[String]>,
513 ) -> Result<Vec<Vec<f64>>, EquationError> {
514 self.validate_input_length(inputs)?;
515
516 let sorted_variables = variables.unwrap_or(&self.sorted_variables);
517 let mut results = Vec::with_capacity(self.equations.len());
518 let n_vars = sorted_variables.len();
519
520 // Initialize vectors for each equation
521 for _ in 0..self.equations.len() {
522 results.push(Vec::with_capacity(n_vars));
523 }
524
525 // Fill the transposed matrix
526 let n_equations = self.equations.len();
527 for var in sorted_variables {
528 let fun = self.partial_derivatives.get(var).unwrap();
529 let mut derivatives = vec![0.0; n_equations];
530 fun(inputs, &mut derivatives);
531 for (eq_idx, &value) in derivatives.iter().enumerate() {
532 results[eq_idx].push(value);
533 }
534 }
535
536 Ok(results)
537 }
538
539 /// Creates a new equation system that computes the Jacobian matrix with respect to specific variables.
540 ///
541 /// This method creates a new equation system optimized for computing partial derivatives
542 /// with respect to a subset of variables. The resulting system evaluates all equations' derivatives
543 /// with respect to the specified variables and arranges them in matrix form.
544 ///
545 /// # Arguments
546 /// * `variables` - Slice of variable names to include in the Jacobian matrix
547 ///
548 /// # Returns
549 /// A new `EquationSystem` that computes the Jacobian matrix when evaluated. The output matrix
550 /// has dimensions `[n_equations × n_variables]`, where each row contains the derivatives
551 /// of one equation with respect to the specified variables.
552 ///
553 /// # Errors
554 /// Returns `EquationError::VariableNotFound` if any of the specified variables doesn't exist
555 /// in the system.
556 ///
557 /// # Example
558 /// ```
559 /// use evalexpr_jit::prelude::*;
560 /// use ndarray::Array2;
561 /// let system = EquationSystem::new(vec![
562 /// "x^2*y + z".to_string(), // f1
563 /// "x*y^2 - z^2".to_string(), // f2
564 /// ]).unwrap();
565 ///
566 /// // Create Jacobian system for x and y only
567 /// let jacobian_system = system.jacobian_wrt(&["x", "y"]).unwrap();
568 ///
569 /// // Prepare matrix to store results (2 equations × 2 variables)
570 /// let mut results = Array2::zeros((2, 2));
571 ///
572 /// // Evaluate Jacobian at point (x=2, y=3, z=1)
573 /// jacobian_system.eval_into_matrix(&[2.0, 3.0, 1.0], &mut results).unwrap();
574 ///
575 /// // results now contains:
576 /// // [
577 /// // [12.0, 4.0], // [∂f1/∂x, ∂f1/∂y]
578 /// // [9.0, 12.0], // [∂f2/∂x, ∂f2/∂y]
579 /// // ]
580 /// ```
581 ///
582 /// # Performance Notes
583 /// - The returned system is JIT-compiled and optimized for repeated evaluations
584 /// - Pre-allocate the results matrix and reuse it for better performance
585 /// - The matrix dimensions will be `[n_equations × n_variables]`
586 pub fn jacobian_wrt(&self, variables: &[&str]) -> Result<EquationSystem, EquationError> {
587 // Verify all variables exist
588 for var in variables {
589 if !self.variable_map.contains_key(*var) {
590 return Err(EquationError::VariableNotFound(var.to_string()));
591 }
592 }
593
594 let mut asts = vec![];
595 for ast in self.asts.iter() {
596 for var in variables {
597 asts.push(*ast.derivative(var));
598 }
599 }
600
601 let output_type = OutputType::Matrix(self.num_equations(), variables.len());
602
603 EquationSystem::from_asts(asts, &self.variable_map, output_type)
604 }
605
606 /// Creates a new equation system containing the higher-order derivatives of all equations
607 /// with respect to multiple variables.
608 ///
609 /// This method allows computing mixed partial derivatives by specifying the variables
610 /// in the order of differentiation. For example, passing ["x", "y"] computes ∂²f/∂x∂y
611 /// for each equation f.
612 ///
613 /// # Arguments
614 /// * `variables` - Slice of variable names to differentiate with respect to, in order
615 ///
616 /// # Returns
617 /// A JIT-compiled function that computes the higher-order derivatives
618 ///
619 /// # Example
620 /// ```
621 /// # use evalexpr_jit::system::EquationSystem;
622 /// # use ndarray::Array1;
623 /// let system = EquationSystem::new(vec![
624 /// "x^2*y".to_string(), // f1
625 /// "x*y^2".to_string(), // f2
626 /// ]).unwrap();
627 ///
628 /// let derivatives = system.derive_wrt(&["x", "y"]).unwrap();
629 /// let mut results = Array1::zeros(2);
630 /// derivatives.eval_into(&vec![2.0, 3.0], &mut results).unwrap();
631 /// assert_eq!(results.as_slice().unwrap(), vec![4.0, 6.0]); // ∂²f1/∂x∂y = 2x, ∂²f2/∂x∂y = 2y
632 /// ```
633 pub fn derive_wrt(&self, variables: &[&str]) -> Result<EquationSystem, EquationError> {
634 // Verify all variables exist
635 for var in variables {
636 if !self.variable_map.contains_key(*var) {
637 return Err(EquationError::VariableNotFound(var.to_string()));
638 }
639 }
640
641 // Create higher-order derivatives of all ASTs
642 let mut derivative_asts = self.asts.clone();
643 for var in variables {
644 derivative_asts = derivative_asts
645 .into_iter()
646 .map(|ast| *ast.derivative(var))
647 .collect();
648 }
649
650 // Create new system from derivative ASTs
651 EquationSystem::from_asts(derivative_asts, &self.variable_map, OutputType::Vector)
652 }
653
654 pub fn validate_matrix_dimensions(
655 &self,
656 n_rows: usize,
657 n_cols: usize,
658 ) -> Result<(), EquationError> {
659 match self.output_type {
660 OutputType::Vector => {
661 return Err(EquationError::MatrixOutputRequired);
662 }
663 OutputType::Matrix(expected_rows, expected_cols) => {
664 if n_rows != expected_rows || n_cols != expected_cols {
665 return Err(EquationError::InvalidMatrixDimensions {
666 expected_rows,
667 expected_cols,
668 got_rows: n_rows,
669 got_cols: n_cols,
670 });
671 }
672 }
673 }
674 Ok(())
675 }
676
677 /// Returns the sorted variables in the system.
678 pub fn sorted_variables(&self) -> &[String] {
679 &self.sorted_variables
680 }
681
682 /// Returns the map of variable names to their indices.
683 pub fn variables(&self) -> &HashMap<String, u32> {
684 &self.variable_map
685 }
686
687 /// Returns the original equation strings.
688 pub fn equations(&self) -> &[String] {
689 &self.equations
690 }
691
692 /// Returns the compiled evaluation function.
693 pub fn fun(&self) -> &CombinedJITFunction {
694 &self.combined_fun
695 }
696
697 /// Returns the map of variable names to their derivative functions.
698 pub fn jacobian_funs(&self) -> &HashMap<String, CombinedJITFunction> {
699 &self.partial_derivatives
700 }
701
702 /// Returns the derivative function for a specific variable.
703 pub fn gradient_fun(&self, variable: &str) -> &CombinedJITFunction {
704 self.partial_derivatives.get(variable).unwrap()
705 }
706
707 /// Returns the number of equations in the system.
708 pub fn num_equations(&self) -> usize {
709 self.equations.len()
710 }
711
712 /// Validates that the number of input values matches the number of variables.
713 fn validate_input_length(&self, inputs: &[f64]) -> Result<(), EquationError> {
714 if inputs.len() != self.sorted_variables.len() {
715 return Err(EquationError::InvalidInputLength {
716 expected: self.sorted_variables.len(),
717 got: inputs.len(),
718 });
719 }
720 Ok(())
721 }
722}
723
724impl Clone for EquationSystem {
725 fn clone(&self) -> Self {
726 Self {
727 equations: self.equations.clone(),
728 asts: self.asts.clone(),
729 variable_map: self.variable_map.clone(),
730 sorted_variables: self.sorted_variables.clone(),
731 combined_fun: Arc::clone(&self.combined_fun),
732 partial_derivatives: self.partial_derivatives.clone(),
733 output_type: self.output_type,
734 }
735 }
736}
737
738#[derive(Debug, Clone, Copy)]
739pub(crate) enum OutputType {
740 Vector,
741 Matrix(usize, usize),
742}
743
744#[cfg(test)]
745mod tests {
746 use super::*;
747 use nalgebra::DVector;
748 use ndarray::{Array1, Array2};
749
750 #[test]
751 fn test_system_with_different_variables() -> Result<(), Box<dyn std::error::Error>> {
752 let expressions = vec![
753 "2*x + y".to_string(), // uses x, y
754 "z^2".to_string(), // uses only z
755 "x + y + z".to_string(), // uses all
756 ];
757
758 let system = EquationSystem::new(expressions)?;
759
760 // Check that all variables are tracked
761 assert_eq!(system.sorted_variables, &["x", "y", "z"]);
762
763 // Evaluate with values for all variables
764 let results = system.eval(&[1.0, 2.0, 3.0])?;
765
766 // Check results
767 assert_eq!(
768 results.as_slice(),
769 vec![
770 4.0, // 2*1 + 2
771 9.0, // 3^2
772 6.0, // 1 + 2 + 3
773 ]
774 );
775
776 Ok(())
777 }
778
779 #[test]
780 fn test_consistent_variable_ordering() -> Result<(), Box<dyn std::error::Error>> {
781 let expressions = vec![
782 "y + x".to_string(), // variables in different order
783 "x + z".to_string(), // different subset of variables
784 ];
785
786 let system = EquationSystem::new(expressions)?;
787
788 // Check that ordering is consistent (alphabetical)
789 assert_eq!(system.sorted_variables, &["x", "y", "z"]);
790
791 // Values must be provided in the consistent order [x, y, z]
792 let results = system.eval(&vec![1.0, 2.0, 3.0])?;
793
794 assert_eq!(
795 results.as_slice(),
796 vec![
797 3.0, // y(2.0) + x(1.0)
798 4.0, // x(1.0) + z(3.0)
799 ]
800 );
801
802 Ok(())
803 }
804
805 #[test]
806 #[should_panic]
807 fn test_invalid_input_length() {
808 let system = EquationSystem::new(vec!["x + y".to_string(), "y + z".to_string()]).unwrap();
809
810 // Should panic: providing only 2 values when system needs 3 (x, y, z)
811 let _ = system.eval(&[1.0, 2.0]).unwrap();
812 }
813
814 #[test]
815 fn test_complex_expressions() -> Result<(), Box<dyn std::error::Error>> {
816 let expressions = vec![
817 "(x + y) * (x - y)".to_string(), // difference of squares
818 "x^3 + y^2 * z".to_string(), // polynomial
819 "(x + y + z) / (x + 1)".to_string(), // division
820 ];
821
822 let system = EquationSystem::new(expressions)?;
823 let results = system.eval(&[2.0, 3.0, 4.0])?;
824
825 assert_eq!(results[0], -5.0); // (2 + 3) * (2 - 3) = 5 * -1 = -5
826 assert_eq!(results[1], 44.0); // 2^3 + 3^2 * 4 = 8 + 9 * 4 = 44
827 assert_eq!(results[2], 3.0); // (2 + 3 + 4) / (2 + 1) = 9 / 3 = 3
828
829 Ok(())
830 }
831
832 #[test]
833 fn test_custom_variable_map() -> Result<(), Box<dyn std::error::Error>> {
834 let mut var_map = HashMap::new();
835 var_map.insert("alpha".to_string(), 1);
836 var_map.insert("beta".to_string(), 0);
837
838 let expressions = vec!["2*alpha + beta".to_string(), "alpha^2 - beta".to_string()];
839
840 let system = EquationSystem::from_var_map(expressions, &var_map)?;
841 let results = system.eval(&[2.0, 1.0])?;
842
843 assert_eq!(results.as_slice(), &[4.0, -1.0]);
844
845 Ok(())
846 }
847
848 #[test]
849 fn test_error_undefined_variable() {
850 let expressions = vec![
851 "x + y".to_string(),
852 "x + undefined_var".to_string(), // undefined variable
853 ];
854
855 let mut var_map = HashMap::new();
856 var_map.insert("x".to_string(), 0);
857 var_map.insert("y".to_string(), 1);
858
859 let result = EquationSystem::from_var_map(expressions, &var_map);
860 assert!(matches!(result, Err(EquationError::VariableNotFound(_))));
861 }
862
863 #[test]
864 fn test_empty_system() -> Result<(), Box<dyn std::error::Error>> {
865 let system = EquationSystem::new(vec![])?;
866 let results = system.eval(&[])?;
867 assert!(results.is_empty());
868 Ok(())
869 }
870
871 #[test]
872 fn test_derive_wrt() -> Result<(), Box<dyn std::error::Error>> {
873 let system = EquationSystem::new(vec!["x^2*y".to_string(), "x*y^2".to_string()])?;
874
875 // First order derivative
876 let dx = system.derive_wrt(&["x"]).unwrap();
877 let mut dx_results = vec![0.0, 0.0];
878 dx.eval_into(&[2.0, 3.0], &mut dx_results).unwrap();
879 assert_eq!(dx_results, vec![12.0, 9.0]); // d/dx[x^2*y] = 2xy, d/dx[x*y^2] = y^2
880
881 // Second order derivative
882 let dxy = system.derive_wrt(&["x", "y"]).unwrap();
883 let mut dxy_results = vec![0.0, 0.0];
884 dxy.eval_into(&[2.0, 3.0], &mut dxy_results).unwrap();
885 assert_eq!(dxy_results, vec![4.0, 6.0]); // d²/dxdy[x^2*y] = 2x, d²/dxdy[x*y^2] = 2y
886
887 Ok(())
888 }
889
890 #[test]
891 fn test_derive_wrt_invalid_variable() {
892 let system =
893 EquationSystem::new(vec!["2*x + y^2".to_string(), "x^2 + z".to_string()]).unwrap();
894
895 let result = system.derive_wrt(&["w"]);
896 assert!(matches!(result, Err(EquationError::VariableNotFound(_))));
897 }
898
899 #[test]
900 fn test_jacobian() -> Result<(), Box<dyn std::error::Error>> {
901 let system = EquationSystem::new(vec![
902 "x^2*y".to_string(), // f1
903 "x*y^2".to_string(), // f2
904 ])?;
905
906 let jacobian = system.eval_jacobian(&[2.0, 3.0], None)?;
907
908 // Jacobian matrix should be:
909 // [∂f1/∂x ∂f1/∂y] = [12.0 4.0] // derivatives of first equation
910 // [∂f2/∂x ∂f2/∂y] = [9.0 12.0] // derivatives of second equation
911
912 assert_eq!(jacobian.len(), 2); // Two rows (one per equation)
913 assert_eq!(jacobian[0], vec![12.0, 4.0]); // Derivatives of first equation
914 assert_eq!(jacobian[1], vec![9.0, 12.0]); // Derivatives of second equation
915
916 Ok(())
917 }
918
919 #[test]
920 fn test_jacobian_wrt() -> Result<(), Box<dyn std::error::Error>> {
921 let system = EquationSystem::new(vec![
922 "x^2*y + z".to_string(), // f1
923 "x*y^2 - z^2".to_string(), // f2
924 ])?;
925
926 // Test subset of variables (x and y only)
927 let jacobian_fn = system.jacobian_wrt(&["x", "y"]).unwrap();
928 let mut results = Array2::zeros((2, 2));
929 jacobian_fn
930 .eval_into_matrix(&vec![2.0, 3.0, 1.0], &mut results)
931 .unwrap();
932
933 // Expected derivatives:
934 // ∂f1/∂x = 2xy = 2(2)(3) = 12
935 // ∂f1/∂y = x^2 = 4
936 // ∂f2/∂x = y^2 = 9
937 // ∂f2/∂y = 2xy = 2(2)(3) = 12
938 assert_eq!(results[[0, 0]], 12.0); // ∂f1/∂x
939 assert_eq!(results[[0, 1]], 4.0); // ∂f1/∂y
940 assert_eq!(results[[1, 0]], 9.0); // ∂f2/∂x
941 assert_eq!(results[[1, 1]], 12.0); // ∂f2/∂y
942
943 Ok(())
944 }
945
946 #[test]
947 fn test_jacobian_wrt_single_variable() -> Result<(), Box<dyn std::error::Error>> {
948 let system = EquationSystem::new(vec![
949 "x^2*y".to_string(), // f1
950 "x*y^2".to_string(), // f2
951 ])?;
952
953 // Test with single variable
954 let jacobian_fn = system.jacobian_wrt(&["x"])?;
955 let mut results = Array2::zeros((2, 1));
956 jacobian_fn
957 .eval_into_matrix(&vec![2.0, 3.0], &mut results)
958 .unwrap();
959
960 // Expected derivatives:
961 // ∂f1/∂x = 2xy = 2(2)(3) = 12
962 // ∂f2/∂x = y^2 = 9
963 assert_eq!(results[[0, 0]], 12.0); // [∂f1/∂x]
964 assert_eq!(results[[1, 0]], 9.0); // [∂f2/∂x]
965
966 Ok(())
967 }
968
969 #[test]
970 fn test_jacobian_wrt_all_variables() -> Result<(), Box<dyn std::error::Error>> {
971 let system = EquationSystem::new(vec![
972 "x^2*y + z".to_string(), // f1
973 "x*y^2 - z^2".to_string(), // f2
974 ])?;
975
976 // Test with all variables
977 let jacobian_fn = system.jacobian_wrt(&["x", "y", "z"])?;
978 let mut results = Array2::zeros((2, 3));
979 jacobian_fn
980 .eval_into_matrix(&vec![2.0, 3.0, 1.0], &mut results)
981 .unwrap();
982
983 // Expected derivatives:
984 // ∂f1/∂x = 2xy = 2(2)(3) = 12
985 // ∂f1/∂y = x^2 = 4
986 // ∂f1/∂z = 1
987 // ∂f2/∂x = y^2 = 9
988 // ∂f2/∂y = 2xy = 2(2)(3) = 12
989 // ∂f2/∂z = -2z = -2
990 assert_eq!(results[[0, 0]], 12.0); // ∂f1/∂x
991 assert_eq!(results[[0, 1]], 4.0); // ∂f1/∂y
992 assert_eq!(results[[0, 2]], 1.0); // ∂f1/∂z
993 assert_eq!(results[[1, 0]], 9.0); // ∂f2/∂x
994 assert_eq!(results[[1, 1]], 12.0); // ∂f2/∂y
995 assert_eq!(results[[1, 2]], -2.0); // ∂f2/∂z
996
997 Ok(())
998 }
999
1000 #[test]
1001 fn test_jacobian_wrt_invalid_variable() {
1002 let system =
1003 EquationSystem::new(vec!["x^2*y + z".to_string(), "x*y^2 - z^2".to_string()]).unwrap();
1004
1005 // Test with non-existent variable
1006 let result = system.jacobian_wrt(&["x", "w"]);
1007 assert!(matches!(result, Err(EquationError::VariableNotFound(_))));
1008 }
1009
1010 #[test]
1011 fn test_jacobian_wrt_reuse_buffer() -> Result<(), Box<dyn std::error::Error>> {
1012 let system = EquationSystem::new(vec![
1013 "x^2*y".to_string(), // f1
1014 "x*y^2".to_string(), // f2
1015 ])?;
1016
1017 let jacobian_fn = system.jacobian_wrt(&["x", "y"])?;
1018 let mut results = Array2::zeros((2, 2));
1019
1020 // First evaluation
1021 jacobian_fn
1022 .eval_into_matrix(&vec![2.0, 3.0], &mut results)
1023 .unwrap();
1024 assert_eq!(results[[0, 0]], 12.0);
1025 assert_eq!(results[[0, 1]], 4.0);
1026 assert_eq!(results[[1, 0]], 9.0);
1027 assert_eq!(results[[1, 1]], 12.0);
1028
1029 // Reuse buffer for second evaluation
1030 jacobian_fn
1031 .eval_into_matrix(&vec![1.0, 2.0], &mut results)
1032 .unwrap();
1033 assert_eq!(results[[0, 0]], 4.0);
1034 assert_eq!(results[[0, 1]], 1.0);
1035 assert_eq!(results[[1, 0]], 4.0);
1036 assert_eq!(results[[1, 1]], 4.0);
1037
1038 Ok(())
1039 }
1040
1041 #[test]
1042 fn test_different_vector_types() -> Result<(), Box<dyn std::error::Error>> {
1043 let system = EquationSystem::new(vec!["x^2*y".to_string(), "x*y^2".to_string()])?;
1044
1045 // Test with Vec<f64>
1046 let vec_inputs = vec![2.0, 3.0];
1047 let vec_results = system.eval(&vec_inputs)?;
1048 assert_eq!(vec_results.as_slice(), &[12.0, 18.0]);
1049
1050 // Test with ndarray::Array1
1051 let ndarray_inputs = Array1::from_vec(vec![2.0, 3.0]);
1052 let ndarray_results = system.eval(&ndarray_inputs)?;
1053 assert_eq!(ndarray_results.as_slice().unwrap(), &[12.0, 18.0]);
1054
1055 // Test with nalgebra::DVector
1056 let nalgebra_inputs = DVector::from_vec(vec![2.0, 3.0]);
1057 let nalgebra_results = system.eval(&nalgebra_inputs)?;
1058 assert_eq!(nalgebra_results.as_slice(), &[12.0, 18.0]);
1059
1060 Ok(())
1061 }
1062
1063 #[test]
1064 fn test_eval_parallel() -> Result<(), Box<dyn std::error::Error>> {
1065 let system = EquationSystem::new(vec!["x^2*y".to_string(), "x*y^2".to_string()])?;
1066
1067 let input_sets = vec![
1068 vec![2.0, 3.0],
1069 vec![1.0, 2.0],
1070 vec![3.0, 4.0],
1071 vec![0.0, 1.0],
1072 ];
1073
1074 let results = system.eval_parallel(&input_sets)?;
1075
1076 assert_eq!(results[0].as_slice(), &[12.0, 18.0]); // [2^2 * 3, 2 * 3^2]
1077 assert_eq!(results[1].as_slice(), &[2.0, 4.0]); // [1^2 * 2, 1 * 2^2]
1078 assert_eq!(results[2].as_slice(), &[36.0, 48.0]); // [3^2 * 4, 3 * 4^2]
1079 assert_eq!(results[3].as_slice(), &[0.0, 0.0]); // [0^2 * 1, 0 * 1^2]
1080
1081 Ok(())
1082 }
1083
1084 #[test]
1085 fn test_eval_into() -> Result<(), Box<dyn std::error::Error>> {
1086 let system = EquationSystem::new(vec!["x^2*y".to_string(), "x*y^2".to_string()])?;
1087
1088 // Test with Vec<f64>
1089 let mut results = vec![0.0; 2];
1090 system.eval_into(&vec![2.0, 3.0], &mut results)?;
1091 assert_eq!(results, vec![12.0, 18.0]);
1092
1093 // Test with ndarray
1094 let mut ndarray_results = Array1::zeros(2);
1095 system.eval_into(&Array1::from_vec(vec![2.0, 3.0]), &mut ndarray_results)?;
1096 assert_eq!(ndarray_results.as_slice().unwrap(), &[12.0, 18.0]);
1097
1098 // Test error case: wrong buffer size
1099 let mut wrong_size = vec![0.0; 3];
1100 assert!(matches!(
1101 system.eval_into(&vec![2.0, 3.0], &mut wrong_size),
1102 Err(EquationError::InvalidInputLength { .. })
1103 ));
1104
1105 Ok(())
1106 }
1107
1108 #[test]
1109 fn test_matrix_output_errors() -> Result<(), Box<dyn std::error::Error>> {
1110 let system = EquationSystem::new(vec!["x^2*y".to_string(), "x*y^2".to_string()])?;
1111
1112 // Regular vector system should error when trying to output as matrix
1113 let mut results = Array2::zeros((2, 2));
1114 assert!(matches!(
1115 system.eval_into_matrix(&vec![2.0, 3.0], &mut results),
1116 Err(EquationError::MatrixOutputRequired)
1117 ));
1118
1119 assert!(matches!(
1120 system.eval_matrix::<_, Array2<f64>>(&vec![2.0, 3.0]),
1121 Err(EquationError::MatrixOutputRequired)
1122 ));
1123
1124 Ok(())
1125 }
1126
1127 #[test]
1128 fn test_gradient() -> Result<(), Box<dyn std::error::Error>> {
1129 let system = EquationSystem::new(vec![
1130 "x^2*y + z".to_string(), // f1
1131 "x*y^2 - z^2".to_string(), // f2
1132 ])?;
1133
1134 // Test gradient with respect to x
1135 let dx = system.gradient(&[2.0, 3.0, 1.0], "x")?;
1136 assert_eq!(dx, vec![12.0, 9.0]); // [∂f1/∂x = 2xy, ∂f2/∂x = y^2]
1137
1138 // Test gradient with respect to y
1139 let dy = system.gradient(&[2.0, 3.0, 1.0], "y")?;
1140 assert_eq!(dy, vec![4.0, 12.0]); // [∂f1/∂y = x^2, ∂f2/∂y = 2xy]
1141
1142 // Test gradient with respect to z
1143 let dz = system.gradient(&[2.0, 3.0, 1.0], "z")?;
1144 assert_eq!(dz, vec![1.0, -2.0]); // [∂f1/∂z = 1, ∂f2/∂z = -2z]
1145
1146 // Test error case: invalid variable
1147 let result = system.gradient(&[2.0, 3.0, 1.0], "w");
1148 assert!(matches!(result, Err(EquationError::VariableNotFound(_))));
1149
1150 // Test error case: wrong input length
1151 let result = system.gradient(&[2.0, 3.0], "x");
1152 assert!(matches!(
1153 result,
1154 Err(EquationError::InvalidInputLength { .. })
1155 ));
1156
1157 Ok(())
1158 }
1159
1160 #[test]
1161 fn test_eval_matrix_on_vector_system() -> Result<(), Box<dyn std::error::Error>> {
1162 let system = EquationSystem::new(vec!["x^2*y".to_string(), "x*y^2".to_string()])?;
1163
1164 // Attempt to evaluate as matrix should fail
1165 let mut results = Array2::zeros((2, 2));
1166 assert!(matches!(
1167 system.eval_into_matrix(&vec![2.0, 3.0], &mut results),
1168 Err(EquationError::MatrixOutputRequired)
1169 ));
1170
1171 // Direct matrix evaluation should also fail
1172 assert!(matches!(
1173 system.eval_matrix::<_, Array2<f64>>(&vec![2.0, 3.0]),
1174 Err(EquationError::MatrixOutputRequired)
1175 ));
1176
1177 Ok(())
1178 }
1179
1180 #[test]
1181 fn test_getters() -> Result<(), Box<dyn std::error::Error>> {
1182 let system = EquationSystem::new(vec!["x^2*y".to_string(), "x*y^2".to_string()])?;
1183
1184 // Test sorted_variables()
1185 assert_eq!(system.sorted_variables(), &["x", "y"]);
1186
1187 // Test variables()
1188 let var_map = system.variables();
1189 assert_eq!(var_map.get("x"), Some(&0));
1190 assert_eq!(var_map.get("y"), Some(&1));
1191
1192 // Test equations()
1193 assert_eq!(system.equations(), &["x^2*y", "x*y^2"]);
1194
1195 // Test fun() returns a valid function
1196 let fun = system.fun();
1197 let mut results = vec![0.0; 2];
1198 fun(&[2.0, 3.0], &mut results);
1199 assert_eq!(results, vec![12.0, 18.0]);
1200
1201 // Test jacobian_funs() returns valid derivative functions
1202 let jacobian_funs = system.jacobian_funs();
1203 assert!(jacobian_funs.contains_key("x"));
1204 assert!(jacobian_funs.contains_key("y"));
1205
1206 // Test gradient_fun() returns valid derivative function
1207 let dx_fun = system.gradient_fun("x");
1208 let mut dx_results = vec![0.0; 2];
1209 dx_fun(&[2.0, 3.0], &mut dx_results);
1210 assert_eq!(dx_results, vec![12.0, 9.0]); // [∂(x^2*y)/∂x, ∂(x*y^2)/∂x]
1211
1212 // Test num_equations()
1213 assert_eq!(system.num_equations(), 2);
1214
1215 Ok(())
1216 }
1217}